Пример #1
0
    def execute(self, context):
        aws_hook = AwsBaseHook(self.aws_credentials_id, client_type='s3')
        credentials = aws_hook.get_credentials()
        redshift = PostgresHook(postgres_conn_id=self.redshift_conn_id)
        self.log.info('Date:' + self.execution_date)
        date = parser.parse(self.execution_date)

        self.log.info("Backfill_data: {}".format(self.backfill_data))
        s3_bucket_key = "s3://{}/{}".format(self.s3_bucket, self.s3_key)
        if self.backfill_data:
            s3_path = s3_bucket_key + '/' + str(date.year) + '/' + str(
                date.month)
        else:
            s3_path = s3_bucket_key
        self.log.info("S3 path: {}".format(s3_path))

        self.log.info("Deleting data from table {}.".format(self.table))

        try:
            redshift.run("DELETE FROM {}".format(self.table))
        except table_does_not_exist as ex:
            self.log.info("Andrea does not exist")

        copy_sql = self.COPY_SQL.format(self.table, s3_path,
                                        credentials.access_key,
                                        credentials.secret_key, self.region,
                                        self.json_path)
        self.log.info(
            "SQL Statement Executing on Redshift: {}".format(copy_sql))
        redshift.run(copy_sql)
Пример #2
0
    def execute(self, context):
        """
        Upload the file using boto3 S3 client
        Args:
            context:

        Returns:

        """
        aws_hook = AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type='s3')
        aws_credentials = aws_hook.get_credentials()
        aws_access_key_id = aws_credentials.access_key
        aws_secret_access_key = aws_credentials.secret_key
        s3client = boto3.client('s3',
                                     region_name=self.region_name,
                                     aws_access_key_id=aws_access_key_id,
                                     aws_secret_access_key=aws_secret_access_key)
        try:
            self.fp = os.path.join(self.working_dir, self.fn)
            s3_key = self.s3_folder + self.fn
            s3_path = 's3://' + self.s3_bucket.rstrip('/') +'/' + s3_key
            self.log.info(f'uploading {self.fp}  to {s3_path}')
            response = s3client.upload_file(self.fp, self.s3_bucket, s3_key)
            self.log.info(response)
        except ClientError as e:
            self.log.error(e)
            raise ClientError(e)
        pass
 def _inject_aws_credentials(self) -> None:
     if TRANSFER_SPEC in self.body and AWS_S3_DATA_SOURCE in self.body[TRANSFER_SPEC]:
         aws_hook = AwsBaseHook(self.aws_conn_id, resource_type="s3")
         aws_credentials = aws_hook.get_credentials()
         aws_access_key_id = aws_credentials.access_key  # type: ignore[attr-defined]
         aws_secret_access_key = aws_credentials.secret_key  # type: ignore[attr-defined]
         self.body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY] = {
             ACCESS_KEY_ID: aws_access_key_id,
             SECRET_ACCESS_KEY: aws_secret_access_key,
         }
 def _inject_aws_credentials(self):
     if TRANSFER_SPEC in self.body and AWS_S3_DATA_SOURCE in self.body[TRANSFER_SPEC]:
         aws_hook = AwsBaseHook(self.aws_conn_id)
         aws_credentials = aws_hook.get_credentials()
         aws_access_key_id = aws_credentials.access_key
         aws_secret_access_key = aws_credentials.secret_key
         self.body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY] = {
             ACCESS_KEY_ID: aws_access_key_id,
             SECRET_ACCESS_KEY: aws_secret_access_key,
         }
Пример #5
0
 def test_get_credentials_from_extra_without_token(self, mock_get_connection):
     mock_connection = Connection(
         extra='{"aws_access_key_id": "aws_access_key_id",'
         '"aws_secret_access_key": "aws_secret_access_key"}'
     )
     mock_get_connection.return_value = mock_connection
     hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
     credentials_from_hook = hook.get_credentials()
     self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
     self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key')
     self.assertIsNone(credentials_from_hook.token)
Пример #6
0
    def test_get_credentials_from_login_without_token(self, mock_get_connection):
        mock_connection = Connection(login='******',
                                     password='******',
                                     )

        mock_get_connection.return_value = mock_connection
        hook = AwsBaseHook(aws_conn_id='aws_default', client_type='spam')
        credentials_from_hook = hook.get_credentials()
        self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
        self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key')
        self.assertIsNone(credentials_from_hook.token)
Пример #7
0
 def test_get_credentials_from_login_with_token(self, mock_get_connection):
     mock_connection = Connection(login='******',
                                  password='******',
                                  extra='{"aws_session_token": "test_token"}'
                                  )
     mock_get_connection.return_value = mock_connection
     hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
     credentials_from_hook = hook.get_credentials()
     self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
     self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key')
     self.assertEqual(credentials_from_hook.token, 'test_token')
Пример #8
0
 def test_get_credentials_from_extra_with_token(self, mock_get_connection):
     mock_connection = Connection(
         extra='{"aws_access_key_id": "aws_access_key_id",'
         '"aws_secret_access_key": "aws_secret_access_key",'
         ' "aws_session_token": "session_token"}')
     mock_get_connection.return_value = mock_connection
     hook = AwsBaseHook(aws_conn_id='aws_default',
                        client_type='airflow_test')
     credentials_from_hook = hook.get_credentials()
     assert credentials_from_hook.access_key == 'aws_access_key_id'
     assert credentials_from_hook.secret_key == 'aws_secret_access_key'
     assert credentials_from_hook.token == 'session_token'
Пример #9
0
    def test_get_credentials_from_role_arn(self, mock_get_connection):
        mock_connection = Connection(extra='{"role_arn":"arn:aws:iam::123456:role/role_arn"}')
        mock_get_connection.return_value = mock_connection
        hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
        credentials_from_hook = hook.get_credentials()
        self.assertIn("ASIA", credentials_from_hook.access_key)

        # We assert the length instead of actual values as the values are random:
        # Details: https://github.com/spulec/moto/commit/ab0d23a0ba2506e6338ae20b3fde70da049f7b03
        self.assertEqual(20, len(credentials_from_hook.access_key))
        self.assertEqual(40, len(credentials_from_hook.secret_key))
        self.assertEqual(356, len(credentials_from_hook.token))
 def test_get_credentials_from_extra_with_token(self, mock_get_connection):
     mock_connection = Connection(
         extra='{"aws_access_key_id": "aws_access_key_id",'
         '"aws_secret_access_key": "aws_secret_access_key",'
         ' "aws_session_token": "session_token"}')
     mock_get_connection.return_value = mock_connection
     hook = AwsBaseHook()
     credentials_from_hook = hook.get_credentials()
     self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
     self.assertEqual(credentials_from_hook.secret_key,
                      'aws_secret_access_key')
     self.assertEqual(credentials_from_hook.token, 'session_token')
Пример #11
0
    def execute(self, context):
        self.log.info('StageToRedshiftOperator start')
        aws_hook = AwsBaseHook(aws_conn_id='aws_credentials',
                               resource_type='*')
        credentials = aws_hook.get_credentials()
        redshift = PostgresHook(postgres_conn_id=self.redshift_conn_id)

        self.log.info("Clearing data from destination Redshift table")
        redshift.run("DELETE FROM {}".format(self.table))

        self.log.info("Copying data from S3 to Redshift")
        rendered_key = self.s3_key.format(**context)
        s3_path = "s3://{}/{}".format(self.s3_bucket, rendered_key)
        formatted_sql = StageToRedshiftOperator.copy_sql_stmt.format(
            self.table, s3_path, credentials.access_key,
            credentials.secret_key, self.ignore_headers, self.delimiter)
        redshift.run(formatted_sql)
    def execute(self, context):
        aws_hook = AwsBaseHook(self.aws_credentials, client_type='s3')
        self.log.info(aws_hook)
        credentials = aws_hook.get_credentials()
        redshift_hook = PostgresHook(self.redshift_conn)

        self.s3_bucket = self.s3_bucket+self.s3_key

        sql = self.COPY_SQL.format(
        self.table,
        self.s3_bucket,
        credentials.access_key,
        credentials.secret_key,
        self.json_path)

        self.log.info(sql)
        redshift_hook.run(sql)
        self.log.info("Stage Complete")
        self.log.info(self.run_date)
    def execute(self, context):
        aws_hook = AwsBaseHook(self.aws_credentials_id, client_type="s3")
        credentials = aws_hook.get_credentials()
        redshift = PostgresHook(postgres_conn_id=self.redshift_conn_id)

        self.log.info("Clearing data from destination Redshift table")
        redshift.run("DELETE FROM {}".format(self.table))

        self.log.info("Copying data from S3 to Redshift")
        rendered_key = self.s3_key.format(**context)
        s3_path = "s3://{}/{}".format(self.s3_bucket, rendered_key)
        formatted_sql = StageToRedshiftOperator.copy_sql.format(
            self.table,
            s3_path,
            credentials.access_key,
            credentials.secret_key,
            self.json_format,
        )
        redshift.run(formatted_sql)
    def execute(self, context):
        aws_hook = AwsBaseHook(self.aws_credentials_id)
        aws_credentials = aws_hook.get_credentials()
        redshift_conn = PostgresHook(
            postgres_conn_id=self.redshift_conn_id,
            connect_args={
                'keepalives': 1,
                'keepalives_idle': 60,
                'keepalives_interval': 60
            })

        self.log.debug(f"Truncate Table: {self.table}")
        redshift_conn.run(f"TRUNCATE TABLE {self.table}")

        format = ''
        if self.data_format == 'csv' and self.ignore_header > 0:
            format += f"IGNOREHEADER {self.ignore_header}\n"

        if self.data_format == 'csv':
            format += f"DELIMITER '{self.delimiter}'\n"
        elif self.data_format == 'json':
            format += f"FORMAT AS JSON '{self.jsonpath}'\n"
        format += f"{self.copy_opts}"
        self.log.debug(f"format : {format}")

        formatted_key = self.s3_src_bucket_key.format(**context)
        self.log.info(f"Rendered S3 source file key : {formatted_key}")
        s3_url = f"s3://{self.s3_src_bucket_name}/{formatted_key}"
        self.log.debug(f"S3 URL : {s3_url}")
        formatted_sql = self._sql.format(**dict(
            table=self.table,
            source=s3_url,
            access_key=aws_credentials.access_key,
            secret_access_key=aws_credentials.secret_key,
            format=format
        ))
        self.log.debug(f"Base SQL: {self._sql}")

        self.log.info(f"Copying data from S3 to Redshift table {self.table}...")
        redshift_conn.run(formatted_sql)
        self.log.info(f"Finished copying data from S3 to Redshift table {self.table}")
Пример #15
0
    def test_get_credentials_from_gcp_credentials(self):
        mock_connection = Connection(extra=json.dumps({
            "role_arn":
            "arn:aws:iam::123456:role/role_arn",
            "assume_role_method":
            "assume_role_with_web_identity",
            "assume_role_with_web_identity_federation":
            'google',
            "assume_role_with_web_identity_federation_audience":
            'aws-federation.airflow.apache.org',
        }))

        # Store original __import__
        orig_import = __import__
        mock_id_token_credentials = mock.Mock()

        def import_mock(name, *args):
            if name == 'airflow.providers.google.common.utils.id_token_credentials':
                return mock_id_token_credentials
            return orig_import(name, *args)

        with mock.patch(
                'builtins.__import__', side_effect=import_mock
        ), mock.patch.dict(
                'os.environ',
                AIRFLOW_CONN_AWS_DEFAULT=mock_connection.get_uri()
        ), mock.patch(
                'airflow.providers.amazon.aws.hooks.base_aws.boto3'
        ) as mock_boto3, mock.patch(
                'airflow.providers.amazon.aws.hooks.base_aws.botocore'
        ) as mock_botocore, mock.patch(
                'airflow.providers.amazon.aws.hooks.base_aws.botocore.session'
        ) as mock_session:
            hook = AwsBaseHook(aws_conn_id='aws_default',
                               client_type='airflow_test')

            credentials_from_hook = hook.get_credentials()
            mock_get_credentials = mock_boto3.session.Session.return_value.get_credentials
            assert (mock_get_credentials.return_value.get_frozen_credentials.
                    return_value == credentials_from_hook)

        mock_boto3.assert_has_calls([
            mock.call.session.Session(
                aws_access_key_id=None,
                aws_secret_access_key=None,
                aws_session_token=None,
                region_name=None,
            ),
            mock.call.session.Session()._session.__bool__(),
            mock.call.session.Session(
                botocore_session=mock_session.Session.return_value,
                region_name=mock_boto3.session.Session.return_value.
                region_name,
            ),
            mock.call.session.Session().get_credentials(),
            mock.call.session.Session().get_credentials().
            get_frozen_credentials(),
        ])
        mock_fetcher = mock_botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher
        mock_botocore.assert_has_calls([
            mock.call.credentials.AssumeRoleWithWebIdentityCredentialFetcher(
                client_creator=mock_boto3.session.Session.return_value.
                _session.create_client,
                extra_args={},
                role_arn='arn:aws:iam::123456:role/role_arn',
                web_identity_token_loader=mock.ANY,
            ),
            mock.call.credentials.DeferredRefreshableCredentials(
                method='assume-role-with-web-identity',
                refresh_using=mock_fetcher.return_value.fetch_credentials,
                time_fetcher=mock.ANY,
            ),
        ])

        mock_session.assert_has_calls([mock.call.Session()])
        mock_id_token_credentials.assert_has_calls([
            mock.call.get_default_id_token_credentials(
                target_audience='aws-federation.airflow.apache.org')
        ])