Source code for azure.core.pipeline.policies._redirect

# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
"""
This module is the requests implementation of Pipeline ABC
"""
import logging
from urllib.parse import urlparse
from typing import Optional, TypeVar, Dict, Any, Union, Type
from typing_extensions import Literal

from azure.core.exceptions import TooManyRedirectsError
from azure.core.pipeline import PipelineResponse, PipelineRequest
from azure.core.pipeline.transport import (
    HttpResponse as LegacyHttpResponse,
    HttpRequest as LegacyHttpRequest,
    AsyncHttpResponse as LegacyAsyncHttpResponse,
)
from azure.core.rest import HttpResponse, HttpRequest, AsyncHttpResponse
from ._base import HTTPPolicy, RequestHistory
from ._utils import get_domain

HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse)
AllHttpResponseType = TypeVar(
    "AllHttpResponseType", HttpResponse, LegacyHttpResponse, AsyncHttpResponse, LegacyAsyncHttpResponse
)
HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
ClsRedirectPolicy = TypeVar("ClsRedirectPolicy", bound="RedirectPolicyBase")

_LOGGER = logging.getLogger(__name__)


def domain_changed(original_domain: Optional[str], url: str) -> bool:
    """Checks if the domain has changed.
    :param str original_domain: The original domain.
    :param str url: The new url.
    :rtype: bool
    :return: Whether the domain has changed.
    """
    domain = get_domain(url)
    if not original_domain:
        return False
    if original_domain == domain:
        return False
    return True


class RedirectPolicyBase:

    REDIRECT_STATUSES = frozenset([300, 301, 302, 303, 307, 308])

    REDIRECT_HEADERS_BLACKLIST = frozenset(["Authorization"])

    def __init__(self, **kwargs: Any) -> None:
        self.allow: bool = kwargs.get("permit_redirects", True)
        self.max_redirects: int = kwargs.get("redirect_max", 30)

        remove_headers = set(kwargs.get("redirect_remove_headers", []))
        self._remove_headers_on_redirect = remove_headers.union(self.REDIRECT_HEADERS_BLACKLIST)
        redirect_status = set(kwargs.get("redirect_on_status_codes", []))
        self._redirect_on_status_codes = redirect_status.union(self.REDIRECT_STATUSES)
        super(RedirectPolicyBase, self).__init__()

    @classmethod
    def no_redirects(cls: Type[ClsRedirectPolicy]) -> ClsRedirectPolicy:
        """Disable redirects.

        :return: A redirect policy with redirects disabled.
        :rtype: ~azure.core.pipeline.policies.RedirectPolicy or ~azure.core.pipeline.policies.AsyncRedirectPolicy
        """
        return cls(permit_redirects=False)

    def configure_redirects(self, options: Dict[str, Any]) -> Dict[str, Any]:
        """Configures the redirect settings.

        :param options: Keyword arguments from context.
        :type options: dict
        :return: A dict containing redirect settings and a history of redirects.
        :rtype: dict
        """
        return {
            "allow": options.pop("permit_redirects", self.allow),
            "redirects": options.pop("redirect_max", self.max_redirects),
            "history": [],
        }

    def get_redirect_location(
        self, response: PipelineResponse[Any, AllHttpResponseType]
    ) -> Union[str, None, Literal[False]]:
        """Checks for redirect status code and gets redirect location.

        :param response: The PipelineResponse object
        :type response: ~azure.core.pipeline.PipelineResponse
        :return: Truthy redirect location string if we got a redirect status
         code and valid location. ``None`` if redirect status and no
         location. ``False`` if not a redirect status code.
        :rtype: str or bool or None
        """
        if response.http_response.status_code in [301, 302]:
            if response.http_request.method in [
                "GET",
                "HEAD",
            ]:
                return response.http_response.headers.get("location")
            return False
        if response.http_response.status_code in self._redirect_on_status_codes:
            return response.http_response.headers.get("location")

        return False

    def increment(
        self, settings: Dict[str, Any], response: PipelineResponse[Any, AllHttpResponseType], redirect_location: str
    ) -> bool:
        """Increment the redirect attempts for this request.

        :param dict settings: The redirect settings
        :param response: A pipeline response object.
        :type response: ~azure.core.pipeline.PipelineResponse
        :param str redirect_location: The redirected endpoint.
        :return: Whether further redirect attempts are remaining.
         False if exhausted; True if more redirect attempts available.
        :rtype: bool
        """
        # TODO: Revise some of the logic here.
        settings["redirects"] -= 1
        settings["history"].append(RequestHistory(response.http_request, http_response=response.http_response))

        redirected = urlparse(redirect_location)
        if not redirected.netloc:
            base_url = urlparse(response.http_request.url)
            response.http_request.url = "{}://{}/{}".format(
                base_url.scheme, base_url.netloc, redirect_location.lstrip("/")
            )
        else:
            response.http_request.url = redirect_location
        if response.http_response.status_code == 303:
            response.http_request.method = "GET"
        for non_redirect_header in self._remove_headers_on_redirect:
            response.http_request.headers.pop(non_redirect_header, None)
        return settings["redirects"] >= 0


[docs]class RedirectPolicy(RedirectPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]): """A redirect policy. A redirect policy in the pipeline can be configured directly or per operation. :keyword bool permit_redirects: Whether the client allows redirects. Defaults to True. :keyword int redirect_max: The maximum allowed redirects. Defaults to 30. .. admonition:: Example: .. literalinclude:: ../samples/test_example_sync.py :start-after: [START redirect_policy] :end-before: [END redirect_policy] :language: python :dedent: 4 :caption: Configuring a redirect policy. """
[docs] def send(self, request: PipelineRequest[HTTPRequestType]) -> PipelineResponse[HTTPRequestType, HTTPResponseType]: """Sends the PipelineRequest object to the next policy. Uses redirect settings to send request to redirect endpoint if necessary. :param request: The PipelineRequest object :type request: ~azure.core.pipeline.PipelineRequest :return: Returns the PipelineResponse or raises error if maximum redirects exceeded. :rtype: ~azure.core.pipeline.PipelineResponse :raises: ~azure.core.exceptions.TooManyRedirectsError if maximum redirects exceeded. """ retryable: bool = True redirect_settings = self.configure_redirects(request.context.options) original_domain = get_domain(request.http_request.url) if redirect_settings["allow"] else None while retryable: response = self.next.send(request) redirect_location = self.get_redirect_location(response) if redirect_location and redirect_settings["allow"]: retryable = self.increment(redirect_settings, response, redirect_location) request.http_request = response.http_request if domain_changed(original_domain, request.http_request.url): # "insecure_domain_change" is used to indicate that a redirect # has occurred to a different domain. This tells the SensitiveHeaderCleanupPolicy # to clean up sensitive headers. We need to remove it before sending the request # to the transport layer. request.context.options["insecure_domain_change"] = True continue return response raise TooManyRedirectsError(redirect_settings["history"])