# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import ( # pylint: disable=unused-import
Union,
Optional,
Any,
Iterable,
Dict,
List,
Type,
Tuple,
TYPE_CHECKING,
)
import logging
try:
from urllib.parse import parse_qs, quote
except ImportError:
from urlparse import parse_qs # type: ignore
from urllib2 import quote # type: ignore
import six
from azure.core.configuration import Configuration
from azure.core.exceptions import HttpResponseError
from azure.core.pipeline import Pipeline
from azure.core.pipeline.transport import RequestsTransport, HttpTransport
from azure.core.pipeline.policies import (
RedirectPolicy,
ContentDecodePolicy,
BearerTokenCredentialPolicy,
ProxyPolicy,
DistributedTracingPolicy,
HttpLoggingPolicy,
UserAgentPolicy
)
from ._shared_access_signature import QueryStringConstants
from ._constants import STORAGE_OAUTH_SCOPE, SERVICE_HOST_BASE, CONNECTION_TIMEOUT, READ_TIMEOUT
from ._models import LocationMode
from ._authentication import SharedKeyCredentialPolicy
from ._policies import (
StorageHeadersPolicy,
StorageContentValidation,
StorageRequestHook,
StorageResponseHook,
StorageLoggingPolicy,
StorageHosts,
TablesRetryPolicy,
)
from ._error import _process_table_error
from ._models import PartialBatchErrorException
from ._sdk_moniker import SDK_MONIKER
_LOGGER = logging.getLogger(__name__)
_SERVICE_PARAMS = {
"blob": {"primary": "BlobEndpoint", "secondary": "BlobSecondaryEndpoint"},
"queue": {"primary": "QueueEndpoint", "secondary": "QueueSecondaryEndpoint"},
"file": {"primary": "FileEndpoint", "secondary": "FileSecondaryEndpoint"},
"dfs": {"primary": "BlobEndpoint", "secondary": "BlobEndpoint"},
"table": {"primary": "TableEndpoint", "secondary": "TableSecondaryEndpoint"},
}
class StorageAccountHostsMixin(object): # pylint: disable=too-many-instance-attributes
def __init__(
self,
parsed_url, # type: Any
service, # type: str
credential=None, # type: Optional[Any]
**kwargs # type: Any
):
# type: (...) -> None
self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY)
self._hosts = kwargs.get("_hosts")
self.scheme = parsed_url.scheme
if service not in ["blob", "queue", "file-share", "dfs", "table"]:
raise ValueError("Invalid service: {}".format(service))
service_name = service.split('-')[0]
account = parsed_url.netloc.split(".{}.core.".format(service_name))
if 'cosmos' in parsed_url.netloc:
account = parsed_url.netloc.split(".{}.cosmos.".format(service_name))
self.account_name = account[0] if len(account) > 1 else None
secondary_hostname = None
self.credential = format_shared_key_credential(account, credential)
if self.scheme.lower() != "https" and hasattr(self.credential, "get_token"):
raise ValueError("Token credential is only supported with HTTPS.")
if hasattr(self.credential, "account_name"):
self.account_name = self.credential.account_name
secondary_hostname = "{}-secondary.{}.{}".format(
self.credential.account_name, service_name, SERVICE_HOST_BASE)
if not self._hosts:
if len(account) > 1:
secondary_hostname = parsed_url.netloc.replace(account[0], account[0] + "-secondary")
if kwargs.get("secondary_hostname"):
secondary_hostname = kwargs["secondary_hostname"]
primary_hostname = (parsed_url.netloc + parsed_url.path).rstrip('/')
self._hosts = {LocationMode.PRIMARY: primary_hostname, LocationMode.SECONDARY: secondary_hostname}
self.require_encryption = kwargs.get("require_encryption", False)
self.key_encryption_key = kwargs.get("key_encryption_key")
self.key_resolver_function = kwargs.get("key_resolver_function")
self._config, self._pipeline = self._create_pipeline(self.credential, storage_sdk=service, **kwargs)
def __enter__(self):
self._client.__enter__()
return self
def __exit__(self, *args):
self._client.__exit__(*args)
def close(self):
""" This method is to close the sockets opened by the client.
It need not be used when using with a context manager.
"""
self._client.close()
@property
def url(self):
"""The full endpoint URL to this entity, including SAS token if used.
This could be either the primary endpoint,
or the secondary endpoint depending on the current :func:`location_mode`.
"""
return self._format_url(self._hosts[self._location_mode])
@property
def _primary_endpoint(self):
"""The full primary endpoint URL.
:type: str
"""
return self._format_url(self._hosts[LocationMode.PRIMARY])
@property
def _primary_hostname(self):
"""The hostname of the primary endpoint.
:type: str
"""
return self._hosts[LocationMode.PRIMARY]
@property
def _secondary_endpoint(self):
"""The full secondary endpoint URL if configured.
If not available a ValueError will be raised. To explicitly specify a secondary hostname, use the optional
`secondary_hostname` keyword argument on instantiation.
:type: str
:raise ValueError:
"""
if not self._hosts[LocationMode.SECONDARY]:
raise ValueError("No secondary host configured.")
return self._format_url(self._hosts[LocationMode.SECONDARY])
@property
def _secondary_hostname(self):
"""The hostname of the secondary endpoint.
If not available this will be None. To explicitly specify a secondary hostname, use the optional
`secondary_hostname` keyword argument on instantiation.
:type: str or None
"""
return self._hosts[LocationMode.SECONDARY]
@property
def location_mode(self):
"""The location mode that the client is currently using.
By default this will be "primary". Options include "primary" and "secondary".
:type: str
"""
return self._location_mode
@location_mode.setter
def location_mode(self, value):
if self._hosts.get(value):
self._location_mode = value
self._client._config.url = self.url # pylint: disable=protected-access
else:
raise ValueError("No host URL for location mode: {}".format(value))
@property
def api_version(self):
"""The version of the Storage API used for requests.
:type: str
"""
return self._client._config.version # pylint: disable=protected-access
def _format_query_string(self, sas_token, credential, snapshot=None, share_snapshot=None):
query_str = "?"
if snapshot:
query_str += "snapshot={}&".format(self.snapshot)
if share_snapshot:
query_str += "sharesnapshot={}&".format(self.snapshot)
if sas_token and not credential:
query_str += sas_token
elif is_credential_sastoken(credential):
query_str += credential.lstrip("?")
credential = None
return query_str.rstrip("?&"), credential
def _create_pipeline(self, credential, **kwargs):
# type: (Any, **Any) -> Tuple[Configuration, Pipeline]
self._credential_policy = None
if hasattr(credential, "get_token"):
self._credential_policy = BearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE)
elif isinstance(credential, SharedKeyCredentialPolicy):
self._credential_policy = credential
elif credential is not None:
raise TypeError("Unsupported credential: {}".format(credential))
config = kwargs.get("_configuration") or create_configuration(**kwargs)
if kwargs.get("_pipeline"):
return config, kwargs["_pipeline"]
config.transport = kwargs.get("transport") # type: ignore
kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT)
kwargs.setdefault("read_timeout", READ_TIMEOUT)
if not config.transport:
config.transport = RequestsTransport(**kwargs)
policies = [
config.headers_policy,
config.proxy_policy,
config.user_agent_policy,
StorageContentValidation(),
StorageRequestHook(**kwargs),
self._credential_policy,
ContentDecodePolicy(response_encoding="utf-8"),
RedirectPolicy(**kwargs),
StorageHosts(hosts=self._hosts, **kwargs),
config.retry_policy,
config.logging_policy,
StorageResponseHook(**kwargs),
DistributedTracingPolicy(**kwargs),
HttpLoggingPolicy(**kwargs)
]
return config, Pipeline(config.transport, policies=policies)
def _batch_send(
self, *reqs, # type: HttpRequest
**kwargs
):
"""Given a series of request, do a Storage batch call.
"""
# Pop it here, so requests doesn't feel bad about additional kwarg
raise_on_any_failure = kwargs.pop("raise_on_any_failure", True)
request = self._client._client.post( # pylint: disable=protected-access
url='{}://{}/?comp=batch{}{}'.format(
self.scheme,
self._primary_hostname,
kwargs.pop('sas', None),
kwargs.pop('timeout', None)
),
headers={
'x-ms-version': self.api_version
}
)
policies = [StorageHeadersPolicy()]
if self._credential_policy:
policies.append(self._credential_policy)
request.set_multipart_mixed(
*reqs,
policies=policies,
enforce_https=False
)
pipeline_response = self._pipeline.run(
request, **kwargs
)
response = pipeline_response.http_response
try:
if response.status_code not in [202]:
raise HttpResponseError(response=response)
parts = response.parts()
if raise_on_any_failure:
parts = list(response.parts())
if any(p for p in parts if not 200 <= p.status_code < 300):
error = PartialBatchErrorException(
message="There is a partial failure in the batch operation.",
response=response, parts=parts
)
raise error
return iter(parts)
return parts
except HttpResponseError as error:
_process_table_error(error)
class TransportWrapper(HttpTransport):
"""Wrapper class that ensures that an inner client created
by a `get_client` method does not close the outer transport for the parent
when used in a context manager.
"""
def __init__(self, transport):
self._transport = transport
def send(self, request, **kwargs):
return self._transport.send(request, **kwargs)
def open(self):
pass
def close(self):
pass
def __enter__(self):
pass
def __exit__(self, *args): # pylint: disable=arguments-differ
pass
def format_shared_key_credential(account, credential):
if isinstance(credential, six.string_types):
if len(account) < 2:
raise ValueError("Unable to determine account name for shared key credential.")
credential = {"account_name": account[0], "account_key": credential}
if isinstance(credential, dict):
if "account_name" not in credential:
raise ValueError("Shared key credential missing 'account_name")
if "account_key" not in credential:
raise ValueError("Shared key credential missing 'account_key")
return SharedKeyCredentialPolicy(**credential)
return credential
def parse_connection_str(conn_str, credential, service, keyword_args):
conn_str = conn_str.rstrip(";")
conn_settings = [s.split("=", 1) for s in conn_str.split(";")]
if any(len(tup) != 2 for tup in conn_settings):
raise ValueError("Connection string is either blank or malformed.")
conn_settings = dict(conn_settings)
endpoints = _SERVICE_PARAMS[service]
primary = None
secondary = None
if not credential:
try:
credential = {"account_name": conn_settings["AccountName"], "account_key": conn_settings["AccountKey"]}
except KeyError:
credential = conn_settings.get("SharedAccessSignature")
if endpoints["primary"] in conn_settings:
primary = conn_settings[endpoints["primary"]]
if endpoints["secondary"] in conn_settings:
secondary = conn_settings[endpoints["secondary"]]
else:
if endpoints["secondary"] in conn_settings:
raise ValueError("Connection string specifies only secondary endpoint.")
try:
primary = "{}://{}.{}.{}".format(
conn_settings["DefaultEndpointsProtocol"],
conn_settings["AccountName"],
service,
conn_settings["EndpointSuffix"],
)
secondary = "{}-secondary.{}.{}".format(
conn_settings["AccountName"], service, conn_settings["EndpointSuffix"]
)
except KeyError:
pass
if not primary:
try:
primary = "https://{}.{}.{}".format(
conn_settings["AccountName"], service, conn_settings.get("EndpointSuffix", SERVICE_HOST_BASE)
)
except KeyError:
raise ValueError("Connection string missing required connection details.")
if 'secondary_hostname' not in keyword_args:
keyword_args['secondary_hostname'] = secondary
return primary, credential
def create_configuration(**kwargs):
# type: (**Any) -> Configuration
config = Configuration(**kwargs)
config.headers_policy = StorageHeadersPolicy(**kwargs)
config.user_agent_policy = UserAgentPolicy(sdk_moniker=SDK_MONIKER, **kwargs)
# sdk_moniker="storage-{}/{}".format(kwargs.pop('storage_sdk'), VERSION), **kwargs)
config.retry_policy = kwargs.get("retry_policy") or TablesRetryPolicy(**kwargs)
config.logging_policy = StorageLoggingPolicy(**kwargs)
config.proxy_policy = ProxyPolicy(**kwargs)
# Storage settings
config.max_single_put_size = kwargs.get("max_single_put_size", 64 * 1024 * 1024)
config.copy_polling_interval = 15
# Block blob uploads
config.max_block_size = kwargs.get("max_block_size", 4 * 1024 * 1024)
config.min_large_block_upload_threshold = kwargs.get("min_large_block_upload_threshold", 4 * 1024 * 1024 + 1)
config.use_byte_buffer = kwargs.get("use_byte_buffer", False)
# Page blob uploads
config.max_page_size = kwargs.get("max_page_size", 4 * 1024 * 1024)
# Blob downloads
config.max_single_get_size = kwargs.get("max_single_get_size", 32 * 1024 * 1024)
config.max_chunk_get_size = kwargs.get("max_chunk_get_size", 4 * 1024 * 1024)
# File uploads
config.max_range_size = kwargs.get("max_range_size", 4 * 1024 * 1024)
return config
def parse_query(query_str):
sas_values = QueryStringConstants.to_list()
parsed_query = {k: v[0] for k, v in parse_qs(query_str).items()}
sas_params = ["{}={}".format(k, quote(v, safe='')) for k, v in parsed_query.items() if k in sas_values]
sas_token = None
if sas_params:
sas_token = "&".join(sas_params)
snapshot = parsed_query.get("snapshot") or parsed_query.get("sharesnapshot")
return snapshot, sas_token
def is_credential_sastoken(credential):
if not credential or not isinstance(credential, six.string_types):
return False
sas_values = QueryStringConstants.to_list()
parsed_query = parse_qs(credential.lstrip("?"))
if parsed_query and all([k in sas_values for k in parsed_query.keys()]):
return True
return False