def test_insert_batch_items_dynamodb_table(self): hook = AwsDynamoDBHook(aws_conn_id='aws_default', table_name='test_airflow', table_keys=['id'], region_name='us-east-1') # this table needs to be created in production table = hook.get_conn().create_table( TableName='test_airflow', KeySchema=[ { 'AttributeName': 'id', 'KeyType': 'HASH' }, ], AttributeDefinitions=[{ 'AttributeName': 'id', 'AttributeType': 'S' }], ProvisionedThroughput={ 'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10 }, ) table = hook.get_conn().Table('test_airflow') items = [{ 'id': str(uuid.uuid4()), 'name': 'airflow' } for _ in range(10)] hook.write_batch_data(items) table.meta.client.get_waiter('table_exists').wait( TableName='test_airflow') self.assertEqual(table.item_count, 10)
def execute(self, context) -> None: hook = AwsDynamoDBHook(aws_conn_id=self.aws_conn_id) table = hook.get_conn().Table(self.dynamodb_table_name) scan_kwargs = copy( self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {} err = None with NamedTemporaryFile() as f: try: f = self._scan_dynamodb_and_upload_to_s3(f, scan_kwargs, table) except Exception as e: err = e raise e finally: if err is None: _upload_file_to_s3(f, self.s3_bucket_name, self.s3_key_prefix, self.aws_conn_id)
def test_get_conn_returns_a_boto3_connection(self): hook = AwsDynamoDBHook(aws_conn_id='aws_default') self.assertIsNotNone(hook.get_conn())
class TestHiveToDynamoDBOperator(unittest.TestCase): def setUp(self): args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} dag = DAG('test_dag_id', default_args=args) self.dag = dag self.sql = 'SELECT 1' self.hook = AwsDynamoDBHook(aws_conn_id='aws_default', region_name='us-east-1') @staticmethod def process_data(data, *args, **kwargs): return json.loads(data.to_json(orient='records')) @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamodb2 package not present') @mock_dynamodb2 def test_get_conn_returns_a_boto3_connection(self): hook = AwsDynamoDBHook(aws_conn_id='aws_default') self.assertIsNotNone(hook.get_conn()) @mock.patch( 'airflow.providers.apache.hive.hooks.hive.HiveServer2Hook.get_pandas_df', return_value=pd.DataFrame(data=[('1', 'sid')], columns=['id', 'name']), ) @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamodb2 package not present') @mock_dynamodb2 def test_get_records_with_schema(self, mock_get_pandas_df): # this table needs to be created in production self.hook.get_conn().create_table( TableName='test_airflow', KeySchema=[ { 'AttributeName': 'id', 'KeyType': 'HASH' }, ], AttributeDefinitions=[{ 'AttributeName': 'id', 'AttributeType': 'S' }], ProvisionedThroughput={ 'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10 }, ) operator = airflow.providers.amazon.aws.transfers.hive_to_dynamodb.HiveToDynamoDBOperator( sql=self.sql, table_name="test_airflow", task_id='hive_to_dynamodb_check', table_keys=['id'], dag=self.dag, ) operator.execute(None) table = self.hook.get_conn().Table('test_airflow') table.meta.client.get_waiter('table_exists').wait( TableName='test_airflow') self.assertEqual(table.item_count, 1) @mock.patch( 'airflow.providers.apache.hive.hooks.hive.HiveServer2Hook.get_pandas_df', return_value=pd.DataFrame(data=[('1', 'sid'), ('1', 'gupta')], columns=['id', 'name']), ) @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamodb2 package not present') @mock_dynamodb2 def test_pre_process_records_with_schema(self, mock_get_pandas_df): # this table needs to be created in production self.hook.get_conn().create_table( TableName='test_airflow', KeySchema=[ { 'AttributeName': 'id', 'KeyType': 'HASH' }, ], AttributeDefinitions=[{ 'AttributeName': 'id', 'AttributeType': 'S' }], ProvisionedThroughput={ 'ReadCapacityUnits': 10, 'WriteCapacityUnits': 10 }, ) operator = airflow.providers.amazon.aws.transfers.hive_to_dynamodb.HiveToDynamoDBOperator( sql=self.sql, table_name='test_airflow', task_id='hive_to_dynamodb_check', table_keys=['id'], pre_process=self.process_data, dag=self.dag, ) operator.execute(None) table = self.hook.get_conn().Table('test_airflow') table.meta.client.get_waiter('table_exists').wait( TableName='test_airflow') self.assertEqual(table.item_count, 1)