def test_upsert_document(self, mock_cosmos): test_id = str(uuid.uuid4()) mock_cosmos.return_value.CreateItem.return_value = {'id': test_id} hook = AzureCosmosDBHook( azure_cosmos_conn_id='azure_cosmos_test_key_id') returned_item = hook.upsert_document( {'data1': 'somedata'}, database_name=self.test_database_name, collection_name=self.test_collection_name, document_id=test_id) expected_calls = [ mock.call().CreateItem( 'dbs/' + self.test_database_name + '/colls/' + self.test_collection_name, { 'data1': 'somedata', 'id': test_id }) ] mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) mock_cosmos.assert_has_calls(expected_calls) logging.getLogger().info(returned_item) self.assertEqual(returned_item['id'], test_id)
def test_delete_database(self, mock_cosmos): hook = AzureCosmosDBHook( azure_cosmos_conn_id='azure_cosmos_test_key_id') hook.delete_database(self.test_database_name) expected_calls = [mock.call().DeleteDatabase('dbs/test_database_name')] mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) mock_cosmos.assert_has_calls(expected_calls)
def test_create_container_default(self, mock_cosmos): hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') hook.create_collection(self.test_collection_name) expected_calls = [mock.call().CreateContainer( 'dbs/test_database_default', {'id': self.test_collection_name})] mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) mock_cosmos.assert_has_calls(expected_calls)
def test_insert_documents(self, mock_cosmos): test_id1 = str(uuid.uuid4()) test_id2 = str(uuid.uuid4()) test_id3 = str(uuid.uuid4()) documents = [ { 'id': test_id1, 'data': 'data1' }, { 'id': test_id2, 'data': 'data2' }, { 'id': test_id3, 'data': 'data3' }, ] hook = AzureCosmosDBHook( azure_cosmos_conn_id='azure_cosmos_test_key_id') returned_item = hook.insert_documents(documents) expected_calls = [ mock.call().CreateItem( 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, { 'data': 'data1', 'id': test_id1 }, ), mock.call().CreateItem( 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, { 'data': 'data2', 'id': test_id2 }, ), mock.call().CreateItem( 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, { 'data': 'data3', 'id': test_id3 }, ), ] logging.getLogger().info(returned_item) mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) mock_cosmos.assert_has_calls(expected_calls, any_order=True)
def execute(self, context: Dict[Any, Any]) -> None: # Create the hook hook = AzureCosmosDBHook(azure_cosmos_conn_id=self.azure_cosmos_conn_id) # Create the DB if it doesn't already exist if not hook.does_database_exist(self.database_name): hook.create_database(self.database_name) # Create the collection as well if not hook.does_collection_exist(self.collection_name, self.database_name): hook.create_collection(self.collection_name, self.database_name) # finally insert the document hook.upsert_document(self.document, self.database_name, self.collection_name)
def test_upsert_document_default(self, mock_cosmos): test_id = str(uuid.uuid4()) mock_cosmos.return_value.CreateItem.return_value = {'id': test_id} hook = AzureCosmosDBHook( azure_cosmos_conn_id='azure_cosmos_test_key_id') returned_item = hook.upsert_document({'id': test_id}) expected_calls = [ mock.call().CreateItem( 'dbs/' + self.test_database_default + '/colls/' + self.test_collection_default, {'id': test_id}, ) ] mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) mock_cosmos.assert_has_calls(expected_calls) logging.getLogger().info(returned_item) assert returned_item['id'] == test_id
def poke(self, context: dict) -> bool: self.log.info("*** Intering poke") hook = AzureCosmosDBHook(self.azure_cosmos_conn_id) return hook.get_document(self.document_id, self.database_name, self.collection_name) is not None
def get_hook(self): if self.conn_type == 'mysql': from airflow.providers.mysql.hooks.mysql import MySqlHook return MySqlHook(mysql_conn_id=self.conn_id) elif self.conn_type == 'google_cloud_platform': from airflow.gcp.hooks.bigquery import BigQueryHook return BigQueryHook(bigquery_conn_id=self.conn_id) elif self.conn_type == 'postgres': from airflow.providers.postgres.hooks.postgres import PostgresHook return PostgresHook(postgres_conn_id=self.conn_id) elif self.conn_type == 'pig_cli': from airflow.providers.apache.pig.hooks.pig import PigCliHook return PigCliHook(pig_cli_conn_id=self.conn_id) elif self.conn_type == 'hive_cli': from airflow.providers.apache.hive.hooks.hive import HiveCliHook return HiveCliHook(hive_cli_conn_id=self.conn_id) elif self.conn_type == 'presto': from airflow.providers.presto.hooks.presto import PrestoHook return PrestoHook(presto_conn_id=self.conn_id) elif self.conn_type == 'hiveserver2': from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook return HiveServer2Hook(hiveserver2_conn_id=self.conn_id) elif self.conn_type == 'sqlite': from airflow.providers.sqlite.hooks.sqlite import SqliteHook return SqliteHook(sqlite_conn_id=self.conn_id) elif self.conn_type == 'jdbc': from airflow.providers.jdbc.hooks.jdbc import JdbcHook return JdbcHook(jdbc_conn_id=self.conn_id) elif self.conn_type == 'mssql': from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook return MsSqlHook(mssql_conn_id=self.conn_id) elif self.conn_type == 'odbc': from airflow.providers.odbc.hooks.odbc import OdbcHook return OdbcHook(odbc_conn_id=self.conn_id) elif self.conn_type == 'oracle': from airflow.providers.oracle.hooks.oracle import OracleHook return OracleHook(oracle_conn_id=self.conn_id) elif self.conn_type == 'vertica': from airflow.providers.vertica.hooks.vertica import VerticaHook return VerticaHook(vertica_conn_id=self.conn_id) elif self.conn_type == 'cloudant': from airflow.providers.cloudant.hooks.cloudant import CloudantHook return CloudantHook(cloudant_conn_id=self.conn_id) elif self.conn_type == 'jira': from airflow.providers.jira.hooks.jira import JiraHook return JiraHook(jira_conn_id=self.conn_id) elif self.conn_type == 'redis': from airflow.providers.redis.hooks.redis import RedisHook return RedisHook(redis_conn_id=self.conn_id) elif self.conn_type == 'wasb': from airflow.providers.microsoft.azure.hooks.wasb import WasbHook return WasbHook(wasb_conn_id=self.conn_id) elif self.conn_type == 'docker': from airflow.providers.docker.hooks.docker import DockerHook return DockerHook(docker_conn_id=self.conn_id) elif self.conn_type == 'azure_data_lake': from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id) elif self.conn_type == 'azure_cosmos': from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id) elif self.conn_type == 'cassandra': from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook return CassandraHook(cassandra_conn_id=self.conn_id) elif self.conn_type == 'mongo': from airflow.providers.mongo.hooks.mongo import MongoHook return MongoHook(conn_id=self.conn_id) elif self.conn_type == 'gcpcloudsql': from airflow.gcp.hooks.cloud_sql import CloudSQLDatabaseHook return CloudSQLDatabaseHook(gcp_cloudsql_conn_id=self.conn_id) elif self.conn_type == 'grpc': from airflow.providers.grpc.hooks.grpc import GrpcHook return GrpcHook(grpc_conn_id=self.conn_id) raise AirflowException("Unknown hook type {}".format(self.conn_type))
def test_create_container_exception(self, mock_cosmos): hook = AzureCosmosDBHook( azure_cosmos_conn_id='azure_cosmos_test_key_id') self.assertRaises(AirflowException, hook.create_collection, None)
def test_delete_database_exception(self, mock_cosmos): hook = AzureCosmosDBHook( azure_cosmos_conn_id='azure_cosmos_test_key_id') self.assertRaises(AirflowException, hook.delete_database, None)
def test_client(self, mock_cosmos): hook = AzureCosmosDBHook( azure_cosmos_conn_id='azure_cosmos_test_key_id') self.assertIsNone(hook._conn) self.assertIsInstance(hook.get_conn(), CosmosClient)
def test_create_container_exception(self, mock_cosmos): hook = AzureCosmosDBHook( azure_cosmos_conn_id='azure_cosmos_test_key_id') with pytest.raises(AirflowException): hook.create_collection(None)
def test_client(self, mock_cosmos): hook = AzureCosmosDBHook( azure_cosmos_conn_id='azure_cosmos_test_key_id') assert hook._conn is None assert isinstance(hook.get_conn(), CosmosClient)
def test_delete_database_exception(self, mock_cosmos): hook = AzureCosmosDBHook( azure_cosmos_conn_id='azure_cosmos_test_key_id') with pytest.raises(AirflowException): hook.delete_database(None)