def execute(self, context) -> bool: """Is written to depend on transform method""" s3_conn = S3Hook(self.aws_conn_id) # Grab collection and execute query according to whether or not it is a pipeline if self.is_pipeline: results = MongoHook(self.mongo_conn_id).aggregate( mongo_collection=self.mongo_collection, aggregate_query=cast(list, self.mongo_query), mongo_db=self.mongo_db, allowDiskUse=self.allow_disk_use, ) else: results = MongoHook(self.mongo_conn_id).find( mongo_collection=self.mongo_collection, query=cast(dict, self.mongo_query), mongo_db=self.mongo_db, ) # Performs transform then stringifies the docs results into json format docs_str = self._stringify(self.transform(results)) s3_conn.load_string( string_data=docs_str, key=self.s3_key, bucket_name=self.s3_bucket, replace=self.replace, compression=self.compression, )
def execute(self, context) -> bool: """Executed by task_instance at runtime""" s3_conn = S3Hook(self.s3_conn_id) # Grab collection and execute query according to whether or not it is a pipeline if self.is_pipeline: results = MongoHook(self.mongo_conn_id).aggregate( mongo_collection=self.mongo_collection, aggregate_query=cast(list, self.mongo_query), mongo_db=self.mongo_db, ) else: results = MongoHook(self.mongo_conn_id).find( mongo_collection=self.mongo_collection, query=cast(dict, self.mongo_query), mongo_db=self.mongo_db, ) # Performs transform then stringifies the docs results into json format docs_str = self._stringify(self.transform(results)) # Load Into S3 s3_conn.load_string(string_data=docs_str, key=self.s3_key, bucket_name=self.s3_bucket, replace=self.replace) return True
def poke(self, context: dict) -> bool: self.log.info( "Sensor check existence of the document " "that matches the following query: %s", self.query) hook = MongoHook(self.mongo_conn_id) return hook.find(self.collection, self.query, find_one=True) is not None
def extract( batch_id, method="GET", http_conn_id="default_api", mongo_conn_id="default_mongo" ): http = HttpHook(method, http_conn_id=http_conn_id) mongo_conn = MongoHook(mongo_conn_id) ids_to_update_coll = mongo_conn.get_collection("ids_to_update", "courts") results_to_transform_coll = mongo_conn.get_collection( "results_to_transform", "courts" ) # Note/TODO: because we add endpoints back that we couldn't handle, we may # get stuck in an infinite loop. Another solution is exiting whenever an # exception occurs, but this isn't ideal either while ids_to_update_coll.find_one({"batch_id": str(batch_id)}) != None: # find a job to work on result = ids_to_update_coll.find_one_and_delete({"batch_id": str(batch_id)}) api_id = result["api_id"] try: # transform to get a valid link # TODO: this needs to be generalized to any website endpoint = f"opinions/{api_id}" # pull data in response = http.run(endpoint) result_data = response.json() if response.status_code == 200: # store our result into mongo results_to_transform_coll.insert_one( {"batch_id": str(batch_id), "data": result_data} ) else: # TODO: throw a more specific exception raise AirflowException( f"Received {response.status_code} code from {endpoint}." ) except json.JSONDecodeError as j_error: print("Failed to decode response with {j_error}:\n{response.body}") mongo_conn.insert_one( "ids_to_update", {"api_id": str(api_id), "batch_id": str(batch_id)}, mongo_db="courts", ) except Exception as error: # something went wrong. Log it and return this endpoint to mongoDB so we can try again print(f"An exception occured while processing batch {batch_id}:\n{error}") mongo_conn.insert_one( "ids_to_update", {"api_id": str(api_id), "batch_id": str(batch_id)}, mongo_db="courts", )
def execute(self, context: Dict[str, Any]) -> Any: self.http = HttpHook(self.method, http_conn_id=self.http_conn_id) self.mongo_conn = MongoHook(self.mongo_conn_id) # generate query parameters self.query = self.query_builder() self.log.info(f"Connecting to: {self.http_conn_id}") return_val = self._execute(context) self._shutdown() return return_val
def execute(self, context): mongoHook = MongoHook(conn_id=self.mongo_conn_id) self.mongo_db = mongoHook.connection.schema log.info('postgres_conn_id: %s', self.postgres_conn_id) log.info('mongo_conn_id: %s', self.mongo_conn_id) log.info('postgres_sql: %s', self.postgres_sql) # log.info('prev_exec_date: %s', self.prev_exec_date) log.info('mongo_db: %s', self.mongo_db) log.info('mongo_collection: %s', self.mongo_collection) well_data = self.get_data() most_recent_date = Variable.get("most_recent_date") print(most_recent_date) filter_query = None for index, well in well_data.iterrows(): if well is not None and well['is_newly_added']: print('newly added') filter_query = {"Name": {"$eq": well['well_name']}} else: print('old well') filter_query = { "$and": [{ "Name": { "$eq": well['well_name'] } }, { "Date": { "$gt": most_recent_date } }] } # filter_query = { "Date" : { "$gt" : most_recent_date } } log.info('mongo filter query: %s', filter_query) mongo_well_list = self.transform( mongoHook.get_collection( self.mongo_collection).find(filter_query)) print(len(mongo_well_list)) if len(mongo_well_list) > 0: for doc in mongo_well_list: doc["water_cut_calc"] = utils.calc_watercut( doc['OIL_bopd'], doc['WATER_bwpd']) doc["gor_calc"] = utils.calc_gor(doc['OIL_bopd'], doc['GAS_mscfd']) self.update_records(mongoHook, filter_query, mongo_well_list)
def test_context_manager(self): with MongoHook(conn_id='mongo_default', mongo_db='default') as ctx_hook: ctx_hook.get_conn() assert isinstance(ctx_hook, MongoHook) assert ctx_hook.client is not None assert ctx_hook.client is None
def test_context_manager(self): with MongoHook(conn_id='mongo_default', mongo_db='default') as ctx_hook: ctx_hook.get_conn() self.assertIsInstance(ctx_hook, MongoHook) self.assertIsNotNone(ctx_hook.client) self.assertIsNone(ctx_hook.client)
def extract_mongodb(client, dbs, coll, source, task_instance, extract_by_batch=None): """ export data from mongodb to json. Arg: client = ``MongoClient()``. dbs = name of database. coll = name of collection. initial_id = document id in ``objectID`` or any unique keys, default ``None`` extract_by_batch = ``int`` batch of rows , default ``None`` Return: list_of_docs """ initial_id=task_instance.xcom_pull(task_ids='first_run') with client: fetch=MongoHook(conn_id='mongo_localhost').find(mongo_collection=coll, mongo_db=dbs) list_of_docs=[] count=0 if initial_id is not None: # determine which row to start for doc in fetch: count+=1 if initial_id == None: count=0 break if initial_id == doc['_id']: break if extract_by_batch is None and initial_id is None: for docs in fetch: docs['_id']=str(docs['_id']) list_of_docs.append(docs) print('extract all') elif extract_by_batch is None and initial_id is not None: for docs in islice(fetch, count): docs['_id']=str(docs['_id']) list_of_docs.append(docs) print('extract all start at {}'.format(count)) elif extract_by_batch is not None and initial_id is None: for docs in islice(fetch, 0, count+extract_by_batch): docs['_id']=str(docs['_id']) list_of_docs.append(docs) print('extract_by_batch {} at {}'.format(extract_by_batch, count)) elif extract_by_batch is not None and initial_id is not None: for docs in islice(fetch, count, count+extract_by_batch): docs['_id']=str(docs['_id']) list_of_docs.append(docs) print('extract_by_batch {} at {}'.format(extract_by_batch, count)) print(len(list_of_docs),"'s rows from {} is being extract'".format(coll)) del fetch with open(source,'w') as json_tripdata: json.dump(list_of_docs, json_tripdata,indent=1) return list_of_docs
def test_transform_load_operator( self, mocker, postgresql, ports_collection, test_dag ): """Test if transform_load_operator upserts data into master db.""" # Create mocks mocker.patch.object( PostgresHook, "get_conn", return_value=postgresql ) mocker.patch.object( MongoHook, "get_collection", return_value=ports_collection ) # Check if the source table has an item in it mongo_hook = MongoHook() collection = mongo_hook.get_collection() assert collection.count_documents({}) > 0 # Check if the sink table is initially empty cursor = postgresql.cursor() cursor.execute("SELECT COUNT(*) FROM ports;") initial_result = cursor.fetchone()[0] assert initial_result == 0 # Setup task mongo_staging_config = MongoConfig('mongo_default', 'ports') postgres_master_config = PostgresConfig('postgres_default') task = TransformAndLoadOperator( mongo_config=mongo_staging_config, postgres_config=postgres_master_config, task_id='test', processor=PortsItemProcessor(), query=SqlQueries.ports_table_insert, query_params={"updated_at": datetime.datetime.utcnow()}, dag=test_dag ) # Execute task and check if it inserted the data successfully task.execute(context={}, testing=True) cursor.execute("SELECT COUNT(*) FROM ports;") after_result = cursor.fetchone()[0] assert after_result > 0
def setUp(self): db.merge_conn( Connection(conn_id='mongo_test', conn_type='mongo', host='mongo', port='27017', schema='test')) args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) hook = MongoHook('mongo_test') hook.insert_one('foo', {'bar': 'baz'}) self.sensor = MongoSensor(task_id='test_task', mongo_conn_id='mongo_test', dag=self.dag, collection='foo', query={'bar': 'baz'})
def process_source_data(): fileHook = FSHook('fs_custom') mongoHook = MongoHook() path = os.path.join(fileHook.get_path(), 'daily_production_data.json') df = pd.read_json(path) water_cut_calc = [] gor_calc = [] for index, row in df.iterrows(): water_cut_calc.append( utils.calc_watercut(row['OIL_bopd'], row['WATER_bwpd'])) gor_calc.append(utils.calc_gor(row['OIL_bopd'], row['GAS_mscfd'])) df = df.assign(**{'water_cut_calc': water_cut_calc, 'gor_calc': gor_calc}) data_dict = df.to_dict("records") mongoHook.insert_many('DailyProduction', data_dict, 'fusion_dev_db') os.remove(path)
def test_transform_load_operator_database_error( self, mocker, postgresql, ports_collection, test_dag ): """Test if transform_load_operator handles DB errors.""" # Create mocks mocker.patch.object( PostgresHook, "get_conn", return_value=postgresql ) mocker.patch.object( MongoHook, "get_collection", return_value=ports_collection ) # Check if the source table has an item in it mongo_hook = MongoHook() collection = mongo_hook.get_collection() assert collection.count_documents({}) > 0 # Setup task, intentionally give an unknown table mongo_staging_config = MongoConfig('mongo_default', 'ports') postgres_master_config = PostgresConfig('postgres_default') task = TransformAndLoadOperator( mongo_config=mongo_staging_config, postgres_config=postgres_master_config, task_id='test', processor=PortsItemProcessor(), query=SqlQueries.ports_table_insert.replace( 'ports', 'ports_wrong' ), query_params={"updated_at": datetime.datetime.utcnow()}, dag=test_dag ) # Execute the task and check if it will raise an UndefinedTable error with raises((UndefinedTable, Exception, OperationalError)): # Set testing to false to implicitly close the database task.execute(context={}, testing=False) task.execute(context={}, testing=True)
def check_data(fetch_last_id, task_instance): last_objectid_from_transformed_pq=task_instance.xcom_pull(task_ids='first_run') with fetch_last_id: fetch_last_id=MongoHook(conn_id='mongo_localhost').find().sort({'id': -1}) for doc in fetch_last_id: last_objectid_from_mongodb=doc['_id'] if last_objectid_from_mongodb==last_objectid_from_transformed_pq: return 'bigquery_is_up_to_date' else: return 'get_data_from_mongodb'
def execute(self, context, testing=False): """ Read all data from mongo db, process it and write to postgresql db. Uses UPSERT SQL query to write data. """ self.log.info('LoadToMasterdbOperator Starting...') self.log.info("Initializing Mongo Staging DB Connection...") mongo_hook = MongoHook(conn_id=self._mongo_conn_id) ports_collection = mongo_hook.get_collection(self._mongo_collection) self.log.info("Initializing Postgres Master DB Connection...") psql_hook = PostgresHook(postgres_conn_id=self._postgres_conn_id) psql_conn = psql_hook.get_conn() psql_cursor = psql_conn.cursor() self.log.info("Loading Staging data to Master Database...") try: for idx, document in enumerate(ports_collection.find({})): document = self._processor.process_item(document) staging_id = document.get('_id').__str__() document['staging_id'] = staging_id document.pop('_id') psql_cursor.execute(self._sql_query, document) psql_conn.commit() except (OperationalError, UndefinedTable, OperationFailure): self.log.error("Writting to database FAILED.") self.log.error(traceback.format_exc()) raise Exception("LoadToMasterdbOperator FAILED.") except Exception: self.log.error(traceback.format_exc()) raise Exception("LoadToMasterdbOperator FAILED.") finally: if not testing: self.log.info('Closing database connections...') psql_conn.close() mongo_hook.close_conn() self.log.info(f'UPSERTED {idx+1} records into Postgres Database.') self.log.info('LoadToMasterdbOperator SUCCESS!')
def execute(self, context): mongoHook = MongoHook(conn_id=self.mongo_conn_id) log.info('odbc_conn_id: %s', self.odbc_conn_id) log.info('postgres_conn_id: %s', self.postgres_conn_id) log.info('mongo_conn_id: %s', self.mongo_conn_id) log.info('mongo_db: %s', mongoHook.connection.schema) log.info('mongo_collection: %s', self.mongo_collection) log.info('odbc_sql: %s', self.odbc_sql) log.info('postgres_sql: %s', self.postgres_sql) log.info('postgres_insert_sql: %s', self.postgres_insert_sql) mongo_well_list = mongoHook.get_collection( self.mongo_collection).distinct("Name") log.info('mongo well list: %s', mongo_well_list) odbc_well_list = self.get_data() log.info('odbc well list: %s', odbc_well_list) final_well_list = [] if not mongo_well_list and len(mongo_well_list) == 0: final_well_list = self.prepare_well_list(odbc_well_list, True) else: mongo_filtered_well_list = self.prepare_well_list( mongo_well_list, False) new_well_list = list(set(odbc_well_list) - set(mongo_well_list)) log.info('new well list: %s', new_well_list) new_well_list = self.prepare_well_list(new_well_list, True) postgres_well_list = self.get_well_data() if postgres_well_list.empty == True: for item in new_well_list: final_well_list.append(item) else: final_well_list = new_well_list log.info('final well list for insert: %s', final_well_list) if final_well_list and len(final_well_list) > 0: self.insert_data(final_well_list)
def test_transform_load_operator_exception_error( self, mocker, postgresql, ports_collection, test_dag ): """Test if transform_load_operator handles Exception thrown.""" # Create mocks mocker.patch.object( PostgresHook, "get_conn", return_value=postgresql ) mocker.patch.object( MongoHook, "get_collection", return_value=ports_collection ) # Check if the source table has an item in it mongo_hook = MongoHook() collection = mongo_hook.get_collection() assert collection.count_documents({}) > 0 # Setup task mongo_staging_config = MongoConfig('mongo_default', 'ports') postgres_master_config = PostgresConfig('postgres_default') task = TransformAndLoadOperator( mongo_config=mongo_staging_config, postgres_config=postgres_master_config, task_id='test', processor=PortsItemProcessor(), query='Wrong SQL query', dag=test_dag ) # Execute task and check if it will raise an Exception error with raises(Exception): task.execute(context={}, testing=True)
def execute(self, context): mongoHook = MongoHook(conn_id=self.mongo_conn_id) self.mongo_db = mongoHook.connection.schema log.info('odbc_conn_id: %s', self.odbc_conn_id) log.info('postgres_conn_id: %s', self.postgres_conn_id) log.info('mongo_conn_id: %s', self.mongo_conn_id) log.info('postgres_sql: %s', self.postgres_sql) # log.info('prev_exec_date: %s', self.prev_exec_date) log.info('mongo_db: %s', self.mongo_db) log.info('mongo_collection: %s', self.mongo_collection) odbcHook = OdbcHook(self.odbc_conn_id) well_data = self.get_data() log.info('postgres well data: %s', well_data) most_recent_date = self.get_most_recent_date(mongoHook) print(most_recent_date) if most_recent_date: print('store most recent date inside airflow variable') Variable.set("most_recent_date", most_recent_date) with closing(odbcHook.get_conn()) as conn: for index, well in well_data.iterrows(): print(well['well_name'], well['is_newly_added']) if well is not None and well['is_newly_added']: sql = 'SELECT * FROM [dbo].[MV_Amarok_DailyProdWellDemoData] where Name = ?' df = pd.read_sql(sql, conn, params=[well['well_name']]) else: sql = 'SELECT * FROM [dbo].[MV_Amarok_DailyProdWellDemoData] where Name = ? and Date > ?' df = pd.read_sql( sql, conn, params=[well['well_name'], most_recent_date]) if not df.empty: data_dict = df.to_dict("records") self.insert_records(mongoHook, data_dict)
def local_batch(batch_name: str, max_batch_size: int, number_of_batches: int, root_path: str) -> None: curr_item_num = 0 batch_id = 1 id_list = [] mongo_conn = MongoHook(conn_id="default_mongo") # Iterate through all directories and assign each filepath a batch_id. for root, directories, files in os.walk(root_path, topdown=False): for name in files: fp = os.path.join(root, name) if curr_item_num == max_batch_size: mongo_conn.insert_many("local_results_to_transform", id_list, mongo_db="courts") curr_item_num = 0 batch_id %= number_of_batches batch_id += 1 id_list = [] if curr_item_num < max_batch_size: # Each document in local_results_to_transform is a dict of # the filepath and the batch it is assigned to. id_list.append({ "batch_id": f"{batch_name}{batch_id}", "file_path": fp }) curr_item_num += 1 # Push any remaining filepaths to the local_results_to_transform collection. if (curr_item_num > 0): mongo_conn.insert_many("local_results_to_transform", id_list, mongo_db="courts") id_list = []
def test_save_to_json_operator( self, mocker, postgresql, ports_collection, test_dag, tmp_path: Path ): """Test if save_to_json_operator saves the file on a specified path""" # Create mocks mocker.patch.object( PostgresHook, "get_conn", return_value=postgresql ) mocker.patch.object( MongoHook, "get_collection", return_value=ports_collection ) # Check if the source table has an item in it mongo_hook = MongoHook() collection = mongo_hook.get_collection() assert collection.count_documents({}) > 0 # Setup some data, transfer staging data to master mongo_staging_config = MongoConfig('mongo_default', 'ports') postgres_master_config = PostgresConfig('postgres_default') transform_load = TransformAndLoadOperator( mongo_config=mongo_staging_config, postgres_config=postgres_master_config, task_id='test', processor=PortsItemProcessor(), query=SqlQueries.ports_table_insert, query_params={"updated_at": datetime.datetime.utcnow()}, dag=test_dag ) # Execute task and check if it inserted the data successfully transform_load.execute(context={}, testing=True) pg_hook = PostgresHook() cursor = pg_hook.get_conn().cursor() cursor.execute("SELECT COUNT(*) FROM ports;") after_result = cursor.fetchone()[0] assert after_result > 0 # Alter tmp_path to forcesively create a path tmp_path = tmp_path / 'unknown-path' # Execute save_to_json to save the data into json file on tmp_path save_to_json = LoadToJsonOperator( task_id='export_to_json', postgres_config=postgres_master_config, query=SqlQueries.select_all_query_to_json, path=tmp_path, tables=['ports'], dag=test_dag ) save_to_json.execute( {'execution_date': datetime.datetime(2021, 1, 1)} ) output_path = tmp_path / 'ports_20210101T000000.json' expected_data = { 'ports': [{ 'id': 1, 'countryName': 'Philippines', 'portName': 'Aleran/Ozamis', 'unlocode': 'PH ALE', 'coordinates': '4234N 00135E' }] } # Read result with open(output_path, "r") as f: result = json.load(f) # Assert assert 'ports' in result assert result == expected_data
class BaseAPIOperator(BaseOperator): """ Base Operator for API Requests to a main endpoint that generates subendpoints for futher requests. :param endpoint: The API endpoint to query. :type endpoint: str :param parser: Function that parses the endpoint response, into a list of sub-endpoints. Should return a list of strings. :type parser: function that takes a requests.Response object and returns a list of sub-endpoints :param response_count: Function that returns number of items in API response :type response_count: Callable[[requests.Response], int] :param number_of_batches: Number of batches used in the DAG Run. :type number_of_batches: int :param http_conn_id: Airflow Connection variable name for the base API URL. :type http_conn_id: str :param mongo_conn_id: Airflow Connection variable name for the MongoDB. :type mongo_conn_id: str :param response_valid: Function that checks if status code is valid. Defaults to 200 status only. :type response_valid: Callable[[requests.Response], bool] :param query_builder: Function that returns a Dictionary of query parameters. :type query_builder: Callable[[None], Dict[str, str]] :param header: Headers to be added to API request. :type header: dict of string key-value pairs :param options: Optional keyword arguments for the Requests library get function. :type options: dict of string key-value pairs :param log_response: Flag to allow for logging Request response. Defaults to False. :type log_response: bool """ @apply_defaults def __init__( self, endpoint: str, parser: Callable[ [requests.Response], list], # Function that parses a response to gather specific endpoints response_count: Callable[ [requests.Response], int], # Determines the number of items from query number_of_batches: int, http_conn_id: str, mongo_conn_id: str, batch_name: str, response_valid: Callable[[requests.Response], bool] = None, query_builder: Callable[[None], str] = None, header: Optional[Dict[str, str]] = None, options: Optional[Dict[str, Any]] = None, log_response: bool = False, **kwargs, ) -> None: # delegate to BaseOperator, we don't need to do anything else super().__init__(**kwargs) self.number_of_batches = number_of_batches # API endpoint information, we should only be making GET requests from here # Header is most likely unneccessary self.endpoint = endpoint self.method = "GET" self.query_builder = query_builder or self._default_query_builder self.header = header or {} self.http_conn_id = http_conn_id self.mongo_conn_id = mongo_conn_id self.batch_name = batch_name # Functions for operating on response data self.parser = parser self.response_count = response_count self.response_valid = response_valid or self._default_response_valid # Options is for Requests library functions self.options = options or {} self.log_response = log_response # # these get instantiated on execute # these get instantiated on execute self.http = None self.mongo_conn = None # Override the execute method, we want any derived classes to override # _execute() def execute(self, context: Dict[str, Any]) -> Any: self.http = HttpHook(self.method, http_conn_id=self.http_conn_id) self.mongo_conn = MongoHook(self.mongo_conn_id) # generate query parameters self.query = self.query_builder() self.log.info(f"Connecting to: {self.http_conn_id}") return_val = self._execute(context) self._shutdown() return return_val def _execute(self, context: Dict[str, Any]) -> Any: raise NotImplementedError( "_execute() needs to be defined for subclasses.") def _call_once(self, use_query: bool = False) -> Union[requests.Response, None]: """ Execute a single API call. :param query: If use_query is true, we use the internal query string provided in our request. :type query: bool (defaults to False) """ response = self.http.run( self.endpoint, self.query if use_query else {}, self.header, self.options, ) if self.log_response: self.log.info(response.url) if not self.response_valid(response): return None return self._to_json(response) def _to_json(self, response: requests.Response): try: return response.json() except JSONDecodeError: self.log.error( f"Failed to convert response to JSON: {response.url}") return None def _api_id_to_document(self, _id: str, name: str, batch_id: int): return {"api_id": str(_id), "batch_id": f"{name}{batch_id}"} def _default_query_builder(self) -> dict: return {} def _default_response_valid(self, response: requests.Response) -> bool: """Default response_valid() function. Returns True only on 200.""" return response.status_code == 200 def _shutdown(self) -> None: """Explicitly close MongoDB connection""" if self.mongo_conn: self.mongo_conn.close_conn()
class MongoDatabase: """ Instance of the MongoDB database where all the extracted data will be stored. The database URI must be set in the "Connections" of Apache Airflow. """ collection_indexes = { 'default': [('date', DESCENDING), ('autonomous_region', ASCENDING)], # this is the most common index 'clinic_description': [('type', ASCENDING), ('description', ASCENDING)], 'population_ar': [('autonomous_region', ASCENDING)], 'death_causes': [('death_cause', ASCENDING), ('age_range', ASCENDING)], 'chronic_illnesses': [('illness', ASCENDING)], 'outbreaks_description': [('date', DESCENDING), ('scope', ASCENDING), ('subscope', ASCENDING)], 'top_death_causes': [('death_cause', ASCENDING)] } extracted_db_name = 'covid_extracted_data' analyzed_db_name = 'covid_analyzed_data' def __init__(self, database_name): """ Connect to the database. """ self.client = MongoHook(conn_id='mongo_covid').get_conn() self.db = self.client.get_database(database_name) @staticmethod def create_collection_index(collection): """Create a custom index for a collection, to improve I/O tasks""" if len(list(collection.list_indexes())) < 2: # Not index created yet, let's create it if collection.name in MongoDatabase.collection_indexes: index = MongoDatabase.collection_indexes[collection.name] else: index = MongoDatabase.collection_indexes['default'] collection.create_index(index) def read_data(self, collection_name, filters=None, projection=None): """ Read data from the database and return it as a DataFrame. :param collection_name: Name of the collection from which the data will be read :param filters: (optional) Dictionary with the query filters. :param projection: (optional) List of columns to retrieve. """ collection = self.db.get_collection(collection_name) if projection: projected_fields = {field: 1 for field in projection} else: projected_fields = {} projected_fields['_id'] = 0 query = collection.find(filters, projected_fields) df = pd.DataFrame(query) return df def store_data(self, collection_name, data, overwrite=True): """ Store data in the database. :param collection_name: Name of the collection in which the data will be stored :param data: document or documents to be stored in the collection :param overwrite: whether to delete the previous data in the collection before storing the new one """ collection = self.db.get_collection(collection_name) if overwrite: collection.delete_many({}) MongoDatabase.create_collection_index(collection) if type(data) == list: # Several documents to be inserted collection.insert_many(data) elif type(data) == dict: # One single document to be inserted collection.insert_one(data) def __del__(self): """When the object is destroyed, the connection with the MongoDB server is released""" self.client.close()
def __init__(self, database_name): """ Connect to the database. """ self.client = MongoHook(conn_id='mongo_covid').get_conn() self.db = self.client.get_database(database_name)
def test_srv(self): hook = MongoHook(conn_id='mongo_default_with_srv') self.assertTrue(hook.uri.startswith('mongodb+srv://'))
def test_save_to_json_operator_database_error( self, mocker, postgresql, ports_collection, test_dag, tmp_path: Path ): """Test if save_to_json_operator can handle errors related to db.""" # Create mocks mocker.patch.object( PostgresHook, "get_conn", return_value=postgresql ) mocker.patch.object( MongoHook, "get_collection", return_value=ports_collection ) # Check if the source table has an item in it mongo_hook = MongoHook() collection = mongo_hook.get_collection() assert collection.count_documents({}) > 0 # Setup some data, transfer staging data to master mongo_staging_config = MongoConfig('mongo_default', 'ports') postgres_master_config = PostgresConfig('postgres_default') transform_load = TransformAndLoadOperator( mongo_config=mongo_staging_config, postgres_config=postgres_master_config, task_id='test', processor=PortsItemProcessor(), query=SqlQueries.ports_table_insert, query_params={"updated_at": datetime.datetime.utcnow()}, dag=test_dag ) # Execute task and check if it inserted the data successfully transform_load.execute(context={}, testing=True) pg_hook = PostgresHook() cursor = pg_hook.get_conn().cursor() cursor.execute("SELECT COUNT(*) FROM ports;") after_result = cursor.fetchone()[0] assert after_result > 0 # Execute save_to_json to save the data into json file on tmp_path save_to_json = LoadToJsonOperator( task_id='test2', postgres_config=postgres_master_config, query=SqlQueries.select_all_query_to_json, path=tmp_path, tables=['foo'], dag=test_dag ) with raises((UndefinedTable, OperationalError, Exception)): # Set testing = False to implicitly close the database connection save_to_json.execute( {'execution_date': datetime.datetime(2021, 1, 1)}, testing=False ) save_to_json.execute( {'execution_date': datetime.datetime(2021, 1, 1)}, testing=True )
def get_hook(self): if self.conn_type == 'mysql': from airflow.providers.mysql.hooks.mysql import MySqlHook return MySqlHook(mysql_conn_id=self.conn_id) elif self.conn_type == 'google_cloud_platform': from airflow.gcp.hooks.bigquery import BigQueryHook return BigQueryHook(bigquery_conn_id=self.conn_id) elif self.conn_type == 'postgres': from airflow.providers.postgres.hooks.postgres import PostgresHook return PostgresHook(postgres_conn_id=self.conn_id) elif self.conn_type == 'pig_cli': from airflow.providers.apache.pig.hooks.pig import PigCliHook return PigCliHook(pig_cli_conn_id=self.conn_id) elif self.conn_type == 'hive_cli': from airflow.providers.apache.hive.hooks.hive import HiveCliHook return HiveCliHook(hive_cli_conn_id=self.conn_id) elif self.conn_type == 'presto': from airflow.providers.presto.hooks.presto import PrestoHook return PrestoHook(presto_conn_id=self.conn_id) elif self.conn_type == 'hiveserver2': from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook return HiveServer2Hook(hiveserver2_conn_id=self.conn_id) elif self.conn_type == 'sqlite': from airflow.providers.sqlite.hooks.sqlite import SqliteHook return SqliteHook(sqlite_conn_id=self.conn_id) elif self.conn_type == 'jdbc': from airflow.providers.jdbc.hooks.jdbc import JdbcHook return JdbcHook(jdbc_conn_id=self.conn_id) elif self.conn_type == 'mssql': from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook return MsSqlHook(mssql_conn_id=self.conn_id) elif self.conn_type == 'odbc': from airflow.providers.odbc.hooks.odbc import OdbcHook return OdbcHook(odbc_conn_id=self.conn_id) elif self.conn_type == 'oracle': from airflow.providers.oracle.hooks.oracle import OracleHook return OracleHook(oracle_conn_id=self.conn_id) elif self.conn_type == 'vertica': from airflow.providers.vertica.hooks.vertica import VerticaHook return VerticaHook(vertica_conn_id=self.conn_id) elif self.conn_type == 'cloudant': from airflow.providers.cloudant.hooks.cloudant import CloudantHook return CloudantHook(cloudant_conn_id=self.conn_id) elif self.conn_type == 'jira': from airflow.providers.jira.hooks.jira import JiraHook return JiraHook(jira_conn_id=self.conn_id) elif self.conn_type == 'redis': from airflow.providers.redis.hooks.redis import RedisHook return RedisHook(redis_conn_id=self.conn_id) elif self.conn_type == 'wasb': from airflow.providers.microsoft.azure.hooks.wasb import WasbHook return WasbHook(wasb_conn_id=self.conn_id) elif self.conn_type == 'docker': from airflow.providers.docker.hooks.docker import DockerHook return DockerHook(docker_conn_id=self.conn_id) elif self.conn_type == 'azure_data_lake': from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id) elif self.conn_type == 'azure_cosmos': from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id) elif self.conn_type == 'cassandra': from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook return CassandraHook(cassandra_conn_id=self.conn_id) elif self.conn_type == 'mongo': from airflow.providers.mongo.hooks.mongo import MongoHook return MongoHook(conn_id=self.conn_id) elif self.conn_type == 'gcpcloudsql': from airflow.gcp.hooks.cloud_sql import CloudSQLDatabaseHook return CloudSQLDatabaseHook(gcp_cloudsql_conn_id=self.conn_id) elif self.conn_type == 'grpc': from airflow.providers.grpc.hooks.grpc import GrpcHook return GrpcHook(grpc_conn_id=self.conn_id) raise AirflowException("Unknown hook type {}".format(self.conn_type))