Source code for azure.ai.ml.operations._model_operations

# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import logging
from typing import Iterable, Union, Dict

from azure.ai.ml._artifacts._artifact_utilities import (
    _check_and_upload_path,
    _update_metadata,
    _get_default_datastore_info,
)
from azure.ai.ml._artifacts._constants import (
    ASSET_PATH_ERROR,
    CHANGED_ASSET_PATH_MSG,
    CHANGED_ASSET_PATH_MSG_NO_PERSONAL_DATA,
)
from azure.ai.ml.operations._datastore_operations import DatastoreOperations
from azure.ai.ml._restclient.v2022_05_01 import AzureMachineLearningWorkspaces as ServiceClient052022
from azure.ai.ml._restclient.v2022_02_01_preview.models import ModelVersionData, ListViewType
from azure.ai.ml._utils._registry_utils import get_sas_uri_for_registry_asset, get_asset_body_for_registry_storage

from azure.ai.ml._utils._asset_utils import _create_or_update_autoincrement
from azure.ai.ml._scope_dependent_operations import OperationScope, _ScopeDependentOperations
from azure.ai.ml.entities._assets import Model
from azure.ai.ml.entities._datastore.credentials import AccountKeyCredentials
from os import path, PathLike, getcwd
from azure.ai.ml._utils._storage_utils import get_storage_client, get_ds_name_and_path_prefix
from azure.ai.ml._utils._asset_utils import _get_latest, _resolve_label_to_asset, _archive_or_restore
from azure.ai.ml._utils.utils import resolve_short_datastore_url, validate_ml_flow_folder
from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import (
    AzureMachineLearningWorkspaces as ServiceClient102021Dataplane,
)

from azure.ai.ml._telemetry import AML_INTERNAL_LOGGER_NAMESPACE, ActivityType, monitor_with_activity
from azure.ai.ml._ml_exceptions import ErrorCategory, AssetPathException, ErrorTarget, ValidationException

logger = logging.getLogger(AML_INTERNAL_LOGGER_NAMESPACE + __name__)
logger.propagate = False
module_logger = logging.getLogger(__name__)


[docs]class ModelOperations(_ScopeDependentOperations): """ ModelOperations You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it for you and attaches it as an attribute. """ def __init__( self, operation_scope: OperationScope, service_client: Union[ServiceClient052022, ServiceClient102021Dataplane], datastore_operations: DatastoreOperations, **kwargs: Dict, ): super(ModelOperations, self).__init__(operation_scope) if "app_insights_handler" in kwargs: logger.addHandler(kwargs.pop("app_insights_handler")) self._model_versions_operation = service_client.model_versions self._model_container_operation = service_client.model_containers self._service_client = service_client self._datastore_operation = datastore_operations # Maps a label to a function which given an asset name, # returns the asset associated with the label self._managed_label_resolver = {"latest": self._get_latest_version}
[docs] @monitor_with_activity(logger, "Model.CreateOrUpdate", ActivityType.PUBLICAPI) def create_or_update(self, model: Model) -> Model: # TODO: Are we going to implement job_name? name = model.name version = model.version sas_uri = None if self._registry_name: sas_uri = get_sas_uri_for_registry_asset( service_client=self._service_client, name=model.name, version=model.version, resource_group=self._resource_group_name, registry=self._registry_name, body=get_asset_body_for_registry_storage(self._registry_name, "models", model.name, model.version), ) if not sas_uri: module_logger.debug(f"Getting the existing asset name: {model.name}, version: {model.version}") return self.get(name=model.name, version=model.version) model, indicator_file = _check_and_upload_path(artifact=model, asset_operations=self, sas_uri=sas_uri) model.path = resolve_short_datastore_url(model.path, self._operation_scope) validate_ml_flow_folder(model.path, model.type) model_version_resource = model._to_rest_object() auto_increment_version = model._auto_increment_version try: if auto_increment_version: result = _create_or_update_autoincrement( name=model.name, body=model_version_resource, version_operation=self._model_versions_operation, container_operation=self._model_container_operation, workspace_name=self._workspace_name, **self._scope_kwargs, ) else: result = ( self._model_versions_operation.begin_create_or_update( name=name, version=version, body=model_version_resource, registry_name=self._registry_name, **self._scope_kwargs, ).result() if self._registry_name else self._model_versions_operation.create_or_update( name=name, version=version, body=model_version_resource, workspace_name=self._workspace_name, **self._scope_kwargs, ) ) if not result and self._registry_name: result = self._get(name=model.name, version=model.version) except Exception as e: # service side raises an exception if we attempt to update an existing asset's path if str(e) == ASSET_PATH_ERROR: raise AssetPathException( message=CHANGED_ASSET_PATH_MSG, target=ErrorTarget.MODEL, no_personal_data_message=CHANGED_ASSET_PATH_MSG_NO_PERSONAL_DATA, error_category=ErrorCategory.USER_ERROR, ) else: raise e model = Model._from_rest_object(result) if auto_increment_version and indicator_file: datastore_info = _get_default_datastore_info(self._datastore_operation) _update_metadata(model.name, model.version, indicator_file, datastore_info) # update version in storage return model
def _get(self, name: str, version: str = None) -> ModelVersionData: # name:latest if version: return ( self._model_versions_operation.get( name=name, version=version, registry_name=self._registry_name, **self._scope_kwargs ) if self._registry_name else self._model_versions_operation.get( name=name, version=version, workspace_name=self._workspace_name, **self._scope_kwargs ) ) else: return ( self._model_container_operation.get(name=name, registry_name=self._registry_name, **self._scope_kwargs) if self._registry_name else self._model_container_operation.get( name=name, workspace_name=self._workspace_name, **self._scope_kwargs ) )
[docs] @monitor_with_activity(logger, "Model.Get", ActivityType.PUBLICAPI) def get(self, name: str, version: str = None, label: str = None) -> Model: """Returns information about the specified model asset. :param name: Name of the model. :type name: str :param version: Version of the model. :type version: str :param label: Label of the model. (mutually exclusive with version) :type label: str """ if version and label: msg = "Cannot specify both version and label." raise ValidationException( message=msg, target=ErrorTarget.MODEL, no_personal_data_message=msg, error_category=ErrorCategory.USER_ERROR, ) if label: return _resolve_label_to_asset(self, name, label) if not version: msg = "Must provide either version or label" raise ValidationException( message=msg, target=ErrorTarget.MODEL, no_personal_data_message=msg, error_category=ErrorCategory.USER_ERROR, ) # TODO: We should consider adding an exception trigger for internal_model=None model_version_resource = self._get(name, version) return Model._from_rest_object(model_version_resource)
[docs] @monitor_with_activity(logger, "Model.Download", ActivityType.PUBLICAPI) def download(self, name: str, version: str, download_path: Union[PathLike, str] = getcwd()) -> None: """Download files related to a model. :param str name: Name of the model. :param str version: Version of the model. :param Union[PathLike, str] download_path: Local path as download destination, defaults to current working directory of the current user. Contents will be overwritten. :raise: ResourceNotFoundError if can't find a model matching provided name. """ model_uri = self.get(name=name, version=version).path ds_name, path_prefix = get_ds_name_and_path_prefix(model_uri) ds = self._datastore_operation.get(ds_name, include_secrets=True) acc_name = ds.account_name if isinstance(ds.credentials, AccountKeyCredentials): credential = ds.credentials.account_key else: try: credential = ds.credentials.sas_token except Exception as e: if not hasattr(ds.credentials, "sas_token"): credential = self._datastore_operation._credential else: raise e container = ds.container_name datastore_type = ds.type storage_client = get_storage_client( credential=credential, container_name=container, storage_account=acc_name, storage_type=datastore_type ) path_file = "{}{}{}".format(download_path, path.sep, name) is_directory = storage_client.exists(f"{path_prefix.rstrip('/')}/") if is_directory: path_file = path.join(path_file, path.basename(path_prefix.rstrip("/"))) module_logger.info(f"Downloading the model {path_prefix} at {path_file}\n") storage_client.download(starts_with=path_prefix, destination=path_file)
[docs] @monitor_with_activity(logger, "Model.Archive", ActivityType.PUBLICAPI) def archive(self, name: str, version: str = None, label: str = None) -> None: """Archive a model asset. :param name: Name of model asset. :type name: str :param version: Version of model asset. :type version: str :param label: Label of the model asset. (mutually exclusive with version) :type label: str """ _archive_or_restore( asset_operations=self, version_operation=self._model_versions_operation, container_operation=self._model_container_operation, is_archived=True, name=name, version=version, label=label, )
[docs] @monitor_with_activity(logger, "Model.Restore", ActivityType.PUBLICAPI) def restore(self, name: str, version: str = None, label: str = None) -> None: """Restore an archived model asset. :param name: Name of model asset. :type name: str :param version: Version of model asset. :type version: str :param label: Label of the model asset. (mutually exclusive with version) :type label: str """ _archive_or_restore( asset_operations=self, version_operation=self._model_versions_operation, container_operation=self._model_container_operation, is_archived=False, name=name, version=version, label=label, )
[docs] @monitor_with_activity(logger, "Model.List", ActivityType.PUBLICAPI) def list(self, name: str = None, *, list_view_type: ListViewType = ListViewType.ACTIVE_ONLY) -> Iterable[Model]: """List all model assets in workspace. :param name: Name of the model. :type name: Optional[str] :param list_view_type: View type for including/excluding (for example) archived models. Default: ACTIVE_ONLY. :type list_view_type: Optional[ListViewType] :return: An iterator like instance of Model objects :rtype: ~azure.core.paging.ItemPaged[Model] """ if name: return ( self._model_versions_operation.list( name=name, registry_name=self._registry_name, cls=lambda objs: [Model._from_rest_object(obj) for obj in objs], **self._scope_kwargs, ) if self._registry_name else self._model_versions_operation.list( name=name, workspace_name=self._workspace_name, cls=lambda objs: [Model._from_rest_object(obj) for obj in objs], list_view_type=list_view_type, **self._scope_kwargs, ) ) else: return ( self._model_container_operation.list( registry_name=self._registry_name, cls=lambda objs: [Model._from_container_rest_object(obj) for obj in objs], list_view_type=list_view_type, **self._scope_kwargs, ) if self._registry_name else self._model_container_operation.list( workspace_name=self._workspace_name, cls=lambda objs: [Model._from_container_rest_object(obj) for obj in objs], list_view_type=list_view_type, **self._scope_kwargs, ) )
def _get_latest_version(self, name: str) -> Model: """Returns the latest version of the asset with the given name. Latest is defined as the most recently created, not the most recently updated. """ result = _get_latest(name, self._model_versions_operation, self._resource_group_name, self._workspace_name) return Model._from_rest_object(result)