Source code for azure.core.pipeline.transport._aiohttp

# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
import sys
from typing import (
    Any, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload
)
from collections.abc import AsyncIterator

import logging
import asyncio
import codecs
import aiohttp
from multidict import CIMultiDict

from azure.core.configuration import ConnectionConfiguration
from azure.core.exceptions import ServiceRequestError, ServiceResponseError, IncompleteReadError
from azure.core.pipeline import Pipeline

from ._base import HttpRequest
from ._base_async import (
    AsyncHttpTransport,
    AsyncHttpResponse,
    _ResponseStopIteration)
from ...utils._pipeline_transport_rest_shared import _aiohttp_body_helper
from .._tools import is_rest as _is_rest
from .._tools_async import handle_no_stream_rest_response as _handle_no_stream_rest_response
if TYPE_CHECKING:
    from ...rest import (
        HttpRequest as RestHttpRequest,
        AsyncHttpResponse as RestAsyncHttpResponse,
    )

# Matching requests, because why not?
CONTENT_CHUNK_SIZE = 10 * 1024
_LOGGER = logging.getLogger(__name__)

[docs]class AioHttpTransport(AsyncHttpTransport): """AioHttp HTTP sender implementation. Fully asynchronous implementation using the aiohttp library. :param session: The client session. :param bool session_owner: Session owner. Defaults True. :keyword bool use_env_settings: Uses proxy settings from environment. Defaults to True. .. admonition:: Example: .. literalinclude:: ../samples/test_example_async.py :start-after: [START aiohttp] :end-before: [END aiohttp] :language: python :dedent: 4 :caption: Asynchronous transport with aiohttp. """ def __init__(self, *, session: Optional[aiohttp.ClientSession] = None, loop=None, session_owner=True, **kwargs): if loop and sys.version_info >= (3, 10): raise ValueError("Starting with Python 3.10, asyncio doesn’t support loop as a parameter anymore") self._loop = loop self._session_owner = session_owner self.session = session self.connection_config = ConnectionConfiguration(**kwargs) self._use_env_settings = kwargs.pop('use_env_settings', True) async def __aenter__(self): await self.open() return self async def __aexit__(self, *args): # pylint: disable=arguments-differ await self.close()
[docs] async def open(self): """Opens the connection. """ if not self.session and self._session_owner: jar = aiohttp.DummyCookieJar() clientsession_kwargs = { "trust_env": self._use_env_settings, "cookie_jar": jar, "auto_decompress": False, } if self._loop is not None: clientsession_kwargs["loop"] = self._loop self.session = aiohttp.ClientSession(**clientsession_kwargs) if self.session is not None: await self.session.__aenter__()
[docs] async def close(self): """Closes the connection. """ if self._session_owner and self.session: await self.session.close() self._session_owner = False self.session = None
def _build_ssl_config(self, cert, verify): # pylint: disable=no-self-use ssl_ctx = None if cert or verify not in (True, False): import ssl if verify not in (True, False): ssl_ctx = ssl.create_default_context(cafile=verify) else: ssl_ctx = ssl.create_default_context() if cert: ssl_ctx.load_cert_chain(*cert) return ssl_ctx return verify def _get_request_data(self, request): #pylint: disable=no-self-use if request.files: form_data = aiohttp.FormData() for form_file, data in request.files.items(): content_type = data[2] if len(data) > 2 else None try: form_data.add_field(form_file, data[1], filename=data[0], content_type=content_type) except IndexError: raise ValueError("Invalid formdata formatting: {}".format(data)) return form_data return request.data @overload async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpResponse]: """Send the request using this HTTP sender. Will pre-load the body into memory to be available with a sync method. Pass stream=True to avoid this behavior. :param request: The HttpRequest object :type request: ~azure.core.pipeline.transport.HttpRequest :param config: Any keyword arguments :return: The AsyncHttpResponse :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse :keyword bool stream: Defaults to False. :keyword dict proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url) :keyword str proxy: will define the proxy to use all the time """ @overload async def send(self, request: "RestHttpRequest", **config: Any) -> Optional["RestAsyncHttpResponse"]: """Send the `azure.core.rest` request using this HTTP sender. Will pre-load the body into memory to be available with a sync method. Pass stream=True to avoid this behavior. :param request: The HttpRequest object :type request: ~azure.core.rest.HttpRequest :param config: Any keyword arguments :return: The AsyncHttpResponse :rtype: ~azure.core.rest.AsyncHttpResponse :keyword bool stream: Defaults to False. :keyword dict proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url) :keyword str proxy: will define the proxy to use all the time """
[docs] async def send(self, request, **config): """Send the request using this HTTP sender. Will pre-load the body into memory to be available with a sync method. Pass stream=True to avoid this behavior. :param request: The HttpRequest object :type request: ~azure.core.pipeline.transport.HttpRequest :param config: Any keyword arguments :return: The AsyncHttpResponse :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse :keyword bool stream: Defaults to False. :keyword dict proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url) :keyword str proxy: will define the proxy to use all the time """ await self.open() try: auto_decompress = self.session.auto_decompress # type: ignore except AttributeError: # auto_decompress is introduced in aiohttp 3.7. We need this to handle Python 3.6. auto_decompress = False proxies = config.pop('proxies', None) if proxies and 'proxy' not in config: # aiohttp needs a single proxy, so iterating until we found the right protocol # Sort by longest string first, so "http" is not used for "https" ;-) for protocol in sorted(proxies.keys(), reverse=True): if request.url.startswith(protocol): config['proxy'] = proxies[protocol] break response: Optional["HTTPResponseType"] = None config['ssl'] = self._build_ssl_config( cert=config.pop('connection_cert', self.connection_config.cert), verify=config.pop('connection_verify', self.connection_config.verify) ) # If we know for sure there is not body, disable "auto content type" # Otherwise, aiohttp will send "application/octect-stream" even for empty POST request # and that break services like storage signature if not request.data and not request.files: config['skip_auto_headers'] = ['Content-Type'] try: stream_response = config.pop("stream", False) timeout = config.pop('connection_timeout', self.connection_config.timeout) read_timeout = config.pop('read_timeout', self.connection_config.read_timeout) socket_timeout = aiohttp.ClientTimeout(sock_connect=timeout, sock_read=read_timeout) result = await self.session.request( # type: ignore request.method, request.url, headers=request.headers, data=self._get_request_data(request), timeout=socket_timeout, allow_redirects=False, **config ) if _is_rest(request): from azure.core.rest._aiohttp import RestAioHttpTransportResponse response = RestAioHttpTransportResponse( request=request, internal_response=result, block_size=self.connection_config.data_block_size, decompress=not auto_decompress ) if not stream_response: await _handle_no_stream_rest_response(response) else: response = AioHttpTransportResponse(request, result, self.connection_config.data_block_size, decompress=not auto_decompress) if not stream_response: await response.load_body() except aiohttp.client_exceptions.ClientResponseError as err: raise ServiceResponseError(err, error=err) from err except aiohttp.client_exceptions.ClientError as err: raise ServiceRequestError(err, error=err) from err except asyncio.TimeoutError as err: raise ServiceResponseError(err, error=err) from err return response
class AioHttpStreamDownloadGenerator(AsyncIterator): """Streams the response body data. :param pipeline: The pipeline object :param response: The client response object. :param bool decompress: If True which is default, will attempt to decode the body based on the *content-encoding* header. """ def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, *, decompress=True) -> None: self.pipeline = pipeline self.request = response.request self.response = response self.block_size = response.block_size self._decompress = decompress internal_response = response.internal_response self.content_length = int(internal_response.headers.get('Content-Length', 0)) self._decompressor = None def __len__(self): return self.content_length async def __anext__(self): internal_response = self.response.internal_response try: chunk = await internal_response.content.read(self.block_size) if not chunk: raise _ResponseStopIteration() if not self._decompress: return chunk enc = internal_response.headers.get('Content-Encoding') if not enc: return chunk enc = enc.lower() if enc in ("gzip", "deflate"): if not self._decompressor: import zlib zlib_mode = 16 + zlib.MAX_WBITS if enc == "gzip" else zlib.MAX_WBITS self._decompressor = zlib.decompressobj(wbits=zlib_mode) chunk = self._decompressor.decompress(chunk) return chunk except _ResponseStopIteration: internal_response.close() raise StopAsyncIteration() except aiohttp.client_exceptions.ClientPayloadError as err: # This is the case that server closes connection before we finish the reading. aiohttp library # raises ClientPayloadError. _LOGGER.warning("Incomplete download: %s", err) internal_response.close() raise IncompleteReadError(err, error=err) except Exception as err: _LOGGER.warning("Unable to stream download: %s", err) internal_response.close() raise
[docs]class AioHttpTransportResponse(AsyncHttpResponse): """Methods for accessing response body data. :param request: The HttpRequest object :type request: ~azure.core.pipeline.transport.HttpRequest :param aiohttp_response: Returned from ClientSession.request(). :type aiohttp_response: aiohttp.ClientResponse object :param block_size: block size of data sent over connection. :type block_size: int :param bool decompress: If True which is default, will attempt to decode the body based on the *content-encoding* header. """ def __init__(self, request: HttpRequest, aiohttp_response: aiohttp.ClientResponse, block_size=None, *, decompress=True) -> None: super(AioHttpTransportResponse, self).__init__(request, aiohttp_response, block_size=block_size) # https://aiohttp.readthedocs.io/en/stable/client_reference.html#aiohttp.ClientResponse self.status_code = aiohttp_response.status self.headers = CIMultiDict(aiohttp_response.headers) self.reason = aiohttp_response.reason self.content_type = aiohttp_response.headers.get('content-type') self._content = None self._decompressed_content = False self._decompress = decompress
[docs] def body(self) -> bytes: """Return the whole body as bytes in memory. """ return _aiohttp_body_helper(self)
[docs] def text(self, encoding: Optional[str] = None) -> str: """Return the whole body as a string. If encoding is not provided, rely on aiohttp auto-detection. :param str encoding: The encoding to apply. """ # super().text detects charset based on self._content() which is compressed # implement the decoding explicitly here body = self.body() ctype = self.headers.get(aiohttp.hdrs.CONTENT_TYPE, "").lower() mimetype = aiohttp.helpers.parse_mimetype(ctype) if not encoding: # extract encoding from mimetype, if caller does not specify encoding = mimetype.parameters.get("charset") if encoding: try: codecs.lookup(encoding) except LookupError: encoding = None if not encoding: if mimetype.type == "application" and ( mimetype.subtype == "json" or mimetype.subtype == "rdap" ): # RFC 7159 states that the default encoding is UTF-8. # RFC 7483 defines application/rdap+json encoding = "utf-8" elif body is None: raise RuntimeError( "Cannot guess the encoding of a not yet read body" ) else: try: import cchardet as chardet except ImportError: # pragma: no cover try: import chardet # type: ignore except ImportError: # pragma: no cover import charset_normalizer as chardet # type: ignore[no-redef] encoding = chardet.detect(body)["encoding"] if encoding == "utf-8" or encoding is None: encoding = "utf-8-sig" return body.decode(encoding)
[docs] async def load_body(self) -> None: """Load in memory the body, so it could be accessible from sync methods.""" try: self._content = await self.internal_response.read() except aiohttp.client_exceptions.ClientPayloadError as err: # This is the case that server closes connection before we finish the reading. aiohttp library # raises ClientPayloadError. raise IncompleteReadError(err, error=err)
[docs] def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: """Generator for streaming response body data. :param pipeline: The pipeline object :type pipeline: azure.core.pipeline.Pipeline :keyword bool decompress: If True which is default, will attempt to decode the body based on the *content-encoding* header. """ return AioHttpStreamDownloadGenerator(pipeline, self, **kwargs)
def __getstate__(self): # Be sure body is loaded in memory, otherwise not pickable and let it throw self.body() state = self.__dict__.copy() # Remove the unpicklable entries. state['internal_response'] = None # aiohttp response are not pickable (see headers comments) state['headers'] = CIMultiDict(self.headers) # MultiDictProxy is not pickable return state