Source code for azure.ai.ml.operations._online_endpoint_operations

# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import json
import logging
import time
from typing import Dict, Iterable, Union, Any

from azure.core.polling import LROPoller
from azure.identity import ChainedTokenCredential

from azure.ai.ml._restclient.v2022_02_01_preview import (
    AzureMachineLearningWorkspaces as ServiceClient022022Preview,
)

from azure.ai.ml._restclient.v2022_02_01_preview.models import (
    EndpointAuthKeys,
    EndpointAuthToken,
    OnlineEndpointTrackedResourceArmPaginatedResult,
    RegenerateEndpointKeysRequest,
    KeyType,
)
from ._operation_orchestrator import OperationOrchestrator

from azure.ai.ml.entities._assets import Data
from azure.ai.ml._utils._endpoint_utils import polling_wait, post_and_validate_response
from azure.ai.ml._scope_dependent_operations import OperationsContainer, OperationScope, _ScopeDependentOperations
from azure.ai.ml.operations._local_endpoint_helper import _LocalEndpointHelper
from azure.ai.ml.constants import (
    KEY,
    AzureMLResourceType,
    EndpointInvokeFields,
    LROConfigurations,
    EndpointKeyType,
)
from azure.ai.ml.entities import OnlineEndpoint, OnlineDeployment
from azure.ai.ml._utils._azureml_polling import AzureMLPolling

from azure.ai.ml._telemetry import AML_INTERNAL_LOGGER_NAMESPACE, ActivityType, monitor_with_activity
from azure.ai.ml._ml_exceptions import ValidationException, ErrorCategory, ErrorTarget

logger = logging.getLogger(AML_INTERNAL_LOGGER_NAMESPACE + __name__)
logger.propagate = False
module_logger = logging.getLogger(__name__)


def _strip_zeroes_from_traffic(traffic: Dict[str, str]) -> Dict[str, str]:
    return {k: v for k, v in traffic.items() if v and int(v) != 0}


[docs]class OnlineEndpointOperations(_ScopeDependentOperations): """ OnlineEndpointOperations You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it for you and attaches it as an attribute. """ def __init__( self, operation_scope: OperationScope, service_client_02_2022_preview: ServiceClient022022Preview, all_operations: OperationsContainer, local_endpoint_helper: _LocalEndpointHelper, credentials: ChainedTokenCredential = None, **kwargs: Dict, ): super(OnlineEndpointOperations, self).__init__(operation_scope) if "app_insights_handler" in kwargs: logger.addHandler(kwargs.pop("app_insights_handler")) self._online_operation = service_client_02_2022_preview.online_endpoints self._online_deployment_operation = service_client_02_2022_preview.online_deployments self._all_operations = all_operations self._local_endpoint_helper = local_endpoint_helper self._credentials = credentials self._init_kwargs = kwargs
[docs] @monitor_with_activity(logger, "OnlineEndpoint.List", ActivityType.PUBLICAPI) def list(self, local: bool = False) -> Iterable[OnlineEndpointTrackedResourceArmPaginatedResult]: """List endpoints of the workspace. :param (bool, optional) local: a flag to indicate whether to interact with endpoints in local Docker environment. Default: False. :return: a list of endpoints """ if local: return self._local_endpoint_helper.list() return self._online_operation.list( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, cls=lambda objs: [OnlineEndpoint._from_rest_object(obj) for obj in objs], **self._init_kwargs, )
[docs] @monitor_with_activity(logger, "OnlineEndpoint.ListKeys", ActivityType.PUBLICAPI) def list_keys(self, name: str) -> Union[EndpointAuthKeys, EndpointAuthToken]: """List the keys :param name str: the endpoint name :raise: Exception if cannot get online credentials :return Union[EndpointAuthKeys, EndpointAuthToken]: depending on the auth mode in the endpoint, returns either keys or token """ return self._get_online_credentials(name=name)
[docs] @monitor_with_activity(logger, "OnlineEndpoint.Get", ActivityType.PUBLICAPI) def get( self, name: str, local: bool = False, ) -> OnlineEndpoint: """Get a Endpoint resource. :param str name: Name of the endpoint. :param (bool, optional) local: a flag to indicate whether to interact with endpoints in local Docker environment. Default: False. :return: Endpoint object retrieved from the service. :rtype: OnlineEndpoint: """ # first get the endpoint if local: return self._local_endpoint_helper.get(endpoint_name=name) endpoint = self._online_operation.get( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=name, **self._init_kwargs, ) deployments_list = self._online_deployment_operation.list( endpoint_name=name, resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, cls=lambda objs: [OnlineDeployment._from_rest_object(obj) for obj in objs], **self._init_kwargs, ) # populate deployments without traffic with zeroes in traffic map converted_endpoint = OnlineEndpoint._from_rest_object(endpoint) if deployments_list: for deployment in deployments_list: if not converted_endpoint.traffic.get(deployment.name) and not converted_endpoint.mirror_traffic.get( deployment.name ): converted_endpoint.traffic[deployment.name] = 0 return converted_endpoint
[docs] @monitor_with_activity(logger, "OnlineEndpoint.BeginDelete", ActivityType.PUBLICAPI) def begin_delete(self, name: str = None, local: bool = False, **kwargs: Any) -> LROPoller: """Delete an Online Endpoint. :param name: Name of the endpoint. :type name: str :param local: Whether to interact with the endpoint in local Docker environment. Defaults to False. :type local: bool :return: A poller to track the operation status if remote, else returns None if local. :rtype: Optional[LROPoller] """ if local: return self._local_endpoint_helper.delete(name=name) start_time = time.time() path_format_arguments = { "endpointName": name, "resourceGroupName": self._resource_group_name, "workspaceName": self._workspace_name, } no_wait = kwargs.get("no_wait", False) delete_poller = self._online_operation.begin_delete( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=name, polling=AzureMLPolling( LROConfigurations.POLL_INTERVAL, path_format_arguments=path_format_arguments, **self._init_kwargs, ) if not no_wait else False, polling_interval=LROConfigurations.POLL_INTERVAL, **self._init_kwargs, ) if no_wait: module_logger.info( f"Delete request initiated. Status can be checked using `az ml online-endpoint show -n {name}`\n" ) return delete_poller else: message = f"Deleting endpoint {name} \n" module_logger.warning( f" Delete request initiated. If you interrupt this command or it times out while waiting for deletion to complete, status can be checked using `az ml online-endpoint show -n {name}`\n" ) polling_wait(poller=delete_poller, start_time=start_time, message=message, timeout=None)
[docs] @monitor_with_activity(logger, "OnlineEndpoint.BeginDeleteOrUpdate", ActivityType.PUBLICAPI) def begin_create_or_update(self, endpoint: OnlineEndpoint, local: bool = False, **kwargs: Any) -> LROPoller: """Create or update an endpoint :param endpoint: The endpoint entity. :type endpoint: Endpoint :param local: Whether to interact with the endpoint in local Docker environment. Defaults to False. :type local: bool :return: A poller to track the operation status if remote, else returns None if local. :rtype: LROPoller """ if local: return self._local_endpoint_helper.create_or_update(endpoint=endpoint) no_wait = kwargs.get("no_wait", False) try: location = self._get_workspace_location() if endpoint.traffic: endpoint.traffic = _strip_zeroes_from_traffic(endpoint.traffic) if endpoint.mirror_traffic: endpoint.mirror_traffic = _strip_zeroes_from_traffic(endpoint.mirror_traffic) endpoint_resource = endpoint._to_rest_online_endpoint(location=location) orchestrators = OperationOrchestrator( operation_container=self._all_operations, operation_scope=self._operation_scope ) if hasattr(endpoint_resource.properties, "compute"): endpoint_resource.properties.compute = orchestrators.get_asset_arm_id( endpoint_resource.properties.compute, azureml_type=AzureMLResourceType.COMPUTE ) poller = self._online_operation.begin_create_or_update( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=endpoint.name, body=endpoint_resource, polling=not no_wait, **self._init_kwargs, ) if no_wait: module_logger.info( f"Endpoint Create/Update request initiated. Status can be checked using `az ml online-endpoint show -n {endpoint.name}`\n" ) return poller else: return OnlineEndpoint._from_rest_object(poller.result()) except Exception as ex: raise ex
[docs] @monitor_with_activity(logger, "OnlineEndpoint.BeginGenerateKeys", ActivityType.PUBLICAPI) def begin_regenerate_keys( self, name: str, key_type: str = EndpointKeyType.PRIMARY_KEY_TYPE, **kwargs: Any ) -> LROPoller: """Regenerate keys for endpoint :param name: The endpoint name. :type name: The endpoint type. Defaults to ONLINE_ENDPOINT_TYPE. :param key_type: One of "primary", "secondary". Defaults to "primary". :type key_type: str :return: A poller to track the operation status. :rtype: LROPoller """ endpoint = self._online_operation.get( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=name, **self._init_kwargs, ) no_wait = kwargs.get("no_wait", False) if endpoint.properties.auth_mode.lower() == "key": return self._regenerate_online_keys(name=name, key_type=key_type, no_wait=no_wait) else: raise ValidationException( message=f"Endpoint '{name}' does not use keys for authentication.", target=ErrorTarget.ONLINE_ENDPOINT, no_personal_data_message="Endpoint does not use keys for authentication.", error_category=ErrorCategory.USER_ERROR, )
[docs] @monitor_with_activity(logger, "OnlineEndpoint.Invoke", ActivityType.PUBLICAPI) def invoke( self, endpoint_name: str, request_file: str = None, deployment_name: str = None, input_data: Union[str, Data] = None, params_override=None, local: bool = False, **kwargs, ) -> str: """Invokes the endpoint with the provided payload :param str endpoint_name: the endpoint name :param (str, optional) request_file: File containing the request payload. This is only valid for online endpoint. :param (str, optional) deployment_name: Name of a specific deployment to invoke. This is optional. By default requests are routed to any of the deployments according to the traffic rules. :param (Union[str, Data], optional) input_data: To use a pre-registered data asset, pass str in format :param (bool, optional) local: a flag to indicate whether to interact with endpoints in local Docker environment. Default: False. Returns: str: Prediction output for online endpoint. """ params_override = params_override or [] # Until this bug is resolved https://msdata.visualstudio.com/Vienna/_workitems/edit/1446538 if deployment_name: self._validate_deployment_name(endpoint_name, deployment_name) with open(request_file, "rb") as f: data = json.loads(f.read()) if local: return self._local_endpoint_helper.invoke( endpoint_name=endpoint_name, data=data, deployment_name=deployment_name ) endpoint = self._online_operation.get( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=endpoint_name, **self._init_kwargs, ) keys = self._get_online_credentials(name=endpoint_name, auth_mode=endpoint.properties.auth_mode) if isinstance(keys, EndpointAuthKeys): key = keys.primary_key elif isinstance(keys, EndpointAuthToken): key = keys.access_token else: key = "" headers = EndpointInvokeFields.DEFAULT_HEADER if key: headers[EndpointInvokeFields.AUTHORIZATION] = f"Bearer {key}" if deployment_name: headers[EndpointInvokeFields.MODEL_DEPLOYMENT] = deployment_name response = post_and_validate_response(endpoint.properties.scoring_uri, json=data, headers=headers, **kwargs) return response.text
def _get_workspace_location(self) -> str: return self._all_operations.all_operations[AzureMLResourceType.WORKSPACE].get(self._workspace_name).location def _get_online_credentials(self, name: str, auth_mode: str = None) -> Union[EndpointAuthKeys, EndpointAuthToken]: if not auth_mode: endpoint = self._online_operation.get( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=name, **self._init_kwargs, ) auth_mode = endpoint.properties.auth_mode if auth_mode.lower() == KEY: return self._online_operation.list_keys( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=name, **self._init_kwargs, ) else: return self._online_operation.get_token( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=name, **self._init_kwargs, ) def _regenerate_online_keys( self, name: str, key_type: str = EndpointKeyType.PRIMARY_KEY_TYPE, no_wait: bool = False ) -> Union[LROPoller[None], None]: keys = self._online_operation.list_keys( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=name, **self._init_kwargs, ) if key_type.lower() == EndpointKeyType.PRIMARY_KEY_TYPE: key_request = RegenerateEndpointKeysRequest(key_type=KeyType.Primary, key_value=keys.primary_key) elif key_type.lower() == EndpointKeyType.SECONDARY_KEY_TYPE: key_request = RegenerateEndpointKeysRequest(key_type=KeyType.Secondary, key_value=keys.secondary_key) else: msg = "Key type must be 'primary' or 'secondary'." raise ValidationException( message=msg, target=ErrorTarget.ONLINE_ENDPOINT, no_personal_data_message=msg, error_category=ErrorCategory.USER_ERROR, ) poller = self._online_operation.begin_regenerate_keys( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=name, body=key_request, **self._init_kwargs, ) if not no_wait: return polling_wait(poller=poller, message="regenerate key") else: return poller def _validate_deployment_name(self, endpoint_name, deployment_name): deployments_list = self._online_deployment_operation.list( endpoint_name=endpoint_name, resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, cls=lambda objs: [obj.name for obj in objs], **self._init_kwargs, ) if deployments_list: if deployment_name not in deployments_list: raise ValidationException( message=f"Deployment name {deployment_name} not found for this endpoint", target=ErrorTarget.ONLINE_ENDPOINT, no_personal_data_message="Deployment name not found for this endpoint", error_category=ErrorCategory.USER_ERROR, ) else: msg = "No deployment exists for this endpoint" raise ValidationException( message=msg, target=ErrorTarget.ONLINE_ENDPOINT, no_personal_data_message=msg, error_category=ErrorCategory.USER_ERROR, )