Exemplo n.º 1
0
    def execute(self, context) -> bool:
        """Is written to depend on transform method"""
        s3_conn = S3Hook(self.aws_conn_id)

        # Grab collection and execute query according to whether or not it is a pipeline
        if self.is_pipeline:
            results = MongoHook(self.mongo_conn_id).aggregate(
                mongo_collection=self.mongo_collection,
                aggregate_query=cast(list, self.mongo_query),
                mongo_db=self.mongo_db,
                allowDiskUse=self.allow_disk_use,
            )

        else:
            results = MongoHook(self.mongo_conn_id).find(
                mongo_collection=self.mongo_collection,
                query=cast(dict, self.mongo_query),
                mongo_db=self.mongo_db,
            )

        # Performs transform then stringifies the docs results into json format
        docs_str = self._stringify(self.transform(results))

        s3_conn.load_string(
            string_data=docs_str,
            key=self.s3_key,
            bucket_name=self.s3_bucket,
            replace=self.replace,
            compression=self.compression,
        )
Exemplo n.º 2
0
    def execute(self, context) -> bool:
        """Executed by task_instance at runtime"""
        s3_conn = S3Hook(self.s3_conn_id)

        # Grab collection and execute query according to whether or not it is a pipeline
        if self.is_pipeline:
            results = MongoHook(self.mongo_conn_id).aggregate(
                mongo_collection=self.mongo_collection,
                aggregate_query=cast(list, self.mongo_query),
                mongo_db=self.mongo_db,
            )

        else:
            results = MongoHook(self.mongo_conn_id).find(
                mongo_collection=self.mongo_collection,
                query=cast(dict, self.mongo_query),
                mongo_db=self.mongo_db,
            )

        # Performs transform then stringifies the docs results into json format
        docs_str = self._stringify(self.transform(results))

        # Load Into S3
        s3_conn.load_string(string_data=docs_str,
                            key=self.s3_key,
                            bucket_name=self.s3_bucket,
                            replace=self.replace)

        return True
Exemplo n.º 3
0
 def poke(self, context: dict) -> bool:
     self.log.info(
         "Sensor check existence of the document "
         "that matches the following query: %s", self.query)
     hook = MongoHook(self.mongo_conn_id)
     return hook.find(self.collection, self.query,
                      find_one=True) is not None
Exemplo n.º 4
0
def extract(
    batch_id, method="GET", http_conn_id="default_api", mongo_conn_id="default_mongo"
):

    http = HttpHook(method, http_conn_id=http_conn_id)

    mongo_conn = MongoHook(mongo_conn_id)
    ids_to_update_coll = mongo_conn.get_collection("ids_to_update", "courts")
    results_to_transform_coll = mongo_conn.get_collection(
        "results_to_transform", "courts"
    )

    # Note/TODO: because we add endpoints back that we couldn't handle, we may
    # get stuck in an infinite loop. Another solution is exiting whenever an
    # exception occurs, but this isn't ideal either
    while ids_to_update_coll.find_one({"batch_id": str(batch_id)}) != None:

        # find a job to work on
        result = ids_to_update_coll.find_one_and_delete({"batch_id": str(batch_id)})
        api_id = result["api_id"]
        try:

            # transform to get a valid link
            # TODO: this needs to be generalized to any website
            endpoint = f"opinions/{api_id}"

            # pull data in
            response = http.run(endpoint)

            result_data = response.json()

            if response.status_code == 200:

                # store our result into mongo
                results_to_transform_coll.insert_one(
                    {"batch_id": str(batch_id), "data": result_data}
                )

            else:
                # TODO: throw a more specific exception
                raise AirflowException(
                    f"Received {response.status_code} code from {endpoint}."
                )

        except json.JSONDecodeError as j_error:
            print("Failed to decode response with {j_error}:\n{response.body}")
            mongo_conn.insert_one(
                "ids_to_update",
                {"api_id": str(api_id), "batch_id": str(batch_id)},
                mongo_db="courts",
            )
        except Exception as error:
            # something went wrong. Log it and return this endpoint to mongoDB so we can try again
            print(f"An exception occured while processing batch {batch_id}:\n{error}")
            mongo_conn.insert_one(
                "ids_to_update",
                {"api_id": str(api_id), "batch_id": str(batch_id)},
                mongo_db="courts",
            )
Exemplo n.º 5
0
    def execute(self, context: Dict[str, Any]) -> Any:

        self.http = HttpHook(self.method, http_conn_id=self.http_conn_id)
        self.mongo_conn = MongoHook(self.mongo_conn_id)

        # generate query parameters
        self.query = self.query_builder()

        self.log.info(f"Connecting to: {self.http_conn_id}")

        return_val = self._execute(context)

        self._shutdown()

        return return_val
    def execute(self, context):
        mongoHook = MongoHook(conn_id=self.mongo_conn_id)
        self.mongo_db = mongoHook.connection.schema
        log.info('postgres_conn_id: %s', self.postgres_conn_id)
        log.info('mongo_conn_id: %s', self.mongo_conn_id)
        log.info('postgres_sql: %s', self.postgres_sql)
        # log.info('prev_exec_date: %s', self.prev_exec_date)
        log.info('mongo_db: %s', self.mongo_db)
        log.info('mongo_collection: %s', self.mongo_collection)

        well_data = self.get_data()
        most_recent_date = Variable.get("most_recent_date")
        print(most_recent_date)
        filter_query = None
        for index, well in well_data.iterrows():
            if well is not None and well['is_newly_added']:
                print('newly added')
                filter_query = {"Name": {"$eq": well['well_name']}}
            else:
                print('old well')
                filter_query = {
                    "$and": [{
                        "Name": {
                            "$eq": well['well_name']
                        }
                    }, {
                        "Date": {
                            "$gt": most_recent_date
                        }
                    }]
                }
                # filter_query = { "Date" : { "$gt" : most_recent_date } }

            log.info('mongo filter query: %s', filter_query)
            mongo_well_list = self.transform(
                mongoHook.get_collection(
                    self.mongo_collection).find(filter_query))
            print(len(mongo_well_list))
            if len(mongo_well_list) > 0:
                for doc in mongo_well_list:
                    doc["water_cut_calc"] = utils.calc_watercut(
                        doc['OIL_bopd'], doc['WATER_bwpd'])
                    doc["gor_calc"] = utils.calc_gor(doc['OIL_bopd'],
                                                     doc['GAS_mscfd'])

                self.update_records(mongoHook, filter_query, mongo_well_list)
Exemplo n.º 7
0
    def test_context_manager(self):
        with MongoHook(conn_id='mongo_default', mongo_db='default') as ctx_hook:
            ctx_hook.get_conn()

            assert isinstance(ctx_hook, MongoHook)
            assert ctx_hook.client is not None

        assert ctx_hook.client is None
Exemplo n.º 8
0
    def test_context_manager(self):
        with MongoHook(conn_id='mongo_default', mongo_db='default') as ctx_hook:
            ctx_hook.get_conn()

            self.assertIsInstance(ctx_hook, MongoHook)
            self.assertIsNotNone(ctx_hook.client)

        self.assertIsNone(ctx_hook.client)
def extract_mongodb(client, dbs, coll, source, task_instance, extract_by_batch=None):
    """
    export data from mongodb to json.
    
    Arg:
        client = ``MongoClient()``.

        dbs = name of database.
        
        coll = name of collection.
      
        initial_id = document id in ``objectID`` or any unique keys, default ``None``

        extract_by_batch = ``int`` batch of rows , default ``None`` 
    Return:
        list_of_docs
    """ 
    initial_id=task_instance.xcom_pull(task_ids='first_run')
    with client:
        fetch=MongoHook(conn_id='mongo_localhost').find(mongo_collection=coll,
                                                      mongo_db=dbs)
        list_of_docs=[]
        count=0
        if initial_id is not None:                                          # determine which row to start 
            for doc in fetch:
                count+=1
                if initial_id == None:
                    count=0
                    break
                if initial_id == doc['_id']:
                    break

        if extract_by_batch is None and initial_id is None:
            for docs in fetch:
                docs['_id']=str(docs['_id'])
                list_of_docs.append(docs)   
            print('extract all')
        elif extract_by_batch is None and initial_id is not None:
            for docs in islice(fetch, count):
                docs['_id']=str(docs['_id'])
                list_of_docs.append(docs)   
            print('extract all start at {}'.format(count))
        elif extract_by_batch is not None and initial_id is None:
            for docs in islice(fetch, 0, count+extract_by_batch):
                docs['_id']=str(docs['_id'])
                list_of_docs.append(docs)   
            print('extract_by_batch {} at {}'.format(extract_by_batch, count))
        elif extract_by_batch is not None and initial_id is not None:
            for docs in islice(fetch, count, count+extract_by_batch):
                docs['_id']=str(docs['_id'])
                list_of_docs.append(docs)   
            print('extract_by_batch {} at {}'.format(extract_by_batch, count))
        print(len(list_of_docs),"'s rows from {} is being extract'".format(coll))
        del fetch
        with open(source,'w') as json_tripdata:
            json.dump(list_of_docs, json_tripdata,indent=1)

    return list_of_docs
    def test_transform_load_operator(
        self, mocker, postgresql, ports_collection, test_dag
    ):
        """Test if transform_load_operator upserts data into master db."""
        # Create mocks
        mocker.patch.object(
            PostgresHook,
            "get_conn",
            return_value=postgresql
        )
        mocker.patch.object(
            MongoHook,
            "get_collection",
            return_value=ports_collection
        )

        # Check if the source table has an item in it
        mongo_hook = MongoHook()
        collection = mongo_hook.get_collection()
        assert collection.count_documents({}) > 0

        # Check if the sink table is initially empty
        cursor = postgresql.cursor()
        cursor.execute("SELECT COUNT(*) FROM ports;")
        initial_result = cursor.fetchone()[0]
        assert initial_result == 0

        # Setup task
        mongo_staging_config = MongoConfig('mongo_default', 'ports')
        postgres_master_config = PostgresConfig('postgres_default')
        task = TransformAndLoadOperator(
            mongo_config=mongo_staging_config,
            postgres_config=postgres_master_config,
            task_id='test',
            processor=PortsItemProcessor(),
            query=SqlQueries.ports_table_insert,
            query_params={"updated_at": datetime.datetime.utcnow()},
            dag=test_dag
        )

        # Execute task and check if it inserted the data successfully
        task.execute(context={}, testing=True)
        cursor.execute("SELECT COUNT(*) FROM ports;")
        after_result = cursor.fetchone()[0]
        assert after_result > 0
Exemplo n.º 11
0
    def setUp(self):
        db.merge_conn(
            Connection(conn_id='mongo_test',
                       conn_type='mongo',
                       host='mongo',
                       port='27017',
                       schema='test'))

        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
        self.dag = DAG('test_dag_id', default_args=args)

        hook = MongoHook('mongo_test')
        hook.insert_one('foo', {'bar': 'baz'})

        self.sensor = MongoSensor(task_id='test_task',
                                  mongo_conn_id='mongo_test',
                                  dag=self.dag,
                                  collection='foo',
                                  query={'bar': 'baz'})
Exemplo n.º 12
0
def process_source_data():
    fileHook = FSHook('fs_custom')
    mongoHook = MongoHook()
    path = os.path.join(fileHook.get_path(), 'daily_production_data.json')

    df = pd.read_json(path)
    water_cut_calc = []
    gor_calc = []

    for index, row in df.iterrows():
        water_cut_calc.append(
            utils.calc_watercut(row['OIL_bopd'], row['WATER_bwpd']))
        gor_calc.append(utils.calc_gor(row['OIL_bopd'], row['GAS_mscfd']))

    df = df.assign(**{'water_cut_calc': water_cut_calc, 'gor_calc': gor_calc})

    data_dict = df.to_dict("records")
    mongoHook.insert_many('DailyProduction', data_dict, 'fusion_dev_db')

    os.remove(path)
    def test_transform_load_operator_database_error(
        self, mocker, postgresql, ports_collection, test_dag
    ):
        """Test if transform_load_operator handles DB errors."""
        # Create mocks
        mocker.patch.object(
            PostgresHook,
            "get_conn",
            return_value=postgresql
        )
        mocker.patch.object(
            MongoHook,
            "get_collection",
            return_value=ports_collection
        )

        # Check if the source table has an item in it
        mongo_hook = MongoHook()
        collection = mongo_hook.get_collection()
        assert collection.count_documents({}) > 0

        # Setup task, intentionally give an unknown table
        mongo_staging_config = MongoConfig('mongo_default', 'ports')
        postgres_master_config = PostgresConfig('postgres_default')
        task = TransformAndLoadOperator(
            mongo_config=mongo_staging_config,
            postgres_config=postgres_master_config,
            task_id='test',
            processor=PortsItemProcessor(),
            query=SqlQueries.ports_table_insert.replace(
                'ports', 'ports_wrong'
            ),
            query_params={"updated_at": datetime.datetime.utcnow()},
            dag=test_dag
        )

        # Execute the task and check if it will raise an UndefinedTable error
        with raises((UndefinedTable, Exception, OperationalError)):
            # Set testing to false to implicitly close the database
            task.execute(context={}, testing=False)
            task.execute(context={}, testing=True)
def check_data(fetch_last_id, task_instance):
    
    last_objectid_from_transformed_pq=task_instance.xcom_pull(task_ids='first_run')
    with fetch_last_id:
        fetch_last_id=MongoHook(conn_id='mongo_localhost').find().sort({'id': -1})

        for doc in fetch_last_id:
            last_objectid_from_mongodb=doc['_id']
            if last_objectid_from_mongodb==last_objectid_from_transformed_pq:
                return 'bigquery_is_up_to_date'
            else:
                return 'get_data_from_mongodb'
Exemplo n.º 15
0
    def execute(self, context, testing=False):
        """
        Read all data from mongo db, process it
        and write to postgresql db.

        Uses UPSERT SQL query to write data.
        """
        self.log.info('LoadToMasterdbOperator Starting...')
        self.log.info("Initializing Mongo Staging DB Connection...")
        mongo_hook = MongoHook(conn_id=self._mongo_conn_id)
        ports_collection = mongo_hook.get_collection(self._mongo_collection)
        self.log.info("Initializing Postgres Master DB Connection...")
        psql_hook = PostgresHook(postgres_conn_id=self._postgres_conn_id)
        psql_conn = psql_hook.get_conn()
        psql_cursor = psql_conn.cursor()
        self.log.info("Loading Staging data to Master Database...")
        try:
            for idx, document in enumerate(ports_collection.find({})):
                document = self._processor.process_item(document)
                staging_id = document.get('_id').__str__()
                document['staging_id'] = staging_id
                document.pop('_id')
                psql_cursor.execute(self._sql_query, document)
            psql_conn.commit()
        except (OperationalError, UndefinedTable, OperationFailure):
            self.log.error("Writting to database FAILED.")
            self.log.error(traceback.format_exc())
            raise Exception("LoadToMasterdbOperator FAILED.")
        except Exception:
            self.log.error(traceback.format_exc())
            raise Exception("LoadToMasterdbOperator FAILED.")
        finally:
            if not testing:
                self.log.info('Closing database connections...')
                psql_conn.close()
                mongo_hook.close_conn()
        self.log.info(f'UPSERTED {idx+1} records into Postgres Database.')
        self.log.info('LoadToMasterdbOperator SUCCESS!')
    def execute(self, context):

        mongoHook = MongoHook(conn_id=self.mongo_conn_id)

        log.info('odbc_conn_id: %s', self.odbc_conn_id)
        log.info('postgres_conn_id: %s', self.postgres_conn_id)
        log.info('mongo_conn_id: %s', self.mongo_conn_id)
        log.info('mongo_db: %s', mongoHook.connection.schema)
        log.info('mongo_collection: %s', self.mongo_collection)
        log.info('odbc_sql: %s', self.odbc_sql)
        log.info('postgres_sql: %s', self.postgres_sql)
        log.info('postgres_insert_sql: %s', self.postgres_insert_sql)

        mongo_well_list = mongoHook.get_collection(
            self.mongo_collection).distinct("Name")
        log.info('mongo well list: %s', mongo_well_list)
        odbc_well_list = self.get_data()
        log.info('odbc well list: %s', odbc_well_list)
        final_well_list = []
        if not mongo_well_list and len(mongo_well_list) == 0:
            final_well_list = self.prepare_well_list(odbc_well_list, True)
        else:
            mongo_filtered_well_list = self.prepare_well_list(
                mongo_well_list, False)
            new_well_list = list(set(odbc_well_list) - set(mongo_well_list))
            log.info('new well list: %s', new_well_list)
            new_well_list = self.prepare_well_list(new_well_list, True)
            postgres_well_list = self.get_well_data()
            if postgres_well_list.empty == True:
                for item in new_well_list:
                    final_well_list.append(item)
            else:
                final_well_list = new_well_list

        log.info('final well list for insert: %s', final_well_list)
        if final_well_list and len(final_well_list) > 0:
            self.insert_data(final_well_list)
    def test_transform_load_operator_exception_error(
        self, mocker, postgresql, ports_collection, test_dag
    ):
        """Test if transform_load_operator handles Exception thrown."""
        # Create mocks
        mocker.patch.object(
            PostgresHook,
            "get_conn",
            return_value=postgresql
        )
        mocker.patch.object(
            MongoHook,
            "get_collection",
            return_value=ports_collection
        )

        # Check if the source table has an item in it
        mongo_hook = MongoHook()
        collection = mongo_hook.get_collection()
        assert collection.count_documents({}) > 0

        # Setup task
        mongo_staging_config = MongoConfig('mongo_default', 'ports')
        postgres_master_config = PostgresConfig('postgres_default')
        task = TransformAndLoadOperator(
            mongo_config=mongo_staging_config,
            postgres_config=postgres_master_config,
            task_id='test',
            processor=PortsItemProcessor(),
            query='Wrong SQL query',
            dag=test_dag
        )

        # Execute task and check if it will raise an Exception error
        with raises(Exception):
            task.execute(context={}, testing=True)
Exemplo n.º 18
0
    def execute(self, context):
        mongoHook = MongoHook(conn_id=self.mongo_conn_id)
        self.mongo_db = mongoHook.connection.schema

        log.info('odbc_conn_id: %s', self.odbc_conn_id)
        log.info('postgres_conn_id: %s', self.postgres_conn_id)
        log.info('mongo_conn_id: %s', self.mongo_conn_id)
        log.info('postgres_sql: %s', self.postgres_sql)
        # log.info('prev_exec_date: %s', self.prev_exec_date)
        log.info('mongo_db: %s', self.mongo_db)
        log.info('mongo_collection: %s', self.mongo_collection)

        odbcHook = OdbcHook(self.odbc_conn_id)

        well_data = self.get_data()
        log.info('postgres well data: %s', well_data)
        most_recent_date = self.get_most_recent_date(mongoHook)
        print(most_recent_date)
        if most_recent_date:
            print('store most recent date inside airflow variable')
            Variable.set("most_recent_date", most_recent_date)

        with closing(odbcHook.get_conn()) as conn:
            for index, well in well_data.iterrows():
                print(well['well_name'], well['is_newly_added'])
                if well is not None and well['is_newly_added']:
                    sql = 'SELECT *  FROM [dbo].[MV_Amarok_DailyProdWellDemoData] where Name = ?'
                    df = pd.read_sql(sql, conn, params=[well['well_name']])
                else:
                    sql = 'SELECT *  FROM [dbo].[MV_Amarok_DailyProdWellDemoData] where Name = ? and Date > ?'
                    df = pd.read_sql(
                        sql,
                        conn,
                        params=[well['well_name'], most_recent_date])

                if not df.empty:
                    data_dict = df.to_dict("records")
                    self.insert_records(mongoHook, data_dict)
Exemplo n.º 19
0
def local_batch(batch_name: str, max_batch_size: int, number_of_batches: int,
                root_path: str) -> None:

    curr_item_num = 0
    batch_id = 1
    id_list = []

    mongo_conn = MongoHook(conn_id="default_mongo")

    # Iterate through all directories and assign each filepath a batch_id.
    for root, directories, files in os.walk(root_path, topdown=False):
        for name in files:
            fp = os.path.join(root, name)
            if curr_item_num == max_batch_size:

                mongo_conn.insert_many("local_results_to_transform",
                                       id_list,
                                       mongo_db="courts")

                curr_item_num = 0
                batch_id %= number_of_batches
                batch_id += 1
                id_list = []

            if curr_item_num < max_batch_size:

                # Each document in local_results_to_transform is a dict of
                # the filepath and the batch it is assigned to.
                id_list.append({
                    "batch_id": f"{batch_name}{batch_id}",
                    "file_path": fp
                })
                curr_item_num += 1

    # Push any remaining filepaths to the local_results_to_transform collection.
    if (curr_item_num > 0):
        mongo_conn.insert_many("local_results_to_transform",
                               id_list,
                               mongo_db="courts")
        id_list = []
Exemplo n.º 20
0
    def test_save_to_json_operator(
        self, mocker, postgresql, ports_collection, test_dag,
        tmp_path: Path
    ):
        """Test if save_to_json_operator saves the file on a specified path"""
        # Create mocks
        mocker.patch.object(
            PostgresHook,
            "get_conn",
            return_value=postgresql
        )
        mocker.patch.object(
            MongoHook,
            "get_collection",
            return_value=ports_collection
        )

        # Check if the source table has an item in it
        mongo_hook = MongoHook()
        collection = mongo_hook.get_collection()
        assert collection.count_documents({}) > 0

        # Setup some data, transfer staging data to master
        mongo_staging_config = MongoConfig('mongo_default', 'ports')
        postgres_master_config = PostgresConfig('postgres_default')
        transform_load = TransformAndLoadOperator(
            mongo_config=mongo_staging_config,
            postgres_config=postgres_master_config,
            task_id='test',
            processor=PortsItemProcessor(),
            query=SqlQueries.ports_table_insert,
            query_params={"updated_at": datetime.datetime.utcnow()},
            dag=test_dag
        )

        # Execute task and check if it inserted the data successfully
        transform_load.execute(context={}, testing=True)
        pg_hook = PostgresHook()
        cursor = pg_hook.get_conn().cursor()
        cursor.execute("SELECT COUNT(*) FROM ports;")
        after_result = cursor.fetchone()[0]
        assert after_result > 0

        # Alter tmp_path to forcesively create a path
        tmp_path = tmp_path / 'unknown-path'

        # Execute save_to_json to save the data into json file on tmp_path
        save_to_json = LoadToJsonOperator(
            task_id='export_to_json',
            postgres_config=postgres_master_config,
            query=SqlQueries.select_all_query_to_json,
            path=tmp_path,
            tables=['ports'],
            dag=test_dag
        )
        save_to_json.execute(
            {'execution_date': datetime.datetime(2021, 1, 1)}
        )

        output_path = tmp_path / 'ports_20210101T000000.json'

        expected_data = {
            'ports': [{
                'id': 1,
                'countryName': 'Philippines',
                'portName': 'Aleran/Ozamis',
                'unlocode': 'PH ALE',
                'coordinates': '4234N 00135E'
            }]
        }

        # Read result
        with open(output_path, "r") as f:
            result = json.load(f)

        # Assert
        assert 'ports' in result
        assert result == expected_data
Exemplo n.º 21
0
class BaseAPIOperator(BaseOperator):
    """
    Base Operator for API Requests to a main endpoint that generates subendpoints for futher requests.

    :param endpoint: The API endpoint to query.
    :type endpoint: str
    :param parser: Function that parses the endpoint response, into a list of sub-endpoints.
                   Should return a list of strings.
    :type parser: function that takes a requests.Response object and returns a list of sub-endpoints
    :param response_count: Function that returns number of items in API response
    :type response_count: Callable[[requests.Response], int]
    :param number_of_batches: Number of batches used in the DAG Run.
    :type number_of_batches: int
    :param http_conn_id: Airflow Connection variable name for the base API URL.
    :type http_conn_id: str
    :param mongo_conn_id: Airflow Connection variable name for the MongoDB.
    :type mongo_conn_id: str
    :param response_valid: Function that checks if status code is valid. Defaults to 200 status only.
    :type response_valid: Callable[[requests.Response], bool]
    :param query_builder: Function that returns a Dictionary of query parameters.
    :type query_builder: Callable[[None], Dict[str, str]]
    :param header: Headers to be added to API request.
    :type header: dict of string key-value pairs
    :param options: Optional keyword arguments for the Requests library get function.
    :type options: dict of string key-value pairs
    :param log_response: Flag to allow for logging Request response. Defaults to False.
    :type log_response: bool
    """
    @apply_defaults
    def __init__(
        self,
        endpoint: str,
        parser: Callable[
            [requests.Response],
            list],  # Function that parses a response to gather specific endpoints
        response_count: Callable[
            [requests.Response],
            int],  # Determines the number of items from query
        number_of_batches: int,
        http_conn_id: str,
        mongo_conn_id: str,
        batch_name: str,
        response_valid: Callable[[requests.Response], bool] = None,
        query_builder: Callable[[None], str] = None,
        header: Optional[Dict[str, str]] = None,
        options: Optional[Dict[str, Any]] = None,
        log_response: bool = False,
        **kwargs,
    ) -> None:

        # delegate to BaseOperator, we don't need to do anything else
        super().__init__(**kwargs)

        self.number_of_batches = number_of_batches

        # API endpoint information, we should only be making GET requests from here
        # Header is most likely unneccessary
        self.endpoint = endpoint
        self.method = "GET"
        self.query_builder = query_builder or self._default_query_builder
        self.header = header or {}

        self.http_conn_id = http_conn_id
        self.mongo_conn_id = mongo_conn_id
        self.batch_name = batch_name

        # Functions for operating on response data
        self.parser = parser
        self.response_count = response_count
        self.response_valid = response_valid or self._default_response_valid

        # Options is for Requests library functions
        self.options = options or {}

        self.log_response = log_response

        # # these get instantiated on execute
        # these get instantiated on execute
        self.http = None
        self.mongo_conn = None

    # Override the execute method, we want any derived classes to override
    # _execute()
    def execute(self, context: Dict[str, Any]) -> Any:

        self.http = HttpHook(self.method, http_conn_id=self.http_conn_id)
        self.mongo_conn = MongoHook(self.mongo_conn_id)

        # generate query parameters
        self.query = self.query_builder()

        self.log.info(f"Connecting to: {self.http_conn_id}")

        return_val = self._execute(context)

        self._shutdown()

        return return_val

    def _execute(self, context: Dict[str, Any]) -> Any:
        raise NotImplementedError(
            "_execute() needs to be defined for subclasses.")

    def _call_once(self,
                   use_query: bool = False) -> Union[requests.Response, None]:
        """
        Execute a single API call.

        :param query: If use_query is true, we use the internal query string provided in our request.
        :type query: bool (defaults to False)
        """
        response = self.http.run(
            self.endpoint,
            self.query if use_query else {},
            self.header,
            self.options,
        )

        if self.log_response:
            self.log.info(response.url)

        if not self.response_valid(response):
            return None

        return self._to_json(response)

    def _to_json(self, response: requests.Response):
        try:
            return response.json()
        except JSONDecodeError:
            self.log.error(
                f"Failed to convert response to JSON: {response.url}")
            return None

    def _api_id_to_document(self, _id: str, name: str, batch_id: int):
        return {"api_id": str(_id), "batch_id": f"{name}{batch_id}"}

    def _default_query_builder(self) -> dict:
        return {}

    def _default_response_valid(self, response: requests.Response) -> bool:
        """Default response_valid() function. Returns True only on 200."""
        return response.status_code == 200

    def _shutdown(self) -> None:
        """Explicitly close MongoDB connection"""
        if self.mongo_conn:
            self.mongo_conn.close_conn()
Exemplo n.º 22
0
class MongoDatabase:
    """
        Instance of the MongoDB database where all the extracted data will be stored.
        The database URI must be set in the "Connections" of Apache Airflow.
    """

    collection_indexes = {
        'default':
        [('date', DESCENDING),
         ('autonomous_region', ASCENDING)],  # this is the most common index
        'clinic_description': [('type', ASCENDING),
                               ('description', ASCENDING)],
        'population_ar': [('autonomous_region', ASCENDING)],
        'death_causes': [('death_cause', ASCENDING), ('age_range', ASCENDING)],
        'chronic_illnesses': [('illness', ASCENDING)],
        'outbreaks_description': [('date', DESCENDING), ('scope', ASCENDING),
                                  ('subscope', ASCENDING)],
        'top_death_causes': [('death_cause', ASCENDING)]
    }

    extracted_db_name = 'covid_extracted_data'
    analyzed_db_name = 'covid_analyzed_data'

    def __init__(self, database_name):
        """
            Connect to the database.
        """
        self.client = MongoHook(conn_id='mongo_covid').get_conn()
        self.db = self.client.get_database(database_name)

    @staticmethod
    def create_collection_index(collection):
        """Create a custom index for a collection, to improve I/O tasks"""
        if len(list(collection.list_indexes())) < 2:
            # Not index created yet, let's create it
            if collection.name in MongoDatabase.collection_indexes:
                index = MongoDatabase.collection_indexes[collection.name]
            else:
                index = MongoDatabase.collection_indexes['default']

            collection.create_index(index)

    def read_data(self, collection_name, filters=None, projection=None):
        """
            Read data from the database and return it as a DataFrame.
            :param collection_name: Name of the collection from which the data will be read
            :param filters: (optional) Dictionary with the query filters.
            :param projection: (optional) List of columns to retrieve.
        """
        collection = self.db.get_collection(collection_name)
        if projection:
            projected_fields = {field: 1 for field in projection}
        else:
            projected_fields = {}

        projected_fields['_id'] = 0

        query = collection.find(filters, projected_fields)
        df = pd.DataFrame(query)

        return df

    def store_data(self, collection_name, data, overwrite=True):
        """
            Store data in the database.
            :param collection_name: Name of the collection in which the data will be stored
            :param data: document or documents to be stored in the collection
            :param overwrite: whether to delete the previous data in the collection before storing the new one
        """

        collection = self.db.get_collection(collection_name)

        if overwrite:
            collection.delete_many({})

        MongoDatabase.create_collection_index(collection)

        if type(data) == list:
            # Several documents to be inserted
            collection.insert_many(data)
        elif type(data) == dict:
            # One single document to be inserted
            collection.insert_one(data)

    def __del__(self):
        """When the object is destroyed, the connection with the MongoDB server is released"""
        self.client.close()
Exemplo n.º 23
0
 def __init__(self, database_name):
     """
         Connect to the database.
     """
     self.client = MongoHook(conn_id='mongo_covid').get_conn()
     self.db = self.client.get_database(database_name)
Exemplo n.º 24
0
 def test_srv(self):
     hook = MongoHook(conn_id='mongo_default_with_srv')
     self.assertTrue(hook.uri.startswith('mongodb+srv://'))
Exemplo n.º 25
0
    def test_save_to_json_operator_database_error(
        self, mocker, postgresql, ports_collection, test_dag,
        tmp_path: Path
    ):
        """Test if save_to_json_operator can handle errors related to db."""
        # Create mocks
        mocker.patch.object(
            PostgresHook,
            "get_conn",
            return_value=postgresql
        )
        mocker.patch.object(
            MongoHook,
            "get_collection",
            return_value=ports_collection
        )

        # Check if the source table has an item in it
        mongo_hook = MongoHook()
        collection = mongo_hook.get_collection()
        assert collection.count_documents({}) > 0

        # Setup some data, transfer staging data to master
        mongo_staging_config = MongoConfig('mongo_default', 'ports')
        postgres_master_config = PostgresConfig('postgres_default')
        transform_load = TransformAndLoadOperator(
            mongo_config=mongo_staging_config,
            postgres_config=postgres_master_config,
            task_id='test',
            processor=PortsItemProcessor(),
            query=SqlQueries.ports_table_insert,
            query_params={"updated_at": datetime.datetime.utcnow()},
            dag=test_dag
        )

        # Execute task and check if it inserted the data successfully
        transform_load.execute(context={}, testing=True)
        pg_hook = PostgresHook()
        cursor = pg_hook.get_conn().cursor()
        cursor.execute("SELECT COUNT(*) FROM ports;")
        after_result = cursor.fetchone()[0]
        assert after_result > 0

        # Execute save_to_json to save the data into json file on tmp_path
        save_to_json = LoadToJsonOperator(
            task_id='test2',
            postgres_config=postgres_master_config,
            query=SqlQueries.select_all_query_to_json,
            path=tmp_path,
            tables=['foo'],
            dag=test_dag
        )
        with raises((UndefinedTable, OperationalError, Exception)):
            # Set testing = False to implicitly close the database connection
            save_to_json.execute(
                {'execution_date': datetime.datetime(2021, 1, 1)},
                testing=False
            )
            save_to_json.execute(
                {'execution_date': datetime.datetime(2021, 1, 1)},
                testing=True
            )
Exemplo n.º 26
0
 def get_hook(self):
     if self.conn_type == 'mysql':
         from airflow.providers.mysql.hooks.mysql import MySqlHook
         return MySqlHook(mysql_conn_id=self.conn_id)
     elif self.conn_type == 'google_cloud_platform':
         from airflow.gcp.hooks.bigquery import BigQueryHook
         return BigQueryHook(bigquery_conn_id=self.conn_id)
     elif self.conn_type == 'postgres':
         from airflow.providers.postgres.hooks.postgres import PostgresHook
         return PostgresHook(postgres_conn_id=self.conn_id)
     elif self.conn_type == 'pig_cli':
         from airflow.providers.apache.pig.hooks.pig import PigCliHook
         return PigCliHook(pig_cli_conn_id=self.conn_id)
     elif self.conn_type == 'hive_cli':
         from airflow.providers.apache.hive.hooks.hive import HiveCliHook
         return HiveCliHook(hive_cli_conn_id=self.conn_id)
     elif self.conn_type == 'presto':
         from airflow.providers.presto.hooks.presto import PrestoHook
         return PrestoHook(presto_conn_id=self.conn_id)
     elif self.conn_type == 'hiveserver2':
         from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook
         return HiveServer2Hook(hiveserver2_conn_id=self.conn_id)
     elif self.conn_type == 'sqlite':
         from airflow.providers.sqlite.hooks.sqlite import SqliteHook
         return SqliteHook(sqlite_conn_id=self.conn_id)
     elif self.conn_type == 'jdbc':
         from airflow.providers.jdbc.hooks.jdbc import JdbcHook
         return JdbcHook(jdbc_conn_id=self.conn_id)
     elif self.conn_type == 'mssql':
         from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
         return MsSqlHook(mssql_conn_id=self.conn_id)
     elif self.conn_type == 'odbc':
         from airflow.providers.odbc.hooks.odbc import OdbcHook
         return OdbcHook(odbc_conn_id=self.conn_id)
     elif self.conn_type == 'oracle':
         from airflow.providers.oracle.hooks.oracle import OracleHook
         return OracleHook(oracle_conn_id=self.conn_id)
     elif self.conn_type == 'vertica':
         from airflow.providers.vertica.hooks.vertica import VerticaHook
         return VerticaHook(vertica_conn_id=self.conn_id)
     elif self.conn_type == 'cloudant':
         from airflow.providers.cloudant.hooks.cloudant import CloudantHook
         return CloudantHook(cloudant_conn_id=self.conn_id)
     elif self.conn_type == 'jira':
         from airflow.providers.jira.hooks.jira import JiraHook
         return JiraHook(jira_conn_id=self.conn_id)
     elif self.conn_type == 'redis':
         from airflow.providers.redis.hooks.redis import RedisHook
         return RedisHook(redis_conn_id=self.conn_id)
     elif self.conn_type == 'wasb':
         from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
         return WasbHook(wasb_conn_id=self.conn_id)
     elif self.conn_type == 'docker':
         from airflow.providers.docker.hooks.docker import DockerHook
         return DockerHook(docker_conn_id=self.conn_id)
     elif self.conn_type == 'azure_data_lake':
         from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook
         return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id)
     elif self.conn_type == 'azure_cosmos':
         from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook
         return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id)
     elif self.conn_type == 'cassandra':
         from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook
         return CassandraHook(cassandra_conn_id=self.conn_id)
     elif self.conn_type == 'mongo':
         from airflow.providers.mongo.hooks.mongo import MongoHook
         return MongoHook(conn_id=self.conn_id)
     elif self.conn_type == 'gcpcloudsql':
         from airflow.gcp.hooks.cloud_sql import CloudSQLDatabaseHook
         return CloudSQLDatabaseHook(gcp_cloudsql_conn_id=self.conn_id)
     elif self.conn_type == 'grpc':
         from airflow.providers.grpc.hooks.grpc import GrpcHook
         return GrpcHook(grpc_conn_id=self.conn_id)
     raise AirflowException("Unknown hook type {}".format(self.conn_type))