Source code for azure.eventgrid._legacy._helpers

# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
from typing import TYPE_CHECKING, Any, Dict
import json
import hashlib
import hmac
import base64

from urllib.parse import quote

from azure.core.pipeline.transport import HttpRequest
from azure.core.pipeline.policies import (
    AzureKeyCredentialPolicy,
    BearerTokenCredentialPolicy,
)
from azure.core.credentials import AzureKeyCredential, AzureSasCredential
from ._generated._serialization import Serializer
from ._signature_credential_policy import EventGridSasCredentialPolicy
from . import _constants as constants

from ._generated.models import (
    CloudEvent as InternalCloudEvent,
)

if TYPE_CHECKING:
    from datetime import datetime


[docs] def generate_sas( endpoint: str, shared_access_key: str, expiration_date_utc: "datetime", *, api_version: str = constants.DEFAULT_API_VERSION, ) -> str: """Helper method to generate shared access signature given hostname, key, and expiration date. :param str endpoint: The topic endpoint to send the events to. Similar to <YOUR-TOPIC-NAME>.<YOUR-REGION-NAME>-1.eventgrid.azure.net :param str shared_access_key: The shared access key to be used for generating the token :param datetime.datetime expiration_date_utc: The expiration datetime in UTC for the signature. :keyword str api_version: The API Version to include in the signature. If not provided, the default API version will be used. :return: A shared access signature string. :rtype: str .. admonition:: Example: .. literalinclude:: ../samples/basic/sync_samples/sample_generate_sas.py :start-after: [START generate_sas] :end-before: [END generate_sas] :language: python :dedent: 0 :caption: Generate a shared access signature. """ full_endpoint = "{}?apiVersion={}".format(endpoint, api_version) encoded_resource = quote(full_endpoint, safe=constants.SAFE_ENCODE) encoded_expiration_utc = quote(str(expiration_date_utc), safe=constants.SAFE_ENCODE) unsigned_sas = "r={}&e={}".format(encoded_resource, encoded_expiration_utc) signature = quote(_generate_hmac(shared_access_key, unsigned_sas), safe=constants.SAFE_ENCODE) signed_sas = "{}&s={}".format(unsigned_sas, signature) return signed_sas
def _generate_hmac(key, message): decoded_key = base64.b64decode(key) bytes_message = message.encode("ascii") hmac_new = hmac.new(decoded_key, bytes_message, hashlib.sha256).digest() return base64.b64encode(hmac_new) def _get_authentication_policy(credential, bearer_token_policy=BearerTokenCredentialPolicy): if credential is None: raise ValueError("Parameter 'self._credential' must not be None.") if hasattr(credential, "get_token"): return bearer_token_policy(credential, constants.DEFAULT_EVENTGRID_SCOPE) if isinstance(credential, AzureKeyCredential): return AzureKeyCredentialPolicy(credential=credential, name=constants.EVENTGRID_KEY_HEADER) if isinstance(credential, AzureSasCredential): return EventGridSasCredentialPolicy(credential=credential, name=constants.EVENTGRID_TOKEN_HEADER) raise ValueError( "The provided credential should be an instance of a TokenCredential, AzureSasCredential or AzureKeyCredential" ) def _is_cloud_event(event): # type: (Any) -> bool required = ("id", "source", "specversion", "type") try: return all((_ in event for _ in required)) and event["specversion"] == "1.0" except TypeError: return False def _is_eventgrid_event_format(event): # type: (Any) -> bool required = ("subject", "eventType", "data", "dataVersion", "id", "eventTime") try: return all((prop in event for prop in required)) except TypeError: return False def _eventgrid_data_typecheck(event): try: data = event.get("data") except AttributeError: data = event.data if isinstance(data, bytes): raise TypeError( "Data in EventGridEvent cannot be bytes. Please refer to" "https://docs.microsoft.com/en-us/azure/event-grid/event-schema" ) def _cloud_event_to_generated(cloud_event, **kwargs): if isinstance(cloud_event.data, bytes): data_base64 = cloud_event.data data = None else: data = cloud_event.data data_base64 = None return InternalCloudEvent( id=cloud_event.id, source=cloud_event.source, type=cloud_event.type, specversion=cloud_event.specversion, data=data, data_base64=data_base64, time=cloud_event.time, dataschema=cloud_event.dataschema, datacontenttype=cloud_event.datacontenttype, subject=cloud_event.subject, additional_properties=cloud_event.extensions, **kwargs, ) def _from_cncf_events(event): # pylint: disable=inconsistent-return-statements """This takes in a CNCF cloudevent and returns a dictionary. If cloud events library is not installed, the event is returned back. :param event: The event to be serialized :type event: cloudevents.http.CloudEvent :return: The serialized event :rtype: any """ try: from cloudevents.http import to_json return json.loads(to_json(event)) except (AttributeError, ImportError): # means this is not a CNCF event return event except Exception as err: # pylint: disable=broad-except msg = """Failed to serialize the event. Please ensure your CloudEvents is correctly formatted (https://pypi.org/project/cloudevents/)""" raise ValueError(msg) from err def _build_request(endpoint, content_type, events, *, channel_name=None, api_version=constants.DEFAULT_API_VERSION): serialize = Serializer() header_parameters: Dict[str, Any] = {} header_parameters["Content-Type"] = serialize.header("content_type", content_type, "str") if channel_name: header_parameters["aeg-channel-name"] = channel_name query_parameters: Dict[str, Any] = {} query_parameters["api-version"] = serialize.query("api_version", api_version, "str") body = serialize.body(events, "[object]") if body is None: data = None else: data = json.dumps(body) header_parameters["Content-Length"] = str(len(data)) request = HttpRequest(method="POST", url=endpoint, headers=header_parameters, data=data) request.format_parameters(query_parameters) return request