Source code for azure.ai.ml.entities._job.distribution

# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------


from typing import Optional, Union, Dict
from azure.ai.ml._restclient.v2022_02_01_preview.models import (
    Mpi as RestMpi,
    PyTorch as RestPyTorch,
    TensorFlow as RestTensorFlow,
    DistributionConfiguration as RestDistributionConfiguration,
    DistributionType as RestDistributionType,
)
from azure.ai.ml.entities._util import SnakeToPascalDescriptor
from azure.ai.ml.entities._mixins import RestTranslatableMixin
from azure.ai.ml.constants import DistributionType

SDK_TO_REST = {
    DistributionType.MPI: RestDistributionType.MPI,
    DistributionType.TENSORFLOW: RestDistributionType.TENSOR_FLOW,
    DistributionType.PYTORCH: RestDistributionType.PY_TORCH,
}


class DistributionConfiguration(RestTranslatableMixin):

    type = SnakeToPascalDescriptor(
        "distribution_type", transformer=lambda x: SDK_TO_REST.get(x, None), reverse_transformer=lambda x: x.lower()
    )

    def __init__(self) -> None:
        pass

    @classmethod
    def _from_rest_object(
        cls, obj: Optional[Union[RestDistributionConfiguration, Dict]]
    ) -> "DistributionConfiguration":
        """
        This function works for distribution property of a Job object and of a Component object()

        Distribution of Job when returned by MFE, is a RestDistributionConfiguration

        Distribution of Component when returned by MFE, is a Dict.
        e.g. {'type': 'Mpi', 'process_count_per_instance': '1'}

        So in the job distribution case, we need to call as_dist() first and get type from "distribution_type" property.
        In the componenet case, we need to extract type from key "type"

        """
        if obj is None:
            return None

        data = obj
        if isinstance(obj, RestDistributionConfiguration):
            data = obj.__dict__

        type_str = data.pop("distribution_type", None) or data.pop("type", None)
        cls = DISTRIBUTION_TYPE_MAP[type_str.lower()]
        return cls(**data)


[docs]class MpiDistribution(RestMpi, DistributionConfiguration): """MPI distribution configuration. :param process_count_per_instance: Number of processes per MPI node. :type process_count_per_instance: int """ def __init__(self, *, process_count_per_instance: Optional[int] = None, **kwargs): super(MpiDistribution, self).__init__(process_count_per_instance=process_count_per_instance, **kwargs)
[docs]class PyTorchDistribution(RestPyTorch, DistributionConfiguration): """PyTorch distribution configuration. :param process_count_per_instance: Number of processes per node. :type process_count_per_instance: int """ def __init__(self, *, process_count_per_instance: Optional[int] = None, **kwargs): super(PyTorchDistribution, self).__init__(process_count_per_instance=process_count_per_instance)
[docs]class TensorFlowDistribution(RestTensorFlow, DistributionConfiguration): """TensorFlow distribution configuration. :vartype distribution_type: str or ~azure.mgmt.machinelearningservices.models.DistributionType :ivar parameter_server_count: Number of parameter server tasks. :vartype parameter_server_count: int :ivar worker_count: Number of workers. If not specified, will default to the instance count. :vartype worker_count: int """ def __init__(self, *, parameter_server_count: Optional[int] = 0, worker_count: Optional[int] = None, **kwargs): super(TensorFlowDistribution, self).__init__( parameter_server_count=parameter_server_count, worker_count=worker_count, **kwargs )
DISTRIBUTION_TYPE_MAP = { DistributionType.MPI: MpiDistribution, DistributionType.TENSORFLOW: TensorFlowDistribution, DistributionType.PYTORCH: PyTorchDistribution, }