Ejemplo n.º 1
0
    def execute(self, context):
        aws_hook = AwsBaseHook(self.aws_credentials_id, client_type='s3')
        credentials = aws_hook.get_credentials()
        redshift = PostgresHook(postgres_conn_id=self.redshift_conn_id)
        self.log.info('Date:' + self.execution_date)
        date = parser.parse(self.execution_date)

        self.log.info("Backfill_data: {}".format(self.backfill_data))
        s3_bucket_key = "s3://{}/{}".format(self.s3_bucket, self.s3_key)
        if self.backfill_data:
            s3_path = s3_bucket_key + '/' + str(date.year) + '/' + str(
                date.month)
        else:
            s3_path = s3_bucket_key
        self.log.info("S3 path: {}".format(s3_path))

        self.log.info("Deleting data from table {}.".format(self.table))

        try:
            redshift.run("DELETE FROM {}".format(self.table))
        except table_does_not_exist as ex:
            self.log.info("Andrea does not exist")

        copy_sql = self.COPY_SQL.format(self.table, s3_path,
                                        credentials.access_key,
                                        credentials.secret_key, self.region,
                                        self.json_path)
        self.log.info(
            "SQL Statement Executing on Redshift: {}".format(copy_sql))
        redshift.run(copy_sql)
    def test_get_resource_type_returns_a_boto3_resource_of_the_requested_type(
            self):
        hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='dynamodb')
        resource_from_hook = hook.get_resource_type('dynamodb')

        # this table needs to be created in production
        table = resource_from_hook.create_table(  # pylint: disable=no-member
            TableName='test_airflow',
            KeySchema=[
                {
                    'AttributeName': 'id',
                    'KeyType': 'HASH'
                },
            ],
            AttributeDefinitions=[{
                'AttributeName': 'id',
                'AttributeType': 'S'
            }],
            ProvisionedThroughput={
                'ReadCapacityUnits': 10,
                'WriteCapacityUnits': 10
            })

        table.meta.client.get_waiter('table_exists').wait(
            TableName='test_airflow')

        self.assertEqual(table.item_count, 0)
Ejemplo n.º 3
0
 def expand_role(self):
     if 'Model' not in self.config:
         return
     config = self.config['Model']
     if 'ExecutionRoleArn' in config:
         hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
         config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
Ejemplo n.º 4
0
    def test_get_session_returns_a_boto3_session(self):
        hook = AwsBaseHook(aws_conn_id='aws_default', resource_type='dynamodb')
        session_from_hook = hook.get_session()
        resource_from_session = session_from_hook.resource('dynamodb')
        table = resource_from_session.create_table(  # pylint: disable=no-member
            TableName='test_airflow',
            KeySchema=[
                {
                    'AttributeName': 'id',
                    'KeyType': 'HASH'
                },
            ],
            AttributeDefinitions=[{
                'AttributeName': 'id',
                'AttributeType': 'S'
            }],
            ProvisionedThroughput={
                'ReadCapacityUnits': 10,
                'WriteCapacityUnits': 10
            },
        )

        table.meta.client.get_waiter('table_exists').wait(
            TableName='test_airflow')

        assert table.item_count == 0
Ejemplo n.º 5
0
    def execute(self, context):
        """
        Upload the file using boto3 S3 client
        Args:
            context:

        Returns:

        """
        aws_hook = AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type='s3')
        aws_credentials = aws_hook.get_credentials()
        aws_access_key_id = aws_credentials.access_key
        aws_secret_access_key = aws_credentials.secret_key
        s3client = boto3.client('s3',
                                     region_name=self.region_name,
                                     aws_access_key_id=aws_access_key_id,
                                     aws_secret_access_key=aws_secret_access_key)
        try:
            self.fp = os.path.join(self.working_dir, self.fn)
            s3_key = self.s3_folder + self.fn
            s3_path = 's3://' + self.s3_bucket.rstrip('/') +'/' + s3_key
            self.log.info(f'uploading {self.fp}  to {s3_path}')
            response = s3client.upload_file(self.fp, self.s3_bucket, s3_key)
            self.log.info(response)
        except ClientError as e:
            self.log.error(e)
            raise ClientError(e)
        pass
Ejemplo n.º 6
0
 def expand_role(self) -> None:
     """Expands an IAM role name into an ARN."""
     if 'TrainingJobDefinition' in self.config:
         config = self.config['TrainingJobDefinition']
         if 'RoleArn' in config:
             hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
             config['RoleArn'] = hook.expand_role(config['RoleArn'])
Ejemplo n.º 7
0
    def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
        self._create_clusters()
        hook = AwsBaseHook(aws_conn_id='aws_default', client_type='redshift')
        client_from_hook = hook.get_conn()

        clusters = client_from_hook.describe_clusters()['Clusters']
        self.assertEqual(len(clusters), 2)
Ejemplo n.º 8
0
 def test_expand_role(self):
     conn = boto3.client('iam', region_name='us-east-1')
     conn.create_role(RoleName='test-role', AssumeRolePolicyDocument='some policy')
     hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
     arn = hook.expand_role('test-role')
     expect_arn = conn.get_role(RoleName='test-role').get('Role').get('Arn')
     self.assertEqual(arn, expect_arn)
Ejemplo n.º 9
0
    def get_iam_token(self, conn):
        """
        Uses AWSHook to retrieve a temporary password to connect to Postgres
        or Redshift. Port is required. If none is provided, default is used for
        each service
        """
        from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook

        redshift = conn.extra_dejson.get('redshift', False)
        aws_conn_id = conn.extra_dejson.get('aws_conn_id', 'aws_default')
        aws_hook = AwsBaseHook(aws_conn_id)
        login = conn.login
        if conn.port is None:
            port = 5439 if redshift else 5432
        else:
            port = conn.port
        if redshift:
            # Pull the custer-identifier from the beginning of the Redshift URL
            # ex. my-cluster.ccdre4hpd39h.us-east-1.redshift.amazonaws.com returns my-cluster
            cluster_identifier = conn.extra_dejson.get('cluster-identifier', conn.host.split('.')[0])
            client = aws_hook.get_client_type('redshift')
            cluster_creds = client.get_cluster_credentials(
                DbUser=conn.login,
                DbName=self.schema or conn.schema,
                ClusterIdentifier=cluster_identifier,
                AutoCreate=False)
            token = cluster_creds['DbPassword']
            login = cluster_creds['DbUser']
        else:
            client = aws_hook.get_client_type('rds')
            token = client.generate_db_auth_token(conn.host, port, conn.login)
        return login, token, port
Ejemplo n.º 10
0
 def expand_role(self):
     if 'Model' not in self.config:
         return
     hook = AwsBaseHook(self.aws_conn_id)
     config = self.config['Model']
     if 'ExecutionRoleArn' in config:
         config['ExecutionRoleArn'] = hook.expand_role(
             config['ExecutionRoleArn'])
Ejemplo n.º 11
0
 def expand_role(self) -> None:
     """Expands an IAM role name into an ARN."""
     if 'Model' not in self.config:
         return
     config = self.config['Model']
     if 'ExecutionRoleArn' in config:
         hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
         config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
Ejemplo n.º 12
0
    def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
        client = boto3.client('emr', region_name='us-east-1')
        if client.list_clusters()['Clusters']:
            raise ValueError('AWS not properly mocked')

        hook = AwsBaseHook(aws_conn_id='aws_default', client_type='emr')
        client_from_hook = hook.get_client_type('emr')

        self.assertEqual(client_from_hook.list_clusters()['Clusters'], [])
 def _inject_aws_credentials(self) -> None:
     if TRANSFER_SPEC in self.body and AWS_S3_DATA_SOURCE in self.body[TRANSFER_SPEC]:
         aws_hook = AwsBaseHook(self.aws_conn_id, resource_type="s3")
         aws_credentials = aws_hook.get_credentials()
         aws_access_key_id = aws_credentials.access_key  # type: ignore[attr-defined]
         aws_secret_access_key = aws_credentials.secret_key  # type: ignore[attr-defined]
         self.body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY] = {
             ACCESS_KEY_ID: aws_access_key_id,
             SECRET_ACCESS_KEY: aws_secret_access_key,
         }
 def _inject_aws_credentials(self):
     if TRANSFER_SPEC in self.body and AWS_S3_DATA_SOURCE in self.body[TRANSFER_SPEC]:
         aws_hook = AwsBaseHook(self.aws_conn_id)
         aws_credentials = aws_hook.get_credentials()
         aws_access_key_id = aws_credentials.access_key
         aws_secret_access_key = aws_credentials.secret_key
         self.body[TRANSFER_SPEC][AWS_S3_DATA_SOURCE][AWS_ACCESS_KEY] = {
             ACCESS_KEY_ID: aws_access_key_id,
             SECRET_ACCESS_KEY: aws_secret_access_key,
         }
Ejemplo n.º 15
0
    def test_get_credentials_from_login_without_token(self, mock_get_connection):
        mock_connection = Connection(login='******',
                                     password='******',
                                     )

        mock_get_connection.return_value = mock_connection
        hook = AwsBaseHook(aws_conn_id='aws_default', client_type='spam')
        credentials_from_hook = hook.get_credentials()
        self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
        self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key')
        self.assertIsNone(credentials_from_hook.token)
Ejemplo n.º 16
0
 def test_get_credentials_from_login_with_token(self, mock_get_connection):
     mock_connection = Connection(login='******',
                                  password='******',
                                  extra='{"aws_session_token": "test_token"}'
                                  )
     mock_get_connection.return_value = mock_connection
     hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
     credentials_from_hook = hook.get_credentials()
     self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
     self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key')
     self.assertEqual(credentials_from_hook.token, 'test_token')
 def test_get_credentials_from_extra_with_s3_config_and_profile(
         self, mock_get_connection, mock_parse_s3_config):
     mock_connection = Connection(extra='{"s3_config_format": "aws", '
                                  '"profile": "test", '
                                  '"s3_config_file": "aws-credentials", '
                                  '"region_name": "us-east-1"}')
     mock_get_connection.return_value = mock_connection
     hook = AwsBaseHook()
     hook._get_credentials(region_name=None)
     mock_parse_s3_config.assert_called_once_with('aws-credentials', 'aws',
                                                  'test')
Ejemplo n.º 18
0
 def test_get_credentials_from_extra_without_token(self, mock_get_connection):
     mock_connection = Connection(
         extra='{"aws_access_key_id": "aws_access_key_id",'
         '"aws_secret_access_key": "aws_secret_access_key"}'
     )
     mock_get_connection.return_value = mock_connection
     hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
     credentials_from_hook = hook.get_credentials()
     self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
     self.assertEqual(credentials_from_hook.secret_key, 'aws_secret_access_key')
     self.assertIsNone(credentials_from_hook.token)
Ejemplo n.º 19
0
    def test_get_credentials_from_role_arn(self, mock_get_connection):
        mock_connection = Connection(extra='{"role_arn":"arn:aws:iam::123456:role/role_arn"}')
        mock_get_connection.return_value = mock_connection
        hook = AwsBaseHook(aws_conn_id='aws_default', client_type='airflow_test')
        credentials_from_hook = hook.get_credentials()
        self.assertIn("ASIA", credentials_from_hook.access_key)

        # We assert the length instead of actual values as the values are random:
        # Details: https://github.com/spulec/moto/commit/ab0d23a0ba2506e6338ae20b3fde70da049f7b03
        self.assertEqual(20, len(credentials_from_hook.access_key))
        self.assertEqual(40, len(credentials_from_hook.secret_key))
        self.assertEqual(356, len(credentials_from_hook.token))
 def test_get_credentials_from_extra_with_token(self, mock_get_connection):
     mock_connection = Connection(
         extra='{"aws_access_key_id": "aws_access_key_id",'
         '"aws_secret_access_key": "aws_secret_access_key",'
         ' "aws_session_token": "session_token"}')
     mock_get_connection.return_value = mock_connection
     hook = AwsBaseHook()
     credentials_from_hook = hook.get_credentials()
     self.assertEqual(credentials_from_hook.access_key, 'aws_access_key_id')
     self.assertEqual(credentials_from_hook.secret_key,
                      'aws_secret_access_key')
     self.assertEqual(credentials_from_hook.token, 'session_token')
Ejemplo n.º 21
0
 def test_get_credentials_from_extra_with_token(self, mock_get_connection):
     mock_connection = Connection(
         extra='{"aws_access_key_id": "aws_access_key_id",'
         '"aws_secret_access_key": "aws_secret_access_key",'
         ' "aws_session_token": "session_token"}')
     mock_get_connection.return_value = mock_connection
     hook = AwsBaseHook(aws_conn_id='aws_default',
                        client_type='airflow_test')
     credentials_from_hook = hook.get_credentials()
     assert credentials_from_hook.access_key == 'aws_access_key_id'
     assert credentials_from_hook.secret_key == 'aws_secret_access_key'
     assert credentials_from_hook.token == 'session_token'
Ejemplo n.º 22
0
    def get_iam_token(self, conn):
        """
        Uses AWSHook to retrieve a temporary password to connect to MySQL
        Port is required. If none is provided, default 3306 is used
        """
        from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook

        aws_conn_id = conn.extra_dejson.get('aws_conn_id', 'aws_default')
        aws_hook = AwsBaseHook(aws_conn_id, client_type='rds')
        if conn.port is None:
            port = 3306
        else:
            port = conn.port
        client = aws_hook.get_conn()
        token = client.generate_db_auth_token(conn.host, port, conn.login)
        return token, port
Ejemplo n.º 23
0
def get_glue_operator_args(job_name, worker_number = 5):
    glue_client = AwsBaseHook(aws_conn_id='aws_default', client_type='glue').get_client_type(client_type='glue',
                                                                                             region_name='us-west-2')
    response = glue_client.start_job_run(JobName=job_name)
    job_id = response['JobRunId']
    print("Job {} ID: {}".format(job_name,job_id))
    while True:
        status = glue_client.get_job_run(JobName=job_name, RunId=job_id)
        state = status['JobRun']['JobRunState']
        if state == 'SUCCEEDED':
            print('Glue job {} run ID {} succeeded'.format(job_name,job_id))
            break
        if state in ['STOPPED', 'FAILED', 'TIMEOUT', 'STOPPING']:
            print('Glue job {} run ID {} is in {} state'.format(job_name,job_id, state))
            raise Exception
        time.sleep(10)
Ejemplo n.º 24
0
    def delete_ecs_task_definition(aws_conn_id: str,
                                   task_definition: str) -> None:
        """
        Delete all revisions of given ecs task definition

        :param aws_conn_id: id of the aws connection to use when creating boto3 client/resource
        :type aws_conn_id: str
        :param task_definition: family prefix for task definition to delete in aws ecs
        :type task_definition: str
        """
        hook = AwsBaseHook(
            aws_conn_id=aws_conn_id,
            client_type="ecs",
        )
        response = hook.conn.list_task_definitions(
            familyPrefix=task_definition,
            status="ACTIVE",
            sort="ASC",
            maxResults=100,
        )
        revisions = [
            arn.split(":")[-1] for arn in response["taskDefinitionArns"]
        ]
        for revision in revisions:
            hook.conn.deregister_task_definition(
                taskDefinition=f"{task_definition}:{revision}", )
Ejemplo n.º 25
0
    def create_ecs_cluster(aws_conn_id: str, cluster_name: str) -> None:
        """
        Create ecs cluster with given name

        If specified cluster exists, it doesn't change and new cluster will not be created.

        :param aws_conn_id: id of the aws connection to use when creating boto3 client/resource
        :type aws_conn_id: str
        :param cluster_name: name of the cluster to create in aws ecs
        :type cluster_name: str
        """
        hook = AwsBaseHook(
            aws_conn_id=aws_conn_id,
            client_type="ecs",
        )
        hook.conn.create_cluster(
            clusterName=cluster_name,
            capacityProviders=[
                "FARGATE_SPOT",
                "FARGATE",
            ],
            defaultCapacityProviderStrategy=[
                {
                    "capacityProvider": "FARGATE_SPOT",
                    "weight": 1,
                    "base": 0,
                },
                {
                    "capacityProvider": "FARGATE",
                    "weight": 1,
                    "base": 0,
                },
            ],
        )
Ejemplo n.º 26
0
 def get_hook(self):
     """Create and return an AwsHook."""
     if not self.hook:
         self.hook = AwsBaseHook(aws_conn_id=self.aws_conn_id,
                                 client_type='ecs',
                                 region_name=self.region_name)
     return self.hook
Ejemplo n.º 27
0
    def execute(self, context):
        self.log.info('StageToRedshiftOperator start')
        aws_hook = AwsBaseHook(aws_conn_id='aws_credentials',
                               resource_type='*')
        credentials = aws_hook.get_credentials()
        redshift = PostgresHook(postgres_conn_id=self.redshift_conn_id)

        self.log.info("Clearing data from destination Redshift table")
        redshift.run("DELETE FROM {}".format(self.table))

        self.log.info("Copying data from S3 to Redshift")
        rendered_key = self.s3_key.format(**context)
        s3_path = "s3://{}/{}".format(self.s3_bucket, rendered_key)
        formatted_sql = StageToRedshiftOperator.copy_sql_stmt.format(
            self.table, s3_path, credentials.access_key,
            credentials.secret_key, self.ignore_headers, self.delimiter)
        redshift.run(formatted_sql)
Ejemplo n.º 28
0
    def create_ecs_task_definition(aws_conn_id: str, task_definition: str,
                                   container: str, image: str,
                                   execution_role_arn: str, awslogs_group: str,
                                   awslogs_region: str,
                                   awslogs_stream_prefix: str) -> None:
        """
        Create ecs task definition with given name

        :param aws_conn_id: id of the aws connection to use when creating boto3 client/resource
        :type aws_conn_id: str
        :param task_definition: family name for task definition to create in aws ecs
        :type task_definition: str
        :param container: name of the container
        :type container: str
        :param image: image used to start a container,
            format: `registry_id`.dkr.ecr.`region`.amazonaws.com/`repository_name`:`tag`
        :type image: str
        :param execution_role_arn: task execution role that the Amazon ECS container agent can assume,
            format: arn:aws:iam::`registry_id`:role/`role_name`
        :type execution_role_arn: str
        :param awslogs_group: awslogs group option in log configuration
        :type awslogs_group: str
        :param awslogs_region: awslogs region option in log configuration
        :type awslogs_region: str
        :param awslogs_stream_prefix: awslogs stream prefix option in log configuration
        :type awslogs_stream_prefix: str
        """
        hook = AwsBaseHook(
            aws_conn_id=aws_conn_id,
            client_type="ecs",
        )
        hook.conn.register_task_definition(
            family=task_definition,
            executionRoleArn=execution_role_arn,
            networkMode="awsvpc",
            containerDefinitions=[
                {
                    "name": container,
                    "image": image,
                    "cpu": 256,
                    "memory": 512,  # hard limit
                    "memoryReservation": 512,  # soft limit
                    "logConfiguration": {
                        "logDriver": "awslogs",
                        "options": {
                            "awslogs-group": awslogs_group,
                            "awslogs-region": awslogs_region,
                            "awslogs-stream-prefix": awslogs_stream_prefix,
                        },
                    },
                },
            ],
            requiresCompatibilities=[
                "FARGATE",
            ],
            cpu="256",  # task cpu limit (total of all containers)
            memory="512",  # task memory limit (total of all containers)
        )
    def execute(self, context):
        aws_hook = AwsBaseHook(self.aws_credentials_id, client_type="s3")
        credentials = aws_hook.get_credentials()
        redshift = PostgresHook(postgres_conn_id=self.redshift_conn_id)

        self.log.info("Clearing data from destination Redshift table")
        redshift.run("DELETE FROM {}".format(self.table))

        self.log.info("Copying data from S3 to Redshift")
        rendered_key = self.s3_key.format(**context)
        s3_path = "s3://{}/{}".format(self.s3_bucket, rendered_key)
        formatted_sql = StageToRedshiftOperator.copy_sql.format(
            self.table,
            s3_path,
            credentials.access_key,
            credentials.secret_key,
            self.json_format,
        )
        redshift.run(formatted_sql)
    def execute(self, context):
        aws_hook = AwsBaseHook(self.aws_credentials, client_type='s3')
        self.log.info(aws_hook)
        credentials = aws_hook.get_credentials()
        redshift_hook = PostgresHook(self.redshift_conn)

        self.s3_bucket = self.s3_bucket+self.s3_key

        sql = self.COPY_SQL.format(
        self.table,
        self.s3_bucket,
        credentials.access_key,
        credentials.secret_key,
        self.json_path)

        self.log.info(sql)
        redshift_hook.run(sql)
        self.log.info("Stage Complete")
        self.log.info(self.run_date)