def execute(self, context): logging.info('Executing: ' + str(self.sql), self.parameters) self.hook = MySqlHook(mysql_conn_id=self.mysql_conn_id) data = self.hook.get_records(self.sql, parameters=self.parameters) #logging.info('Executed: ' + str(x)) return data
def etl_process(**kwargs): logger.info(kwargs["execution_date"]) file_path = FSHook(FILE_CONNECTION_NAME).get_path() filename = 'time_series_covid19_recovered_global.csv' mysql_connection = MySqlHook(mysql_conn_id=CONNECTION_DB_NAME).get_sqlalchemy_engine() full_path = f'{file_path}/{filename}' recovered = pd.read_csv(full_path, encoding = "ISO-8859-1").rename(columns= {'Lat': 'lat', 'Long': 'lon'}) recovered['lat'] = recovered.lat.astype(str) recovered['lon'] = recovered.lon.astype(str) variables = [ "Province/State", "Country/Region", "lat", "lon" ] new_recovered = pd.melt(frame=recovered, id_vars= variables, var_name="fecha",value_name="recovered") new_recovered["recovered"] = new_recovered["recovered"].astype(int) with mysql_connection.begin() as connection: connection.execute("DELETE FROM airflowcovid.recoverd WHERE 1=1") new_recovered.rename(columns=COLUMNS).to_sql('recoverd', con=connection, schema='airflowcovid', if_exists='append', index=False) os.remove(full_path) logger.info(f"Rows inserted into recoverd table in Mysql")
def execute(self, context): mysql_hook = MySqlHook(schema=self.database, mysql_conn_id=self.mysql_conn_id) for rows in self._bq_get_data(): mysql_hook.insert_rows(self.mysql_table, rows, replace=self.replace)
def get_force_run_data(self): mysql = MySqlHook(mysql_conn_id=biowardrobe_connection_id) with closing(mysql.get_conn()) as conn: with closing(conn.cursor()) as cursor: cursor.execute("""select uid from labdata where ((forcerun=1 AND libstatus >11 AND deleted=0) OR deleted=1) """) return cursor.fetchall()
def bulk_load_teams(table_name, **kwargs): local_filepath = '/home/vagrant/airflow/dags/baseballdatabank-master/core/top_teams_final.csv' conn = MySqlHook(mysql_conn_id='local_mysql') #conn.bulk_load(table_name, local_filepath) results = pandas.read_csv(local_filepath, sep = '\t', names=['yearID', 'franchID', 'teamID', 'W', 'L', 'percentage', 'franchName' ], encoding='utf-8') conn.insert_rows(table=table_name, rows=results.values.tolist()) return table_name
def execute(self, context): logging.info(self.__class__.__name__) m_hook = MySqlHook(self.mysql_conn_id) data = (S3Hook(self.s3_conn_id).get_key( self.s3_key, bucket_name=self.s3_bucket).get_contents_as_string( encoding='utf-8')) records = [ tuple(record.split(',')) for record in data.split('\n') if record ] if self.drop_first_row: records = records[1:] if len(records) < 1: logging.info("No records") return insert_query = ''' INSERT INTO {schema}.{table} ({columns}) VALUES ({placeholders}) '''.format(schema=self.schema, table=self.table, columns=', '.join(self.columns), placeholders=', '.join('%s' for col in self.columns)) conn = m_hook.get_conn() cur = conn.cursor() cur.executemany(insert_query, records) cur.close() conn.commit() conn.close()
def select_monthly_sales(): connection = MySqlHook(mysql_conn_id='mysql_default') connection.run(""" SELECT count(*) FROM airflow_bi.monthly_item_sales; """, autocommit=True) return True
def bulk_load_sql(table_name, **kwargs): selected_files = kwargs['ti'].xcom_pull(task_ids='select_files') conn = MySqlHook(mysql_conn_id='telmetry_mysql') import pandas for selected_file in selected_files: df = pandas.read_csv(selected_file, sep=",", decimal=".", encoding='utf-8') df['wheel'] = df['wheel'].str[2:4] df['action'] = df['action'].str[2:4] connection = conn.get_conn() try: cursor = connection.cursor() sql = "insert into " + table_name + " (" + ",".join([ str(f) for f in df ]) + ") values (" + ",".join(["%s"] * len(df.columns)) + ")" print("SQL statement is " + sql) for index, row in df.iterrows(): values = [row[c] for c in df] print("inserting values " + str(values)) cursor.execute(sql, values) connection.commit() finally: connection.close() return table_name
def cal_avg(**kwargs): # pushes an XCom without a specific target hook = MySqlHook(mysql_conn_id='mysql_default', schema='test') age = hook.get_first('select round(avg(age),0) from user') kwargs['ti'].xcom_push(key='age', value=age)
def execute(self, context): logging.info('Executing: ' + str(self.sql_queries)) mysql_hook = MySqlHook(mysql_conn_id=self.mysql_conn_id) mongo_hook = MongoHook(mongo_conn_id=self.mongo_conn_id) logging.info("Transferring MySQL query results into MongoDB database.") mysql_conn = mysql_hook.get_conn() mysql_conn.cursorclass = MySQLdb.cursors.DictCursor cursor = mysql_conn.cursor() mongo_conn = mongo_hook.get_conn() mongo_db = mongo_conn.weather if self.mysql_preoperator: logging.info("Running MySQL preoperator") cursor.execute(self.mysql_preoperator) for index, sql in enumerate(self.sql_queries): cursor.execute(sql, self.parameters) fetched_rows = list(cursor.fetchall()) mongo_db[self.mongo_collections[index]].insert_many(fetched_rows) logging.info("Transfer Done")
def etl_process(**kwargs): logger.info(kwargs["execution_date"]) file_path = FSHook(FILE_CONNECTION_NAME).get_path() filename = 'time_series_covid19_deaths_global.csv' mysql_connection = MySqlHook( mysql_conn_id=CONNECTION_DB_NAME).get_sqlalchemy_engine() full_path = f'{file_path}/{filename}' logger.info(full_path) df = pd.read_csv(full_path) df = pd.melt(df, id_vars=['Lat', 'Long', 'Province/State', 'Country/Region'], var_name="RegDate", value_name="Count") df = df[df["Count"] > 0] df = df.rename(columns={ 'Province/State': 'State', 'Country/Region': 'Country' }) df['RegDate'] = pd.to_datetime(df['RegDate']) df['Type'] = 'D' with mysql_connection.begin() as connection: connection.execute("DELETE FROM Covid.Cases WHERE Type='D'") df.to_sql('Cases', con=connection, schema='Covid', if_exists='append', index=False) os.remove(full_path) logger.info(f"Rows inserted confirmed {len(df.index)}")
def execute(self, context): self.log.info('Executing: %s', self.sql) hook = MySqlHook(mysql_conn_id=self.mysql_conn_id, schema=self.database) hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
def build_aggregrate(): connection = MySqlHook(mysql_conn_id='mysql_default') sql = ''' INSERT INTO `swapi_data`.`swapi_people_aggregate` (film, birth_year, name, film_name) SELECT film, max(birth_year_number) as birth_year, ( SELECT name FROM swapi_data.swapi_people WHERE film = t.film ORDER BY birth_year_number DESC LIMIT 0,1 ) as name, film_name FROM swapi_data.swapi_people t GROUP BY film, film_name; ''' connection.run(sql, autocommit=True, parameters=()) return True
def set_sign_users(doc, context): db = MySqlHook(mysql_conn_id='mariadb', schema="dapp") # conn = db.get_conn() # cursor = conn.cursor() sql = f""" insert into sign_users(instance_id, sign_area_id, sequence, user_culture, user_id, user_name, responsibility, position, class_position, host_address, reserved_date, delay_time, is_deputy, is_comment) values(%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s) """ # result = cursor.execute(sql) # logging.info(f'cursor result: {result}') db.run(sql, autocommit=True, parameters=[ doc.find("instance_id").text, doc.find('sign_area_id').text, doc.find('sequence').text, doc.find('user_culture').text, doc.find('user_id').text, doc.find('user_name').text, doc.find('responsibility').text, doc.find('position').text, doc.find('class_position').text, doc.find('host_address').text, doc.find('reserved_date').text, doc.find('delay_time').text, doc.find('is_deputy').text, doc.find('is_comment').text ])
def get_groups(group_id, context): db = MySqlHook(mysql_conn_id='mariadb', schema="dbo") sql = f""" select group_id, parent_id, name, culture, group_code, sequence, is_childs, depth, is_display, email, interface_id, remark, created_date, modified_date from groups where group_id = %s """ task = {} rows = db.get_records(sql, parameters=[group_id]) for row in rows: model = { 'group_id': row[0], 'parent_id': row[1], 'name': row[2], 'culture': row[3], 'group_code': row[4], 'sequence': row[5], 'is_childs': row[6], 'depth': row[7], 'is_display': row[8], 'email': row[9], 'remark': row[10], 'created_date': str(row[11]), 'modified_date': str(row[12]) } task = model context['ti'].xcom_push(key=GROUPS, value=task) return task
def get_sign_activity(instance_id, context): db = MySqlHook(mysql_conn_id='mariadb', schema="dapp") sql = f""" select instance_id, sign_area_id, sequence, sign_area, sign_section, sign_position, sign_action, is_comment, is_executed, group_id, user_id, host_address from sign_activity where instance_id = %s """ tasks = {} tasks[SIGN_ACTIVITY] = [] rows = db.get_records(sql, parameters=[instance_id]) for row in rows: model = { 'instance_id': row[0], 'sign_area_id': row[1], 'sequence': row[2], 'sign_area': row[3], 'sign_section': row[4], 'sign_position': row[5], 'sign_action': row[6], 'is_comment': row[7], 'is_executed': row[8], 'group_id': row[9], 'user_id': row[10], 'host_address': row[11] } tasks[SIGN_ACTIVITY].append(model) context['ti'].xcom_push(key=SIGN_ACTIVITY, value=tasks[SIGN_ACTIVITY]) return list(tasks.values())
def get_signers(instance_id, context): db = MySqlHook(mysql_conn_id='mariadb', schema="dapp") sql = f""" select instance_id, sign_area_id, sequence, sub_instance_id, sign_section, sign_position, sign_action, is_executed, group_culture, group_id, group_name, created_date, received_date, approved_date from signers where instance_id = %s """ tasks = {} rows = db.get_records(sql, parameters=[instance_id]) tasks[SIGNERS] = [] for row in rows: model = { 'instance_id': row[0], 'sign_area_id': row[1], 'sequence': row[2], 'sub_instance_id': row[3], 'sign_section': row[4], 'sign_position': row[5], 'sign_action': row[6], 'is_executed': row[7], 'group_culture': row[8], 'group_id': row[9], 'group_name': row[10], 'created_date': str(row[11]), 'received_date': str(row[12]), 'approved_date': str(row[13]) } tasks[SIGNERS].append(model) context['ti'].xcom_push(key=SIGNERS, value=tasks[SIGNERS]) return list(tasks.values())
def branch_download_func(**context): biowardrobe_uid = context['dag_run'].conf['biowardrobe_uid'] \ if 'biowardrobe_uid' in context['dag_run'].conf else None if not biowardrobe_uid: raise Exception('biowardrobe_id must be provided') data = {} mysql = MySqlHook(mysql_conn_id=biowardrobe_connection_id) with closing(mysql.get_conn()) as conn: with closing(conn.cursor()) as cursor: data = get_biowardrobe_data(cursor=cursor, biowardrobe_uid=biowardrobe_uid) _logger.info("Data: ", data) context['ti'].xcom_push(key='url', value=data['url']) context['ti'].xcom_push(key='uid', value=data['uid']) context['ti'].xcom_push(key='upload', value=data['upload']) context['ti'].xcom_push(key='output_folder', value=data['output_folder']) context['ti'].xcom_push(key='email', value=data['email']) if re.match("^https?://|^s?ftp://|^-", data['url']): return "download_aria2" if re.match("^(GSM|SR[ARX])[0-9]+$", data['url']): return "download_sra" return "download_local"
def integration_procces(**kwargs): db_connection = MySqlHook('airflow_db').get_sqlalchemy_engine() with db_connection.begin() as transaction: df_C = pd.read_sql_table("confirmed", con=transaction, schema='covid') df_R = pd.read_sql_table("recovered", con=transaction, schema='covid') df_D = pd.read_sql_table("deaths", con=transaction, schema='covid') df_C = df_C.drop(columns=['id']) df_R = df_R.drop(columns=['id', 'lat', 'long']) df_D = df_D.drop(columns=['id', 'lat', 'long']) df_C['province_state'] = df_C['province_state'].fillna('') df_R['province_state'] = df_R['province_state'].fillna('') df_D['province_state'] = df_D['province_state'].fillna('') df_C = df_C.rename(columns=COLUMNS_C) df_R = df_R.rename(columns=COLUMNS_R) df_D = df_D.rename(columns=COLUMNS_D) df = pd.merge(df_C, df_R, on=['country_region', 'province_state', 'event_date']) df = pd.merge(df, df_D, on=['country_region', 'province_state', 'event_date']) df['mortality_rate'] = df['d_cases'] / df['c_cases'] df['recovery_rate'] = df['r_cases'] / df['c_cases'] #df_final = df[COLUMNS_VIEW] df_final = df with db_connection.begin() as transaction: transaction.execute("DELETE FROM covid.cases_data WHERE 1=1") df_final.to_sql("cases_data", con=transaction, schema="covid", if_exists="append", index=False)
def get_mysql_cursor(self): mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) logging.info("Success connect to %s Mysql databases" % (self.mysql_conn_id)) mysql_conn = mysql.get_conn() mysql_cursor = mysql_conn.cursor() return mysql_conn, mysql_cursor
def load_rows_to_destination(df, conn_id_dest, table_name_dest, load_all): # extract every table as a data frame and convert it to object type -> NaT values then are treated as null # objects and can be converted to None df = df.astype(object) # convert nulls to None (needed in MySQL upload) logging.debug("Convert NaN, NaT -> None") df = df.where(pd.notnull(df), None) target_fields = list(df.keys()) logging.info("Column fields from source: {}".format(target_fields)) logging.info("Row Count from chunk source: '{}'".format(df.shape[0])) if not load_all: # just load the part that has updated_at > last_destination_updated_at mysql_hook_load = hooks.MyMysqlHook(conn_id_dest) # replace should be false, but cannot be sure that we are not repeating values mysql_hook_load.insert_update_on_duplicate_rows( table_name_dest, rows=df.values.tolist(), columns=target_fields, commit_every=1000) else: # load everything replacing any value if the same PK is found mysql_hook_load = MySqlHook(conn_id_dest) mysql_hook_load.insert_rows(table_name_dest, df.values.tolist(), target_fields=target_fields, commit_every=1000, replace=True)
def execute(self, context): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) self.log.info("Dumping MySQL query results to local file") conn = mysql.get_conn() cursor = conn.cursor() cursor.execute(self.sql) with NamedTemporaryFile("wb") as f: csv_writer = csv.writer(f, delimiter=self.delimiter, quoting=self.quoting, quotechar=self.quotechar, escapechar=self.escapechar, encoding="utf-8") field_dict = OrderedDict() for field in cursor.description: field_dict[field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor) f.flush() cursor.close() conn.close() self.log.info("Loading file into Hive") hive.load_file(f.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties)
def execute(self, context): logging.info('Executing: ' + str(self.sql)) hook = MySqlHook(mysql_conn_id=self.mysql_conn_id) hook.run( self.sql, autocommit=self.autocommit, parameters=self.parameters)
def execute(self, context): log.info("이찬호") db = MySqlHook(mysql_conn_id=self.db_conn_id, schema=self.db_schema) sql = """ select o.id, o.name, o.desc from test o """ # initialize the task list buckets tasks = {} index = 0 rows = db.get_records(sql) for row in rows: index += 1 tasks[f'get_workflow_{index}'] = [] resultCounter = 0 for row in rows: resultCounter += 1 bucket = (resultCounter % index) model = {'id': str(row[0]), 'name': str(row[1])} tasks[f'get_workflow_{bucket}'].append(model) # Push the order lists into xcom for task in tasks: if len(tasks[task]) > 0: logging.info(f'Task {task} has {len(tasks[task])} orders.') context['ti'].xcom_push(key=task, value=tasks[task]) return list(tasks.values())
def execute(self, context): hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) self.log.info("Extracting data from Hive: %s", self.sql) if self.bulk_load: tmpfile = NamedTemporaryFile() hive.to_csv(self.sql, tmpfile.name, delimiter='\t', lineterminator='\n', output_header=False, hive_conf=context_to_airflow_vars(context)) else: results = hive.get_records(self.sql) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) if self.mysql_preoperator: self.log.info("Running MySQL preoperator") mysql.run(self.mysql_preoperator) self.log.info("Inserting rows into MySQL") if self.bulk_load: mysql.bulk_load(table=self.mysql_table, tmp_file=tmpfile.name) tmpfile.close() else: mysql.insert_rows(table=self.mysql_table, rows=results) if self.mysql_postoperator: self.log.info("Running MySQL postoperator") mysql.run(self.mysql_postoperator) self.log.info("Done.")
def execute(self, context): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) self.log.info("Dumping MySQL query results to local file") conn = mysql.get_conn() cursor = conn.cursor() cursor.execute(self.sql) with NamedTemporaryFile("wb") as f: csv_writer = csv.writer(f, delimiter=self.delimiter, encoding="utf-8") field_dict = OrderedDict() for field in cursor.description: field_dict[field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor) f.flush() cursor.close() conn.close() self.log.info("Loading file into Hive") hive.load_file( f.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties)
def partitions(self): """ Retrieve table partitions """ schema, table = request.args.get("table").split('.') sql = """ SELECT a.PART_NAME, a.CREATE_TIME, c.LOCATION, c.IS_COMPRESSED, c.INPUT_FORMAT, c.OUTPUT_FORMAT FROM PARTITIONS a JOIN TBLS b ON a.TBL_ID = b.TBL_ID JOIN DBS d ON b.DB_ID = d.DB_ID JOIN SDS c ON a.SD_ID = c.SD_ID WHERE b.TBL_NAME like '{table}' AND d.NAME like '{schema}' ORDER BY PART_NAME DESC """.format(table=table, schema=schema) hook = MySqlHook(METASTORE_MYSQL_CONN_ID) df = hook.get_pandas_df(sql) return df.to_html( classes="table table-striped table-bordered table-hover", index=False, na_rep='', )
def save_data_into_db(): mysql_hook = MySqlHook(mysql_conn_id='Covid19') with open('data.json') as f: data = json.load(f) insert = """ INSERT INTO daily_covid19_reports ( confirmed, recovered, hospitalized, deaths, new_confirmed, new_recovered, new_hospitalized, new_deaths, update_date, source, dev_by, server_by) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s); """ mysql_hook.run(insert, parameters=(data['Confirmed'], data['Recovered'], data['Hospitalized'], data['Deaths'], data['NewConfirmed'], data['NewRecovered'], data['NewHospitalized'], data['NewDeaths'], datetime.strptime(data['UpdateDate'], '%d/%m/%Y %H:%M'), data['Source'], data['DevBy'], data['SeverBy']))
def index(self): end_date = datetime.now().date() start_date = end_date - timedelta(days=1) sql = """ select a.dag_id, a.state, min(a.start_date) as start_date, max(a.end_date) as end_date, max(a.end_date)-min(a.start_date) as duration, b.job_type, a.job_id from task_instance as a join job as b ON a.job_id = b.id where a.start_date >= "{start_date}" and a.start_date < "{end_date}" and a.state != 'failed' group by a.dag_id, a.job_id order by start_date; """.format(start_date=start_date, end_date=end_date) h = MySqlHook(METASTORE_MYSQL_CONN_ID) rows = h.get_records(sql) tasks = [] taskNames = [] name_set = set("") time_format = "%Y-%m-%dT%H:%M:%S" for row in rows: dag_id = row[0] state = row[1] start_date = row[2] end_date = row[3] if not end_date: end_date = datetime.now() duration = str(row[4]) task = { 'status': state, 'taskName': dag_id, 'startDate': time.mktime(start_date.timetuple()) * 1000, 'endDate': time.mktime(end_date.timetuple()) * 1000, 'executionDate': start_date.strftime(time_format), 'isoStart': start_date.strftime(time_format), 'isoEnd': end_date.strftime(time_format), 'duration': duration } taskNames.append(dag_id) name_set.add(dag_id) tasks.append(task) data = { 'height': 20 * len(name_set), 'tasks': tasks, 'taskNames': taskNames, 'taskStatus': { 'success': 'success' } } return self.render("scheduler_browser/gantt.html", data=data)
def objects(self): where_clause = '' if DB_WHITELIST: dbs = ",".join(["'" + db + "'" for db in DB_WHITELIST]) where_clause = "AND b.name IN ({})".format(dbs) if DB_BLACKLIST: dbs = ",".join(["'" + db + "'" for db in DB_BLACKLIST]) where_clause = "AND b.name NOT IN ({})".format(dbs) sql = """ SELECT CONCAT(b.NAME, '.', a.TBL_NAME), TBL_TYPE FROM TBLS a JOIN DBS b ON a.DB_ID = b.DB_ID WHERE a.TBL_NAME NOT LIKE '%tmp%' AND a.TBL_NAME NOT LIKE '%temp%' AND b.NAME NOT LIKE '%tmp%' AND b.NAME NOT LIKE '%temp%' {where_clause} LIMIT {LIMIT}; """.format(where_clause=where_clause, LIMIT=TABLE_SELECTOR_LIMIT) h = MySqlHook(METASTORE_MYSQL_CONN_ID) d = [ {'id': row[0], 'text': row[0]} for row in h.get_records(sql)] return json.dumps(d)
def _query_mysql(self): """ Queries mysql and returns a cursor to the results. """ mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) conn = mysql.get_conn() df = pd.read_sql(sql=self.sql, con=conn, index_col=self.index_col) return df
def execute(self, context): self.log.info('Executing: %s', self.sql) hook = MySqlHook(mysql_conn_id=self.mysql_conn_id, schema=self.database) hook.run( self.sql, autocommit=self.autocommit, parameters=self.parameters)
def execute(self, context): hook = MySqlHook(mysql_conn_id=self.mysql_conn_id, schema=self.database) sql = "select first_name from authors" result = hook.get_first(sql) message = "Hello {}".format(result['first_name']) print(message) return message
def _query_mysql(self): """ Queries mysql and returns a cursor to the results. """ mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) conn = mysql.get_conn() cursor = conn.cursor() cursor.execute(self.sql) return cursor
def test_mysql_hook_test_bulk_dump(self): from airflow.hooks.mysql_hook import MySqlHook hook = MySqlHook('airflow_ci') priv = hook.get_first("SELECT @@global.secure_file_priv") if priv and priv[0]: # Confirm that no error occurs hook.bulk_dump("INFORMATION_SCHEMA.TABLES", os.path.join(priv[0], "TABLES")) else: self.skipTest("Skip test_mysql_hook_test_bulk_load " "since file output is not permitted")
def execute(self, context): hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id) logging.info("Extracting data from Hive") logging.info(self.sql) if self.bulk_load: tmpfile = NamedTemporaryFile() hive.to_csv(self.sql, tmpfile.name, delimiter='\t', lineterminator='\n', output_header=False) else: results = hive.get_records(self.sql) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) if self.mysql_preoperator: logging.info("Running MySQL preoperator") mysql.run(self.mysql_preoperator) logging.info("Inserting rows into MySQL") if self.bulk_load: mysql.bulk_load(table=self.mysql_table, tmp_file=tmpfile.name) tmpfile.close() else: mysql.insert_rows(table=self.mysql_table, rows=results) if self.mysql_postoperator: logging.info("Running MySQL postoperator") mysql.run(self.mysql_postoperator) logging.info("Done.")
def execute(self, context): presto = PrestoHook(presto_conn_id=self.presto_conn_id) self.log.info("Extracting data from Presto: %s", self.sql) results = presto.get_records(self.sql) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) if self.mysql_preoperator: self.log.info("Running MySQL preoperator") self.log.info(self.mysql_preoperator) mysql.run(self.mysql_preoperator) self.log.info("Inserting rows into MySQL") mysql.insert_rows(table=self.mysql_table, rows=results)
def test_mysql_hook_test_bulk_dump_mock(self, mock_get_conn): mock_execute = mock.MagicMock() mock_get_conn.return_value.cursor.return_value.execute = mock_execute from airflow.hooks.mysql_hook import MySqlHook hook = MySqlHook('airflow_ci') table = "INFORMATION_SCHEMA.TABLES" tmp_file = "/path/to/output/file" hook.bulk_dump(table, tmp_file) from airflow.utils.tests import assertEqualIgnoreMultipleSpaces mock_execute.assert_called_once() query = """ SELECT * INTO OUTFILE '{tmp_file}' FROM {table} """.format(tmp_file=tmp_file, table=table) assertEqualIgnoreMultipleSpaces(self, mock_execute.call_args[0][0], query)
def test_mysql_to_hive_type_conversion(self, mock_load_file): mysql_conn_id = 'airflow_ci' mysql_table = 'test_mysql_to_hive' from airflow.hooks.mysql_hook import MySqlHook m = MySqlHook(mysql_conn_id) try: with m.get_conn() as c: c.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) c.execute(""" CREATE TABLE {} ( c0 TINYINT, c1 SMALLINT, c2 MEDIUMINT, c3 INT, c4 BIGINT ) """.format(mysql_table)) from airflow.operators.mysql_to_hive import MySqlToHiveTransfer t = MySqlToHiveTransfer( task_id='test_m2h', mysql_conn_id=mysql_conn_id, hive_cli_conn_id='beeline_default', sql="SELECT * FROM {}".format(mysql_table), hive_table='test_mysql_to_hive', dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) mock_load_file.assert_called_once() d = OrderedDict() d["c0"] = "SMALLINT" d["c1"] = "INT" d["c2"] = "INT" d["c3"] = "BIGINT" d["c4"] = "DECIMAL(38,0)" self.assertEqual(mock_load_file.call_args[1]["field_dict"], d) finally: with m.get_conn() as c: c.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
def test_mysql_hook_test_bulk_load(self): records = ("foo", "bar", "baz") import tempfile with tempfile.NamedTemporaryFile() as t: t.write("\n".join(records).encode('utf8')) t.flush() from airflow.hooks.mysql_hook import MySqlHook h = MySqlHook('airflow_ci') with h.get_conn() as c: c.execute(""" CREATE TABLE IF NOT EXISTS test_airflow ( dummy VARCHAR(50) ) """) c.execute("TRUNCATE TABLE test_airflow") h.bulk_load("test_airflow", t.name) c.execute("SELECT dummy FROM test_airflow") results = tuple(result[0] for result in c.fetchall()) self.assertEqual(sorted(results), sorted(records))
def index(self): sql = """ SELECT a.name as db, db_location_uri as location, count(1) as object_count, a.desc as description FROM DBS a JOIN TBLS b ON a.DB_ID = b.DB_ID GROUP BY a.name, db_location_uri, a.desc """.format(**locals()) h = MySqlHook(METASTORE_MYSQL_CONN_ID) df = h.get_pandas_df(sql) df.db = ( '<a href="/admin/metastorebrowserview/db/?db=' + df.db + '">' + df.db + '</a>') table = df.to_html( classes="table table-striped table-bordered table-hover", index=False, escape=False, na_rep='',) return self.render( "metastore_browser/dbs.html", table=table)
def setUp(self): super(TestMySqlHookConn, self).setUp() self.connection = Connection( login='******', password='******', host='host', schema='schema', ) self.db_hook = MySqlHook() self.db_hook.get_connection = mock.Mock() self.db_hook.get_connection.return_value = self.connection
def execute(self, context=None): metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id) table = metastore.get_table(table_name=self.table) field_types = {col.name: col.type for col in table.sd.cols} exprs = { ('', 'count'): 'COUNT(*)' } for col, col_type in list(field_types.items()): d = {} if self.assignment_func: d = self.assignment_func(col, col_type) if d is None: d = self.get_default_exprs(col, col_type) else: d = self.get_default_exprs(col, col_type) exprs.update(d) exprs.update(self.extra_exprs) exprs = OrderedDict(exprs) exprs_str = ",\n ".join([ v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items()]) where_clause = ["{} = '{}'".format(k, v) for k, v in self.partition.items()] where_clause = " AND\n ".join(where_clause) sql = "SELECT {exprs_str} FROM {table} WHERE {where_clause};".format( exprs_str=exprs_str, table=self.table, where_clause=where_clause) presto = PrestoHook(presto_conn_id=self.presto_conn_id) self.log.info('Executing SQL check: %s', sql) row = presto.get_first(hql=sql) self.log.info("Record: %s", row) if not row: raise AirflowException("The query returned None") part_json = json.dumps(self.partition, sort_keys=True) self.log.info("Deleting rows from previous runs if they exist") mysql = MySqlHook(self.mysql_conn_id) sql = """ SELECT 1 FROM hive_stats WHERE table_name='{table}' AND partition_repr='{part_json}' AND dttm='{dttm}' LIMIT 1; """.format(table=self.table, part_json=part_json, dttm=self.dttm) if mysql.get_records(sql): sql = """ DELETE FROM hive_stats WHERE table_name='{table}' AND partition_repr='{part_json}' AND dttm='{dttm}'; """.format(table=self.table, part_json=part_json, dttm=self.dttm) mysql.run(sql) self.log.info("Pivoting and loading cells into the Airflow db") rows = [(self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, row)] mysql.insert_rows( table='hive_stats', rows=rows, target_fields=[ 'ds', 'dttm', 'table_name', 'partition_repr', 'col', 'metric', 'value', ] )
def test_mysql_to_hive_verify_loaded_values(self): mysql_conn_id = 'airflow_ci' mysql_table = 'test_mysql_to_hive' hive_table = 'test_mysql_to_hive' from airflow.hooks.mysql_hook import MySqlHook m = MySqlHook(mysql_conn_id) try: minmax = ( 255, 65535, 16777215, 4294967295, 18446744073709551615, -128, -32768, -8388608, -2147483648, -9223372036854775808 ) with m.get_conn() as c: c.execute("DROP TABLE IF EXISTS {}".format(mysql_table)) c.execute(""" CREATE TABLE {} ( c0 TINYINT UNSIGNED, c1 SMALLINT UNSIGNED, c2 MEDIUMINT UNSIGNED, c3 INT UNSIGNED, c4 BIGINT UNSIGNED, c5 TINYINT, c6 SMALLINT, c7 MEDIUMINT, c8 INT, c9 BIGINT ) """.format(mysql_table)) c.execute(""" INSERT INTO {} VALUES ( {}, {}, {}, {}, {}, {}, {}, {}, {}, {} ) """.format(mysql_table, *minmax)) from airflow.operators.mysql_to_hive import MySqlToHiveTransfer t = MySqlToHiveTransfer( task_id='test_m2h', mysql_conn_id=mysql_conn_id, hive_cli_conn_id='beeline_default', sql="SELECT * FROM {}".format(mysql_table), hive_table=hive_table, recreate=True, delimiter=",", dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) from airflow.hooks.hive_hooks import HiveServer2Hook h = HiveServer2Hook() r = h.get_records("SELECT * FROM {}".format(hive_table)) self.assertEqual(r[0], minmax) finally: with m.get_conn() as c: c.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
def execute(self, context): vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) tmpfile = None result = None selected_columns = [] count = 0 with closing(vertica.get_conn()) as conn: with closing(conn.cursor()) as cursor: cursor.execute(self.sql) selected_columns = [d.name for d in cursor.description] if self.bulk_load: tmpfile = NamedTemporaryFile("w") self.log.info( "Selecting rows from Vertica to local file %s...", tmpfile.name) self.log.info(self.sql) csv_writer = csv.writer(tmpfile, delimiter='\t', encoding='utf-8') for row in cursor.iterate(): csv_writer.writerow(row) count += 1 tmpfile.flush() else: self.log.info("Selecting rows from Vertica...") self.log.info(self.sql) result = cursor.fetchall() count = len(result) self.log.info("Selected rows from Vertica %s", count) if self.mysql_preoperator: self.log.info("Running MySQL preoperator...") mysql.run(self.mysql_preoperator) try: if self.bulk_load: self.log.info("Bulk inserting rows into MySQL...") with closing(mysql.get_conn()) as conn: with closing(conn.cursor()) as cursor: cursor.execute("LOAD DATA LOCAL INFILE '%s' INTO " "TABLE %s LINES TERMINATED BY '\r\n' (%s)" % (tmpfile.name, self.mysql_table, ", ".join(selected_columns))) conn.commit() tmpfile.close() else: self.log.info("Inserting rows into MySQL...") mysql.insert_rows(table=self.mysql_table, rows=result, target_fields=selected_columns) self.log.info("Inserted rows into MySQL %s", count) except (MySQLdb.Error, MySQLdb.Warning): self.log.info("Inserted rows into MySQL 0") raise if self.mysql_postoperator: self.log.info("Running MySQL postoperator...") mysql.run(self.mysql_postoperator) self.log.info("Done")
class TestMySqlHookConn(unittest.TestCase): def setUp(self): super(TestMySqlHookConn, self).setUp() self.connection = Connection( login='******', password='******', host='host', schema='schema', ) self.db_hook = MySqlHook() self.db_hook.get_connection = mock.Mock() self.db_hook.get_connection.return_value = self.connection @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') def test_get_conn(self, mock_connect): self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['user'], 'login') self.assertEqual(kwargs['passwd'], 'password') self.assertEqual(kwargs['host'], 'host') self.assertEqual(kwargs['db'], 'schema') @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') def test_get_conn_port(self, mock_connect): self.connection.port = 3307 self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['port'], 3307) @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') def test_get_conn_charset(self, mock_connect): self.connection.extra = json.dumps({'charset': 'utf-8'}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['charset'], 'utf-8') self.assertEqual(kwargs['use_unicode'], True) @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') def test_get_conn_cursor(self, mock_connect): self.connection.extra = json.dumps({'cursor': 'sscursor'}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['cursorclass'], MySQLdb.cursors.SSCursor) @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') def test_get_conn_local_infile(self, mock_connect): self.connection.extra = json.dumps({'local_infile': True}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['local_infile'], 1) @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') def test_get_con_unix_socket(self, mock_connect): self.connection.extra = json.dumps({'unix_socket': "/tmp/socket"}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['unix_socket'], '/tmp/socket') @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') def test_get_conn_ssl_as_dictionary(self, mock_connect): self.connection.extra = json.dumps({'ssl': SSL_DICT}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['ssl'], SSL_DICT) @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') def test_get_conn_ssl_as_string(self, mock_connect): self.connection.extra = json.dumps({'ssl': json.dumps(SSL_DICT)}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['ssl'], SSL_DICT)
def partitions(self): schema, table = request.args.get("table").split('.') sql = """ SELECT a.PART_NAME, a.CREATE_TIME, c.LOCATION, c.IS_COMPRESSED, c.INPUT_FORMAT, c.OUTPUT_FORMAT FROM PARTITIONS a JOIN TBLS b ON a.TBL_ID = b.TBL_ID JOIN DBS d ON b.DB_ID = d.DB_ID JOIN SDS c ON a.SD_ID = c.SD_ID WHERE b.TBL_NAME like '{table}' AND d.NAME like '{schema}' ORDER BY PART_NAME DESC """.format(**locals()) h = MySqlHook(METASTORE_MYSQL_CONN_ID) df = h.get_pandas_df(sql) return df.to_html( classes="table table-striped table-bordered table-hover", index=False, na_rep='',)