# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import json
import os
from pathlib import Path
from marshmallow import INCLUDE, Schema
from typing import Dict, Union
from azure.ai.ml._restclient.v2022_05_01.models import (
ComponentVersionData,
ComponentVersionDetails,
)
from azure.ai.ml._schema.component.command_component import CommandComponentSchema, RestCommandComponentSchema
from azure.ai.ml.entities._job.distribution import (
DistributionConfiguration,
MpiDistribution,
TensorFlowDistribution,
PyTorchDistribution,
)
from azure.ai.ml.entities._job.resource_configuration import ResourceConfiguration
from azure.ai.ml.entities._job.parameterized_command import ParameterizedCommand
from azure.ai.ml.entities._assets import Environment
from azure.ai.ml.constants import BASE_PATH_CONTEXT_KEY, COMPONENT_TYPE, ComponentSource
from azure.ai.ml.constants import NodeType
from azure.ai.ml.entities._component.input_output import ComponentInput, ComponentOutput
from .component import Component
from .._util import validate_attribute_type
from azure.ai.ml._ml_exceptions import ValidationException, ErrorCategory, ErrorTarget
from .._validation import ValidationResult, _ValidationResultBuilder
from ..._schema import PathAwareSchema
from ..._utils.utils import get_all_data_binding_expressions, parse_args_description_from_docstring
[docs]class CommandComponent(Component, ParameterizedCommand):
"""Command component version, used to define a command component.
:param name: Name of the component.
:type name: str
:param version: Version of the component.
:type version: str
:param description: Description of the component.
:type description: str
:param tags: Tag dictionary. Tags can be added, removed, and updated.
:type tags: dict
:param display_name: Display name of the component.
:type display_name: str
:param command: Command to be executed in component.
:type command: str
:param code: Code file or folder that will be uploaded to the cloud for component execution.
:type code: str
:param environment: Environment that component will run in.
:type environment: Union[Environment, str]
:param distribution: Distribution configuration for distributed training.
:type distribution: Union[dict, PyTorchDistribution, MpiDistribution, TensorFlowDistribution]
:param resources: Compute Resource configuration for the component.
:type resources: Union[dict, ~azure.ai.ml.entities.ResourceConfiguration]
:param inputs: Inputs of the component.
:type inputs: dict
:param outputs: Outputs of the component.
:type outputs: dict
:param instance_count: promoted property from resources.instance_count
:type instance_count: int
"""
def __init__(
self,
*,
name: str = None,
version: str = None,
description: str = None,
tags: Dict = None,
display_name: str = None,
command: str = None,
code: str = None,
environment: Union[str, Environment] = None,
distribution: Union[PyTorchDistribution, MpiDistribution, TensorFlowDistribution] = None,
resources: ResourceConfiguration = None,
inputs: Dict = None,
outputs: Dict = None,
instance_count: int = None, # promoted property from resources.instance_count
**kwargs,
):
# validate init params are valid type
validate_attribute_type(attrs_to_check=locals(), attr_type_map=self._attr_type_map())
kwargs[COMPONENT_TYPE] = NodeType.COMMAND
# Set default base path
if "base_path" not in kwargs:
kwargs["base_path"] = Path(".")
# Component backend doesn't support environment_variables yet,
# this is to support the case of CommandComponent being the trial of
# a SweepJob, where environment_variables is stored as part of trial
environment_variables = kwargs.pop("environment_variables", None)
super().__init__(
name=name,
version=version,
description=description,
tags=tags,
display_name=display_name,
inputs=inputs,
outputs=outputs,
**kwargs,
)
# No validation on value passed here because in pipeline job, required code&environment maybe absent
# and fill in later with job defaults.
self.command = command
self.code = code
self.environment_variables = environment_variables
self.environment = environment
self.resources = resources
self.distribution = distribution
# check mutual exclusivity of promoted properties
if self.resources is not None and instance_count is not None:
msg = "instance_count and resources are mutually exclusive"
raise ValidationException(
message=msg,
target=ErrorTarget.COMPONENT,
no_personal_data_message=msg,
error_category=ErrorCategory.USER_ERROR,
)
self.instance_count = instance_count
@property
def instance_count(self) -> int:
"""
Return value of promoted property resources.instance_count.
:return: Value of resources.instance_count.
:rtype: Optional[int]
"""
return self.resources.instance_count if self.resources else None
@instance_count.setter
def instance_count(self, value: int):
if not value:
return
if not self.resources:
self.resources = ResourceConfiguration(instance_count=value)
else:
self.resources.instance_count = value
@classmethod
def _attr_type_map(cls) -> dict:
return {
"environment": (str, Environment),
"environment_variables": dict,
"resources": (dict, ResourceConfiguration),
"code": (str, os.PathLike),
}
def _to_dict(self) -> Dict:
"""Dump the command component content into a dictionary."""
# Distribution inherits from autorest generated class, use as_dist() to dump to json
# Replace the name of $schema to schema.
component_schema_dict = self._dump_for_validation()
component_schema_dict.pop("base_path", None)
return {**self._other_parameter, **component_schema_dict}
def _get_environment_id(self) -> Union[str, None]:
# Return environment id of environment
# handle case when environment is defined inline
if isinstance(self.environment, Environment):
return self.environment.id
else:
return self.environment
@classmethod
def _create_schema_for_validation(cls, context) -> Union[PathAwareSchema, Schema]:
return CommandComponentSchema(context=context)
def _customized_validate(self):
return self._validate_command()
def _validate_command(self) -> ValidationResult:
# command
if self.command:
invalid_expressions = []
for data_binding_expression in get_all_data_binding_expressions(self.command, is_singular=False):
if not self._is_valid_data_binding_expression(data_binding_expression):
invalid_expressions.append(data_binding_expression)
if invalid_expressions:
error_msg = "Invalid data binding expression: {}".format(", ".join(invalid_expressions))
return _ValidationResultBuilder.from_single_message(error_msg, "command")
return _ValidationResultBuilder.success()
def _is_valid_data_binding_expression(self, data_binding_expression: str) -> bool:
current_obj = self
for item in data_binding_expression.split("."):
if hasattr(current_obj, item):
current_obj = getattr(current_obj, item)
else:
try:
current_obj = current_obj[item]
except Exception:
return False
return True
@classmethod
def _load_from_dict(cls, data: Dict, context: Dict, **kwargs) -> "CommandComponent":
return CommandComponent(
yaml_str=kwargs.pop("yaml_str", None),
_source=kwargs.pop("_source", ComponentSource.YAML),
**(CommandComponentSchema(context=context).load(data, unknown=INCLUDE, **kwargs)),
)
def _to_rest_object(self) -> ComponentVersionData:
# Convert nested ordered dict to dict.
# TODO: we may need to use original dict from component YAML(only change code and environment), returning
# parsed dict might add default value for some field, eg: if we add property "optional" with default value
# to ComponentInput, it will add field "optional" to all inputs even if user doesn't specify one
component = json.loads(json.dumps(self._to_dict()))
properties = ComponentVersionDetails(
component_spec=component,
description=self.description,
is_anonymous=self._is_anonymous,
properties=self.properties,
tags=self.tags,
)
result = ComponentVersionData(properties=properties)
result.name = self.name
return result
@classmethod
def _load_from_rest(cls, obj: ComponentVersionData) -> "CommandComponent":
rest_component_version = obj.properties
inputs = {
k: ComponentInput._from_rest_object(v)
for k, v in rest_component_version.component_spec.pop("inputs", {}).items()
}
outputs = {
k: ComponentOutput._from_rest_object(v)
for k, v in rest_component_version.component_spec.pop("outputs", {}).items()
}
distribution = rest_component_version.component_spec.pop("distribution", None)
if distribution:
distribution = DistributionConfiguration._from_rest_object(distribution)
command_component = CommandComponent(
id=obj.id,
is_anonymous=rest_component_version.is_anonymous,
creation_context=obj.system_data,
inputs=inputs,
outputs=outputs,
distribution=distribution,
# use different schema for component from rest since name may be "invalid"
**RestCommandComponentSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).load(
rest_component_version.component_spec, unknown=INCLUDE
),
_source=ComponentSource.REST,
)
return command_component
@classmethod
def _parse_args_description_from_docstring(cls, docstring):
return parse_args_description_from_docstring(docstring)
def __str__(self):
try:
return self._ordered_yaml()
except BaseException:
return super(CommandComponent, self).__str__()