Example #1
0
async def test_single_account_matching_username():
    """one cached account, username specified, username matches -> credential should auth that account"""

    upn = "spam@eggs"
    refresh_token = "refresh-token"
    scope = "scope"
    account = get_account_event(uid="uid_a", utid="utid", username=upn, refresh_token=refresh_token)
    cache = populated_cache(account)

    expected_token = "***"
    transport = async_validating_transport(
        requests=[Request(required_data={"refresh_token": refresh_token, "scope": scope})],
        responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))],
    )
    credential = SharedTokenCacheCredential(_cache=cache, transport=transport, username=upn)
    token = await credential.get_token(scope)
    assert token.token == expected_token
async def test_app_service_user_assigned_identity():
    """App Service 2017-09-01: MSI_ENDPOINT, MSI_SECRET set"""

    access_token = "****"
    expires_on = 42
    client_id = "some-guid"
    expected_token = AccessToken(access_token, expires_on)
    endpoint = "http://localhost:42/token"
    secret = "expected-secret"
    scope = "scope"
    transport = async_validating_transport(
        requests=[
            Request(
                base_url=endpoint,
                method="GET",
                required_headers={
                    "secret": secret,
                    "User-Agent": USER_AGENT
                },
                required_params={
                    "api-version": "2017-09-01",
                    "clientid": client_id,
                    "resource": scope
                },
            )
        ],
        responses=[
            mock_response(
                json_payload={
                    "access_token": access_token,
                    "expires_on": "01/01/1970 00:00:{} +00:00".format(
                        expires_on),
                    "resource": scope,
                    "token_type": "Bearer",
                })
        ],
    )

    with mock.patch(
            "os.environ", {
                EnvironmentVariables.MSI_ENDPOINT: endpoint,
                EnvironmentVariables.MSI_SECRET: secret
            }):
        token = await ManagedIdentityCredential(
            client_id=client_id, transport=transport).get_token(scope)
        assert token == expected_token
Example #3
0
async def test_imds():
    access_token = "****"
    expires_on = 42
    expected_token = AccessToken(access_token, expires_on)
    scope = "scope"
    transport = async_validating_transport(
        requests=[
            Request(
                url=Endpoints.IMDS
            ),  # first request should be availability probe => match only the URL
            Request(
                base_url=Endpoints.IMDS,
                method="GET",
                required_headers={
                    "Metadata": "true",
                    "User-Agent": USER_AGENT
                },
                required_params={
                    "api-version": "2018-02-01",
                    "resource": scope
                },
            ),
        ],
        responses=[
            # probe receives error response
            mock_response(status_code=400,
                          json_payload={"error": "this is an error message"}),
            mock_response(
                json_payload={
                    "access_token": access_token,
                    "expires_in": 42,
                    "expires_on": expires_on,
                    "ext_expires_in": 42,
                    "not_before": int(time.time()),
                    "resource": scope,
                    "token_type": "Bearer",
                }),
        ],
    )

    # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential
    with mock.patch.dict("os.environ", clear=True):
        token = await ManagedIdentityCredential(transport=transport
                                                ).get_token(scope)
    assert token == expected_token
async def test_no_user_settings():
    """the credential should default to Public Cloud and "organizations" tenant when it can't read VS Code settings"""

    transport = async_validating_transport(
        requests=[
            Request(base_url="https://{}/{}".format(
                AzureAuthorityHosts.AZURE_PUBLIC_CLOUD, "organizations"))
        ],
        responses=[
            mock_response(json_payload=build_aad_response(access_token="**"))
        ],
    )

    credential = get_credential(transport=transport)
    with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"):
        await credential.get_token("scope")

    assert transport.send.call_count == 1
async def test_app_service():
    """App Service environment: MSI_ENDPOINT, MSI_SECRET set"""

    access_token = "****"
    expires_on = 42
    expected_token = AccessToken(access_token, expires_on)
    url = "http://localhost:42/token"
    secret = "expected-secret"
    scope = "scope"
    transport = async_validating_transport(
        requests=[
            Request(
                url,
                method="GET",
                required_headers={
                    "Metadata": "true",
                    "secret": secret,
                    "User-Agent": USER_AGENT
                },
                required_params={
                    "api-version": "2017-09-01",
                    "resource": scope
                },
            )
        ],
        responses=[
            mock_response(
                json_payload={
                    "access_token": access_token,
                    "expires_on": expires_on,
                    "resource": scope,
                    "token_type": "Bearer",
                })
        ],
    )

    with mock.patch(
            "os.environ", {
                EnvironmentVariables.MSI_ENDPOINT: url,
                EnvironmentVariables.MSI_SECRET: secret
            }):
        token = await ManagedIdentityCredential(transport=transport
                                                ).get_token(scope)
        assert token == expected_token
Example #6
0
async def test_imds_url_override():
    url = "https://localhost/token"
    expected_token = "***"
    scope = "scope"
    now = int(time.time())

    transport = async_validating_transport(
        requests=[
            Request(
                base_url=url,
                method="GET",
                required_headers={
                    "Metadata": "true",
                    "User-Agent": USER_AGENT
                },
                required_params={
                    "api-version": "2018-02-01",
                    "resource": scope
                },
            ),
        ],
        responses=[
            mock_response(
                json_payload={
                    "access_token": expected_token,
                    "expires_in": 42,
                    "expires_on": now + 42,
                    "ext_expires_in": 42,
                    "not_before": now,
                    "resource": scope,
                    "token_type": "Bearer",
                }),
        ],
    )

    with mock.patch.dict(
            "os.environ",
        {EnvironmentVariables.AZURE_POD_IDENTITY_TOKEN_URL: url},
            clear=True):
        credential = ImdsCredential(transport=transport)
        token = await credential.get_token(scope)

    assert token.token == expected_token
Example #7
0
async def test_two_accounts_username_specified():
    """two cached accounts, username specified, one account matches -> credential should auth that account"""

    scope = "scope"
    expected_refresh_token = "refresh-token-a"
    upn_a = "a@foo"
    upn_b = "b@foo"
    account_a = get_account_event(username=upn_a, uid="uid_a", utid="utid", refresh_token=expected_refresh_token)
    account_b = get_account_event(username=upn_b, uid="uid_b", utid="utid", refresh_token="refresh_token_b")
    cache = populated_cache(account_a, account_b)

    expected_token = "***"
    transport = async_validating_transport(
        requests=[Request(required_data={"refresh_token": expected_refresh_token, "scope": scope})],
        responses=[mock_response(json_payload=build_aad_response(access_token=expected_token))],
    )
    credential = SharedTokenCacheCredential(username=upn_a, _cache=cache, transport=transport)
    token = await credential.get_token(scope)
    assert token.token == expected_token
Example #8
0
async def test_identity_config_cloud_shell():
    param_name, param_value = "foo", "bar"
    access_token = "****"
    expires_on = 42
    expected_token = AccessToken(access_token, expires_on)
    endpoint = "http://localhost:42/token"
    scope = "scope"
    transport = async_validating_transport(
        requests=[
            Request(
                base_url=endpoint,
                method="POST",
                required_headers={
                    "Metadata": "true",
                    "User-Agent": USER_AGENT
                },
                required_data={
                    "resource": scope,
                    param_name: param_value
                },
            )
        ],
        responses=[
            mock_response(
                json_payload={
                    "access_token": access_token,
                    "expires_in": 0,
                    "expires_on": expires_on,
                    "not_before": int(time.time()),
                    "resource": scope,
                    "token_type": "Bearer",
                })
        ],
    )

    with mock.patch.dict(MsiCredential.__module__ + ".os.environ",
                         {EnvironmentVariables.MSI_ENDPOINT: endpoint},
                         clear=True):
        credential = MsiCredential(identity_config={param_name: param_value},
                                   transport=transport)
        token = await credential.get_token(scope)

    assert token == expected_token
async def test_custom_hooks(environ):
    """The credential's pipeline should include azure-core's CustomHookPolicy"""

    scope = "scope"
    expected_token = "***"
    request_hook = mock.Mock()
    response_hook = mock.Mock()
    now = int(time.time())
    expected_response = mock_response(
        json_payload={
            "access_token": expected_token,
            "expires_in": 3600,
            "expires_on": now + 3600,
            "ext_expires_in": 3600,
            "not_before": now,
            "resource": scope,
            "token_type": "Bearer",
        })
    transport = async_validating_transport(requests=[Request()] * 2,
                                           responses=[expected_response] * 2)

    with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, environ, clear=True):
        credential = ManagedIdentityCredential(transport=transport,
                                               raw_request_hook=request_hook,
                                               raw_response_hook=response_hook)
    await credential.get_token(scope)

    if environ:
        # some environment variables are set, so we're not mocking IMDS and should expect 1 request
        assert request_hook.call_count == 1
        assert response_hook.call_count == 1
        args, kwargs = response_hook.call_args
        pipeline_response = args[0]
        assert pipeline_response.http_response == expected_response
    else:
        # we're mocking IMDS and should expect 2 requests
        assert request_hook.call_count == 2
        assert response_hook.call_count == 2
        responses = [
            args[0].http_response for args, _ in response_hook.call_args_list
        ]
        assert responses == [expected_response] * 2
Example #10
0
async def test_cloud_shell_user_assigned_identity():
    """Cloud Shell environment: only MSI_ENDPOINT set"""

    access_token = "****"
    expires_on = 42
    client_id = "some-guid"
    expected_token = AccessToken(access_token, expires_on)
    endpoint = "http://localhost:42/token"
    scope = "scope"
    transport = async_validating_transport(
        requests=[
            Request(
                base_url=endpoint,
                method="POST",
                required_headers={
                    "Metadata": "true",
                    "User-Agent": USER_AGENT
                },
                required_data={
                    "client_id": client_id,
                    "resource": scope
                },
            )
        ],
        responses=[
            mock_response(
                json_payload={
                    "access_token": access_token,
                    "expires_in": 0,
                    "expires_on": expires_on,
                    "not_before": int(time.time()),
                    "resource": scope,
                    "token_type": "Bearer",
                })
        ],
    )

    with mock.patch("os.environ",
                    {EnvironmentVariables.MSI_ENDPOINT: endpoint}):
        token = await ManagedIdentityCredential(
            client_id=client_id, transport=transport).get_token(scope)
        assert token == expected_token
async def test_reads_cloud_settings(cloud, authority):
    """the credential should read authority and tenant from VS Code settings when an application doesn't specify them"""

    expected_tenant = "tenant-id"
    user_settings = {"azure.cloud": cloud, "azure.tenant": expected_tenant}

    transport = async_validating_transport(
        requests=[
            Request(
                base_url="https://{}/{}".format(authority, expected_tenant))
        ],
        responses=[
            mock_response(json_payload=build_aad_response(access_token="**"))
        ],
    )

    credential = get_credential(user_settings, transport=transport)
    with mock.patch(GET_REFRESH_TOKEN, lambda _: "**"):
        await credential.get_token("scope")

    assert transport.send.call_count == 1
Example #12
0
async def test_identity_config():
    param_name, param_value = "foo", "bar"
    access_token = "****"
    expires_on = 42
    expected_token = AccessToken(access_token, expires_on)
    scope = "scope"
    client_id = "some-guid"

    transport = async_validating_transport(
        requests=[
            Request(base_url=IMDS_URL),
            Request(
                base_url=IMDS_URL,
                method="GET",
                required_headers={"Metadata": "true", "User-Agent": USER_AGENT},
                required_params={"api-version": "2018-02-01", "resource": scope, param_name: param_value},
            ),
        ],
        responses=[
            mock_response(status_code=400, json_payload={"error": "this is an error message"}),
            mock_response(
                json_payload={
                    "access_token": access_token,
                    "expires_in": 42,
                    "expires_on": expires_on,
                    "ext_expires_in": 42,
                    "not_before": int(time.time()),
                    "resource": scope,
                    "token_type": "Bearer",
                }
            ),
        ],
    )

    credential = ImdsCredential(client_id=client_id, identity_config={param_name: param_value}, transport=transport)
    token = await credential.get_token(scope)

    assert token == expected_token
async def test_imds_user_assigned_identity():
    access_token = "****"
    expires_on = 42
    expected_token = AccessToken(access_token, expires_on)
    scope = "scope"
    client_id = "some-guid"
    transport = async_validating_transport(
        requests=[
            Request(base_url=IMDS_AUTHORITY + IMDS_TOKEN_PATH),
            Request(
                base_url=IMDS_AUTHORITY + IMDS_TOKEN_PATH,
                method="GET",
                required_headers={"Metadata": "true", "User-Agent": USER_AGENT},
                required_params={"api-version": "2018-02-01", "client_id": client_id, "resource": scope},
            ),
        ],
        responses=[
            # probe receives error response
            mock_response(status_code=400, json_payload={"error": "this is an error message"}),
            mock_response(
                json_payload={
                    "access_token": access_token,
                    "client_id": client_id,
                    "expires_in": 42,
                    "expires_on": expires_on,
                    "ext_expires_in": 42,
                    "not_before": int(time.time()),
                    "resource": scope,
                    "token_type": "Bearer",
                }
            ),
        ],
    )

    # ensure e.g. $MSI_ENDPOINT isn't set, so we get ImdsCredential
    with mock.patch.dict("os.environ", clear=True):
        token = await ManagedIdentityCredential(client_id=client_id, transport=transport).get_token(scope)
    assert token == expected_token
Example #14
0
async def test_client_secret_environment_credential():
    client_id = "fake-client-id"
    secret = "fake-client-secret"
    tenant_id = "fake-tenant-id"
    access_token = "***"

    transport = async_validating_transport(
        requests=[
            Request(url_substring=tenant_id,
                    required_data={
                        "client_id": client_id,
                        "client_secret": secret
                    })
        ],
        responses=[
            mock_response(
                json_payload={
                    "token_type": "Bearer",
                    "expires_in": 42,
                    "ext_expires_in": 42,
                    "access_token": access_token,
                })
        ],
    )

    environment = {
        EnvironmentVariables.AZURE_CLIENT_ID: client_id,
        EnvironmentVariables.AZURE_CLIENT_SECRET: secret,
        EnvironmentVariables.AZURE_TENANT_ID: tenant_id,
    }
    with patch("os.environ", environment):
        token = await EnvironmentCredential(transport=transport
                                            ).get_token("scope")

    # not validating expires_on because doing so requires monkeypatching time, and this is tested elsewhere
    assert token.token == access_token
async def test_authentication_record():
    tenant_id = "tenant-id"
    client_id = "client-id"
    authority = "localhost"
    object_id = "object-id"
    home_account_id = object_id + "." + tenant_id
    username = "******"
    record = AuthenticationRecord(tenant_id, client_id, authority,
                                  home_account_id, username)

    expected_access_token = "****"
    expected_refresh_token = "**"
    account = get_account_event(username,
                                object_id,
                                tenant_id,
                                authority=authority,
                                client_id=client_id,
                                refresh_token=expected_refresh_token)
    cache = populated_cache(account)

    transport = async_validating_transport(
        requests=[
            Request(authority=authority,
                    required_data={"refresh_token": expected_refresh_token})
        ],
        responses=[
            mock_response(json_payload=build_aad_response(
                access_token=expected_access_token))
        ],
    )
    credential = SharedTokenCacheCredential(_authentication_record=record,
                                            transport=transport,
                                            _cache=cache)

    token = await credential.get_token("scope")
    assert token.token == expected_access_token
async def test_same_tenant_different_usernames():
    """two cached accounts, same tenant, different usernames"""

    access_token_a = "access-token-a"
    access_token_b = "access-token-b"
    refresh_token_a = "refresh-token-a"
    refresh_token_b = "refresh-token-b"

    upn_a = "spam@eggs"
    upn_b = "eggs@spam"
    tenant_id = "the-tenant"
    account_a = get_account_event(username=upn_a,
                                  uid="another-guid",
                                  utid=tenant_id,
                                  refresh_token=refresh_token_a)
    account_b = get_account_event(username=upn_b,
                                  uid="more-guid",
                                  utid=tenant_id,
                                  refresh_token=refresh_token_b)
    cache = populated_cache(account_a, account_b)

    # with no username specified the credential can't select an identity
    transport = Mock(
        side_effect=Exception())  # (so it shouldn't use the network)
    credential = SharedTokenCacheCredential(tenant_id=tenant_id,
                                            _cache=cache,
                                            transport=transport)
    with pytest.raises(CredentialUnavailableError) as ex:
        await credential.get_token("scope")

    assert ex.value.message.startswith(
        MULTIPLE_MATCHING_ACCOUNTS[:MULTIPLE_MATCHING_ACCOUNTS.index("{")])
    assert tenant_id in ex.value.message

    # with a username specified, the credential should auth the matching account
    scope = "scope"
    transport = async_validating_transport(
        requests=[
            Request(required_data={
                "refresh_token": refresh_token_b,
                "scope": scope
            })
        ],
        responses=[
            mock_response(json_payload=build_aad_response(
                access_token=access_token_a))
        ],
    )
    credential = SharedTokenCacheCredential(username=upn_b,
                                            _cache=cache,
                                            transport=transport)
    token = await credential.get_token(scope)
    assert token.token == access_token_a

    transport = async_validating_transport(
        requests=[
            Request(required_data={
                "refresh_token": refresh_token_a,
                "scope": scope
            })
        ],
        responses=[
            mock_response(json_payload=build_aad_response(
                access_token=access_token_a))
        ],
    )
    credential = SharedTokenCacheCredential(username=upn_a,
                                            _cache=cache,
                                            transport=transport)
    token = await credential.get_token(scope)
    assert token.token == access_token_a
Example #17
0
async def test_azure_arc(tmpdir):
    """Azure Arc 2019-11-01"""
    access_token = "****"
    api_version = "2019-11-01"
    expires_on = 42
    identity_endpoint = "http://localhost:42/token"
    imds_endpoint = "http://localhost:42"
    scope = "scope"
    secret_key = "XXXX"

    key_file = tmpdir.mkdir("key").join("key_file.key")
    key_file.write(secret_key)
    assert key_file.read() == secret_key
    key_path = os.path.join(key_file.dirname, key_file.basename)

    transport = async_validating_transport(
        requests=[
            Request(
                base_url=identity_endpoint,
                method="GET",
                required_headers={"Metadata": "true"},
                required_params={
                    "api-version": api_version,
                    "resource": scope
                },
            ),
            Request(
                base_url=identity_endpoint,
                method="GET",
                required_headers={
                    "Metadata": "true",
                    "Authorization": "Basic {}".format(secret_key)
                },
                required_params={
                    "api-version": api_version,
                    "resource": scope
                },
            ),
        ],
        responses=[
            # first response gives path to authentication key
            mock_response(status_code=401,
                          headers={
                              "WWW-Authenticate":
                              "Basic realm={}".format(key_path)
                          }),
            mock_response(
                json_payload={
                    "access_token": access_token,
                    "expires_on": expires_on,
                    "resource": scope,
                    "token_type": "Bearer",
                }),
        ],
    )

    with mock.patch(
            "os.environ",
        {
            EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint,
            EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint
        },
    ):
        token = await ManagedIdentityCredential(transport=transport
                                                ).get_token(scope)
        assert token.token == access_token
        assert token.expires_on == expires_on
Example #18
0
async def test_azure_ml():
    """Azure ML: MSI_ENDPOINT, MSI_SECRET set (like App Service 2017-09-01 but with a different response format)"""

    expected_token = AccessToken("****", int(time.time()) + 3600)
    url = "http://localhost:42/token"
    secret = "expected-secret"
    scope = "scope"
    client_id = "client"

    transport = async_validating_transport(
        requests=[
            Request(
                url,
                method="GET",
                required_headers={
                    "secret": secret,
                    "User-Agent": USER_AGENT
                },
                required_params={
                    "api-version": "2017-09-01",
                    "resource": scope
                },
            ),
            Request(
                url,
                method="GET",
                required_headers={
                    "secret": secret,
                    "User-Agent": USER_AGENT
                },
                required_params={
                    "api-version": "2017-09-01",
                    "resource": scope,
                    "clientid": client_id
                },
            ),
        ],
        responses=[
            mock_response(
                json_payload={
                    "access_token": expected_token.token,
                    "expires_in": 3600,
                    "expires_on": expected_token.expires_on,
                    "resource": scope,
                    "token_type": "Bearer",
                })
        ] * 2,
    )

    with mock.patch.dict(
            MANAGED_IDENTITY_ENVIRON,
        {
            EnvironmentVariables.MSI_ENDPOINT: url,
            EnvironmentVariables.MSI_SECRET: secret
        },
            clear=True,
    ):
        credential = ManagedIdentityCredential(transport=transport)
        token = await credential.get_token(scope)
        assert token.token == expected_token.token
        assert token.expires_on == expected_token.expires_on

        credential = ManagedIdentityCredential(transport=transport,
                                               client_id=client_id)
        token = await credential.get_token(scope)
        assert token.token == expected_token.token
        assert token.expires_on == expected_token.expires_on
Example #19
0
async def test_same_username_different_tenants():
    """two cached accounts, same username, different tenants"""

    access_token_a = "access-token-a"
    access_token_b = "access-token-b"
    refresh_token_a = "refresh-token-a"
    refresh_token_b = "refresh-token-b"

    upn = "spam@eggs"
    tenant_a = "tenant-a"
    tenant_b = "tenant-b"
    account_a = get_account_event(username=upn,
                                  uid="another-guid",
                                  utid=tenant_a,
                                  refresh_token=refresh_token_a)
    account_b = get_account_event(username=upn,
                                  uid="more-guid",
                                  utid=tenant_b,
                                  refresh_token=refresh_token_b)
    cache = populated_cache(account_a, account_b)

    # with no tenant specified the credential can't select an identity
    transport = Mock(
        side_effect=Exception())  # (so it shouldn't use the network)
    credential = SharedTokenCacheCredential(username=upn,
                                            _cache=cache,
                                            transport=transport)
    with pytest.raises(ClientAuthenticationError) as ex:
        await credential.get_token("scope")
    # error message should indicate multiple matching accounts, and list discovered accounts
    assert ex.value.message.startswith(
        MULTIPLE_MATCHING_ACCOUNTS[:MULTIPLE_MATCHING_ACCOUNTS.index("{")])
    discovered_accounts = ex.value.message.splitlines()[-1]
    assert discovered_accounts.count(upn) == 2
    assert tenant_a in discovered_accounts and tenant_b in discovered_accounts

    # with tenant specified, the credential should auth the matching account
    scope = "scope"
    transport = async_validating_transport(
        requests=[
            Request(required_data={
                "refresh_token": refresh_token_a,
                "scope": scope
            })
        ],
        responses=[
            mock_response(json_payload=build_aad_response(
                access_token=access_token_a))
        ],
    )
    credential = SharedTokenCacheCredential(tenant_id=tenant_a,
                                            _cache=cache,
                                            transport=transport)
    token = await credential.get_token(scope)
    assert token.token == access_token_a

    transport = async_validating_transport(
        requests=[
            Request(required_data={
                "refresh_token": refresh_token_b,
                "scope": scope
            })
        ],
        responses=[
            mock_response(json_payload=build_aad_response(
                access_token=access_token_b))
        ],
    )
    credential = SharedTokenCacheCredential(tenant_id=tenant_b,
                                            _cache=cache,
                                            transport=transport)
    token = await credential.get_token(scope)
    assert token.token == access_token_b
Example #20
0
async def test_app_service_user_assigned_identity():
    """App Service 2017-09-01: MSI_ENDPOINT, MSI_SECRET set"""

    expected_token = "****"
    expires_on = 42
    client_id = "some-guid"
    endpoint = "http://localhost:42/token"
    secret = "expected-secret"
    scope = "scope"
    param_name, param_value = "foo", "bar"

    transport = async_validating_transport(
        requests=[
            Request(
                base_url=endpoint,
                method="GET",
                required_headers={
                    "secret": secret,
                    "User-Agent": USER_AGENT
                },
                required_params={
                    "api-version": "2017-09-01",
                    "clientid": client_id,
                    "resource": scope
                },
            ),
            Request(
                base_url=endpoint,
                method="GET",
                required_headers={
                    "secret": secret,
                    "User-Agent": USER_AGENT
                },
                required_params={
                    "api-version": "2017-09-01",
                    "resource": scope,
                    param_name: param_value
                },
            ),
        ],
        responses=[
            mock_response(
                json_payload={
                    "access_token": expected_token,
                    "expires_on": "01/01/1970 00:00:{} +00:00".format(
                        expires_on),
                    "resource": scope,
                    "token_type": "Bearer",
                })
        ] * 2,
    )

    with mock.patch.dict(
            MANAGED_IDENTITY_ENVIRON,
        {
            EnvironmentVariables.MSI_ENDPOINT: endpoint,
            EnvironmentVariables.MSI_SECRET: secret
        },
            clear=True,
    ):
        credential = ManagedIdentityCredential(client_id=client_id,
                                               transport=transport)
        token = await credential.get_token(scope)
        assert token.token == expected_token
        assert token.expires_on == expires_on

        credential = ManagedIdentityCredential(
            client_id=client_id,
            transport=transport,
            identity_config={param_name: param_value})
        token = await credential.get_token(scope)
        assert token.token == expected_token
        assert token.expires_on == expires_on
Example #21
0
async def test_app_service_2017_09_01():
    """test parsing of App Service MSI 2017-09-01's eccentric platform-dependent expires_on strings"""

    access_token = "****"
    expires_on = 42
    expected_token = AccessToken(access_token, expires_on)
    url = "http://localhost:42/token"
    secret = "expected-secret"
    scope = "scope"

    transport = async_validating_transport(
        requests=[
            Request(
                url,
                method="GET",
                required_headers={
                    "secret": secret,
                    "User-Agent": USER_AGENT
                },
                required_params={
                    "api-version": "2017-09-01",
                    "resource": scope
                },
            )
        ] * 2,
        responses=[
            mock_response(
                json_payload={
                    "access_token": access_token,
                    "expires_on": "01/01/1970 00:00:{} +00:00".format(
                        expires_on),  # linux format
                    "resource": scope,
                    "token_type": "Bearer",
                }),
            mock_response(
                json_payload={
                    "access_token": access_token,
                    "expires_on": "1/1/1970 12:00:{} AM +00:00".format(
                        expires_on),  # windows format
                    "resource": scope,
                    "token_type": "Bearer",
                }),
        ],
    )

    with mock.patch.dict(
            MANAGED_IDENTITY_ENVIRON,
        {
            EnvironmentVariables.MSI_ENDPOINT: url,
            EnvironmentVariables.MSI_SECRET: secret
        },
            clear=True,
    ):
        token = await ManagedIdentityCredential(transport=transport
                                                ).get_token(scope)
        assert token == expected_token
        assert token.expires_on == expires_on

        token = await ManagedIdentityCredential(transport=transport
                                                ).get_token(scope)
        assert token == expected_token
        assert token.expires_on == expires_on
Example #22
0
async def test_prefers_app_service_2017_09_01():
    """When the environment is configured for both App Service versions, the credential should prefer 2017-09-01

    Support for 2019-08-01 was removed due to https://github.com/Azure/azure-sdk-for-python/issues/14670. This test
    should be removed when that support is added back.
    """

    access_token = "****"
    expires_on = 42
    expected_token = AccessToken(access_token, expires_on)
    url = "http://localhost:42/token"
    secret = "expected-secret"
    scope = "scope"

    transport = async_validating_transport(
        requests=[
            Request(
                url,
                method="GET",
                required_headers={
                    "secret": secret,
                    "User-Agent": USER_AGENT
                },
                required_params={
                    "api-version": "2017-09-01",
                    "resource": scope
                },
            )
        ] * 2,
        responses=[
            mock_response(
                json_payload={
                    "access_token": access_token,
                    "expires_on": "01/01/1970 00:00:{} +00:00".format(
                        expires_on),  # linux format
                    "resource": scope,
                    "token_type": "Bearer",
                }),
            mock_response(
                json_payload={
                    "access_token": access_token,
                    "expires_on": "1/1/1970 12:00:{} AM +00:00".format(
                        expires_on),  # windows format
                    "resource": scope,
                    "token_type": "Bearer",
                }),
        ],
    )

    with mock.patch.dict(
            MANAGED_IDENTITY_ENVIRON,
        {
            EnvironmentVariables.IDENTITY_ENDPOINT: url,
            EnvironmentVariables.IDENTITY_HEADER: secret,
            EnvironmentVariables.MSI_ENDPOINT: url,
            EnvironmentVariables.MSI_SECRET: secret,
        },
            clear=True,
    ):
        credential = ManagedIdentityCredential(transport=transport)
        token = await credential.get_token(scope)
        assert token == expected_token
        assert token.expires_on == expires_on

        credential = ManagedIdentityCredential(transport=transport)
        token = await credential.get_token(scope)
        assert token == expected_token
        assert token.expires_on == expires_on
Example #23
0
async def test_cloud_shell_user_assigned_identity():
    """Cloud Shell environment: only MSI_ENDPOINT set"""

    expected_token = "****"
    expires_on = 42
    client_id = "some-guid"
    endpoint = "http://localhost:42/token"
    scope = "scope"
    param_name, param_value = "foo", "bar"

    transport = async_validating_transport(
        requests=[
            Request(
                base_url=endpoint,
                method="POST",
                required_headers={
                    "Metadata": "true",
                    "User-Agent": USER_AGENT
                },
                required_data={
                    "client_id": client_id,
                    "resource": scope
                },
            ),
            Request(
                base_url=endpoint,
                method="POST",
                required_headers={
                    "Metadata": "true",
                    "User-Agent": USER_AGENT
                },
                required_data={
                    "resource": scope,
                    param_name: param_value
                },
            ),
        ],
        responses=[
            mock_response(
                json_payload={
                    "access_token": expected_token,
                    "expires_in": 0,
                    "expires_on": expires_on,
                    "not_before": int(time.time()),
                    "resource": scope,
                    "token_type": "Bearer",
                })
        ] * 2,
    )

    with mock.patch.dict(MANAGED_IDENTITY_ENVIRON,
                         {EnvironmentVariables.MSI_ENDPOINT: endpoint},
                         clear=True):
        credential = ManagedIdentityCredential(client_id=client_id,
                                               transport=transport)
        token = await credential.get_token(scope)
        assert token.token == expected_token
        assert token.expires_on == expires_on

        credential = ManagedIdentityCredential(
            transport=transport, identity_config={param_name: param_value})
        token = await credential.get_token(scope)
        assert token.token == expected_token
        assert token.expires_on == expires_on
async def test_token_exchange(tmpdir):
    exchange_token = "exchange-token"
    token_file = tmpdir.join("token")
    token_file.write(exchange_token)
    access_token = "***"
    authority = "https://localhost"
    default_client_id = "default_client_id"
    tenant = "tenant_id"
    scope = "scope"

    success_response = mock_response(
        json_payload={
            "access_token": access_token,
            "expires_in": 3600,
            "ext_expires_in": 3600,
            "expires_on": int(time.time()) + 3600,
            "not_before": int(time.time()),
            "resource": scope,
            "token_type": "Bearer",
        }
    )
    transport = async_validating_transport(
        requests=[
            Request(
                base_url=authority,
                method="POST",
                required_data={
                    "client_assertion": exchange_token,
                    "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
                    "client_id": default_client_id,
                    "grant_type": "client_credentials",
                    "scope": scope,
                },
            )
        ],
        responses=[success_response],
    )

    mock_environ = {
        EnvironmentVariables.AZURE_AUTHORITY_HOST: authority,
        EnvironmentVariables.AZURE_CLIENT_ID: default_client_id,
        EnvironmentVariables.AZURE_TENANT_ID: tenant,
        EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath,
    }
    # credential should default to AZURE_CLIENT_ID
    with mock.patch.dict("os.environ", mock_environ, clear=True):
        credential = ManagedIdentityCredential(transport=transport)
        token = await credential.get_token(scope)
        assert token.token == access_token

    # client_id kwarg should override AZURE_CLIENT_ID
    nondefault_client_id = "non" + default_client_id
    transport = async_validating_transport(
        requests=[
            Request(
                base_url=authority,
                method="POST",
                required_data={
                    "client_assertion": exchange_token,
                    "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
                    "client_id": nondefault_client_id,
                    "grant_type": "client_credentials",
                    "scope": scope,
                },
            )
        ],
        responses=[success_response],
    )

    with mock.patch.dict("os.environ", mock_environ, clear=True):
        credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport)
        token = await credential.get_token(scope)
    assert token.token == access_token

    # AZURE_CLIENT_ID may not have a value, in which case client_id is required
    transport = async_validating_transport(
        requests=[
            Request(
                base_url=authority,
                method="POST",
                required_data={
                    "client_assertion": exchange_token,
                    "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
                    "client_id": nondefault_client_id,
                    "grant_type": "client_credentials",
                    "scope": scope,
                },
            )
        ],
        responses=[success_response],
    )

    with mock.patch.dict(
        "os.environ",
        {
            EnvironmentVariables.AZURE_AUTHORITY_HOST: authority,
            EnvironmentVariables.AZURE_TENANT_ID: tenant,
            EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath,
        },
        clear=True,
    ):
        with pytest.raises(ValueError):
            ManagedIdentityCredential()

        credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport)
        token = await credential.get_token(scope)
    assert token.token == access_token
async def test_auth_code_credential():
    client_id = "client id"
    tenant_id = "tenant"
    expected_code = "auth code"
    redirect_uri = "https://localhost"
    expected_access_token = "access"
    expected_refresh_token = "refresh"
    expected_scope = "scope"

    auth_response = build_aad_response(access_token=expected_access_token, refresh_token=expected_refresh_token)
    transport = async_validating_transport(
        requests=[
            Request(  # first call should redeem the auth code
                url_substring=tenant_id,
                required_data={
                    "client_id": client_id,
                    "code": expected_code,
                    "grant_type": "authorization_code",
                    "redirect_uri": redirect_uri,
                    "scope": expected_scope,
                },
            ),
            Request(  # third call should redeem the refresh token
                url_substring=tenant_id,
                required_data={
                    "client_id": client_id,
                    "grant_type": "refresh_token",
                    "refresh_token": expected_refresh_token,
                    "scope": expected_scope,
                },
            ),
        ],
        responses=[mock_response(json_payload=auth_response)] * 2,
    )
    cache = msal.TokenCache()

    credential = AuthorizationCodeCredential(
        client_id=client_id,
        tenant_id=tenant_id,
        authorization_code=expected_code,
        redirect_uri=redirect_uri,
        transport=transport,
        cache=cache,
    )

    # first call should redeem the auth code
    token = await credential.get_token(expected_scope)
    assert token.token == expected_access_token
    assert transport.send.call_count == 1

    # no auth code -> credential should return cached token
    token = await credential.get_token(expected_scope)
    assert token.token == expected_access_token
    assert transport.send.call_count == 1

    # no auth code, no cached token -> credential should redeem refresh token
    cached_access_token = cache.find(cache.CredentialType.ACCESS_TOKEN)[0]
    cache.remove_at(cached_access_token)
    token = await credential.get_token(expected_scope)
    assert token.token == expected_access_token
    assert transport.send.call_count == 2