Source code for azure.keyvault.secrets._shared.client_base

# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from copy import deepcopy
from enum import Enum
from typing import Any
from urllib.parse import urlparse

from azure.core import CaseInsensitiveEnumMeta
from azure.core.credentials import TokenCredential
from azure.core.pipeline.policies import HttpLoggingPolicy
from azure.core.rest import HttpRequest, HttpResponse
from azure.core.tracing.decorator import distributed_trace

from . import ChallengeAuthPolicy
from .._generated import KeyVaultClient as _KeyVaultClient
from .._generated import models as _models
from .._generated._serialization import Serializer
from .._sdk_moniker import SDK_MONIKER


[docs] class ApiVersion(str, Enum, metaclass=CaseInsensitiveEnumMeta): """Key Vault API versions supported by this package""" #: this is the default version V7_5 = "7.5" V7_4 = "7.4" V7_3 = "7.3" V7_2 = "7.2" V7_1 = "7.1" V7_0 = "7.0" V2016_10_01 = "2016-10-01"
DEFAULT_VERSION = ApiVersion.V7_5 _SERIALIZER = Serializer() _SERIALIZER.client_side_validation = False def _format_api_version(request: HttpRequest, api_version: str) -> HttpRequest: """Returns a request copy that includes an api-version query parameter if one wasn't originally present. :param request: The HTTP request being sent. :type request: ~azure.core.rest.HttpRequest :param str api_version: The service API version that the request should include. :returns: A copy of the request that includes an api-version query parameter. :rtype: azure.core.rest.HttpRequest """ request_copy = deepcopy(request) params = {"api-version": api_version} # By default, we want to use the client's API version query = urlparse(request_copy.url).query if query: request_copy.url = request_copy.url.partition("?")[0] existing_params = {p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")]} params.update(existing_params) # If an api-version was provided, this will overwrite our default # Reconstruct the query parameters onto the URL query_params = [] for k, v in params.items(): query_params.append("{}={}".format(k, v)) query = "?" + "&".join(query_params) request_copy.url = request_copy.url + query return request_copy class KeyVaultClientBase(object): # pylint:disable=protected-access def __init__(self, vault_url: str, credential: TokenCredential, **kwargs: Any) -> None: if not credential: raise ValueError( "credential should be an object supporting the TokenCredential protocol, " "such as a credential from azure-identity" ) if not vault_url: raise ValueError("vault_url must be the URL of an Azure Key Vault") try: self.api_version = kwargs.pop("api_version", DEFAULT_VERSION) # If API version was provided as an enum value, need to make a plain string for 3.11 compatibility if hasattr(self.api_version, "value"): self.api_version = self.api_version.value self._vault_url = vault_url.strip(" /") client = kwargs.get("generated_client") if client: # caller provided a configured client -> only models left to initialize self._client = client models = kwargs.get("generated_models") self._models = models or _models return http_logging_policy = HttpLoggingPolicy(**kwargs) http_logging_policy.allowed_header_names.update( {"x-ms-keyvault-network-info", "x-ms-keyvault-region", "x-ms-keyvault-service-version"} ) verify_challenge = kwargs.pop("verify_challenge_resource", True) self._client = _KeyVaultClient( api_version=self.api_version, authentication_policy=ChallengeAuthPolicy(credential, verify_challenge_resource=verify_challenge), sdk_moniker=SDK_MONIKER, http_logging_policy=http_logging_policy, **kwargs ) self._models = _models except ValueError as exc: # Ignore pyright error that comes from not identifying ApiVersion as an iterable enum raise NotImplementedError( f"This package doesn't support API version '{self.api_version}'. " + "Supported versions: " + f"{', '.join(v.value for v in ApiVersion)}" # pyright: ignore[reportGeneralTypeIssues] ) from exc @property def vault_url(self) -> str: return self._vault_url def __enter__(self) -> "KeyVaultClientBase": self._client.__enter__() return self def __exit__(self, *args: Any) -> None: self._client.__exit__(*args) def close(self) -> None: """Close sockets opened by the client. Calling this method is unnecessary when using the client as a context manager. """ self._client.close() @distributed_trace def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs a network request using the client's existing pipeline. The request URL can be relative to the vault URL. The service API version used for the request is the same as the client's unless otherwise specified. This method does not raise if the response is an error; to raise an exception, call `raise_for_status()` on the returned response object. For more information about how to send custom requests with this method, see https://aka.ms/azsdk/dpcodegen/python/send_request. :param request: The network request you want to make. :type request: ~azure.core.rest.HttpRequest :keyword bool stream: Whether the response payload will be streamed. Defaults to False. :return: The response of your network call. Does not do error handling on your response. :rtype: ~azure.core.rest.HttpResponse """ request_copy = _format_api_version(request, self.api_version) path_format_arguments = { "vaultBaseUrl": _SERIALIZER.url("vault_base_url", self._vault_url, "str", skip_quote=True), } request_copy.url = self._client._client.format_url(request_copy.url, **path_format_arguments) return self._client._client.send_request(request_copy, stream=stream, **kwargs)