예제 #1
0
        def test_get_results_with_schema(self):
            from airflow.hooks.hive_hooks import HiveServer2Hook
            from unittest.mock import MagicMock

            # Configure
            sql = "select 1"
            schema = "notdefault"
            hook = HiveServer2Hook()
            cursor_mock = MagicMock(
                __enter__=cursor_mock,
                __exit__=None,
                execute=None,
                fetchall=[],
            )
            get_conn_mock = MagicMock(
                __enter__=get_conn_mock,
                __exit__=None,
                cursor=cursor_mock,
            )
            hook.get_conn = get_conn_mock

            # Run
            hook.get_results(sql, schema)

            # Verify
            get_conn_mock.assert_called_with(self.nondefault_schema)
예제 #2
0
 def test_get_pandas_df(self):
     hook = HiveServer2Hook()
     query = "SELECT * FROM {}".format(self.table)
     df = hook.get_pandas_df(query, schema=self.database)
     self.assertEqual(len(df), 2)
     self.assertListEqual(df.columns.tolist(), self.columns)
     self.assertListEqual(df[self.columns[0]].values.tolist(), [1, 2])
예제 #3
0
def add_partition(v_execution_date, v_execution_day, v_execution_hour,
                  target_table_db_name, target_table_name, conn_id,
                  hive_table_name, server_name, hive_db, is_must_have_data,
                  **kwargs):
    # 生成_SUCCESS
    """
    第一个参数true: 数据目录是有country_code分区。false 没有
    第二个参数true: 数据有才生成_SUCCESS false 数据没有也生成_SUCCESS

    """
    TaskTouchzSuccess().countries_touchz_success(
        v_execution_day, hive_db, hive_table_name,
        OSS_PATH % ("{server_name}.{db_name}.{table_name}".format(
            server_name=server_name,
            db_name=target_table_db_name,
            table_name=target_table_name)), "false", is_must_have_data,
        v_execution_hour)

    sql = '''
            ALTER TABLE {hive_db}.{table} ADD IF NOT EXISTS PARTITION (dt = '{ds}', hour = '{hour}')
        '''.format(hive_db=hive_db,
                   table=hive_table_name,
                   ds=v_execution_day,
                   hour=v_execution_hour)

    hive2_conn = HiveServer2Hook().get_conn()
    cursor = hive2_conn.cursor()
    cursor.execute(sql)

    return
예제 #4
0
    def execute(self, context):
        hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)
        logging.info("Extracting data from Hive")
        logging.info(self.sql)

        if self.bulk_load:
            tmpfile = NamedTemporaryFile()
            hive.to_csv(self.sql,
                        tmpfile.name,
                        delimiter='\t',
                        lineterminator='\n',
                        output_header=False)
        else:
            results = hive.get_records(self.sql)

        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
        if self.mysql_preoperator:
            logging.info("Running MySQL preoperator")
            mysql.run(self.mysql_preoperator)

        logging.info("Inserting rows into MySQL")

        if self.bulk_load:
            mysql.bulk_load(table=self.mysql_table, tmp_file=tmpfile.name)
            tmpfile.close()
        else:
            mysql.insert_rows(table=self.mysql_table, rows=results)

        if self.mysql_postoperator:
            logging.info("Running MySQL postoperator")
            mysql.run(self.mysql_postoperator)

        logging.info("Done.")
예제 #5
0
    def test_get_results_with_hive_conf(self):
        hql = [
            "set key", "set airflow.ctx.dag_id", "set airflow.ctx.dag_run_id",
            "set airflow.ctx.task_id", "set airflow.ctx.execution_date"
        ]

        dag_id_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format']
        task_id_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format']
        execution_date_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][
                'env_var_format']
        dag_run_id_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][
                'env_var_format']
        os.environ[dag_id_ctx_var_name] = 'test_dag_id'
        os.environ[task_id_ctx_var_name] = 'test_task_id'
        os.environ[execution_date_ctx_var_name] = 'test_execution_date'
        os.environ[dag_run_id_ctx_var_name] = 'test_dag_run_id'

        hook = HiveServer2Hook()
        output = '\n'.join(res_tuple[0] for res_tuple in hook.get_results(
            hql=hql, hive_conf={'key': 'value'})['data'])
        self.assertIn('value', output)
        self.assertIn('test_dag_id', output)
        self.assertIn('test_task_id', output)
        self.assertIn('test_execution_date', output)
        self.assertIn('test_dag_run_id', output)

        del os.environ[dag_id_ctx_var_name]
        del os.environ[task_id_ctx_var_name]
        del os.environ[execution_date_ctx_var_name]
        del os.environ[dag_run_id_ctx_var_name]
예제 #6
0
 def test_multi_statements(self):
     from airflow.hooks.hive_hooks import HiveServer2Hook
     sqls = [
         "CREATE TABLE IF NOT EXISTS test_multi_statements (i INT)",
         "DROP TABLE test_multi_statements",
     ]
     hook = HiveServer2Hook()
     hook.get_records(sqls)
예제 #7
0
 def execute(self, context):
     samba = SambaHook(samba_conn_id=self.samba_conn_id)
     hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)
     tmpfile = tempfile.NamedTemporaryFile()
     self.log.info("Fetching file from Hive")
     hive.to_csv(hql=self.hql, csv_filepath=tmpfile.name)
     self.log.info("Pushing to samba")
     samba.push_from_local(self.destination_filepath, tmpfile.name)
def validate_all_hi_table_exist_task(hive_all_hi_table_name, mysql_table_name, **kwargs):
    check_sql = 'show partitions %s.%s' % (HIVE_DB, hive_all_hi_table_name)
    hive2_conn = HiveServer2Hook().get_conn()
    cursor = hive2_conn.cursor()
    cursor.execute(check_sql)
    if len(cursor.fetchall()) == 0:
        return 'import_table_{}'.format(mysql_table_name)
    else:
        return 'add_partitions_{}'.format(hive_all_hi_table_name)
예제 #9
0
 def execute(self, context):
     samba = SambaHook(samba_conn_id=self.samba_conn_id)
     hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)
     with NamedTemporaryFile() as tmp_file:
         self.log.info("Fetching file from Hive")
         hive.to_csv(hql=self.hql, csv_filepath=tmp_file.name,
                     hive_conf=context_to_airflow_vars(context))
         self.log.info("Pushing to samba")
         samba.push_from_local(self.destination_filepath, tmp_file.name)
예제 #10
0
 def test_multi_statements(self):
     sqls = [
         "CREATE TABLE IF NOT EXISTS test_multi_statements (i INT)",
         "SELECT * FROM {}".format(self.table),
         "DROP TABLE test_multi_statements",
     ]
     hook = HiveServer2Hook()
     results = hook.get_records(sqls, schema=self.database)
     self.assertListEqual(results, [(1, 1), (2, 2)])
예제 #11
0
 def test_to_csv(self):
     hook = HiveServer2Hook()
     query = "SELECT * FROM {}".format(self.table)
     csv_filepath = 'query_results.csv'
     hook.to_csv(query, csv_filepath, schema=self.database,
                 delimiter=',', lineterminator='\n', output_header=True)
     df = pd.read_csv(csv_filepath, sep=',')
     self.assertListEqual(df.columns.tolist(), self.columns)
     self.assertListEqual(df[self.columns[0]].values.tolist(), [1, 2])
     self.assertEqual(len(df), 2)
예제 #12
0
    def test_mysql_to_hive_verify_loaded_values(self):
        mysql_conn_id = 'airflow_ci'
        mysql_table = 'test_mysql_to_hive'
        hive_table = 'test_mysql_to_hive'

        from airflow.hooks.mysql_hook import MySqlHook
        m = MySqlHook(mysql_conn_id)

        try:
            minmax = (255, 65535, 16777215, 4294967295, 18446744073709551615,
                      -128, -32768, -8388608, -2147483648,
                      -9223372036854775808)

            with m.get_conn() as c:
                c.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
                c.execute("""
                    CREATE TABLE {} (
                        c0 TINYINT   UNSIGNED,
                        c1 SMALLINT  UNSIGNED,
                        c2 MEDIUMINT UNSIGNED,
                        c3 INT       UNSIGNED,
                        c4 BIGINT    UNSIGNED,
                        c5 TINYINT,
                        c6 SMALLINT,
                        c7 MEDIUMINT,
                        c8 INT,
                        c9 BIGINT
                    )
                """.format(mysql_table))
                c.execute("""
                    INSERT INTO {} VALUES (
                        {}, {}, {}, {}, {}, {}, {}, {}, {}, {}
                    )
                """.format(mysql_table, *minmax))

            from airflow.operators.mysql_to_hive import MySqlToHiveTransfer
            t = MySqlToHiveTransfer(task_id='test_m2h',
                                    mysql_conn_id=mysql_conn_id,
                                    hive_cli_conn_id='beeline_default',
                                    sql="SELECT * FROM {}".format(mysql_table),
                                    hive_table=hive_table,
                                    recreate=True,
                                    delimiter=",",
                                    dag=self.dag)
            t.run(start_date=DEFAULT_DATE,
                  end_date=DEFAULT_DATE,
                  ignore_ti_state=True)

            from airflow.hooks.hive_hooks import HiveServer2Hook
            h = HiveServer2Hook()
            r = h.get_records("SELECT * FROM {}".format(hive_table))
            self.assertEqual(r[0], minmax)
        finally:
            with m.get_conn() as c:
                c.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
def add_hi_partition(v_execution_date, v_execution_day, v_execution_hour, db_name, table_name, conn_id, hive_table_name,
                     server_name, hive_db, is_must_have_data, **kwargs):
    sql = '''
            ALTER TABLE {hive_db}.{table} ADD IF NOT EXISTS PARTITION (dt = '{ds}', hour = '{hour}')
        '''.format(hive_db=hive_db, table=hive_table_name, ds=v_execution_day, hour=v_execution_hour)

    hive2_conn = HiveServer2Hook().get_conn()
    cursor = hive2_conn.cursor()
    cursor.execute(sql)

    return
예제 #14
0
 def get_hook(self):
     try:
         if self.conn_type == 'mysql':
             from airflow.hooks.mysql_hook import MySqlHook
             return MySqlHook(mysql_conn_id=self.conn_id)
         elif self.conn_type == 'google_cloud_platform':
             from airflow.contrib.hooks.bigquery_hook import BigQueryHook
             return BigQueryHook(bigquery_conn_id=self.conn_id)
         elif self.conn_type == 'postgres':
             from airflow.hooks.postgres_hook import PostgresHook
             return PostgresHook(postgres_conn_id=self.conn_id)
         elif self.conn_type == 'hive_cli':
             from airflow.hooks.hive_hooks import HiveCliHook
             return HiveCliHook(hive_cli_conn_id=self.conn_id)
         elif self.conn_type == 'presto':
             from airflow.hooks.presto_hook import PrestoHook
             return PrestoHook(presto_conn_id=self.conn_id)
         elif self.conn_type == 'hiveserver2':
             from airflow.hooks.hive_hooks import HiveServer2Hook
             return HiveServer2Hook(hiveserver2_conn_id=self.conn_id)
         elif self.conn_type == 'sqlite':
             from airflow.hooks.sqlite_hook import SqliteHook
             return SqliteHook(sqlite_conn_id=self.conn_id)
         elif self.conn_type == 'jdbc':
             from airflow.hooks.jdbc_hook import JdbcHook
             return JdbcHook(jdbc_conn_id=self.conn_id)
         elif self.conn_type == 'mssql':
             from airflow.hooks.mssql_hook import MsSqlHook
             return MsSqlHook(mssql_conn_id=self.conn_id)
         elif self.conn_type == 'oracle':
             from airflow.hooks.oracle_hook import OracleHook
             return OracleHook(oracle_conn_id=self.conn_id)
         elif self.conn_type == 'vertica':
             from airflow.contrib.hooks.vertica_hook import VerticaHook
             return VerticaHook(vertica_conn_id=self.conn_id)
         elif self.conn_type == 'cloudant':
             from airflow.contrib.hooks.cloudant_hook import CloudantHook
             return CloudantHook(cloudant_conn_id=self.conn_id)
         elif self.conn_type == 'jira':
             from airflow.contrib.hooks.jira_hook import JiraHook
             return JiraHook(jira_conn_id=self.conn_id)
         elif self.conn_type == 'redis':
             from airflow.contrib.hooks.redis_hook import RedisHook
             return RedisHook(redis_conn_id=self.conn_id)
         elif self.conn_type == 'wasb':
             from airflow.contrib.hooks.wasb_hook import WasbHook
             return WasbHook(wasb_conn_id=self.conn_id)
         elif self.conn_type == 'docker':
             from airflow.hooks.docker_hook import DockerHook
             return DockerHook(docker_conn_id=self.conn_id)
     except:
         pass
예제 #15
0
 def test_to_csv(self):
     hook = HiveServer2Hook()
     query = "SELECT * FROM {}".format(self.table)
     csv_filepath = 'query_results.csv'
     with self.assertLogs() as cm:
         hook.to_csv(query, csv_filepath, schema=self.database,
                     delimiter=',', lineterminator='\n', output_header=True, fetch_size=2)
         df = pd.read_csv(csv_filepath, sep=',')
         self.assertListEqual(df.columns.tolist(), self.columns)
         self.assertListEqual(df[self.columns[0]].values.tolist(), [1, 2])
         self.assertEqual(len(df), 2)
         self.assertIn('INFO:airflow.hooks.hive_hooks.HiveServer2Hook:'
                       'Written 2 rows so far.', cm.output)
예제 #16
0
        def test_select_conn_with_schema(self, connect_mock):
            from airflow.hooks.hive_hooks import HiveServer2Hook

            # Configure
            hook = HiveServer2Hook()

            # Run
            hook.get_conn(self.nondefault_schema)

            # Verify
            assert connect_mock.called
            (args, kwargs) = connect_mock.call_args_list[0]
            assert kwargs['database'] == self.nondefault_schema
예제 #17
0
def count_data_rows(templates_dict, **kwargs):

    hook = HiveServer2Hook()

    query = """
        SELECT count(*) FROM mydata
    """
    result = hook.get_results(schema=templates_dict['schema'], hql=query)

    if result['data'][0][0] > 100:

        return 'clean_data'
    else:
        return 'stop_flow'
예제 #18
0
        def test_get_records_with_schema(self, get_results_mock):
            from airflow.hooks.hive_hooks import HiveServer2Hook

            # Configure
            sql = "select 1"
            hook = HiveServer2Hook()

            # Run
            hook.get_records(sql, self.nondefault_schema)

            # Verify
            self.assertTrue(self.connect_mock.called)
            (args, kwargs) = self.connect_mock.call_args_list[0]
            self.assertEqual(sql, args[0])
            self.assertEqual(self.nondefault_schema, kwargs['schema'])
예제 #19
0
        def test_get_pandas_df_with_schema(self, get_results_mock):
            from airflow.hooks.hive_hooks import HiveServer2Hook

            # Configure
            sql = "select 1"
            hook = HiveServer2Hook()

            # Run
            hook.get_pandas_df(sql, self.nondefault_schema)

            # Verify
            assert self.connect_mock.called
            (args, kwargs) = self.connect_mock.call_args_list[0]
            assert args[0] == sql
            assert kwargs['schema'] == self.nondefault_schema
예제 #20
0
    def test_mysql_to_hive_verify_csv_special_char(self):
        mysql_table = 'test_mysql_to_hive'
        hive_table = 'test_mysql_to_hive'

        from airflow.hooks.mysql_hook import MySqlHook
        hook = MySqlHook()

        try:
            db_record = (
                'c0',
                '["true"]'
            )
            with hook.get_conn() as conn:
                conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
                conn.execute("""
                    CREATE TABLE {} (
                        c0 VARCHAR(25),
                        c1 VARCHAR(25)
                    )
                """.format(mysql_table))
                conn.execute("""
                    INSERT INTO {} VALUES (
                        '{}', '{}'
                    )
                """.format(mysql_table, *db_record))

            from airflow.operators.mysql_to_hive import MySqlToHiveTransfer
            import unicodecsv as csv
            op = MySqlToHiveTransfer(
                task_id='test_m2h',
                hive_cli_conn_id='hive_cli_default',
                sql="SELECT * FROM {}".format(mysql_table),
                hive_table=hive_table,
                recreate=True,
                delimiter=",",
                quoting=csv.QUOTE_NONE,
                quotechar='',
                escapechar='@',
                dag=self.dag)
            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

            from airflow.hooks.hive_hooks import HiveServer2Hook
            hive_hook = HiveServer2Hook()
            result = hive_hook.get_records("SELECT * FROM {}".format(hive_table))
            self.assertEqual(result[0], db_record)
        finally:
            with hook.get_conn() as conn:
                conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
    def execute(self, context):
        hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)
        tmpfile = tempfile.NamedTemporaryFile()
        self.log.info("Fetching file from Hive")

        hive.to_csv(hql=self.hql, csv_filepath=tmpfile.name)
        self.log.info("Pushing to localfile")

        if not os.path.exists(self.dst_path):
            os.makedirs(self.dst_path)

        target = os.path.join(
            self.dst_path,
            self.dst_filename) if self.dst_filename else self.dst_path

        shutil.copy(tmpfile.name, target)
예제 #22
0
    def test_get_conn_with_password(self, mock_connect):
        conn_id = "conn_with_password"
        conn_env = CONN_ENV_PREFIX + conn_id.upper()

        with patch.dict(
            'os.environ',
            {conn_env: "jdbc+hive2://conn_id:conn_pass@localhost:10000/default?authMechanism=LDAP"}
        ):
            HiveServer2Hook(hiveserver2_conn_id=conn_id).get_conn()
            mock_connect.assert_called_once_with(
                host='localhost',
                port=10000,
                auth='LDAP',
                kerberos_service_name=None,
                username='******',
                password='******',
                database='default')
예제 #23
0
    def test_get_conn_with_password(self, mock_connect):
        from airflow.hooks.base_hook import CONN_ENV_PREFIX
        conn_id = "conn_with_password"
        conn_env = CONN_ENV_PREFIX + conn_id.upper()
        conn_value = os.environ.get(conn_env)
        os.environ[conn_env] = "jdbc+hive2://conn_id:conn_pass@localhost:10000/default?authMechanism=LDAP"

        HiveServer2Hook(hiveserver2_conn_id=conn_id).get_conn()
        mock_connect.assert_called_with(
            host='localhost',
            port=10000,
            auth='LDAP',
            kerberos_service_name=None,
            username='******',
            password='******',
            database='default')

        if conn_value:
            os.environ[conn_env] = conn_value
예제 #24
0
    def get_records(self):
        """Executes a query to obtain a count of records on a table

        Returns:
            int -- quantity of records from a count query
        """

        if self.query_engine == 'hive':
            hook = HiveServer2Hook(self.query_engine_conn_id)

        elif self.query_engine == 'presto':
            hook = PrestoHook(self.query_engine_conn_id)

        # executes query to Hive or Presto
        res = hook.get_records(self.records_query)

        if len(res) > 1:
            raise
        else:
            return res[0]
    def execute(self, context):
        hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)

        logging.info('Extracting data from Hive')
        logging.info(self.hql)

        data = hive.get_pandas_df(self.hql, schema=self.schema)
        gcp_hook = GoogleCloudStorageHook(google_cloud_storage_conn_id=self.google_cloud_storage_conn_id)
        logging.info('Inserting rows onto google cloud storage')

        with tempfile.NamedTemporaryFile(suffix='.json', prefix='tmp') as tmp_file:
            data = data.to_json(orient='records')
            recs = json.loads(data)
            for record in recs:
                tmp_file.write(json.dumps(record))
                tmp_file.write("\n")
            tmp_file.flush()

            remote_file_name = self.file_pattern.format('aa')
            remote_name = os.path.join(self.subdir, remote_file_name)
            gcp_hook.upload(self.bucket, remote_name, tmp_file.name)

        logging.info('Done.')
예제 #26
0
    def execute(self, context):
        hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)

        self.log.info('Extracting data from Hive')
        self.log.info(self.sql)

        data = hive.get_pandas_df(self.sql, schema=self.schema)
        dynamodb = AwsDynamoDBHook(aws_conn_id=self.aws_conn_id,
                                   table_name=self.table_name,
                                   table_keys=self.table_keys,
                                   region_name=self.region_name)

        self.log.info('Inserting rows into dynamodb')

        if self.pre_process is None:
            dynamodb.write_batch_data(
                json.loads(data.to_json(orient='records')))
        else:
            dynamodb.write_batch_data(
                self.pre_process(data=data,
                                 args=self.pre_process_args,
                                 kwargs=self.pre_process_kwargs))

        self.log.info('Done.')
예제 #27
0
def run_check_table(schema_table_db_name, schema_table_name,
                    target_table_db_name, target_table_name, conn_id,
                    hive_table_name, server_name, **kwargs):
    # SHOW TABLES in oride_db LIKE 'data_aa'
    check_sql = 'SHOW TABLES in %s LIKE \'%s\'' % (HIVE_DB, hive_table_name)
    hive2_conn = HiveServer2Hook().get_conn()
    cursor = hive2_conn.cursor()
    cursor.execute(check_sql)
    if len(cursor.fetchall()) == 0:
        logging.info('Create Hive Table: %s.%s', HIVE_DB, hive_table_name)
        # get table column
        column_sql = '''
                SELECT
                    COLUMN_NAME,
                    DATA_TYPE,
                    NUMERIC_PRECISION,
                    NUMERIC_SCALE,
                    COLUMN_COMMENT
                FROM
                    information_schema.columns
                WHERE
                    table_schema='{db_name}' and table_name='{table_name}'
            '''.format(db_name=schema_table_db_name,
                       table_name=schema_table_name)
        mysql_hook = MySqlHook(conn_id)
        mysql_conn = mysql_hook.get_conn()
        mysql_cursor = mysql_conn.cursor()
        mysql_cursor.execute(column_sql)
        results = mysql_cursor.fetchall()
        rows = []
        for result in results:
            if result[0] == 'dt':
                col_name = '_dt'
            else:
                col_name = result[0]
            if result[1] == 'timestamp' or result[1] == 'varchar' or result[1] == 'char' or result[1] == 'text' or \
                    result[1] == 'longtext' or \
                    result[1] == 'mediumtext' or \
                    result[1] == 'json' or \
                    result[1] == 'datetime':
                data_type = 'string'
            elif result[1] == 'decimal':
                data_type = result[1] + "(" + str(result[2]) + "," + str(
                    result[3]) + ")"
            else:
                data_type = result[1]
            rows.append("`%s` %s comment '%s'" %
                        (col_name, data_type, str(result[4]).replace(
                            '\n', '').replace('\r', '')))
        mysql_conn.close()

        # hive create table
        hive_hook = HiveCliHook()
        sql = ODS_CREATE_TABLE_SQL.format(
            db_name=HIVE_DB,
            table_name=hive_table_name,
            columns=",\n".join(rows),
            oss_path=OSS_PATH % ("{server_name}.{db_name}.{table_name}".format(
                server_name=server_name,
                db_name=target_table_db_name,
                table_name=target_table_name)))
        logging.info('Executing: %s', sql)
        hive_hook.run_cli(sql)

    else:
        sqoopSchema = SqoopSchemaUpdate()
        response = sqoopSchema.append_hive_schema(
            hive_db=HIVE_DB,
            hive_table=hive_table_name,
            mysql_db=schema_table_db_name,
            mysql_table=schema_table_name,
            mysql_conn=conn_id,
            oss_path=OSS_PATH % ("{server_name}.{db_name}.{table_name}".format(
                server_name=server_name,
                db_name=target_table_db_name,
                table_name=target_table_name)))
        if response:
            return True
    return
예제 #28
0
 def test_get_conn(self):
     hook = HiveServer2Hook()
     hook.get_conn()
예제 #29
0
 def test_get_results_header(self):
     hook = HiveServer2Hook()
     query = "SELECT * FROM {}".format(self.table)
     results = hook.get_results(query, schema=self.database)
     self.assertListEqual([col[0] for col in results['header']],
                          self.columns)
예제 #30
0
 def test_get_results_data(self):
     hook = HiveServer2Hook()
     query = "SELECT * FROM {}".format(self.table)
     results = hook.get_results(query, schema=self.database)
     self.assertListEqual(results['data'], [(1, 1), (2, 2)])