def execute(self, context): """ Executed by task_instance at runtime """ s3_conn = S3Hook(self.s3_conn_id) # Grab collection and execute query according to whether or not it is a pipeline if self.is_pipeline: results = MongoHook(self.mongo_conn_id).aggregate( mongo_collection=self.mongo_collection, aggregate_query=self.mongo_query, mongo_db=self.mongo_db) else: results = MongoHook(self.mongo_conn_id).find( mongo_collection=self.mongo_collection, query=self.mongo_query, mongo_db=self.mongo_db) # Performs transform then stringifies the docs results into json format docs_str = self._stringify(self.transform(results)) # Load Into S3 s3_conn.load_string(string_data=docs_str, key=self.s3_key, bucket_name=self.s3_bucket, replace=self.replace) return True
def poke(self, context): self.log.info( "Sensor check existence of the document " "that matches the following query: %s", self.query) hook = MongoHook(self.mongo_conn_id) return hook.find(self.collection, self.query, find_one=True) is not None
def execute(self, context): logging.info('Executing: ' + str(self.sql_queries)) mysql_hook = MySqlHook(mysql_conn_id=self.mysql_conn_id) mongo_hook = MongoHook(mongo_conn_id=self.mongo_conn_id) logging.info("Transferring MySQL query results into MongoDB database.") mysql_conn = mysql_hook.get_conn() mysql_conn.cursorclass = MySQLdb.cursors.DictCursor cursor = mysql_conn.cursor() mongo_conn = mongo_hook.get_conn() mongo_db = mongo_conn.weather if self.mysql_preoperator: logging.info("Running MySQL preoperator") cursor.execute(self.mysql_preoperator) for index, sql in enumerate(self.sql_queries): cursor.execute(sql, self.parameters) fetched_rows = list(cursor.fetchall()) mongo_db[self.mongo_collections[index]].insert_many(fetched_rows) logging.info("Transfer Done")
def read_csv_and_dump(path, libname): new_data = pd.read_csv(path, index_col=[0, 1]) hook = MongoHook(conn_id="arctic_mongo") store = hook.get_conn() data_ = {} if store.library_exists(libname): # update from start, in older to overide the start date data point] # merge data by replacing lib = store.get_library(libname) for symbol in data.index.levels[0]: new_data = data[symbol] store.write(symbol, data=new_data) # simple overide the current data
def execute(self, context): """Executed by task_instance at runtime""" with closing(MongoHook(self.conn_id).get_conn()) as client: db = client[self.db] collection = db[self.collection] # NOTE: Pass only one of target_xcom and target_path assert bool(self.target_path) ^ bool(self.target_xcom) if self.target_path: # Use path given in task definition target = self.target_path else: # Access from XCOM target = context['task_instance'].xcom_pull(**self.target_xcom) if os.path.isdir(target): for filepath in glob.glob(f'/{target}/*.json'): self._update_db(filepath, collection) elif os.path.isfile(target): self._update_db(target, collection) else: # Should Never Exit Here... return False return True
def test_context_manager(self): with MongoHook(conn_id='mongo_default', mongo_db='default') as ctxHook: ctxHook.get_conn() self.assertIsInstance(ctxHook, MongoHook) self.assertIsNotNone(ctxHook.client) self.assertIsNone(ctxHook.client)
def __init__(self, apk_id, apk_version, runner_conf, target_device=None, *args, **kwargs): super(AndroidRunnerOperator, self).__init__(queue='android', runner_conf=runner_conf, *args, **kwargs) self.apk_id = apk_id self.apk_version = apk_version self.apk_path = None self.test_apk_path = None self.serial = target_device self.mongo_hk = MongoHook(conn_id='stocksdktest_mongo') self.conn = self.mongo_hk.get_conn()
def check_mongo_db(**kwargs): mongo_uri = kwargs.get('mongo_uri') mongo_db = kwargs.get('mongo_db') mongo_collection = kwargs.get('mongo_collection') mongo_conn = MongoHook(mongo_uri).get_conn() # Grab collection collection = mongo_conn.get_database(mongo_db).get_collection( mongo_collection) count = collection.find().count() if count > 0: logging.info('Total in mongo db: {}, coll: {}, count: {}'.format( mongo_db, mongo_collection, count)) return True else: logging.info('No data found in mongo db: {}, coll: {}'.format( mongo_db, mongo_collection)) return False
def poke(self, context): hook = MongoHook(self.mongo_conn_id, libname=self.libname) client = hook.get_conn() store = Arctic(client) self.log.info( f'Poking for {self.mongo_conn_id}, {self.libname}: {self.symbol}') try: if store.library_exists(self.libname): lib = store.get_library(self.libname) if lib.has_symbol(self.symbol): return self.python_call_back( self.meta, lib.read_meta(self.symbol).metadata) except OSError: return False return False
def setUp(self): db.merge_conn( Connection(conn_id='mongo_test', conn_type='mongo', host='mongo', port='27017', schema='test')) args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} self.dag = DAG('test_dag_id', default_args=args) hook = MongoHook('mongo_test') hook.insert_one('foo', {'bar': 'baz'}) self.sensor = MongoSensor(task_id='test_task', mongo_conn_id='mongo_test', dag=self.dag, collection='foo', query={'bar': 'baz'})
def content_neo4j_node_creation(**kwargs): """ :param kwargs: :return: """ from lib.neo4jintegrator import Neo4jIntegrator uri = "bolt://" + globals()["GOLD_NEO4J_IP"] + ":" + globals( )["NEO4J_PORT"] neo4j_user = "******" neo4j_pass = "******" driver = Neo4jIntegrator(uri, neo4j_user, neo4j_pass) # mongo_uri = globals()["META_MONGO_IP"] + ":" + globals()["MONGO_PORT"] meta_base = MongoHook(globals()["MONGO_META_CONN_ID"]) coll = kwargs["dag_run"].conf["swift_container"] swift_id = str(kwargs["dag_run"].conf["swift_id"]) doc = meta_base.get_conn().swift.get_collection(coll).find_one( {"swift_object_id": swift_id}) driver.insert_image(doc)
def execute(self, context): mongo = MongoHook(conn_id=self.conn_id, ) mongo.uri, dbname = mongo.uri.rsplit("/", maxsplit=1) # conn = mongo.get_conn() # return conn.list_database_names() posts = mongo.get_collection("posts", dbname) import datetime post = { "author": "Mike", "text": "My first blog post!", "tags": ["mongodb", "python", "pymongo"], "date": datetime.datetime.utcnow() } # posts = db.posts post_id = posts.insert_one(post).inserted_id # collection = mongo.get_collection('people', mongo_db='starwars') # res = collection.find_one() # return str(res['_id']) return str(post_id)
def _get_mongo_doc(self): """ It gets the document from mongodb server connection. convert it into appropriate, string into json file :param self: :return: """ mongo_conn = MongoHook(self.mongo_conn_id).get_conn() collection = mongo_conn.get_database(self.mongo_db).get_collection(self.mongo_collection) results = collection.aggregate(self.mongo_query) if self.is_pipeline else collection.find(self.mongo_query) docs_str = self._stringify(self.transform(results)) print(docs_str) #file_no = 0 tmp_file_handle = NamedTemporaryFile(delete=True) #tmp_file_handles = {self.filename:tmp_file_handle} if PY3: docs_str = docs_str.replace("$",'').encode('utf-8') tmp_file_handle.write(docs_str) tmp_file_handles = {self.filename:tmp_file_handle} final_json_file = tmp_file_handles return final_json_file
def get_mongodb(): mongodb_test = MongoHook(conn_id='mongodb_id') data = mongodb_test.find('order_item', {}) data = list(map(map_data, data)) mongodb_test.insert_many('order_item', data) data = mongodb_test.find('order_item', {}) data = list(map(map_data, data)) print(data)
def setUp(self): configuration.load_test_config() db.merge_conn( Connection( conn_id='mongo_test', conn_type='mongo', host='mongo', port='27017', schema='test')) args = { 'owner': 'airflow', 'start_date': DEFAULT_DATE } self.dag = DAG('test_dag_id', default_args=args) hook = MongoHook('mongo_test') hook.insert_one('foo', {'bar': 'baz'}) self.sensor = MongoSensor( task_id='test_task', mongo_conn_id='mongo_test', dag=self.dag, collection='foo', query={'bar': 'baz'} )
class AndroidRunnerOperator(StockOperator): @apply_defaults def __init__(self, apk_id, apk_version, runner_conf, target_device=None, release_xcom_key="android_release", *args, **kwargs): super(AndroidRunnerOperator, self).__init__(queue='android', runner_conf=runner_conf, *args, **kwargs) self.apk_id = apk_id self.apk_version = apk_version self.apk_path = None self.test_apk_path = None self.serial = target_device self.release_xcom_key = release_xcom_key self.mongo_hk = MongoHook(conn_id='stocksdktest_mongo') self.conn = self.mongo_hk.get_conn() def install_apk(self, apk_files): """ :param apk_files: :type apk_files: list(operators.release_ci_operator.ReleaseFile) """ for file in apk_files: path = '/tmp/%s/%s' % (file.md5sum, file.name) download_file(url=file.url, file_path=path, md5=file.md5sum) if exec_adb_cmd(['adb', 'install', '-r', '-t', path], serial=self.serial) != 0: raise AirflowException('Install apk from %s failed' % file) def pre_execute(self, context): super(AndroidRunnerOperator, self).pre_execute(context) # TODO: TO Debug, so annotate it and return # it seems 2 apks have been installed and com.chi.ssetest too if not start_adb_server(): raise AirflowException('ADB Server can not start') if not self.serial: self.serial = scan_local_device() if not self.serial: raise AirflowException('can not scan device') if not connect_to_device(self.serial): print("serial", self.serial) raise AirflowException('can not connect to device "%s"' % self.serial) main_apk_version = get_app_version(self.serial, self.apk_id) print('Verify App(%s) version: %s, cur is %s' % (self.apk_id, self.apk_version, main_apk_version)) if self.apk_version == main_apk_version: return else: if main_apk_version is not None: # uninstall previous apk if exec_adb_cmd(['adb', 'uninstall', self.apk_id], serial=self.serial) != 0 or\ exec_adb_cmd(['adb', 'uninstall', '%s.test' % self.apk_id], serial=self.serial) != 0: raise AirflowException('Uninstall previous apk error') release_files = self.xcom_pull(context, key=self.release_xcom_key) print('release: %s' % release_files) if release_files is None or not isinstance(release_files, list): raise AirflowException( 'Can not get Android release assets: %s') self.install_apk(release_files) @staticmethod def protobuf_record_to_dict(record): if record is None: sys.stderr( 'TextExecutionRecordtoDict Type Error, param is NoneType') return if type(record) != TestExecutionRecord: sys.stderr( 'TextExecutionRecordtoDict Type Error, param is not TestExecutionRecord' ) return res = dict() res['jobID'] = record.jobID res['runnerID'] = record.runnerID res['testcaseID'] = record.testcaseID res['recordID'] = record.recordID res['isPass'] = record.isPass res['startTime'] = record.startTime res['paramData'] = bytes_to_dict(record.paramData) res['resultData'] = bytes_to_dict(record.resultData) res['exceptionData'] = bytes_to_dict(record.exceptionData) return res def pre_process_dot(self, record_dict_list): for i in range(record_dict_list.__len__()): resultData = record_dict_list[i]['resultData'] if resultData != None: old_keys = [] for k in resultData.keys(): old_keys.append(k) for old_key in old_keys: new_key = old_key.replace('.', '_') new_key = new_key.replace('$', '_') print('old_key', old_key) print('new_key', new_key) resultData[new_key] = resultData.pop(old_key) return record_dict_list def execute(self, context): record_dict_list = list() chunk_cache = LogChunkCache() def read_record(record_str): record = TestExecutionRecord() data = parse_logcat(chunk_cache, record_str) if data: record.ParseFromString(data) if len(record.ListFields()) > 0: print("*************************") print(record) record_dict_list.append( AndroidRunnerOperator.protobuf_record_to_dict(record)) print("*************************") spawn_logcat(serial=self.serial, logger=read_record) test_status_code = [] def check_test_result(line): if 'INSTRUMENTATION_STATUS_CODE:' in line: # find number in string, https://stackoverflow.com/a/29581287/9797889 codes = re.findall( "[-+]?[.]?[\d]+(?:,\d\d\d)*[\.]?\d*(?:[eE][-+]?\d+)?", line) # check whether code ONLY contains '0' or '1' test_status_code.extend(codes) command_to_script(args=[ 'am', 'instrument', '-w', '-r', '-e', 'debug', 'false', '-e', 'filter', 'com.chi.ssetest.TestcaseFilter', '-e', 'listener', 'com.chi.ssetest.TestcaseExecutionListener', '-e', 'collector_file', 'test.log', '-e', 'runner_config', base64_encode(self.runner_conf.SerializeToString()), 'com.chi.ssetest.test/android.support.test.runner.AndroidJUnitRunner' ], script_path='/tmp/test.sh') cmd_code_push = exec_adb_cmd( args=['adb', 'push', '/tmp/test.sh', '/data/local/tmp/'], serial=self.serial) cmd_logcat_clear = exec_adb_cmd(args=['adb', 'logcat', '-c'], serial=self.serial) cmd_code_exec = exec_adb_cmd( args=['adb', 'shell', 'sh', '/data/local/tmp/test.sh'], serial=self.serial, logger=check_test_result) # cmd_logcat = exec_adb_cmd(args=['adb','logcat','-c'], serial=self.serial) if cmd_code_push != 0 or cmd_code_exec != 0 or len(test_status_code) == 0 or \ (test_status_code.count('0') + test_status_code.count('1') < len(test_status_code)): raise AirflowException('Android Test Failed') client = self.mongo_hk.client db = client["stockSdkTest"] col = db[self.task_id + datetime.date.today().__str__()] print('Debug Airflow: dict_list:---------------') record_dict_list = self.pre_process_dot(record_dict_list) print(record_dict_list) try: col.insert_many(record_dict_list) except TypeError as s: print(s) self.xcom_push(context, key=self.task_id, value=s) finally: self.xcom_push(context, key=self.task_id, value=self.runner_conf.runnerID) self.xcom_push(context, key=self.runner_conf.runnerID, value=record_dict_list)
'B': dictObj2 } reorderedDictWithReorderedListsInValue = { 'B': dictObj2, 'A': [{ 'Y': 2 }, { 'X': [reorderedDictObj, dictObj2] }] } a = {"L": "M", "N": dictWithListsInValue} b = {"L": "M", "N": reorderedDictWithReorderedListsInValue} # return j1, j2 # return j3,j4 # return a,b r1 = {"1": j1, "2": j3, "3": a} r2 = {"1": j2, "2": j4, "3": b} return r1, r2 if __name__ == '__main__': mongo_hk = MongoHook() mongo_hk.uri = 'mongodb://localhost:27017/' r1, r2 = genTwoCase() a = DataCompareOperator(runner_conf='1', task_id='11', task_id_list=['a', 'b']) res = a.record_compare(r1, r2) print(res) # myclient = pymongo.MongoClient("mongodb://localhost:27017/")
def get_mongodb_connection(): conn = MongoHook(conn_id='playrecipe_mongo') return conn
class DataCompareOperator(BaseOperator): @apply_defaults def __init__(self, runner_conf, task_id_list, *args, **kwargs): super().__init__(queue='worker', *args, **kwargs) self.runner_conf = runner_conf self.task_id_list = task_id_list self.mongo_hk = MongoHook(conn_id='stocksdktest_mongo') self.conn = self.mongo_hk.get_conn() def close_connection(self): self.mongo_hk.close_conn() def get_ios_data(self): return 0 def get_android_data(self): return 0 def ordered(self, obj): if isinstance(obj, dict): return sorted((k, self.ordered(v)) for k, v in obj.items()) if isinstance(obj, list): return sorted(self.ordered(x) for x in obj) else: return obj def get_value_from_path(self, record, path): # 字符串是否表示一个数字 def is_str_int(str): return type(eval(str)) == int path_list = path.lstrip('/').split('/') src_a = record for key in path_list: # print('src_a:',src_a) # print('key:',key) if isinstance(src_a, list) and is_str_int(key): src_a = src_a[int(key)] elif isinstance(src_a, dict): src_a = src_a[key] else: raise TypeError("Error in json path") return src_a ''' 返回两个记录的比较 ''' def record_compare(self, record1, record2): res = (record1 == record2) resInfo = [] if res == True: print("Easy Json Dict , PASS") else: '''若嵌套了List,要忽略list的顺序,自上而下排序''' try: res = self.ordered(record1) == self.ordered(record2) print("Easy Json Dict With List") except TypeError as e: res = self.my_obj_cmp(record1, record2) print("Hard Json Dict With List") finally: if res == False: ''' 如果出现不一致,就使用json_patch进行不一致的寻找''' patch = jsonpatch.make_patch(record1, record2) patches = patch.patch false_cnts = 0 # record the real false numbers try: for item in patches: # print("-----------There is a option " + item['op']) if item['op'] == 'replace': src_a = self.get_value_from_path( record1, item['path']) src_b = item['value'] ''' if it's numbers ''' # TODO: Now it's toy numbers = False try: if isinstance(src_a, str) and isinstance( src_b, str): t1 = type(eval(src_a.strip('%'))) t2 = type(eval(src_b.strip('%'))) if t1 == t2: if t1 == int or t1 == float: numbers = True # print('Element val : ', src_a, src_b) except SyntaxError as e1: numbers = False except NameError as e2: numbers = False finally: if src_a != src_b: false_cnts += 1 resInfo.append({ 'type': 'Data Inconsistency', 'location': item['path'], 'src_a': src_a, 'src_b': src_b }) elif item['op'] == 'add' or item['op'] == 'remove': src_a = "not exist in src_a" src_b = "nut exist in src_b" # print(item) if item['op'] == 'add': src_b = item['value'] else: src_a = self.get_value_from_path( record1, item['path']) resInfo.append({ 'type': 'Data Amount Inconsistency', 'location': item['path'], 'src_a': src_a, 'src_b': src_b }) false_cnts += 1 elif item['op'] == 'move' or item['op'] == 'copy': ''' move equals remove and add''' ''' copy equals add the value in from to path ''' src_a = self.get_value_from_path( record1, item['from']) src_b = "nut exist in src_b" resInfo.append({ 'type': 'Data Amount Inconsistency', 'location': item['from'], 'src_a': src_a, 'src_b': src_b }) false_cnts += 1 elif item['op'] == 'test': print( "-----------There is a option Test TODO" + item['op']) except TypeError as e: resInfo = patches result = {"Consistency Result": res, "More Infomations": resInfo} return result def my_list_cmp(self, list1, list2): if (list1.__len__() != list2.__len__()): return False for l in list1: found = False for m in list2: res = self.my_obj_cmp(l, m) if (res): found = True break if (not found): return False return True def my_obj_cmp(self, obj1, obj2): # print('My Obj Cmp : ', obj1, obj2) if isinstance(obj1, list): ''' 若obj1为list,首先判断obj2是否也为list,是则继续调用my_list_cmp函数 ''' if (not isinstance(obj2, list)): return False return self.my_list_cmp(obj1, obj2) elif (isinstance(obj1, dict)): ''' 若obj1为dict,首先判断obj2是否也为dict,是则继续判断keys的集合是否一致, 是则对每个k对应的value进行比对,若为list或者dict,则递归调用, 否则直接比较 ''' if (not isinstance(obj2, dict)): return False exp = set(obj2.keys()) == set(obj1.keys()) if (not exp): # print(obj1.keys(), obj2.keys()) return False for k in obj1.keys(): val1 = obj1.get(k) val2 = obj2.get(k) if isinstance(val1, list): if (not self.my_list_cmp(val1, val2)): return False elif isinstance(val1, dict): if (not self.my_obj_cmp(val1, val2)): return False else: numbers = False try: if isinstance(val1, str) and isinstance(val2, str): t1 = type(eval(val1.strip('%'))) t2 = type(eval(val2.strip('%'))) if t1 == t2: if t1 == int or t1 == float: numbers = True print('Element val : ', val1, val2) except SyntaxError as e1: numbers = False except NameError as e2: numbers = False finally: if val2 != val1: return False else: # print('Element obj : ', obj1, obj2) return obj1 == obj2 return True def execute(self, context): myclient = self.mongo_hk.client mydb = myclient["stockSdkTest"] col1 = mydb[self.task_id_list[0] + datetime.date.today().__str__()] col2 = mydb[self.task_id_list[1] + datetime.date.today().__str__()] id1 = self.xcom_pull(context, key=self.task_id_list[0]) id2 = self.xcom_pull(context, key=self.task_id_list[1]) print('xcom_pull', id1) print('xcom_pull', id2) result = {} # TODO: Use Mongo To Selection, Maybe every DAG with a collection could be better? for x in col1.find(): for y in col2.find(): if x['paramData'] != None and y['paramData'] != None \ and x['testcaseID'] == y['testcaseID'] \ and x['runnerID'] == id1 and y['runnerID'] == id2\ and x['paramData'] == y['paramData']: # can use self.my_obj_cmp(x['paramData'], y['paramData']),but now is not necessary I think testcaseID = x['testcaseID'] print(x) print(y) r1 = x['resultData'] r2 = y['resultData'] print(r1) print(r2) resDBItem = {} resDBItem['Result_1'] = r1 resDBItem['RunnerID_1'] = id1 resDBItem['JobID_1'] = x['jobID'] resDBItem['Result_2'] = r2 resDBItem['RunnerID_2'] = id2 resDBItem['JobID_2'] = y['jobID'] if x['exceptionData'] != None or y['exceptionData'] != None: print('Exception Explode in ' + testcaseID) resDBItem['Exception_Data_1'] = x['exceptionData'] resDBItem['Exception_Data_2'] = y['exceptionData'] else: res = self.record_compare(r1, r2) print(res) for k, v in res.items(): resDBItem[k] = v if result.get(testcaseID) == None: result[testcaseID] = [] result[testcaseID].append(resDBItem) print(result) # {'OHLCV3_1': True, 'OHLCV3_2': True, 'OHLCV3_5': True} result_name = 'test_result' + datetime.date.today().__str__() col_res = mydb[result_name] try: col_res.insert_one(result) except TypeError as e: print(e) finally: self.xcom_push(context, key=result_name, value=result) print("over")
def __init__(self, runner_conf, task_id_list, *args, **kwargs): super().__init__(queue='worker', *args, **kwargs) self.runner_conf = runner_conf self.task_id_list = task_id_list self.mongo_hk = MongoHook(conn_id='stocksdktest_mongo') self.conn = self.mongo_hk.get_conn()
import os import logging import fnmatch import json import glob from pymongo.errors import BulkWriteError from typing import List from airflow.contrib.hooks.mongo_hook import MongoHook from common.utils import * from elasticsearch import Elasticsearch from elasticsearch.helpers import bulk # Setting up boto3 hook to AWS S3 s3_hook = S3Hook('my_conn_S3') # Setting up MongoDB hook to mlab server mongodb_hook = MongoHook('mongo_default') ftp_conn_id = "pubmed_ftp" s3bucket = 'case_reports' mongo_folder = 'casereports' def extract_pubmed_data() -> None: """Extracts case-reports from pubmed data and stores result on S3 """ pattern = "*.xml.tar.gz" ftp_path = '/pub/pmc/oa_bulk' root_dir = '/usr/local/airflow' temp_dir = os.path.join(root_dir, 'temp') bucket_name = 'supreme-acrobat-data' prefix = s3bucket + '/pubmed/original'
def get_mongo_hook(): return MongoHook()
def test_srv(self): hook = MongoHook(conn_id='mongo_default_with_srv') self.assertTrue(hook.uri.startswith('mongodb+srv://'))
def get_hook(self): if self.conn_type == 'mysql': from airflow.hooks.mysql_hook import MySqlHook return MySqlHook(mysql_conn_id=self.conn_id) elif self.conn_type == 'google_cloud_platform': from airflow.gcp.hooks.bigquery import BigQueryHook return BigQueryHook(bigquery_conn_id=self.conn_id) elif self.conn_type == 'postgres': from airflow.hooks.postgres_hook import PostgresHook return PostgresHook(postgres_conn_id=self.conn_id) elif self.conn_type == 'pig_cli': from airflow.hooks.pig_hook import PigCliHook return PigCliHook(pig_cli_conn_id=self.conn_id) elif self.conn_type == 'hive_cli': from airflow.hooks.hive_hooks import HiveCliHook return HiveCliHook(hive_cli_conn_id=self.conn_id) elif self.conn_type == 'presto': from airflow.hooks.presto_hook import PrestoHook return PrestoHook(presto_conn_id=self.conn_id) elif self.conn_type == 'hiveserver2': from airflow.hooks.hive_hooks import HiveServer2Hook return HiveServer2Hook(hiveserver2_conn_id=self.conn_id) elif self.conn_type == 'sqlite': from airflow.hooks.sqlite_hook import SqliteHook return SqliteHook(sqlite_conn_id=self.conn_id) elif self.conn_type == 'jdbc': from airflow.hooks.jdbc_hook import JdbcHook return JdbcHook(jdbc_conn_id=self.conn_id) elif self.conn_type == 'mssql': from airflow.hooks.mssql_hook import MsSqlHook return MsSqlHook(mssql_conn_id=self.conn_id) elif self.conn_type == 'oracle': from airflow.hooks.oracle_hook import OracleHook return OracleHook(oracle_conn_id=self.conn_id) elif self.conn_type == 'vertica': from airflow.contrib.hooks.vertica_hook import VerticaHook return VerticaHook(vertica_conn_id=self.conn_id) elif self.conn_type == 'cloudant': from airflow.contrib.hooks.cloudant_hook import CloudantHook return CloudantHook(cloudant_conn_id=self.conn_id) elif self.conn_type == 'jira': from airflow.contrib.hooks.jira_hook import JiraHook return JiraHook(jira_conn_id=self.conn_id) elif self.conn_type == 'redis': from airflow.contrib.hooks.redis_hook import RedisHook return RedisHook(redis_conn_id=self.conn_id) elif self.conn_type == 'wasb': from airflow.contrib.hooks.wasb_hook import WasbHook return WasbHook(wasb_conn_id=self.conn_id) elif self.conn_type == 'docker': from airflow.hooks.docker_hook import DockerHook return DockerHook(docker_conn_id=self.conn_id) elif self.conn_type == 'azure_data_lake': from airflow.contrib.hooks.azure_data_lake_hook import AzureDataLakeHook return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id) elif self.conn_type == 'azure_cosmos': from airflow.contrib.hooks.azure_cosmos_hook import AzureCosmosDBHook return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id) elif self.conn_type == 'cassandra': from airflow.contrib.hooks.cassandra_hook import CassandraHook return CassandraHook(cassandra_conn_id=self.conn_id) elif self.conn_type == 'mongo': from airflow.contrib.hooks.mongo_hook import MongoHook return MongoHook(conn_id=self.conn_id) elif self.conn_type == 'gcpcloudsql': from airflow.gcp.hooks.cloud_sql import CloudSqlDatabaseHook return CloudSqlDatabaseHook(gcp_cloudsql_conn_id=self.conn_id) elif self.conn_type == 'grpc': from airflow.contrib.hooks.grpc_hook import GrpcHook return GrpcHook(grpc_conn_id=self.conn_id) raise AirflowException("Unknown hook type {}".format(self.conn_type))
def poke(self, context): self.log.info("Sensor check existence of the document " "that matches the following query: %s", self.query) hook = MongoHook(self.mongo_conn_id) return hook.find(self.collection, self.query, find_one=True) is not None