def execute(self, context): aws_hook = AwsBaseHook(self.aws_credentials_id, client_type='s3') credentials = aws_hook.get_credentials() redshift = PostgresHook(postgres_conn_id=self.redshift_conn_id) self.log.info('Date:' + self.execution_date) date = parser.parse(self.execution_date) self.log.info("Backfill_data: {}".format(self.backfill_data)) s3_bucket_key = "s3://{}/{}".format(self.s3_bucket, self.s3_key) if self.backfill_data: s3_path = s3_bucket_key + '/' + str(date.year) + '/' + str( date.month) else: s3_path = s3_bucket_key self.log.info("S3 path: {}".format(s3_path)) self.log.info("Deleting data from table {}.".format(self.table)) try: redshift.run("DELETE FROM {}".format(self.table)) except table_does_not_exist as ex: self.log.info("Andrea does not exist") copy_sql = self.COPY_SQL.format(self.table, s3_path, credentials.access_key, credentials.secret_key, self.region, self.json_path) self.log.info( "SQL Statement Executing on Redshift: {}".format(copy_sql)) redshift.run(copy_sql)
def test_get_resource_type_returns_a_boto3_resource_of_the_requested_type( self): hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='dynamodb') resource_from_hook = hook.get_resource_type('dynamodb') # this table needs to be created in production table = resource_from_hook.create_table( # pylint: disable=no-member TableName='test_airflow', KeySchema=[ { 'AttributeName': 'id', 'KeyType': 'HASH' }, ], AttributeDefinitions=[{ 'AttributeName': 'id', 'AttributeType': 'S' }], ProvisionedThroughput={ 'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10 }) table.meta.client.get_waiter('table_exists').wait( TableName='test_airflow') self.assertEqual(table.item_count, 0)
def expand_role(self): if 'Model' not in self.config: return config = self.config['Model'] if 'ExecutionRoleArn' in config: hook = AwsBaseHook(self.aws_conn_id, client_type='iam') config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
def test_get_session_returns_a_boto3_session(self): hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='dynamodb') session_from_hook = hook.get_session() resource_from_session = session_from_hook.resource('dynamodb') table = resource_from_session.create_table( # pylint: disable=no-member TableName='test_airflow', KeySchema=[ { 'AttributeName': 'id', 'KeyType': 'HASH' }, ], AttributeDefinitions=[{ 'AttributeName': 'id', 'AttributeType': 'S' }], ProvisionedThroughput={ 'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10 }, ) table.meta.client.get_waiter('table_exists').wait( TableName='test_airflow') assert table.item_count == 0
def execute(self, context): """ Upload the file using boto3 S3 client Args: context: Returns: """ aws_hook = AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type='s3') aws_credentials = aws_hook.get_credentials() aws_access_key_id = aws_credentials.access_key aws_secret_access_key = aws_credentials.secret_key s3client = boto3.client('s3', region_name=self.region_name, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key) try: self.fp = os.path.join(self.working_dir, self.fn) s3_key = self.s3_folder + self.fn s3_path = 's3://' + self.s3_bucket.rstrip('/') +'/' + s3_key self.log.info(f'uploading {self.fp} to {s3_path}') response = s3client.upload_file(self.fp, self.s3_bucket, s3_key) self.log.info(response) except ClientError as e: self.log.error(e) raise ClientError(e) pass
def expand_role(self) -> None: """Expands an IAM role name into an ARN.""" if 'TrainingJobDefinition' in self.config: config = self.config['TrainingJobDefinition'] if 'RoleArn' in config: hook = AwsBaseHook(self.aws_conn_id, client_type='iam') config['RoleArn'] = hook.expand_role(config['RoleArn'])
def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self): self._create_clusters() hook = AwsBaseHook(aws_conn_id='aws_default', client_type='redshift') client_from_hook = hook.get_conn() clusters = client_from_hook.describe_clusters()['Clusters'] self.assertEqual(len(clusters), 2)
def test_expand_role(self): conn = boto3.client('iam', region_name='us-east-1') conn.create_role(RoleName='test-role', AssumeRolePolicyDocument='some policy') hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test') arn = hook.expand_role('test-role') expect_arn = conn.get_role(RoleName='test-role').get('Role').get('Arn') self.assertEqual(arn, expect_arn)
def get_iam_token(self, conn): """ Uses AWSHook to retrieve a temporary password to connect to Postgres or Redshift. Port is required. If none is provided, default is used for each service """ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook redshift = conn.extra_dejson.get('redshift', False) aws_conn_id = conn.extra_dejson.get('aws_conn_id', 'aws_default') aws_hook = AwsBaseHook(aws_conn_id) login = conn.login if conn.port is None: port = 5439 if redshift else 5432 else: port = conn.port if redshift: # Pull the custer-identifier from the beginning of the Redshift URL # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-cluster cluster_identifier = conn.extra_dejson.get('cluster-identifier', conn.host.split('.')[0]) client = aws_hook.get_client_type('redshift') cluster_creds = client.get_cluster_credentials( DbUser=conn.login, DbName=self.schema or conn.schema, ClusterIdentifier=cluster_identifier, AutoCreate=False) token = cluster_creds['DbPassword'] login = cluster_creds['DbUser'] else: client = aws_hook.get_client_type('rds') token = client.generate_db_auth_token(conn.host, port, conn.login) return login, token, port
def expand_role(self): if 'Model' not in self.config: return hook = AwsBaseHook(self.aws_conn_id) config = self.config['Model'] if 'ExecutionRoleArn' in config: config['ExecutionRoleArn'] = hook.expand_role( config['ExecutionRoleArn'])
def expand_role(self) -> None: """Expands an IAM role name into an ARN.""" if 'Model' not in self.config: return config = self.config['Model'] if 'ExecutionRoleArn' in config: hook = AwsBaseHook(self.aws_conn_id, client_type='iam') config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self): client = boto3.client('emr', region_name='us-east-1') if client.list_clusters()['Clusters']: raise ValueError('AWS not properly mocked') hook = AwsBaseHook(aws_conn_id='aws_default', client_type='emr') client_from_hook = hook.get_client_type('emr') self.assertEqual(client_from_hook.list_clusters()['Clusters'], [])
def _inject_aws_credentials(self) -> None: if TRANSFER_SPEC in self.body and AWS_S3_DATA_SOURCE in self.body[TRANSFER_SPEC]: aws_hook = AwsBaseHook(self.aws_conn_id, resource_type="s3") aws_credentials = aws_hook.get_credentials() aws_access_key_id = aws_credentials.access_key # type: ignore[attr-defined] aws_secret_access_key = aws_credentials.secret_key # type: ignore[attr-defined] self.body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY] = { ACCESS_KEY_ID: aws_access_key_id, SECRET_ACCESS_KEY: aws_secret_access_key, }
def _inject_aws_credentials(self): if TRANSFER_SPEC in self.body and AWS_S3_DATA_SOURCE in self.body[TRANSFER_SPEC]: aws_hook = AwsBaseHook(self.aws_conn_id) aws_credentials = aws_hook.get_credentials() aws_access_key_id = aws_credentials.access_key aws_secret_access_key = aws_credentials.secret_key self.body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY] = { ACCESS_KEY_ID: aws_access_key_id, SECRET_ACCESS_KEY: aws_secret_access_key, }
def test_get_credentials_from_login_without_token(self, mock_get_connection): mock_connection = Connection(login='******', password='******', ) mock_get_connection.return_value = mock_connection hook = AwsBaseHook(aws_conn_id='aws_default', client_type='spam') credentials_from_hook = hook.get_credentials() self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id') self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key') self.assertIsNone(credentials_from_hook.token)
def test_get_credentials_from_login_with_token(self, mock_get_connection): mock_connection = Connection(login='******', password='******', extra='{"aws_session_token": "test_token"}' ) mock_get_connection.return_value = mock_connection hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test') credentials_from_hook = hook.get_credentials() self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id') self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key') self.assertEqual(credentials_from_hook.token, 'test_token')
def test_get_credentials_from_extra_with_s3_config_and_profile( self, mock_get_connection, mock_parse_s3_config): mock_connection = Connection(extra='{"s3_config_format": "aws", ' '"profile": "test", ' '"s3_config_file": "aws-credentials", ' '"region_name": "us-east-1"}') mock_get_connection.return_value = mock_connection hook = AwsBaseHook() hook._get_credentials(region_name=None) mock_parse_s3_config.assert_called_once_with('aws-credentials', 'aws', 'test')
def test_get_credentials_from_extra_without_token(self, mock_get_connection): mock_connection = Connection( extra='{"aws_access_key_id": "aws_access_key_id",' '"aws_secret_access_key": "aws_secret_access_key"}' ) mock_get_connection.return_value = mock_connection hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test') credentials_from_hook = hook.get_credentials() self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id') self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key') self.assertIsNone(credentials_from_hook.token)
def test_get_credentials_from_role_arn(self, mock_get_connection): mock_connection = Connection(extra='{"role_arn":"arn:aws:iam::123456:role/role_arn"}') mock_get_connection.return_value = mock_connection hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test') credentials_from_hook = hook.get_credentials() self.assertIn("ASIA", credentials_from_hook.access_key) # We assert the length instead of actual values as the values are random: # Details: https://github.com/spulec/moto/commit/ab0d23a0ba2506e6338ae20b3fde70da049f7b03 self.assertEqual(20, len(credentials_from_hook.access_key)) self.assertEqual(40, len(credentials_from_hook.secret_key)) self.assertEqual(356, len(credentials_from_hook.token))
def test_get_credentials_from_extra_with_token(self, mock_get_connection): mock_connection = Connection( extra='{"aws_access_key_id": "aws_access_key_id",' '"aws_secret_access_key": "aws_secret_access_key",' ' "aws_session_token": "session_token"}') mock_get_connection.return_value = mock_connection hook = AwsBaseHook() credentials_from_hook = hook.get_credentials() self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id') self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key') self.assertEqual(credentials_from_hook.token, 'session_token')
def test_get_credentials_from_extra_with_token(self, mock_get_connection): mock_connection = Connection( extra='{"aws_access_key_id": "aws_access_key_id",' '"aws_secret_access_key": "aws_secret_access_key",' ' "aws_session_token": "session_token"}') mock_get_connection.return_value = mock_connection hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test') credentials_from_hook = hook.get_credentials() assert credentials_from_hook.access_key == 'aws_access_key_id' assert credentials_from_hook.secret_key == 'aws_secret_access_key' assert credentials_from_hook.token == 'session_token'
def get_iam_token(self, conn): """ Uses AWSHook to retrieve a temporary password to connect to MySQL Port is required. If none is provided, default 3306 is used """ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook aws_conn_id = conn.extra_dejson.get('aws_conn_id', 'aws_default') aws_hook = AwsBaseHook(aws_conn_id, client_type='rds') if conn.port is None: port = 3306 else: port = conn.port client = aws_hook.get_conn() token = client.generate_db_auth_token(conn.host, port, conn.login) return token, port
def get_glue_operator_args(job_name, worker_number = 5): glue_client = AwsBaseHook(aws_conn_id='aws_default', client_type='glue').get_client_type(client_type='glue', region_name='us-west-2') response = glue_client.start_job_run(JobName=job_name) job_id = response['JobRunId'] print("Job {} ID: {}".format(job_name,job_id)) while True: status = glue_client.get_job_run(JobName=job_name, RunId=job_id) state = status['JobRun']['JobRunState'] if state == 'SUCCEEDED': print('Glue job {} run ID {} succeeded'.format(job_name,job_id)) break if state in ['STOPPED', 'FAILED', 'TIMEOUT', 'STOPPING']: print('Glue job {} run ID {} is in {} state'.format(job_name,job_id, state)) raise Exception time.sleep(10)
def delete_ecs_task_definition(aws_conn_id: str, task_definition: str) -> None: """ Delete all revisions of given ecs task definition :param aws_conn_id: id of the aws connection to use when creating boto3 client/resource :type aws_conn_id: str :param task_definition: family prefix for task definition to delete in aws ecs :type task_definition: str """ hook = AwsBaseHook( aws_conn_id=aws_conn_id, client_type="ecs", ) response = hook.conn.list_task_definitions( familyPrefix=task_definition, status="ACTIVE", sort="ASC", maxResults=100, ) revisions = [ arn.split(":")[-1] for arn in response["taskDefinitionArns"] ] for revision in revisions: hook.conn.deregister_task_definition( taskDefinition=f"{task_definition}:{revision}", )
def create_ecs_cluster(aws_conn_id: str, cluster_name: str) -> None: """ Create ecs cluster with given name If specified cluster exists, it doesn't change and new cluster will not be created. :param aws_conn_id: id of the aws connection to use when creating boto3 client/resource :type aws_conn_id: str :param cluster_name: name of the cluster to create in aws ecs :type cluster_name: str """ hook = AwsBaseHook( aws_conn_id=aws_conn_id, client_type="ecs", ) hook.conn.create_cluster( clusterName=cluster_name, capacityProviders=[ "FARGATE_SPOT", "FARGATE", ], defaultCapacityProviderStrategy=[ { "capacityProvider": "FARGATE_SPOT", "weight": 1, "base": 0, }, { "capacityProvider": "FARGATE", "weight": 1, "base": 0, }, ], )
def get_hook(self): """Create and return an AwsHook.""" if not self.hook: self.hook = AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type='ecs', region_name=self.region_name) return self.hook
def execute(self, context): self.log.info('StageToRedshiftOperator start') aws_hook = AwsBaseHook(aws_conn_id='aws_credentials', resource_type='*') credentials = aws_hook.get_credentials() redshift = PostgresHook(postgres_conn_id=self.redshift_conn_id) self.log.info("Clearing data from destination Redshift table") redshift.run("DELETE FROM {}".format(self.table)) self.log.info("Copying data from S3 to Redshift") rendered_key = self.s3_key.format(**context) s3_path = "s3://{}/{}".format(self.s3_bucket, rendered_key) formatted_sql = StageToRedshiftOperator.copy_sql_stmt.format( self.table, s3_path, credentials.access_key, credentials.secret_key, self.ignore_headers, self.delimiter) redshift.run(formatted_sql)
def create_ecs_task_definition(aws_conn_id: str, task_definition: str, container: str, image: str, execution_role_arn: str, awslogs_group: str, awslogs_region: str, awslogs_stream_prefix: str) -> None: """ Create ecs task definition with given name :param aws_conn_id: id of the aws connection to use when creating boto3 client/resource :type aws_conn_id: str :param task_definition: family name for task definition to create in aws ecs :type task_definition: str :param container: name of the container :type container: str :param image: image used to start a container, format: `registry_id`.dkr.ecr.`region`.amazonaws.com/`repository_name`:`tag` :type image: str :param execution_role_arn: task execution role that the Amazon ECS container agent can assume, format: arn:aws:iam::`registry_id`:role/`role_name` :type execution_role_arn: str :param awslogs_group: awslogs group option in log configuration :type awslogs_group: str :param awslogs_region: awslogs region option in log configuration :type awslogs_region: str :param awslogs_stream_prefix: awslogs stream prefix option in log configuration :type awslogs_stream_prefix: str """ hook = AwsBaseHook( aws_conn_id=aws_conn_id, client_type="ecs", ) hook.conn.register_task_definition( family=task_definition, executionRoleArn=execution_role_arn, networkMode="awsvpc", containerDefinitions=[ { "name": container, "image": image, "cpu": 256, "memory": 512, # hard limit "memoryReservation": 512, # soft limit "logConfiguration": { "logDriver": "awslogs", "options": { "awslogs-group": awslogs_group, "awslogs-region": awslogs_region, "awslogs-stream-prefix": awslogs_stream_prefix, }, }, }, ], requiresCompatibilities=[ "FARGATE", ], cpu="256", # task cpu limit (total of all containers) memory="512", # task memory limit (total of all containers) )
def execute(self, context): aws_hook = AwsBaseHook(self.aws_credentials_id, client_type="s3") credentials = aws_hook.get_credentials() redshift = PostgresHook(postgres_conn_id=self.redshift_conn_id) self.log.info("Clearing data from destination Redshift table") redshift.run("DELETE FROM {}".format(self.table)) self.log.info("Copying data from S3 to Redshift") rendered_key = self.s3_key.format(**context) s3_path = "s3://{}/{}".format(self.s3_bucket, rendered_key) formatted_sql = StageToRedshiftOperator.copy_sql.format( self.table, s3_path, credentials.access_key, credentials.secret_key, self.json_format, ) redshift.run(formatted_sql)
def execute(self, context): aws_hook = AwsBaseHook(self.aws_credentials, client_type='s3') self.log.info(aws_hook) credentials = aws_hook.get_credentials() redshift_hook = PostgresHook(self.redshift_conn) self.s3_bucket = self.s3_bucket+self.s3_key sql = self.COPY_SQL.format( self.table, self.s3_bucket, credentials.access_key, credentials.secret_key, self.json_path) self.log.info(sql) redshift_hook.run(sql) self.log.info("Stage Complete") self.log.info(self.run_date)