Exemple #1
0
    def test_run_cli_with_hive_conf(self):
        hql = "set key;\n" \
              "set airflow.ctx.dag_id;\nset airflow.ctx.dag_run_id;\n" \
              "set airflow.ctx.task_id;\nset airflow.ctx.execution_date;\n"

        dag_id_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format']
        task_id_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format']
        execution_date_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][
                'env_var_format']
        dag_run_id_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][
                'env_var_format']
        with mock.patch.dict(
                'os.environ', {
                    dag_id_ctx_var_name: 'test_dag_id',
                    task_id_ctx_var_name: 'test_task_id',
                    execution_date_ctx_var_name: 'test_execution_date',
                    dag_run_id_ctx_var_name: 'test_dag_run_id',
                }):
            hook = HiveCliHook()
            output = hook.run_cli(hql=hql, hive_conf={'key': 'value'})
            self.assertIn('value', output)
            self.assertIn('test_dag_id', output)
            self.assertIn('test_task_id', output)
            self.assertIn('test_execution_date', output)
            self.assertIn('test_dag_run_id', output)
Exemple #2
0
    def execute(self, context):
        hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id)

        self.log.info("Dumping Vertica query results to local file")
        conn = vertica.get_conn()
        cursor = conn.cursor()
        cursor.execute(self.sql)
        with NamedTemporaryFile("w") as f:
            csv_writer = csv.writer(f,
                                    delimiter=self.delimiter,
                                    encoding='utf-8')
            field_dict = OrderedDict()
            col_count = 0
            for field in cursor.description:
                col_count += 1
                col_position = f"Column{col_count}"
                field_dict[col_position if field[0] ==
                           '' else field[0]] = self.type_map(field[1])
            csv_writer.writerows(cursor.iterate())
            f.flush()
            cursor.close()
            conn.close()
            self.log.info("Loading file into Hive")
            hive.load_file(
                f.name,
                self.hive_table,
                field_dict=field_dict,
                create=self.create,
                partition=self.partition,
                delimiter=self.delimiter,
                recreate=self.recreate,
            )
Exemple #3
0
    def test_load_file_create_table(self, mock_run_cli):
        filepath = "/path/to/input/file"
        table = "output_table"
        field_dict = OrderedDict([("name", "string"), ("gender", "string")])
        fields = ",\n    ".join([
            '`{k}` {v}'.format(k=k.strip('`'), v=v)
            for k, v in field_dict.items()
        ])

        hook = HiveCliHook()
        hook.load_file(filepath=filepath,
                       table=table,
                       field_dict=field_dict,
                       create=True,
                       recreate=True)

        create_table = ("DROP TABLE IF EXISTS {table};\n"
                        "CREATE TABLE IF NOT EXISTS {table} (\n{fields})\n"
                        "ROW FORMAT DELIMITED\n"
                        "FIELDS TERMINATED BY ','\n"
                        "STORED AS textfile\n;".format(table=table,
                                                       fields=fields))

        load_data = ("LOAD DATA LOCAL INPATH '{filepath}' "
                     "OVERWRITE INTO TABLE {table} ;\n".format(
                         filepath=filepath, table=table))
        calls = [mock.call(create_table), mock.call(load_data)]
        mock_run_cli.assert_has_calls(calls, any_order=True)
    def execute(self, context: Dict[str, str]):
        hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)

        self.log.info("Dumping MySQL query results to local file")
        conn = mysql.get_conn()
        cursor = conn.cursor()
        cursor.execute(self.sql)
        with NamedTemporaryFile("wb") as f:
            csv_writer = csv.writer(
                f,
                delimiter=self.delimiter,
                quoting=self.quoting,
                quotechar=self.quotechar,
                escapechar=self.escapechar,
                encoding="utf-8",
            )
            field_dict = OrderedDict()
            for field in cursor.description:
                field_dict[field[0]] = self.type_map(field[1])
            csv_writer.writerows(cursor)
            f.flush()
            cursor.close()
            conn.close()
            self.log.info("Loading file into Hive")
            hive.load_file(
                f.name,
                self.hive_table,
                field_dict=field_dict,
                create=self.create,
                partition=self.partition,
                delimiter=self.delimiter,
                recreate=self.recreate,
                tblproperties=self.tblproperties,
            )
Exemple #5
0
    def test_load_df_with_data_types(self, mock_run_cli):
        ord_dict = OrderedDict()
        ord_dict['b'] = [True]
        ord_dict['i'] = [-1]
        ord_dict['t'] = [1]
        ord_dict['f'] = [0.0]
        ord_dict['c'] = ['c']
        ord_dict['M'] = [datetime.datetime(2018, 1, 1)]
        ord_dict['O'] = [object()]
        ord_dict['S'] = [b'STRING']
        ord_dict['U'] = ['STRING']
        ord_dict['V'] = [None]
        df = pd.DataFrame(ord_dict)

        hook = HiveCliHook()
        hook.load_df(df, 't')

        query = """
            CREATE TABLE IF NOT EXISTS t (
                `b` BOOLEAN,
                `i` BIGINT,
                `t` BIGINT,
                `f` DOUBLE,
                `c` STRING,
                `M` TIMESTAMP,
                `O` STRING,
                `S` STRING,
                `U` STRING,
                `V` STRING)
            ROW FORMAT DELIMITED
            FIELDS TERMINATED BY ','
            STORED AS textfile
            ;
        """
        assert_equal_ignore_multiple_spaces(self, mock_run_cli.call_args_list[0][0][0], query)
Exemple #6
0
    def execute(self, context: Dict[str, str]):
        mssql = MsSqlHook(mssql_conn_id=self.mssql_conn_id)
        self.log.info(
            "Dumping Microsoft SQL Server query results to local file")
        with mssql.get_conn() as conn:
            with conn.cursor() as cursor:
                cursor.execute(self.sql)
                with NamedTemporaryFile("w") as tmp_file:
                    csv_writer = csv.writer(tmp_file,
                                            delimiter=self.delimiter,
                                            encoding='utf-8')
                    field_dict = OrderedDict()
                    col_count = 0
                    for field in cursor.description:
                        col_count += 1
                        col_position = "Column{position}".format(
                            position=col_count)
                        field_dict[col_position if field[0] ==
                                   '' else field[0]] = self.type_map(field[1])
                    csv_writer.writerows(cursor)
                    tmp_file.flush()

            hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
            self.log.info("Loading file into Hive")
            hive.load_file(tmp_file.name,
                           self.hive_table,
                           field_dict=field_dict,
                           create=self.create,
                           partition=self.partition,
                           delimiter=self.delimiter,
                           recreate=self.recreate,
                           tblproperties=self.tblproperties)
Exemple #7
0
 def ddl(self):
     """
     Retrieve table ddl
     """
     table = request.args.get("table")
     sql = "SHOW CREATE TABLE {table};".format(table=table)
     hook = HiveCliHook(HIVE_CLI_CONN_ID)
     return hook.run_cli(sql)
Exemple #8
0
    def test_load_file_without_create_table(self, mock_run_cli):
        filepath = "/path/to/input/file"
        table = "output_table"

        hook = HiveCliHook()
        hook.load_file(filepath=filepath, table=table, create=False)

        query = ("LOAD DATA LOCAL INPATH '{filepath}' "
                 "OVERWRITE INTO TABLE {table} ;\n".format(filepath=filepath,
                                                           table=table))
        calls = [mock.call(query)]
        mock_run_cli.assert_has_calls(calls, any_order=True)
Exemple #9
0
    def test_get_proxy_user_value(self):
        hook = HiveCliHook()
        returner = mock.MagicMock()
        returner.extra_dejson = {'proxy_user': '******'}
        hook.use_beeline = True
        hook.conn = returner

        # Run
        result = hook._prepare_cli_cmd()

        # Verify
        self.assertIn('hive.server2.proxy.user=a_user_proxy', result[2])
Exemple #10
0
    def test_load_df_with_optional_parameters(self, mock_to_csv, mock_load_file):
        hook = HiveCliHook()
        bools = (True, False)
        for create, recreate in itertools.product(bools, bools):
            mock_load_file.reset_mock()
            hook.load_df(df=pd.DataFrame({"c": range(0, 10)}),
                         table="t",
                         create=create,
                         recreate=recreate)

            assert mock_load_file.call_count == 1
            kwargs = mock_load_file.call_args[1]
            self.assertEqual(kwargs["create"], create)
            self.assertEqual(kwargs["recreate"], recreate)
Exemple #11
0
 def get_hook(self):
     return HiveCliHook(
         hive_cli_conn_id=self.hive_cli_conn_id,
         run_as=self.run_as,
         mapred_queue=self.mapred_queue,
         mapred_queue_priority=self.mapred_queue_priority,
         mapred_job_name=self.mapred_job_name)
Exemple #12
0
 def get_hook(self) -> HiveCliHook:
     """Get Hive cli hook"""
     return HiveCliHook(
         hive_cli_conn_id=self.hive_cli_conn_id,
         run_as=self.run_as,
         mapred_queue=self.mapred_queue,
         mapred_queue_priority=self.mapred_queue_priority,
         mapred_job_name=self.mapred_job_name,
     )
Exemple #13
0
    def test_load_df(self, mock_to_csv, mock_load_file):
        df = pd.DataFrame({"c": ["foo", "bar", "baz"]})
        table = "t"
        delimiter = ","
        encoding = "utf-8"

        hook = HiveCliHook()
        hook.load_df(df=df,
                     table=table,
                     delimiter=delimiter,
                     encoding=encoding)

        assert mock_to_csv.call_count == 1
        kwargs = mock_to_csv.call_args[1]
        self.assertEqual(kwargs["header"], False)
        self.assertEqual(kwargs["index"], False)
        self.assertEqual(kwargs["sep"], delimiter)

        assert mock_load_file.call_count == 1
        kwargs = mock_load_file.call_args[1]
        self.assertEqual(kwargs["delimiter"], delimiter)
        self.assertEqual(kwargs["field_dict"], {"c": "STRING"})
        self.assertTrue(isinstance(kwargs["field_dict"], OrderedDict))
        self.assertEqual(kwargs["table"], table)
Exemple #14
0
    def execute(self, context: Dict[str, Any]) -> None:
        hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        self.log.info("Extracting data from Hive")
        hive_table = 'druid.' + context['task_instance_key_str'].replace(
            '.', '_')
        sql = self.sql.strip().strip(';')
        tblproperties = ''.join([
            ", '{}' = '{}'".format(k, v)
            for k, v in self.hive_tblproperties.items()
        ])
        hql = f"""\
        SET mapred.output.compress=false;
        SET hive.exec.compress.output=false;
        DROP TABLE IF EXISTS {hive_table};
        CREATE TABLE {hive_table}
        ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t'
        STORED AS TEXTFILE
        TBLPROPERTIES ('serialization.null.format' = ''{tblproperties})
        AS
        {sql}
        """
        self.log.info("Running command:\n %s", hql)
        hive.run_cli(hql)

        meta_hook = HiveMetastoreHook(self.metastore_conn_id)

        # Get the Hive table and extract the columns
        table = meta_hook.get_table(hive_table)
        columns = [col.name for col in table.sd.cols]

        # Get the path on hdfs
        static_path = meta_hook.get_table(hive_table).sd.location

        druid = DruidHook(druid_ingest_conn_id=self.druid_ingest_conn_id)

        try:
            index_spec = self.construct_ingest_query(
                static_path=static_path,
                columns=columns,
            )

            self.log.info("Inserting rows into Druid, hdfs path: %s",
                          static_path)

            druid.submit_indexing_job(index_spec)

            self.log.info("Load seems to have succeeded!")
        finally:
            self.log.info("Cleaning up by dropping the temp Hive table %s",
                          hive_table)
            hql = "DROP TABLE IF EXISTS {}".format(hive_table)
            hive.run_cli(hql)
Exemple #15
0
    def execute(self, context: 'Context'):
        # Downloading file from S3
        s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
        hive_hook = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
        self.log.info("Downloading S3 file")

        if self.wildcard_match:
            if not s3_hook.check_for_wildcard_key(self.s3_key):
                raise AirflowException(f"No key matches {self.s3_key}")
            s3_key_object = s3_hook.get_wildcard_key(self.s3_key)
        else:
            if not s3_hook.check_for_key(self.s3_key):
                raise AirflowException(
                    f"The key {self.s3_key} does not exists")
            s3_key_object = s3_hook.get_key(self.s3_key)

        _, file_ext = os.path.splitext(s3_key_object.key)
        if self.select_expression and self.input_compressed and file_ext.lower(
        ) != '.gz':
            raise AirflowException(
                "GZIP is the only compression format Amazon S3 Select supports"
            )

        with TemporaryDirectory(
                prefix='tmps32hive_') as tmp_dir, NamedTemporaryFile(
                    mode="wb", dir=tmp_dir, suffix=file_ext) as f:
            self.log.info("Dumping S3 key %s contents to local file %s",
                          s3_key_object.key, f.name)
            if self.select_expression:
                option = {}
                if self.headers:
                    option['FileHeaderInfo'] = 'USE'
                if self.delimiter:
                    option['FieldDelimiter'] = self.delimiter

                input_serialization = {'CSV': option}
                if self.input_compressed:
                    input_serialization['CompressionType'] = 'GZIP'

                content = s3_hook.select_key(
                    bucket_name=s3_key_object.bucket_name,
                    key=s3_key_object.key,
                    expression=self.select_expression,
                    input_serialization=input_serialization,
                )
                f.write(content.encode("utf-8"))
            else:
                s3_key_object.download_fileobj(f)
            f.flush()

            if self.select_expression or not self.headers:
                self.log.info("Loading file %s into Hive", f.name)
                hive_hook.load_file(
                    f.name,
                    self.hive_table,
                    field_dict=self.field_dict,
                    create=self.create,
                    partition=self.partition,
                    delimiter=self.delimiter,
                    recreate=self.recreate,
                    tblproperties=self.tblproperties,
                )
            else:
                # Decompressing file
                if self.input_compressed:
                    self.log.info("Uncompressing file %s", f.name)
                    fn_uncompressed = uncompress_file(f.name, file_ext,
                                                      tmp_dir)
                    self.log.info("Uncompressed to %s", fn_uncompressed)
                    # uncompressed file available now so deleting
                    # compressed file to save disk space
                    f.close()
                else:
                    fn_uncompressed = f.name

                # Testing if header matches field_dict
                if self.check_headers:
                    self.log.info("Matching file header against field_dict")
                    header_list = self._get_top_row_as_list(fn_uncompressed)
                    if not self._match_headers(header_list):
                        raise AirflowException("Header check failed")

                # Deleting top header row
                self.log.info("Removing header from file %s", fn_uncompressed)
                headless_file = self._delete_top_row_and_compress(
                    fn_uncompressed, file_ext, tmp_dir)
                self.log.info("Headless file %s", headless_file)
                self.log.info("Loading file %s into Hive", headless_file)
                hive_hook.load_file(
                    headless_file,
                    self.hive_table,
                    field_dict=self.field_dict,
                    create=self.create,
                    partition=self.partition,
                    delimiter=self.delimiter,
                    recreate=self.recreate,
                    tblproperties=self.tblproperties,
                )
Exemple #16
0
 def get_hook(self):
     if self.conn_type == 'mysql':
         from airflow.providers.mysql.hooks.mysql import MySqlHook
         return MySqlHook(mysql_conn_id=self.conn_id)
     elif self.conn_type == 'google_cloud_platform':
         from airflow.gcp.hooks.bigquery import BigQueryHook
         return BigQueryHook(bigquery_conn_id=self.conn_id)
     elif self.conn_type == 'postgres':
         from airflow.providers.postgres.hooks.postgres import PostgresHook
         return PostgresHook(postgres_conn_id=self.conn_id)
     elif self.conn_type == 'pig_cli':
         from airflow.providers.apache.pig.hooks.pig import PigCliHook
         return PigCliHook(pig_cli_conn_id=self.conn_id)
     elif self.conn_type == 'hive_cli':
         from airflow.providers.apache.hive.hooks.hive import HiveCliHook
         return HiveCliHook(hive_cli_conn_id=self.conn_id)
     elif self.conn_type == 'presto':
         from airflow.providers.presto.hooks.presto import PrestoHook
         return PrestoHook(presto_conn_id=self.conn_id)
     elif self.conn_type == 'hiveserver2':
         from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook
         return HiveServer2Hook(hiveserver2_conn_id=self.conn_id)
     elif self.conn_type == 'sqlite':
         from airflow.providers.sqlite.hooks.sqlite import SqliteHook
         return SqliteHook(sqlite_conn_id=self.conn_id)
     elif self.conn_type == 'jdbc':
         from airflow.providers.jdbc.hooks.jdbc import JdbcHook
         return JdbcHook(jdbc_conn_id=self.conn_id)
     elif self.conn_type == 'mssql':
         from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
         return MsSqlHook(mssql_conn_id=self.conn_id)
     elif self.conn_type == 'odbc':
         from airflow.providers.odbc.hooks.odbc import OdbcHook
         return OdbcHook(odbc_conn_id=self.conn_id)
     elif self.conn_type == 'oracle':
         from airflow.providers.oracle.hooks.oracle import OracleHook
         return OracleHook(oracle_conn_id=self.conn_id)
     elif self.conn_type == 'vertica':
         from airflow.providers.vertica.hooks.vertica import VerticaHook
         return VerticaHook(vertica_conn_id=self.conn_id)
     elif self.conn_type == 'cloudant':
         from airflow.providers.cloudant.hooks.cloudant import CloudantHook
         return CloudantHook(cloudant_conn_id=self.conn_id)
     elif self.conn_type == 'jira':
         from airflow.providers.jira.hooks.jira import JiraHook
         return JiraHook(jira_conn_id=self.conn_id)
     elif self.conn_type == 'redis':
         from airflow.providers.redis.hooks.redis import RedisHook
         return RedisHook(redis_conn_id=self.conn_id)
     elif self.conn_type == 'wasb':
         from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
         return WasbHook(wasb_conn_id=self.conn_id)
     elif self.conn_type == 'docker':
         from airflow.providers.docker.hooks.docker import DockerHook
         return DockerHook(docker_conn_id=self.conn_id)
     elif self.conn_type == 'azure_data_lake':
         from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook
         return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id)
     elif self.conn_type == 'azure_cosmos':
         from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook
         return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id)
     elif self.conn_type == 'cassandra':
         from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook
         return CassandraHook(cassandra_conn_id=self.conn_id)
     elif self.conn_type == 'mongo':
         from airflow.providers.mongo.hooks.mongo import MongoHook
         return MongoHook(conn_id=self.conn_id)
     elif self.conn_type == 'gcpcloudsql':
         from airflow.gcp.hooks.cloud_sql import CloudSQLDatabaseHook
         return CloudSQLDatabaseHook(gcp_cloudsql_conn_id=self.conn_id)
     elif self.conn_type == 'grpc':
         from airflow.providers.grpc.hooks.grpc import GrpcHook
         return GrpcHook(grpc_conn_id=self.conn_id)
     raise AirflowException("Unknown hook type {}".format(self.conn_type))
Exemple #17
0
 def test_run_cli(self):
     hook = HiveCliHook()
     hook.run_cli("SHOW DATABASES")