Ejemplo n.º 1
0
 def test_to_csv_assertlogs(self):
     hook = HiveServer2Hook()
     query = "SELECT * FROM {}".format(self.table)
     csv_filepath = 'query_results.csv'
     with self.assertLogs() as cm:
         hook.to_csv(query,
                     csv_filepath,
                     schema=self.database,
                     delimiter=',',
                     lineterminator='\n',
                     output_header=True,
                     fetch_size=2)
         df = pd.read_csv(csv_filepath, sep=',')
         self.assertListEqual(df.columns.tolist(), self.columns)
         self.assertListEqual(df[self.columns[0]].values.tolist(), [1, 2])
         self.assertEqual(len(df), 2)
         self.assertIn(
             'INFO:airflow.hooks.hive_hooks.HiveServer2Hook:'
             'Written 2 rows so far.', cm.output)
Ejemplo n.º 2
0
    def test_get_conn_with_password(self, mock_connect):
        from airflow.hooks.base_hook import CONN_ENV_PREFIX
        conn_id = "conn_with_password"
        conn_env = CONN_ENV_PREFIX + conn_id.upper()
        conn_value = os.environ.get(conn_env)
        os.environ[
            conn_env] = "jdbc+hive2://conn_id:conn_pass@localhost:10000/default?authMechanism=LDAP"

        HiveServer2Hook(hiveserver2_conn_id=conn_id).get_conn()
        mock_connect.assert_called_once_with(host='localhost',
                                             port=10000,
                                             auth='LDAP',
                                             kerberos_service_name=None,
                                             username='******',
                                             password='******',
                                             database='default')

        if conn_value:
            os.environ[conn_env] = conn_value
Ejemplo n.º 3
0
    def get_records(self):
        """Executes a query to obtain a count of records on a table

        Returns:
            int -- quantity of records from a count query
        """

        if self.query_engine == 'hive':
            hook = HiveServer2Hook(self.query_engine_conn_id)

        elif self.query_engine == 'presto':
            hook = PrestoHook(self.query_engine_conn_id)

        # executes query to Hive or Presto
        res = hook.get_records(self.records_query)

        if len(res) > 1:
            raise
        else:
            return res[0]
    def execute(self, context):
        hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)

        logging.info('Extracting data from Hive')
        logging.info(self.hql)

        data = hive.get_pandas_df(self.hql, schema=self.schema)
        gcp_hook = GoogleCloudStorageHook(google_cloud_storage_conn_id=self.google_cloud_storage_conn_id)
        logging.info('Inserting rows onto google cloud storage')

        with tempfile.NamedTemporaryFile(suffix='.json', prefix='tmp') as tmp_file:
            data = data.to_json(orient='records')
            recs = json.loads(data)
            for record in recs:
                tmp_file.write(json.dumps(record))
                tmp_file.write("\n")
            tmp_file.flush()

            remote_file_name = self.file_pattern.format('aa')
            remote_name = os.path.join(self.subdir, remote_file_name)
            gcp_hook.upload(self.bucket, remote_name, tmp_file.name)

        logging.info('Done.')
Ejemplo n.º 5
0
    def test_get_results_with_hive_conf(self):
        hql = ["set key",
               "set airflow.ctx.dag_id",
               "set airflow.ctx.dag_run_id",
               "set airflow.ctx.task_id",
               "set airflow.ctx.execution_date"]

        dag_id_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format']
        task_id_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format']
        execution_date_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][
                'env_var_format']
        dag_run_id_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][
                'env_var_format']
        os.environ[dag_id_ctx_var_name] = 'test_dag_id'
        os.environ[task_id_ctx_var_name] = 'test_task_id'
        os.environ[execution_date_ctx_var_name] = 'test_execution_date'
        os.environ[dag_run_id_ctx_var_name] = 'test_dag_run_id'

        hook = HiveServer2Hook()
        output = '\n'.join(res_tuple[0]
                           for res_tuple
                           in hook.get_results(hql=hql,
                                               hive_conf={'key': 'value'})['data'])
        self.assertIn('value', output)
        self.assertIn('test_dag_id', output)
        self.assertIn('test_task_id', output)
        self.assertIn('test_execution_date', output)
        self.assertIn('test_dag_run_id', output)

        del os.environ[dag_id_ctx_var_name]
        del os.environ[task_id_ctx_var_name]
        del os.environ[execution_date_ctx_var_name]
        del os.environ[dag_run_id_ctx_var_name]
Ejemplo n.º 6
0
    def execute(self, context):
        hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)

        self.log.info("Extracting data from Hive: %s", self.sql)
        hive_conf = context_to_airflow_vars(context)
        if self.hive_conf:
            hive_conf.update(self.hive_conf)
        if self.bulk_load:
            tmp_file = NamedTemporaryFile()
            hive.to_csv(self.sql,
                        tmp_file.name,
                        delimiter='\t',
                        lineterminator='\n',
                        output_header=False,
                        hive_conf=hive_conf)
        else:
            hive_results = hive.get_records(self.sql, hive_conf=hive_conf)

        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)

        if self.mysql_preoperator:
            self.log.info("Running MySQL preoperator")
            mysql.run(self.mysql_preoperator)

        self.log.info("Inserting rows into MySQL")
        if self.bulk_load:
            mysql.bulk_load(table=self.mysql_table, tmp_file=tmp_file.name)
            tmp_file.close()
        else:
            mysql.insert_rows(table=self.mysql_table, rows=hive_results)

        if self.mysql_postoperator:
            self.log.info("Running MySQL postoperator")
            mysql.run(self.mysql_postoperator)

        self.log.info("Done.")
Ejemplo n.º 7
0
    def execute(self, context):
        hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)

        self.log.info('Extracting data from Hive')
        self.log.info(self.sql)

        data = hive.get_pandas_df(self.sql, schema=self.schema)
        dynamodb = AwsDynamoDBHook(aws_conn_id=self.aws_conn_id,
                                   table_name=self.table_name,
                                   table_keys=self.table_keys,
                                   region_name=self.region_name)

        self.log.info('Inserting rows into dynamodb')

        if self.pre_process is None:
            dynamodb.write_batch_data(
                json.loads(data.to_json(orient='records')))
        else:
            dynamodb.write_batch_data(
                self.pre_process(data=data,
                                 args=self.pre_process_args,
                                 kwargs=self.pre_process_kwargs))

        self.log.info('Done.')
def add_partition(v_execution_date, v_execution_day, v_execution_hour, db_name, table_name, conn_id, hive_table_name,
                  server_name, hive_db, is_must_have_data, **kwargs):
    # 生成_SUCCESS
    """
    第一个参数true: 数据目录是有country_code分区。false 没有
    第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS

    """
    TaskTouchzSuccess().countries_touchz_success(
        v_execution_day,
        hive_db,
        hive_table_name,
        ALL_HI_OSS_PATH % hive_table_name,
        "false", is_must_have_data, v_execution_hour)

    sql = '''
            ALTER TABLE {hive_db}.{table} ADD IF NOT EXISTS PARTITION (dt = '{ds}', hour = '{hour}')
        '''.format(hive_db=hive_db, table=hive_table_name, ds=v_execution_day, hour=v_execution_hour)

    hive2_conn = HiveServer2Hook().get_conn()
    cursor = hive2_conn.cursor()
    cursor.execute(sql)

    return
Ejemplo n.º 9
0
 def test_get_records(self):
     hook = HiveServer2Hook()
     query = "SELECT * FROM {}".format(self.table)
     results = hook.get_pandas_df(query, schema=self.database)
     self.assertEqual(len(results), 2)
Ejemplo n.º 10
0
def run_check_table(schema_table_db_name, schema_table_name,
                    target_table_db_name, target_table_name, conn_id,
                    hive_table_name, server_name, **kwargs):
    # SHOW TABLES in oride_db LIKE 'data_aa'
    check_sql = 'SHOW TABLES in %s LIKE \'%s\'' % (HIVE_DB, hive_table_name)
    hive2_conn = HiveServer2Hook().get_conn()
    cursor = hive2_conn.cursor()
    cursor.execute(check_sql)
    if len(cursor.fetchall()) == 0:
        logging.info('Create Hive Table: %s.%s', HIVE_DB, hive_table_name)
        # get table column
        column_sql = '''
                SELECT
                    COLUMN_NAME,
                    DATA_TYPE,
                    NUMERIC_PRECISION,
                    NUMERIC_SCALE,
                    COLUMN_COMMENT
                FROM
                    information_schema.columns
                WHERE
                    table_schema='{db_name}' and table_name='{table_name}'
            '''.format(db_name=schema_table_db_name,
                       table_name=schema_table_name)
        mysql_hook = MySqlHook(conn_id)
        mysql_conn = mysql_hook.get_conn()
        mysql_cursor = mysql_conn.cursor()
        mysql_cursor.execute(column_sql)
        results = mysql_cursor.fetchall()
        rows = []
        for result in results:
            if result[0] == 'dt':
                col_name = '_dt'
            else:
                col_name = result[0]
            if result[1] == 'timestamp' or result[1] == 'varchar' or result[1] == 'char' or result[1] == 'text' or \
                    result[1] == 'longtext' or \
                    result[1] == 'mediumtext' or \
                    result[1] == 'json' or \
                    result[1] == 'datetime':
                data_type = 'string'
            elif result[1] == 'decimal':
                data_type = result[1] + "(" + str(result[2]) + "," + str(
                    result[3]) + ")"
            else:
                data_type = result[1]
            rows.append("`%s` %s comment '%s'" %
                        (col_name, data_type, str(result[4]).replace(
                            '\n', '').replace('\r', '')))
        mysql_conn.close()

        # hive create table
        hive_hook = HiveCliHook()
        sql = ODS_CREATE_TABLE_SQL.format(
            db_name=HIVE_DB,
            table_name=hive_table_name,
            columns=",\n".join(rows),
            oss_path=OSS_PATH % ("{server_name}.{db_name}.{table_name}".format(
                server_name=server_name,
                db_name=target_table_db_name,
                table_name=target_table_name)))
        logging.info('Executing: %s', sql)
        hive_hook.run_cli(sql)

    else:
        sqoopSchema = SqoopSchemaUpdate()
        response = sqoopSchema.append_hive_schema(
            hive_db=HIVE_DB,
            hive_table=hive_table_name,
            mysql_db=schema_table_db_name,
            mysql_table=schema_table_name,
            mysql_conn=conn_id,
            oss_path=OSS_PATH % ("{server_name}.{db_name}.{table_name}".format(
                server_name=server_name,
                db_name=target_table_db_name,
                table_name=target_table_name)))
        if response:
            return True
    return
Ejemplo n.º 11
0
 def test_select_conn(self):
     from airflow.hooks.hive_hooks import HiveServer2Hook
     sql = "select 1"
     hook = HiveServer2Hook()
     hook.get_records(sql)
Ejemplo n.º 12
0
 def test_get_results_header(self):
     hook = HiveServer2Hook()
     query = "SELECT * FROM {}".format(self.table)
     results = hook.get_results(query, schema=self.database)
     self.assertListEqual([col[0] for col in results['header']],
                          self.columns)
Ejemplo n.º 13
0
 def test_get_results_data(self):
     hook = HiveServer2Hook()
     query = "SELECT * FROM {}".format(self.table)
     results = hook.get_results(query, schema=self.database)
     self.assertListEqual(results['data'], [(1, 1), (2, 2)])
Ejemplo n.º 14
0
 def test_get_conn(self):
     hook = HiveServer2Hook()
     hook.get_conn()
Ejemplo n.º 15
0
 def get_hook(self):
     if self.conn_type == 'mysql':
         from airflow.hooks.mysql_hook import MySqlHook
         return MySqlHook(mysql_conn_id=self.conn_id)
     elif self.conn_type == 'google_cloud_platform':
         from airflow.gcp.hooks.bigquery import BigQueryHook
         return BigQueryHook(bigquery_conn_id=self.conn_id)
     elif self.conn_type == 'postgres':
         from airflow.hooks.postgres_hook import PostgresHook
         return PostgresHook(postgres_conn_id=self.conn_id)
     elif self.conn_type == 'pig_cli':
         from airflow.hooks.pig_hook import PigCliHook
         return PigCliHook(pig_cli_conn_id=self.conn_id)
     elif self.conn_type == 'hive_cli':
         from airflow.hooks.hive_hooks import HiveCliHook
         return HiveCliHook(hive_cli_conn_id=self.conn_id)
     elif self.conn_type == 'presto':
         from airflow.hooks.presto_hook import PrestoHook
         return PrestoHook(presto_conn_id=self.conn_id)
     elif self.conn_type == 'hiveserver2':
         from airflow.hooks.hive_hooks import HiveServer2Hook
         return HiveServer2Hook(hiveserver2_conn_id=self.conn_id)
     elif self.conn_type == 'sqlite':
         from airflow.hooks.sqlite_hook import SqliteHook
         return SqliteHook(sqlite_conn_id=self.conn_id)
     elif self.conn_type == 'jdbc':
         from airflow.hooks.jdbc_hook import JdbcHook
         return JdbcHook(jdbc_conn_id=self.conn_id)
     elif self.conn_type == 'mssql':
         from airflow.hooks.mssql_hook import MsSqlHook
         return MsSqlHook(mssql_conn_id=self.conn_id)
     elif self.conn_type == 'oracle':
         from airflow.hooks.oracle_hook import OracleHook
         return OracleHook(oracle_conn_id=self.conn_id)
     elif self.conn_type == 'vertica':
         from airflow.contrib.hooks.vertica_hook import VerticaHook
         return VerticaHook(vertica_conn_id=self.conn_id)
     elif self.conn_type == 'cloudant':
         from airflow.contrib.hooks.cloudant_hook import CloudantHook
         return CloudantHook(cloudant_conn_id=self.conn_id)
     elif self.conn_type == 'jira':
         from airflow.contrib.hooks.jira_hook import JiraHook
         return JiraHook(jira_conn_id=self.conn_id)
     elif self.conn_type == 'redis':
         from airflow.contrib.hooks.redis_hook import RedisHook
         return RedisHook(redis_conn_id=self.conn_id)
     elif self.conn_type == 'wasb':
         from airflow.contrib.hooks.wasb_hook import WasbHook
         return WasbHook(wasb_conn_id=self.conn_id)
     elif self.conn_type == 'docker':
         from airflow.hooks.docker_hook import DockerHook
         return DockerHook(docker_conn_id=self.conn_id)
     elif self.conn_type == 'azure_data_lake':
         from airflow.contrib.hooks.azure_data_lake_hook import AzureDataLakeHook
         return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id)
     elif self.conn_type == 'azure_cosmos':
         from airflow.contrib.hooks.azure_cosmos_hook import AzureCosmosDBHook
         return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id)
     elif self.conn_type == 'cassandra':
         from airflow.contrib.hooks.cassandra_hook import CassandraHook
         return CassandraHook(cassandra_conn_id=self.conn_id)
     elif self.conn_type == 'mongo':
         from airflow.contrib.hooks.mongo_hook import MongoHook
         return MongoHook(conn_id=self.conn_id)
     elif self.conn_type == 'gcpcloudsql':
         from airflow.gcp.hooks.cloud_sql import CloudSqlDatabaseHook
         return CloudSqlDatabaseHook(gcp_cloudsql_conn_id=self.conn_id)
     elif self.conn_type == 'grpc':
         from airflow.contrib.hooks.grpc_hook import GrpcHook
         return GrpcHook(grpc_conn_id=self.conn_id)
     raise AirflowException("Unknown hook type {}".format(self.conn_type))
Ejemplo n.º 16
0
 def test_to_csv(self):
     from airflow.hooks.hive_hooks import HiveServer2Hook
     sql = "select 1"
     hook = HiveServer2Hook()
     hook.to_csv(hql=sql, csv_filepath="/tmp/test_to_csv")
Ejemplo n.º 17
0
    def test_mysql_to_hive_verify_loaded_values(self):
        mysql_conn_id = 'airflow_ci'
        mysql_table = 'test_mysql_to_hive'
        hive_table = 'test_mysql_to_hive'

        from airflow.hooks.mysql_hook import MySqlHook
        m = MySqlHook(mysql_conn_id)

        try:
            minmax = (
                255,
                65535,
                16777215,
                4294967295,
                18446744073709551615,
                -128,
                -32768,
                -8388608,
                -2147483648,
                -9223372036854775808
            )

            with m.get_conn() as c:
                c.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
                c.execute("""
                    CREATE TABLE {} (
                        c0 TINYINT   UNSIGNED,
                        c1 SMALLINT  UNSIGNED,
                        c2 MEDIUMINT UNSIGNED,
                        c3 INT       UNSIGNED,
                        c4 BIGINT    UNSIGNED,
                        c5 TINYINT,
                        c6 SMALLINT,
                        c7 MEDIUMINT,
                        c8 INT,
                        c9 BIGINT
                    )
                """.format(mysql_table))
                c.execute("""
                    INSERT INTO {} VALUES (
                        {}, {}, {}, {}, {}, {}, {}, {}, {}, {}
                    )
                """.format(mysql_table, *minmax))

            from airflow.operators.mysql_to_hive import MySqlToHiveTransfer
            t = MySqlToHiveTransfer(
                task_id='test_m2h',
                mysql_conn_id=mysql_conn_id,
                hive_cli_conn_id='beeline_default',
                sql="SELECT * FROM {}".format(mysql_table),
                hive_table=hive_table,
                recreate=True,
                delimiter=",",
                dag=self.dag)
            t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

            from airflow.hooks.hive_hooks import HiveServer2Hook
            h = HiveServer2Hook()
            r = h.get_records("SELECT * FROM {}".format(hive_table))
            self.assertEqual(r[0], minmax)
        finally:
            with m.get_conn() as c:
                c.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
Ejemplo n.º 18
0
 def get_hook(self):
     return HiveServer2Hook(
         hiveserver2_conn_id=self.hive_cli_conn_id)