def test_client_id_none(): """the credential should ignore client_id=None""" expected_access_token = "****" scope = "scope" def send(request, **_): assert "client_id" not in request.query if request.data: assert "client_id" not in request.body # Cloud Shell return mock_response(json_payload=( build_aad_response(access_token=expected_access_token, expires_on="42", resource=scope))) # IMDS credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send)) token = credential.get_token(scope) assert token.token == expected_access_token # Cloud Shell with mock.patch.dict( MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: "https://localhost"}, clear=True): credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send)) token = credential.get_token(scope) assert token.token == expected_access_token
def test_azure_arc_client_id(): """Azure Arc doesn't support user-assigned managed identity""" with mock.patch( "os.environ", { EnvironmentVariables.IDENTITY_ENDPOINT: "http://localhost:42/token", EnvironmentVariables.IMDS_ENDPOINT: "http://localhost:42", }): credential = ManagedIdentityCredential(client_id="some-guid") with pytest.raises(ClientAuthenticationError): credential.get_token("scope")
class MsiTokenProvider(TokenProviderBase): """ MSI Token Provider obtains a token from the MSI endpoint The args parameter is a dictionary conforming with the ManagedIdentityCredential initializer API arguments """ def __init__(self, kusto_uri: str, msi_args): super().__init__(kusto_uri) self._msi_args = msi_args self._msi_auth_context = None @staticmethod def name() -> str: return "MsiTokenProvider" def context(self) -> dict: context = self._msi_args.copy() context["authority"] = self.name() return context def _init_impl(self): try: self.msi_auth_context = ManagedIdentityCredential(**self._msi_args) except Exception as e: raise KustoClientError("Failed to initialize MSI ManagedIdentityCredential with [" + str(self._msi_params) + "]\n" + str(e)) def _get_token_impl(self) -> dict: return None def _get_token_from_cache_impl(self) -> dict: try: msi_token = self.msi_auth_context.get_token(self._kusto_uri) return {TokenConstants.MSAL_TOKEN_TYPE: TokenConstants.BEARER_TYPE, TokenConstants.MSAL_ACCESS_TOKEN: msi_token.token} except Exception as e: raise KustoClientError("Failed to obtain MSI token for '" + self._kusto_uri + "' with [" + str(self._msi_params) + "]\n" + str(e))
def test_client_id_none_app_service_2017_09_01(): """The credential should ignore client_id=None. App Service 2017-09-01 must be tested separately due to its eccentric expires_on format. """ expected_access_token = "****" scope = "scope" def send(request, **_): assert "client_id" not in request.query assert "clientid" not in request.query return mock_response( json_payload=( build_aad_response( access_token=expected_access_token, expires_on="01/01/1970 00:00:42 +00:00", resource=scope ) ) ) with mock.patch.dict( MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: "https://localhost", EnvironmentVariables.MSI_SECRET: "secret"}, clear=True, ): credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send)) token = credential.get_token(scope) assert token.token == expected_access_token
def test_app_service_user_assigned_identity(): """App Service 2017-09-01: MSI_ENDPOINT, MSI_SECRET set""" expected_token = "****" expires_on = 42 client_id = "some-guid" endpoint = "http://localhost:42/token" secret = "expected-secret" scope = "scope" param_name, param_value = "foo", "bar" transport = validating_transport( requests=[ Request( base_url=endpoint, method="GET", required_headers={"secret": secret, "User-Agent": USER_AGENT}, required_params={"api-version": "2017-09-01", "clientid": client_id, "resource": scope}, ), Request( base_url=endpoint, method="GET", required_headers={"secret": secret, "User-Agent": USER_AGENT}, required_params={ "api-version": "2017-09-01", "clientid": client_id, "resource": scope, param_name: param_value, }, ), ], responses=[ mock_response( json_payload={ "access_token": expected_token, "expires_on": "01/01/1970 00:00:{} +00:00".format(expires_on), "resource": scope, "token_type": "Bearer", } ) ] * 2, ) with mock.patch.dict( MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: endpoint, EnvironmentVariables.MSI_SECRET: secret}, clear=True, ): token = ManagedIdentityCredential(client_id=client_id, transport=transport).get_token(scope) assert token.token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential( client_id=client_id, transport=transport, identity_config={param_name: param_value} ) token = credential.get_token(scope) assert token.token == expected_token assert token.expires_on == expires_on
def test_custom_hooks(environ): """The credential's pipeline should include azure-core's CustomHookPolicy""" scope = "scope" expected_token = "***" request_hook = mock.Mock() response_hook = mock.Mock() now = int(time.time()) expected_response = mock_response( json_payload={ "access_token": expected_token, "expires_in": 3600, "expires_on": now + 3600, "ext_expires_in": 3600, "not_before": now, "resource": scope, "token_type": "Bearer", }) transport = validating_transport(requests=[Request()] * 2, responses=[expected_response] * 2) with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, environ, clear=True): credential = ManagedIdentityCredential(transport=transport, raw_request_hook=request_hook, raw_response_hook=response_hook) credential.get_token(scope) if environ: # some environment variables are set, so we're not mocking IMDS and should expect 1 request assert request_hook.call_count == 1 assert response_hook.call_count == 1 args, kwargs = response_hook.call_args pipeline_response = args[0] assert pipeline_response.http_response == expected_response else: # we're mocking IMDS and should expect 2 requests assert request_hook.call_count == 2 assert response_hook.call_count == 2 responses = [ args[0].http_response for args, _ in response_hook.call_args_list ] assert responses == [expected_response] * 2
def test_client_id_none(): """the credential should ignore client_id=None""" expected_access_token = "****" def send(request, **_): assert "client_id" not in request.query # IMDS assert "clientid" not in request.query # App Service 2017-09-01 if request.data: assert "client_id" not in request.body # Cloud Shell return mock_response(json_payload=(build_aad_response( access_token=expected_access_token))) credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send)) token = credential.get_token("scope") assert token.token == expected_access_token with mock.patch.dict( MANAGED_IDENTITY_ENVIRON, { EnvironmentVariables.MSI_ENDPOINT: "https://localhost", EnvironmentVariables.MSI_SECRET: "secret" }, clear=True, ): credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send)) token = credential.get_token("scope") assert token.token == expected_access_token with mock.patch.dict( MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: "https://localhost"}, clear=True, ): credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send)) token = credential.get_token("scope") assert token.token == expected_access_token
class MsiTokenProvider(TokenProviderBase): """ MSI Token Provider obtains a token from the MSI endpoint The args parameter is a dictionary conforming with the ManagedIdentityCredential initializer API arguments """ def __init__(self, kusto_uri: str, msi_args: dict = None): super().__init__(kusto_uri) self._msi_args = msi_args self._msi_auth_context = None self._msi_auth_context_async = None @staticmethod def name() -> str: return "MsiTokenProvider" def context(self) -> dict: context = self._msi_args.copy() context["authority"] = self.name() return context def _init_impl(self): pass def _get_token_impl(self) -> dict: try: if self._msi_auth_context is None: self._msi_auth_context = ManagedIdentityCredential(**self._msi_args) msi_token = self._msi_auth_context.get_token(self._kusto_uri) return {TokenConstants.MSAL_TOKEN_TYPE: TokenConstants.BEARER_TYPE, TokenConstants.MSAL_ACCESS_TOKEN: msi_token.token} except ClientAuthenticationError as e: raise KustoClientError("Failed to initialize MSI ManagedIdentityCredential with [{0}]\n{1}".format(self._msi_args, e)) except Exception as e: raise KustoClientError("Failed to obtain MSI token for '{0}' with [{1}]\n{2}".format(self._kusto_uri, self._msi_args, e)) async def _get_token_impl_async(self) -> Optional[dict]: try: if self._msi_auth_context_async is None: self._msi_auth_context_async = AsyncManagedIdentityCredential(**self._msi_args) msi_token = await self._msi_auth_context_async.get_token(self._kusto_uri) return {TokenConstants.MSAL_TOKEN_TYPE: TokenConstants.BEARER_TYPE, TokenConstants.MSAL_ACCESS_TOKEN: msi_token.token} except ClientAuthenticationError as e: raise KustoClientError("Failed to initialize MSI async ManagedIdentityCredential with [{0}]\n{1}".format(self._msi_args, e)) except Exception as e: raise KustoClientError("Failed to obtain MSI token for '{0}' with [{1}]\n{2}".format(self._kusto_uri, self._msi_args, e)) def _get_token_from_cache_impl(self) -> dict: return None
def test_cloud_shell_user_assigned_identity(): """Cloud Shell environment: only MSI_ENDPOINT set""" expected_token = "****" expires_on = 42 client_id = "some-guid" endpoint = "http://localhost:42/token" scope = "scope" param_name, param_value = "foo", "bar" transport = validating_transport( requests=[ Request( base_url=endpoint, method="POST", required_headers={"Metadata": "true", "User-Agent": USER_AGENT}, required_data={"client_id": client_id, "resource": scope}, ), Request( base_url=endpoint, method="POST", required_headers={"Metadata": "true", "User-Agent": USER_AGENT}, required_data={"resource": scope, param_name: param_value}, ), ], responses=[ mock_response( json_payload={ "access_token": expected_token, "expires_in": 0, "expires_on": expires_on, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", } ) ] * 2, ) with mock.patch.dict(MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: endpoint}, clear=True): token = ManagedIdentityCredential(client_id=client_id, transport=transport).get_token(scope) assert token.token == expected_token assert token.expires_on == expires_on credential = ManagedIdentityCredential(transport=transport, identity_config={param_name: param_value}) token = credential.get_token(scope) assert token.token == expected_token assert token.expires_on == expires_on
def test_token_exchange_tenant_id(tmpdir): exchange_token = "exchange-token" token_file = tmpdir.join("token") token_file.write(exchange_token) access_token = "***" authority = "https://localhost" default_client_id = "default_client_id" tenant = "tenant_id" scope = "scope" success_response = mock_response( json_payload={ "access_token": access_token, "expires_in": 3600, "ext_expires_in": 3600, "expires_on": int(time.time()) + 3600, "not_before": int(time.time()), "resource": scope, "token_type": "Bearer", }) transport = validating_transport( requests=[ Request( base_url=authority, method="POST", required_data={ "client_assertion": exchange_token, "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", "client_id": default_client_id, "grant_type": "client_credentials", "scope": scope, }, ) ], responses=[success_response], ) mock_environ = { EnvironmentVariables.AZURE_AUTHORITY_HOST: authority, EnvironmentVariables.AZURE_CLIENT_ID: default_client_id, EnvironmentVariables.AZURE_TENANT_ID: tenant, EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath, } with mock.patch.dict("os.environ", mock_environ, clear=True): credential = ManagedIdentityCredential(transport=transport) token = credential.get_token(scope, tenant_id="tenant_id") assert token.token == access_token
def test_cloud_shell_live(cloud_shell): credential = ManagedIdentityCredential() token = credential.get_token("https://vault.azure.net") # Validate the token by sending a request to the Key Vault. The request is manual because azure-keyvault-secrets # can't authenticate in Cloud Shell; the MSI endpoint there doesn't support AADv2 scopes. policies = [ ContentDecodePolicy(), RedirectPolicy(), RetryPolicy(), HttpLoggingPolicy() ] client = PipelineClient(cloud_shell["vault_url"], policies=policies) list_secrets = client.get( "secrets", headers={"Authorization": "Bearer " + token.token}, params={"api-version": "7.0"}) with client: client._pipeline.run(list_secrets)
def main(req: func.HttpRequest) -> func.HttpResponse: logging.info('Python HTTP trigger function processed a request.') from azure.identity import DefaultAzureCredential, AzureCliCredential, ChainedTokenCredential, ManagedIdentityCredential managed_identity = ManagedIdentityCredential() scope = "https://management.azure.com" token = managed_identity.get_token(scope) access_token = token.token crtpath = 'BaltimoreCyberTrustRoot.crt.pem' #crtpath = 'DigiCertGlobalRootCA.crt.pem' # Connect to MySQL cnx = mysql.connector.connect( user="******", password=access_token, host="mysqldevSUFFIXflex.mysql.database.azure.com", port=3306, ssl_ca=crtpath, tls_versions=['TLSv1.2'] ) logging.info(cnx) # Show databases cursor = cnx.cursor() cursor.execute("SHOW DATABASES") result_list = cursor.fetchall() # Build result response text result_str_list = [] for row in result_list: row_str = ', '.join([str(v) for v in row]) result_str_list.append(row_str) result_str = '\n'.join(result_str_list) return func.HttpResponse( result_str, status_code=200 )
def _get_managed_identity_credential(self): cred = ManagedIdentityCredential(client_id=self.aad_application_id) return cred.get_token('https://storage.azure.com/.default').token
class _AadHelper: authentication_method = None auth_context = None msi_auth_context = None username = None kusto_uri = None authority_uri = None client_id = None password = None thumbprint = None private_certificate = None public_certificate = None msi_params = None token_provider = None def __init__(self, kcsb: "KustoConnectionStringBuilder"): self.kusto_uri = "{0.scheme}://{0.hostname}".format( urlparse(kcsb.data_source)) self.username = None if all([kcsb.aad_user_id, kcsb.password]): self.authentication_method = AuthenticationMethod.aad_username_password self.client_id = "db662dc1-0cfe-4e1c-a843-19a68e65be58" self.username = kcsb.aad_user_id self.password = kcsb.password elif all([kcsb.application_client_id, kcsb.application_key]): self.authentication_method = AuthenticationMethod.aad_application_key self.client_id = kcsb.application_client_id self.client_secret = kcsb.application_key elif all([ kcsb.application_client_id, kcsb.application_certificate, kcsb.application_certificate_thumbprint ]): self.client_id = kcsb.application_client_id self.private_certificate = kcsb.application_certificate self.thumbprint = kcsb.application_certificate_thumbprint if all([kcsb.application_public_certificate]): self.public_certificate = kcsb.application_public_certificate self.authentication_method = AuthenticationMethod.aad_application_certificate_sni else: self.authentication_method = AuthenticationMethod.aad_application_certificate elif kcsb.msi_authentication: self.authentication_method = AuthenticationMethod.managed_service_identity self.msi_params = kcsb.msi_parameters return elif any([kcsb.user_token, kcsb.application_token]): self.token = kcsb.user_token or kcsb.application_token self.authentication_method = AuthenticationMethod.aad_token return elif kcsb.az_cli: self.authentication_method = AuthenticationMethod.az_cli_profile return elif kcsb.token_provider: self.authentication_method = AuthenticationMethod.token_provider self.token_provider = kcsb.token_provider else: self.authentication_method = AuthenticationMethod.aad_device_login self.client_id = "db662dc1-0cfe-4e1c-a843-19a68e65be58" authority = kcsb.authority_id or "common" aad_authority_uri = os.environ.get("AadAuthorityUri", CLOUD_LOGIN_URL) self.authority_uri = aad_authority_uri + authority if aad_authority_uri.endswith( "/") else aad_authority_uri + "/" + authority def acquire_authorization_header(self): """Acquire tokens from AAD.""" try: return self._acquire_authorization_header() except (AdalError, KustoClientError) as error: if self.authentication_method is AuthenticationMethod.aad_username_password: kwargs = { "username": self.username, "client_id": self.client_id } elif self.authentication_method is AuthenticationMethod.aad_application_key: kwargs = {"client_id": self.client_id} elif self.authentication_method is AuthenticationMethod.aad_device_login: kwargs = {"client_id": self.client_id} elif self.authentication_method in ( AuthenticationMethod.aad_application_certificate, AuthenticationMethod.aad_application_certificate_sni): kwargs = { "client_id": self.client_id, "thumbprint": self.thumbprint } elif self.authentication_method is AuthenticationMethod.managed_service_identity: kwargs = self.msi_params elif self.authentication_method is AuthenticationMethod.token_provider: kwargs = {} else: raise error kwargs["resource"] = self.kusto_uri if self.authentication_method is AuthenticationMethod.managed_service_identity: kwargs[ "authority"] = AuthenticationMethod.managed_service_identity.value elif self.authentication_method is AuthenticationMethod.token_provider: kwargs["authority"] = AuthenticationMethod.token_provider.value elif self.auth_context is not None: kwargs["authority"] = self.auth_context.authority.url elif self.authentication_method is AuthenticationMethod.az_cli_profile: kwargs["authority"] = AuthenticationMethod.az_cli_profile.value raise KustoAuthenticationError(self.authentication_method.value, error, **kwargs) def _acquire_authorization_header(self) -> str: # Token was provided by caller if self.authentication_method is AuthenticationMethod.aad_token: return _get_header("Bearer", self.token) if self.authentication_method is AuthenticationMethod.token_provider: caller_token = self.token_provider() if not isinstance(caller_token, str): raise KustoClientError( "Token provider returned something that is not a string [" + str(type(caller_token)) + "]") return _get_header("Bearer", caller_token) # Obtain token from MSI endpoint if self.authentication_method == AuthenticationMethod.managed_service_identity: msi_token = self.get_token_from_msi() return _get_header("Bearer", msi_token.token) refresh_token = None if self.authentication_method == AuthenticationMethod.az_cli_profile: stored_token = _get_azure_cli_auth_token() if (TokenResponseFields.REFRESH_TOKEN in stored_token and TokenResponseFields._CLIENT_ID in stored_token and TokenResponseFields._AUTHORITY in stored_token): self.client_id = stored_token[TokenResponseFields._CLIENT_ID] self.username = stored_token[TokenResponseFields.USER_ID] self.authority_uri = stored_token[ TokenResponseFields._AUTHORITY] refresh_token = stored_token[TokenResponseFields.REFRESH_TOKEN] if self.auth_context is None: self.auth_context = AuthenticationContext(self.authority_uri) if refresh_token is not None: token = self.auth_context.acquire_token_with_refresh_token( refresh_token, self.client_id, self.kusto_uri) else: token = self.auth_context.acquire_token(self.kusto_uri, self.username, self.client_id) if token is not None: expiration_date = dateutil.parser.parse( token[TokenResponseFields.EXPIRES_ON]) if expiration_date > datetime.now() + timedelta(minutes=1): return _get_header_from_dict(token) if TokenResponseFields.REFRESH_TOKEN in token: token = self.auth_context.acquire_token_with_refresh_token( token[TokenResponseFields.REFRESH_TOKEN], self.client_id, self.kusto_uri) if token is not None: return _get_header_from_dict(token) # obtain token from AAD if self.authentication_method is AuthenticationMethod.aad_username_password: token = self.auth_context.acquire_token_with_username_password( self.kusto_uri, self.username, self.password, self.client_id) elif self.authentication_method is AuthenticationMethod.aad_application_key: token = self.auth_context.acquire_token_with_client_credentials( self.kusto_uri, self.client_id, self.client_secret) elif self.authentication_method is AuthenticationMethod.aad_device_login: code = self.auth_context.acquire_user_code(self.kusto_uri, self.client_id) print(code[OAuth2DeviceCodeResponseParameters.MESSAGE]) webbrowser.open( code[OAuth2DeviceCodeResponseParameters.VERIFICATION_URL]) token = self.auth_context.acquire_token_with_device_code( self.kusto_uri, code, self.client_id) elif self.authentication_method in ( AuthenticationMethod.aad_application_certificate, AuthenticationMethod.aad_application_certificate): token = self.auth_context.acquire_token_with_client_certificate( self.kusto_uri, self.client_id, self.private_certificate, self.thumbprint, self.public_certificate) else: raise KustoClientError( "Please choose authentication method from azure.kusto.data.security.AuthenticationMethod" ) return _get_header_from_dict(token) def get_token_from_msi(self) -> AccessToken: try: if self.msi_auth_context is None: # Create the MSI Authentication object self.msi_auth_context = ManagedIdentityCredential( **self.msi_params) return self.msi_auth_context.get_token(self.kusto_uri) except Exception as e: raise KustoClientError("Failed to obtain MSI context for [" + str(self.msi_params) + "]\n" + str(e))
def main(req: func.HttpRequest) -> func.HttpResponse: try: PipelinePartType = req.params.get('PipelinePartType') req_body = "" if not PipelinePartType: try: req_body = req.get_json() except ValueError: pass else: PipelinePartType = req_body.get('PipelinePartType') #load environmental variables load_dotenv() SubscriptionID = os.getenv("SubscriptionID") ResourceGroupName = os.getenv("ResourceGroupName") DataFactoryName = os.getenv("DataFactoryName") ClientID = os.getenv("ClientID") SecretKey = os.getenv("SecretKey") TenantID = os.getenv("TenantID") credential = ManagedIdentityCredential() try: credential.get_token() #token check except: if not ClientID: raise Exception( "Managed identity or environment variables are needed") else: credentials = ServicePrincipalCredentials(client_id=ClientID, secret=SecretKey, tenant=TenantID) #Azure clients resource_client = ResourceManagementClient(credentials, SubscriptionID) adf_client = DataFactoryManagementClient(credentials, SubscriptionID) logging.info('Python HTTP trigger function processed a request.') JsonDefinition = req_body DataPipelineName = req_body.get('JsonDefinition').get( "DataPipelineName") DataPipelineName = req.params.get('DataPipelineName') print(req.params) print(req_body) logging.info(f"Processing PipelinePartType: {PipelinePartType}") ## ## Azure Data Factory Objects ## if PipelinePartType == "ADF": logging.info(f"Executing the ADF Pipeline: {PipelinePartType}") run_response = adf_client.pipelines.create_run( ResourceGroupName, DataFactoryName, "wait", parameters={"JsonDefinition": JsonDefinition}) while adf_client.pipeline_runs.get( ResourceGroupName, DataFactoryName, run_response.run_id).status == "InProgress": time.sleep(2) Response = adf_client.pipeline_runs.get(ResourceGroupName, DataFactoryName, run_response.run_id).status if PipelinePartType: return func.HttpResponse(Response) else: return func.HttpResponse("Parameters not configured correctly.", status_code=400) except Exception as e: logging.error(e) return func.HttpResponse(e, status_code=400)