Source code for azure.identity._internal.shared_token_cache

# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import abc
import os
import sys


from msal import TokenCache
from azure.core.exceptions import ClientAuthenticationError
from .._constants import KnownAuthorities

try:
    ABC = abc.ABC
except AttributeError:  # Python 2.7, abc exists, but not ABC
    ABC = abc.ABCMeta("ABC", (object,), {"__slots__": ()})  # type: ignore

try:
    from typing import TYPE_CHECKING
except ImportError:
    TYPE_CHECKING = False

if TYPE_CHECKING:
    # pylint:disable=unused-import,ungrouped-imports
    from typing import Any, Iterable, List, Mapping, Optional
    import msal_extensions
    from azure.core.credentials import AccessToken
    from .._internal import AadClientBase

    CacheItem = Mapping[str, str]


MULTIPLE_ACCOUNTS = """Multiple users were discovered in the shared token cache. If using DefaultAzureCredential, set
the AZURE_USERNAME environment variable to the preferred username. Otherwise,
specify it when constructing SharedTokenCacheCredential.
Discovered accounts: {}"""

MULTIPLE_MATCHING_ACCOUNTS = """Found multiple accounts matching{}{}. If using DefaultAzureCredential, set environment
variables AZURE_USERNAME and AZURE_TENANT_ID with the preferred username and tenant.
Otherwise, specify them when constructing SharedTokenCacheCredential.
Discovered accounts: {}"""

NO_ACCOUNTS = """The shared cache contains no signed-in accounts. To authenticate with SharedTokenCacheCredential, login
through developer tooling supporting Azure single sign on"""

NO_MATCHING_ACCOUNTS = """The cache contains no account matching the specified{}{}. To authenticate with
SharedTokenCacheCredential, login through developer tooling supporting Azure single sign on.
Discovered accounts: {}"""

NO_TOKEN = """Token acquisition failed for user '{}'. To fix, re-authenticate
through developer tooling supporting Azure single sign on"""

# build a dictionary {authority: {its known aliases}}, aliases taken from MSAL.NET's KnownMetadataProvider
KNOWN_ALIASES = {
    alias: aliases  # N.B. aliases includes alias itself
    for aliases in (
        frozenset((KnownAuthorities.AZURE_CHINA, "login.partner.microsoftonline.cn")),
        frozenset((KnownAuthorities.AZURE_PUBLIC_CLOUD, "login.windows.net", "login.microsoft.com", "sts.windows.net")),
        frozenset((KnownAuthorities.AZURE_GOVERNMENT, "login.usgovcloudapi.net")),
    )
    for alias in aliases
}


def _account_to_string(account):
    username = account.get("username")
    home_account_id = account.get("home_account_id", "").split(".")
    tenant_id = home_account_id[-1] if len(home_account_id) == 2 else ""
    return "(username: {}, tenant: {})".format(username, tenant_id)


def _filtered_accounts(accounts, username=None, tenant_id=None):
    """yield accounts matching username and/or tenant_id"""

    filtered_accounts = []
    for account in accounts:
        if username and account.get("username") != username:
            continue
        if tenant_id:
            try:
                _, tenant = account["home_account_id"].split(".")
                if tenant_id != tenant:
                    continue
            except Exception:  # pylint:disable=broad-except
                continue
        filtered_accounts.append(account)
    return filtered_accounts


class SharedTokenCacheBase(ABC):
    def __init__(self, username=None, **kwargs):  # pylint:disable=unused-argument
        # type: (Optional[str], **Any) -> None

        self._authority = kwargs.pop("authority", None) or KnownAuthorities.AZURE_PUBLIC_CLOUD
        self._authority_aliases = KNOWN_ALIASES.get(self._authority) or frozenset((self._authority,))
        self._username = username
        self._tenant_id = kwargs.pop("tenant_id", None)

        cache = kwargs.pop("_cache", None)  # for ease of testing

        if not cache and sys.platform.startswith("win") and "LOCALAPPDATA" in os.environ:
            from msal_extensions.token_cache import WindowsTokenCache

            cache = WindowsTokenCache(
                cache_location=os.path.join(os.environ["LOCALAPPDATA"], ".IdentityService", "msal.cache")
            )

            # prevent writing to the shared cache
            # TODO: seperating deserializing access tokens from caching them would make this cleaner
            cache.add = lambda *_, **__: None

        if cache:
            self._cache = cache
            self._client = self._get_auth_client(
                authority=self._authority, cache=cache, **kwargs
            )  # type: Optional[AadClientBase]
        else:
            self._client = None

    @abc.abstractmethod
    def _get_auth_client(self, **kwargs):
        # type: (**Any) -> AadClientBase
        pass

    def _get_cache_items_for_authority(self, credential_type):
        # type: (TokenCache.CredentialType) -> List[CacheItem]
        """yield cache items matching this credential's authority or one of its aliases"""

        items = []
        for item in self._cache.find(credential_type):
            environment = item.get("environment")
            if environment in self._authority_aliases:
                items.append(item)
        return items

    def _get_accounts_having_matching_refresh_tokens(self):
        # type: () -> Iterable[CacheItem]
        """returns an iterable of cached accounts which have a matching refresh token"""

        refresh_tokens = self._get_cache_items_for_authority(TokenCache.CredentialType.REFRESH_TOKEN)
        all_accounts = self._get_cache_items_for_authority(TokenCache.CredentialType.ACCOUNT)

        accounts = {}
        for refresh_token in refresh_tokens:
            home_account_id = refresh_token.get("home_account_id")
            if not home_account_id:
                continue
            for account in all_accounts:
                # When the token has no family, msal.net falls back to matching client_id,
                # which won't work for the shared cache because we don't know the IDs of
                # all contributing apps. It should be unnecessary anyway because the
                # apps should all belong to the family.
                if home_account_id == account.get("home_account_id") and "family_id" in refresh_token:
                    accounts[account["home_account_id"]] = account
        return accounts.values()

    def _get_account(self, username=None, tenant_id=None):
        # type: (Optional[str], Optional[str]) -> CacheItem
        """returns exactly one account which has a refresh token and matches username and/or tenant_id"""

        accounts = self._get_accounts_having_matching_refresh_tokens()
        if not accounts:
            # cache is empty or contains no refresh token -> user needs to sign in
            raise ClientAuthenticationError(message=NO_ACCOUNTS)

        filtered_accounts = _filtered_accounts(accounts, username, tenant_id)
        if len(filtered_accounts) == 1:
            return filtered_accounts[0]

        # no, or multiple, accounts after filtering -> choose the best error message
        cached_accounts = ", ".join(_account_to_string(account) for account in accounts)
        if username or tenant_id:
            username_string = " username: {}".format(username) if username else ""
            tenant_string = " tenant: {}".format(tenant_id) if tenant_id else ""
            if filtered_accounts:
                message = MULTIPLE_MATCHING_ACCOUNTS.format(username_string, tenant_string, cached_accounts)
            else:
                message = NO_MATCHING_ACCOUNTS.format(username_string, tenant_string, cached_accounts)
        else:
            message = MULTIPLE_ACCOUNTS.format(cached_accounts)

        raise ClientAuthenticationError(message=message)

    def _get_refresh_tokens(self, account):
        return self._cache.find(
            TokenCache.CredentialType.REFRESH_TOKEN, query={"home_account_id": account.get("home_account_id")}
        )

    @staticmethod
    def supported():
        # type: () -> bool
        """Whether the shared token cache is supported on the current platform.

        :rtype: bool
        """
        return sys.platform.startswith("win")