Source code for azure.core.pipeline.policies._authentication

# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
import time
from typing import TYPE_CHECKING, Optional, TypeVar, MutableMapping, Any, Union, cast
from azure.core.credentials import TokenCredential, SupportsTokenInfo, TokenRequestOptions, TokenProvider
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.transport import HttpResponse as LegacyHttpResponse, HttpRequest as LegacyHttpRequest
from azure.core.rest import HttpResponse, HttpRequest
from . import HTTPPolicy, SansIOHTTPPolicy
from ...exceptions import ServiceRequestError

if TYPE_CHECKING:
    # pylint:disable=unused-import
    from azure.core.credentials import (
        AccessToken,
        AccessTokenInfo,
        AzureKeyCredential,
        AzureSasCredential,
    )

HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse)
HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)


# pylint:disable=too-few-public-methods
class _BearerTokenCredentialPolicyBase:
    """Base class for a Bearer Token Credential Policy.

    :param credential: The credential.
    :type credential: ~azure.core.credentials.TokenProvider
    :param str scopes: Lets you specify the type of access needed.
    :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested
        tokens. Defaults to False.
    """

    def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None:
        super(_BearerTokenCredentialPolicyBase, self).__init__()
        self._scopes = scopes
        self._credential = credential
        self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None
        self._enable_cae: bool = kwargs.get("enable_cae", False)

    @staticmethod
    def _enforce_https(request: PipelineRequest[HTTPRequestType]) -> None:
        # move 'enforce_https' from options to context so it persists
        # across retries but isn't passed to a transport implementation
        option = request.context.options.pop("enforce_https", None)

        # True is the default setting; we needn't preserve an explicit opt in to the default behavior
        if option is False:
            request.context["enforce_https"] = option

        enforce_https = request.context.get("enforce_https", True)
        if enforce_https and not request.http_request.url.lower().startswith("https"):
            raise ServiceRequestError(
                "Bearer token authentication is not permitted for non-TLS protected (non-https) URLs."
            )

    @staticmethod
    def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
        """Updates the Authorization header with the bearer token.

        :param MutableMapping[str, str] headers: The HTTP Request headers
        :param str token: The OAuth token.
        """
        headers["Authorization"] = "Bearer {}".format(token)

    @property
    def _need_new_token(self) -> bool:
        now = time.time()
        refresh_on = getattr(self._token, "refresh_on", None)
        return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300

    def _request_token(self, *scopes: str, **kwargs: Any) -> None:
        """Request a new token from the credential.

        This will call the credential's appropriate method to get a token and store it in the policy.

        :param str scopes: The type of access needed.
        """
        if self._enable_cae:
            kwargs.setdefault("enable_cae", self._enable_cae)

        if hasattr(self._credential, "get_token_info"):
            options: TokenRequestOptions = {}
            # Loop through all the keyword arguments and check if they are part of the TokenRequestOptions.
            for key in list(kwargs.keys()):
                if key in TokenRequestOptions.__annotations__:  # pylint:disable=no-member
                    options[key] = kwargs.pop(key)  # type: ignore[literal-required]

            self._token = cast(SupportsTokenInfo, self._credential).get_token_info(*scopes, options=options)
        else:
            self._token = cast(TokenCredential, self._credential).get_token(*scopes, **kwargs)


[docs] class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]): """Adds a bearer token Authorization header to requests. :param credential: The credential. :type credential: ~azure.core.TokenCredential :param str scopes: Lets you specify the type of access needed. :keyword bool enable_cae: Indicates whether to enable Continuous Access Evaluation (CAE) on all requested tokens. Defaults to False. :raises: :class:`~azure.core.exceptions.ServiceRequestError` """
[docs] def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: """Called before the policy sends a request. The base implementation authorizes the request with a bearer token. :param ~azure.core.pipeline.PipelineRequest request: the request """ self._enforce_https(request) if self._token is None or self._need_new_token: self._request_token(*self._scopes) bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token self._update_headers(request.http_request.headers, bearer_token)
[docs] def authorize_request(self, request: PipelineRequest[HTTPRequestType], *scopes: str, **kwargs: Any) -> None: """Acquire a token from the credential and authorize the request with it. Keyword arguments are passed to the credential's get_token method. The token will be cached and used to authorize future requests. :param ~azure.core.pipeline.PipelineRequest request: the request :param str scopes: required scopes of authentication """ self._request_token(*scopes, **kwargs) bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token self._update_headers(request.http_request.headers, bearer_token)
[docs] def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: """Authorize request with a bearer token and send it to the next policy :param request: The pipeline request object :type request: ~azure.core.pipeline.PipelineRequest :return: The pipeline response object :rtype: ~azure.core.pipeline.PipelineResponse """ self.on_request(request) try: response = self.next.send(request) except Exception: # pylint:disable=broad-except self.on_exception(request) raise self.on_response(request, response) if response.http_response.status_code == 401: self._token = None # any cached token is invalid if "WWW-Authenticate" in response.http_response.headers: request_authorized = self.on_challenge(request, response) if request_authorized: # if we receive a challenge response, we retrieve a new token # which matches the new target. In this case, we don't want to remove # token from the request so clear the 'insecure_domain_change' tag request.context.options.pop("insecure_domain_change", False) try: response = self.next.send(request) self.on_response(request, response) except Exception: # pylint:disable=broad-except self.on_exception(request) raise return response
[docs] def on_challenge( self, request: PipelineRequest[HTTPRequestType], response: PipelineResponse[HTTPRequestType, HTTPResponseType] ) -> bool: """Authorize request according to an authentication challenge This method is called when the resource provider responds 401 with a WWW-Authenticate header. :param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge :param ~azure.core.pipeline.PipelineResponse response: the resource provider's response :returns: a bool indicating whether the policy should send the request :rtype: bool """ # pylint:disable=unused-argument return False
[docs] def on_response( self, request: PipelineRequest[HTTPRequestType], response: PipelineResponse[HTTPRequestType, HTTPResponseType] ) -> None: """Executed after the request comes back from the next policy. :param request: Request to be modified after returning from the policy. :type request: ~azure.core.pipeline.PipelineRequest :param response: Pipeline response object :type response: ~azure.core.pipeline.PipelineResponse """
[docs] def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None: """Executed when an exception is raised while executing the next policy. This method is executed inside the exception handler. :param request: The Pipeline request object :type request: ~azure.core.pipeline.PipelineRequest """ # pylint: disable=unused-argument return
[docs] class AzureKeyCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): """Adds a key header for the provided credential. :param credential: The credential used to authenticate requests. :type credential: ~azure.core.credentials.AzureKeyCredential :param str name: The name of the key header used for the credential. :keyword str prefix: The name of the prefix for the header value if any. :raises: ValueError or TypeError """ def __init__( # pylint: disable=unused-argument self, credential: "AzureKeyCredential", name: str, *, prefix: Optional[str] = None, **kwargs: Any, ) -> None: super().__init__() if not hasattr(credential, "key"): raise TypeError("String is not a supported credential input type. Use an instance of AzureKeyCredential.") if not name: raise ValueError("name can not be None or empty") if not isinstance(name, str): raise TypeError("name must be a string.") self._credential = credential self._name = name self._prefix = prefix + " " if prefix else ""
[docs] def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: request.http_request.headers[self._name] = f"{self._prefix}{self._credential.key}"
[docs] class AzureSasCredentialPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]): """Adds a shared access signature to query for the provided credential. :param credential: The credential used to authenticate requests. :type credential: ~azure.core.credentials.AzureSasCredential :raises: ValueError or TypeError """ def __init__(self, credential: "AzureSasCredential", **kwargs: Any) -> None: # pylint: disable=unused-argument super(AzureSasCredentialPolicy, self).__init__() if not credential: raise ValueError("credential can not be None") self._credential = credential
[docs] def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None: url = request.http_request.url query = request.http_request.query signature = self._credential.signature if signature.startswith("?"): signature = signature[1:] if query: if signature not in url: url = url + "&" + signature else: if url.endswith("?"): url = url + signature else: url = url + "?" + signature request.http_request.url = url