コード例 #1
0
    def test_execute(self, mock_run, mock_session):
        access_key = "aws_access_key_id"
        secret_key = "aws_secret_access_key"
        mock_session.return_value = Session(access_key, secret_key)
        mock_session.return_value.access_key = access_key
        mock_session.return_value.secret_key = secret_key
        mock_session.return_value.token = None

        schema = "schema"
        table = "table"
        s3_bucket = "bucket"
        s3_key = "key"
        copy_options = ""

        op = S3ToRedshiftOperator(
            schema=schema,
            table=table,
            s3_bucket=s3_bucket,
            s3_key=s3_key,
            copy_options=copy_options,
            redshift_conn_id="redshift_conn_id",
            aws_conn_id="aws_conn_id",
            task_id="task_id",
            dag=None,
        )
        op.execute(None)

        credentials_block = build_credentials_block(mock_session.return_value)
        copy_query = op._build_copy_query(credentials_block, copy_options)

        assert mock_run.call_count == 1
        assert access_key in copy_query
        assert secret_key in copy_query
        assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0],
                                            copy_query)
コード例 #2
0
    def execute(self, context) -> None:
        postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
        s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
        credentials = s3_hook.get_credentials()
        credentials_block = build_credentials_block(credentials)
        copy_options = '\n\t\t\t'.join(self.copy_options)

        copy_statement = self._build_copy_query(credentials_block, copy_options)

        self.log.info("Creating the staging table...")
        postgres_hook.run(self.create_table_sql)
        self.log.info("Creating the staging table complete...")

        self.log.info('Executing COPY command...')
        postgres_hook.run(copy_statement)
        self.log.info("COPY command complete...")

        self.log.info("Logging the number of rows and files on S3 affected...")
        number_of_rows = postgres_hook.get_first(f"SELECT count(*) FROM {self.schema}.{self.table}")[0]
        number_of_keys_s3 = s3_hook.list_keys(bucket_name=self.s3_bucket, prefix=self.s3_key)

        self.log.info(f"{self.schema}.{self.table} has {number_of_rows} rows")
        self.log.info(f"{self.s3_bucket}/{self.s3_key} has {len(number_of_keys_s3)} files")

        self.log.info("Logging the number of rows and files on S3 affected complete...")
コード例 #3
0
ファイル: s3_to_redshift.py プロジェクト: wolvery/airflow
    def execute(self, context) -> None:
        redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
        conn = S3Hook.get_connection(conn_id=self.aws_conn_id)

        credentials_block = None
        if conn.extra_dejson.get('role_arn', False):
            credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
        else:
            s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
            credentials = s3_hook.get_credentials()
            credentials_block = build_credentials_block(credentials)

        copy_options = '\n\t\t\t'.join(self.copy_options)
        destination = f'{self.schema}.{self.table}'
        copy_destination = f'#{self.table}' if self.method == 'UPSERT' else destination

        copy_statement = self._build_copy_query(copy_destination,
                                                credentials_block,
                                                copy_options)

        sql: Union[list, str]

        if self.method == 'REPLACE':
            sql = [
                "BEGIN;", f"DELETE FROM {destination};", copy_statement,
                "COMMIT"
            ]
        elif self.method == 'UPSERT':
            keys = self.upsert_keys or redshift_hook.get_table_primary_key(
                self.table, self.schema)
            if not keys:
                raise AirflowException(
                    f"No primary key on {self.schema}.{self.table}. Please provide keys on 'upsert_keys'"
                )
            where_statement = ' AND '.join(
                [f'{self.table}.{k} = {copy_destination}.{k}' for k in keys])

            sql = [
                f"CREATE TABLE {copy_destination} (LIKE {destination});",
                copy_statement,
                "BEGIN;",
                f"DELETE FROM {destination} USING {copy_destination} WHERE {where_statement};",
                f"INSERT INTO {destination} SELECT * FROM {copy_destination};",
                "COMMIT",
            ]

        else:
            sql = copy_statement

        self.log.info('Executing COPY command...')
        redshift_hook.run(sql, autocommit=self.autocommit)
        self.log.info("COPY command complete...")
コード例 #4
0
    def test_execute_sts_token(
        self,
        table_as_file_name,
        expected_s3_key,
        mock_run,
        mock_session,
    ):
        access_key = "ASIA_aws_access_key_id"
        secret_key = "aws_secret_access_key"
        token = "token"
        mock_session.return_value = Session(access_key, secret_key, token)
        mock_session.return_value.access_key = access_key
        mock_session.return_value.secret_key = secret_key
        mock_session.return_value.token = token
        schema = "schema"
        table = "table"
        s3_bucket = "bucket"
        s3_key = "key"
        unload_options = [
            'HEADER',
        ]

        op = RedshiftToS3Operator(
            schema=schema,
            table=table,
            s3_bucket=s3_bucket,
            s3_key=s3_key,
            unload_options=unload_options,
            include_header=True,
            redshift_conn_id="redshift_conn_id",
            aws_conn_id="aws_conn_id",
            task_id="task_id",
            table_as_file_name=table_as_file_name,
            dag=None,
        )

        op.execute(None)

        unload_options = '\n\t\t\t'.join(unload_options)
        select_query = f"SELECT * FROM {schema}.{table}"
        credentials_block = build_credentials_block(mock_session.return_value)

        unload_query = op._build_unload_query(
            credentials_block, select_query, expected_s3_key, unload_options
        )

        assert mock_run.call_count == 1
        assert access_key in unload_query
        assert secret_key in unload_query
        assert token in unload_query
        assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], unload_query)
コード例 #5
0
ファイル: test_redshift.py プロジェクト: ysktir/airflow-1
    def test_build_credentials_block_sts(self, mock_session):
        access_key = "ASIA_aws_access_key_id"
        secret_key = "aws_secret_access_key"
        token = "aws_secret_token"
        mock_session.return_value = Session(access_key, secret_key)
        mock_session.return_value.access_key = access_key
        mock_session.return_value.secret_key = secret_key
        mock_session.return_value.token = token

        credentials_block = build_credentials_block(mock_session.return_value)

        assert access_key in credentials_block
        assert secret_key in credentials_block
        assert token in credentials_block
コード例 #6
0
ファイル: redshift_to_s3.py プロジェクト: waleedsamy/airflow
    def execute(self, context) -> None:
        postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
        s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)

        credentials = s3_hook.get_credentials()
        credentials_block = build_credentials_block(credentials)
        unload_options = '\n\t\t\t'.join(self.unload_options)

        unload_query = self._build_unload_query(credentials_block,
                                                self._select_query,
                                                self.s3_key, unload_options)

        self.log.info('Executing UNLOAD command...')
        postgres_hook.run(unload_query, self.autocommit)
        self.log.info("UNLOAD command complete...")
コード例 #7
0
    def test_custom_select_query_unloading(
        self,
        table,
        table_as_file_name,
        expected_s3_key,
        mock_run,
        mock_session,
    ):
        access_key = "aws_access_key_id"
        secret_key = "aws_secret_access_key"
        mock_session.return_value = Session(access_key, secret_key)
        mock_session.return_value.access_key = access_key
        mock_session.return_value.secret_key = secret_key
        mock_session.return_value.token = None
        s3_bucket = "bucket"
        s3_key = "key"
        unload_options = [
            'HEADER',
        ]
        select_query = "select column from table"

        op = RedshiftToS3Operator(
            select_query=select_query,
            table=table,
            table_as_file_name=table_as_file_name,
            s3_bucket=s3_bucket,
            s3_key=s3_key,
            unload_options=unload_options,
            include_header=True,
            redshift_conn_id="redshift_conn_id",
            aws_conn_id="aws_conn_id",
            task_id="task_id",
            dag=None,
        )

        op.execute(None)

        unload_options = '\n\t\t\t'.join(unload_options)
        credentials_block = build_credentials_block(mock_session.return_value)

        unload_query = op._build_unload_query(credentials_block, select_query,
                                              expected_s3_key, unload_options)

        assert mock_run.call_count == 1
        assert access_key in unload_query
        assert secret_key in unload_query
        assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0],
                                            unload_query)
コード例 #8
0
ファイル: redshift_to_s3.py プロジェクト: sylvainczr/airflow
    def execute(self, context) -> None:
        postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
        s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)

        credentials = s3_hook.get_credentials()
        credentials_block = build_credentials_block(credentials)
        unload_options = '\n\t\t\t'.join(self.unload_options)
        s3_key = f"{self.s3_key}/{self.table}_" if self.table_as_file_name else self.s3_key
        select_query = f"SELECT * FROM {self.schema}.{self.table}"

        unload_query = self._build_unload_query(credentials_block,
                                                select_query, s3_key,
                                                unload_options)

        self.log.info('Executing UNLOAD command...')
        postgres_hook.run(unload_query, self.autocommit)
        self.log.info("UNLOAD command complete...")
コード例 #9
0
    def test_truncate(self, mock_run, mock_session):
        access_key = "aws_access_key_id"
        secret_key = "aws_secret_access_key"
        mock_session.return_value = Session(access_key, secret_key)
        mock_session.return_value.access_key = access_key
        mock_session.return_value.secret_key = secret_key
        mock_session.return_value.token = None

        schema = "schema"
        table = "table"
        s3_bucket = "bucket"
        s3_key = "key"
        copy_options = ""

        op = S3ToRedshiftOperator(
            schema=schema,
            table=table,
            s3_bucket=s3_bucket,
            s3_key=s3_key,
            copy_options=copy_options,
            truncate_table=True,
            redshift_conn_id="redshift_conn_id",
            aws_conn_id="aws_conn_id",
            task_id="task_id",
            dag=None,
        )
        op.execute(None)

        credentials_block = build_credentials_block(mock_session.return_value)
        copy_statement = op._build_copy_query(credentials_block, copy_options)

        truncate_statement = f'TRUNCATE TABLE {schema}.{table};'
        transaction = f"""
                    BEGIN;
                    {truncate_statement}
                    {copy_statement}
                    COMMIT
                    """
        assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0],
                                            transaction)

        assert mock_run.call_count == 1
コード例 #10
0
    def execute(self, context: 'Context') -> None:
        redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
        conn = S3Hook.get_connection(conn_id=self.aws_conn_id)
        if conn.extra_dejson.get('role_arn', False):
            credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
        else:
            s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
            credentials = s3_hook.get_credentials()
            credentials_block = build_credentials_block(credentials)

        unload_options = '\n\t\t\t'.join(self.unload_options)

        unload_query = self._build_unload_query(credentials_block,
                                                self.select_query, self.s3_key,
                                                unload_options)

        self.log.info('Executing UNLOAD command...')
        redshift_hook.run(unload_query,
                          self.autocommit,
                          parameters=self.parameters)
        self.log.info("UNLOAD command complete...")
コード例 #11
0
    def execute(self, context) -> None:
        postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
        s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
        credentials = s3_hook.get_credentials()
        credentials_block = build_credentials_block(credentials)
        copy_options = '\n\t\t\t'.join(self.copy_options)
        destination = f'{self.schema}.{self.table}'
        copy_destination = f'#{self.table}' if self.method == 'UPSERT' else destination

        copy_statement = self._build_copy_query(copy_destination, credentials_block, copy_options)

        if self.method == 'REPLACE':
            sql = f"""
            BEGIN;
            DELETE FROM {destination};
            {copy_statement}
            COMMIT
            """
        elif self.method == 'UPSERT':
            keys = self.upsert_keys or postgres_hook.get_table_primary_key(self.table, self.schema)
            if not keys:
                raise AirflowException(
                    f"No primary key on {self.schema}.{self.table}. Please provide keys on 'upsert_keys'"
                )
            where_statement = ' AND '.join([f'{self.table}.{k} = {copy_destination}.{k}' for k in keys])
            sql = f"""
            CREATE TABLE {copy_destination} (LIKE {destination});
            {copy_statement}
            BEGIN;
            DELETE FROM {destination} USING {copy_destination} WHERE {where_statement};
            INSERT INTO {destination} SELECT * FROM {copy_destination};
            COMMIT
            """
        else:
            sql = copy_statement

        self.log.info('Executing COPY command...')
        postgres_hook.run(sql, self.autocommit)
        self.log.info("COPY command complete...")
コード例 #12
0
    def execute(self, context) -> None:
        postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
        s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
        credentials = s3_hook.get_credentials()
        credentials_block = build_credentials_block(credentials)
        copy_options = '\n\t\t\t'.join(self.copy_options)

        copy_statement = self._build_copy_query(credentials_block, copy_options)

        if self.truncate_table:
            delete_statement = f'DELETE FROM {self.schema}.{self.table};'
            sql = f"""
            BEGIN;
            {delete_statement}
            {copy_statement}
            COMMIT
            """
        else:
            sql = copy_statement

        self.log.info('Executing COPY command...')
        postgres_hook.run(sql, self.autocommit)
        self.log.info("COPY command complete...")