def doTestMysqlHook(*args, **kwargs): sql_hook = MySqlHook().get_hook(conn_id="mysql_operator_test_connid") sql = "select * from manzeng_predict_src_table;" result = sql_hook.get_records(sql) for row in result: print(row) sql = "select max(id) as max_id from manzeng_predict_src_table" result = sql_hook.get_records(sql) print('maxid:' + str(result[0][0])) result = sql_hook.get_first(sql) print('maxid:' + str(result[0])) LoggingMixin.log.exception("exception raise test") sql_hook.run( """insert into manzeng_result_v3(consignor_phone,prediction) values('122','33')""" )
def get_signers(instance_id, context): db = MySqlHook(mysql_conn_id='mariadb', schema="dapp") sql = f""" select instance_id, sign_area_id, sequence, sub_instance_id, sign_section, sign_position, sign_action, is_executed, group_culture, group_id, group_name, created_date, received_date, approved_date from signers where instance_id = %s """ tasks = {} rows = db.get_records(sql, parameters=[instance_id]) tasks[SIGNERS] = [] for row in rows: model = { 'instance_id': row[0], 'sign_area_id': row[1], 'sequence': row[2], 'sub_instance_id': row[3], 'sign_section': row[4], 'sign_position': row[5], 'sign_action': row[6], 'is_executed': row[7], 'group_culture': row[8], 'group_id': row[9], 'group_name': row[10], 'created_date': str(row[11]), 'received_date': str(row[12]), 'approved_date': str(row[13]) } tasks[SIGNERS].append(model) context['ti'].xcom_push(key=SIGNERS, value=tasks[SIGNERS]) return list(tasks.values())
def get_sign_activity(instance_id, context): db = MySqlHook(mysql_conn_id='mariadb', schema="dapp") sql = f""" select instance_id, sign_area_id, sequence, sign_area, sign_section, sign_position, sign_action, is_comment, is_executed, group_id, user_id, host_address from sign_activity where instance_id = %s """ tasks = {} tasks[SIGN_ACTIVITY] = [] rows = db.get_records(sql, parameters=[instance_id]) for row in rows: model = { 'instance_id': row[0], 'sign_area_id': row[1], 'sequence': row[2], 'sign_area': row[3], 'sign_section': row[4], 'sign_position': row[5], 'sign_action': row[6], 'is_comment': row[7], 'is_executed': row[8], 'group_id': row[9], 'user_id': row[10], 'host_address': row[11] } tasks[SIGN_ACTIVITY].append(model) context['ti'].xcom_push(key=SIGN_ACTIVITY, value=tasks[SIGN_ACTIVITY]) return list(tasks.values())
def objects(self): where_clause = '' if DB_WHITELIST: dbs = ",".join(["'" + db + "'" for db in DB_WHITELIST]) where_clause = "AND b.name IN ({})".format(dbs) if DB_BLACKLIST: dbs = ",".join(["'" + db + "'" for db in DB_BLACKLIST]) where_clause = "AND b.name NOT IN ({})".format(dbs) sql = """ SELECT CONCAT(b.NAME, '.', a.TBL_NAME), TBL_TYPE FROM TBLS a JOIN DBS b ON a.DB_ID = b.DB_ID WHERE a.TBL_NAME NOT LIKE '%tmp%' AND a.TBL_NAME NOT LIKE '%temp%' AND b.NAME NOT LIKE '%tmp%' AND b.NAME NOT LIKE '%temp%' {where_clause} LIMIT {LIMIT}; """.format(where_clause=where_clause, LIMIT=TABLE_SELECTOR_LIMIT) h = MySqlHook(METASTORE_MYSQL_CONN_ID) d = [ {'id': row[0], 'text': row[0]} for row in h.get_records(sql)] return json.dumps(d)
def execute(self, context): log.info("이찬호") db = MySqlHook(mysql_conn_id=self.db_conn_id, schema=self.db_schema) sql = """ select o.id, o.name, o.desc from test o """ # initialize the task list buckets tasks = {} index = 0 rows = db.get_records(sql) for row in rows: index += 1 tasks[f'get_workflow_{index}'] = [] resultCounter = 0 for row in rows: resultCounter += 1 bucket = (resultCounter % index) model = {'id': str(row[0]), 'name': str(row[1])} tasks[f'get_workflow_{bucket}'].append(model) # Push the order lists into xcom for task in tasks: if len(tasks[task]) > 0: logging.info(f'Task {task} has {len(tasks[task])} orders.') context['ti'].xcom_push(key=task, value=tasks[task]) return list(tasks.values())
class MysqlQueryOperatorWithTemplatedParams(BaseOperator): template_fields = ('sql', 'parameters') template_ext = ('.sql',) ui_color = '#ededed' @apply_defaults def __init__( self, sql, mysql_conn_id='mysql_dwh', autocommit=False, parameters=None, *args, **kwargs): super(MysqlQueryOperatorWithTemplatedParams, self).__init__(*args, **kwargs) self.sql = sql self.mysql_conn_id = mysql_conn_id self.autocommit = autocommit self.parameters = parameters def execute(self, context): logging.info('Executing: ' + str(self.sql), self.parameters) self.hook = MySqlHook(mysql_conn_id=self.mysql_conn_id) data = self.hook.get_records(self.sql, parameters=self.parameters) #logging.info('Executed: ' + str(x)) return data
def objects(self): where_clause = '' if DB_WHITELIST: dbs = ",".join(["'" + db + "'" for db in DB_WHITELIST]) where_clause = "AND b.name IN ({})".format(dbs) if DB_BLACKLIST: dbs = ",".join(["'" + db + "'" for db in DB_BLACKLIST]) where_clause = "AND b.name NOT IN ({})".format(dbs) sql = """ SELECT CONCAT(b.NAME, '.', a.TBL_NAME), TBL_TYPE FROM TBLS a JOIN DBS b ON a.DB_ID = b.DB_ID WHERE a.TBL_NAME NOT LIKE '%tmp%' AND a.TBL_NAME NOT LIKE '%temp%' AND b.NAME NOT LIKE '%tmp%' AND b.NAME NOT LIKE '%temp%' {where_clause} LIMIT {LIMIT}; """.format(where_clause=where_clause, LIMIT=TABLE_SELECTOR_LIMIT) h = MySqlHook(METASTORE_MYSQL_CONN_ID) d = [ {'id': row[0], 'text': row[0]} for row in h.get_records(sql)] return json.dumps(d)
def index(self): end_date = datetime.now().date() start_date = end_date - timedelta(days=1) sql = """ select a.dag_id, a.state, min(a.start_date) as start_date, max(a.end_date) as end_date, max(a.end_date)-min(a.start_date) as duration, b.job_type, a.job_id from task_instance as a join job as b ON a.job_id = b.id where a.start_date >= "{start_date}" and a.start_date < "{end_date}" and a.state != 'failed' group by a.dag_id, a.job_id order by start_date; """.format(start_date=start_date, end_date=end_date) h = MySqlHook(METASTORE_MYSQL_CONN_ID) rows = h.get_records(sql) tasks = [] taskNames = [] name_set = set("") time_format = "%Y-%m-%dT%H:%M:%S" for row in rows: dag_id = row[0] state = row[1] start_date = row[2] end_date = row[3] if not end_date: end_date = datetime.now() duration = str(row[4]) task = { 'status': state, 'taskName': dag_id, 'startDate': time.mktime(start_date.timetuple()) * 1000, 'endDate': time.mktime(end_date.timetuple()) * 1000, 'executionDate': start_date.strftime(time_format), 'isoStart': start_date.strftime(time_format), 'isoEnd': end_date.strftime(time_format), 'duration': duration } taskNames.append(dag_id) name_set.add(dag_id) tasks.append(task) data = { 'height': 20 * len(name_set), 'tasks': tasks, 'taskNames': taskNames, 'taskStatus': { 'success': 'success' } } return self.render("scheduler_browser/gantt.html", data=data)
def get_groups(group_id, context): db = MySqlHook(mysql_conn_id='mariadb', schema="dbo") sql = f""" select group_id, parent_id, name, culture, group_code, sequence, is_childs, depth, is_display, email, interface_id, remark, created_date, modified_date from groups where group_id = %s """ task = {} rows = db.get_records(sql, parameters=[group_id]) for row in rows: model = { 'group_id': row[0], 'parent_id': row[1], 'name': row[2], 'culture': row[3], 'group_code': row[4], 'sequence': row[5], 'is_childs': row[6], 'depth': row[7], 'is_display': row[8], 'email': row[9], 'remark': row[10], 'created_date': str(row[11]), 'modified_date': str(row[12]) } task = model context['ti'].xcom_push(key=GROUPS, value=task) return task
def get_max_mysql(connection_name, schema, table_name, column): logging.info('Executing: SELECT max(' + str(column) + ') FROM ' + str(schema) + '.' + str(table_name)) hook = MySqlHook(mysql_conn_id=connection_name) output = hook.get_records('SELECT max(' + str(column) + ') FROM ' + str(schema) + '.' + str(table_name)) logging.info(output[0][0]) return output[0][0]
def filter_db(): api = MySqlHook() data = api.get_records(sql='select * from movie where vote_average > 7') # truncate table filter api.run(sql='truncate table movie_filter') # insert ke table filter api.insert_rows(table='movie_filter', rows=data)
def move_mysql_to_redshift(tablename): mysql_hook = MySqlHook(mysql_conn_id='mysql_baseball') redshift_hook = PostgresHook(postgres_conn_id='redshift_host') sql = "select league_id, name, abbr from " sql += tablename cur = mysql_hook.get_records(sql) redshift_hook.insert_rows(tablename,cur) return cur;
def copy_table(self, mysql_conn_id, postgres_conn_id): print("### fetching records from MySQL table ###") mysqlserver = MySqlHook(mysql_conn_id) sql_query = "SELECT * from clean_store_transactions " data = mysqlserver.get_records(sql_query) print("### inserting records into Postgres table ###") postgresserver = PostgresHook(postgres_conn_id) postgres_query = "INSERT INTO clean_store_transactions VALUES(%s, %s, %s, %s, %s, %s, %s, %s, %s);" postgresserver.insert_rows(table='clean_store_transactions', rows=data) return True
def step2(ds, **kargs): mysql_hook = MySqlHook(mysql_conn_id = 'cloudsql-test') items = mysql_hook.get_records("SELECT policyID FROM sample_db_1.SAMPLE_TABLE_4 limit 20") # mysql_hook = MySqlHook(mysql_conn_id='local_mysql') # items = mysql_hook.get_records("SELECT samepleid FROM cloudtest.SAMPLE_TABLE_5 limit 20") mail_list = [] for r in items: # print 'mail:%s ' % r mail_list.append('%d' % r) # print(mail_list) params = ','.join(mail_list) logging.info("params:{}".format(params))
def get_Mysql_data(): def convert(a): return str(a).replace('(', '').replace(')', '').replace(',', '').replace("'", '') mysql_hook = MySqlHook(mysql_conn_id='test_mysql') sql = "SELECT sample_id FROM actg_samples where active_state=1" records = mysql_hook.get_records(sql=sql) final = [] for a in records: final.append(convert(a)) return final
def loggin_updateids_from_kv(**kwargs): mysql_hook = MySqlHook(mysql_conn_id='cloudsql-test') items = mysql_hook.get_records( "SELECT samepleid FROM {}.{} where sync_status = 0".format( kwargs['export_database'], kwargs['export_table'])) # mysql_hook = MySqlHook(mysql_conn_id='local_mysql') # items = mysql_hook.get_records("SELECT samepleid FROM cloudtest.SAMPLE_TABLE_5 limit 20") mail_list = [] for r in items: # print 'mail:%s ' % r mail_list.append('%d' % r) # print(mail_list) kwargs['update_id'] = ','.join(mail_list) logging.info("loggin_updateids_from_kv_ids:{}".format(kwargs['update_id']))
def checkRecords(**kwargs): # Check for data, if present write results to file # If empty, log information to different file mysql_hook = MySqlHook(mysql_conn_id='local_mysql') # Table test1 contains list of postive integers sql = """ select * from test1 """ records = mysql_hook.get_records(sql) # Pushing to Task instance if len(records) == 0: kwargs['task_instance'].xcom_push(key='check', value=False) else: kwargs['task_instance'].xcom_push(key='check', value=True) kwargs['task_instance'].xcom_push(key='recordcount', value=len(records))
def get_max_updated_at(conn_id, table_name): """ Gets the max updated_at timestamp. If table is empty it returns None :param conn_id: Connection ID from the DB to search for :type conn_id: str :param table_name: name of the table to extract the max updated_at :type table_name: str :return: str with timestamp or None """ mysql_hook = MySqlHook(conn_id) sql_query = "SELECT MAX(updated_at) FROM {}".format(table_name) max_updated_at = mysql_hook.get_records(sql_query)[0][ 0] # row 0, column 0. Only one record returned return max_updated_at
def get_workflow(**context): db = MySqlHook(mysql_conn_id='mariadb', schema="djob") sql = """ select workflow_process_id,ngen,site_id,application_id,instance_id,schema_id,name,workflow_instance_id,state,retry_count,ready, execute_date,created_date,bookmark,version,request,reserved,message from workflow_process where ready > 0 and retry_count < 10 limit 1 """ task = {} rows = db.get_records(sql) for row in rows: model = { 'workflow_process_id': row[0], 'ngen': row[1], 'site_id': row[2], 'application_id': row[3], 'instance_id': row[4], 'schema_id': row[5], 'name': row[6], 'workflow_instance_id': row[7], 'state': row[8], 'retry_count': row[9], 'ready': row[10], 'execute_date': str(row[11]), 'created_date': str(row[12]), 'bookmark': row[13], 'version': row[14], 'request': row[15], 'reserved': row[16], 'message': row[17] } task = model # 객체가 있는 경우 처리 if task != {}: context['ti'].xcom_push(key=WORKFLOWS, value=task) sql = f""" update workflow_process set ready = 0, bookmark = 'start' where workflow_process_id = %s """ db.run(sql, autocommit=True, parameters=[task['workflow_process_id']])
def poke(self, context): db = MySqlHook(mysql_conn_id='mariadb', schema="djob") sql = """ select workflow_process_id,ngen,site_id,application_id,instance_id,schema_id,name,workflow_instance_id,state,retry_count,ready, execute_date,created_date,bookmark,version,request,reserved,message from workflow_process where ready > 0 and retry_count < 10 """ tasks = {} tasks[WORKFLOW_PROCESS] = [] rows = db.get_records(sql) for row in rows: model = { 'workflow_process_id':row[0], 'ngen':row[1], 'site_id':row[2], 'application_id':row[3], 'instance_id':row[4], 'schema_id':row[5], 'name':row[6], 'workflow_instance_id':row[7], 'state':row[8], 'retry_count':row[9], 'ready':row[10], 'execute_date':str(row[11]), 'created_date':str(row[12]), 'bookmark':row[13], 'version':row[14], 'request':row[15], 'reserved':row[16], 'message':row[17] } tasks[WORKFLOW_PROCESS].append(model) # 객체가 있는 경우 처리 if tasks[WORKFLOW_PROCESS] != []: log.info('workflow_process find data') context['ti'].xcom_push(key=WORKFLOW_PROCESS, value=tasks[WORKFLOW_PROCESS]) return True else: log.info('workflow_process empty data') return False
def get_instance(**context): workflow = context['ti'].xcom_pull(task_ids=WORKFLOW_START_TASK, key=WORKFLOWS) db = MySqlHook(mysql_conn_id='mariadb', schema="dapp") instance_id = int(workflow["instance_id"]) sql = f""" select instance_id,state,form_id,parent_id,workflow_id,subject,creator_culture,creator_name,group_culture,group_name,is_urgency,is_comment,is_related_document, attach_count,summary,re_draft_group_id,sub_proc_group_id,interface_id,created_date,completed_date, creator_id,group_id from instances where instance_id = %s """ task = {} rows = db.get_records(sql, parameters=[instance_id]) for row in rows: model = { 'instance_id': row[0], 'state': row[1], 'form_id': row[2], 'parent_id': row[3], 'workflow_id': row[4], 'subject': row[5], 'creator_culture': row[6], 'creator_name': row[7], 'group_culture': row[8], 'group_name': row[9], 'is_urgency': row[10], 'is_comment': row[11], 'is_related_document': row[12], 'attach_count': row[13], 'summary': row[14], 're_draft_group_id': row[15], 'sub_proc_group_id': row[16], 'interface_id': row[17], 'created_date': str(row[18]), 'completed_date': str(row[19]), 'creator_id': row[20], 'group_id': row[21] } task = model context['ti'].xcom_push(key=INSTANCES, value=task) return task
def check_previous_runs(**kwargs): context = kwargs current_run_id = context['dag_run'].run_id current_dag_id = context['dag_run'].dag_id # Connect to mysql and check for any errors for this DAG airflow_conn = MySqlHook(mysql_conn_id='deliverbi_mysql_airflow') l_error_count = 0 cmd_sql = f"select count(1) from airflow.dag_run where dag_id = '{current_dag_id}' " cmd_sql += f"and run_id <> '{current_run_id}' and state = 'failed'" print(cmd_sql) airflow_data = airflow_conn.get_records(sql=cmd_sql) for row in airflow_data: l_error_count = int((str(row[0]))) print("Found Previous Errors:" + str(l_error_count)) if l_error_count != 0: raise AirflowException( "Previous Run in Error so Failing the Current Run")
def get_data_from_mysql(filename, tablename): hook = MySqlHook(mysql_conn_id='mysql_baseball') sql = "select * from " sql += tablename cur = hook.get_records(sql) # f = open(filename,'w') # print >>f, cur c = csv.writer(open(filename, "wb"), quoting=csv.QUOTE_NONNUMERIC) for row in cur: c.writerow(row) c.close() return cur
def get_users(user_id, context): db = MySqlHook(mysql_conn_id='mariadb', schema="dbo") sql = f""" select user_id, name, culture, group_id, employee_num, anonymous_name, email, theme_code, date_format_code, time_format_code, time_zone, row_count, language_code, interface_id, phone, mobile, fax, icon, addsign_img, is_plural, is_notification, is_absence, is_deputy from users where user_id = %s """ task = {} rows = db.get_records(sql, parameters=[user_id]) for row in rows: model = { 'user_id': row[0], 'name': row[1], 'culture': row[2], 'group_id': row[3], 'employee_num': row[4], 'anonymous_name': row[5], 'email': row[6], 'theme_code': row[7], 'date_format_code': row[8], 'time_format_code': row[9], 'row_count': row[10], 'language_code': row[11], 'interface_id': row[12], 'phone': row[13], 'mobile': row[14], 'fax': row[15], 'icon': row[16], 'addsign_img': row[17], 'is_plural': row[18], 'is_notification': row[19], 'is_absence': row[20], 'is_deputy': row[21] } task = model context['ti'].xcom_push(key=USERS, value=task) return task
def get_columns_and_exclude(conn_id, table_name, l_columns_exclude): """ :param conn_id: connection id to connect to :param table_name: table to get the columns :param l_columns_exclude: list of strings of columns to exclude :return: list of strings without the excluded columns """ mysql_hook = MySqlHook(conn_id) sql_query = "SHOW COLUMNS FROM {}".format(table_name) all_records = mysql_hook.get_records(sql_query) l_columns_after_exclude = [ "t.`{}`".format(l_row[0]) for l_row in all_records if l_row[0] not in l_columns_exclude ] logging.debug( "Columns after exclude: '{}'".format(l_columns_after_exclude)) return l_columns_after_exclude
def send_aggregate_to_requestbin(): target = 'http://requestbin.net/r/zorarbzo' connection = MySqlHook(mysql_conn_id='mysql_default') sql = ''' SELECT film_name, name, birth_year FROM `swapi_data`.`swapi_people_aggregate`; ''' result = connection.get_records(sql) data = [] for item in result: data.append({ "film_name": item[0], "name": item[1], "birth_year": str(item[2]) }) result = requests.post(target, data=json.dumps(data)) return result
def get_group_users(gid, uid, context): db = MySqlHook(mysql_conn_id='mariadb', schema="dbo") sql = f""" select b.*, a.relation_type, a.is_master from dbo.group_users a inner join dbo.groups b on a.group_id = b.group_id where a.user_id = %s and a.parent_id = %s """ tasks = {} tasks[GROUP_USERS] = [] rows = db.get_records(sql, parameters=[uid, gid]) for row in rows: model = { 'group_id': row[0], 'parent_id': row[1], 'name': row[2], 'culture': row[3], 'group_code': row[4], 'sequence': row[5], 'is_childs': row[6], 'depth': row[7], 'is_display': row[8], 'email': row[9], 'remark': row[10], 'created_date': str(row[11]), 'modified_date': str(row[12]), 'relation_type': row[13], 'is_master': row[14] } tasks[GROUP_USERS].append(model) context['ti'].xcom_push(key=GROUPS, value=tasks[GROUP_USERS]) return list(tasks.values())
def getWorkflows(**context): db = MySqlHook(mysql_conn_id='mariadb', schema="djob") # initialize the task list buckets sql = """ select o.id, o.name, o.desc from test o """ tasks = {} rowCount = 0 rows = db.get_records(sql) for row in rows: rowCount += 1 tasks[f'order_processing_task_{rowCount}'] = [] # populate the task list buckets # distribute them evenly across the set of buckets # records: List[List[Optional[Any]]] = db.get_records(get_orders_query) resultCount = 0 for row in rows: resultCount += 1 model = {'id': str(row[0]), 'name': str(row[1])} # tasks[f'order_processing_task_{bucket}'] = [] tasks[f'order_processing_task_{resultCount}'].append(model) items = {} # Push the order lists into xcom for task in tasks: if len(tasks[task]) > 0: logging.info(f'Task {task} has {len(tasks[task])} orders.') context['ti'].xcom_push(key=task, value=tasks[task]) return list(tasks.values())
def execute(self, context=None): metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id) table = metastore.get_table(table_name=self.table) field_types = {col.name: col.type for col in table.sd.cols} exprs = { ('', 'count'): 'COUNT(*)' } for col, col_type in list(field_types.items()): d = {} if self.assignment_func: d = self.assignment_func(col, col_type) if d is None: d = self.get_default_exprs(col, col_type) else: d = self.get_default_exprs(col, col_type) exprs.update(d) exprs.update(self.extra_exprs) exprs = OrderedDict(exprs) exprs_str = ",\n ".join([ v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items()]) where_clause = ["{} = '{}'".format(k, v) for k, v in self.partition.items()] where_clause = " AND\n ".join(where_clause) sql = "SELECT {exprs_str} FROM {table} WHERE {where_clause};".format( exprs_str=exprs_str, table=self.table, where_clause=where_clause) presto = PrestoHook(presto_conn_id=self.presto_conn_id) self.log.info('Executing SQL check: %s', sql) row = presto.get_first(hql=sql) self.log.info("Record: %s", row) if not row: raise AirflowException("The query returned None") part_json = json.dumps(self.partition, sort_keys=True) self.log.info("Deleting rows from previous runs if they exist") mysql = MySqlHook(self.mysql_conn_id) sql = """ SELECT 1 FROM hive_stats WHERE table_name='{table}' AND partition_repr='{part_json}' AND dttm='{dttm}' LIMIT 1; """.format(table=self.table, part_json=part_json, dttm=self.dttm) if mysql.get_records(sql): sql = """ DELETE FROM hive_stats WHERE table_name='{table}' AND partition_repr='{part_json}' AND dttm='{dttm}'; """.format(table=self.table, part_json=part_json, dttm=self.dttm) mysql.run(sql) self.log.info("Pivoting and loading cells into the Airflow db") rows = [(self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, row)] mysql.insert_rows( table='hive_stats', rows=rows, target_fields=[ 'ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value', ] )
def execute(self, context=None): metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id) table = metastore.get_table(table_name=self.table) field_types = {col.name: col.type for col in table.sd.cols} exprs = {('', 'count'): 'COUNT(*)'} for col, col_type in list(field_types.items()): d = {} if self.assignment_func: d = self.assignment_func(col, col_type) if d is None: d = self.get_default_exprs(col, col_type) else: d = self.get_default_exprs(col, col_type) exprs.update(d) exprs.update(self.extra_exprs) exprs = OrderedDict(exprs) exprs_str = ",\n ".join( [v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items()]) where_clause = [ "{} = '{}'".format(k, v) for k, v in self.partition.items() ] where_clause = " AND\n ".join(where_clause) sql = "SELECT {exprs_str} FROM {table} WHERE {where_clause};".format( exprs_str=exprs_str, table=self.table, where_clause=where_clause) presto = PrestoHook(presto_conn_id=self.presto_conn_id) self.log.info('Executing SQL check: %s', sql) row = presto.get_first(hql=sql) self.log.info("Record: %s", row) if not row: raise AirflowException("The query returned None") part_json = json.dumps(self.partition, sort_keys=True) self.log.info("Deleting rows from previous runs if they exist") mysql = MySqlHook(self.mysql_conn_id) sql = """ SELECT 1 FROM hive_stats WHERE table_name='{table}' AND partition_repr='{part_json}' AND dttm='{dttm}' LIMIT 1; """.format(table=self.table, part_json=part_json, dttm=self.dttm) if mysql.get_records(sql): sql = """ DELETE FROM hive_stats WHERE table_name='{table}' AND partition_repr='{part_json}' AND dttm='{dttm}'; """.format(table=self.table, part_json=part_json, dttm=self.dttm) mysql.run(sql) self.log.info("Pivoting and loading cells into the Airflow db") rows = [(self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, row)] mysql.insert_rows(table='hive_stats', rows=rows, target_fields=[ 'ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value', ])
def poke(self, context): # Setting default args for ngap2 airflow db db_type = 'mysql' internal_db = True # Reading connection to check db type if self.cluster_id is not None: conn = BaseHook.get_connection(self.cluster_id) logging.info( "checking for Extra field in connections to determine db type mysql or postgres or lambda" ) # check for db type from extra args for arg_name, arg_val in conn.extra_dejson.items(): if arg_name in ['db_type']: db_type = arg_val if arg_name in ['internal_db']: internal_db = False if self.execution_delta: dttm = context['execution_date'] - self.execution_delta elif self.execution_delta_json: hour = context['execution_date'].strftime('%H') delta = self.execution_delta_json[hour] hour_d = int(delta.split(':')[0]) minute_d = int( delta.split(':')[1] if len(delta.split(':')) > 1 else '00') final_minutes = 0 if hour_d < 0: final_minutes = (hour_d * 60) - minute_d else: final_minutes = (hour_d * 60) + minute_d dttm = context['execution_date'] - timedelta(minutes=final_minutes) else: dttm = context['execution_date'] allowed_states = tuple(self.allowed_states) if len(allowed_states) == 1: sql = " SELECT ti.task_id FROM task_instance ti WHERE ti.dag_id = '{self.external_dag_id}' " \ "AND ti.task_id = '{self.external_task_id}' AND ti.state = ('{allowed_states[0]}') " \ "AND ti.execution_date = '{dttm}';".format(**locals()) else: sql = "SELECT ti.task_id FROM task_instance ti WHERE ti.dag_id = '{self.external_dag_id}' " \ "AND ti.task_id = '{self.external_task_id}' AND ti.state in {allowed_states} " \ "AND ti.execution_date = '{dttm}';".format(**locals()) if self.cluster_id is None: logging.info('Poking for ' '{self.external_dag_id}.' '{self.external_task_id} on ' '{dttm} ... '.format(**locals())) TI = TaskInstance session = settings.Session() count = session.query(TI).filter( TI.dag_id == self.external_dag_id, TI.task_id == self.external_task_id, TI.state.in_(self.allowed_states), TI.execution_date == dttm, ).count() session.commit() session.close() elif internal_db == True: logging.info('Poking for ' '{self.external_dag_id}.' '{self.external_task_id} on ' '{dttm} on {self.cluster_id} ... '.format(**locals())) hook = None if db_type == 'mysql': from airflow.hooks.mysql_hook import MySqlHook hook = MySqlHook(mysql_conn_id=self.cluster_id) elif db_type == 'postgres': from airflow.hooks.postgres_hook import PostgresHook hook = PostgresHook(postgres_conn_id=self.cluster_id) else: raise Exception("Please specify correct db type") records = hook.get_records(sql) if not records: count = 0 else: if str(records[0][0]) in ( '0', '', ): count = 0 else: count = len(records) logging.info('task record found') else: host = conn.host user = conn.login password = conn.password dbname = conn.schema port = conn.port url = None for arg_name, arg_val in conn.extra_dejson.items(): if arg_name in ['endpoint']: url = arg_val if not url: raise KeyError( 'Lambda endpoint is not specified in Extra args') logging.info('Poking for ' '{self.external_dag_id}.' '{self.external_task_id} on ' '{dttm} on {self.cluster_id} ... '.format(**locals())) import requests from requests.adapters import HTTPAdapter from requests.packages.urllib3.util.retry import Retry payload = "{\"ENDPOINT\":\"" + host + "\",\"PORT\":\"" + str( port ) + "\",\"DBUSER\":\"" + user + "\",\"DBPASSWORD\":\"" + password + "\",\"DATABASE\":\"" + dbname + "\", \"DB_TYPE\":\"" + db_type + "\",\"QUERY\":\"" + sql + "\"}" headers = { 'Content-Type': "application/json", 'cache-control': "no-cache", } session = requests.Session() retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504, 500]) session.mount('https://', HTTPAdapter(max_retries=retries)) response = session.post(url.replace(' ', ''), headers=headers, data=payload) if response.status_code == 200: count = int(response.text) else: raise Exception(response.content) logging.info(count) return count
def execute(self, context): self.log.info('Executing: %s', self.sql) hook = MySqlHook(mysql_conn_id=self.mysql_conn_id, schema=self.database) return hook.get_records(self.sql, parameters=self.parameters)