# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from typing import Any, Optional, AsyncIterator as AsyncIteratorType
from collections.abc import AsyncIterator
import logging
import asyncio
import aiohttp
from requests.exceptions import (
ChunkedEncodingError,
StreamConsumedError)
from azure.core.configuration import ConnectionConfiguration
from azure.core.exceptions import ServiceRequestError, ServiceResponseError, AzureError
from azure.core.pipeline import Pipeline
from ._base import HttpRequest
from ._base_async import (
AsyncHttpTransport,
AsyncHttpResponse,
_ResponseStopIteration)
# Matching requests, because why not?
CONTENT_CHUNK_SIZE = 10 * 1024
_LOGGER = logging.getLogger(__name__)
[docs]class AioHttpTransport(AsyncHttpTransport):
"""AioHttp HTTP sender implementation.
Fully asynchronous implementation using the aiohttp library.
:param session: The client session.
:param loop: The event loop.
:param bool session_owner: Session owner. Defaults True.
:keyword bool use_env_settings: Uses proxy settings from environment. Defaults to True.
.. admonition:: Example:
.. literalinclude:: ../samples/test_example_async.py
:start-after: [START aiohttp]
:end-before: [END aiohttp]
:language: python
:dedent: 4
:caption: Asynchronous transport with aiohttp.
"""
def __init__(self, *, session=None, loop=None, session_owner=True, **kwargs):
self._loop = loop
self._session_owner = session_owner
self.session = session
self.connection_config = ConnectionConfiguration(**kwargs)
self._use_env_settings = kwargs.pop('use_env_settings', True)
async def __aenter__(self):
await self.open()
return self
async def __aexit__(self, *args): # pylint: disable=arguments-differ
await self.close()
[docs] async def open(self):
"""Opens the connection.
"""
if not self.session and self._session_owner:
jar = aiohttp.DummyCookieJar()
self.session = aiohttp.ClientSession(
loop=self._loop,
trust_env=self._use_env_settings,
cookie_jar=jar
)
if self.session is not None:
await self.session.__aenter__()
[docs] async def close(self):
"""Closes the connection.
"""
if self._session_owner and self.session:
await self.session.close()
self._session_owner = False
self.session = None
def _build_ssl_config(self, cert, verify): # pylint: disable=no-self-use
ssl_ctx = None
if cert or verify not in (True, False):
import ssl
if verify not in (True, False):
ssl_ctx = ssl.create_default_context(cafile=verify)
else:
ssl_ctx = ssl.create_default_context()
if cert:
ssl_ctx.load_cert_chain(*cert)
return ssl_ctx
return verify
def _get_request_data(self, request): #pylint: disable=no-self-use
if request.files:
form_data = aiohttp.FormData()
for form_file, data in request.files.items():
content_type = data[2] if len(data) > 2 else None
try:
form_data.add_field(form_file, data[1], filename=data[0], content_type=content_type)
except IndexError:
raise ValueError("Invalid formdata formatting: {}".format(data))
return form_data
return request.data
[docs] async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpResponse]:
"""Send the request using this HTTP sender.
Will pre-load the body into memory to be available with a sync method.
Pass stream=True to avoid this behavior.
:param request: The HttpRequest object
:type request: ~azure.core.pipeline.transport.HttpRequest
:param config: Any keyword arguments
:return: The AsyncHttpResponse
:rtype: ~azure.core.pipeline.transport.AsyncHttpResponse
:keyword bool stream: Defaults to False.
:keyword dict proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url)
:keyword str proxy: will define the proxy to use all the time
"""
await self.open()
proxies = config.pop('proxies', None)
if proxies and 'proxy' not in config:
# aiohttp needs a single proxy, so iterating until we found the right protocol
# Sort by longest string first, so "http" is not used for "https" ;-)
for protocol in sorted(proxies.keys(), reverse=True):
if request.url.startswith(protocol):
config['proxy'] = proxies[protocol]
break
error = None # type: Optional[AzureError]
response = None
config['ssl'] = self._build_ssl_config(
cert=config.pop('connection_cert', self.connection_config.cert),
verify=config.pop('connection_verify', self.connection_config.verify)
)
# If we know for sure there is not body, disable "auto content type"
# Otherwise, aiohttp will send "application/octect-stream" even for empty POST request
# and that break services like storage signature
if not request.data and not request.files:
config['skip_auto_headers'] = ['Content-Type']
try:
stream_response = config.pop("stream", False)
result = await self.session.request(
request.method,
request.url,
headers=request.headers,
data=self._get_request_data(request),
timeout=config.pop('connection_timeout', self.connection_config.timeout),
allow_redirects=False,
**config
)
response = AioHttpTransportResponse(request, result, self.connection_config.data_block_size)
if not stream_response:
await response.load_body()
except aiohttp.client_exceptions.ClientConnectorError as err:
error = ServiceRequestError(err, error=err)
except asyncio.TimeoutError as err:
error = ServiceResponseError(err, error=err)
if error:
raise error
return response
class AioHttpStreamDownloadGenerator(AsyncIterator):
"""Streams the response body data.
:param pipeline: The pipeline object
:param response: The client response object.
:param block_size: block size of data sent over connection.
:type block_size: int
"""
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None:
self.pipeline = pipeline
self.request = response.request
self.response = response
self.block_size = response.block_size
self.content_length = int(response.internal_response.headers.get('Content-Length', 0))
self.downloaded = 0
def __len__(self):
return self.content_length
async def __anext__(self):
retry_active = True
retry_total = 3
retry_interval = 1000
while retry_active:
try:
chunk = await self.response.internal_response.content.read(self.block_size)
if not chunk:
raise _ResponseStopIteration()
self.downloaded += self.block_size
return chunk
except _ResponseStopIteration:
self.response.internal_response.close()
raise StopAsyncIteration()
except (ChunkedEncodingError, ConnectionError):
retry_total -= 1
if retry_total <= 0:
retry_active = False
else:
await asyncio.sleep(retry_interval)
headers = {'range': 'bytes=' + str(self.downloaded) + '-'}
resp = self.pipeline.run(self.request, stream=True, headers=headers)
if resp.status_code == 416:
raise
chunk = await self.response.internal_response.content.read(self.block_size)
if not chunk:
raise StopIteration()
self.downloaded += len(chunk)
return chunk
continue
except StreamConsumedError:
raise
except Exception as err:
_LOGGER.warning("Unable to stream download: %s", err)
self.response.internal_response.close()
raise
[docs]class AioHttpTransportResponse(AsyncHttpResponse):
"""Methods for accessing response body data.
:param request: The HttpRequest object
:type request: ~azure.core.pipeline.transport.HttpRequest
:param aiohttp_response: Returned from ClientSession.request().
:type aiohttp_response: aiohttp.ClientResponse object
:param block_size: block size of data sent over connection.
:type block_size: int
"""
def __init__(self, request: HttpRequest, aiohttp_response: aiohttp.ClientResponse, block_size=None) -> None:
super(AioHttpTransportResponse, self).__init__(request, aiohttp_response, block_size=block_size)
# https://aiohttp.readthedocs.io/en/stable/client_reference.html#aiohttp.ClientResponse
self.status_code = aiohttp_response.status
self.headers = aiohttp_response.headers
self.reason = aiohttp_response.reason
self.content_type = aiohttp_response.headers.get('content-type')
self._body = None
[docs] def body(self) -> bytes:
"""Return the whole body as bytes in memory.
"""
if self._body is None:
raise ValueError("Body is not available. Call async method load_body, or do your call with stream=False.")
return self._body
[docs] async def load_body(self) -> None:
"""Load in memory the body, so it could be accessible from sync methods."""
self._body = await self.internal_response.read()
[docs] def stream_download(self, pipeline) -> AsyncIteratorType[bytes]:
"""Generator for streaming response body data.
:param pipeline: The pipeline object
:type pipeline: azure.core.pipeline
"""
return AioHttpStreamDownloadGenerator(pipeline, self)