def execute(self, context): hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) logging.info("Extracting data from Hive") logging.info(self.sql) if self.bulk_load: tmpfile = NamedTemporaryFile() hive.to_csv(self.sql, tmpfile.name, delimiter='\t', lineterminator='\n', output_header=False) else: results = hive.get_records(self.sql) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) if self.mysql_preoperator: logging.info("Running MySQL preoperator") mysql.run(self.mysql_preoperator) logging.info("Inserting rows into MySQL") if self.bulk_load: mysql.bulk_load(table=self.mysql_table, tmp_file=tmpfile.name) tmpfile.close() else: mysql.insert_rows(table=self.mysql_table, rows=results) if self.mysql_postoperator: logging.info("Running MySQL postoperator") mysql.run(self.mysql_postoperator) logging.info("Done.")
def set_sign_users(doc, context): db = MySqlHook(mysql_conn_id='mariadb', schema="dapp") # conn = db.get_conn() # cursor = conn.cursor() sql = f""" insert into sign_users(instance_id, sign_area_id, sequence, user_culture, user_id, user_name, responsibility, position, class_position, host_address, reserved_date, delay_time, is_deputy, is_comment) values(%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s) """ # result = cursor.execute(sql) # logging.info(f'cursor result: {result}') db.run(sql, autocommit=True, parameters=[ doc.find("instance_id").text, doc.find('sign_area_id').text, doc.find('sequence').text, doc.find('user_culture').text, doc.find('user_id').text, doc.find('user_name').text, doc.find('responsibility').text, doc.find('position').text, doc.find('class_position').text, doc.find('host_address').text, doc.find('reserved_date').text, doc.find('delay_time').text, doc.find('is_deputy').text, doc.find('is_comment').text ])
def execute(self, context): logging.info('Executing: ' + str(self.sql)) hook = MySqlHook(mysql_conn_id=self.mysql_conn_id) hook.run( self.sql, autocommit=self.autocommit, parameters=self.parameters)
def execute(self, context): self.log.info('Executing: %s', self.sql) hook = MySqlHook(mysql_conn_id=self.mysql_conn_id, schema=self.database) hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
def execute(self, context): logging.info('Executing: ' + str(self.sql)) src_mysql = MySqlHook(mysql_conn_id=self.src_mysql_conn_id) dest_mysql = MySqlHook(mysql_conn_id=self.dest_mysqls_conn_id) logging.info( "Transferring Mysql query results into other Mysql database.") conn = src_mysql.get_conn() cursor = conn.cursor() cursor.execute(self.sql, self.query_parameters) if self.mysql_preoperator: logging.info("Running Mysql preoperator") dest_mysql.run(self.mysql_preoperator) if cursor.rowcount != 0: logging.info("Inserting rows into Mysql") for i, row in enumerate(cursor): print("row", row) dest_mysql.insert_rows(table=self.dest_table, rows=cursor) logging.info(str(cursor.rowcount) + " rows inserted") else: logging.info("No rows inserted") if self.mysql_postoperator: logging.info("Running Mysql postoperator") dest_mysql.run(self.mysql_postoperator) logging.info("Done.")
def save_data_into_db(): mysql_hook = MySqlHook(mysql_conn_id='covid19') with open('data.json') as f: data = json.load(f) insert = """ INSERT INTO daily_covid19_reports ( confirmed, recovered, hospitalized, deaths, new_confirmed, new_recovered, new_hospitalized, new_deaths, update_date, source, dev_by, server_by) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s); """ mysql_hook.run(insert, parameters=(data['Confirmed'], data['Recovered'], data['Hospitalized'], data['Deaths'], data['NewConfirmed'], data['NewRecovered'], data['NewHospitalized'], data['NewDeaths'], datetime.strptime(data['UpdateDate'], '%d/%m/%Y %H:%M'), data['Source'], data['DevBy'], data['SeverBy']))
def build_aggregrate(): connection = MySqlHook(mysql_conn_id='mysql_default') sql = ''' INSERT INTO `swapi_data`.`swapi_people_aggregate` (film, birth_year, name, film_name) SELECT film, max(birth_year_number) as birth_year, ( SELECT name FROM swapi_data.swapi_people WHERE film = t.film ORDER BY birth_year_number DESC LIMIT 0,1 ) as name, film_name FROM swapi_data.swapi_people t GROUP BY film, film_name; ''' connection.run(sql, autocommit=True, parameters=()) return True
def store_data_mysql(**kwargs): # task instance kwargs provided by the context and op_kwargs ti = kwargs['ti'] # get data from XComs list_new_weather = [ ti.xcom_pull(key=None, task_ids='get_weather_' + str(city).replace(' ', '')) for city in cities ] # connect to MySQL server (and database!) through MySqlHook connection = MySqlHook(mysql_conn_id='datascientest_sql_weather') # create table if not exists to store weather data sql_creation = 'CREATE TABLE IF NOT EXISTS cities_live (id int primary ' \ 'key not null auto_increment, city varchar(1000), temp_live float, ' \ 'temp_min float, temp_max float, humidity float, pressure float, weather_description varchar(1000), ' \ 'wind_speed float, time datetime)' connection.run(sql_creation) # adding new weather results to MySQL weather table for new_weather in list_new_weather: sql_new_record = 'INSERT INTO cities_live (city, temp_live, temp_min, temp_max, humidity, pressure, weather_description, wind_speed, time) VALUES ' sql_new_record += '(%s,%s,%s,%s,%s,%s,%s,%s,%s)' parameters = [ new_weather['city'], new_weather['temp_live'], new_weather['temp_min'], new_weather['temp_max'], new_weather['humidity'], new_weather['pressure'], new_weather['weather_description'], new_weather['wind_speed'], new_weather['time'] ] connection.run(sql_new_record, autocommit=True, parameters=parameters)
def select_monthly_sales(): connection = MySqlHook(mysql_conn_id='mysql_default') connection.run(""" SELECT count(*) FROM airflow_bi.monthly_item_sales; """, autocommit=True) return True
def load_BitcoinRate(**kwargs): ti = kwargs['ti'] sql=ti.xcom_pull(key=None, task_ids='transform_BitcoinRate') connection = MySqlHook(mysql_conn_id=os.environ['AIRFLOW_CONN_ID']) connection.run(sql, autocommit=True) return True
def execute(self, context): self.log.info('Executing: %s', self.sql) hook = MySqlHook(mysql_conn_id=self.mysql_conn_id, schema=self.database) hook.run( self.sql, autocommit=self.autocommit, parameters=self.parameters)
def set_error(workflow_process_id, message): db = MySqlHook(mysql_conn_id='mariadb', schema="djob") sql = f""" update workflow_process set ready = 1, retry_count = retry_count + 1, message = %s where workflow_process_id = %s """ db.run(sql, autocommit=True, parameters=[message, workflow_process_id])
def insert_counts(date): fruitRecs = pandas.read_csv("/tmp/fruit/combined/%s.txt" % (date),sep=" ",names=['type','date','time']) fruitCounts = fruitRecs.groupby('type').count() mysql_hook = MySqlHook(mysql_conn_id="mysql_fruit", schema="fruit") # This connection must be set from the Connection view in Airflow UI connection = mysql_hook.get_conn() # Gets the connection from PostgreHook #connection.insert_rows("insert into fruit_count (apples) values (11)") mysql_hook.run("insert into fruit_counts (rec_date,apples,figs) values ('%s',%s,%s)" % (date,fruitCounts.date.apples,fruitCounts.date.figs))
def set_sign_activity(instance_id, contents, context): db = MySqlHook(mysql_conn_id='mariadb', schema="dapp") sql = f""" update sign_activity set contents = %s where instance_id = %s """ db.run(sql, autocommit=True, parameters=[contents, instance_id])
def sql_import(**kwargs): input_file = kwargs['templates_dict']['input_file'] columns = ["WORD", "TIMES"] mysql = MySqlHook(mysql_conn_id='workshop_sql_conn_id') mysql.run("TRUNCATE WORDCOUNT") with open(input_file) as file: reader = csv.reader(file, delimiter=' ') data = list(reader) mysql.insert_rows('WORDCOUNT', data, target_fields=columns)
def save_avg(**kwargs): hook = MySqlHook(mysql_conn_id='mysql_default', schema='test') ti = kwargs['ti'] age = ti.xcom_pull(key='age', task_ids='cal_avg') hook.run('insert into tongji(average_age) values(%s)' % (age), autocommit=True)
def filter_db(): api = MySqlHook() data = api.get_records(sql='select * from movie where vote_average > 7') # truncate table filter api.run(sql='truncate table movie_filter') # insert ke table filter api.insert_rows(table='movie_filter', rows=data)
def store_data(**kwargs): ti = kwargs['ti'] parsed_records = ti.xcom_pull(key=None, task_ids='download_image') connection = MySqlHook(mysql_conn_id='mysql_default') for r in parsed_records: url = r['url'] data = json.dumps(r) sql = 'INSERT INTO recipes(url,data) VALUES (%s,%s)' connection.run(sql, autocommit=True, parameters=(url, data)) return True
def update_sync_records_from_kv(**kwargs): logging.info("update_sync_records_from_kv_ids:".format( kwargs['update_id'])) if kwargs['update_id']: mysql_hook = MySqlHook(mysql_conn_id='cloudsql-test') mysql_hook.run( "update {}.{} set sync_status = 1 where samepleid in ({})".format( kwargs['export_database'], kwargs['export_table'], kwargs['update_id']), True) else: logging.info("update_sync_records_from_kv: no records need to update")
def sql_import(**kwargs): input_file = kwargs['templates_dict']['input_file'] columns = [ "CODE", "NUMBER_RELATED_ORDERS", "NUMBER_STATUSES", "NUMBER_PARTNERS", "NUMBER_COMMENTS" ] mysql = MySqlHook(mysql_conn_id='workshop_sql_conn_id') mysql.run("TRUNCATE PROCESSED_ORDER") with open(input_file) as file: reader = csv.reader(file, delimiter=' ') data = list(reader) mysql.insert_rows('PROCESSED_ORDER', data, target_fields=columns)
def execute(self, context): presto = PrestoHook(presto_conn_id=self.presto_conn_id) self.log.info("Extracting data from Presto: %s", self.sql) results = presto.get_records(self.sql) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) if self.mysql_preoperator: self.log.info("Running MySQL preoperator") self.log.info(self.mysql_preoperator) mysql.run(self.mysql_preoperator) self.log.info("Inserting rows into MySQL") mysql.insert_rows(table=self.mysql_table, rows=results)
def execute(self, context): mysql_infields = ','.join('`{}`'.format(infield) for infield in self.mysql_infields) self.log.info('MySQL fields: %s', mysql_infields) s3_hook = S3Hook(self.aws_conn_id) mysql_hook = MySqlHook(mysql_conn_id=self.mysql_conn_id) self.log.info('Listing files in s3://%s/%s', self.s3_bucket, self.s3_prefix) s3_infiles = s3_hook.list_keys(self.s3_bucket, prefix=self.s3_prefix, delimiter=self.s3_delimiter) if not s3_infiles: raise RuntimeError('no file to process') with TemporaryDirectory(prefix='airflow_mysqlloadop_') as tmp_dir: with NamedTemporaryFile('ab', dir=tmp_dir, delete=False) as tmp: for s3_infile in s3_infiles: self.log.info('Download s3://%s/%s', self.s3_bucket, s3_infile) s3_obj = s3_hook.get_key(s3_infile, self.s3_bucket) if s3_obj.content_type == 'application/x-directory': self.log.info('Skip directory: s3://%s/%s', self.s3_bucket, s3_infile) continue s3_obj.download_fileobj(tmp) mysql_infile = tmp.name self.log.info('MySQL infile: %s', mysql_infile) mysql_sql_fmt = ''' LOAD DATA LOCAL INFILE '{file}' INTO TABLE `{database}`.`{table}` FIELDS TERMINATED BY '{seps[0]}' ENCLOSED BY '{seps[1]}' LINES TERMINATED BY '{seps[2]}' ({fields}) ; ''' mysql_sql = mysql_sql_fmt.format(file=mysql_infile, database=self.mysql_database, table=self.mysql_table, seps=self.mysql_inseps, fields=mysql_infields) self.log.info('Execute SQL') mysql_hook.run(mysql_sql)
def doTestMysqlHook(*args, **kwargs): sql_hook = MySqlHook().get_hook(conn_id="mysql_operator_test_connid") sql = "select * from manzeng_predict_src_table;" result = sql_hook.get_records(sql) for row in result: print(row) sql = "select max(id) as max_id from manzeng_predict_src_table" result = sql_hook.get_records(sql) print('maxid:' + str(result[0][0])) result = sql_hook.get_first(sql) print('maxid:' + str(result[0])) LoggingMixin.log.exception("exception raise test") sql_hook.run( """insert into manzeng_result_v3(consignor_phone,prediction) values('122','33')""" )
def execute(self, context): postgres = PostgresHook(postgres_conn_id=self.redshift_conn_id) self.log.info("Extracting data from Redshift: %s", self.sql) results = postgres.get_records(self.sql) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) if self.mysql_preoperator: self.log.info("Running MySQL preoperator") self.log.info(self.mysql_preoperator) mysql.run(self.mysql_preoperator) self.log.info("Inserting rows into MySQL") mysql.insert_rows(table=self.mysql_table, rows=results, replace=self.replace)
def get_workflow(**context): db = MySqlHook(mysql_conn_id='mariadb', schema="djob") sql = """ select workflow_process_id,ngen,site_id,application_id,instance_id,schema_id,name,workflow_instance_id,state,retry_count,ready, execute_date,created_date,bookmark,version,request,reserved,message from workflow_process where ready > 0 and retry_count < 10 limit 1 """ task = {} rows = db.get_records(sql) for row in rows: model = { 'workflow_process_id': row[0], 'ngen': row[1], 'site_id': row[2], 'application_id': row[3], 'instance_id': row[4], 'schema_id': row[5], 'name': row[6], 'workflow_instance_id': row[7], 'state': row[8], 'retry_count': row[9], 'ready': row[10], 'execute_date': str(row[11]), 'created_date': str(row[12]), 'bookmark': row[13], 'version': row[14], 'request': row[15], 'reserved': row[16], 'message': row[17] } task = model # 객체가 있는 경우 처리 if task != {}: context['ti'].xcom_push(key=WORKFLOWS, value=task) sql = f""" update workflow_process set ready = 0, bookmark = 'start' where workflow_process_id = %s """ db.run(sql, autocommit=True, parameters=[task['workflow_process_id']])
def create_table(): # Drop and Re-create table connection = MySqlHook(mysql_conn_id='mysql_default') sql = '''CREATE TABLE IF NOT EXISTS `swapi_data`.`swapi_people_aggregate` ( `id` int(11) NOT NULL auto_increment, `film_name` varchar(100) NOT NULL default '', `film` varchar(100) NOT NULL default '', `name` varchar(100) NOT NULL default '', `birth_year` DECIMAL(4,1) NOT NULL default 0, PRIMARY KEY (`id`) );''' connection.run(sql, autocommit=True, parameters=()) sql = '''DELETE FROM `swapi_data`.`swapi_people_aggregate`;''' connection.run(sql, autocommit=True, parameters=()) return True
def local_to_mysql(): connection = MySqlHook(mysql_conn_id='youtube_db') query = ''' CREATE TABLE IF NOT EXISTS `group3`.`youtube7` ( `video_id` VARCHAR(100) NOT NULL, `title` VARCHAR(100) NULL, `publishedAt` VARCHAR(45) NULL, `channelId` VARCHAR(45) NULL, `channelTitle` VARCHAR(60) NULL, `categoryId` INT NULL, `trending_date` DATETIME NULL, `tags` LONGTEXT NULL, `view_count` INT NULL, `likes` INT NULL, `dislikes` INT NULL, `comment_count` INT NULL, `thumbnail_link` VARCHAR(100) NULL, `comments_disabled` TINYINT NULL, `ratings_disabled` TINYINT NULL, `description` LONGTEXT NULL, PRIMARY KEY (`video_id`)); ''' connection.run(query, autocommit=True) # df = pd.read_csv(temp_youtube_trending_vids) df = pd.concat( [pd.read_csv(f, sep=',') for f in glob.glob('/temp' + "/*.csv")], ignore_index=True) df = df.where((pd.notnull(df)), None) for i, row in df.iterrows(): query = ''' INSERT IGNORE INTO group3.youtube7 (video_id, title, publishedAt, channelId, channelTitle, categoryId, trending_date, tags, view_count, likes, dislikes, comment_count, thumbnail_link, comments_disabled, ratings_disabled, description) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s) ''' try: connection.run(query, autocommit=True, parameters=tuple(row)) except: pass
def execute(self, context): dest_mysql = MySqlHook(mysql_conn_id=self.dest_mysqls_conn_id) self.cursor = self.cursor if not data_cursor else kwargs['ti'].xcom_pull( key=None, task_ids=data_cursor) logging.info( "Transferring cursor into new Mysql database.") if self.mysql_preoperator: logging.info("Running Mysql preoperator") dest_mysql.run(self.mysql_preoperator) dest_mysql.insert_rows(table=self.dest_table, rows=self.cursor) logging.info(self.cursor.rowcount, " rows inserted") else: logging.info("No rows inserted") if self.mysql_postoperator: logging.info("Running Mysql postoperator") dest_mysql.run(self.mysql_postoperator) logging.info("Done.")
def set_signers(doc, group, context): db = MySqlHook(mysql_conn_id='mariadb', schema="dapp") sql = f""" insert into signers(instance_id, sign_area_id, sequence, sub_instance_id, sign_section, sign_position, sign_action, is_executed, group_culture, group_id, group_name, created_date, received_date, approved_date) values(%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s); """ is_executed = True sub_instance_id = 0 sign_action = STATUS_00 # 기결재 db.run(sql, autocommit=True, parameters=[ doc.find('instance_id').text, doc.find('sign_area_id').text, doc.find('sequence').text, sub_instance_id, doc.attrib['sign_section'], doc.attrib['sign_position'], sign_action, is_executed, group['culture'], doc.find('group_id').text, group['name'], datetime.now(), datetime.now(), datetime.now() ])
def store_people(records): connection = MySqlHook(mysql_conn_id='mysql_default') for person in records: name = person['name'] birth_year = person['birth_year'] url = person['url'] films = person['films'] film_names_master = { "http://swapi.dev/api/films/1/": "A New Hope", "http://swapi.dev/api/films/2/": "The Empire Strikes Back", "http://swapi.dev/api/films/3/": "Return of the Jedi", "http://swapi.dev/api/films/4/": "The Phantom Menace", "http://swapi.dev/api/films/5/": "Attack of the Clones", "http://swapi.dev/api/films/6/": "Revenge of the Sith", "http://swapi.dev/api/films/7/": "The Force Awakens", } for film in films: film_name = film_names_master[film] sql = 'INSERT INTO `swapi_data`.`swapi_people`(name, birth_year, film, url, film_name) VALUES (%s, %s, %s, %s, %s)' connection.run(sql, autocommit=True, parameters=(name, birth_year, film, url, film_name)) return True
def execute(self, context): vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) tmpfile = None result = None selected_columns = [] count = 0 with closing(vertica.get_conn()) as conn: with closing(conn.cursor()) as cursor: cursor.execute(self.sql) selected_columns = [d.name for d in cursor.description] if self.bulk_load: tmpfile = NamedTemporaryFile("w") self.log.info( "Selecting rows from Vertica to local file %s...", tmpfile.name) self.log.info(self.sql) csv_writer = csv.writer(tmpfile, delimiter='\t', encoding='utf-8') for row in cursor.iterate(): csv_writer.writerow(row) count += 1 tmpfile.flush() else: self.log.info("Selecting rows from Vertica...") self.log.info(self.sql) result = cursor.fetchall() count = len(result) self.log.info("Selected rows from Vertica %s", count) if self.mysql_preoperator: self.log.info("Running MySQL preoperator...") mysql.run(self.mysql_preoperator) try: if self.bulk_load: self.log.info("Bulk inserting rows into MySQL...") with closing(mysql.get_conn()) as conn: with closing(conn.cursor()) as cursor: cursor.execute( "LOAD DATA LOCAL INFILE '%s' INTO " "TABLE %s LINES TERMINATED BY '\r\n' (%s)" % (tmpfile.name, self.mysql_table, ", ".join(selected_columns))) conn.commit() tmpfile.close() else: self.log.info("Inserting rows into MySQL...") mysql.insert_rows(table=self.mysql_table, rows=result, target_fields=selected_columns) self.log.info("Inserted rows into MySQL %s", count) except (MySQLdb.Error, MySQLdb.Warning): self.log.info("Inserted rows into MySQL 0") raise if self.mysql_postoperator: self.log.info("Running MySQL postoperator...") mysql.run(self.mysql_postoperator) self.log.info("Done")
def execute(self, context): vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) tmpfile = None result = None selected_columns = [] count = 0 with closing(vertica.get_conn()) as conn: with closing(conn.cursor()) as cursor: cursor.execute(self.sql) selected_columns = [d.name for d in cursor.description] if self.bulk_load: tmpfile = NamedTemporaryFile("w") self.log.info( "Selecting rows from Vertica to local file %s...", tmpfile.name) self.log.info(self.sql) csv_writer = csv.writer(tmpfile, delimiter='\t', encoding='utf-8') for row in cursor.iterate(): csv_writer.writerow(row) count += 1 tmpfile.flush() else: self.log.info("Selecting rows from Vertica...") self.log.info(self.sql) result = cursor.fetchall() count = len(result) self.log.info("Selected rows from Vertica %s", count) if self.mysql_preoperator: self.log.info("Running MySQL preoperator...") mysql.run(self.mysql_preoperator) try: if self.bulk_load: self.log.info("Bulk inserting rows into MySQL...") with closing(mysql.get_conn()) as conn: with closing(conn.cursor()) as cursor: cursor.execute("LOAD DATA LOCAL INFILE '%s' INTO " "TABLE %s LINES TERMINATED BY '\r\n' (%s)" % (tmpfile.name, self.mysql_table, ", ".join(selected_columns))) conn.commit() tmpfile.close() else: self.log.info("Inserting rows into MySQL...") mysql.insert_rows(table=self.mysql_table, rows=result, target_fields=selected_columns) self.log.info("Inserted rows into MySQL %s", count) except (MySQLdb.Error, MySQLdb.Warning): self.log.info("Inserted rows into MySQL 0") raise if self.mysql_postoperator: self.log.info("Running MySQL postoperator...") mysql.run(self.mysql_postoperator) self.log.info("Done")
def execute(self, context=None): metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id) table = metastore.get_table(table_name=self.table) field_types = {col.name: col.type for col in table.sd.cols} exprs = {('', 'count'): 'COUNT(*)'} for col, col_type in list(field_types.items()): d = {} if self.assignment_func: d = self.assignment_func(col, col_type) if d is None: d = self.get_default_exprs(col, col_type) else: d = self.get_default_exprs(col, col_type) exprs.update(d) exprs.update(self.extra_exprs) exprs = OrderedDict(exprs) exprs_str = ",\n ".join( [v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items()]) where_clause = [ "{} = '{}'".format(k, v) for k, v in self.partition.items() ] where_clause = " AND\n ".join(where_clause) sql = "SELECT {exprs_str} FROM {table} WHERE {where_clause};".format( exprs_str=exprs_str, table=self.table, where_clause=where_clause) presto = PrestoHook(presto_conn_id=self.presto_conn_id) self.log.info('Executing SQL check: %s', sql) row = presto.get_first(hql=sql) self.log.info("Record: %s", row) if not row: raise AirflowException("The query returned None") part_json = json.dumps(self.partition, sort_keys=True) self.log.info("Deleting rows from previous runs if they exist") mysql = MySqlHook(self.mysql_conn_id) sql = """ SELECT 1 FROM hive_stats WHERE table_name='{table}' AND partition_repr='{part_json}' AND dttm='{dttm}' LIMIT 1; """.format(table=self.table, part_json=part_json, dttm=self.dttm) if mysql.get_records(sql): sql = """ DELETE FROM hive_stats WHERE table_name='{table}' AND partition_repr='{part_json}' AND dttm='{dttm}'; """.format(table=self.table, part_json=part_json, dttm=self.dttm) mysql.run(sql) self.log.info("Pivoting and loading cells into the Airflow db") rows = [(self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, row)] mysql.insert_rows(table='hive_stats', rows=rows, target_fields=[ 'ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value', ])
def execute(self, context=None): metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id) table = metastore.get_table(table_name=self.table) field_types = {col.name: col.type for col in table.sd.cols} exprs = { ('', 'count'): 'COUNT(*)' } for col, col_type in list(field_types.items()): d = {} if self.assignment_func: d = self.assignment_func(col, col_type) if d is None: d = self.get_default_exprs(col, col_type) else: d = self.get_default_exprs(col, col_type) exprs.update(d) exprs.update(self.extra_exprs) exprs = OrderedDict(exprs) exprs_str = ",\n ".join([ v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items()]) where_clause = ["{} = '{}'".format(k, v) for k, v in self.partition.items()] where_clause = " AND\n ".join(where_clause) sql = "SELECT {exprs_str} FROM {table} WHERE {where_clause};".format( exprs_str=exprs_str, table=self.table, where_clause=where_clause) presto = PrestoHook(presto_conn_id=self.presto_conn_id) self.log.info('Executing SQL check: %s', sql) row = presto.get_first(hql=sql) self.log.info("Record: %s", row) if not row: raise AirflowException("The query returned None") part_json = json.dumps(self.partition, sort_keys=True) self.log.info("Deleting rows from previous runs if they exist") mysql = MySqlHook(self.mysql_conn_id) sql = """ SELECT 1 FROM hive_stats WHERE table_name='{table}' AND partition_repr='{part_json}' AND dttm='{dttm}' LIMIT 1; """.format(table=self.table, part_json=part_json, dttm=self.dttm) if mysql.get_records(sql): sql = """ DELETE FROM hive_stats WHERE table_name='{table}' AND partition_repr='{part_json}' AND dttm='{dttm}'; """.format(table=self.table, part_json=part_json, dttm=self.dttm) mysql.run(sql) self.log.info("Pivoting and loading cells into the Airflow db") rows = [(self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, row)] mysql.insert_rows( table='hive_stats', rows=rows, target_fields=[ 'ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value', ] )