def __init__(self,
              hive_conn,
              kerberos_service_name: str = 'hive',
              *args,
              **kwargs):
     super(HiveHelper, self).__init__(*args, **kwargs)
     self.hive_conn = hive_conn
     self.kerberos_service_name = kerberos_service_name
     self.host = HiveServer2Hook.get_connection(self.hive_conn).host
     self.port = HiveServer2Hook.get_connection(self.hive_conn).port
     self.user = HiveServer2Hook.get_connection(self.hive_conn).login
     self._conn = self._connect(kerberos_service_name)
     self._cursor = self._conn.cursor()
Exemple #2
0
    def execute(self, context):
        hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)

        self.log.info('Extracting data from Hive')
        self.log.info(self.sql)

        data = hive.get_pandas_df(self.sql, schema=self.schema)
        dynamodb = AwsDynamoDBHook(
            aws_conn_id=self.aws_conn_id,
            table_name=self.table_name,
            table_keys=self.table_keys,
            region_name=self.region_name,
        )

        self.log.info('Inserting rows into dynamodb')

        if self.pre_process is None:
            dynamodb.write_batch_data(
                json.loads(data.to_json(orient='records')))
        else:
            dynamodb.write_batch_data(
                self.pre_process(data=data,
                                 args=self.pre_process_args,
                                 kwargs=self.pre_process_kwargs))

        self.log.info('Done.')
Exemple #3
0
    def test_get_results_with_hive_conf(self):
        hql = [
            "set key", "set airflow.ctx.dag_id", "set airflow.ctx.dag_run_id",
            "set airflow.ctx.task_id", "set airflow.ctx.execution_date"
        ]

        dag_id_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_ID']['env_var_format']
        task_id_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_TASK_ID']['env_var_format']
        execution_date_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_EXECUTION_DATE'][
                'env_var_format']
        dag_run_id_ctx_var_name = \
            AIRFLOW_VAR_NAME_FORMAT_MAPPING['AIRFLOW_CONTEXT_DAG_RUN_ID'][
                'env_var_format']

        with mock.patch.dict(
                'os.environ', {
                    dag_id_ctx_var_name: 'test_dag_id',
                    task_id_ctx_var_name: 'test_task_id',
                    execution_date_ctx_var_name: 'test_execution_date',
                    dag_run_id_ctx_var_name: 'test_dag_run_id',
                }):
            hook = HiveServer2Hook()
            output = '\n'.join(res_tuple[0] for res_tuple in hook.get_results(
                hql=hql, hive_conf={'key': 'value'})['data'])
        self.assertIn('value', output)
        self.assertIn('test_dag_id', output)
        self.assertIn('test_task_id', output)
        self.assertIn('test_execution_date', output)
        self.assertIn('test_dag_run_id', output)
Exemple #4
0
 def test_get_pandas_df(self):
     hook = HiveServer2Hook()
     query = "SELECT * FROM {}".format(self.table)
     df = hook.get_pandas_df(query, schema=self.database)
     self.assertEqual(len(df), 2)
     self.assertListEqual(df.columns.tolist(), self.columns)
     self.assertListEqual(df[self.columns[0]].values.tolist(), [1, 2])
Exemple #5
0
    def execute(self, context: 'Context'):
        hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)

        self.log.info("Extracting data from Hive: %s", self.sql)
        hive_conf = context_to_airflow_vars(context)
        if self.hive_conf:
            hive_conf.update(self.hive_conf)
        if self.bulk_load:
            with NamedTemporaryFile() as tmp_file:
                hive.to_csv(
                    self.sql,
                    tmp_file.name,
                    delimiter='\t',
                    lineterminator='\n',
                    output_header=False,
                    hive_conf=hive_conf,
                )
                mysql = self._call_preoperator()
                mysql.bulk_load(table=self.mysql_table, tmp_file=tmp_file.name)
        else:
            hive_results = hive.get_records(self.sql, hive_conf=hive_conf)
            mysql = self._call_preoperator()
            mysql.insert_rows(table=self.mysql_table, rows=hive_results)

        if self.mysql_postoperator:
            self.log.info("Running MySQL postoperator")
            mysql.run(self.mysql_postoperator)

        self.log.info("Done.")
Exemple #6
0
 def execute(self, context):
     with NamedTemporaryFile() as tmp_file:
         self.log.info("Fetching file from Hive")
         hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)
         hive.to_csv(hql=self.hql, csv_filepath=tmp_file.name, hive_conf=context_to_airflow_vars(context))
         self.log.info("Pushing to samba")
         samba = SambaHook(samba_conn_id=self.samba_conn_id)
         samba.push_from_local(self.destination_filepath, tmp_file.name)
    def test_mysql_to_hive_verify_loaded_values(self):
        mysql_table = 'test_mysql_to_hive'
        hive_table = 'test_mysql_to_hive'

        hook = MySqlHook()

        try:
            minmax = (
                255,
                65535,
                16777215,
                4294967295,
                18446744073709551615,
                -128,
                -32768,
                -8388608,
                -2147483648,
                -9223372036854775808
            )

            with hook.get_conn() as conn:
                conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
                conn.execute("""
                    CREATE TABLE {} (
                        c0 TINYINT   UNSIGNED,
                        c1 SMALLINT  UNSIGNED,
                        c2 MEDIUMINT UNSIGNED,
                        c3 INT       UNSIGNED,
                        c4 BIGINT    UNSIGNED,
                        c5 TINYINT,
                        c6 SMALLINT,
                        c7 MEDIUMINT,
                        c8 INT,
                        c9 BIGINT
                    )
                """.format(mysql_table))
                conn.execute("""
                    INSERT INTO {} VALUES (
                        {}, {}, {}, {}, {}, {}, {}, {}, {}, {}
                    )
                """.format(mysql_table, *minmax))

            op = MySqlToHiveTransferOperator(
                task_id='test_m2h',
                hive_cli_conn_id='hive_cli_default',
                sql="SELECT * FROM {}".format(mysql_table),
                hive_table=hive_table,
                recreate=True,
                delimiter=",",
                dag=self.dag)
            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

            hive_hook = HiveServer2Hook()
            result = hive_hook.get_records("SELECT * FROM {}".format(hive_table))
            self.assertEqual(result[0], minmax)
        finally:
            with hook.get_conn() as conn:
                conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
Exemple #8
0
 def test_multi_statements(self):
     sqls = [
         "CREATE TABLE IF NOT EXISTS test_multi_statements (i INT)",
         "SELECT * FROM {}".format(self.table),
         "DROP TABLE test_multi_statements",
     ]
     hook = HiveServer2Hook()
     results = hook.get_records(sqls, schema=self.database)
     self.assertListEqual(results, [(1, 1), (2, 2)])
Exemple #9
0
 def test_to_csv(self):
     hook = HiveServer2Hook()
     query = "SELECT * FROM {}".format(self.table)
     csv_filepath = 'query_results.csv'
     hook.to_csv(query, csv_filepath, schema=self.database,
                 delimiter=',', lineterminator='\n', output_header=True, fetch_size=2)
     df = pd.read_csv(csv_filepath, sep=',')
     self.assertListEqual(df.columns.tolist(), self.columns)
     self.assertListEqual(df[self.columns[0]].values.tolist(), [1, 2])
     self.assertEqual(len(df), 2)
Exemple #10
0
    def test_mysql_to_hive_verify_csv_special_char(self):
        mysql_table = 'test_mysql_to_hive'
        hive_table = 'test_mysql_to_hive'

        from airflow.hooks.mysql_hook import MySqlHook
        hook = MySqlHook()

        try:
            db_record = (
                'c0',
                '["true"]'
            )
            with hook.get_conn() as conn:
                conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
                conn.execute("""
                    CREATE TABLE {} (
                        c0 VARCHAR(25),
                        c1 VARCHAR(25)
                    )
                """.format(mysql_table))
                conn.execute("""
                    INSERT INTO {} VALUES (
                        '{}', '{}'
                    )
                """.format(mysql_table, *db_record))

            from airflow.operators.mysql_to_hive import MySqlToHiveTransfer
            import unicodecsv as csv
            op = MySqlToHiveTransfer(
                task_id='test_m2h',
                hive_cli_conn_id='hive_cli_default',
                sql="SELECT * FROM {}".format(mysql_table),
                hive_table=hive_table,
                recreate=True,
                delimiter=",",
                quoting=csv.QUOTE_NONE,
                quotechar='',
                escapechar='@',
                dag=self.dag)
            op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

            from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook
            hive_hook = HiveServer2Hook()
            result = hive_hook.get_records("SELECT * FROM {}".format(hive_table))
            self.assertEqual(result[0], db_record)
        finally:
            with hook.get_conn() as conn:
                conn.execute("DROP TABLE IF EXISTS {}".format(mysql_table))
Exemple #11
0
    def test_get_conn_with_password(self, mock_connect):
        conn_id = "conn_with_password"
        conn_env = CONN_ENV_PREFIX + conn_id.upper()

        with mock.patch.dict(
            'os.environ',
            {conn_env: "jdbc+hive2://conn_id:conn_pass@localhost:10000/default?authMechanism=LDAP"}
        ):
            HiveServer2Hook(hiveserver2_conn_id=conn_id).get_conn()
            mock_connect.assert_called_once_with(
                host='localhost',
                port=10000,
                auth='LDAP',
                kerberos_service_name=None,
                username='******',
                password='******',
                database='default')
Exemple #12
0
 def test_get_results_data(self):
     hook = HiveServer2Hook()
     query = "SELECT * FROM {}".format(self.table)
     results = hook.get_results(query, schema=self.database)
     self.assertListEqual(results['data'], [(1, 1), (2, 2)])
Exemple #13
0
 def test_get_results_header(self):
     hook = HiveServer2Hook()
     query = "SELECT * FROM {}".format(self.table)
     results = hook.get_results(query, schema=self.database)
     self.assertListEqual([col[0] for col in results['header']],
                          self.columns)
Exemple #14
0
 def test_get_conn(self):
     hook = HiveServer2Hook()
     hook.get_conn()
Exemple #15
0
 def get_hook(self):
     if self.conn_type == 'mysql':
         from airflow.providers.mysql.hooks.mysql import MySqlHook
         return MySqlHook(mysql_conn_id=self.conn_id)
     elif self.conn_type == 'google_cloud_platform':
         from airflow.gcp.hooks.bigquery import BigQueryHook
         return BigQueryHook(bigquery_conn_id=self.conn_id)
     elif self.conn_type == 'postgres':
         from airflow.providers.postgres.hooks.postgres import PostgresHook
         return PostgresHook(postgres_conn_id=self.conn_id)
     elif self.conn_type == 'pig_cli':
         from airflow.providers.apache.pig.hooks.pig import PigCliHook
         return PigCliHook(pig_cli_conn_id=self.conn_id)
     elif self.conn_type == 'hive_cli':
         from airflow.providers.apache.hive.hooks.hive import HiveCliHook
         return HiveCliHook(hive_cli_conn_id=self.conn_id)
     elif self.conn_type == 'presto':
         from airflow.providers.presto.hooks.presto import PrestoHook
         return PrestoHook(presto_conn_id=self.conn_id)
     elif self.conn_type == 'hiveserver2':
         from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook
         return HiveServer2Hook(hiveserver2_conn_id=self.conn_id)
     elif self.conn_type == 'sqlite':
         from airflow.providers.sqlite.hooks.sqlite import SqliteHook
         return SqliteHook(sqlite_conn_id=self.conn_id)
     elif self.conn_type == 'jdbc':
         from airflow.providers.jdbc.hooks.jdbc import JdbcHook
         return JdbcHook(jdbc_conn_id=self.conn_id)
     elif self.conn_type == 'mssql':
         from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
         return MsSqlHook(mssql_conn_id=self.conn_id)
     elif self.conn_type == 'odbc':
         from airflow.providers.odbc.hooks.odbc import OdbcHook
         return OdbcHook(odbc_conn_id=self.conn_id)
     elif self.conn_type == 'oracle':
         from airflow.providers.oracle.hooks.oracle import OracleHook
         return OracleHook(oracle_conn_id=self.conn_id)
     elif self.conn_type == 'vertica':
         from airflow.providers.vertica.hooks.vertica import VerticaHook
         return VerticaHook(vertica_conn_id=self.conn_id)
     elif self.conn_type == 'cloudant':
         from airflow.providers.cloudant.hooks.cloudant import CloudantHook
         return CloudantHook(cloudant_conn_id=self.conn_id)
     elif self.conn_type == 'jira':
         from airflow.providers.jira.hooks.jira import JiraHook
         return JiraHook(jira_conn_id=self.conn_id)
     elif self.conn_type == 'redis':
         from airflow.providers.redis.hooks.redis import RedisHook
         return RedisHook(redis_conn_id=self.conn_id)
     elif self.conn_type == 'wasb':
         from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
         return WasbHook(wasb_conn_id=self.conn_id)
     elif self.conn_type == 'docker':
         from airflow.providers.docker.hooks.docker import DockerHook
         return DockerHook(docker_conn_id=self.conn_id)
     elif self.conn_type == 'azure_data_lake':
         from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook
         return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id)
     elif self.conn_type == 'azure_cosmos':
         from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook
         return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id)
     elif self.conn_type == 'cassandra':
         from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook
         return CassandraHook(cassandra_conn_id=self.conn_id)
     elif self.conn_type == 'mongo':
         from airflow.providers.mongo.hooks.mongo import MongoHook
         return MongoHook(conn_id=self.conn_id)
     elif self.conn_type == 'gcpcloudsql':
         from airflow.gcp.hooks.cloud_sql import CloudSQLDatabaseHook
         return CloudSQLDatabaseHook(gcp_cloudsql_conn_id=self.conn_id)
     elif self.conn_type == 'grpc':
         from airflow.providers.grpc.hooks.grpc import GrpcHook
         return GrpcHook(grpc_conn_id=self.conn_id)
     raise AirflowException("Unknown hook type {}".format(self.conn_type))