def setUp(self):
     args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
     dag = DAG('test_dag_id', default_args=args)
     self.dag = dag
     self.sql = 'SELECT 1'
     self.hook = AwsDynamoDBHook(aws_conn_id='aws_default',
                                 region_name='us-east-1')
Ejemplo n.º 2
0
def persist_data( **kwargs): 
    hook = AwsDynamoDBHook(table_name="TABLE_NAME", #TABLE_NAME
                            aws_conn_id='aws_default')
    faceIndexDetails = kwargs['ti'].xcom_pull(key='FaceIndexDetails')
    thumbnailDetails = kwargs['ti'].xcom_pull(key='ThumbnailDetails')
    conf = kwargs['dag_run'].conf
    dynamoItem = {
        "UserId" : conf["userId"],
        "s3Bucket" : conf["s3Bucket"],
        "s3Key": conf["s3Key"],
        "faceId" :faceIndexDetails['FaceId'],
        "thumbnail": thumbnailDetails['thumbnail']    
    }
    items = [dynamoItem]
    hook.write_batch_data(items)
Ejemplo n.º 3
0
    def execute(self, context) -> None:
        hook = AwsDynamoDBHook(aws_conn_id=self.aws_conn_id)
        table = hook.get_conn().Table(self.dynamodb_table_name)

        scan_kwargs = copy(
            self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {}
        err = None
        with NamedTemporaryFile() as f:
            try:
                f = self._scan_dynamodb_and_upload_to_s3(f, scan_kwargs, table)
            except Exception as e:
                err = e
                raise e
            finally:
                if err is None:
                    _upload_file_to_s3(f, self.s3_bucket_name,
                                       self.s3_key_prefix, self.aws_conn_id)
Ejemplo n.º 4
0
    def execute(self, context):
        hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)

        self.log.info('Extracting data from Hive')
        self.log.info(self.sql)

        data = hive.get_pandas_df(self.sql, schema=self.schema)
        dynamodb = AwsDynamoDBHook(
            aws_conn_id=self.aws_conn_id,
            table_name=self.table_name,
            table_keys=self.table_keys,
            region_name=self.region_name,
        )

        self.log.info('Inserting rows into dynamodb')

        if self.pre_process is None:
            dynamodb.write_batch_data(
                json.loads(data.to_json(orient='records')))
        else:
            dynamodb.write_batch_data(
                self.pre_process(data=data,
                                 args=self.pre_process_args,
                                 kwargs=self.pre_process_kwargs))

        self.log.info('Done.')
Ejemplo n.º 5
0
 def execute(self, context) -> None:
     table = AwsDynamoDBHook().get_conn().Table(self.dynamodb_table_name)
     scan_kwargs = copy(self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {}
     err = None
     f = NamedTemporaryFile()
     try:
         f = self._scan_dynamodb_and_upload_to_s3(f, scan_kwargs, table)
     except Exception as e:
         err = e
         raise e
     finally:
         if err is None:
             _upload_file_to_s3(f, self.s3_bucket_name, self.s3_key_prefix)
         f.close()
Ejemplo n.º 6
0
    def test_insert_batch_items_dynamodb_table(self):

        hook = AwsDynamoDBHook(aws_conn_id='aws_default',
                               table_name='test_airflow',
                               table_keys=['id'],
                               region_name='us-east-1')

        # this table needs to be created in production
        table = hook.get_conn().create_table(
            TableName='test_airflow',
            KeySchema=[
                {
                    'AttributeName': 'id',
                    'KeyType': 'HASH'
                },
            ],
            AttributeDefinitions=[{
                'AttributeName': 'id',
                'AttributeType': 'S'
            }],
            ProvisionedThroughput={
                'ReadCapacityUnits': 10,
                'WriteCapacityUnits': 10
            },
        )

        table = hook.get_conn().Table('test_airflow')

        items = [{
            'id': str(uuid.uuid4()),
            'name': 'airflow'
        } for _ in range(10)]

        hook.write_batch_data(items)

        table.meta.client.get_waiter('table_exists').wait(
            TableName='test_airflow')
        self.assertEqual(table.item_count, 10)
Ejemplo n.º 7
0
 def test_get_conn_returns_a_boto3_connection(self):
     hook = AwsDynamoDBHook(aws_conn_id='aws_default')
     self.assertIsNotNone(hook.get_conn())
class TestHiveToDynamoDBOperator(unittest.TestCase):
    def setUp(self):
        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
        dag = DAG('test_dag_id', default_args=args)
        self.dag = dag
        self.sql = 'SELECT 1'
        self.hook = AwsDynamoDBHook(aws_conn_id='aws_default',
                                    region_name='us-east-1')

    @staticmethod
    def process_data(data, *args, **kwargs):
        return json.loads(data.to_json(orient='records'))

    @unittest.skipIf(mock_dynamodb2 is None,
                     'mock_dynamodb2 package not present')
    @mock_dynamodb2
    def test_get_conn_returns_a_boto3_connection(self):
        hook = AwsDynamoDBHook(aws_conn_id='aws_default')
        self.assertIsNotNone(hook.get_conn())

    @mock.patch(
        'airflow.providers.apache.hive.hooks.hive.HiveServer2Hook.get_pandas_df',
        return_value=pd.DataFrame(data=[('1', 'sid')], columns=['id', 'name']),
    )
    @unittest.skipIf(mock_dynamodb2 is None,
                     'mock_dynamodb2 package not present')
    @mock_dynamodb2
    def test_get_records_with_schema(self, mock_get_pandas_df):
        # this table needs to be created in production
        self.hook.get_conn().create_table(
            TableName='test_airflow',
            KeySchema=[
                {
                    'AttributeName': 'id',
                    'KeyType': 'HASH'
                },
            ],
            AttributeDefinitions=[{
                'AttributeName': 'id',
                'AttributeType': 'S'
            }],
            ProvisionedThroughput={
                'ReadCapacityUnits': 10,
                'WriteCapacityUnits': 10
            },
        )

        operator = airflow.providers.amazon.aws.transfers.hive_to_dynamodb.HiveToDynamoDBOperator(
            sql=self.sql,
            table_name="test_airflow",
            task_id='hive_to_dynamodb_check',
            table_keys=['id'],
            dag=self.dag,
        )

        operator.execute(None)

        table = self.hook.get_conn().Table('test_airflow')
        table.meta.client.get_waiter('table_exists').wait(
            TableName='test_airflow')
        self.assertEqual(table.item_count, 1)

    @mock.patch(
        'airflow.providers.apache.hive.hooks.hive.HiveServer2Hook.get_pandas_df',
        return_value=pd.DataFrame(data=[('1', 'sid'), ('1', 'gupta')],
                                  columns=['id', 'name']),
    )
    @unittest.skipIf(mock_dynamodb2 is None,
                     'mock_dynamodb2 package not present')
    @mock_dynamodb2
    def test_pre_process_records_with_schema(self, mock_get_pandas_df):
        # this table needs to be created in production
        self.hook.get_conn().create_table(
            TableName='test_airflow',
            KeySchema=[
                {
                    'AttributeName': 'id',
                    'KeyType': 'HASH'
                },
            ],
            AttributeDefinitions=[{
                'AttributeName': 'id',
                'AttributeType': 'S'
            }],
            ProvisionedThroughput={
                'ReadCapacityUnits': 10,
                'WriteCapacityUnits': 10
            },
        )

        operator = airflow.providers.amazon.aws.transfers.hive_to_dynamodb.HiveToDynamoDBOperator(
            sql=self.sql,
            table_name='test_airflow',
            task_id='hive_to_dynamodb_check',
            table_keys=['id'],
            pre_process=self.process_data,
            dag=self.dag,
        )

        operator.execute(None)

        table = self.hook.get_conn().Table('test_airflow')
        table.meta.client.get_waiter('table_exists').wait(
            TableName='test_airflow')
        self.assertEqual(table.item_count, 1)