# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import logging
from typing import Dict, Any, Union
from azure.ai.ml._schema._sweep.sweep_job import SweepJobSchema
from azure.ai.ml.constants import (
BASE_PATH_CONTEXT_KEY,
TYPE,
JobType,
)
from azure.ai.ml._restclient.v2022_02_01_preview.models import (
SweepJob as RestSweepJob,
SamplingAlgorithm as RestSamplingAlgorithm,
JobBaseData,
TrialComponent,
ManagedIdentity,
UserIdentity,
AmlToken,
)
from azure.ai.ml.entities import Job, CommandJob
from ..parameterized_command import ParameterizedCommand
from azure.ai.ml.entities._component.command_component import CommandComponent
from azure.ai.ml.entities._inputs_outputs import Input, Output
from azure.ai.ml.entities._job.sweep.sampling_algorithm import SamplingAlgorithm
from azure.ai.ml.entities._util import load_from_dict
from .objective import Objective
from .parameterized_sweep import SAMPLING_ALGORITHM_CONSTRUCTOR, ParameterizedSweep
from .search_space import SweepDistribution
from .early_termination_policy import EarlyTerminationPolicy
from azure.ai.ml.entities._job._input_output_helpers import (
to_rest_dataset_literal_inputs,
from_rest_inputs_to_dataset_literal,
validate_inputs_for_command,
from_rest_data_outputs,
to_rest_data_outputs,
validate_key_contains_allowed_characters,
)
from azure.ai.ml._utils.utils import map_single_brackets_and_warn
from azure.ai.ml.entities._job.job_io_mixin import JobIOMixin
from ..job_limits import SweepJobLimits
from azure.ai.ml._ml_exceptions import ErrorTarget, JobException
module_logger = logging.getLogger(__name__)
[docs]class SweepJob(Job, ParameterizedSweep, JobIOMixin):
"""Sweep job for hyperparameter tuning.
:param name: Name of the job.
:type name: str
:param display_name: Display name of the job.
:type display_name: str
:param description: Description of the job.
:type description: str
:param tags: Tag dictionary. Tags can be added, removed, and updated.
:type tags: dict[str, str]
:param properties: The asset property dictionary.
:type properties: dict[str, str]
:param experiment_name: Name of the experiment the job will be created under, if None is provided, job will be created under experiment 'Default'.
:type experiment_name: str
:param identity: Identity that training job will use while running on compute.
:type identity: Union[azure.ai.ml.ManagedIdentity, azure.ai.ml.AmlToken, azure.ai.ml.UserIdentity]
:param inputs: Inputs to the command.
:type inputs: dict
:param outputs: Mapping of output data bindings used in the job.
:type outputs: dict[str, azure.ai.ml.Output]
:param sampling_algorithm: The hyperparameter sampling algorithm to use over the `search_space`. Defaults to "random".
:type sampling_algorithm: str
:param search_space: Dictionary of the hyperparameter search space. The key is the name of the hyperparameter and the value is the parameter expression.
:type search_space: Dict
:param objective: Metric to optimize for.
:type objective: Objective
:param compute: The compute target the job runs on.
:type compute: str
:param trial: The job configuration for each trial. Each trial will be provided with a different combination of hyperparameter values that the system samples from the search_space.
:type trial: Union[azure.ai.ml.entities.CommandJob, azure.ai.ml.entities.CommandComponent]
:param early_termination: The early termination policy to use. A trial job is canceled when the criteria of the specified policy are met. If omitted, no early termination policy will be applied.
:type early_termination: EarlyTerminationPolicy
:param limits: Limits for the sweep job.
:type limits: ~azure.ai.ml.entities.SweepJobLimits
:param kwargs: A dictionary of additional configuration parameters.
:type kwargs: dict
"""
def __init__(
self,
*,
name: str = None,
description: str = None,
tags: Dict = None,
display_name: str = None,
experiment_name: str = None,
identity: Union[ManagedIdentity, AmlToken, UserIdentity] = None,
inputs: Dict[str, Union[Input, str, bool, int, float]] = None,
outputs: Dict[str, Output] = None,
compute: str = None,
limits: SweepJobLimits = None,
sampling_algorithm: Union[str, SamplingAlgorithm] = None,
search_space: Dict[str, SweepDistribution] = None,
objective: Objective = None,
trial: Union[CommandJob, CommandComponent] = None,
early_termination: EarlyTerminationPolicy = None,
**kwargs: Any,
):
kwargs[TYPE] = JobType.SWEEP
Job.__init__(
self,
name=name,
description=description,
tags=tags,
display_name=display_name,
experiment_name=experiment_name,
compute=compute,
**kwargs,
)
self.inputs = inputs
self.outputs = outputs
self.trial = trial
self.identity = identity
ParameterizedSweep.__init__(
self,
limits=limits,
sampling_algorithm=sampling_algorithm,
objective=objective,
early_termination=early_termination,
search_space=search_space,
)
def _to_dict(self) -> Dict:
return SweepJobSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
def _to_rest_object(self) -> JobBaseData:
self._override_missing_properties_from_trial()
self.trial.command = map_single_brackets_and_warn(self.trial.command)
search_space = {param: space._to_rest_object() for (param, space) in self.search_space.items()}
validate_inputs_for_command(self.trial.command, self.inputs)
for key in search_space.keys():
validate_key_contains_allowed_characters(key)
trial_component = TrialComponent(
code_id=self.trial.code,
environment_id=self.trial.environment,
command=self.trial.command,
environment_variables=self.trial.environment_variables,
resources=self.trial.resources._to_rest_object() if self.trial.resources else None,
)
sweep_job = RestSweepJob(
display_name=self.display_name,
description=self.description,
experiment_name=self.experiment_name,
search_space=search_space,
sampling_algorithm=self._get_rest_sampling_algorithm(),
limits=self.limits._to_rest_object() if self.limits else None,
early_termination=self.early_termination,
properties=self.properties,
compute_id=self.compute,
objective=self.objective,
trial=trial_component,
tags=self.tags,
inputs=to_rest_dataset_literal_inputs(self.inputs),
outputs=to_rest_data_outputs(self.outputs),
identity=self.identity,
)
sweep_job_resource = JobBaseData(properties=sweep_job)
sweep_job_resource.name = self.name
return sweep_job_resource
def _to_component(self, context: Dict = None, **kwargs):
msg = "no sweep component entity"
raise JobException(message=msg, no_personal_data_message=msg, target=ErrorTarget.SWEEP_JOB)
@classmethod
def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs) -> "SweepJob":
loaded_schema = load_from_dict(SweepJobSchema, data, context, additional_message, **kwargs)
loaded_schema["trial"] = ParameterizedCommand(**(loaded_schema["trial"]))
sweep_job = SweepJob(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_schema)
return sweep_job
@classmethod
def _load_from_rest(cls, obj: JobBaseData) -> "SweepJob":
properties: RestSweepJob = obj.properties
# Unpack termination schema
early_termination = EarlyTerminationPolicy._from_rest_object(properties.early_termination)
# Unpack sampling algorithm
sampling_algorithm = SamplingAlgorithm._from_rest_object(properties.sampling_algorithm)
trial = ParameterizedCommand._load_from_sweep_job(obj.properties)
# Compute also appears in both layers of the yaml, but only one of the REST.
# This should be a required field in one place, but cannot be if its optional in two
return SweepJob(
name=obj.name,
id=obj.id,
display_name=properties.display_name,
description=properties.description,
properties=properties.properties,
tags=properties.tags,
experiment_name=properties.experiment_name,
services=properties.services,
status=properties.status,
creation_context=obj.system_data,
trial=trial,
compute=properties.compute_id,
sampling_algorithm=sampling_algorithm,
search_space={
param: SweepDistribution._from_rest_object(dist) for (param, dist) in properties.search_space.items()
},
limits=SweepJobLimits._from_rest_object(properties.limits),
early_termination=early_termination,
objective=properties.objective,
inputs=from_rest_inputs_to_dataset_literal(properties.inputs),
outputs=from_rest_data_outputs(properties.outputs),
identity=properties.identity,
)
def _override_missing_properties_from_trial(self):
if not isinstance(self.trial, CommandJob):
return
if not self.compute:
self.compute = self.trial.compute
if not self.inputs:
self.inputs = self.trial.inputs
if not self.outputs:
self.outputs = self.trial.outputs
has_trial_limits_timeout = self.trial.limits and self.trial.limits.timeout
if has_trial_limits_timeout and not self.limits:
self.limits = SweepJobLimits(trial_timeout=self.trial.limits.timeout)
elif has_trial_limits_timeout and not self.limits.trial_timeout:
self.limits.trial_timeout = self.trial.limits.timeout