Esempio n. 1
0
    def test_get_session_returns_a_boto3_session(self):
        hook = AwsHook(aws_conn_id='aws_default')
        session_from_hook = hook.get_session()
        resource_from_session = session_from_hook.resource('dynamodb')
        table = resource_from_session.create_table(  # pylint: disable=no-member
            TableName='test_airflow',
            KeySchema=[
                {
                    'AttributeName': 'id',
                    'KeyType': 'HASH'
                },
            ],
            AttributeDefinitions=[{
                'AttributeName': 'id',
                'AttributeType': 'S'
            }],
            ProvisionedThroughput={
                'ReadCapacityUnits': 10,
                'WriteCapacityUnits': 10
            })

        table.meta.client.get_waiter('table_exists').wait(
            TableName='test_airflow')

        self.assertEqual(table.item_count, 0)
Esempio n. 2
0
    def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
        self._create_clusters()
        hook = AwsHook(aws_conn_id='aws_default')
        client_from_hook = hook.get_client_type('redshift')

        clusters = client_from_hook.describe_clusters()['Clusters']
        self.assertEqual(len(clusters), 2)
Esempio n. 3
0
    def get_iam_token(self, conn):
        """
        Uses AWSHook to retrieve a temporary password to connect to Postgres
        or Redshift. Port is required. If none is provided, default is used for
        each service
        """
        from airflow.providers.amazon.aws.hooks.aws_hook import AwsHook

        redshift = conn.extra_dejson.get('redshift', False)
        aws_conn_id = conn.extra_dejson.get('aws_conn_id', 'aws_default')
        aws_hook = AwsHook(aws_conn_id)
        login = conn.login
        if conn.port is None:
            port = 5439 if redshift else 5432
        else:
            port = conn.port
        if redshift:
            # Pull the custer-identifier from the beginning of the Redshift URL
            # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-cluster
            cluster_identifier = conn.extra_dejson.get('cluster-identifier',
                                                       conn.host.split('.')[0])
            client = aws_hook.get_client_type('redshift')
            cluster_creds = client.get_cluster_credentials(
                DbUser=conn.login,
                DbName=self.schema or conn.schema,
                ClusterIdentifier=cluster_identifier,
                AutoCreate=False)
            token = cluster_creds['DbPassword']
            login = cluster_creds['DbUser']
        else:
            client = aws_hook.get_client_type('rds')
            token = client.generate_db_auth_token(conn.host, port, conn.login)
        return login, token, port
Esempio n. 4
0
 def expand_role(self):
     if 'Model' not in self.config:
         return
     config = self.config['Model']
     if 'ExecutionRoleArn' in config:
         hook = AwsHook(self.aws_conn_id)
         config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
Esempio n. 5
0
 def test_expand_role(self):
     conn = boto3.client('iam', region_name='us-east-1')
     conn.create_role(RoleName='test-role',
                      AssumeRolePolicyDocument='some policy')
     hook = AwsHook()
     arn = hook.expand_role('test-role')
     expect_arn = conn.get_role(RoleName='test-role').get('Role').get('Arn')
     self.assertEqual(arn, expect_arn)
Esempio n. 6
0
    def hook(self) -> AwsHook:
        """
        An AWS API connection manager (wraps boto3)

        :return: the connected hook to AWS
        :rtype: AwsHook
        """
        if self._hook is None:
            self._hook = AwsHook(aws_conn_id=self.aws_conn_id)
        return self._hook
Esempio n. 7
0
    def test_get_client_type_returns_a_boto3_client_of_the_requested_type(
            self):
        client = boto3.client('emr', region_name='us-east-1')
        if client.list_clusters()['Clusters']:
            raise ValueError('AWS not properly mocked')

        hook = AwsHook(aws_conn_id='aws_default')
        client_from_hook = hook.get_client_type('emr')

        self.assertEqual(client_from_hook.list_clusters()['Clusters'], [])
Esempio n. 8
0
 def _inject_aws_credentials(self):
     if TRANSFER_SPEC in self.body and AWS_S3_DATA_SOURCE in self.body[
             TRANSFER_SPEC]:
         aws_hook = AwsHook(self.aws_conn_id)
         aws_credentials = aws_hook.get_credentials()
         aws_access_key_id = aws_credentials.access_key
         aws_secret_access_key = aws_credentials.secret_key
         self.body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY] = {
             ACCESS_KEY_ID: aws_access_key_id,
             SECRET_ACCESS_KEY: aws_secret_access_key,
         }
Esempio n. 9
0
 def test_get_credentials_from_extra_with_s3_config_and_profile(
         self, mock_get_connection, mock_parse_s3_config):
     mock_connection = Connection(extra='{"s3_config_format": "aws", '
                                  '"profile": "test", '
                                  '"s3_config_file": "aws-credentials", '
                                  '"region_name": "us-east-1"}')
     mock_get_connection.return_value = mock_connection
     hook = AwsHook()
     hook._get_credentials(region_name=None)
     mock_parse_s3_config.assert_called_once_with('aws-credentials', 'aws',
                                                  'test')
Esempio n. 10
0
 def test_get_credentials_from_extra_without_token(self,
                                                   mock_get_connection):
     mock_connection = Connection(
         extra='{"aws_access_key_id": "aws_access_key_id",'
         '"aws_secret_access_key": "aws_secret_access_key"}')
     mock_get_connection.return_value = mock_connection
     hook = AwsHook()
     credentials_from_hook = hook.get_credentials()
     self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
     self.assertEqual(credentials_from_hook.secret_key,
                      'aws_secret_access_key')
     self.assertIsNone(credentials_from_hook.token)
Esempio n. 11
0
 def test_get_credentials_from_login_with_token(self, mock_get_connection):
     mock_connection = Connection(
         login='******',
         password='******',
         extra='{"aws_session_token": "test_token"}')
     mock_get_connection.return_value = mock_connection
     hook = AwsHook()
     credentials_from_hook = hook.get_credentials()
     self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
     self.assertEqual(credentials_from_hook.secret_key,
                      'aws_secret_access_key')
     self.assertEqual(credentials_from_hook.token, 'test_token')
Esempio n. 12
0
    def test_get_credentials_from_role_arn(self, mock_get_connection):
        mock_connection = Connection(
            extra='{"role_arn":"arn:aws:iam::123456:role/role_arn"}')
        mock_get_connection.return_value = mock_connection
        hook = AwsHook()
        credentials_from_hook = hook.get_credentials()
        self.assertIn("ASIA", credentials_from_hook.access_key)

        # We assert the length instead of actual values as the values are random:
        # Details: https://github.com/spulec/moto/commit/ab0d23a0ba2506e6338ae20b3fde70da049f7b03
        self.assertEqual(20, len(credentials_from_hook.access_key))
        self.assertEqual(40, len(credentials_from_hook.secret_key))
        self.assertEqual(356, len(credentials_from_hook.token))
Esempio n. 13
0
    def test_get_credentials_from_login_without_token(self,
                                                      mock_get_connection):
        mock_connection = Connection(
            login='******',
            password='******',
        )

        mock_get_connection.return_value = mock_connection
        hook = AwsHook()
        credentials_from_hook = hook.get_credentials()
        self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
        self.assertEqual(credentials_from_hook.secret_key,
                         'aws_secret_access_key')
        self.assertIsNone(credentials_from_hook.token)
Esempio n. 14
0
    def get_iam_token(self, conn):
        """
        Uses AWSHook to retrieve a temporary password to connect to MySQL
        Port is required. If none is provided, default 3306 is used
        """
        from airflow.providers.amazon.aws.hooks.aws_hook import AwsHook

        aws_conn_id = conn.extra_dejson.get('aws_conn_id', 'aws_default')
        aws_hook = AwsHook(aws_conn_id)
        if conn.port is None:
            port = 3306
        else:
            port = conn.port
        client = aws_hook.get_client_type('rds')
        token = client.generate_db_auth_token(conn.host, port, conn.login)
        return token, port
Esempio n. 15
0
 def get_hook(self):
     """Create and return an AwsHook."""
     return AwsHook(
         aws_conn_id=self.aws_conn_id
     )
Esempio n. 16
0
 def expand_role(self):
     if 'RoleArn' in self.config:
         hook = AwsHook(self.aws_conn_id)
         self.config['RoleArn'] = hook.expand_role(self.config['RoleArn'])
Esempio n. 17
0
 def get_hook(self):
     return AwsHook(aws_conn_id=self.aws_conn_id)
Esempio n. 18
0
 def expand_role(self):
     if 'TrainingJobDefinition' in self.config:
         config = self.config['TrainingJobDefinition']
         if 'RoleArn' in config:
             hook = AwsHook(self.aws_conn_id)
             config['RoleArn'] = hook.expand_role(config['RoleArn'])