async def test_client_id_none():
    """the credential should ignore client_id=None"""

    expected_access_token = "****"
    scope = "scope"

    async def send(request, **_):
        assert "client_id" not in request.query  # IMDS
        if request.data:
            assert "client_id" not in request.body  # Cloud Shell
        return mock_response(
            json_payload=(build_aad_response(access_token=expected_access_token, expires_on="42", resource=scope))
        )

    with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, {}, clear=True):
        credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send))
        token = await credential.get_token(scope)
    assert token.token == expected_access_token

    with mock.patch.dict(
        MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: "https://localhost"}, clear=True
    ):
        credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send))
        token = await credential.get_token(scope)
    assert token.token == expected_access_token
async def test_app_service_user_assigned_identity():
    """App Service 2017-09-01: MSI_ENDPOINT, MSI_SECRET set"""

    expected_token = "****"
    expires_on = 42
    client_id = "some-guid"
    endpoint = "http://localhost:42/token"
    secret = "expected-secret"
    scope = "scope"
    param_name, param_value = "foo", "bar"

    transport = async_validating_transport(
        requests=[
            Request(
                base_url=endpoint,
                method="GET",
                required_headers={"secret": secret, "User-Agent": USER_AGENT},
                required_params={"api-version": "2017-09-01", "clientid": client_id, "resource": scope},
            ),
            Request(
                base_url=endpoint,
                method="GET",
                required_headers={"secret": secret, "User-Agent": USER_AGENT},
                required_params={
                    "api-version": "2017-09-01",
                    "resource": scope,
                    param_name: param_value,
                },
            ),
        ],
        responses=[
            mock_response(
                json_payload={
                    "access_token": expected_token,
                    "expires_on": "01/01/1970 00:00:{} +00:00".format(expires_on),
                    "resource": scope,
                    "token_type": "Bearer",
                }
            )
        ]
        * 2,
    )

    with mock.patch.dict(
        MANAGED_IDENTITY_ENVIRON,
        {EnvironmentVariables.MSI_ENDPOINT: endpoint, EnvironmentVariables.MSI_SECRET: secret},
        clear=True,
    ):
        credential = ManagedIdentityCredential(client_id=client_id, transport=transport)
        token = await credential.get_token(scope)
        assert token.token == expected_token
        assert token.expires_on == expires_on

        credential = ManagedIdentityCredential(
            client_id=client_id, transport=transport, identity_config={param_name: param_value}
        )
        token = await credential.get_token(scope)
        assert token.token == expected_token
        assert token.expires_on == expires_on
async def test_app_service_2017_09_01():
    """When the environment for 2019-08-01 is not configured, 2017-09-01 should be used."""

    access_token = "****"
    expires_on = 42
    expected_token = AccessToken(access_token, expires_on)
    url = "http://localhost:42/token"
    secret = "expected-secret"
    scope = "scope"

    transport = async_validating_transport(
        requests=[
            Request(
                url,
                method="GET",
                required_headers={"secret": secret, "User-Agent": USER_AGENT},
                required_params={"api-version": "2017-09-01", "resource": scope},
            )
        ]
        * 2,
        responses=[
            mock_response(
                json_payload={
                    "access_token": access_token,
                    "expires_on": "01/01/1970 00:00:{} +00:00".format(expires_on),  # linux format
                    "resource": scope,
                    "token_type": "Bearer",
                }
            ),
            mock_response(
                json_payload={
                    "access_token": access_token,
                    "expires_on": "1/1/1970 12:00:{} AM +00:00".format(expires_on),  # windows format
                    "resource": scope,
                    "token_type": "Bearer",
                }
            ),
        ],
    )

    with mock.patch.dict(
        MANAGED_IDENTITY_ENVIRON,
        {
            EnvironmentVariables.MSI_ENDPOINT: url,
            EnvironmentVariables.MSI_SECRET: secret,
        },
        clear=True,
    ):
        credential = ManagedIdentityCredential(transport=transport)
        token = await credential.get_token(scope)
        assert token == expected_token
        assert token.expires_on == expires_on

        credential = ManagedIdentityCredential(transport=transport)
        token = await credential.get_token(scope)
        assert token == expected_token
        assert token.expires_on == expires_on
async def test_azure_ml():
    """Azure ML: MSI_ENDPOINT, MSI_SECRET set (like App Service 2017-09-01 but with a different response format)"""

    expected_token = AccessToken("****", int(time.time()) + 3600)
    url = "http://localhost:42/token"
    secret = "expected-secret"
    scope = "scope"
    client_id = "client"

    transport = async_validating_transport(
        requests=[
            Request(
                url,
                method="GET",
                required_headers={"secret": secret, "User-Agent": USER_AGENT},
                required_params={"api-version": "2017-09-01", "resource": scope},
            ),
            Request(
                url,
                method="GET",
                required_headers={"secret": secret, "User-Agent": USER_AGENT},
                required_params={"api-version": "2017-09-01", "resource": scope, "clientid": client_id},
            ),
        ],
        responses=[
            mock_response(
                json_payload={
                    "access_token": expected_token.token,
                    "expires_in": 3600,
                    "expires_on": expected_token.expires_on,
                    "resource": scope,
                    "token_type": "Bearer",
                }
            )
        ]
        * 2,
    )

    with mock.patch.dict(
        MANAGED_IDENTITY_ENVIRON,
        {EnvironmentVariables.MSI_ENDPOINT: url, EnvironmentVariables.MSI_SECRET: secret},
        clear=True,
    ):
        credential = ManagedIdentityCredential(transport=transport)
        token = await credential.get_token(scope)
        assert token.token == expected_token.token
        assert token.expires_on == expected_token.expires_on

        credential = ManagedIdentityCredential(transport=transport, client_id=client_id)
        token = await credential.get_token(scope)
        assert token.token == expected_token.token
        assert token.expires_on == expected_token.expires_on
async def test_cloud_shell_user_assigned_identity():
    """Cloud Shell environment: only MSI_ENDPOINT set"""

    expected_token = "****"
    expires_on = 42
    client_id = "some-guid"
    endpoint = "http://localhost:42/token"
    scope = "scope"
    param_name, param_value = "foo", "bar"

    transport = async_validating_transport(
        requests=[
            Request(
                base_url=endpoint,
                method="POST",
                required_headers={"Metadata": "true", "User-Agent": USER_AGENT},
                required_data={"client_id": client_id, "resource": scope},
            ),
            Request(
                base_url=endpoint,
                method="POST",
                required_headers={"Metadata": "true", "User-Agent": USER_AGENT},
                required_data={"resource": scope, param_name: param_value},
            ),
        ],
        responses=[
            mock_response(
                json_payload={
                    "access_token": expected_token,
                    "expires_in": 0,
                    "expires_on": expires_on,
                    "not_before": int(time.time()),
                    "resource": scope,
                    "token_type": "Bearer",
                }
            )
        ]
        * 2,
    )

    with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: endpoint}, clear=True):
        credential = ManagedIdentityCredential(client_id=client_id, transport=transport)
        token = await credential.get_token(scope)
        assert token.token == expected_token
        assert token.expires_on == expires_on

        credential = ManagedIdentityCredential(transport=transport, identity_config={param_name: param_value})
        token = await credential.get_token(scope)
        assert token.token == expected_token
        assert token.expires_on == expires_on
async def test_managed_identity_live(live_managed_identity_config):
    credential = ManagedIdentityCredential(client_id=live_managed_identity_config["client_id"])

    # do something with Key Vault to verify the credential can get a valid token
    client = SecretClient(live_managed_identity_config["vault_url"], credential, logging_enable=True)
    secret = await client.set_secret("managed-identity-test-secret", "value")
    await client.delete_secret(secret.name)
Ejemplo n.º 7
0
async def test_client_id_none_app_service_2017_09_01():
    """The credential should ignore client_id=None.

    App Service 2017-09-01 must be tested separately due to its eccentric expires_on format.
    """

    expected_access_token = "****"
    scope = "scope"

    async def send(request, **_):
        assert "client_id" not in request.query
        assert "clientid" not in request.query
        return mock_response(json_payload=(
            build_aad_response(access_token=expected_access_token,
                               expires_on="01/01/1970 00:00:42 +00:00",
                               resource=scope)))

    with mock.patch.dict(
            MANAGED_IDENTITY_ENVIRON,
        {
            EnvironmentVariables.MSI_ENDPOINT: "https://localhost",
            EnvironmentVariables.MSI_SECRET: "secret"
        },
            clear=True,
    ):
        credential = ManagedIdentityCredential(client_id=None,
                                               transport=mock.Mock(send=send))
        token = await credential.get_token(scope)
    assert token.token == expected_access_token
async def test_close(environ):
    transport = AsyncMockTransport()
    with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, environ, clear=True):
        credential = ManagedIdentityCredential(transport=transport)

    await credential.close()

    assert transport.__aexit__.call_count == 1
Ejemplo n.º 9
0
async def test_managed_identity_live(live_managed_identity_config):
    credential = ManagedIdentityCredential(
        client_id=live_managed_identity_config["client_id"])

    # do something with Key Vault to verify the credential can get a valid token
    client = SecretClient(live_managed_identity_config["vault_url"],
                          credential,
                          logging_enable=True)
    async for _ in client.list_properties_of_secrets():
        pass
Ejemplo n.º 10
0
async def test_cloud_shell_live(cloud_shell):
    credential = ManagedIdentityCredential()
    token = credential.get_token("https://vault.azure.net")

    # Validate the token by sending a request to the Key Vault. The request is manual because azure-keyvault-secrets
    # can't authenticate in Cloud Shell; the MSI endpoint there doesn't support AADv2 scopes.
    policies = [
        ContentDecodePolicy(),
        AsyncRedirectPolicy(),
        AsyncRetryPolicy(),
        HttpLoggingPolicy()
    ]
    client = AsyncPipelineClient(cloud_shell["vault_url"], policies=policies)
    list_secrets = client.get(
        "secrets",
        headers={"Authorization": "Bearer " + token.token},
        params={"api-version": "7.0"})
    async with client:
        await client._pipeline.run(list_secrets)
async def test_context_manager(environ):
    transport = AsyncMockTransport()
    with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, environ, clear=True):
        credential = ManagedIdentityCredential(transport=transport)

    async with credential:
        assert transport.__aenter__.call_count == 1
        assert transport.__aexit__.call_count == 0

    assert transport.__aenter__.call_count == 1
    assert transport.__aexit__.call_count == 1
async def test_azure_arc_client_id():
    """Azure Arc doesn't support user-assigned managed identity"""
    with mock.patch(
            "os.environ", {
                EnvironmentVariables.IDENTITY_ENDPOINT:
                "http://localhost:42/token",
                EnvironmentVariables.IMDS_ENDPOINT: "http://localhost:42",
            }):
        credential = ManagedIdentityCredential(client_id="some-guid")

    with pytest.raises(ClientAuthenticationError):
        await credential.get_token("scope")
Ejemplo n.º 13
0
    def _initialize_credentials(self):
        if self.subscription_id is not None \
           and self.arm_base_url is not None:
            if self.vscode_tenant_id is None:
                self.vscode_tenant_id = self._get_tenant_id(
                    arm_base_url=self.arm_base_url,
                    subscription_id=self.subscription_id)
            if self.shared_cache_tenant_id is None:
                self.shared_cache_tenant_id = self._get_tenant_id(
                    arm_base_url=self.arm_base_url,
                    subscription_id=self.subscription_id)
            if self.interactive_browser_tenant_id is None:
                self.interactive_browser_tenant_id = self._get_tenant_id(
                    arm_base_url=self.arm_base_url,
                    subscription_id=self.subscription_id)

        credentials = []  # type: List[AsyncTokenCredential]
        if not self.exclude_token_file_credential:
            credentials.append(_TokenFileCredential())
        if not self.exclude_environment_credential:
            credentials.append(EnvironmentCredential(authority=self.authority))
        if not self.exclude_managed_identity_credential:
            credentials.append(
                ManagedIdentityCredential(
                    client_id=self.managed_identity_client_id))
        if not self.exclude_shared_token_cache_credential and SharedTokenCacheCredential.supported(
        ):
            try:
                # username and/or tenant_id are only required when the cache contains tokens for multiple identities
                shared_cache = SharedTokenCacheCredential(
                    username=self.shared_cache_username,
                    tenant_id=self.shared_cache_tenant_id,
                    authority=self.authority)
                credentials.append(shared_cache)
            except Exception as ex:  # pylint:disable=broad-except
                _LOGGER.info("Shared token cache is unavailable: '%s'", ex)
        if not self.exclude_visual_studio_code_credential:
            credentials.append(
                VisualStudioCodeCredential(tenant_id=self.vscode_tenant_id))
        if not self.exclude_cli_credential:
            credentials.append(AzureCliCredential())
        if not self.exclude_powershell_credential:
            credentials.append(AzurePowerShellCredential())
        if not self.exclude_interactive_browser_credential:
            credentials.append(
                InteractiveBrowserCredential(
                    tenant_id=self.interactive_browser_tenant_id))
        if not self.exclude_device_code_credential:
            credentials.append(
                DeviceCodeCredential(
                    tenant_id=self.interactive_browser_tenant_id))

        self.credentials = credentials
async def test_token_exchange_tenant_id(tmpdir):
    exchange_token = "exchange-token"
    token_file = tmpdir.join("token")
    token_file.write(exchange_token)
    access_token = "***"
    authority = "https://localhost"
    default_client_id = "default_client_id"
    tenant = "tenant_id"
    scope = "scope"

    success_response = mock_response(
        json_payload={
            "access_token": access_token,
            "expires_in": 3600,
            "ext_expires_in": 3600,
            "expires_on": int(time.time()) + 3600,
            "not_before": int(time.time()),
            "resource": scope,
            "token_type": "Bearer",
        }
    )
    transport = async_validating_transport(
        requests=[
            Request(
                base_url=authority,
                method="POST",
                required_data={
                    "client_assertion": exchange_token,
                    "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
                    "client_id": default_client_id,
                    "grant_type": "client_credentials",
                    "scope": scope,
                },
            )
        ],
        responses=[success_response],
    )

    mock_environ = {
        EnvironmentVariables.AZURE_AUTHORITY_HOST: authority,
        EnvironmentVariables.AZURE_CLIENT_ID: default_client_id,
        EnvironmentVariables.AZURE_TENANT_ID: tenant,
        EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath,
    }
    # credential should default to AZURE_CLIENT_ID
    with mock.patch.dict("os.environ", mock_environ, clear=True):
        credential = ManagedIdentityCredential(transport=transport)
        token = await credential.get_token(scope, tenant_id="tenant_id")
        assert token.token == access_token
async def test_custom_hooks(environ):
    """The credential's pipeline should include azure-core's CustomHookPolicy"""

    scope = "scope"
    expected_token = "***"
    request_hook = mock.Mock()
    response_hook = mock.Mock()
    now = int(time.time())
    expected_response = mock_response(
        json_payload={
            "access_token": expected_token,
            "expires_in": 3600,
            "expires_on": now + 3600,
            "ext_expires_in": 3600,
            "not_before": now,
            "resource": scope,
            "token_type": "Bearer",
        })
    transport = async_validating_transport(requests=[Request()] * 2,
                                           responses=[expected_response] * 2)

    with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, environ, clear=True):
        credential = ManagedIdentityCredential(transport=transport,
                                               raw_request_hook=request_hook,
                                               raw_response_hook=response_hook)
    await credential.get_token(scope)

    if environ:
        # some environment variables are set, so we're not mocking IMDS and should expect 1 request
        assert request_hook.call_count == 1
        assert response_hook.call_count == 1
        args, kwargs = response_hook.call_args
        pipeline_response = args[0]
        assert pipeline_response.http_response == expected_response
    else:
        # we're mocking IMDS and should expect 2 requests
        assert request_hook.call_count == 2
        assert response_hook.call_count == 2
        responses = [
            args[0].http_response for args, _ in response_hook.call_args_list
        ]
        assert responses == [expected_response] * 2
Ejemplo n.º 16
0
async def main(req: func.HttpRequest) -> func.HttpResponse:
    # capture interesting environment variables for debugging
    env = "\n".join(f"{var}: {os.environ.get(var)}"
                    for var in EXPECTED_VARIABLES)

    try:
        credential = ManagedIdentityCredential(client_id=os.environ.get(
            "AZURE_IDENTITY_TEST_MANAGED_IDENTITY_CLIENT_ID"))

        # do something with Key Vault to verify the credential can get a valid token
        client = SecretClient(os.environ["AZURE_IDENTITY_TEST_VAULT_URL"],
                              credential,
                              logging_enable=True)
        secret = await client.set_secret("managed-identity-test-secret",
                                         "value")
        await client.delete_secret(secret.name)

        return func.HttpResponse("test passed")
    except Exception as ex:
        return func.HttpResponse("test failed: " + repr(ex) + "\n" * 3 + env)
Ejemplo n.º 17
0
async def test_prefers_app_service_2017_09_01():
    """When the environment is configured for both App Service versions, the credential should prefer 2017-09-01

    Support for 2019-08-01 was removed due to https://github.com/Azure/azure-sdk-for-python/issues/14670. This test
    should be removed when that support is added back.
    """

    access_token = "****"
    expires_on = 42
    expected_token = AccessToken(access_token, expires_on)
    url = "http://localhost:42/token"
    secret = "expected-secret"
    scope = "scope"

    transport = async_validating_transport(
        requests=[
            Request(
                url,
                method="GET",
                required_headers={
                    "secret": secret,
                    "User-Agent": USER_AGENT
                },
                required_params={
                    "api-version": "2017-09-01",
                    "resource": scope
                },
            )
        ] * 2,
        responses=[
            mock_response(
                json_payload={
                    "access_token": access_token,
                    "expires_on": "01/01/1970 00:00:{} +00:00".format(
                        expires_on),  # linux format
                    "resource": scope,
                    "token_type": "Bearer",
                }),
            mock_response(
                json_payload={
                    "access_token": access_token,
                    "expires_on": "1/1/1970 12:00:{} AM +00:00".format(
                        expires_on),  # windows format
                    "resource": scope,
                    "token_type": "Bearer",
                }),
        ],
    )

    with mock.patch.dict(
            MANAGED_IDENTITY_ENVIRON,
        {
            EnvironmentVariables.IDENTITY_ENDPOINT: url,
            EnvironmentVariables.IDENTITY_HEADER: secret,
            EnvironmentVariables.MSI_ENDPOINT: url,
            EnvironmentVariables.MSI_SECRET: secret,
        },
            clear=True,
    ):
        credential = ManagedIdentityCredential(transport=transport)
        token = await credential.get_token(scope)
        assert token == expected_token
        assert token.expires_on == expires_on

        credential = ManagedIdentityCredential(transport=transport)
        token = await credential.get_token(scope)
        assert token == expected_token
        assert token.expires_on == expires_on
async def test_context_manager_incomplete_configuration():
    async with ManagedIdentityCredential():
        pass