Source code for azure.eventgrid._helpers

# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from typing import TYPE_CHECKING, Any, Dict
import json
import hashlib
import hmac
import base64

from urllib.parse import quote

from azure.core.pipeline.transport import HttpRequest
from azure.core.pipeline.policies import AzureKeyCredentialPolicy, BearerTokenCredentialPolicy
from azure.core.credentials import AzureKeyCredential, AzureSasCredential
from ._generated._serialization import Serializer
from ._signature_credential_policy import EventGridSasCredentialPolicy
from . import _constants as constants

from ._generated.models import (
    CloudEvent as InternalCloudEvent,

    from datetime import datetime

[docs]def generate_sas( endpoint: str, shared_access_key: str, expiration_date_utc: "datetime", *, api_version: str = constants.DEFAULT_API_VERSION, ) -> str: """Helper method to generate shared access signature given hostname, key, and expiration date. :param str endpoint: The topic endpoint to send the events to. Similar to <YOUR-TOPIC-NAME>.<YOUR-REGION-NAME> :param str shared_access_key: The shared access key to be used for generating the token :param datetime.datetime expiration_date_utc: The expiration datetime in UTC for the signature. :keyword str api_version: The API Version to include in the signature. If not provided, the default API version will be used. :return: A shared access signature string. :rtype: str .. admonition:: Example: .. literalinclude:: ../samples/sync_samples/ :start-after: [START generate_sas] :end-before: [END generate_sas] :language: python :dedent: 0 :caption: Generate a shared access signature. """ full_endpoint = "{}?apiVersion={}".format( endpoint, api_version ) encoded_resource = quote(full_endpoint, safe=constants.SAFE_ENCODE) encoded_expiration_utc = quote(str(expiration_date_utc), safe=constants.SAFE_ENCODE) unsigned_sas = "r={}&e={}".format(encoded_resource, encoded_expiration_utc) signature = quote( _generate_hmac(shared_access_key, unsigned_sas), safe=constants.SAFE_ENCODE ) signed_sas = "{}&s={}".format(unsigned_sas, signature) return signed_sas
def _generate_hmac(key, message): decoded_key = base64.b64decode(key) bytes_message = message.encode("ascii") hmac_new =, bytes_message, hashlib.sha256).digest() return base64.b64encode(hmac_new) def _get_authentication_policy(credential, bearer_token_policy=BearerTokenCredentialPolicy): if credential is None: raise ValueError("Parameter 'self._credential' must not be None.") if hasattr(credential, "get_token"): return bearer_token_policy( credential, constants.DEFAULT_EVENTGRID_SCOPE ) if isinstance(credential, AzureKeyCredential): return AzureKeyCredentialPolicy( credential=credential, name=constants.EVENTGRID_KEY_HEADER ) if isinstance(credential, AzureSasCredential): return EventGridSasCredentialPolicy( credential=credential, name=constants.EVENTGRID_TOKEN_HEADER ) raise ValueError( "The provided credential should be an instance of a TokenCredential, AzureSasCredential or AzureKeyCredential" ) def _is_cloud_event(event): # type: (Any) -> bool required = ("id", "source", "specversion", "type") try: return all((_ in event for _ in required)) and event["specversion"] == "1.0" except TypeError: return False def _is_eventgrid_event(event): # type: (Any) -> bool required = ("subject", "eventType", "data", "dataVersion", "id", "eventTime") try: return all((prop in event for prop in required)) except TypeError: return False def _eventgrid_data_typecheck(event): try: data = event.get("data") except AttributeError: data = if isinstance(data, bytes): raise TypeError( "Data in EventGridEvent cannot be bytes. Please refer to" "" ) def _cloud_event_to_generated(cloud_event, **kwargs): if isinstance(, bytes): data_base64 = data = None else: data = data_base64 = None return InternalCloudEvent(, source=cloud_event.source, type=cloud_event.type, specversion=cloud_event.specversion, data=data, data_base64=data_base64, time=cloud_event.time, dataschema=cloud_event.dataschema, datacontenttype=cloud_event.datacontenttype, subject=cloud_event.subject, additional_properties=cloud_event.extensions, **kwargs ) def _from_cncf_events(event): # pylint: disable=inconsistent-return-statements """This takes in a CNCF cloudevent and returns a dictionary. If cloud events library is not installed, the event is returned back. :param event: The event to be serialized :type event: cloudevents.http.CloudEvent :return: The serialized event :rtype: any """ try: from cloudevents.http import to_json return json.loads(to_json(event)) except (AttributeError, ImportError): # means this is not a CNCF event return event except Exception as err: # pylint: disable=broad-except msg = """Failed to serialize the event. Please ensure your CloudEvents is correctly formatted (""" raise ValueError(msg) from err def _build_request(endpoint, content_type, events, *, channel_name=None, api_version=constants.DEFAULT_API_VERSION): serialize = Serializer() header_parameters: Dict[str, Any] = {} header_parameters['Content-Type'] = serialize.header("content_type", content_type, 'str') if channel_name: header_parameters['aeg-channel-name'] = channel_name query_parameters: Dict[str, Any] = {} query_parameters['api-version'] = serialize.query("api_version", api_version, 'str') body = serialize.body(events, '[object]') if body is None: data = None else: data = json.dumps(body) header_parameters['Content-Length'] = str(len(data)) request = HttpRequest( method="POST", url=endpoint, headers=header_parameters, data=data ) request.format_parameters(query_parameters) return request