コード例 #1
1
ファイル: hive_to_druid.py プロジェクト: asnir/airflow
    def execute(self, context):
        hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        logging.info("Extracting data from Hive")
        hive_table = "druid." + context["task_instance_key_str"].replace(".", "_")
        sql = self.sql.strip().strip(";")
        hql = """\
        set mapred.output.compress=false;
        set hive.exec.compress.output=false;
        DROP TABLE IF EXISTS {hive_table};
        CREATE TABLE {hive_table}
        ROW FORMAT DELIMITED FIELDS TERMINATED BY  '\t'
        STORED AS TEXTFILE
        TBLPROPERTIES ('serialization.null.format' = '')
        AS
        {sql}
        """.format(
            **locals()
        )
        logging.info("Running command:\n {}".format(hql))
        hive.run_cli(hql)

        m = HiveMetastoreHook(self.metastore_conn_id)
        t = m.get_table(hive_table)

        columns = [col.name for col in t.sd.cols]

        hdfs_uri = m.get_table(hive_table).sd.location
        pos = hdfs_uri.find("/user")
        static_path = hdfs_uri[pos:]

        schema, table = hive_table.split(".")

        druid = DruidHook(druid_ingest_conn_id=self.druid_ingest_conn_id)
        logging.info("Inserting rows into Druid")
        logging.info("HDFS path: " + static_path)

        try:
            druid.load_from_hdfs(
                datasource=self.druid_datasource,
                intervals=self.intervals,
                static_path=static_path,
                ts_dim=self.ts_dim,
                columns=columns,
                num_shards=self.num_shards,
                target_partition_size=self.target_partition_size,
                query_granularity=self.query_granularity,
                segment_granularity=self.segment_granularity,
                metric_spec=self.metric_spec,
                hadoop_dependency_coordinates=self.hadoop_dependency_coordinates,
            )
            logging.info("Load seems to have succeeded!")
        finally:
            logging.info("Cleaning up by dropping the temp " "Hive table {}".format(hive_table))
            hql = "DROP TABLE IF EXISTS {}".format(hive_table)
            hive.run_cli(hql)
コード例 #2
0
def merge_pre_hi_data_task(hive_db, hive_all_hi_table_name, hive_hi_table_name,
                           is_must_have_data, pt, now_hour, pre_hour_day,
                           pre_hour, **kwargs):
    sqoopSchema = SqoopSchemaUpdate()
    hive_columns = sqoopSchema.get_hive_column_name(hive_db,
                                                    hive_all_hi_table_name)

    hql = ADD_HI_SQL.format(db_name=hive_db,
                            hive_all_hi_table_name=hive_all_hi_table_name,
                            hive_hi_table_name=hive_hi_table_name,
                            pt=pt,
                            now_hour=now_hour,
                            pre_hour_day=pre_hour_day,
                            pre_hour=pre_hour,
                            columns=',\n'.join(hive_columns))

    hive_hook = HiveCliHook()

    # 读取sql

    logging.info('Executing: %s', hql)

    # 执行Hive
    hive_hook.run_cli(hql)

    # 生成_SUCCESS
    """
    第一个参数true: 数据目录是有country_code分区。false 没有
    第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS 

    """
    TaskTouchzSuccess().countries_touchz_success(
        pt, hive_db, hive_all_hi_table_name,
        ALL_HI_OSS_PATH % hive_all_hi_table_name, "false", is_must_have_data,
        now_hour)
コード例 #3
0
    def execute(self, context):
        hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)

        self.log.info("Dumping MySQL query results to local file")
        conn = mysql.get_conn()
        cursor = conn.cursor()
        cursor.execute(self.sql)
        with NamedTemporaryFile("wb") as f:
            csv_writer = csv.writer(f,
                                    delimiter=self.delimiter,
                                    quoting=self.quoting,
                                    quotechar=self.quotechar,
                                    escapechar=self.escapechar,
                                    encoding="utf-8")
            field_dict = OrderedDict()
            for field in cursor.description:
                field_dict[field[0]] = self.type_map(field[1])
            csv_writer.writerows(cursor)
            f.flush()
            cursor.close()
            conn.close()
            self.log.info("Loading file into Hive")
            hive.load_file(f.name,
                           self.hive_table,
                           field_dict=field_dict,
                           create=self.create,
                           partition=self.partition,
                           delimiter=self.delimiter,
                           recreate=self.recreate,
                           tblproperties=self.tblproperties)
コード例 #4
0
def execution_data_task_id(ds,**kargs):

    hive_hook = HiveCliHook()

    TaskTouchzSuccess().del_path(ds,db_name,table_name,hdfs_path,"true","true")

    #读取sql
    _sql=dim_oride_city_sql_task(ds)

    logging.info('Executing: %s', _sql)

    #执行Hive
    hive_hook.run_cli(_sql)

    #熔断数据,如果数据不能为0
    check_key_data_cnt_task(ds)

    #熔断数据
    check_key_data_task(ds)

    #生成_SUCCESS
    """
    第一个参数true: 数据目录是有country_code分区。false 没有
    第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS 

    """
    TaskTouchzSuccess().countries_touchz_success(ds,db_name,table_name,hdfs_path,"true","true")
コード例 #5
0
def execution_data_task_id(ds, dag, **kwargs):
    v_date = kwargs.get('v_execution_date')
    v_day = kwargs.get('v_execution_day')
    v_hour = kwargs.get('v_execution_hour')

    hive_hook = HiveCliHook()

    """
        #功能函数
            alter语句: alter_partition()
            删除分区: delete_partition()
            生产success: touchz_success()

        #参数
            is_countries_online --是否开通多国家业务 默认(true 开通)
            db_name --hive 数据库的名称
            table_name --hive 表的名称
            data_oss_path --oss 数据目录的地址
            is_country_partition --是否有国家码分区,[默认(true 有country_code分区)]
            is_result_force_exist --数据是否强行产出,[默认(true 必须有数据才生成_SUCCESS)] false 数据没有也生成_SUCCESS 
            execute_time --当前脚本执行时间(%Y-%m-%d %H:%M:%S)
            is_hour_task --是否开通小时级任务,[默认(false)]
            frame_type --模板类型(只有 is_hour_task:'true' 时生效): utc 产出分区为utc时间,local 产出分区为本地时间,[默认(utc)]。

        #读取sql
            %_sql(ds,v_hour)

    """

    args = [
        {
            "dag": dag,
            "is_countries_online": "true",
            "db_name": db_name,
            "table_name": table_name,
            "data_oss_path": hdfs_path,
            "is_country_partition": "true",
            "is_result_force_exist": "true",
            "execute_time": v_date,
            "is_hour_task": "true",
            "frame_type": "local"
        }
    ]

    cf = CountriesPublicFrame_dev(args)


    # 读取sql
    _sql="\n"+cf.alter_partition()+"\n"+dim_opay_terminal_base_hf_sql_task(ds, v_date)


    logging.info('Executing: %s',_sql)

    # 执行Hive
    hive_hook.run_cli(_sql)



    # 生产success
    cf.touchz_success()
コード例 #6
0
ファイル: vertica_to_hive.py プロジェクト: zyh1690/airflow
    def execute(self, context):
        hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id)

        self.log.info("Dumping Vertica query results to local file")
        conn = vertica.get_conn()
        cursor = conn.cursor()
        cursor.execute(self.sql)
        with NamedTemporaryFile("w") as f:
            csv_writer = csv.writer(f,
                                    delimiter=self.delimiter,
                                    encoding='utf-8')
            field_dict = OrderedDict()
            col_count = 0
            for field in cursor.description:
                col_count += 1
                col_position = "Column{position}".format(position=col_count)
                field_dict[col_position if field[0] == '' else field[0]] = \
                    self.type_map(field[1])
            csv_writer.writerows(cursor.iterate())
            f.flush()
            cursor.close()
            conn.close()
            self.log.info("Loading file into Hive")
            hive.load_file(f.name,
                           self.hive_table,
                           field_dict=field_dict,
                           create=self.create,
                           partition=self.partition,
                           delimiter=self.delimiter,
                           recreate=self.recreate)
コード例 #7
0
    def execute(self, context):
        hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        mssql = MsSqlHook(mssql_conn_id=self.mssql_conn_id)

        logging.info("Dumping Microsoft SQL Server query results to local file")
        conn = mssql.get_conn()
        cursor = conn.cursor()
        cursor.execute(self.sql)
        with NamedTemporaryFile("w") as f:
            csv_writer = csv.writer(f, delimiter=self.delimiter, encoding='utf-8')
            field_dict = OrderedDict()
            col_count = 0
            for field in cursor.description:
                col_count += 1
                col_position = "Column{position}".format(position=col_count)
                field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1])
            csv_writer.writerows(cursor)
            f.flush()
            cursor.close()
            conn.close()
            logging.info("Loading file into Hive")
            hive.load_file(
                f.name,
                self.hive_table,
                field_dict=field_dict,
                create=self.create,
                partition=self.partition,
                delimiter=self.delimiter,
                recreate=self.recreate)
コード例 #8
0
    def execute(self, context):
        hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)

        logging.info("Dumping MySQL query results to local file")
        conn = mysql.get_conn()
        cursor = conn.cursor(MySQLdb.cursors.SSCursor)
        cursor.execute(self.sql)
        with NamedTemporaryFile("wb") as f:
            csv_writer = csv.writer(f,
                                    delimiter=self.delimiter,
                                    encoding="utf-8")
            field_dict = OrderedDict()
            for field in cursor.description:
                field_dict[field[0]] = self.type_map(field[1])
            # csv_writer.writerows(cursor)
            while True:
                row = cursor.fetchone()
                if not row:
                    break
                csv_writer.writerow(row)
            f.flush()
            cursor.close()
            conn.close()
            logging.info("Loading file into Hive")
            hive.load_file(f.name,
                           self.hive_table,
                           field_dict=field_dict,
                           create=self.create,
                           partition=self.partition,
                           delimiter=self.delimiter,
                           recreate=self.recreate)
コード例 #9
0
    def test_load_df_with_data_types(self, mock_run_cli):
        d = OrderedDict()
        d['b'] = [True]
        d['i'] = [-1]
        d['t'] = [1]
        d['f'] = [0.0]
        d['c'] = ['c']
        d['M'] = [datetime.datetime(2018, 1, 1)]
        d['O'] = [object()]
        d['S'] = [b'STRING']
        d['U'] = ['STRING']
        d['V'] = [None]
        df = pd.DataFrame(d)

        hook = HiveCliHook()
        hook.load_df(df, 't')

        query = """
            CREATE TABLE IF NOT EXISTS t (
                b BOOLEAN,
                i BIGINT,
                t BIGINT,
                f DOUBLE,
                c STRING,
                M TIMESTAMP,
                O STRING,
                S STRING,
                U STRING,
                V STRING)
            ROW FORMAT DELIMITED
            FIELDS TERMINATED BY ','
            STORED AS textfile
            ;
        """
        assertEqualIgnoreMultipleSpaces(self, mock_run_cli.call_args_list[0][0][0], query)
コード例 #10
0
    def test_load_file_create_table(self, mock_run_cli):
        filepath = "/path/to/input/file"
        table = "output_table"
        field_dict = OrderedDict([("name", "string"), ("gender", "string")])
        fields = ",\n    ".join([k + ' ' + v for k, v in field_dict.items()])

        hook = HiveCliHook()
        hook.load_file(filepath=filepath,
                       table=table,
                       field_dict=field_dict,
                       create=True,
                       recreate=True)

        create_table = ("DROP TABLE IF EXISTS {table};\n"
                        "CREATE TABLE IF NOT EXISTS {table} (\n{fields})\n"
                        "ROW FORMAT DELIMITED\n"
                        "FIELDS TERMINATED BY ','\n"
                        "STORED AS textfile\n;".format(table=table,
                                                       fields=fields))

        load_data = ("LOAD DATA LOCAL INPATH '{filepath}' "
                     "OVERWRITE INTO TABLE {table} ;\n".format(
                         filepath=filepath, table=table))
        calls = [mock.call(create_table), mock.call(load_data)]
        mock_run_cli.assert_has_calls(calls, any_order=True)
コード例 #11
0
def execution_data_task_id(ds, **kwargs):
    v_date = kwargs.get('v_execution_date')
    v_day = kwargs.get('v_execution_day')
    v_hour = kwargs.get('v_execution_hour')

    hive_hook = HiveCliHook()

    # 读取sql
    _sql = dwd_oride_driver_cheating_detection_hi_sql_task(ds, v_hour)

    logging.info('Executing: %s', _sql)

    # 执行Hive
    hive_hook.run_cli(_sql)

    # 熔断数据
    # check_key_data_task(ds)

    # 生成_SUCCESS
    """
    第一个参数true: 数据目录是有country_code分区。false 没有
    第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS 

    """

    TaskTouchzSuccess().countries_touchz_success(ds, db_name, table_name,
                                                 hdfs_path, "true", "false",
                                                 v_hour)
コード例 #12
0
    def execute(self, context):
        mssql = MsSqlHook(mssql_conn_id=self.mssql_conn_id)
        self.log.info(
            "Dumping Microsoft SQL Server query results to local file")
        with mssql.get_conn() as conn:
            with conn.cursor() as cursor:
                cursor.execute(self.sql)
                with NamedTemporaryFile("w") as tmp_file:
                    csv_writer = csv.writer(tmp_file,
                                            delimiter=self.delimiter,
                                            encoding='utf-8')
                    field_dict = OrderedDict()
                    col_count = 0
                    for field in cursor.description:
                        col_count += 1
                        col_position = "Column{position}".format(
                            position=col_count)
                        field_dict[col_position if field[0] ==
                                   '' else field[0]] = self.type_map(field[1])
                    csv_writer.writerows(cursor)
                    tmp_file.flush()

            hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
            self.log.info("Loading file into Hive")
            hive.load_file(tmp_file.name,
                           self.hive_table,
                           field_dict=field_dict,
                           create=self.create,
                           partition=self.partition,
                           delimiter=self.delimiter,
                           recreate=self.recreate,
                           tblproperties=self.tblproperties)
コード例 #13
0
def execution_data_task_id(ds, dag, **kwargs):
    v_date = kwargs.get('v_execution_date')
    v_day = kwargs.get('v_execution_day')
    v_hour = kwargs.get('v_execution_hour')

    hive_hook = HiveCliHook()
    """
        #功能函数
            alter语句: alter_partition()
            删除分区: delete_partition()
            生产success: touchz_success()

        #参数
            is_countries_online --是否开通多国家业务 默认(true 开通)
            db_name --hive 数据库的名称
            table_name --hive 表的名称
            data_oss_path --oss 数据目录的地址
            is_country_partition --是否有国家码分区,[默认(true 有country_code分区)]
            is_result_force_exist --数据是否强行产出,[默认(true 必须有数据才生成_SUCCESS)] false 数据没有也生成_SUCCESS 
            execute_time --当前脚本执行时间(%Y-%m-%d %H:%M:%S)
            is_hour_task --是否开通小时级任务,[默认(false)]
            frame_type --模板类型(只有 is_hour_task:'true' 时生效): utc 产出分区为utc时间,local 产出分区为本地时间,[默认(utc)]。
            is_offset --是否开启时间前后偏移(影响success 文件)
            execute_time_offset --执行时间偏移值(-1、0、1),在当前执行时间上,前后偏移原有时间,用于产出前后小时分区
            business_key --产品线名称

        #读取sql
            %_sql(ds,v_hour)

    """

    args = [{
        "dag": dag,
        "is_countries_online": "true",
        "db_name": db_name,
        "table_name": table_name,
        "data_oss_path": hdfs_path,
        "is_country_partition": "true",
        "is_result_force_exist": "false",
        "execute_time": v_date,
        "is_hour_task": "true",
        "frame_type": "local",
        "is_offset": "true",
        "execute_time_offset": -1,
        "business_key": "opay"
    }]

    cf = CountriesAppFrame(args)

    # 读取sql
    _sql = "\n" + cf.alter_partition(
    ) + "\n" + app_opay_life_payment_sum_ng_h_sql_task(ds, v_date)

    logging.info('Executing: %s', _sql)

    # 执行Hive
    hive_hook.run_cli(_sql)

    # 生产success
    cf.touchz_success()
コード例 #14
0
    def test_run_cli_with_hive_conf(self):
        hql = "set key;\n" \
              "set airflow.ctx.dag_id;\nset airflow.ctx.dag_run_id;\n" \
              "set airflow.ctx.task_id;\nset airflow.ctx.execution_date;\n"

        dag_id_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format']
        task_id_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format']
        execution_date_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][
                'env_var_format']
        dag_run_id_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][
                'env_var_format']
        os.environ[dag_id_ctx_var_name] = 'test_dag_id'
        os.environ[task_id_ctx_var_name] = 'test_task_id'
        os.environ[execution_date_ctx_var_name] = 'test_execution_date'
        os.environ[dag_run_id_ctx_var_name] = 'test_dag_run_id'

        hook = HiveCliHook()
        output = hook.run_cli(hql=hql, hive_conf={'key': 'value'})
        self.assertIn('value', output)
        self.assertIn('test_dag_id', output)
        self.assertIn('test_task_id', output)
        self.assertIn('test_execution_date', output)
        self.assertIn('test_dag_run_id', output)

        del os.environ[dag_id_ctx_var_name]
        del os.environ[task_id_ctx_var_name]
        del os.environ[execution_date_ctx_var_name]
        del os.environ[dag_run_id_ctx_var_name]
コード例 #15
0
    def test_load_df_with_data_types(self, mock_run_cli):
        d = OrderedDict()
        d['b'] = [True]
        d['i'] = [-1]
        d['t'] = [1]
        d['f'] = [0.0]
        d['c'] = ['c']
        d['M'] = [datetime.datetime(2018, 1, 1)]
        d['O'] = [object()]
        d['S'] = ['STRING'.encode('utf-8')]
        d['U'] = ['STRING']
        d['V'] = [None]
        df = pd.DataFrame(d)

        hook = HiveCliHook()
        hook.load_df(df, 't')

        query = """
            CREATE TABLE IF NOT EXISTS t (
                b BOOLEAN,
                i BIGINT,
                t BIGINT,
                f DOUBLE,
                c STRING,
                M TIMESTAMP,
                O STRING,
                S STRING,
                U STRING,
                V STRING)
            ROW FORMAT DELIMITED
            FIELDS TERMINATED BY ','
            STORED AS textfile
            ;
        """
        assertEqualIgnoreMultipleSpaces(self, mock_run_cli.call_args_list[0][0][0], query)
コード例 #16
0
    def test_run_cli_with_hive_conf(self):
        hql = "set key;\n" \
              "set airflow.ctx.dag_id;\nset airflow.ctx.dag_run_id;\n" \
              "set airflow.ctx.task_id;\nset airflow.ctx.execution_date;\n"

        dag_id_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format']
        task_id_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format']
        execution_date_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][
                'env_var_format']
        dag_run_id_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][
                'env_var_format']
        os.environ[dag_id_ctx_var_name] = 'test_dag_id'
        os.environ[task_id_ctx_var_name] = 'test_task_id'
        os.environ[execution_date_ctx_var_name] = 'test_execution_date'
        os.environ[dag_run_id_ctx_var_name] = 'test_dag_run_id'

        hook = HiveCliHook()
        output = hook.run_cli(hql=hql, hive_conf={'key': 'value'})
        self.assertIn('value', output)
        self.assertIn('test_dag_id', output)
        self.assertIn('test_task_id', output)
        self.assertIn('test_execution_date', output)
        self.assertIn('test_dag_run_id', output)

        del os.environ[dag_id_ctx_var_name]
        del os.environ[task_id_ctx_var_name]
        del os.environ[execution_date_ctx_var_name]
        del os.environ[dag_run_id_ctx_var_name]
コード例 #17
0
def get_driver_online_time(ds, **op_kwargs):
    dt = op_kwargs["ds_nodash"]
    conn = get_db_conn('timerange_conn_db')
    mcursor = conn.cursor()
    mcursor.execute(get_driver_id)
    result = mcursor.fetchone()
    conn.commit()
    mcursor.close()
    conn.close()
    processes = []
    max_driver_id = result[0]

    logging.info('max driver id %d', max_driver_id)
    id_list = [x for x in range(1, max_driver_id+1)]
    part_size = 1000
    index = 0
    manager = Manager()
    rows = manager.list([])
    while index < max_driver_id:
        p = Process(target=get_driver_timerange,
                    args=(id_list[index:index + part_size], dt, rows))
        index += part_size
        processes.append(p)
        p.start()
    for p in processes:
        p.join()
    if rows:
        query = """
            INSERT OVERWRITE TABLE oride_dw_ods.{tab_name} PARTITION (dt='{dt}')
            VALUES {value}
        """.format(dt=ds, value=','.join(rows),tab_name=table_name)
        logging.info('import_driver_online_time run sql:%s' % query)
        hive_hook = HiveCliHook()
        hive_hook.run_cli(query)
コード例 #18
0
    def execute(self, context):
        hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)

        self.log.info("Dumping MySQL query results to local file")
        conn = mysql.get_conn()
        cursor = conn.cursor()
        cursor.execute(self.sql)
        with NamedTemporaryFile("wb") as f:
            csv_writer = csv.writer(f, delimiter=self.delimiter, encoding="utf-8")
            field_dict = OrderedDict()
            for field in cursor.description:
                field_dict[field[0]] = self.type_map(field[1])
            csv_writer.writerows(cursor)
            f.flush()
            cursor.close()
            conn.close()
            self.log.info("Loading file into Hive")
            hive.load_file(
                f.name,
                self.hive_table,
                field_dict=field_dict,
                create=self.create,
                partition=self.partition,
                delimiter=self.delimiter,
                recreate=self.recreate,
                tblproperties=self.tblproperties)
コード例 #19
0
 def ddl(self):
     """
     Retrieve table ddl
     """
     table = request.args.get("table")
     sql = "SHOW CREATE TABLE {table};".format(table=table)
     hook = HiveCliHook(HIVE_CLI_CONN_ID)
     return hook.run_cli(sql)
コード例 #20
0
def dwd_oride_driver_call_record_mid(ds, **kargs):
    hive_hook = HiveCliHook()

    # 读取sql
    _sql = dwd_oride_driver_call_record_mid_sql_task(ds)
    logging.info('Executing: %s', _sql)

    # 执行Hive
    hive_hook.run_cli(_sql)
def execution_data_task_id(ds, **kwargs):
    v_date = kwargs.get('v_execution_date')
    v_day = kwargs.get('v_execution_day')
    v_hour = kwargs.get('v_execution_hour')

    hive_hook = HiveCliHook()
    """
        #功能函数
        alter语句: alter_partition
        删除分区: delete_partition
        生产success: touchz_success

        #参数
        第一个参数true: 所有国家是否上线。false 没有
        第二个参数true: 数据目录是有country_code分区。false 没有
        第三个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS 

        #读取sql
        %_sql(ds,v_hour)

        第一个参数ds: 天级任务
        第二个参数v_hour: 小时级任务,需要使用

    """

    cf = CountriesPublicFrame("false", ds, db_name, table_name, hdfs_path,
                              "true", "true")

    v_info = [{
        "table": "oride_db.oride_data.data_user_estimate_records",
        "start_timeThour": "{v_day}T00".format(v_day=v_day),
        "end_dateThour": "{v_day}T23".format(v_day=v_day),
        "depend_dir": "oss://opay-datalake/oride_binlog"
    }]

    hcm = TaskHourSuccessCountMonitor(ds, v_info)

    hcm.HourSuccessCountMonitor()

    # 删除分区
    # cf.delete_partition()

    # 读取sql
    _sql = "\n" + cf.alter_partition(
    ) + "\n" + dwd_oride_passenger_estimate_records_di_sql_task(ds)

    logging.info('Executing: %s', _sql)

    # 执行Hive
    hive_hook.run_cli(_sql)

    # 熔断数据,如果数据不能为0
    # check_key_data_cnt_task(ds)

    # 生产success
    cf.touchz_success()
コード例 #22
0
    def test_load_file(self, mock_run_cli):
        filepath = "/path/to/input/file"
        table = "output_table"

        hook = HiveCliHook()
        hook.load_file(filepath=filepath, table=table, create=False)

        query = ("LOAD DATA LOCAL INPATH '{filepath}' "
                 "OVERWRITE INTO TABLE {table} ;\n".format(filepath=filepath,
                                                           table=table))
        mock_run_cli.assert_called_with(query)
コード例 #23
0
    def test_load_file(self, mock_run_cli):
        filepath = "/path/to/input/file"
        table = "output_table"

        hook = HiveCliHook()
        hook.load_file(filepath=filepath, table=table, create=False)

        query = ("LOAD DATA LOCAL INPATH '{filepath}' "
                 "OVERWRITE INTO TABLE {table} ;\n".format(filepath=filepath,
                                                           table=table))
        calls = [mock.call(';'), mock.call(query)]
        mock_run_cli.assert_has_calls(calls, any_order=True)
コード例 #24
0
    def execute(self, context):
        hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        self.log.info("Extracting data from Hive")
        hive_table = 'druid.' + context['task_instance_key_str'].replace('.', '_')
        sql = self.sql.strip().strip(';')
        tblproperties = ''.join([", '{}' = '{}'"
                                .format(k, v)
                                 for k, v in self.hive_tblproperties.items()])
        hql = """\
        SET mapred.output.compress=false;
        SET hive.exec.compress.output=false;
        DROP TABLE IF EXISTS {hive_table};
        CREATE TABLE {hive_table}
        ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t'
        STORED AS TEXTFILE
        TBLPROPERTIES ('serialization.null.format' = ''{tblproperties})
        AS
        {sql}
        """.format(hive_table=hive_table, tblproperties=tblproperties, sql=sql)
        self.log.info("Running command:\n %s", hql)
        hive.run_cli(hql)

        m = HiveMetastoreHook(self.metastore_conn_id)

        # Get the Hive table and extract the columns
        t = m.get_table(hive_table)
        columns = [col.name for col in t.sd.cols]

        # Get the path on hdfs
        static_path = m.get_table(hive_table).sd.location

        schema, table = hive_table.split('.')

        druid = DruidHook(druid_ingest_conn_id=self.druid_ingest_conn_id)

        try:
            index_spec = self.construct_ingest_query(
                static_path=static_path,
                columns=columns,
            )

            self.log.info("Inserting rows into Druid, hdfs path: %s", static_path)

            druid.submit_indexing_job(index_spec)

            self.log.info("Load seems to have succeeded!")
        finally:
            self.log.info(
                "Cleaning up by dropping the temp Hive table %s",
                hive_table
            )
            hql = "DROP TABLE IF EXISTS {}".format(hive_table)
            hive.run_cli(hql)
コード例 #25
0
def execution_data_task_id(ds, **kwargs):
    v_date = kwargs.get('v_execution_date')
    v_day = kwargs.get('v_execution_day')
    v_hour = kwargs.get('v_execution_hour')

    hive_hook = HiveCliHook()
    """
        #功能函数
        alter语句: alter_partition
        删除分区: delete_partition
        生产success: touchz_success

        #参数
        第一个参数true: 数据目录是有country_code分区。false 没有
        第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS 

        #读取sql
        %_sql_task(ds,v_hour)

        第一个参数ds: 天级任务
        第二个参数v_hour: 小时级任务,需要使用

    """

    if datetime.strptime(ds, '%Y-%m-%d').weekday() == 6:
        cf = CountriesPublicFrame("false", ds, db_name, table_name, hdfs_path,
                                  "true", "false")
    else:
        cf = CountriesPublicFrame("false", ds, db_name, table_name, hdfs_path,
                                  "true", "true")

    # 删除分区
    # cf.delete_partition()

    # 拼接SQL

    _sql = "\n" + cf.alter_partition(
    ) + "\n" + app_ocredit_phones_order_base_cube_w_sql_task(ds)

    logging.info('Executing: %s', _sql)

    # 执行Hive
    hive_hook.run_cli(_sql)

    # 熔断数据,如果数据不能为0
    # check_key_data_cnt_task(ds)

    # 熔断数据
    # check_key_data_task(ds)

    # 生产success
    cf.touchz_success()
コード例 #26
0
    def test_load_file(self, mock_run_cli):
        filepath = "/path/to/input/file"
        table = "output_table"

        hook = HiveCliHook()
        hook.load_file(filepath=filepath, table=table, create=False)

        query = (
            "LOAD DATA LOCAL INPATH '{filepath}' "
            "OVERWRITE INTO TABLE {table} \n"
            .format(filepath=filepath, table=table)
        )
        mock_run_cli.assert_called_with(query)
コード例 #27
0
    def test_load_df_with_optional_parameters(self, mock_to_csv, mock_load_file):
        hook = HiveCliHook()
        b = (True, False)
        for create, recreate in itertools.product(b, b):
            mock_load_file.reset_mock()
            hook.load_df(df=pd.DataFrame({"c": range(0, 10)}),
                         table="t",
                         create=create,
                         recreate=recreate)

            mock_load_file.assert_called_once()
            kwargs = mock_load_file.call_args[1]
            self.assertEqual(kwargs["create"], create)
            self.assertEqual(kwargs["recreate"], recreate)
コード例 #28
0
    def test_get_proxy_user_value(self):
        from airflow.hooks.hive_hooks import HiveCliHook

        hook = HiveCliHook()
        returner = mock.MagicMock()
        returner.extra_dejson = {'proxy_user': '******'}
        hook.use_beeline = True
        hook.conn = returner

        # Run
        result = hook._prepare_cli_cmd()

        # Verify
        self.assertIn('hive.server2.proxy.user=a_user_proxy', result[2])
コード例 #29
0
    def test_load_df_with_optional_parameters(self, mock_to_csv, mock_load_file):
        hook = HiveCliHook()
        b = (True, False)
        for create, recreate in itertools.product(b, b):
            mock_load_file.reset_mock()
            hook.load_df(df=pd.DataFrame({"c": range(0, 10)}),
                         table="t",
                         create=create,
                         recreate=recreate)

            assert mock_load_file.call_count == 1
            kwargs = mock_load_file.call_args[1]
            self.assertEqual(kwargs["create"], create)
            self.assertEqual(kwargs["recreate"], recreate)
コード例 #30
0
def execution_data_task_id(ds, **kwargs):

    v_date = kwargs.get('v_execution_date')
    v_day = kwargs.get('v_execution_day')
    v_hour = kwargs.get('v_execution_hour')

    hive_hook = HiveCliHook()
    """
        #功能函数
        alter语句: alter_partition
        删除分区: delete_partition
        生产success: touchz_success

        #参数
        第一个参数true: 数据目录是有country_code分区。false 没有
        第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS 

        #读取sql
        %_sql_task(ds,v_hour)

        第一个参数ds: 天级任务
        第二个参数v_hour: 小时级任务,需要使用

    """

    cf = CountriesPublicFrame("false", ds, db_name, table_name, hdfs_path,
                              "true", "true")

    #删除分区
    #cf.delete_partition()

    #拼接SQL

    _sql = "\n" + cf.alter_partition(
    ) + "\n" + dwd_oride_strategy_data_invite_df_sql_task(ds)

    logging.info('Executing: %s', _sql)

    #执行Hive
    hive_hook.run_cli(_sql)

    #熔断数据,如果数据不能为0
    #check_key_data_cnt_task(ds)

    #熔断数据
    #check_key_data_task(ds)

    #生产success
    cf.touchz_success()
コード例 #31
0
ファイル: mapr_tasks_dag.py プロジェクト: verdyr/mapr-airflow
def query_hive(**kwargs):
    ti = kwargs['ti']
    # get sha of latest commit
    v1 = ti.xcom_pull(key=None, task_ids='get_last_commit_task')
    json_value = json.loads(v1)
    sha = json_value['sha']

    hive_cli = HiveCliHook()
    hql = "select * from mapr_music_updates where commit_sha = '" + sha + "';"
    latest_commit = hive_cli.run_cli(hql)

    changed = latest_commit.find(sha) == -1
    ti.xcom_push(key='sha', value=sha)
    ti.xcom_push(key='is_changed', value=changed)

    return 'reimport_dataset_task' if changed else 'skip_reimport_dataset_task'
コード例 #32
0
 def execute(self, context):
     self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
     self.s3 = S3Hook(s3_conn_id=self.s3_conn_id)
     logging.info("Downloading S3 file")
     if self.wildcard_match:
         if not self.s3.check_for_wildcard_key(self.s3_key):
             raise AirflowException("No key matches {0}".format(self.s3_key))
         s3_key_object = self.s3.get_wildcard_key(self.s3_key)
     else:
         if not self.s3.check_for_key(self.s3_key):
             raise AirflowException(
                 "The key {0} does not exists".format(self.s3_key))
         s3_key_object = self.s3.get_key(self.s3_key)
     with NamedTemporaryFile("w") as f:
         logging.info("Dumping S3 key {0} contents to local"
                      " file {1}".format(s3_key_object.key, f.name))
         s3_key_object.get_contents_to_file(f)
         f.flush()
         self.s3.connection.close()
         if not self.headers:
             logging.info("Loading file into Hive")
             self.hive.load_file(
                 f.name,
                 self.hive_table,
                 field_dict=self.field_dict,
                 create=self.create,
                 partition=self.partition,
                 delimiter=self.delimiter,
                 recreate=self.recreate)
         else:
             with open(f.name, 'r') as tmpf:
                 if self.check_headers:
                     header_l = tmpf.readline()
                     header_line = header_l.rstrip()
                     header_list = header_line.split(self.delimiter)
                     field_names = list(self.field_dict.keys())
                     test_field_match = [h1.lower() == h2.lower() for h1, h2
                                         in zip(header_list, field_names)]
                     if not all(test_field_match):
                         logging.warning("Headers do not match field names"
                                         "File headers:\n {header_list}\n"
                                         "Field names: \n {field_names}\n"
                                         "".format(**locals()))
                         raise AirflowException("Headers do not match the "
                                         "field_dict keys")
                 with NamedTemporaryFile("w") as f_no_headers:
                     tmpf.seek(0)
                     next(tmpf)
                     for line in tmpf:
                         f_no_headers.write(line)
                     f_no_headers.flush()
                     logging.info("Loading file without headers into Hive")
                     self.hive.load_file(
                         f_no_headers.name,
                         self.hive_table,
                         field_dict=self.field_dict,
                         create=self.create,
                         partition=self.partition,
                         delimiter=self.delimiter,
                         recreate=self.recreate)
コード例 #33
0
 def get_hook(self):
     return HiveCliHook(
                     hive_cli_conn_id=self.hive_cli_conn_id,
                     run_as=self.run_as,
                     mapred_queue=self.mapred_queue,
                     mapred_queue_priority=self.mapred_queue_priority,
                     mapred_job_name=self.mapred_job_name)
コード例 #34
0
def create_hive_external_table(db, table, conn, **op_kwargs):
    sqoopSchema = SqoopSchemaUpdate()
    response = sqoopSchema.update_hive_schema(
        hive_db=hive_db,
        hive_table=hive_table.format(bs=table),
        mysql_db=db,
        mysql_table=table,
        mysql_conn=conn
    )
    #if response:
    #    return True

    mysql_conn = get_db_conn(conn)
    mcursor = mysql_conn.cursor()
    sql = '''
        select 
            COLUMN_NAME, 
            DATA_TYPE, 
            COLUMN_COMMENT,
            COLUMN_TYPE 
        from information_schema.COLUMNS 
        where TABLE_SCHEMA='{db}' and 
            TABLE_NAME='{table}' 
        order by ORDINAL_POSITION
    '''.format(db=db, table=table)
    # logging.info(sql)
    mcursor.execute(sql)
    res = mcursor.fetchall()
    # logging.info(res)
    columns = []
    for (name, type, comment, co_type) in res:
        if type.upper() == 'DECIMAL':
            columns.append("`%s` %s comment '%s'" % (name, co_type.replace('unsigned', '').replace('signed', ''), comment))
        else:
            columns.append("`%s` %s comment '%s'" % (name, mysql_type_to_hive.get(type.upper(), 'string'), comment))
    mysql_conn.close()
    # 创建hive数据表的sql
    hql = ods_create_table_hql.format(
        db_name=hive_db,
        table_name=hive_table.format(bs=table),
        columns=",\n".join(columns),
        hdfs_path=hdfs_path.format(bs=table)
    )
    logging.info(hql)
    hive_hook = HiveCliHook()
    logging.info('Executing: %s', hql)
    hive_hook.run_cli(hql)
コード例 #35
0
def merge_pre_hi_with_full_data_task(hive_db, hive_h_his_table_name,
                                     hive_hi_table_name, mysql_db_name,
                                     mysql_table_name, mysql_conn,
                                     sqoop_temp_db_name, sqoop_table_name, pt,
                                     now_hour, pre_day, pre_hour_day, pre_hour,
                                     is_must_have_data, **kwargs):
    sqoopSchema = SqoopSchemaUpdate()

    hive_columns = sqoopSchema.get_hive_column_name(hive_db,
                                                    hive_h_his_table_name)
    mysql_columns = sqoopSchema.get_mysql_column_name(mysql_db_name,
                                                      mysql_table_name,
                                                      mysql_conn)
    pre_day_ms = int(time.mktime(time.strptime(pre_day, "%Y-%m-%d"))) * 1000

    hql = MERGE_HI_WITH_FULL_SQL.format(
        columns=',\n'.join(hive_columns),
        pt=pt,
        now_hour=now_hour,
        db_name=hive_db,
        mysql_db_name=mysql_db_name,
        hive_h_his_table_name=hive_h_his_table_name,
        hive_hi_table_name=hive_hi_table_name,
        mysql_table_name=mysql_table_name,
        pre_day_ms=pre_day_ms,
        mysql_columns=',\n'.join(mysql_columns),
        sqoop_temp_db_name=sqoop_temp_db_name,
        sqoop_table_name=sqoop_table_name)

    hive_hook = HiveCliHook()

    # 读取sql
    logging.info('Executing: %s', hql)

    # 执行Hive
    hive_hook.run_cli(hql)

    # 生成_SUCCESS
    """
    第一个参数true: 数据目录是有country_code分区。false 没有
    第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS 

    """
    TaskTouchzSuccess().countries_touchz_success(
        pt, hive_db, hive_h_his_table_name,
        H_HIS_OSS_PATH % hive_h_his_table_name, "false", is_must_have_data,
        now_hour)
コード例 #36
0
def execution_data_task_id(ds, **kargs):
    hive_hook = HiveCliHook()

    # 读取sql
    _sql = dwm_oride_driver_act_w_sql_task(ds)

    logging.info('Executing: %s', _sql)

    # 执行Hive
    hive_hook.run_cli(_sql)

    # 生成_SUCCESS
    """
    第一个参数true: 数据目录是有country_code分区。false 没有
    第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS 

    """
    TaskTouchzSuccess().countries_touchz_success('{pt}'.format(pt=airflow.macros.ds_add(ds, +6)), db_name, table_name, hdfs_path, "true", "true")
コード例 #37
0
def execution_data_task_id(ds, ds_nodash,  **kargs):
    hive_hook = HiveCliHook()

    # 读取sql
    _sql = app_opay_active_user_report_w_sql_task(ds,ds_nodash)

    logging.info('Executing: %s', _sql)

    # 执行Hive
    hive_hook.run_cli(_sql)

    # 生成_SUCCESS
    """
    第一个参数true: 数据目录是有country_code分区。false 没有
    第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS 

    """
    TaskTouchzSuccess().countries_touchz_success(ds, db_name, table_name, hdfs_path, "true", "true")
コード例 #38
0
def execution_data_task_id(ds, **kargs):
    hive_hook = HiveCliHook()

    # 读取sql
    _sql = ods_sqoop_base_bd_agent_df_sql_task(ds)

    logging.info('Executing: %s', _sql)

    # 执行Hive
    hive_hook.run_cli(_sql)

    # 生成_SUCCESS
    """
    第一个参数true: 数据目录是有country_code分区。false 没有
    第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS 

    """
    TaskTouchzSuccess().countries_touchz_success(ds, db_name, table_name, hdfs_path, "false", "true")
コード例 #39
0
def execution_act_driver_task(ds, **kargs):
    hive_hook = HiveCliHook()

    # 读取sql
    _sql = app_oride_act_driver_cohort_w_sql_task(ds)

    # 执行hive
    hive_hook.run_cli(_sql)

    # 生成_SUCCESS
    """
    第一个参数true: 数据目录是有country_code分区。false 没有
    第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS 

    """
    pt = airflow.macros.ds_add(ds, +6)
    hdfs_path = get_table_info(3)[1]
    TaskTouchzSuccess().countries_touchz_success(pt, "oride_dw", get_table_info(3)[0], hdfs_path, "true", "true")
コード例 #40
0
    def test_load_df(self, mock_to_csv, mock_load_file):
        df = pd.DataFrame({"c": ["foo", "bar", "baz"]})
        table = "t"
        delimiter = ","
        encoding = "utf-8"

        hook = HiveCliHook()
        hook.load_df(df=df,
                     table=table,
                     delimiter=delimiter,
                     encoding=encoding)

        mock_to_csv.assert_called_once()
        kwargs = mock_to_csv.call_args[1]
        self.assertEqual(kwargs["header"], False)
        self.assertEqual(kwargs["index"], False)
        self.assertEqual(kwargs["sep"], delimiter)

        mock_load_file.assert_called_once()
        kwargs = mock_load_file.call_args[1]
        self.assertEqual(kwargs["delimiter"], delimiter)
        self.assertEqual(kwargs["field_dict"], {"c": u"STRING"})
        self.assertTrue(isinstance(kwargs["field_dict"], OrderedDict))
        self.assertEqual(kwargs["table"], table)
コード例 #41
0
class S3ToHiveTransfer(BaseOperator):
    """
    Moves data from S3 to Hive. The operator downloads a file from S3,
    stores the file locally before loading it into a Hive table.
    If the ``create`` or ``recreate`` arguments are set to ``True``,
    a ``CREATE TABLE`` and ``DROP TABLE`` statements are generated.
    Hive data types are inferred from the cursor's metadata from.

    Note that the table generated in Hive uses ``STORED AS textfile``
    which isn't the most efficient serialization format. If a
    large amount of data is loaded and/or if the tables gets
    queried considerably, you may want to use this operator only to
    stage the data into a temporary table before loading it into its
    final destination using a ``HiveOperator``.

    :param s3_key: The key to be retrieved from S3
    :type s3_key: str
    :param field_dict: A dictionary of the fields name in the file
        as keys and their Hive types as values
    :type field_dict: dict
    :param hive_table: target Hive table, use dot notation to target a
        specific database
    :type hive_table: str
    :param create: whether to create the table if it doesn't exist
    :type create: bool
    :param recreate: whether to drop and recreate the table at every
        execution
    :type recreate: bool
    :param partition: target partition as a dict of partition columns
        and values
    :type partition: dict
    :param headers: whether the file contains column names on the first
        line
    :type headers: bool
    :param check_headers: whether the column names on the first line should be
        checked against the keys of field_dict
    :type check_headers: bool
    :param wildcard_match: whether the s3_key should be interpreted as a Unix
        wildcard pattern
    :type wildcard_match: bool
    :param delimiter: field delimiter in the file
    :type delimiter: str
    :param aws_conn_id: source s3 connection
    :type aws_conn_id: str
    :param hive_cli_conn_id: destination hive connection
    :type hive_cli_conn_id: str
    :param input_compressed: Boolean to determine if file decompression is
        required to process headers
    :type input_compressed: bool
    :param tblproperties: TBLPROPERTIES of the hive table being created
    :type tblproperties: dict
    """

    template_fields = ('s3_key', 'partition', 'hive_table')
    template_ext = ()
    ui_color = '#a0e08c'

    @apply_defaults
    def __init__(
            self,
            s3_key,
            field_dict,
            hive_table,
            delimiter=',',
            create=True,
            recreate=False,
            partition=None,
            headers=False,
            check_headers=False,
            wildcard_match=False,
            aws_conn_id='aws_default',
            hive_cli_conn_id='hive_cli_default',
            input_compressed=False,
            tblproperties=None,
            *args, **kwargs):
        super(S3ToHiveTransfer, self).__init__(*args, **kwargs)
        self.s3_key = s3_key
        self.field_dict = field_dict
        self.hive_table = hive_table
        self.delimiter = delimiter
        self.create = create
        self.recreate = recreate
        self.partition = partition
        self.headers = headers
        self.check_headers = check_headers
        self.wildcard_match = wildcard_match
        self.hive_cli_conn_id = hive_cli_conn_id
        self.aws_conn_id = aws_conn_id
        self.input_compressed = input_compressed
        self.tblproperties = tblproperties

        if (self.check_headers and
                not (self.field_dict is not None and self.headers)):
            raise AirflowException("To check_headers provide " +
                                   "field_dict and headers")

    def execute(self, context):
        # Downloading file from S3
        self.s3 = S3Hook(aws_conn_id=self.aws_conn_id)
        self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        self.log.info("Downloading S3 file")

        if self.wildcard_match:
            if not self.s3.check_for_wildcard_key(self.s3_key):
                raise AirflowException("No key matches {0}"
                                       .format(self.s3_key))
            s3_key_object = self.s3.get_wildcard_key(self.s3_key)
        else:
            if not self.s3.check_for_key(self.s3_key):
                raise AirflowException(
                    "The key {0} does not exists".format(self.s3_key))
            s3_key_object = self.s3.get_key(self.s3_key)
        root, file_ext = os.path.splitext(s3_key_object.key)
        with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\
                NamedTemporaryFile(mode="wb",
                                   dir=tmp_dir,
                                   suffix=file_ext) as f:
            self.log.info("Dumping S3 key {0} contents to local file {1}"
                          .format(s3_key_object.key, f.name))
            s3_key_object.download_fileobj(f)
            f.flush()
            if not self.headers:
                self.log.info("Loading file %s into Hive", f.name)
                self.hive.load_file(
                    f.name,
                    self.hive_table,
                    field_dict=self.field_dict,
                    create=self.create,
                    partition=self.partition,
                    delimiter=self.delimiter,
                    recreate=self.recreate,
                    tblproperties=self.tblproperties)
            else:
                # Decompressing file
                if self.input_compressed:
                    self.log.info("Uncompressing file %s", f.name)
                    fn_uncompressed = uncompress_file(f.name,
                                                      file_ext,
                                                      tmp_dir)
                    self.log.info("Uncompressed to %s", fn_uncompressed)
                    # uncompressed file available now so deleting
                    # compressed file to save disk space
                    f.close()
                else:
                    fn_uncompressed = f.name

                # Testing if header matches field_dict
                if self.check_headers:
                    self.log.info("Matching file header against field_dict")
                    header_list = self._get_top_row_as_list(fn_uncompressed)
                    if not self._match_headers(header_list):
                        raise AirflowException("Header check failed")

                # Deleting top header row
                self.log.info("Removing header from file %s", fn_uncompressed)
                headless_file = (
                    self._delete_top_row_and_compress(fn_uncompressed,
                                                      file_ext,
                                                      tmp_dir))
                self.log.info("Headless file %s", headless_file)
                self.log.info("Loading file %s into Hive", headless_file)
                self.hive.load_file(headless_file,
                                    self.hive_table,
                                    field_dict=self.field_dict,
                                    create=self.create,
                                    partition=self.partition,
                                    delimiter=self.delimiter,
                                    recreate=self.recreate,
                                    tblproperties=self.tblproperties)

    def _get_top_row_as_list(self, file_name):
        with open(file_name, 'rt') as f:
            header_line = f.readline().strip()
            header_list = header_line.split(self.delimiter)
            return header_list

    def _match_headers(self, header_list):
        if not header_list:
            raise AirflowException("Unable to retrieve header row from file")
        field_names = self.field_dict.keys()
        if len(field_names) != len(header_list):
            self.log.warning("Headers count mismatch"
                              "File headers:\n {header_list}\n"
                              "Field names: \n {field_names}\n"
                              "".format(**locals()))
            return False
        test_field_match = [h1.lower() == h2.lower()
                            for h1, h2 in zip(header_list, field_names)]
        if not all(test_field_match):
            self.log.warning("Headers do not match field names"
                              "File headers:\n {header_list}\n"
                              "Field names: \n {field_names}\n"
                              "".format(**locals()))
            return False
        else:
            return True

    def _delete_top_row_and_compress(
            self,
            input_file_name,
            output_file_ext,
            dest_dir):
        # When output_file_ext is not defined, file is not compressed
        open_fn = open
        if output_file_ext.lower() == '.gz':
            open_fn = gzip.GzipFile
        elif output_file_ext.lower() == '.bz2':
            open_fn = bz2.BZ2File

        os_fh_output, fn_output = \
            tempfile.mkstemp(suffix=output_file_ext, dir=dest_dir)
        with open(input_file_name, 'rb') as f_in,\
                open_fn(fn_output, 'wb') as f_out:
            f_in.seek(0)
            next(f_in)
            for line in f_in:
                f_out.write(line)
        return fn_output
コード例 #42
0
    def execute(self, context):
        # Downloading file from S3
        self.s3 = S3Hook(aws_conn_id=self.aws_conn_id)
        self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        self.log.info("Downloading S3 file")

        if self.wildcard_match:
            if not self.s3.check_for_wildcard_key(self.s3_key):
                raise AirflowException("No key matches {0}"
                                       .format(self.s3_key))
            s3_key_object = self.s3.get_wildcard_key(self.s3_key)
        else:
            if not self.s3.check_for_key(self.s3_key):
                raise AirflowException(
                    "The key {0} does not exists".format(self.s3_key))
            s3_key_object = self.s3.get_key(self.s3_key)
        root, file_ext = os.path.splitext(s3_key_object.key)
        with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\
                NamedTemporaryFile(mode="wb",
                                   dir=tmp_dir,
                                   suffix=file_ext) as f:
            self.log.info("Dumping S3 key {0} contents to local file {1}"
                          .format(s3_key_object.key, f.name))
            s3_key_object.download_fileobj(f)
            f.flush()
            if not self.headers:
                self.log.info("Loading file %s into Hive", f.name)
                self.hive.load_file(
                    f.name,
                    self.hive_table,
                    field_dict=self.field_dict,
                    create=self.create,
                    partition=self.partition,
                    delimiter=self.delimiter,
                    recreate=self.recreate,
                    tblproperties=self.tblproperties)
            else:
                # Decompressing file
                if self.input_compressed:
                    self.log.info("Uncompressing file %s", f.name)
                    fn_uncompressed = uncompress_file(f.name,
                                                      file_ext,
                                                      tmp_dir)
                    self.log.info("Uncompressed to %s", fn_uncompressed)
                    # uncompressed file available now so deleting
                    # compressed file to save disk space
                    f.close()
                else:
                    fn_uncompressed = f.name

                # Testing if header matches field_dict
                if self.check_headers:
                    self.log.info("Matching file header against field_dict")
                    header_list = self._get_top_row_as_list(fn_uncompressed)
                    if not self._match_headers(header_list):
                        raise AirflowException("Header check failed")

                # Deleting top header row
                self.log.info("Removing header from file %s", fn_uncompressed)
                headless_file = (
                    self._delete_top_row_and_compress(fn_uncompressed,
                                                      file_ext,
                                                      tmp_dir))
                self.log.info("Headless file %s", headless_file)
                self.log.info("Loading file %s into Hive", headless_file)
                self.hive.load_file(headless_file,
                                    self.hive_table,
                                    field_dict=self.field_dict,
                                    create=self.create,
                                    partition=self.partition,
                                    delimiter=self.delimiter,
                                    recreate=self.recreate,
                                    tblproperties=self.tblproperties)
コード例 #43
0
 def test_run_cli(self):
     hook = HiveCliHook()
     hook.run_cli("SHOW DATABASES")
コード例 #44
0
class S3ToHiveTransfer(BaseOperator):
    """
    Moves data from S3 to Hive. The operator downloads a file from S3,
    stores the file locally before loading it into a Hive table.
    If the ``create`` or ``recreate`` arguments are set to ``True``,
    a ``CREATE TABLE`` and ``DROP TABLE`` statements are generated.
    Hive data types are inferred from the cursor's metadata from.

    Note that the table generated in Hive uses ``STORED AS textfile``
    which isn't the most efficient serialization format. If a
    large amount of data is loaded and/or if the tables gets
    queried considerably, you may want to use this operator only to
    stage the data into a temporary table before loading it into its
    final destination using a ``HiveOperator``.

    :param s3_key: The key to be retrieved from S3
    :type s3_key: str
    :param field_dict: A dictionary of the fields name in the file
        as keys and their Hive types as values
    :type field_dict: dict
    :param hive_table: target Hive table, use dot notation to target a
        specific database
    :type hive_table: str
    :param create: whether to create the table if it doesn't exist
    :type create: bool
    :param recreate: whether to drop and recreate the table at every
        execution
    :type recreate: bool
    :param partition: target partition as a dict of partition columns
        and values
    :type partition: dict
    :param headers: whether the file contains column names on the first
        line
    :type headers: bool
    :param check_headers: whether the column names on the first line should be
        checked against the keys of field_dict
    :type check_headers: bool
    :param wildcard_match: whether the s3_key should be interpreted as a Unix
        wildcard pattern
    :type wildcard_match: bool
    :param delimiter: field delimiter in the file
    :type delimiter: str
    :param s3_conn_id: source s3 connection
    :type s3_conn_id: str
    :param hive_conn_id: destination hive connection
    :type hive_conn_id: str
    """

    template_fields = ('s3_key', 'partition', 'hive_table')
    template_ext = ()
    ui_color = '#a0e08c'

    @apply_defaults
    def __init__(
            self,
            s3_key,
            field_dict,
            hive_table,
            delimiter=',',
            create=True,
            recreate=False,
            partition=None,
            headers=False,
            check_headers=False,
            wildcard_match=False,
            s3_conn_id='s3_default',
            hive_cli_conn_id='hive_cli_default',
            *args, **kwargs):
        super(S3ToHiveTransfer, self).__init__(*args, **kwargs)
        self.s3_key = s3_key
        self.field_dict = field_dict
        self.hive_table = hive_table
        self.delimiter = delimiter
        self.create = create
        self.recreate = recreate
        self.partition = partition
        self.headers = headers
        self.check_headers = check_headers
        self.wildcard_match = wildcard_match
        self.hive_cli_conn_id = hive_cli_conn_id
        self.s3_conn_id = s3_conn_id

    def execute(self, context):
        self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        self.s3 = S3Hook(s3_conn_id=self.s3_conn_id)
        logging.info("Downloading S3 file")
        if self.wildcard_match:
            if not self.s3.check_for_wildcard_key(self.s3_key):
                raise AirflowException("No key matches {0}".format(self.s3_key))
            s3_key_object = self.s3.get_wildcard_key(self.s3_key)
        else:
            if not self.s3.check_for_key(self.s3_key):
                raise AirflowException(
                    "The key {0} does not exists".format(self.s3_key))
            s3_key_object = self.s3.get_key(self.s3_key)
        with NamedTemporaryFile("w") as f:
            logging.info("Dumping S3 key {0} contents to local"
                         " file {1}".format(s3_key_object.key, f.name))
            s3_key_object.get_contents_to_file(f)
            f.flush()
            self.s3.connection.close()
            if not self.headers:
                logging.info("Loading file into Hive")
                self.hive.load_file(
                    f.name,
                    self.hive_table,
                    field_dict=self.field_dict,
                    create=self.create,
                    partition=self.partition,
                    delimiter=self.delimiter,
                    recreate=self.recreate)
            else:
                with open(f.name, 'r') as tmpf:
                    if self.check_headers:
                        header_l = tmpf.readline()
                        header_line = header_l.rstrip()
                        header_list = header_line.split(self.delimiter)
                        field_names = list(self.field_dict.keys())
                        test_field_match = [h1.lower() == h2.lower() for h1, h2
                                            in zip(header_list, field_names)]
                        if not all(test_field_match):
                            logging.warning("Headers do not match field names"
                                            "File headers:\n {header_list}\n"
                                            "Field names: \n {field_names}\n"
                                            "".format(**locals()))
                            raise AirflowException("Headers do not match the "
                                            "field_dict keys")
                    with NamedTemporaryFile("w") as f_no_headers:
                        tmpf.seek(0)
                        next(tmpf)
                        for line in tmpf:
                            f_no_headers.write(line)
                        f_no_headers.flush()
                        logging.info("Loading file without headers into Hive")
                        self.hive.load_file(
                            f_no_headers.name,
                            self.hive_table,
                            field_dict=self.field_dict,
                            create=self.create,
                            partition=self.partition,
                            delimiter=self.delimiter,
                            recreate=self.recreate)
コード例 #45
0
ファイル: main.py プロジェクト: AdamUnger/incubator-airflow
 def ddl(self):
     table = request.args.get("table")
     sql = "SHOW CREATE TABLE {table};".format(table=table)
     h = HiveCliHook(HIVE_CLI_CONN_ID)
     return h.run_cli(sql)
コード例 #46
0
    def execute(self, context):
        # Downloading file from S3
        self.s3 = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
        self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        self.log.info("Downloading S3 file")

        if self.wildcard_match:
            if not self.s3.check_for_wildcard_key(self.s3_key):
                raise AirflowException("No key matches {0}"
                                       .format(self.s3_key))
            s3_key_object = self.s3.get_wildcard_key(self.s3_key)
        else:
            if not self.s3.check_for_key(self.s3_key):
                raise AirflowException(
                    "The key {0} does not exists".format(self.s3_key))
            s3_key_object = self.s3.get_key(self.s3_key)

        root, file_ext = os.path.splitext(s3_key_object.key)
        if (self.select_expression and self.input_compressed and
                file_ext.lower() != '.gz'):
            raise AirflowException("GZIP is the only compression " +
                                   "format Amazon S3 Select supports")

        with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\
                NamedTemporaryFile(mode="wb",
                                   dir=tmp_dir,
                                   suffix=file_ext) as f:
            self.log.info(
                "Dumping S3 key %s contents to local file %s", s3_key_object.key, f.name
            )
            if self.select_expression:
                option = {}
                if self.headers:
                    option['FileHeaderInfo'] = 'USE'
                if self.delimiter:
                    option['FieldDelimiter'] = self.delimiter

                input_serialization = {'CSV': option}
                if self.input_compressed:
                    input_serialization['CompressionType'] = 'GZIP'

                content = self.s3.select_key(
                    bucket_name=s3_key_object.bucket_name,
                    key=s3_key_object.key,
                    expression=self.select_expression,
                    input_serialization=input_serialization
                )
                f.write(content.encode("utf-8"))
            else:
                s3_key_object.download_fileobj(f)
            f.flush()

            if self.select_expression or not self.headers:
                self.log.info("Loading file %s into Hive", f.name)
                self.hive.load_file(
                    f.name,
                    self.hive_table,
                    field_dict=self.field_dict,
                    create=self.create,
                    partition=self.partition,
                    delimiter=self.delimiter,
                    recreate=self.recreate,
                    tblproperties=self.tblproperties)
            else:
                # Decompressing file
                if self.input_compressed:
                    self.log.info("Uncompressing file %s", f.name)
                    fn_uncompressed = uncompress_file(f.name,
                                                      file_ext,
                                                      tmp_dir)
                    self.log.info("Uncompressed to %s", fn_uncompressed)
                    # uncompressed file available now so deleting
                    # compressed file to save disk space
                    f.close()
                else:
                    fn_uncompressed = f.name

                # Testing if header matches field_dict
                if self.check_headers:
                    self.log.info("Matching file header against field_dict")
                    header_list = self._get_top_row_as_list(fn_uncompressed)
                    if not self._match_headers(header_list):
                        raise AirflowException("Header check failed")

                # Deleting top header row
                self.log.info("Removing header from file %s", fn_uncompressed)
                headless_file = (
                    self._delete_top_row_and_compress(fn_uncompressed,
                                                      file_ext,
                                                      tmp_dir))
                self.log.info("Headless file %s", headless_file)
                self.log.info("Loading file %s into Hive", headless_file)
                self.hive.load_file(headless_file,
                                    self.hive_table,
                                    field_dict=self.field_dict,
                                    create=self.create,
                                    partition=self.partition,
                                    delimiter=self.delimiter,
                                    recreate=self.recreate,
                                    tblproperties=self.tblproperties)