Source code for azure.confidentialledger.aio._client_base

# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

from typing import Any, TYPE_CHECKING, Union

from azure.core.pipeline.policies import (
    AsyncBearerTokenCredentialPolicy,
    HttpLoggingPolicy,
)

from .._user_agent import USER_AGENT
from .._generated._generated_ledger.v0_1_preview.aio import (
    ConfidentialLedgerClient as _ConfidentialLedgerClient,
)
from .._shared import ConfidentialLedgerCertificateCredential, DEFAULT_VERSION

if TYPE_CHECKING:
    from azure.core.credentials import TokenCredential


class AsyncConfidentialLedgerClientBase(object):
    def __init__(
        self,
        *,
        endpoint: str,
        credential: Union[ConfidentialLedgerCertificateCredential, "TokenCredential"],
        ledger_certificate_path: str,
        **kwargs: Any
    ) -> None:

        client = kwargs.get("generated_client")
        if client:
            # caller provided a configured client -> nothing left to initialize
            self._client = client
            return

        if not endpoint:
            raise ValueError("Expected endpoint to be a non-empty string")

        if not credential:
            raise ValueError("Expected credential to not be None")

        if not isinstance(ledger_certificate_path, str):
            raise TypeError("ledger_certificate_path must be a string")

        if ledger_certificate_path == "":
            raise ValueError(
                "If not None, ledger_certificate_path must be a non-empty string"
            )

        endpoint = endpoint.strip(" /")
        try:
            if not endpoint.startswith("https://"):
                self._endpoint = "https://" + endpoint
            else:
                self._endpoint = endpoint
        except AttributeError:
            raise ValueError("Confidential Ledger URL must be a string.")

        self.api_version = kwargs.pop("api_version", DEFAULT_VERSION)

        if not kwargs.get("transport", None):
            # Customize the transport layer to use client certificate authentication and validate
            # a self-signed TLS certificate.
            if isinstance(credential, ConfidentialLedgerCertificateCredential):
                # The async version of the client seems to expect a sequence of filenames.
                # azure/core/pipeline/transport/_aiohttp.py:163
                # > ssl_ctx.load_cert_chain(*cert)
                kwargs["connection_cert"] = (credential.certificate_path,)

            kwargs["connection_verify"] = ledger_certificate_path

        http_logging_policy = HttpLoggingPolicy(**kwargs)
        http_logging_policy.allowed_header_names.update(
            {
                "x-ms-keyvault-network-info",
                "x-ms-keyvault-region",
                "x-ms-keyvault-service-version",
            }
        )

        if not isinstance(credential, ConfidentialLedgerCertificateCredential):
            kwargs["authentication_policy"] = kwargs.pop(
                "authentication_policy",
                AsyncBearerTokenCredentialPolicy(
                    credential,
                    "https://confidential-ledger.azure.com/.default",
                    **kwargs
                ),
            )

        try:
            self._client = _ConfidentialLedgerClient(
                self._endpoint,
                api_version=self.api_version,
                http_logging_policy=http_logging_policy,
                sdk_moniker=USER_AGENT,
                **kwargs
            )
        except NotImplementedError:
            raise NotImplementedError(
                "This package doesn't support API version '{}'. ".format(
                    self.api_version
                )
                + "Supported versions: 0.1-preview"
            )

    @property
    def endpoint(self) -> str:
        """The URL this client is connected to."""
        return self._endpoint

    async def __aenter__(self) -> "AsyncConfidentialLedgerClientBase":
        await self._client.__aenter__()
        return self

    async def __aexit__(self, *args: Any) -> None:
        await self._client.__aexit__(*args)

    async def close(self) -> None:
        """Close sockets opened by the client.

        Calling this method is unnecessary when using the client as a context manager.
        """
        await self._client.close()