# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from __future__ import unicode_literals
import logging
import asyncio
import time
import functools
from typing import TYPE_CHECKING, Any, Dict, List, Callable, Optional, Union, cast
import six
from uamqp import (
authentication,
constants,
errors,
compat,
Message,
AMQPClientAsync,
)
from azure.core.credentials import AccessToken
from .._client_base import ClientBase, _generate_sas_token, _parse_conn_str
from .._utils import utc_from_timestamp
from ..exceptions import ClientClosedError, ConnectError
from .._constants import (
JWT_TOKEN_SCOPE,
MGMT_OPERATION,
MGMT_PARTITION_OPERATION,
MGMT_STATUS_CODE,
MGMT_STATUS_DESC
)
from ._connection_manager_async import get_connection_manager
from ._error_async import _handle_exception
if TYPE_CHECKING:
from azure.core.credentials import TokenCredential
try:
from typing_extensions import Protocol
except ImportError:
Protocol = object # type: ignore
_LOGGER = logging.getLogger(__name__)
[docs]class EventHubSharedKeyCredential(object):
"""The shared access key credential used for authentication.
:param str policy: The name of the shared access policy.
:param str key: The shared access key.
"""
def __init__(self, policy: str, key: str):
self.policy = policy
self.key = key
self.token_type = b"servicebus.windows.net:sastoken"
[docs] async def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
if not scopes:
raise ValueError("No token scope provided.")
return _generate_sas_token(scopes[0], self.policy, self.key)
class EventHubSASTokenCredential(object):
"""The shared access token credential used for authentication.
:param str token: The shared access token string
:param int expiry: The epoch timestamp
"""
def __init__(self, token: str, expiry: int) -> None:
"""
:param str token: The shared access token string
:param int expiry: The epoch timestamp
"""
self.token = token
self.expiry = expiry
self.token_type = b"servicebus.windows.net:sastoken"
async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument
"""
This method is automatically called when token is about to expire.
"""
return AccessToken(self.token, self.expiry)
class ClientBaseAsync(ClientBase):
def __init__(
self,
fully_qualified_namespace: str,
eventhub_name: str,
credential: "TokenCredential",
**kwargs: Any
) -> None:
self._loop = kwargs.pop("loop", None)
super(ClientBaseAsync, self).__init__(
fully_qualified_namespace=fully_qualified_namespace,
eventhub_name=eventhub_name,
credential=credential,
**kwargs
)
self._conn_manager_async = get_connection_manager(loop=self._loop, **kwargs)
def __enter__(self):
raise TypeError(
"Asynchronous client must be opened with async context manager."
)
@staticmethod
def _from_connection_string(conn_str: str, **kwargs) -> Dict[str, Any]:
host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, kwargs)
kwargs["fully_qualified_namespace"] = host
kwargs["eventhub_name"] = entity
if token and token_expiry:
kwargs["credential"] = EventHubSASTokenCredential(token, token_expiry)
elif policy and key:
kwargs["credential"] = EventHubSharedKeyCredential(policy, key)
return kwargs
async def _create_auth_async(self) -> authentication.JWTTokenAsync:
"""
Create an ~uamqp.authentication.SASTokenAuthAsync instance to authenticate
the session.
"""
try:
# ignore mypy's warning because token_type is Optional
token_type = self._credential.token_type # type: ignore
except AttributeError:
token_type = b"jwt"
if token_type == b"servicebus.windows.net:sastoken":
auth = authentication.JWTTokenAsync(
self._auth_uri,
self._auth_uri,
functools.partial(self._credential.get_token, self._auth_uri),
token_type=token_type,
timeout=self._config.auth_timeout,
http_proxy=self._config.http_proxy,
transport_type=self._config.transport_type,
custom_endpoint_hostname=self._config.custom_endpoint_hostname,
port=self._config.connection_port,
verify=self._config.connection_verify
)
await auth.update_token()
return auth
return authentication.JWTTokenAsync(
self._auth_uri,
self._auth_uri,
functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE),
token_type=token_type,
timeout=self._config.auth_timeout,
http_proxy=self._config.http_proxy,
transport_type=self._config.transport_type,
custom_endpoint_hostname=self._config.custom_endpoint_hostname,
port=self._config.connection_port,
verify=self._config.connection_verify
)
async def _close_connection_async(self) -> None:
await self._conn_manager_async.reset_connection_if_broken()
async def _backoff_async(
self,
retried_times: int,
last_exception: Exception,
timeout_time: Optional[float] = None,
entity_name: Optional[str] = None,
) -> None:
entity_name = entity_name or self._container_id
backoff = self._config.backoff_factor * 2 ** retried_times
if backoff <= self._config.backoff_max and (
timeout_time is None or time.time() + backoff <= timeout_time
): # pylint:disable=no-else-return
await asyncio.sleep(backoff, loop=self._loop)
_LOGGER.info(
"%r has an exception (%r). Retrying...",
format(entity_name),
last_exception,
)
else:
_LOGGER.info(
"%r operation has timed out. Last exception before timeout is (%r)",
entity_name,
last_exception,
)
raise last_exception
async def _management_request_async(self, mgmt_msg: Message, op_type: bytes) -> Any:
retried_times = 0
last_exception = None
while retried_times <= self._config.max_retries:
mgmt_auth = await self._create_auth_async()
mgmt_client = AMQPClientAsync(
self._mgmt_target, auth=mgmt_auth, debug=self._config.network_tracing
)
try:
conn = await self._conn_manager_async.get_connection(
self._address.hostname, mgmt_auth
)
mgmt_msg.application_properties["security_token"] = mgmt_auth.token
await mgmt_client.open_async(connection=conn)
response = await mgmt_client.mgmt_request_async(
mgmt_msg,
constants.READ_OPERATION,
op_type=op_type,
status_code_field=MGMT_STATUS_CODE,
description_fields=MGMT_STATUS_DESC,
)
status_code = int(response.application_properties[MGMT_STATUS_CODE])
description = response.application_properties.get(MGMT_STATUS_DESC) # type: Optional[Union[str, bytes]]
if description and isinstance(description, six.binary_type):
description = description.decode('utf-8')
if status_code < 400:
return response
if status_code in [401]:
raise errors.AuthenticationException(
"Management authentication failed. Status code: {}, Description: {!r}".format(
status_code,
description
)
)
if status_code in [404]:
raise ConnectError(
"Management connection failed. Status code: {}, Description: {!r}".format(
status_code,
description
)
)
raise errors.AMQPConnectionError(
"Management request error. Status code: {}, Description: {!r}".format(
status_code,
description
)
)
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except Exception as exception: # pylint:disable=broad-except
last_exception = await _handle_exception(exception, self)
await self._backoff_async(
retried_times=retried_times, last_exception=last_exception
)
retried_times += 1
if retried_times > self._config.max_retries:
_LOGGER.info(
"%r returns an exception %r", self._container_id, last_exception
)
raise last_exception
finally:
await mgmt_client.close_async()
async def _get_eventhub_properties_async(self) -> Dict[str, Any]:
mgmt_msg = Message(application_properties={"name": self.eventhub_name})
response = await self._management_request_async(
mgmt_msg, op_type=MGMT_OPERATION
)
output = {}
eh_info = response.get_data() # type: Dict[bytes, Any]
if eh_info:
output["eventhub_name"] = eh_info[b"name"].decode("utf-8")
output["created_at"] = utc_from_timestamp(
float(eh_info[b"created_at"]) / 1000
)
output["partition_ids"] = [
p.decode("utf-8") for p in eh_info[b"partition_ids"]
]
return output
async def _get_partition_ids_async(self) -> List[str]:
return (await self._get_eventhub_properties_async())["partition_ids"]
async def _get_partition_properties_async(
self, partition_id: str
) -> Dict[str, Any]:
mgmt_msg = Message(
application_properties={
"name": self.eventhub_name,
"partition": partition_id,
}
)
response = await self._management_request_async(
mgmt_msg, op_type=MGMT_PARTITION_OPERATION
)
partition_info = response.get_data() # type: Dict[bytes, Union[bytes, int]]
output = {} # type: Dict[str, Any]
if partition_info:
output["eventhub_name"] = cast(bytes, partition_info[b"name"]).decode(
"utf-8"
)
output["id"] = cast(bytes, partition_info[b"partition"]).decode("utf-8")
output["beginning_sequence_number"] = cast(
int, partition_info[b"begin_sequence_number"]
)
output["last_enqueued_sequence_number"] = cast(
int, partition_info[b"last_enqueued_sequence_number"]
)
output["last_enqueued_offset"] = cast(
bytes, partition_info[b"last_enqueued_offset"]
).decode("utf-8")
output["is_empty"] = partition_info[b"is_partition_empty"]
output["last_enqueued_time_utc"] = utc_from_timestamp(
float(cast(int, partition_info[b"last_enqueued_time_utc"]) / 1000)
)
return output
async def _close_async(self) -> None:
await self._conn_manager_async.close_connection()
if TYPE_CHECKING:
class AbstractConsumerProducer(Protocol):
@property
def _name(self):
# type: () -> str
"""Name of the consumer or producer
"""
@_name.setter
def _name(self, value):
pass
@property
def _client(self):
# type: () -> ClientBaseAsync
"""The instance of EventHubComsumerClient or EventHubProducerClient
"""
@_client.setter
def _client(self, value):
pass
@property
def _handler(self):
# type: () -> AMQPClientAsync
"""The instance of SendClientAsync or ReceiveClientAsync
"""
@property
def _loop(self):
# type: () -> asyncio.AbstractEventLoop
"""The event loop that users pass in to call wrap sync calls to async API.
It's furthur passed to uamqp APIs
"""
@_loop.setter
def _loop(self, value):
pass
@property
def running(self):
# type: () -> bool
"""Whether the consumer or producer is running
"""
@running.setter
def running(self, value):
pass
def _create_handler(self, auth: authentication.JWTTokenAsync) -> None:
pass
_MIXIN_BASE = AbstractConsumerProducer
else:
_MIXIN_BASE = object
class ConsumerProducerMixin(_MIXIN_BASE):
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
def _check_closed(self) -> None:
if self.closed:
raise ClientClosedError(
"{} has been closed. Please create a new one to handle event data.".format(
self._name
)
)
async def _open(self) -> None:
"""
Open the EventHubConsumer using the supplied connection.
"""
# pylint: disable=protected-access,line-too-long
if not self.running:
if self._handler:
await self._handler.close_async()
auth = await self._client._create_auth_async()
self._create_handler(auth)
await self._handler.open_async(
connection=await self._client._conn_manager_async.get_connection(
self._client._address.hostname, auth
)
)
while not await self._handler.client_ready_async():
await asyncio.sleep(0.05, loop=self._loop)
self._max_message_size_on_link = (
self._handler.message_handler._link.peer_max_message_size
or constants.MAX_MESSAGE_LENGTH_BYTES
)
self.running = True
async def _close_handler_async(self) -> None:
if self._handler:
# close the link (shared connection) or connection (not shared)
await self._handler.close_async()
self.running = False
async def _close_connection_async(self) -> None:
await self._close_handler_async()
await self._client._conn_manager_async.reset_connection_if_broken() # pylint:disable=protected-access
async def _handle_exception(self, exception: Exception) -> Exception:
if not self.running and isinstance(exception, compat.TimeoutException):
exception = errors.AuthenticationException("Authorization timeout.")
return await _handle_exception(exception, self)
return await _handle_exception(exception, self)
async def _do_retryable_operation(
self,
operation: Callable[..., Any],
timeout: Optional[float] = None,
**kwargs: Any
) -> Optional[Any]:
# pylint:disable=protected-access,line-too-long
timeout_time = (time.time() + timeout) if timeout else None
retried_times = 0
last_exception = kwargs.pop("last_exception", None)
operation_need_param = kwargs.pop("operation_need_param", True)
max_retries = self._client._config.max_retries
while retried_times <= max_retries:
try:
if operation_need_param:
return await operation(
timeout_time=timeout_time,
last_exception=last_exception,
**kwargs
)
return await operation()
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except Exception as exception: # pylint:disable=broad-except
last_exception = await self._handle_exception(exception)
await self._client._backoff_async(
retried_times=retried_times,
last_exception=last_exception,
timeout_time=timeout_time,
entity_name=self._name,
)
retried_times += 1
if retried_times > max_retries:
_LOGGER.info(
"%r operation has exhausted retry. Last exception: %r.",
self._name,
last_exception,
)
raise last_exception
return None
async def close(self) -> None:
"""
Close down the handler. If the handler has already closed,
this will be a no op.
"""
await self._close_handler_async()
self.closed = True