def execute(self, context): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) logging.info("Extracting data from Hive") hive_table = "druid." + context["task_instance_key_str"].replace(".", "_") sql = self.sql.strip().strip(";") hql = """\ set mapred.output.compress=false; set hive.exec.compress.output=false; DROP TABLE IF EXISTS {hive_table}; CREATE TABLE {hive_table} ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' STORED AS TEXTFILE TBLPROPERTIES ('serialization.null.format' = '') AS {sql} """.format( **locals() ) logging.info("Running command:\n {}".format(hql)) hive.run_cli(hql) m = HiveMetastoreHook(self.metastore_conn_id) t = m.get_table(hive_table) columns = [col.name for col in t.sd.cols] hdfs_uri = m.get_table(hive_table).sd.location pos = hdfs_uri.find("/user") static_path = hdfs_uri[pos:] schema, table = hive_table.split(".") druid = DruidHook(druid_ingest_conn_id=self.druid_ingest_conn_id) logging.info("Inserting rows into Druid") logging.info("HDFS path: " + static_path) try: druid.load_from_hdfs( datasource=self.druid_datasource, intervals=self.intervals, static_path=static_path, ts_dim=self.ts_dim, columns=columns, num_shards=self.num_shards, target_partition_size=self.target_partition_size, query_granularity=self.query_granularity, segment_granularity=self.segment_granularity, metric_spec=self.metric_spec, hadoop_dependency_coordinates=self.hadoop_dependency_coordinates, ) logging.info("Load seems to have succeeded!") finally: logging.info("Cleaning up by dropping the temp " "Hive table {}".format(hive_table)) hql = "DROP TABLE IF EXISTS {}".format(hive_table) hive.run_cli(hql)
def merge_pre_hi_data_task(hive_db, hive_all_hi_table_name, hive_hi_table_name, is_must_have_data, pt, now_hour, pre_hour_day, pre_hour, **kwargs): sqoopSchema = SqoopSchemaUpdate() hive_columns = sqoopSchema.get_hive_column_name(hive_db, hive_all_hi_table_name) hql = ADD_HI_SQL.format(db_name=hive_db, hive_all_hi_table_name=hive_all_hi_table_name, hive_hi_table_name=hive_hi_table_name, pt=pt, now_hour=now_hour, pre_hour_day=pre_hour_day, pre_hour=pre_hour, columns=',\n'.join(hive_columns)) hive_hook = HiveCliHook() # 读取sql logging.info('Executing: %s', hql) # 执行Hive hive_hook.run_cli(hql) # 生成_SUCCESS """ 第一个参数true: 数据目录是有country_code分区。false 没有 第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS """ TaskTouchzSuccess().countries_touchz_success( pt, hive_db, hive_all_hi_table_name, ALL_HI_OSS_PATH % hive_all_hi_table_name, "false", is_must_have_data, now_hour)
def execute(self, context): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) self.log.info("Dumping MySQL query results to local file") conn = mysql.get_conn() cursor = conn.cursor() cursor.execute(self.sql) with NamedTemporaryFile("wb") as f: csv_writer = csv.writer(f, delimiter=self.delimiter, quoting=self.quoting, quotechar=self.quotechar, escapechar=self.escapechar, encoding="utf-8") field_dict = OrderedDict() for field in cursor.description: field_dict[field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor) f.flush() cursor.close() conn.close() self.log.info("Loading file into Hive") hive.load_file(f.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties)
def execution_data_task_id(ds,**kargs): hive_hook = HiveCliHook() TaskTouchzSuccess().del_path(ds,db_name,table_name,hdfs_path,"true","true") #读取sql _sql=dim_oride_city_sql_task(ds) logging.info('Executing: %s', _sql) #执行Hive hive_hook.run_cli(_sql) #熔断数据,如果数据不能为0 check_key_data_cnt_task(ds) #熔断数据 check_key_data_task(ds) #生成_SUCCESS """ 第一个参数true: 数据目录是有country_code分区。false 没有 第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS """ TaskTouchzSuccess().countries_touchz_success(ds,db_name,table_name,hdfs_path,"true","true")
def execution_data_task_id(ds, dag, **kwargs): v_date = kwargs.get('v_execution_date') v_day = kwargs.get('v_execution_day') v_hour = kwargs.get('v_execution_hour') hive_hook = HiveCliHook() """ #功能函数 alter语句: alter_partition() 删除分区: delete_partition() 生产success: touchz_success() #参数 is_countries_online --是否开通多国家业务 默认(true 开通) db_name --hive 数据库的名称 table_name --hive 表的名称 data_oss_path --oss 数据目录的地址 is_country_partition --是否有国家码分区,[默认(true 有country_code分区)] is_result_force_exist --数据是否强行产出,[默认(true 必须有数据才生成_SUCCESS)] false 数据没有也生成_SUCCESS execute_time --当前脚本执行时间(%Y-%m-%d %H:%M:%S) is_hour_task --是否开通小时级任务,[默认(false)] frame_type --模板类型(只有 is_hour_task:'true' 时生效): utc 产出分区为utc时间,local 产出分区为本地时间,[默认(utc)]。 #读取sql %_sql(ds,v_hour) """ args = [ { "dag": dag, "is_countries_online": "true", "db_name": db_name, "table_name": table_name, "data_oss_path": hdfs_path, "is_country_partition": "true", "is_result_force_exist": "true", "execute_time": v_date, "is_hour_task": "true", "frame_type": "local" } ] cf = CountriesPublicFrame_dev(args) # 读取sql _sql="\n"+cf.alter_partition()+"\n"+dim_opay_terminal_base_hf_sql_task(ds, v_date) logging.info('Executing: %s',_sql) # 执行Hive hive_hook.run_cli(_sql) # 生产success cf.touchz_success()
def execute(self, context): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id) self.log.info("Dumping Vertica query results to local file") conn = vertica.get_conn() cursor = conn.cursor() cursor.execute(self.sql) with NamedTemporaryFile("w") as f: csv_writer = csv.writer(f, delimiter=self.delimiter, encoding='utf-8') field_dict = OrderedDict() col_count = 0 for field in cursor.description: col_count += 1 col_position = "Column{position}".format(position=col_count) field_dict[col_position if field[0] == '' else field[0]] = \ self.type_map(field[1]) csv_writer.writerows(cursor.iterate()) f.flush() cursor.close() conn.close() self.log.info("Loading file into Hive") hive.load_file(f.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate)
def execute(self, context): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) mssql = MsSqlHook(mssql_conn_id=self.mssql_conn_id) logging.info("Dumping Microsoft SQL Server query results to local file") conn = mssql.get_conn() cursor = conn.cursor() cursor.execute(self.sql) with NamedTemporaryFile("w") as f: csv_writer = csv.writer(f, delimiter=self.delimiter, encoding='utf-8') field_dict = OrderedDict() col_count = 0 for field in cursor.description: col_count += 1 col_position = "Column{position}".format(position=col_count) field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor) f.flush() cursor.close() conn.close() logging.info("Loading file into Hive") hive.load_file( f.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate)
def execute(self, context): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) logging.info("Dumping MySQL query results to local file") conn = mysql.get_conn() cursor = conn.cursor(MySQLdb.cursors.SSCursor) cursor.execute(self.sql) with NamedTemporaryFile("wb") as f: csv_writer = csv.writer(f, delimiter=self.delimiter, encoding="utf-8") field_dict = OrderedDict() for field in cursor.description: field_dict[field[0]] = self.type_map(field[1]) # csv_writer.writerows(cursor) while True: row = cursor.fetchone() if not row: break csv_writer.writerow(row) f.flush() cursor.close() conn.close() logging.info("Loading file into Hive") hive.load_file(f.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate)
def test_load_df_with_data_types(self, mock_run_cli): d = OrderedDict() d['b'] = [True] d['i'] = [-1] d['t'] = [1] d['f'] = [0.0] d['c'] = ['c'] d['M'] = [datetime.datetime(2018, 1, 1)] d['O'] = [object()] d['S'] = [b'STRING'] d['U'] = ['STRING'] d['V'] = [None] df = pd.DataFrame(d) hook = HiveCliHook() hook.load_df(df, 't') query = """ CREATE TABLE IF NOT EXISTS t ( b BOOLEAN, i BIGINT, t BIGINT, f DOUBLE, c STRING, M TIMESTAMP, O STRING, S STRING, U STRING, V STRING) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS textfile ; """ assertEqualIgnoreMultipleSpaces(self, mock_run_cli.call_args_list[0][0][0], query)
def test_load_file_create_table(self, mock_run_cli): filepath = "/path/to/input/file" table = "output_table" field_dict = OrderedDict([("name", "string"), ("gender", "string")]) fields = ",\n ".join([k + ' ' + v for k, v in field_dict.items()]) hook = HiveCliHook() hook.load_file(filepath=filepath, table=table, field_dict=field_dict, create=True, recreate=True) create_table = ("DROP TABLE IF EXISTS {table};\n" "CREATE TABLE IF NOT EXISTS {table} (\n{fields})\n" "ROW FORMAT DELIMITED\n" "FIELDS TERMINATED BY ','\n" "STORED AS textfile\n;".format(table=table, fields=fields)) load_data = ("LOAD DATA LOCAL INPATH '{filepath}' " "OVERWRITE INTO TABLE {table} ;\n".format( filepath=filepath, table=table)) calls = [mock.call(create_table), mock.call(load_data)] mock_run_cli.assert_has_calls(calls, any_order=True)
def execution_data_task_id(ds, **kwargs): v_date = kwargs.get('v_execution_date') v_day = kwargs.get('v_execution_day') v_hour = kwargs.get('v_execution_hour') hive_hook = HiveCliHook() # 读取sql _sql = dwd_oride_driver_cheating_detection_hi_sql_task(ds, v_hour) logging.info('Executing: %s', _sql) # 执行Hive hive_hook.run_cli(_sql) # 熔断数据 # check_key_data_task(ds) # 生成_SUCCESS """ 第一个参数true: 数据目录是有country_code分区。false 没有 第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS """ TaskTouchzSuccess().countries_touchz_success(ds, db_name, table_name, hdfs_path, "true", "false", v_hour)
def execute(self, context): mssql = MsSqlHook(mssql_conn_id=self.mssql_conn_id) self.log.info( "Dumping Microsoft SQL Server query results to local file") with mssql.get_conn() as conn: with conn.cursor() as cursor: cursor.execute(self.sql) with NamedTemporaryFile("w") as tmp_file: csv_writer = csv.writer(tmp_file, delimiter=self.delimiter, encoding='utf-8') field_dict = OrderedDict() col_count = 0 for field in cursor.description: col_count += 1 col_position = "Column{position}".format( position=col_count) field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor) tmp_file.flush() hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.log.info("Loading file into Hive") hive.load_file(tmp_file.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties)
def execution_data_task_id(ds, dag, **kwargs): v_date = kwargs.get('v_execution_date') v_day = kwargs.get('v_execution_day') v_hour = kwargs.get('v_execution_hour') hive_hook = HiveCliHook() """ #功能函数 alter语句: alter_partition() 删除分区: delete_partition() 生产success: touchz_success() #参数 is_countries_online --是否开通多国家业务 默认(true 开通) db_name --hive 数据库的名称 table_name --hive 表的名称 data_oss_path --oss 数据目录的地址 is_country_partition --是否有国家码分区,[默认(true 有country_code分区)] is_result_force_exist --数据是否强行产出,[默认(true 必须有数据才生成_SUCCESS)] false 数据没有也生成_SUCCESS execute_time --当前脚本执行时间(%Y-%m-%d %H:%M:%S) is_hour_task --是否开通小时级任务,[默认(false)] frame_type --模板类型(只有 is_hour_task:'true' 时生效): utc 产出分区为utc时间,local 产出分区为本地时间,[默认(utc)]。 is_offset --是否开启时间前后偏移(影响success 文件) execute_time_offset --执行时间偏移值(-1、0、1),在当前执行时间上,前后偏移原有时间,用于产出前后小时分区 business_key --产品线名称 #读取sql %_sql(ds,v_hour) """ args = [{ "dag": dag, "is_countries_online": "true", "db_name": db_name, "table_name": table_name, "data_oss_path": hdfs_path, "is_country_partition": "true", "is_result_force_exist": "false", "execute_time": v_date, "is_hour_task": "true", "frame_type": "local", "is_offset": "true", "execute_time_offset": -1, "business_key": "opay" }] cf = CountriesAppFrame(args) # 读取sql _sql = "\n" + cf.alter_partition( ) + "\n" + app_opay_life_payment_sum_ng_h_sql_task(ds, v_date) logging.info('Executing: %s', _sql) # 执行Hive hive_hook.run_cli(_sql) # 生产success cf.touchz_success()
def test_run_cli_with_hive_conf(self): hql = "set key;\n" \ "set airflow.ctx.dag_id;\nset airflow.ctx.dag_run_id;\n" \ "set airflow.ctx.task_id;\nset airflow.ctx.execution_date;\n" dag_id_ctx_var_name = \ AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format'] task_id_ctx_var_name = \ AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format'] execution_date_ctx_var_name = \ AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][ 'env_var_format'] dag_run_id_ctx_var_name = \ AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][ 'env_var_format'] os.environ[dag_id_ctx_var_name] = 'test_dag_id' os.environ[task_id_ctx_var_name] = 'test_task_id' os.environ[execution_date_ctx_var_name] = 'test_execution_date' os.environ[dag_run_id_ctx_var_name] = 'test_dag_run_id' hook = HiveCliHook() output = hook.run_cli(hql=hql, hive_conf={'key': 'value'}) self.assertIn('value', output) self.assertIn('test_dag_id', output) self.assertIn('test_task_id', output) self.assertIn('test_execution_date', output) self.assertIn('test_dag_run_id', output) del os.environ[dag_id_ctx_var_name] del os.environ[task_id_ctx_var_name] del os.environ[execution_date_ctx_var_name] del os.environ[dag_run_id_ctx_var_name]
def test_load_df_with_data_types(self, mock_run_cli): d = OrderedDict() d['b'] = [True] d['i'] = [-1] d['t'] = [1] d['f'] = [0.0] d['c'] = ['c'] d['M'] = [datetime.datetime(2018, 1, 1)] d['O'] = [object()] d['S'] = ['STRING'.encode('utf-8')] d['U'] = ['STRING'] d['V'] = [None] df = pd.DataFrame(d) hook = HiveCliHook() hook.load_df(df, 't') query = """ CREATE TABLE IF NOT EXISTS t ( b BOOLEAN, i BIGINT, t BIGINT, f DOUBLE, c STRING, M TIMESTAMP, O STRING, S STRING, U STRING, V STRING) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS textfile ; """ assertEqualIgnoreMultipleSpaces(self, mock_run_cli.call_args_list[0][0][0], query)
def get_driver_online_time(ds, **op_kwargs): dt = op_kwargs["ds_nodash"] conn = get_db_conn('timerange_conn_db') mcursor = conn.cursor() mcursor.execute(get_driver_id) result = mcursor.fetchone() conn.commit() mcursor.close() conn.close() processes = [] max_driver_id = result[0] logging.info('max driver id %d', max_driver_id) id_list = [x for x in range(1, max_driver_id+1)] part_size = 1000 index = 0 manager = Manager() rows = manager.list([]) while index < max_driver_id: p = Process(target=get_driver_timerange, args=(id_list[index:index + part_size], dt, rows)) index += part_size processes.append(p) p.start() for p in processes: p.join() if rows: query = """ INSERT OVERWRITE TABLE oride_dw_ods.{tab_name} PARTITION (dt='{dt}') VALUES {value} """.format(dt=ds, value=','.join(rows),tab_name=table_name) logging.info('import_driver_online_time run sql:%s' % query) hive_hook = HiveCliHook() hive_hook.run_cli(query)
def execute(self, context): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) self.log.info("Dumping MySQL query results to local file") conn = mysql.get_conn() cursor = conn.cursor() cursor.execute(self.sql) with NamedTemporaryFile("wb") as f: csv_writer = csv.writer(f, delimiter=self.delimiter, encoding="utf-8") field_dict = OrderedDict() for field in cursor.description: field_dict[field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor) f.flush() cursor.close() conn.close() self.log.info("Loading file into Hive") hive.load_file( f.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties)
def ddl(self): """ Retrieve table ddl """ table = request.args.get("table") sql = "SHOW CREATE TABLE {table};".format(table=table) hook = HiveCliHook(HIVE_CLI_CONN_ID) return hook.run_cli(sql)
def dwd_oride_driver_call_record_mid(ds, **kargs): hive_hook = HiveCliHook() # 读取sql _sql = dwd_oride_driver_call_record_mid_sql_task(ds) logging.info('Executing: %s', _sql) # 执行Hive hive_hook.run_cli(_sql)
def execution_data_task_id(ds, **kwargs): v_date = kwargs.get('v_execution_date') v_day = kwargs.get('v_execution_day') v_hour = kwargs.get('v_execution_hour') hive_hook = HiveCliHook() """ #功能函数 alter语句: alter_partition 删除分区: delete_partition 生产success: touchz_success #参数 第一个参数true: 所有国家是否上线。false 没有 第二个参数true: 数据目录是有country_code分区。false 没有 第三个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS #读取sql %_sql(ds,v_hour) 第一个参数ds: 天级任务 第二个参数v_hour: 小时级任务,需要使用 """ cf = CountriesPublicFrame("false", ds, db_name, table_name, hdfs_path, "true", "true") v_info = [{ "table": "oride_db.oride_data.data_user_estimate_records", "start_timeThour": "{v_day}T00".format(v_day=v_day), "end_dateThour": "{v_day}T23".format(v_day=v_day), "depend_dir": "oss://opay-datalake/oride_binlog" }] hcm = TaskHourSuccessCountMonitor(ds, v_info) hcm.HourSuccessCountMonitor() # 删除分区 # cf.delete_partition() # 读取sql _sql = "\n" + cf.alter_partition( ) + "\n" + dwd_oride_passenger_estimate_records_di_sql_task(ds) logging.info('Executing: %s', _sql) # 执行Hive hive_hook.run_cli(_sql) # 熔断数据,如果数据不能为0 # check_key_data_cnt_task(ds) # 生产success cf.touchz_success()
def test_load_file(self, mock_run_cli): filepath = "/path/to/input/file" table = "output_table" hook = HiveCliHook() hook.load_file(filepath=filepath, table=table, create=False) query = ("LOAD DATA LOCAL INPATH '{filepath}' " "OVERWRITE INTO TABLE {table} ;\n".format(filepath=filepath, table=table)) mock_run_cli.assert_called_with(query)
def test_load_file(self, mock_run_cli): filepath = "/path/to/input/file" table = "output_table" hook = HiveCliHook() hook.load_file(filepath=filepath, table=table, create=False) query = ("LOAD DATA LOCAL INPATH '{filepath}' " "OVERWRITE INTO TABLE {table} ;\n".format(filepath=filepath, table=table)) calls = [mock.call(';'), mock.call(query)] mock_run_cli.assert_has_calls(calls, any_order=True)
def execute(self, context): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.log.info("Extracting data from Hive") hive_table = 'druid.' + context['task_instance_key_str'].replace('.', '_') sql = self.sql.strip().strip(';') tblproperties = ''.join([", '{}' = '{}'" .format(k, v) for k, v in self.hive_tblproperties.items()]) hql = """\ SET mapred.output.compress=false; SET hive.exec.compress.output=false; DROP TABLE IF EXISTS {hive_table}; CREATE TABLE {hive_table} ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' STORED AS TEXTFILE TBLPROPERTIES ('serialization.null.format' = ''{tblproperties}) AS {sql} """.format(hive_table=hive_table, tblproperties=tblproperties, sql=sql) self.log.info("Running command:\n %s", hql) hive.run_cli(hql) m = HiveMetastoreHook(self.metastore_conn_id) # Get the Hive table and extract the columns t = m.get_table(hive_table) columns = [col.name for col in t.sd.cols] # Get the path on hdfs static_path = m.get_table(hive_table).sd.location schema, table = hive_table.split('.') druid = DruidHook(druid_ingest_conn_id=self.druid_ingest_conn_id) try: index_spec = self.construct_ingest_query( static_path=static_path, columns=columns, ) self.log.info("Inserting rows into Druid, hdfs path: %s", static_path) druid.submit_indexing_job(index_spec) self.log.info("Load seems to have succeeded!") finally: self.log.info( "Cleaning up by dropping the temp Hive table %s", hive_table ) hql = "DROP TABLE IF EXISTS {}".format(hive_table) hive.run_cli(hql)
def execution_data_task_id(ds, **kwargs): v_date = kwargs.get('v_execution_date') v_day = kwargs.get('v_execution_day') v_hour = kwargs.get('v_execution_hour') hive_hook = HiveCliHook() """ #功能函数 alter语句: alter_partition 删除分区: delete_partition 生产success: touchz_success #参数 第一个参数true: 数据目录是有country_code分区。false 没有 第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS #读取sql %_sql_task(ds,v_hour) 第一个参数ds: 天级任务 第二个参数v_hour: 小时级任务,需要使用 """ if datetime.strptime(ds, '%Y-%m-%d').weekday() == 6: cf = CountriesPublicFrame("false", ds, db_name, table_name, hdfs_path, "true", "false") else: cf = CountriesPublicFrame("false", ds, db_name, table_name, hdfs_path, "true", "true") # 删除分区 # cf.delete_partition() # 拼接SQL _sql = "\n" + cf.alter_partition( ) + "\n" + app_ocredit_phones_order_base_cube_w_sql_task(ds) logging.info('Executing: %s', _sql) # 执行Hive hive_hook.run_cli(_sql) # 熔断数据,如果数据不能为0 # check_key_data_cnt_task(ds) # 熔断数据 # check_key_data_task(ds) # 生产success cf.touchz_success()
def test_load_file(self, mock_run_cli): filepath = "/path/to/input/file" table = "output_table" hook = HiveCliHook() hook.load_file(filepath=filepath, table=table, create=False) query = ( "LOAD DATA LOCAL INPATH '{filepath}' " "OVERWRITE INTO TABLE {table} \n" .format(filepath=filepath, table=table) ) mock_run_cli.assert_called_with(query)
def test_load_df_with_optional_parameters(self, mock_to_csv, mock_load_file): hook = HiveCliHook() b = (True, False) for create, recreate in itertools.product(b, b): mock_load_file.reset_mock() hook.load_df(df=pd.DataFrame({"c": range(0, 10)}), table="t", create=create, recreate=recreate) mock_load_file.assert_called_once() kwargs = mock_load_file.call_args[1] self.assertEqual(kwargs["create"], create) self.assertEqual(kwargs["recreate"], recreate)
def test_get_proxy_user_value(self): from airflow.hooks.hive_hooks import HiveCliHook hook = HiveCliHook() returner = mock.MagicMock() returner.extra_dejson = {'proxy_user': '******'} hook.use_beeline = True hook.conn = returner # Run result = hook._prepare_cli_cmd() # Verify self.assertIn('hive.server2.proxy.user=a_user_proxy', result[2])
def test_load_df_with_optional_parameters(self, mock_to_csv, mock_load_file): hook = HiveCliHook() b = (True, False) for create, recreate in itertools.product(b, b): mock_load_file.reset_mock() hook.load_df(df=pd.DataFrame({"c": range(0, 10)}), table="t", create=create, recreate=recreate) assert mock_load_file.call_count == 1 kwargs = mock_load_file.call_args[1] self.assertEqual(kwargs["create"], create) self.assertEqual(kwargs["recreate"], recreate)
def execution_data_task_id(ds, **kwargs): v_date = kwargs.get('v_execution_date') v_day = kwargs.get('v_execution_day') v_hour = kwargs.get('v_execution_hour') hive_hook = HiveCliHook() """ #功能函数 alter语句: alter_partition 删除分区: delete_partition 生产success: touchz_success #参数 第一个参数true: 数据目录是有country_code分区。false 没有 第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS #读取sql %_sql_task(ds,v_hour) 第一个参数ds: 天级任务 第二个参数v_hour: 小时级任务,需要使用 """ cf = CountriesPublicFrame("false", ds, db_name, table_name, hdfs_path, "true", "true") #删除分区 #cf.delete_partition() #拼接SQL _sql = "\n" + cf.alter_partition( ) + "\n" + dwd_oride_strategy_data_invite_df_sql_task(ds) logging.info('Executing: %s', _sql) #执行Hive hive_hook.run_cli(_sql) #熔断数据,如果数据不能为0 #check_key_data_cnt_task(ds) #熔断数据 #check_key_data_task(ds) #生产success cf.touchz_success()
def query_hive(**kwargs): ti = kwargs['ti'] # get sha of latest commit v1 = ti.xcom_pull(key=None, task_ids='get_last_commit_task') json_value = json.loads(v1) sha = json_value['sha'] hive_cli = HiveCliHook() hql = "select * from mapr_music_updates where commit_sha = '" + sha + "';" latest_commit = hive_cli.run_cli(hql) changed = latest_commit.find(sha) == -1 ti.xcom_push(key='sha', value=sha) ti.xcom_push(key='is_changed', value=changed) return 'reimport_dataset_task' if changed else 'skip_reimport_dataset_task'
def execute(self, context): self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.s3 = S3Hook(s3_conn_id=self.s3_conn_id) logging.info("Downloading S3 file") if self.wildcard_match: if not self.s3.check_for_wildcard_key(self.s3_key): raise AirflowException("No key matches {0}".format(self.s3_key)) s3_key_object = self.s3.get_wildcard_key(self.s3_key) else: if not self.s3.check_for_key(self.s3_key): raise AirflowException( "The key {0} does not exists".format(self.s3_key)) s3_key_object = self.s3.get_key(self.s3_key) with NamedTemporaryFile("w") as f: logging.info("Dumping S3 key {0} contents to local" " file {1}".format(s3_key_object.key, f.name)) s3_key_object.get_contents_to_file(f) f.flush() self.s3.connection.close() if not self.headers: logging.info("Loading file into Hive") self.hive.load_file( f.name, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate) else: with open(f.name, 'r') as tmpf: if self.check_headers: header_l = tmpf.readline() header_line = header_l.rstrip() header_list = header_line.split(self.delimiter) field_names = list(self.field_dict.keys()) test_field_match = [h1.lower() == h2.lower() for h1, h2 in zip(header_list, field_names)] if not all(test_field_match): logging.warning("Headers do not match field names" "File headers:\n {header_list}\n" "Field names: \n {field_names}\n" "".format(**locals())) raise AirflowException("Headers do not match the " "field_dict keys") with NamedTemporaryFile("w") as f_no_headers: tmpf.seek(0) next(tmpf) for line in tmpf: f_no_headers.write(line) f_no_headers.flush() logging.info("Loading file without headers into Hive") self.hive.load_file( f_no_headers.name, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate)
def get_hook(self): return HiveCliHook( hive_cli_conn_id=self.hive_cli_conn_id, run_as=self.run_as, mapred_queue=self.mapred_queue, mapred_queue_priority=self.mapred_queue_priority, mapred_job_name=self.mapred_job_name)
def create_hive_external_table(db, table, conn, **op_kwargs): sqoopSchema = SqoopSchemaUpdate() response = sqoopSchema.update_hive_schema( hive_db=hive_db, hive_table=hive_table.format(bs=table), mysql_db=db, mysql_table=table, mysql_conn=conn ) #if response: # return True mysql_conn = get_db_conn(conn) mcursor = mysql_conn.cursor() sql = ''' select COLUMN_NAME, DATA_TYPE, COLUMN_COMMENT, COLUMN_TYPE from information_schema.COLUMNS where TABLE_SCHEMA='{db}' and TABLE_NAME='{table}' order by ORDINAL_POSITION '''.format(db=db, table=table) # logging.info(sql) mcursor.execute(sql) res = mcursor.fetchall() # logging.info(res) columns = [] for (name, type, comment, co_type) in res: if type.upper() == 'DECIMAL': columns.append("`%s` %s comment '%s'" % (name, co_type.replace('unsigned', '').replace('signed', ''), comment)) else: columns.append("`%s` %s comment '%s'" % (name, mysql_type_to_hive.get(type.upper(), 'string'), comment)) mysql_conn.close() # 创建hive数据表的sql hql = ods_create_table_hql.format( db_name=hive_db, table_name=hive_table.format(bs=table), columns=",\n".join(columns), hdfs_path=hdfs_path.format(bs=table) ) logging.info(hql) hive_hook = HiveCliHook() logging.info('Executing: %s', hql) hive_hook.run_cli(hql)
def merge_pre_hi_with_full_data_task(hive_db, hive_h_his_table_name, hive_hi_table_name, mysql_db_name, mysql_table_name, mysql_conn, sqoop_temp_db_name, sqoop_table_name, pt, now_hour, pre_day, pre_hour_day, pre_hour, is_must_have_data, **kwargs): sqoopSchema = SqoopSchemaUpdate() hive_columns = sqoopSchema.get_hive_column_name(hive_db, hive_h_his_table_name) mysql_columns = sqoopSchema.get_mysql_column_name(mysql_db_name, mysql_table_name, mysql_conn) pre_day_ms = int(time.mktime(time.strptime(pre_day, "%Y-%m-%d"))) * 1000 hql = MERGE_HI_WITH_FULL_SQL.format( columns=',\n'.join(hive_columns), pt=pt, now_hour=now_hour, db_name=hive_db, mysql_db_name=mysql_db_name, hive_h_his_table_name=hive_h_his_table_name, hive_hi_table_name=hive_hi_table_name, mysql_table_name=mysql_table_name, pre_day_ms=pre_day_ms, mysql_columns=',\n'.join(mysql_columns), sqoop_temp_db_name=sqoop_temp_db_name, sqoop_table_name=sqoop_table_name) hive_hook = HiveCliHook() # 读取sql logging.info('Executing: %s', hql) # 执行Hive hive_hook.run_cli(hql) # 生成_SUCCESS """ 第一个参数true: 数据目录是有country_code分区。false 没有 第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS """ TaskTouchzSuccess().countries_touchz_success( pt, hive_db, hive_h_his_table_name, H_HIS_OSS_PATH % hive_h_his_table_name, "false", is_must_have_data, now_hour)
def execution_data_task_id(ds, **kargs): hive_hook = HiveCliHook() # 读取sql _sql = dwm_oride_driver_act_w_sql_task(ds) logging.info('Executing: %s', _sql) # 执行Hive hive_hook.run_cli(_sql) # 生成_SUCCESS """ 第一个参数true: 数据目录是有country_code分区。false 没有 第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS """ TaskTouchzSuccess().countries_touchz_success('{pt}'.format(pt=airflow.macros.ds_add(ds, +6)), db_name, table_name, hdfs_path, "true", "true")
def execution_data_task_id(ds, ds_nodash, **kargs): hive_hook = HiveCliHook() # 读取sql _sql = app_opay_active_user_report_w_sql_task(ds,ds_nodash) logging.info('Executing: %s', _sql) # 执行Hive hive_hook.run_cli(_sql) # 生成_SUCCESS """ 第一个参数true: 数据目录是有country_code分区。false 没有 第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS """ TaskTouchzSuccess().countries_touchz_success(ds, db_name, table_name, hdfs_path, "true", "true")
def execution_data_task_id(ds, **kargs): hive_hook = HiveCliHook() # 读取sql _sql = ods_sqoop_base_bd_agent_df_sql_task(ds) logging.info('Executing: %s', _sql) # 执行Hive hive_hook.run_cli(_sql) # 生成_SUCCESS """ 第一个参数true: 数据目录是有country_code分区。false 没有 第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS """ TaskTouchzSuccess().countries_touchz_success(ds, db_name, table_name, hdfs_path, "false", "true")
def execution_act_driver_task(ds, **kargs): hive_hook = HiveCliHook() # 读取sql _sql = app_oride_act_driver_cohort_w_sql_task(ds) # 执行hive hive_hook.run_cli(_sql) # 生成_SUCCESS """ 第一个参数true: 数据目录是有country_code分区。false 没有 第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS """ pt = airflow.macros.ds_add(ds, +6) hdfs_path = get_table_info(3)[1] TaskTouchzSuccess().countries_touchz_success(pt, "oride_dw", get_table_info(3)[0], hdfs_path, "true", "true")
def test_load_df(self, mock_to_csv, mock_load_file): df = pd.DataFrame({"c": ["foo", "bar", "baz"]}) table = "t" delimiter = "," encoding = "utf-8" hook = HiveCliHook() hook.load_df(df=df, table=table, delimiter=delimiter, encoding=encoding) mock_to_csv.assert_called_once() kwargs = mock_to_csv.call_args[1] self.assertEqual(kwargs["header"], False) self.assertEqual(kwargs["index"], False) self.assertEqual(kwargs["sep"], delimiter) mock_load_file.assert_called_once() kwargs = mock_load_file.call_args[1] self.assertEqual(kwargs["delimiter"], delimiter) self.assertEqual(kwargs["field_dict"], {"c": u"STRING"}) self.assertTrue(isinstance(kwargs["field_dict"], OrderedDict)) self.assertEqual(kwargs["table"], table)
class S3ToHiveTransfer(BaseOperator): """ Moves data from S3 to Hive. The operator downloads a file from S3, stores the file locally before loading it into a Hive table. If the ``create`` or ``recreate`` arguments are set to ``True``, a ``CREATE TABLE`` and ``DROP TABLE`` statements are generated. Hive data types are inferred from the cursor's metadata from. Note that the table generated in Hive uses ``STORED AS textfile`` which isn't the most efficient serialization format. If a large amount of data is loaded and/or if the tables gets queried considerably, you may want to use this operator only to stage the data into a temporary table before loading it into its final destination using a ``HiveOperator``. :param s3_key: The key to be retrieved from S3 :type s3_key: str :param field_dict: A dictionary of the fields name in the file as keys and their Hive types as values :type field_dict: dict :param hive_table: target Hive table, use dot notation to target a specific database :type hive_table: str :param create: whether to create the table if it doesn't exist :type create: bool :param recreate: whether to drop and recreate the table at every execution :type recreate: bool :param partition: target partition as a dict of partition columns and values :type partition: dict :param headers: whether the file contains column names on the first line :type headers: bool :param check_headers: whether the column names on the first line should be checked against the keys of field_dict :type check_headers: bool :param wildcard_match: whether the s3_key should be interpreted as a Unix wildcard pattern :type wildcard_match: bool :param delimiter: field delimiter in the file :type delimiter: str :param aws_conn_id: source s3 connection :type aws_conn_id: str :param hive_cli_conn_id: destination hive connection :type hive_cli_conn_id: str :param input_compressed: Boolean to determine if file decompression is required to process headers :type input_compressed: bool :param tblproperties: TBLPROPERTIES of the hive table being created :type tblproperties: dict """ template_fields = ('s3_key', 'partition', 'hive_table') template_ext = () ui_color = '#a0e08c' @apply_defaults def __init__( self, s3_key, field_dict, hive_table, delimiter=',', create=True, recreate=False, partition=None, headers=False, check_headers=False, wildcard_match=False, aws_conn_id='aws_default', hive_cli_conn_id='hive_cli_default', input_compressed=False, tblproperties=None, *args, **kwargs): super(S3ToHiveTransfer, self).__init__(*args, **kwargs) self.s3_key = s3_key self.field_dict = field_dict self.hive_table = hive_table self.delimiter = delimiter self.create = create self.recreate = recreate self.partition = partition self.headers = headers self.check_headers = check_headers self.wildcard_match = wildcard_match self.hive_cli_conn_id = hive_cli_conn_id self.aws_conn_id = aws_conn_id self.input_compressed = input_compressed self.tblproperties = tblproperties if (self.check_headers and not (self.field_dict is not None and self.headers)): raise AirflowException("To check_headers provide " + "field_dict and headers") def execute(self, context): # Downloading file from S3 self.s3 = S3Hook(aws_conn_id=self.aws_conn_id) self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.log.info("Downloading S3 file") if self.wildcard_match: if not self.s3.check_for_wildcard_key(self.s3_key): raise AirflowException("No key matches {0}" .format(self.s3_key)) s3_key_object = self.s3.get_wildcard_key(self.s3_key) else: if not self.s3.check_for_key(self.s3_key): raise AirflowException( "The key {0} does not exists".format(self.s3_key)) s3_key_object = self.s3.get_key(self.s3_key) root, file_ext = os.path.splitext(s3_key_object.key) with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\ NamedTemporaryFile(mode="wb", dir=tmp_dir, suffix=file_ext) as f: self.log.info("Dumping S3 key {0} contents to local file {1}" .format(s3_key_object.key, f.name)) s3_key_object.download_fileobj(f) f.flush() if not self.headers: self.log.info("Loading file %s into Hive", f.name) self.hive.load_file( f.name, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties) else: # Decompressing file if self.input_compressed: self.log.info("Uncompressing file %s", f.name) fn_uncompressed = uncompress_file(f.name, file_ext, tmp_dir) self.log.info("Uncompressed to %s", fn_uncompressed) # uncompressed file available now so deleting # compressed file to save disk space f.close() else: fn_uncompressed = f.name # Testing if header matches field_dict if self.check_headers: self.log.info("Matching file header against field_dict") header_list = self._get_top_row_as_list(fn_uncompressed) if not self._match_headers(header_list): raise AirflowException("Header check failed") # Deleting top header row self.log.info("Removing header from file %s", fn_uncompressed) headless_file = ( self._delete_top_row_and_compress(fn_uncompressed, file_ext, tmp_dir)) self.log.info("Headless file %s", headless_file) self.log.info("Loading file %s into Hive", headless_file) self.hive.load_file(headless_file, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties) def _get_top_row_as_list(self, file_name): with open(file_name, 'rt') as f: header_line = f.readline().strip() header_list = header_line.split(self.delimiter) return header_list def _match_headers(self, header_list): if not header_list: raise AirflowException("Unable to retrieve header row from file") field_names = self.field_dict.keys() if len(field_names) != len(header_list): self.log.warning("Headers count mismatch" "File headers:\n {header_list}\n" "Field names: \n {field_names}\n" "".format(**locals())) return False test_field_match = [h1.lower() == h2.lower() for h1, h2 in zip(header_list, field_names)] if not all(test_field_match): self.log.warning("Headers do not match field names" "File headers:\n {header_list}\n" "Field names: \n {field_names}\n" "".format(**locals())) return False else: return True def _delete_top_row_and_compress( self, input_file_name, output_file_ext, dest_dir): # When output_file_ext is not defined, file is not compressed open_fn = open if output_file_ext.lower() == '.gz': open_fn = gzip.GzipFile elif output_file_ext.lower() == '.bz2': open_fn = bz2.BZ2File os_fh_output, fn_output = \ tempfile.mkstemp(suffix=output_file_ext, dir=dest_dir) with open(input_file_name, 'rb') as f_in,\ open_fn(fn_output, 'wb') as f_out: f_in.seek(0) next(f_in) for line in f_in: f_out.write(line) return fn_output
def execute(self, context): # Downloading file from S3 self.s3 = S3Hook(aws_conn_id=self.aws_conn_id) self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.log.info("Downloading S3 file") if self.wildcard_match: if not self.s3.check_for_wildcard_key(self.s3_key): raise AirflowException("No key matches {0}" .format(self.s3_key)) s3_key_object = self.s3.get_wildcard_key(self.s3_key) else: if not self.s3.check_for_key(self.s3_key): raise AirflowException( "The key {0} does not exists".format(self.s3_key)) s3_key_object = self.s3.get_key(self.s3_key) root, file_ext = os.path.splitext(s3_key_object.key) with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\ NamedTemporaryFile(mode="wb", dir=tmp_dir, suffix=file_ext) as f: self.log.info("Dumping S3 key {0} contents to local file {1}" .format(s3_key_object.key, f.name)) s3_key_object.download_fileobj(f) f.flush() if not self.headers: self.log.info("Loading file %s into Hive", f.name) self.hive.load_file( f.name, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties) else: # Decompressing file if self.input_compressed: self.log.info("Uncompressing file %s", f.name) fn_uncompressed = uncompress_file(f.name, file_ext, tmp_dir) self.log.info("Uncompressed to %s", fn_uncompressed) # uncompressed file available now so deleting # compressed file to save disk space f.close() else: fn_uncompressed = f.name # Testing if header matches field_dict if self.check_headers: self.log.info("Matching file header against field_dict") header_list = self._get_top_row_as_list(fn_uncompressed) if not self._match_headers(header_list): raise AirflowException("Header check failed") # Deleting top header row self.log.info("Removing header from file %s", fn_uncompressed) headless_file = ( self._delete_top_row_and_compress(fn_uncompressed, file_ext, tmp_dir)) self.log.info("Headless file %s", headless_file) self.log.info("Loading file %s into Hive", headless_file) self.hive.load_file(headless_file, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties)
def test_run_cli(self): hook = HiveCliHook() hook.run_cli("SHOW DATABASES")
class S3ToHiveTransfer(BaseOperator): """ Moves data from S3 to Hive. The operator downloads a file from S3, stores the file locally before loading it into a Hive table. If the ``create`` or ``recreate`` arguments are set to ``True``, a ``CREATE TABLE`` and ``DROP TABLE`` statements are generated. Hive data types are inferred from the cursor's metadata from. Note that the table generated in Hive uses ``STORED AS textfile`` which isn't the most efficient serialization format. If a large amount of data is loaded and/or if the tables gets queried considerably, you may want to use this operator only to stage the data into a temporary table before loading it into its final destination using a ``HiveOperator``. :param s3_key: The key to be retrieved from S3 :type s3_key: str :param field_dict: A dictionary of the fields name in the file as keys and their Hive types as values :type field_dict: dict :param hive_table: target Hive table, use dot notation to target a specific database :type hive_table: str :param create: whether to create the table if it doesn't exist :type create: bool :param recreate: whether to drop and recreate the table at every execution :type recreate: bool :param partition: target partition as a dict of partition columns and values :type partition: dict :param headers: whether the file contains column names on the first line :type headers: bool :param check_headers: whether the column names on the first line should be checked against the keys of field_dict :type check_headers: bool :param wildcard_match: whether the s3_key should be interpreted as a Unix wildcard pattern :type wildcard_match: bool :param delimiter: field delimiter in the file :type delimiter: str :param s3_conn_id: source s3 connection :type s3_conn_id: str :param hive_conn_id: destination hive connection :type hive_conn_id: str """ template_fields = ('s3_key', 'partition', 'hive_table') template_ext = () ui_color = '#a0e08c' @apply_defaults def __init__( self, s3_key, field_dict, hive_table, delimiter=',', create=True, recreate=False, partition=None, headers=False, check_headers=False, wildcard_match=False, s3_conn_id='s3_default', hive_cli_conn_id='hive_cli_default', *args, **kwargs): super(S3ToHiveTransfer, self).__init__(*args, **kwargs) self.s3_key = s3_key self.field_dict = field_dict self.hive_table = hive_table self.delimiter = delimiter self.create = create self.recreate = recreate self.partition = partition self.headers = headers self.check_headers = check_headers self.wildcard_match = wildcard_match self.hive_cli_conn_id = hive_cli_conn_id self.s3_conn_id = s3_conn_id def execute(self, context): self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.s3 = S3Hook(s3_conn_id=self.s3_conn_id) logging.info("Downloading S3 file") if self.wildcard_match: if not self.s3.check_for_wildcard_key(self.s3_key): raise AirflowException("No key matches {0}".format(self.s3_key)) s3_key_object = self.s3.get_wildcard_key(self.s3_key) else: if not self.s3.check_for_key(self.s3_key): raise AirflowException( "The key {0} does not exists".format(self.s3_key)) s3_key_object = self.s3.get_key(self.s3_key) with NamedTemporaryFile("w") as f: logging.info("Dumping S3 key {0} contents to local" " file {1}".format(s3_key_object.key, f.name)) s3_key_object.get_contents_to_file(f) f.flush() self.s3.connection.close() if not self.headers: logging.info("Loading file into Hive") self.hive.load_file( f.name, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate) else: with open(f.name, 'r') as tmpf: if self.check_headers: header_l = tmpf.readline() header_line = header_l.rstrip() header_list = header_line.split(self.delimiter) field_names = list(self.field_dict.keys()) test_field_match = [h1.lower() == h2.lower() for h1, h2 in zip(header_list, field_names)] if not all(test_field_match): logging.warning("Headers do not match field names" "File headers:\n {header_list}\n" "Field names: \n {field_names}\n" "".format(**locals())) raise AirflowException("Headers do not match the " "field_dict keys") with NamedTemporaryFile("w") as f_no_headers: tmpf.seek(0) next(tmpf) for line in tmpf: f_no_headers.write(line) f_no_headers.flush() logging.info("Loading file without headers into Hive") self.hive.load_file( f_no_headers.name, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate)
def ddl(self): table = request.args.get("table") sql = "SHOW CREATE TABLE {table};".format(table=table) h = HiveCliHook(HIVE_CLI_CONN_ID) return h.run_cli(sql)
def execute(self, context): # Downloading file from S3 self.s3 = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) self.log.info("Downloading S3 file") if self.wildcard_match: if not self.s3.check_for_wildcard_key(self.s3_key): raise AirflowException("No key matches {0}" .format(self.s3_key)) s3_key_object = self.s3.get_wildcard_key(self.s3_key) else: if not self.s3.check_for_key(self.s3_key): raise AirflowException( "The key {0} does not exists".format(self.s3_key)) s3_key_object = self.s3.get_key(self.s3_key) root, file_ext = os.path.splitext(s3_key_object.key) if (self.select_expression and self.input_compressed and file_ext.lower() != '.gz'): raise AirflowException("GZIP is the only compression " + "format Amazon S3 Select supports") with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\ NamedTemporaryFile(mode="wb", dir=tmp_dir, suffix=file_ext) as f: self.log.info( "Dumping S3 key %s contents to local file %s", s3_key_object.key, f.name ) if self.select_expression: option = {} if self.headers: option['FileHeaderInfo'] = 'USE' if self.delimiter: option['FieldDelimiter'] = self.delimiter input_serialization = {'CSV': option} if self.input_compressed: input_serialization['CompressionType'] = 'GZIP' content = self.s3.select_key( bucket_name=s3_key_object.bucket_name, key=s3_key_object.key, expression=self.select_expression, input_serialization=input_serialization ) f.write(content.encode("utf-8")) else: s3_key_object.download_fileobj(f) f.flush() if self.select_expression or not self.headers: self.log.info("Loading file %s into Hive", f.name) self.hive.load_file( f.name, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties) else: # Decompressing file if self.input_compressed: self.log.info("Uncompressing file %s", f.name) fn_uncompressed = uncompress_file(f.name, file_ext, tmp_dir) self.log.info("Uncompressed to %s", fn_uncompressed) # uncompressed file available now so deleting # compressed file to save disk space f.close() else: fn_uncompressed = f.name # Testing if header matches field_dict if self.check_headers: self.log.info("Matching file header against field_dict") header_list = self._get_top_row_as_list(fn_uncompressed) if not self._match_headers(header_list): raise AirflowException("Header check failed") # Deleting top header row self.log.info("Removing header from file %s", fn_uncompressed) headless_file = ( self._delete_top_row_and_compress(fn_uncompressed, file_ext, tmp_dir)) self.log.info("Headless file %s", headless_file) self.log.info("Loading file %s into Hive", headless_file) self.hive.load_file(headless_file, self.hive_table, field_dict=self.field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate, tblproperties=self.tblproperties)