Source code for azure.ai.ml.entities._assets._artifacts.data

# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import os
import re
from typing import Dict, Optional, Type, Union
from pathlib import Path
from os import PathLike

from azure.ai.ml.entities._assets import Artifact
from .artifact import ArtifactStorageInfo
from azure.ai.ml._utils.utils import is_url, load_yaml
from azure.ai.ml.constants import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, SHORT_URI_FORMAT, AssetTypes
from azure.ai.ml._restclient.v2022_05_01.models import (
    DataContainerData,
    DataContainerDetails,
    DataType,
    DataVersionBaseData,
    DataVersionBaseDetails,
    UriFileDataVersion,
    UriFolderDataVersion,
    MLTableData,
)
from azure.ai.ml._utils._arm_id_utils import get_arm_id_object_from_id
from azure.ai.ml._schema import DataSchema
from azure.ai.ml.entities._util import load_from_dict
from azure.ai.ml._ml_exceptions import ValidationException, ErrorCategory, ErrorTarget

DataAssetTypeModelMap: Dict[str, Type[DataVersionBaseDetails]] = {
    AssetTypes.URI_FILE: UriFileDataVersion,
    AssetTypes.URI_FOLDER: UriFolderDataVersion,
    AssetTypes.MLTABLE: MLTableData,
}


def getModelForDataAssetType(data_asset_type: str) -> Type[DataVersionBaseDetails]:
    model = DataAssetTypeModelMap.get(data_asset_type)
    if model is None:
        msg = "Unknown DataType {}".format(data_asset_type)
        raise ValidationException(
            message=msg,
            no_personal_data_message=msg,
            target=ErrorTarget.ARTIFACT,
            error_category=ErrorCategory.USER_ERROR,
        )
    return model


DataTypeMap: Dict[DataType, str] = {
    DataType.URI_FILE: AssetTypes.URI_FILE,
    DataType.URI_FOLDER: AssetTypes.URI_FOLDER,
    DataType.MLTABLE: AssetTypes.MLTABLE,
}


def getDataAssetType(data_type: DataType) -> str:
    return DataTypeMap.get(data_type, data_type)  # pass through value if no match found


[docs]class Data(Artifact): """Data for training and scoring. :param name: Name of the resource. :type name: str :param version: Version of the resource. :type version: str :param description: Description of the resource. :type description: str :param tags: Tag dictionary. Tags can be added, removed, and updated. :type tags: dict[str, str] :param properties: The asset property dictionary. :type properties: dict[str, str] :param path: The path to the asset on the datastore. This can be local or remote :type path: str :param type: The type of the asset. Valid values are uri_file, uri_folder, mltable. Defaults to uri_folder. :type type: Literal[AssetTypes.URI_FILE, AssetTypes.URI_FOLDER, AssetTypes.MLTABLE] :param kwargs: A dictionary of additional configuration parameters. :type kwargs: dict """ def __init__( self, *, name: Optional[str] = None, version: Optional[str] = None, description: Optional[str] = None, tags: Optional[Dict] = None, properties: Optional[Dict] = None, path: Optional[str] = None, # if type is mltable, the path has to be a folder. type: str = AssetTypes.URI_FOLDER, # type: ignore **kwargs, ): self._skip_validation = kwargs.pop("skip_validation", False) self._mltable_schema_url = kwargs.pop("mltable_schema_url", None) self._referenced_uris = kwargs.pop("referenced_uris", None) self.type = type super().__init__( name=name, version=version, path=path, description=description, tags=tags, properties=properties, **kwargs, ) self.path = path @property def path(self) -> Optional[Union[str, PathLike]]: return self._path @path.setter def path(self, value: str) -> None: # Call the parent setter to resolve the path with base_path if it was a local path super(Data, type(self)).path.fset(self, value) if self.type == AssetTypes.URI_FOLDER and self._path is not None and not is_url(self._path): self._path = Path(os.path.join(self._path, "")) @classmethod def _load( cls, data: Optional[Dict] = None, yaml_path: Optional[Union[PathLike, str]] = None, params_override: Optional[list] = None, **kwargs, ) -> "Data": data = data or {} params_override = params_override or [] context = { BASE_PATH_CONTEXT_KEY: Path(yaml_path).parent if yaml_path else Path("./"), PARAMS_OVERRIDE_KEY: params_override, } data_asset = Data._load_from_dict(yaml_data=data, context=context, **kwargs) return data_asset @classmethod def _load_from_dict(cls, yaml_data: Dict, context: Dict, **kwargs) -> "Data": return Data(**load_from_dict(DataSchema, yaml_data, context, **kwargs)) def _to_dict(self) -> Dict: return DataSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self) def _to_container_rest_object(self) -> DataContainerData: VersionDetailsClass = getModelForDataAssetType(self.type) return DataContainerData( properties=DataContainerDetails( properties=self.properties, tags=self.tags, is_archived=False, data_type=VersionDetailsClass.data_type ) ) def _to_rest_object(self) -> DataVersionBaseData: VersionDetailsClass = getModelForDataAssetType(self.type) data_version_details = VersionDetailsClass( description=self.description, is_anonymous=self._is_anonymous, tags=self.tags, is_archived=False, properties=self.properties, data_uri=self.path, ) if VersionDetailsClass._attribute_map.get("referenced_uris") is not None: data_version_details.referenced_uris = self._referenced_uris return DataVersionBaseData(properties=data_version_details) @classmethod def _from_container_rest_object(cls, data_container_rest_object: DataContainerData) -> "Data": data_rest_object_details: DataContainerDetails = data_container_rest_object.properties data = Data( name=data_container_rest_object.name, creation_context=data_container_rest_object.system_data, tags=data_rest_object_details.tags, properties=data_rest_object_details.properties, type=getDataAssetType(data_rest_object_details.data_type), ) data.latest_version = data_rest_object_details.latest_version return data @classmethod def _from_rest_object(cls, data_rest_object: DataVersionBaseData) -> "Data": data_rest_object_details: DataVersionBaseDetails = data_rest_object.properties arm_id_object = get_arm_id_object_from_id(data_rest_object.id) path = data_rest_object_details.data_uri data = Data( id=data_rest_object.id, name=arm_id_object.asset_name, version=arm_id_object.asset_version, path=path, type=getDataAssetType(data_rest_object_details.data_type), description=data_rest_object_details.description, tags=data_rest_object_details.tags, properties=data_rest_object_details.properties, creation_context=data_rest_object.system_data, is_anonymous=data_rest_object_details.is_anonymous, referenced_uris=getattr(data_rest_object_details, "referenced_uris", None), ) return data def _update_path(self, asset_artifact: ArtifactStorageInfo) -> None: regex = r"datastores\/(.+)" groups = re.search(regex, asset_artifact.datastore_arm_id) if groups: datastore_name = groups.group(1) self.path = SHORT_URI_FORMAT.format(datastore_name, asset_artifact.relative_path)