# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# --------------------------------------------------------------------------
"""Traces network calls using the implementation library from the settings."""
import logging
import sys
import urllib.parse
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union, Any, Type
from types import TracebackType
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import SansIOHTTPPolicy
from azure.core.pipeline.transport import (
HttpResponse as LegacyHttpResponse,
HttpRequest as LegacyHttpRequest,
)
from azure.core.rest import HttpResponse, HttpRequest
from azure.core.settings import settings
from azure.core.tracing import SpanKind
if TYPE_CHECKING:
from azure.core.tracing._abstract_span import (
AbstractSpan,
)
HTTPResponseType = TypeVar("HTTPResponseType", HttpResponse, LegacyHttpResponse)
HTTPRequestType = TypeVar("HTTPRequestType", HttpRequest, LegacyHttpRequest)
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
OptExcInfo = Union[ExcInfo, Tuple[None, None, None]]
_LOGGER = logging.getLogger(__name__)
def _default_network_span_namer(http_request: HTTPRequestType) -> str:
"""Extract the path to be used as network span name.
:param http_request: The HTTP request
:type http_request: ~azure.core.pipeline.transport.HttpRequest
:returns: The string to use as network span name
:rtype: str
"""
path = urllib.parse.urlparse(http_request.url).path
if not path:
path = "/"
return path
[docs]
class DistributedTracingPolicy(SansIOHTTPPolicy[HTTPRequestType, HTTPResponseType]):
"""The policy to create spans for Azure calls.
:keyword network_span_namer: A callable to customize the span name
:type network_span_namer: callable[[~azure.core.pipeline.transport.HttpRequest], str]
:keyword tracing_attributes: Attributes to set on all created spans
:type tracing_attributes: dict[str, str]
"""
TRACING_CONTEXT = "TRACING_CONTEXT"
_REQUEST_ID = "x-ms-client-request-id"
_RESPONSE_ID = "x-ms-request-id"
_HTTP_RESEND_COUNT = "http.request.resend_count"
def __init__(self, **kwargs: Any):
self._network_span_namer = kwargs.get("network_span_namer", _default_network_span_namer)
self._tracing_attributes = kwargs.get("tracing_attributes", {})
[docs]
def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
ctxt = request.context.options
try:
span_impl_type = settings.tracing_implementation()
if span_impl_type is None:
return
namer = ctxt.pop("network_span_namer", self._network_span_namer)
tracing_attributes = ctxt.pop("tracing_attributes", self._tracing_attributes)
span_name = namer(request.http_request)
span = span_impl_type(name=span_name, kind=SpanKind.CLIENT)
for attr, value in tracing_attributes.items():
span.add_attribute(attr, value)
span.start()
headers = span.to_header()
request.http_request.headers.update(headers)
request.context[self.TRACING_CONTEXT] = span
except Exception as err: # pylint: disable=broad-except
_LOGGER.warning("Unable to start network span: %s", err)
[docs]
def end_span(
self,
request: PipelineRequest[HTTPRequestType],
response: Optional[HTTPResponseType] = None,
exc_info: Optional[OptExcInfo] = None,
) -> None:
"""Ends the span that is tracing the network and updates its status.
:param request: The PipelineRequest object
:type request: ~azure.core.pipeline.PipelineRequest
:param response: The HttpResponse object
:type response: ~azure.core.rest.HTTPResponse or ~azure.core.pipeline.transport.HttpResponse
:param exc_info: The exception information
:type exc_info: tuple
"""
if self.TRACING_CONTEXT not in request.context:
return
span: "AbstractSpan" = request.context[self.TRACING_CONTEXT]
http_request: Union[HttpRequest, LegacyHttpRequest] = request.http_request
if span is not None:
span.set_http_attributes(http_request, response=response)
if request.context.get("retry_count"):
span.add_attribute(self._HTTP_RESEND_COUNT, request.context["retry_count"])
request_id = http_request.headers.get(self._REQUEST_ID)
if request_id is not None:
span.add_attribute(self._REQUEST_ID, request_id)
if response and self._RESPONSE_ID in response.headers:
span.add_attribute(self._RESPONSE_ID, response.headers[self._RESPONSE_ID])
if exc_info:
span.__exit__(*exc_info)
else:
span.finish()
[docs]
def on_response(
self,
request: PipelineRequest[HTTPRequestType],
response: PipelineResponse[HTTPRequestType, HTTPResponseType],
) -> None:
self.end_span(request, response=response.http_response)
[docs]
def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:
self.end_span(request, exc_info=sys.exc_info())