コード例 #1
0
    def execute(self, context: 'Context') -> None:
        gcs_hook = GCSHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
        )

        presto_hook = PrestoHook(presto_conn_id=self.presto_conn_id)

        with NamedTemporaryFile("w+") as temp_file:
            self.log.info("Downloading data from %s", self.source_object)
            gcs_hook.download(
                bucket_name=self.source_bucket,
                object_name=self.source_object,
                filename=temp_file.name,
            )

            data = list(csv.reader(temp_file))
            fields = tuple(data[0])
            rows = []
            for row in data[1:]:
                rows.append(tuple(row))

            self.log.info("Inserting data into %s", self.presto_table)
            presto_hook.insert_rows(table=self.presto_table, rows=rows, target_fields=fields)
コード例 #2
0
 def drop_db():
     hook = PrestoHook()
     with hook.get_conn() as conn:
         with closing(conn.cursor()) as cur:
             cur.execute(DELETE_QUERY)
             # Presto does not execute queries until the result is fetched. :-(
             cur.fetchone()
コード例 #3
0
 def query(self):
     """Queries presto and returns a cursor to the results."""
     presto = PrestoHook(presto_conn_id=self.presto_conn_id)
     conn = presto.get_conn()
     cursor = conn.cursor()
     self.log.info("Executing: %s", self.sql)
     cursor.execute(self.sql)
     return _PrestoToGCSPrestoCursorAdapter(cursor)
コード例 #4
0
 def data(self):
     """Retrieve data from table"""
     table = request.args.get("table")
     sql = f"SELECT * FROM {table} LIMIT 1000;"
     hook = PrestoHook(PRESTO_CONN_ID)
     df = hook.get_pandas_df(sql)
     return df.to_html(
         classes="table table-striped table-bordered table-hover",
         index=False,
         na_rep='',
     )
コード例 #5
0
 def test_should_record_records_with_kerberos_auth(self):
     conn_url = ('presto://airflow@presto:7778/?'
                 'auth=kerberos&kerberos__service_name=HTTP&'
                 'verify=False&'
                 'protocol=https')
     with mock.patch.dict('os.environ',
                          AIRFLOW_CONN_PRESTO_DEFAULT=conn_url):
         hook = PrestoHook()
         sql = "SELECT name FROM tpch.sf1.customer ORDER BY custkey ASC LIMIT 3"
         records = hook.get_records(sql)
         assert [['Customer#000000001'], ['Customer#000000002'],
                 ['Customer#000000003']] == records
コード例 #6
0
    def execute(self, context: 'Context') -> None:
        presto = PrestoHook(presto_conn_id=self.presto_conn_id)
        self.log.info("Extracting data from Presto: %s", self.sql)
        results = presto.get_records(self.sql)

        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
        if self.mysql_preoperator:
            self.log.info("Running MySQL preoperator")
            self.log.info(self.mysql_preoperator)
            mysql.run(self.mysql_preoperator)

        self.log.info("Inserting rows into MySQL")
        mysql.insert_rows(table=self.mysql_table, rows=results)
コード例 #7
0
ファイル: test_presto.py プロジェクト: phucbui95/airflow
 def test_get_conn_invalid_auth(self, mock_get_connection):
     mock_get_connection.return_value = Connection(
         login='******',
         password='******',
         host='host',
         schema='hive',
         extra=json.dumps({'auth': 'kerberos'}),
     )
     with self.assertRaisesRegex(
         AirflowException, re.escape("Kerberos authorization doesn't support password.")
     ):
         PrestoHook().get_conn()
コード例 #8
0
ファイル: test_presto.py プロジェクト: phucbui95/airflow
    def test_get_conn_verify(self, current_verify, expected_verify):
        patcher_connect = patch('airflow.providers.presto.hooks.presto.prestodb.dbapi.connect')
        patcher_get_connections = patch('airflow.providers.presto.hooks.presto.PrestoHook.get_connection')

        with patcher_connect as mock_connect, patcher_get_connections as mock_get_connection:
            mock_get_connection.return_value = Connection(
                login='******', host='host', schema='hive', extra=json.dumps({'verify': current_verify})
            )
            mock_verify = mock.PropertyMock()
            type(mock_connect.return_value._http_session).verify = mock_verify

            conn = PrestoHook().get_conn()
            mock_verify.assert_called_once_with(expected_verify)
            self.assertEqual(mock_connect.return_value, conn)
コード例 #9
0
    def test_get_conn_kerberos_auth(self, mock_get_connection, mock_connect,
                                    mock_auth):
        mock_get_connection.return_value = Connection(
            login='******',
            host='host',
            schema='hive',
            extra=json.dumps({
                'auth': 'kerberos',
                'kerberos__config': 'TEST_KERBEROS_CONFIG',
                'kerberos__service_name': 'TEST_SERVICE_NAME',
                'kerberos__mutual_authentication':
                'TEST_MUTUAL_AUTHENTICATION',
                'kerberos__force_preemptive': True,
                'kerberos__hostname_override': 'TEST_HOSTNAME_OVERRIDE',
                'kerberos__sanitize_mutual_error_response': True,
                'kerberos__principal': 'TEST_PRINCIPAL',
                'kerberos__delegate': 'TEST_DELEGATE',
                'kerberos__ca_bundle': 'TEST_CA_BUNDLE',
            }),
        )

        conn = PrestoHook().get_conn()
        mock_connect.assert_called_once_with(
            catalog='hive',
            host='host',
            port=None,
            http_scheme='http',
            schema='hive',
            source='airflow',
            user='******',
            isolation_level=0,
            auth=mock_auth.return_value,
        )
        mock_auth.assert_called_once_with(
            ca_bundle='TEST_CA_BUNDLE',
            config='TEST_KERBEROS_CONFIG',
            delegate='TEST_DELEGATE',
            force_preemptive=True,
            hostname_override='TEST_HOSTNAME_OVERRIDE',
            mutual_authentication='TEST_MUTUAL_AUTHENTICATION',
            principal='TEST_PRINCIPAL',
            sanitize_mutual_error_response=True,
            service_name='TEST_SERVICE_NAME',
        )
        assert mock_connect.return_value == conn
コード例 #10
0
ファイル: test_presto.py プロジェクト: phucbui95/airflow
    def test_get_conn_basic_auth(self, mock_get_connection, mock_connect, mock_basic_auth):
        mock_get_connection.return_value = Connection(
            login='******', password='******', host='host', schema='hive'
        )

        conn = PrestoHook().get_conn()
        mock_connect.assert_called_once_with(
            catalog='hive',
            host='host',
            port=None,
            http_scheme='http',
            schema='hive',
            source='airflow',
            user='******',
            isolation_level=0,
            auth=mock_basic_auth.return_value,
        )
        mock_basic_auth.assert_called_once_with('login', 'password')
        self.assertEqual(mock_connect.return_value, conn)
コード例 #11
0
    def execute(self, context: 'Context') -> None:
        gcs_hook = GCSHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
        )

        presto_hook = PrestoHook(presto_conn_id=self.presto_conn_id)

        with NamedTemporaryFile("w+") as temp_file:
            self.log.info("Downloading data from %s", self.source_object)
            gcs_hook.download(
                bucket_name=self.source_bucket,
                object_name=self.source_object,
                filename=temp_file.name,
            )

            data = csv.reader(temp_file)
            rows = (tuple(row) for row in data)
            self.log.info("Inserting data into %s", self.presto_table)

            if self.schema_fields:
                presto_hook.insert_rows(table=self.presto_table,
                                        rows=rows,
                                        target_fields=self.schema_fields)
            elif self.schema_object:
                blob = gcs_hook.download(
                    bucket_name=self.source_bucket,
                    object_name=self.schema_object,
                )
                schema_fields = json.loads(blob.decode("utf-8"))
                presto_hook.insert_rows(table=self.presto_table,
                                        rows=rows,
                                        target_fields=schema_fields)
            else:
                presto_hook.insert_rows(table=self.presto_table, rows=rows)
コード例 #12
0
 def test_should_record_records(self):
     hook = PrestoHook()
     sql = "SELECT name FROM tpch.sf1.customer ORDER BY custkey ASC LIMIT 3"
     records = hook.get_records(sql)
     assert [['Customer#000000001'], ['Customer#000000002'],
             ['Customer#000000003']] == records
コード例 #13
0
ファイル: presto_to_slack.py プロジェクト: kosteev/airflow
 def _get_presto_hook(self) -> PrestoHook:
     return PrestoHook(presto_conn_id=self.presto_conn_id)
コード例 #14
0
 def get_hook(self):
     if self.conn_type == 'mysql':
         from airflow.providers.mysql.hooks.mysql import MySqlHook
         return MySqlHook(mysql_conn_id=self.conn_id)
     elif self.conn_type == 'google_cloud_platform':
         from airflow.gcp.hooks.bigquery import BigQueryHook
         return BigQueryHook(bigquery_conn_id=self.conn_id)
     elif self.conn_type == 'postgres':
         from airflow.providers.postgres.hooks.postgres import PostgresHook
         return PostgresHook(postgres_conn_id=self.conn_id)
     elif self.conn_type == 'pig_cli':
         from airflow.providers.apache.pig.hooks.pig import PigCliHook
         return PigCliHook(pig_cli_conn_id=self.conn_id)
     elif self.conn_type == 'hive_cli':
         from airflow.providers.apache.hive.hooks.hive import HiveCliHook
         return HiveCliHook(hive_cli_conn_id=self.conn_id)
     elif self.conn_type == 'presto':
         from airflow.providers.presto.hooks.presto import PrestoHook
         return PrestoHook(presto_conn_id=self.conn_id)
     elif self.conn_type == 'hiveserver2':
         from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook
         return HiveServer2Hook(hiveserver2_conn_id=self.conn_id)
     elif self.conn_type == 'sqlite':
         from airflow.providers.sqlite.hooks.sqlite import SqliteHook
         return SqliteHook(sqlite_conn_id=self.conn_id)
     elif self.conn_type == 'jdbc':
         from airflow.providers.jdbc.hooks.jdbc import JdbcHook
         return JdbcHook(jdbc_conn_id=self.conn_id)
     elif self.conn_type == 'mssql':
         from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
         return MsSqlHook(mssql_conn_id=self.conn_id)
     elif self.conn_type == 'odbc':
         from airflow.providers.odbc.hooks.odbc import OdbcHook
         return OdbcHook(odbc_conn_id=self.conn_id)
     elif self.conn_type == 'oracle':
         from airflow.providers.oracle.hooks.oracle import OracleHook
         return OracleHook(oracle_conn_id=self.conn_id)
     elif self.conn_type == 'vertica':
         from airflow.providers.vertica.hooks.vertica import VerticaHook
         return VerticaHook(vertica_conn_id=self.conn_id)
     elif self.conn_type == 'cloudant':
         from airflow.providers.cloudant.hooks.cloudant import CloudantHook
         return CloudantHook(cloudant_conn_id=self.conn_id)
     elif self.conn_type == 'jira':
         from airflow.providers.jira.hooks.jira import JiraHook
         return JiraHook(jira_conn_id=self.conn_id)
     elif self.conn_type == 'redis':
         from airflow.providers.redis.hooks.redis import RedisHook
         return RedisHook(redis_conn_id=self.conn_id)
     elif self.conn_type == 'wasb':
         from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
         return WasbHook(wasb_conn_id=self.conn_id)
     elif self.conn_type == 'docker':
         from airflow.providers.docker.hooks.docker import DockerHook
         return DockerHook(docker_conn_id=self.conn_id)
     elif self.conn_type == 'azure_data_lake':
         from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook
         return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id)
     elif self.conn_type == 'azure_cosmos':
         from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook
         return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id)
     elif self.conn_type == 'cassandra':
         from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook
         return CassandraHook(cassandra_conn_id=self.conn_id)
     elif self.conn_type == 'mongo':
         from airflow.providers.mongo.hooks.mongo import MongoHook
         return MongoHook(conn_id=self.conn_id)
     elif self.conn_type == 'gcpcloudsql':
         from airflow.gcp.hooks.cloud_sql import CloudSQLDatabaseHook
         return CloudSQLDatabaseHook(gcp_cloudsql_conn_id=self.conn_id)
     elif self.conn_type == 'grpc':
         from airflow.providers.grpc.hooks.grpc import GrpcHook
         return GrpcHook(grpc_conn_id=self.conn_id)
     raise AirflowException("Unknown hook type {}".format(self.conn_type))
コード例 #15
0
ファイル: presto_check.py プロジェクト: harishjami1382/test2
 def get_db_hook(self):
     return PrestoHook(presto_conn_id=self.presto_conn_id)
コード例 #16
0
def get_presto_hook():
    presto_hook = PrestoHook('trino')
    return presto_hook
コード例 #17
0
 def drop_db():
     hook = PrestoHook()
     hook.run(DELETE_QUERY)
コード例 #18
0
ファイル: hive_stats.py プロジェクト: folly3/airflow-1
    def execute(self, context: Optional[Dict[str, Any]] = None) -> None:
        metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id)
        table = metastore.get_table(table_name=self.table)
        field_types = {col.name: col.type for col in table.sd.cols}

        exprs: Any = {('', 'count'): 'COUNT(*)'}
        for col, col_type in list(field_types.items()):
            if self.assignment_func:
                assign_exprs = self.assignment_func(col, col_type)
                if assign_exprs is None:
                    assign_exprs = self.get_default_exprs(col, col_type)
            else:
                assign_exprs = self.get_default_exprs(col, col_type)
            exprs.update(assign_exprs)
        exprs.update(self.extra_exprs)
        exprs = OrderedDict(exprs)
        exprs_str = ",\n        ".join([v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items()])

        where_clause_ = ["{} = '{}'".format(k, v) for k, v in self.partition.items()]
        where_clause = " AND\n        ".join(where_clause_)
        sql = "SELECT {exprs_str} FROM {table} WHERE {where_clause};".format(
            exprs_str=exprs_str, table=self.table, where_clause=where_clause
        )

        presto = PrestoHook(presto_conn_id=self.presto_conn_id)
        self.log.info('Executing SQL check: %s', sql)
        row = presto.get_first(hql=sql)
        self.log.info("Record: %s", row)
        if not row:
            raise AirflowException("The query returned None")

        part_json = json.dumps(self.partition, sort_keys=True)

        self.log.info("Deleting rows from previous runs if they exist")
        mysql = MySqlHook(self.mysql_conn_id)
        sql = """
        SELECT 1 FROM hive_stats
        WHERE
            table_name='{table}' AND
            partition_repr='{part_json}' AND
            dttm='{dttm}'
        LIMIT 1;
        """.format(
            table=self.table, part_json=part_json, dttm=self.dttm
        )
        if mysql.get_records(sql):
            sql = """
            DELETE FROM hive_stats
            WHERE
                table_name='{table}' AND
                partition_repr='{part_json}' AND
                dttm='{dttm}';
            """.format(
                table=self.table, part_json=part_json, dttm=self.dttm
            )
            mysql.run(sql)

        self.log.info("Pivoting and loading cells into the Airflow db")
        rows = [
            (self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, row)
        ]
        mysql.insert_rows(
            table='hive_stats',
            rows=rows,
            target_fields=[
                'ds',
                'dttm',
                'table_name',
                'partition_repr',
                'col',
                'metric',
                'value',
            ],
        )
コード例 #19
0
 def init_db():
     hook = PrestoHook()
     hook.run(CREATE_QUERY)
     hook.run(LOAD_QUERY)