Source code for azure.core.pipeline.transport._aiohttp

# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from typing import Any, Optional, AsyncIterator as AsyncIteratorType
from collections.abc import AsyncIterator

import logging
import asyncio
import aiohttp

from requests.exceptions import (
    ChunkedEncodingError,
    StreamConsumedError)

from azure.core.configuration import ConnectionConfiguration
from azure.core.exceptions import ServiceRequestError, ServiceResponseError, AzureError
from azure.core.pipeline import Pipeline

from ._base import HttpRequest
from ._base_async import (
    AsyncHttpTransport,
    AsyncHttpResponse,
    _ResponseStopIteration)

# Matching requests, because why not?
CONTENT_CHUNK_SIZE = 10 * 1024
_LOGGER = logging.getLogger(__name__)


[docs]class AioHttpTransport(AsyncHttpTransport): """AioHttp HTTP sender implementation. Fully asynchronous implementation using the aiohttp library. :param session: The client session. :param loop: The event loop. :param bool session_owner: Session owner. Defaults True. :keyword bool use_env_settings: Uses proxy settings from environment. Defaults to True. .. admonition:: Example: .. literalinclude:: ../samples/test_example_async.py :start-after: [START aiohttp] :end-before: [END aiohttp] :language: python :dedent: 4 :caption: Asynchronous transport with aiohttp. """ def __init__(self, *, session=None, loop=None, session_owner=True, **kwargs): self._loop = loop self._session_owner = session_owner self.session = session self.connection_config = ConnectionConfiguration(**kwargs) self._use_env_settings = kwargs.pop('use_env_settings', True) async def __aenter__(self): await self.open() return self async def __aexit__(self, *args): # pylint: disable=arguments-differ await self.close()
[docs] async def open(self): """Opens the connection. """ if not self.session and self._session_owner: jar = aiohttp.DummyCookieJar() self.session = aiohttp.ClientSession( loop=self._loop, trust_env=self._use_env_settings, cookie_jar=jar ) if self.session is not None: await self.session.__aenter__()
[docs] async def close(self): """Closes the connection. """ if self._session_owner and self.session: await self.session.close() self._session_owner = False self.session = None
def _build_ssl_config(self, cert, verify): # pylint: disable=no-self-use ssl_ctx = None if cert or verify not in (True, False): import ssl if verify not in (True, False): ssl_ctx = ssl.create_default_context(cafile=verify) else: ssl_ctx = ssl.create_default_context() if cert: ssl_ctx.load_cert_chain(*cert) return ssl_ctx return verify def _get_request_data(self, request): #pylint: disable=no-self-use if request.files: form_data = aiohttp.FormData() for form_file, data in request.files.items(): content_type = data[2] if len(data) > 2 else None try: form_data.add_field(form_file, data[1], filename=data[0], content_type=content_type) except IndexError: raise ValueError("Invalid formdata formatting: {}".format(data)) return form_data return request.data
[docs] async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpResponse]: """Send the request using this HTTP sender. Will pre-load the body into memory to be available with a sync method. Pass stream=True to avoid this behavior. :param request: The HttpRequest object :type request: ~azure.core.pipeline.transport.HttpRequest :param config: Any keyword arguments :return: The AsyncHttpResponse :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse :keyword bool stream: Defaults to False. :keyword dict proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url) :keyword str proxy: will define the proxy to use all the time """ await self.open() proxies = config.pop('proxies', None) if proxies and 'proxy' not in config: # aiohttp needs a single proxy, so iterating until we found the right protocol # Sort by longest string first, so "http" is not used for "https" ;-) for protocol in sorted(proxies.keys(), reverse=True): if request.url.startswith(protocol): config['proxy'] = proxies[protocol] break error = None # type: Optional[AzureError] response = None config['ssl'] = self._build_ssl_config( cert=config.pop('connection_cert', self.connection_config.cert), verify=config.pop('connection_verify', self.connection_config.verify) ) # If we know for sure there is not body, disable "auto content type" # Otherwise, aiohttp will send "application/octect-stream" even for empty POST request # and that break services like storage signature if not request.data and not request.files: config['skip_auto_headers'] = ['Content-Type'] try: stream_response = config.pop("stream", False) timeout = config.pop('connection_timeout', self.connection_config.timeout) read_timeout = config.pop('read_timeout', self.connection_config.read_timeout) socket_timeout = aiohttp.ClientTimeout(sock_connect=timeout, sock_read=read_timeout) result = await self.session.request( request.method, request.url, headers=request.headers, data=self._get_request_data(request), timeout=socket_timeout, allow_redirects=False, **config ) response = AioHttpTransportResponse(request, result, self.connection_config.data_block_size) if not stream_response: await response.load_body() except aiohttp.client_exceptions.ClientConnectorError as err: error = ServiceRequestError(err, error=err) except asyncio.TimeoutError as err: error = ServiceResponseError(err, error=err) if error: raise error return response
class AioHttpStreamDownloadGenerator(AsyncIterator): """Streams the response body data. :param pipeline: The pipeline object :param response: The client response object. :param block_size: block size of data sent over connection. :type block_size: int """ def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None: self.pipeline = pipeline self.request = response.request self.response = response self.block_size = response.block_size self.content_length = int(response.internal_response.headers.get('Content-Length', 0)) self.downloaded = 0 def __len__(self): return self.content_length async def __anext__(self): retry_active = True retry_total = 3 retry_interval = 1000 while retry_active: try: chunk = await self.response.internal_response.content.read(self.block_size) if not chunk: raise _ResponseStopIteration() self.downloaded += self.block_size return chunk except _ResponseStopIteration: self.response.internal_response.close() raise StopAsyncIteration() except (ChunkedEncodingError, ConnectionError): retry_total -= 1 if retry_total <= 0: retry_active = False else: await asyncio.sleep(retry_interval) headers = {'range': 'bytes=' + str(self.downloaded) + '-'} resp = self.pipeline.run(self.request, stream=True, headers=headers) if resp.status_code == 416: raise chunk = await self.response.internal_response.content.read(self.block_size) if not chunk: raise StopIteration() self.downloaded += len(chunk) return chunk continue except StreamConsumedError: raise except Exception as err: _LOGGER.warning("Unable to stream download: %s", err) self.response.internal_response.close() raise
[docs]class AioHttpTransportResponse(AsyncHttpResponse): """Methods for accessing response body data. :param request: The HttpRequest object :type request: ~azure.core.pipeline.transport.HttpRequest :param aiohttp_response: Returned from ClientSession.request(). :type aiohttp_response: aiohttp.ClientResponse object :param block_size: block size of data sent over connection. :type block_size: int """ def __init__(self, request: HttpRequest, aiohttp_response: aiohttp.ClientResponse, block_size=None) -> None: super(AioHttpTransportResponse, self).__init__(request, aiohttp_response, block_size=block_size) # https://aiohttp.readthedocs.io/en/stable/client_reference.html#aiohttp.ClientResponse self.status_code = aiohttp_response.status self.headers = aiohttp_response.headers self.reason = aiohttp_response.reason self.content_type = aiohttp_response.headers.get('content-type') self._body = None
[docs] def body(self) -> bytes: """Return the whole body as bytes in memory. """ if self._body is None: raise ValueError("Body is not available. Call async method load_body, or do your call with stream=False.") return self._body
[docs] async def load_body(self) -> None: """Load in memory the body, so it could be accessible from sync methods.""" self._body = await self.internal_response.read()
[docs] def stream_download(self, pipeline) -> AsyncIteratorType[bytes]: """Generator for streaming response body data. :param pipeline: The pipeline object :type pipeline: azure.core.pipeline """ return AioHttpStreamDownloadGenerator(pipeline, self)