# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from __future__ import unicode_literals, annotations
import logging
import asyncio
import time
import functools
from typing import TYPE_CHECKING, Any, Dict, List, Callable, Optional, Union, cast
from azure.core.credentials import (
AccessToken,
AzureSasCredential,
AzureNamedKeyCredential,
)
from .._client_base import (
ClientBase,
_generate_sas_token,
_parse_conn_str,
_get_backoff_time,
)
from .._utils import utc_from_timestamp, parse_sas_credential
from ..exceptions import ClientClosedError
from .._constants import (
JWT_TOKEN_SCOPE,
MGMT_OPERATION,
MGMT_PARTITION_OPERATION,
MGMT_STATUS_CODE,
MGMT_STATUS_DESC,
READ_OPERATION,
)
from ._async_utils import get_dict_with_loop_if_needed
from ._connection_manager_async import get_connection_manager
try:
from ._transport._uamqp_transport_async import UamqpTransportAsync
except ImportError:
UamqpTransportAsync = None # type: ignore
from ._transport._pyamqp_transport_async import PyamqpTransportAsync
if TYPE_CHECKING:
from .._pyamqp.message import Message
from .._pyamqp.aio import AMQPClientAsync
from .._pyamqp.aio._authentication_async import JWTTokenAuthAsync
try:
from uamqp import (
Message as uamqp_Message,
AMQPClientAsync as uamqp_AMQPClientAsync,
)
from uamqp.authentication import JWTTokenAsync as uamqp_JWTTokenAsync
except ImportError:
uamqp_Message = None
uamqp_AMQPClientAsync = None
uamqp_JWTTokenAsync = None
from azure.core.credentials_async import AsyncTokenCredential
try:
from typing_extensions import TypeAlias, Protocol
except ImportError:
Protocol = object # type: ignore
CredentialTypes: TypeAlias = Union[
"EventHubSharedKeyCredential",
AsyncTokenCredential,
AzureSasCredential,
AzureNamedKeyCredential,
]
class AbstractConsumerProducer(Protocol):
@property
def _name(self) -> str:
"""Name of the consumer or producer"""
@_name.setter
def _name(self, value):
pass
@property
def _client(self) -> ClientBaseAsync:
"""The instance of EventHubComsumerClient or EventHubProducerClient"""
@_client.setter
def _client(self, value):
pass
@property
def _handler(self) -> Union["uamqp_AMQPClientAsync", AMQPClientAsync]:
"""The instance of SendClientAsync or ReceiveClientAsync"""
@property
def _internal_kwargs(self) -> Dict[Any, Any]:
"""The dict with an event loop that users may pass in to wrap sync calls to async API.
It's furthur passed to uamqp APIs
"""
@_internal_kwargs.setter
def _internal_kwargs(self, value):
pass
@property
def running(self) -> bool:
"""Whether the consumer or producer is running"""
@running.setter
def running(self, value: bool) -> None:
pass
def _create_handler(self, auth: Union["uamqp_JWTTokenAsync", JWTTokenAuthAsync]) -> None:
pass
_MIXIN_BASE = AbstractConsumerProducer
else:
_MIXIN_BASE = object
_LOGGER = logging.getLogger(__name__)
[docs]
class EventHubSharedKeyCredential:
"""The shared access key credential used for authentication.
:param str policy: The name of the shared access policy.
:param str key: The shared access key.
"""
def __init__(self, policy: str, key: str):
self.policy = policy
self.key = key
self.token_type = b"servicebus.windows.net:sastoken"
[docs]
async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument
if not scopes:
raise ValueError("No token scope provided.")
return _generate_sas_token(scopes[0], self.policy, self.key)
class EventHubSASTokenCredential:
"""The shared access token credential used for authentication.
:param str token: The shared access token string
:param int expiry: The epoch timestamp
"""
def __init__(self, token: str, expiry: int) -> None:
"""
:param str token: The shared access token string
:param int expiry: The epoch timestamp
"""
self.token = token
self.expiry = expiry
self.token_type = b"servicebus.windows.net:sastoken"
async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument
"""
This method is automatically called when token is about to expire.
:param str scopes: The list of scopes for which the token has access.
:return: The token object
:rtype: ~azure.core.credentials.AccessToken
"""
return AccessToken(self.token, self.expiry)
class EventhubAzureNamedKeyTokenCredentialAsync: # pylint: disable=name-too-long
"""The named key credential used for authentication.
:param credential: The AzureNamedKeyCredential that should be used.
:type credential: ~azure.core.credentials.AzureNamedKeyCredential
"""
def __init__(self, azure_named_key_credential: AzureNamedKeyCredential) -> None:
self._credential: AzureNamedKeyCredential = azure_named_key_credential
self.token_type: bytes = b"servicebus.windows.net:sastoken"
async def get_token(self, *scopes, **kwargs) -> AccessToken: # pylint:disable=unused-argument
if not scopes:
raise ValueError("No token scope provided.")
name, key = self._credential.named_key
return _generate_sas_token(scopes[0], name, key)
class EventhubAzureSasTokenCredentialAsync:
"""The shared access token credential used for authentication
when AzureSasCredential is provided.
:param azure_sas_credential: The credential to be used for authentication.
:type azure_sas_credential: ~azure.core.credentials.AzureSasCredential
"""
def __init__(self, azure_sas_credential: AzureSasCredential) -> None:
self._credential = azure_sas_credential
self.token_type = b"servicebus.windows.net:sastoken"
async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument
"""
This method is automatically called when token is about to expire.
:param str scopes: The list of scopes for which the token has access.
:return: The access token.
:rtype: ~azure.core.credentials.AccessToken
"""
signature, expiry = parse_sas_credential(self._credential)
return AccessToken(signature, cast(int, expiry))
class ClientBaseAsync(ClientBase):
def __init__(
self, fully_qualified_namespace: str, eventhub_name: str, credential: "CredentialTypes", **kwargs: Any
) -> None:
self._internal_kwargs = get_dict_with_loop_if_needed(kwargs.get("loop", None))
uamqp_transport = kwargs.get("uamqp_transport", False)
if uamqp_transport and UamqpTransportAsync is None:
raise ValueError("To use the uAMQP transport, please install `uamqp>=1.6.0,<2.0.0`.")
self._amqp_transport = UamqpTransportAsync if uamqp_transport else PyamqpTransportAsync
if isinstance(credential, AzureSasCredential):
self._credential = EventhubAzureSasTokenCredentialAsync(credential) # type: ignore
elif isinstance(credential, AzureNamedKeyCredential):
self._credential = EventhubAzureNamedKeyTokenCredentialAsync(credential) # type: ignore
else:
self._credential = credential # type: ignore
super(ClientBaseAsync, self).__init__(
fully_qualified_namespace=fully_qualified_namespace,
eventhub_name=eventhub_name,
credential=self._credential,
amqp_transport=self._amqp_transport,
**kwargs,
)
kwargs["custom_endpoint_address"] = self._config.custom_endpoint_address
self._conn_manager_async = get_connection_manager(amqp_transport=self._amqp_transport, **kwargs)
def __enter__(self) -> None:
raise TypeError("Asynchronous client must be opened with async context manager.")
@staticmethod
def _from_connection_string(conn_str: str, **kwargs) -> Dict[str, Any]:
host, policy, key, entity, token, token_expiry, emulator = _parse_conn_str(conn_str, **kwargs)
kwargs["fully_qualified_namespace"] = host
kwargs["eventhub_name"] = entity
# Check if emulator is in use, unset tls if it is
if emulator:
kwargs["use_tls"] = False
if token and token_expiry:
kwargs["credential"] = EventHubSASTokenCredential(token, token_expiry)
elif policy and key:
kwargs["credential"] = EventHubSharedKeyCredential(policy, key)
return kwargs
async def _create_auth_async(
self, *, auth_uri: Optional[str] = None
) -> Union["uamqp_JWTTokenAsync", JWTTokenAuthAsync]:
"""
Create an ~uamqp.authentication.SASTokenAuthAsync instance to authenticate
the session.
:keyword auth_uri: The URI to authenticate with.
:paramtype auth_uri: str or None
:return: A JWTTokenAuthAsync instance to authenticate the session.
:rtype: ~uamqp.authentication.JWTTokenAsync or JWTTokenAuthAsync
"""
# if auth_uri is not provided, use the default hub one
entity_auth_uri = auth_uri if auth_uri else self._eventhub_auth_uri
try:
# ignore mypy's warning because token_type is Optional
token_type = self._credential.token_type # type: ignore
except AttributeError:
token_type = b"jwt"
if token_type == b"servicebus.windows.net:sastoken":
return await self._amqp_transport.create_token_auth_async(
entity_auth_uri,
functools.partial(self._credential.get_token, entity_auth_uri),
token_type=token_type,
config=self._config,
update_token=True,
)
return await self._amqp_transport.create_token_auth_async(
entity_auth_uri,
functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE),
token_type=token_type,
config=self._config,
update_token=False,
)
async def _close_connection_async(self) -> None:
await self._conn_manager_async.reset_connection_if_broken()
async def _backoff_async(
self,
retried_times: int,
last_exception: Exception,
timeout_time: Optional[float] = None,
entity_name: Optional[str] = None,
) -> None:
entity_name = entity_name or self._container_id
backoff = _get_backoff_time(
self._config.retry_mode,
self._config.backoff_factor,
self._config.backoff_max,
retried_times,
)
if backoff <= self._config.backoff_max and (
timeout_time is None or time.time() + backoff <= timeout_time
): # pylint:disable=no-else-return
await asyncio.sleep(backoff, **self._internal_kwargs)
_LOGGER.info(
"%r has an exception (%r). Retrying...",
format(entity_name),
last_exception,
)
else:
_LOGGER.info(
"%r operation has timed out. Last exception before timeout is (%r)",
entity_name,
last_exception,
)
raise last_exception
async def _management_request_async(self, mgmt_msg: Union[Message, uamqp_Message], op_type: bytes) -> Any:
retried_times = 0
last_exception = None
while retried_times <= self._config.max_retries:
mgmt_auth = await self._create_auth_async()
mgmt_client = self._amqp_transport.create_mgmt_client(
self._address, mgmt_auth=mgmt_auth, config=self._config
)
try:
conn = await self._conn_manager_async.get_connection(endpoint=self._address.hostname, auth=mgmt_auth)
await mgmt_client.open_async(connection=conn)
while not await mgmt_client.client_ready_async():
await asyncio.sleep(0.05)
cast(Dict[Union[str, bytes], Any], mgmt_msg.application_properties)["security_token"] = (
await self._amqp_transport.get_updated_token_async(mgmt_auth)
)
status_code, description, response = await self._amqp_transport.mgmt_client_request_async(
mgmt_client,
mgmt_msg,
operation=READ_OPERATION,
operation_type=op_type,
status_code_field=MGMT_STATUS_CODE,
description_fields=MGMT_STATUS_DESC,
)
status_code = int(status_code)
if description and isinstance(description, bytes):
description = description.decode("utf-8")
if status_code < 400:
return response
raise self._amqp_transport.get_error(status_code, description)
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except Exception as exception: # pylint:disable=broad-except
# If optional dependency is not installed, do not retry.
if isinstance(exception, ImportError):
raise exception
# is_consumer=True passed in here, ALTHOUGH this method is shared by the producer and consumer.
# is_consumer will only be checked if FileNotFoundError is raised by self.mgmt_client.open() due to
# invalid/non-existent connection_verify filepath. The producer will encounter the FileNotFoundError
# when opening the SendClient, so is_consumer=True will not be passed to amqp_transport.handle_exception
# there. This is for uamqp exception parity, which raises FileNotFoundError in the consumer and
# EventHubError in the producer. TODO: Remove `is_consumer` kwarg when resolving issue #27128.
last_exception = await self._amqp_transport._handle_exception_async( # pylint: disable=protected-access
exception, self, is_consumer=True
)
await self._backoff_async(retried_times=retried_times, last_exception=last_exception)
retried_times += 1
if retried_times > self._config.max_retries:
_LOGGER.info("%r returns an exception %r", self._container_id, last_exception)
raise last_exception from None
finally:
await mgmt_client.close_async()
async def _get_eventhub_properties_async(self) -> Dict[str, Any]:
mgmt_msg = self._amqp_transport.build_message(application_properties={"name": self.eventhub_name})
response = await self._management_request_async(mgmt_msg, op_type=MGMT_OPERATION)
output = {}
eh_info: Dict[bytes, Any] = response.value
if eh_info:
output["eventhub_name"] = eh_info[b"name"].decode("utf-8")
output["created_at"] = utc_from_timestamp(float(eh_info[b"created_at"]) / 1000)
output["partition_ids"] = [p.decode("utf-8") for p in eh_info[b"partition_ids"]]
return output
async def _get_partition_ids_async(self) -> List[str]:
return (await self._get_eventhub_properties_async())["partition_ids"]
async def _get_partition_properties_async(self, partition_id: str) -> Dict[str, Any]:
mgmt_msg = self._amqp_transport.build_message(
application_properties={
"name": self.eventhub_name,
"partition": partition_id,
}
)
response = await self._management_request_async(mgmt_msg, op_type=MGMT_PARTITION_OPERATION)
partition_info: Dict[bytes, Union[bytes, int]] = response.value
output: Dict[str, Any] = {}
if partition_info:
output["eventhub_name"] = cast(bytes, partition_info[b"name"]).decode("utf-8")
output["id"] = cast(bytes, partition_info[b"partition"]).decode("utf-8")
output["beginning_sequence_number"] = cast(int, partition_info[b"begin_sequence_number"])
output["last_enqueued_sequence_number"] = cast(int, partition_info[b"last_enqueued_sequence_number"])
output["last_enqueued_offset"] = cast(bytes, partition_info[b"last_enqueued_offset"]).decode("utf-8")
output["is_empty"] = partition_info[b"is_partition_empty"]
output["last_enqueued_time_utc"] = utc_from_timestamp(
float(cast(int, partition_info[b"last_enqueued_time_utc"]) / 1000)
)
return output
async def _close_async(self) -> None:
await self._conn_manager_async.close_connection()
class ConsumerProducerMixin(_MIXIN_BASE):
async def __aenter__(self) -> ConsumerProducerMixin:
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close()
def _check_closed(self) -> None:
if self.closed:
raise ClientClosedError(
"{} has been closed. Please create a new one to handle event data.".format(self._name)
)
async def _open(self) -> None:
"""
Open the EventHubConsumer using the supplied connection.
"""
# pylint: disable=protected-access,line-too-long
if not self.running:
if self._handler:
await self._handler.close_async()
auth = await self._client._create_auth_async(auth_uri=self._client._auth_uri)
self._create_handler(auth)
conn = await self._client._conn_manager_async.get_connection(
endpoint=self._client._address.hostname, auth=auth
)
await self._handler.open_async(connection=conn)
while not await self._handler.client_ready_async():
await asyncio.sleep(0.05, **self._internal_kwargs)
# pylint: disable=protected-access
self._max_message_size_on_link = (
self._client._amqp_transport.get_remote_max_message_size(self._handler)
or self._client._amqp_transport.MAX_MESSAGE_LENGTH_BYTES
)
self.running = True
async def _close_handler_async(self) -> None:
if self._handler:
# close the link (shared connection) or connection (not shared)
await self._handler.close_async()
self.running = False
async def _close_connection_async(self) -> None:
await self._close_handler_async()
await self._client._conn_manager_async.reset_connection_if_broken() # pylint:disable=protected-access
async def _handle_exception(self, exception: Exception, *, is_consumer: bool = False) -> Exception:
# pylint: disable=protected-access
exception = self._client._amqp_transport.check_timeout_exception(self, exception)
return await self._client._amqp_transport._handle_exception_async(exception, self, is_consumer=is_consumer)
async def _do_retryable_operation(
self, operation: Callable[..., Any], timeout: Optional[float] = None, **kwargs: Any
) -> Optional[Any]:
# pylint:disable=protected-access,line-too-long
timeout_time = (time.time() + timeout) if timeout else None
retried_times = 0
last_exception = kwargs.pop("last_exception", None)
operation_need_param = kwargs.pop("operation_need_param", True)
max_retries = self._client._config.max_retries
while retried_times <= max_retries:
try:
if operation_need_param:
return await operation(timeout_time=timeout_time, last_exception=last_exception, **kwargs)
return await operation()
except asyncio.CancelledError: # pylint: disable=try-except-raise
raise
except Exception as exception: # pylint:disable=broad-except
# If optional dependency is not installed, do not retry.
if isinstance(exception, ImportError):
raise exception
last_exception = await self._handle_exception(exception)
await self._client._backoff_async(
retried_times=retried_times,
last_exception=last_exception,
timeout_time=timeout_time,
entity_name=self._name,
)
retried_times += 1
if retried_times > max_retries:
_LOGGER.info(
"%r operation has exhausted retry. Last exception: %r.",
self._name,
last_exception,
)
raise last_exception from None
return None
async def close(self) -> None:
"""
Close down the handler. If the handler has already closed,
this will be a no op.
"""
await self._close_handler_async()
self.closed = True