Source code for azure.servicebus._base_handler

# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import collections
import logging
import uuid
import time
from datetime import timedelta
from typing import cast, Optional, Tuple, TYPE_CHECKING, Dict, Any, Callable, Type

try:
    from urllib import quote_plus  # type: ignore
except ImportError:
    from urllib.parse import quote_plus

import uamqp
from uamqp import utils
from uamqp.message import MessageProperties

from ._common._configuration import Configuration
from .exceptions import (
    ServiceBusError,
    ServiceBusAuthorizationError,
    _create_servicebus_exception
)
from ._common.utils import create_properties
from ._common.constants import (
    CONTAINER_PREFIX,
    MANAGEMENT_PATH_SUFFIX,
    TOKEN_TYPE_SASTOKEN,
    MGMT_REQUEST_OP_TYPE_ENTITY_MGMT,
    ASSOCIATEDLINKPROPERTYNAME
)

if TYPE_CHECKING:
    from azure.core.credentials import TokenCredential

_AccessToken = collections.namedtuple("AccessToken", "token expires_on")
_LOGGER = logging.getLogger(__name__)


def _parse_conn_str(conn_str):
    # type: (str) -> Tuple[str, str, str, str]
    endpoint = None
    shared_access_key_name = None
    shared_access_key = None
    entity_path = None  # type: Optional[str]
    for element in conn_str.split(";"):
        key, _, value = element.partition("=")
        if key.lower() == "endpoint":
            endpoint = value.rstrip("/")
        elif key.lower() == "hostname":
            endpoint = value.rstrip("/")
        elif key.lower() == "sharedaccesskeyname":
            shared_access_key_name = value
        elif key.lower() == "sharedaccesskey":
            shared_access_key = value
        elif key.lower() == "entitypath":
            entity_path = value
    if not all([endpoint, shared_access_key_name, shared_access_key]):
        raise ValueError(
            "Invalid connection string. Should be in the format: "
            "Endpoint=sb://<FQDN>/;SharedAccessKeyName=<KeyName>;SharedAccessKey=<KeyValue>"
        )
    entity = cast(str, entity_path)
    left_slash_pos = cast(str, endpoint).find("//")
    if left_slash_pos != -1:
        host = cast(str, endpoint)[left_slash_pos + 2:]
    else:
        host = str(endpoint)
    return host, str(shared_access_key_name), str(shared_access_key), entity


def _generate_sas_token(uri, policy, key, expiry=None):
    # type: (str, str, str, Optional[timedelta]) -> _AccessToken
    """Create a shared access signiture token as a string literal.
    :returns: SAS token as string literal.
    :rtype: str
    """
    if not expiry:
        expiry = timedelta(hours=1)  # Default to 1 hour.

    abs_expiry = int(time.time()) + expiry.seconds
    encoded_uri = quote_plus(uri).encode("utf-8")  # pylint: disable=no-member
    encoded_policy = quote_plus(policy).encode("utf-8")  # pylint: disable=no-member
    encoded_key = key.encode("utf-8")

    token = utils.create_sas_token(encoded_policy, encoded_key, encoded_uri, expiry)
    return _AccessToken(token=token, expires_on=abs_expiry)


def _convert_connection_string_to_kwargs(conn_str, shared_key_credential_type, **kwargs):
    # type: (str, Type, Any) -> Dict[str, Any]
    host, policy, key, entity_in_conn_str = _parse_conn_str(conn_str)
    queue_name = kwargs.get("queue_name")
    topic_name = kwargs.get("topic_name")
    if not (queue_name or topic_name or entity_in_conn_str):
        raise ValueError("Entity name is missing. Please specify `queue_name` or `topic_name`"
                         " or use a connection string including the entity information.")

    if queue_name and topic_name:
        raise ValueError("`queue_name` and `topic_name` can not be specified simultaneously.")

    entity_in_kwargs = queue_name or topic_name
    if entity_in_conn_str and entity_in_kwargs and (entity_in_conn_str != entity_in_kwargs):
        raise ServiceBusAuthorizationError(
            "Entity names do not match, the entity name in connection string is {};"
            " the entity name in parameter is {}.".format(entity_in_conn_str, entity_in_kwargs)
        )

    kwargs["fully_qualified_namespace"] = host
    kwargs["entity_name"] = entity_in_conn_str or entity_in_kwargs
    kwargs["credential"] = shared_key_credential_type(policy, key)
    return kwargs


[docs]class ServiceBusSharedKeyCredential(object): """The shared access key credential used for authentication. :param str policy: The name of the shared access policy. :param str key: The shared access key. """ def __init__(self, policy, key): # type: (str, str) -> None self.policy = policy self.key = key self.token_type = TOKEN_TYPE_SASTOKEN
[docs] def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument # type: (str, Any) -> _AccessToken if not scopes: raise ValueError("No token scope provided.") return _generate_sas_token(scopes[0], self.policy, self.key)
class BaseHandler: # pylint:disable=too-many-instance-attributes def __init__( self, fully_qualified_namespace, entity_name, credential, **kwargs ): # type: (str, str, TokenCredential, Any) -> None self.fully_qualified_namespace = fully_qualified_namespace self._entity_name = entity_name subscription_name = kwargs.get("subscription_name") self._mgmt_target = self._entity_name + (("/Subscriptions/" + subscription_name) if subscription_name else '') self._mgmt_target = "{}{}".format(self._mgmt_target, MANAGEMENT_PATH_SUFFIX) self._credential = credential self._container_id = CONTAINER_PREFIX + str(uuid.uuid4())[:8] self._config = Configuration(**kwargs) self._running = False self._handler = None # type: uamqp.AMQPClient self._auth_uri = None self._properties = create_properties() def __enter__(self): self._open_with_retry() return self def __exit__(self, *args): self.close() def _handle_exception(self, exception): # type: (BaseException) -> ServiceBusError error, error_need_close_handler, error_need_raise = _create_servicebus_exception(_LOGGER, exception, self) if error_need_close_handler: self._close_handler() if error_need_raise: raise error return error def _backoff( self, retried_times, last_exception, timeout=None, entity_name=None ): # type: (int, Exception, Optional[float], str) -> None entity_name = entity_name or self._container_id backoff = self._config.retry_backoff_factor * 2 ** retried_times if backoff <= self._config.retry_backoff_max and ( timeout is None or backoff <= timeout ): # pylint:disable=no-else-return time.sleep(backoff) _LOGGER.info( "%r has an exception (%r). Retrying...", format(entity_name), last_exception, ) else: _LOGGER.info( "%r operation has timed out. Last exception before timeout is (%r)", entity_name, last_exception, ) raise last_exception def _do_retryable_operation(self, operation, timeout=None, **kwargs): # type: (Callable, Optional[float], Any) -> Any require_last_exception = kwargs.pop("require_last_exception", False) require_timeout = kwargs.pop("require_timeout", False) retried_times = 0 max_retries = self._config.retry_total while retried_times <= max_retries: try: if require_timeout: kwargs["timeout"] = timeout return operation(**kwargs) except StopIteration: raise except Exception as exception: # pylint: disable=broad-except last_exception = self._handle_exception(exception) if require_last_exception: kwargs["last_exception"] = last_exception retried_times += 1 if retried_times > max_retries: _LOGGER.info( "%r operation has exhausted retry. Last exception: %r.", self._container_id, last_exception, ) raise last_exception self._backoff( retried_times=retried_times, last_exception=last_exception, timeout=timeout ) def _mgmt_request_response(self, mgmt_operation, message, callback, keep_alive_associated_link=True, **kwargs): # type: (str, uamqp.Message, Callable, bool, Any) -> uamqp.Message self._open() application_properties = {} # Some mgmt calls do not support an associated link name (such as list_sessions). Most do, so on by default. if keep_alive_associated_link: try: application_properties = {ASSOCIATEDLINKPROPERTYNAME:self._handler.message_handler.name} except AttributeError: pass mgmt_msg = uamqp.Message( body=message, properties=MessageProperties( reply_to=self._mgmt_target, encoding=self._config.encoding, **kwargs ), application_properties=application_properties ) try: return self._handler.mgmt_request( mgmt_msg, mgmt_operation, op_type=MGMT_REQUEST_OP_TYPE_ENTITY_MGMT, node=self._mgmt_target.encode(self._config.encoding), timeout=5000, callback=callback ) except Exception as exp: # pylint: disable=broad-except raise ServiceBusError("Management request failed: {}".format(exp), exp) def _mgmt_request_response_with_retry(self, mgmt_operation, message, callback, **kwargs): # type: (bytes, Dict[str, Any], Callable, Any) -> Any return self._do_retryable_operation( self._mgmt_request_response, mgmt_operation=mgmt_operation, message=message, callback=callback, **kwargs ) def _open(self): # pylint: disable=no-self-use raise ValueError("Subclass should override the method.") def _open_with_retry(self): return self._do_retryable_operation(self._open) def _close_handler(self): if self._handler: self._handler.close() self._handler = None self._running = False def close(self): # type: () -> None """Close down the handler links (and connection if the handler uses a separate connection). If the handler has already closed, this operation will do nothing. :rtype: None """ self._close_handler()