def execute(self, context): aws_hook = AwsBaseHook(self.aws_credentials_id, client_type='s3') credentials = aws_hook.get_credentials() redshift = PostgresHook(postgres_conn_id=self.redshift_conn_id) self.log.info('Date:' + self.execution_date) date = parser.parse(self.execution_date) self.log.info("Backfill_data: {}".format(self.backfill_data)) s3_bucket_key = "s3://{}/{}".format(self.s3_bucket, self.s3_key) if self.backfill_data: s3_path = s3_bucket_key + '/' + str(date.year) + '/' + str( date.month) else: s3_path = s3_bucket_key self.log.info("S3 path: {}".format(s3_path)) self.log.info("Deleting data from table {}.".format(self.table)) try: redshift.run("DELETE FROM {}".format(self.table)) except table_does_not_exist as ex: self.log.info("Andrea does not exist") copy_sql = self.COPY_SQL.format(self.table, s3_path, credentials.access_key, credentials.secret_key, self.region, self.json_path) self.log.info( "SQL Statement Executing on Redshift: {}".format(copy_sql)) redshift.run(copy_sql)
def execute(self, context): """ Upload the file using boto3 S3 client Args: context: Returns: """ aws_hook = AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type='s3') aws_credentials = aws_hook.get_credentials() aws_access_key_id = aws_credentials.access_key aws_secret_access_key = aws_credentials.secret_key s3client = boto3.client('s3', region_name=self.region_name, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key) try: self.fp = os.path.join(self.working_dir, self.fn) s3_key = self.s3_folder + self.fn s3_path = 's3://' + self.s3_bucket.rstrip('/') +'/' + s3_key self.log.info(f'uploading {self.fp} to {s3_path}') response = s3client.upload_file(self.fp, self.s3_bucket, s3_key) self.log.info(response) except ClientError as e: self.log.error(e) raise ClientError(e) pass
def _inject_aws_credentials(self) -> None: if TRANSFER_SPEC in self.body and AWS_S3_DATA_SOURCE in self.body[TRANSFER_SPEC]: aws_hook = AwsBaseHook(self.aws_conn_id, resource_type="s3") aws_credentials = aws_hook.get_credentials() aws_access_key_id = aws_credentials.access_key # type: ignore[attr-defined] aws_secret_access_key = aws_credentials.secret_key # type: ignore[attr-defined] self.body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY] = { ACCESS_KEY_ID: aws_access_key_id, SECRET_ACCESS_KEY: aws_secret_access_key, }
def _inject_aws_credentials(self): if TRANSFER_SPEC in self.body and AWS_S3_DATA_SOURCE in self.body[TRANSFER_SPEC]: aws_hook = AwsBaseHook(self.aws_conn_id) aws_credentials = aws_hook.get_credentials() aws_access_key_id = aws_credentials.access_key aws_secret_access_key = aws_credentials.secret_key self.body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY] = { ACCESS_KEY_ID: aws_access_key_id, SECRET_ACCESS_KEY: aws_secret_access_key, }
def test_get_credentials_from_extra_without_token(self, mock_get_connection): mock_connection = Connection( extra='{"aws_access_key_id": "aws_access_key_id",' '"aws_secret_access_key": "aws_secret_access_key"}' ) mock_get_connection.return_value = mock_connection hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test') credentials_from_hook = hook.get_credentials() self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id') self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key') self.assertIsNone(credentials_from_hook.token)
def test_get_credentials_from_login_without_token(self, mock_get_connection): mock_connection = Connection(login='******', password='******', ) mock_get_connection.return_value = mock_connection hook = AwsBaseHook(aws_conn_id='aws_default', client_type='spam') credentials_from_hook = hook.get_credentials() self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id') self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key') self.assertIsNone(credentials_from_hook.token)
def test_get_credentials_from_login_with_token(self, mock_get_connection): mock_connection = Connection(login='******', password='******', extra='{"aws_session_token": "test_token"}' ) mock_get_connection.return_value = mock_connection hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test') credentials_from_hook = hook.get_credentials() self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id') self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key') self.assertEqual(credentials_from_hook.token, 'test_token')
def test_get_credentials_from_extra_with_token(self, mock_get_connection): mock_connection = Connection( extra='{"aws_access_key_id": "aws_access_key_id",' '"aws_secret_access_key": "aws_secret_access_key",' ' "aws_session_token": "session_token"}') mock_get_connection.return_value = mock_connection hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test') credentials_from_hook = hook.get_credentials() assert credentials_from_hook.access_key == 'aws_access_key_id' assert credentials_from_hook.secret_key == 'aws_secret_access_key' assert credentials_from_hook.token == 'session_token'
def test_get_credentials_from_role_arn(self, mock_get_connection): mock_connection = Connection(extra='{"role_arn":"arn:aws:iam::123456:role/role_arn"}') mock_get_connection.return_value = mock_connection hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test') credentials_from_hook = hook.get_credentials() self.assertIn("ASIA", credentials_from_hook.access_key) # We assert the length instead of actual values as the values are random: # Details: https://github.com/spulec/moto/commit/ab0d23a0ba2506e6338ae20b3fde70da049f7b03 self.assertEqual(20, len(credentials_from_hook.access_key)) self.assertEqual(40, len(credentials_from_hook.secret_key)) self.assertEqual(356, len(credentials_from_hook.token))
def test_get_credentials_from_extra_with_token(self, mock_get_connection): mock_connection = Connection( extra='{"aws_access_key_id": "aws_access_key_id",' '"aws_secret_access_key": "aws_secret_access_key",' ' "aws_session_token": "session_token"}') mock_get_connection.return_value = mock_connection hook = AwsBaseHook() credentials_from_hook = hook.get_credentials() self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id') self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key') self.assertEqual(credentials_from_hook.token, 'session_token')
def execute(self, context): self.log.info('StageToRedshiftOperator start') aws_hook = AwsBaseHook(aws_conn_id='aws_credentials', resource_type='*') credentials = aws_hook.get_credentials() redshift = PostgresHook(postgres_conn_id=self.redshift_conn_id) self.log.info("Clearing data from destination Redshift table") redshift.run("DELETE FROM {}".format(self.table)) self.log.info("Copying data from S3 to Redshift") rendered_key = self.s3_key.format(**context) s3_path = "s3://{}/{}".format(self.s3_bucket, rendered_key) formatted_sql = StageToRedshiftOperator.copy_sql_stmt.format( self.table, s3_path, credentials.access_key, credentials.secret_key, self.ignore_headers, self.delimiter) redshift.run(formatted_sql)
def execute(self, context): aws_hook = AwsBaseHook(self.aws_credentials, client_type='s3') self.log.info(aws_hook) credentials = aws_hook.get_credentials() redshift_hook = PostgresHook(self.redshift_conn) self.s3_bucket = self.s3_bucket+self.s3_key sql = self.COPY_SQL.format( self.table, self.s3_bucket, credentials.access_key, credentials.secret_key, self.json_path) self.log.info(sql) redshift_hook.run(sql) self.log.info("Stage Complete") self.log.info(self.run_date)
def execute(self, context): aws_hook = AwsBaseHook(self.aws_credentials_id, client_type="s3") credentials = aws_hook.get_credentials() redshift = PostgresHook(postgres_conn_id=self.redshift_conn_id) self.log.info("Clearing data from destination Redshift table") redshift.run("DELETE FROM {}".format(self.table)) self.log.info("Copying data from S3 to Redshift") rendered_key = self.s3_key.format(**context) s3_path = "s3://{}/{}".format(self.s3_bucket, rendered_key) formatted_sql = StageToRedshiftOperator.copy_sql.format( self.table, s3_path, credentials.access_key, credentials.secret_key, self.json_format, ) redshift.run(formatted_sql)
def execute(self, context): aws_hook = AwsBaseHook(self.aws_credentials_id) aws_credentials = aws_hook.get_credentials() redshift_conn = PostgresHook( postgres_conn_id=self.redshift_conn_id, connect_args={ 'keepalives': 1, 'keepalives_idle': 60, 'keepalives_interval': 60 }) self.log.debug(f"Truncate Table: {self.table}") redshift_conn.run(f"TRUNCATE TABLE {self.table}") format = '' if self.data_format == 'csv' and self.ignore_header > 0: format += f"IGNOREHEADER {self.ignore_header}\n" if self.data_format == 'csv': format += f"DELIMITER '{self.delimiter}'\n" elif self.data_format == 'json': format += f"FORMAT AS JSON '{self.jsonpath}'\n" format += f"{self.copy_opts}" self.log.debug(f"format : {format}") formatted_key = self.s3_src_bucket_key.format(**context) self.log.info(f"Rendered S3 source file key : {formatted_key}") s3_url = f"s3://{self.s3_src_bucket_name}/{formatted_key}" self.log.debug(f"S3 URL : {s3_url}") formatted_sql = self._sql.format(**dict( table=self.table, source=s3_url, access_key=aws_credentials.access_key, secret_access_key=aws_credentials.secret_key, format=format )) self.log.debug(f"Base SQL: {self._sql}") self.log.info(f"Copying data from S3 to Redshift table {self.table}...") redshift_conn.run(formatted_sql) self.log.info(f"Finished copying data from S3 to Redshift table {self.table}")
def test_get_credentials_from_gcp_credentials(self): mock_connection = Connection(extra=json.dumps({ "role_arn": "arn:aws:iam::123456:role/role_arn", "assume_role_method": "assume_role_with_web_identity", "assume_role_with_web_identity_federation": 'google', "assume_role_with_web_identity_federation_audience": 'aws-federation.airflow.apache.org', })) # Store original __import__ orig_import = __import__ mock_id_token_credentials = mock.Mock() def import_mock(name, *args): if name == 'airflow.providers.google.common.utils.id_token_credentials': return mock_id_token_credentials return orig_import(name, *args) with mock.patch( 'builtins.__import__', side_effect=import_mock ), mock.patch.dict( 'os.environ', AIRFLOW_CONN_AWS_DEFAULT=mock_connection.get_uri() ), mock.patch( 'airflow.providers.amazon.aws.hooks.base_aws.boto3' ) as mock_boto3, mock.patch( 'airflow.providers.amazon.aws.hooks.base_aws.botocore' ) as mock_botocore, mock.patch( 'airflow.providers.amazon.aws.hooks.base_aws.botocore.session' ) as mock_session: hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test') credentials_from_hook = hook.get_credentials() mock_get_credentials = mock_boto3.session.Session.return_value.get_credentials assert (mock_get_credentials.return_value.get_frozen_credentials. return_value == credentials_from_hook) mock_boto3.assert_has_calls([ mock.call.session.Session( aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, region_name=None, ), mock.call.session.Session()._session.__bool__(), mock.call.session.Session( botocore_session=mock_session.Session.return_value, region_name=mock_boto3.session.Session.return_value. region_name, ), mock.call.session.Session().get_credentials(), mock.call.session.Session().get_credentials(). get_frozen_credentials(), ]) mock_fetcher = mock_botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher mock_botocore.assert_has_calls([ mock.call.credentials.AssumeRoleWithWebIdentityCredentialFetcher( client_creator=mock_boto3.session.Session.return_value. _session.create_client, extra_args={}, role_arn='arn:aws:iam::123456:role/role_arn', web_identity_token_loader=mock.ANY, ), mock.call.credentials.DeferredRefreshableCredentials( method='assume-role-with-web-identity', refresh_using=mock_fetcher.return_value.fetch_credentials, time_fetcher=mock.ANY, ), ]) mock_session.assert_has_calls([mock.call.Session()]) mock_id_token_credentials.assert_has_calls([ mock.call.get_default_id_token_credentials( target_audience='aws-federation.airflow.apache.org') ])