# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from copy import deepcopy
from enum import Enum
from typing import Any
from urllib.parse import urlparse
from azure.core import CaseInsensitiveEnumMeta
from azure.core.credentials import TokenCredential
from azure.core.pipeline.policies import HttpLoggingPolicy
from azure.core.rest import HttpRequest, HttpResponse
from azure.core.tracing.decorator import distributed_trace
from . import ChallengeAuthPolicy
from .._generated import KeyVaultClient as _KeyVaultClient
from .._generated import models as _models
from .._generated._serialization import Serializer
from .._sdk_moniker import SDK_MONIKER
[docs]
class ApiVersion(str, Enum, metaclass=CaseInsensitiveEnumMeta):
"""Key Vault API versions supported by this package"""
#: this is the default version
V7_5 = "7.5"
V7_4 = "7.4"
V7_3 = "7.3"
V7_2 = "7.2"
V7_1 = "7.1"
V7_0 = "7.0"
V2016_10_01 = "2016-10-01"
DEFAULT_VERSION = ApiVersion.V7_5
_SERIALIZER = Serializer()
_SERIALIZER.client_side_validation = False
def _format_api_version(request: HttpRequest, api_version: str) -> HttpRequest:
"""Returns a request copy that includes an api-version query parameter if one wasn't originally present.
:param request: The HTTP request being sent.
:type request: ~azure.core.rest.HttpRequest
:param str api_version: The service API version that the request should include.
:returns: A copy of the request that includes an api-version query parameter.
:rtype: azure.core.rest.HttpRequest
"""
request_copy = deepcopy(request)
params = {"api-version": api_version} # By default, we want to use the client's API version
query = urlparse(request_copy.url).query
if query:
request_copy.url = request_copy.url.partition("?")[0]
existing_params = {p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")]}
params.update(existing_params) # If an api-version was provided, this will overwrite our default
# Reconstruct the query parameters onto the URL
query_params = []
for k, v in params.items():
query_params.append("{}={}".format(k, v))
query = "?" + "&".join(query_params)
request_copy.url = request_copy.url + query
return request_copy
class KeyVaultClientBase(object):
# pylint:disable=protected-access
def __init__(self, vault_url: str, credential: TokenCredential, **kwargs: Any) -> None:
if not credential:
raise ValueError(
"credential should be an object supporting the TokenCredential protocol, "
"such as a credential from azure-identity"
)
if not vault_url:
raise ValueError("vault_url must be the URL of an Azure Key Vault")
try:
self.api_version = kwargs.pop("api_version", DEFAULT_VERSION)
# If API version was provided as an enum value, need to make a plain string for 3.11 compatibility
if hasattr(self.api_version, "value"):
self.api_version = self.api_version.value
self._vault_url = vault_url.strip(" /")
client = kwargs.get("generated_client")
if client:
# caller provided a configured client -> only models left to initialize
self._client = client
models = kwargs.get("generated_models")
self._models = models or _models
return
http_logging_policy = HttpLoggingPolicy(**kwargs)
http_logging_policy.allowed_header_names.update(
{"x-ms-keyvault-network-info", "x-ms-keyvault-region", "x-ms-keyvault-service-version"}
)
verify_challenge = kwargs.pop("verify_challenge_resource", True)
self._client = _KeyVaultClient(
api_version=self.api_version,
authentication_policy=ChallengeAuthPolicy(credential, verify_challenge_resource=verify_challenge),
sdk_moniker=SDK_MONIKER,
http_logging_policy=http_logging_policy,
**kwargs
)
self._models = _models
except ValueError as exc:
# Ignore pyright error that comes from not identifying ApiVersion as an iterable enum
raise NotImplementedError(
f"This package doesn't support API version '{self.api_version}'. "
+ "Supported versions: "
+ f"{', '.join(v.value for v in ApiVersion)}" # pyright: ignore[reportGeneralTypeIssues]
) from exc
@property
def vault_url(self) -> str:
return self._vault_url
def __enter__(self) -> "KeyVaultClientBase":
self._client.__enter__()
return self
def __exit__(self, *args: Any) -> None:
self._client.__exit__(*args)
def close(self) -> None:
"""Close sockets opened by the client.
Calling this method is unnecessary when using the client as a context manager.
"""
self._client.close()
@distributed_trace
def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse:
"""Runs a network request using the client's existing pipeline.
The request URL can be relative to the vault URL. The service API version used for the request is the same as
the client's unless otherwise specified. This method does not raise if the response is an error; to raise an
exception, call `raise_for_status()` on the returned response object. For more information about how to send
custom requests with this method, see https://aka.ms/azsdk/dpcodegen/python/send_request.
:param request: The network request you want to make.
:type request: ~azure.core.rest.HttpRequest
:keyword bool stream: Whether the response payload will be streamed. Defaults to False.
:return: The response of your network call. Does not do error handling on your response.
:rtype: ~azure.core.rest.HttpResponse
"""
request_copy = _format_api_version(request, self.api_version)
path_format_arguments = {
"vaultBaseUrl": _SERIALIZER.url("vault_base_url", self._vault_url, "str", skip_quote=True),
}
request_copy.url = self._client._client.format_url(request_copy.url, **path_format_arguments)
return self._client._client.send_request(request_copy, stream=stream, **kwargs)