def drop_db():
     hook = TrinoHook()
     with hook.get_conn() as conn:
         with closing(conn.cursor()) as cur:
             cur.execute(DELETE_QUERY)
             # Trino does not execute queries until the result is fetched. :-(
             cur.fetchone()
示例#2
0
 def query(self):
     """Queries trino and returns a cursor to the results."""
     trino = TrinoHook(trino_conn_id=self.trino_conn_id)
     conn = trino.get_conn()
     cursor = conn.cursor()
     self.log.info("Executing: %s", self.sql)
     cursor.execute(self.sql)
     return _TrinoToGCSTrinoCursorAdapter(cursor)
示例#3
0
 def test_should_record_records_with_kerberos_auth(self):
     conn_url = (
         'trino://[email protected]:7778/?'
         'auth=kerberos&kerberos__service_name=HTTP&'
         'verify=False&'
         'protocol=https'
     )
     with mock.patch.dict('os.environ', AIRFLOW_CONN_TRINO_DEFAULT=conn_url):
         hook = TrinoHook()
         sql = "SELECT name FROM tpch.sf1.customer ORDER BY custkey ASC LIMIT 3"
         records = hook.get_records(sql)
         assert [['Customer#000000001'], ['Customer#000000002'], ['Customer#000000003']] == records
    def execute(self, context: Dict) -> None:
        trino = TrinoHook(trino_conn_id=self.trino_conn_id)
        self.log.info("Extracting data from Trino: %s", self.sql)
        results = trino.get_records(self.sql)

        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
        if self.mysql_preoperator:
            self.log.info("Running MySQL preoperator")
            self.log.info(self.mysql_preoperator)
            mysql.run(self.mysql_preoperator)

        self.log.info("Inserting rows into MySQL")
        mysql.insert_rows(table=self.mysql_table, rows=results)
示例#5
0
    def execute(self, context: 'Context') -> None:
        gcs_hook = GCSHook(
            gcp_conn_id=self.gcp_conn_id,
            delegate_to=self.delegate_to,
            impersonation_chain=self.impersonation_chain,
        )

        trino_hook = TrinoHook(trino_conn_id=self.trino_conn_id)

        with NamedTemporaryFile("w+") as temp_file:
            self.log.info("Downloading data from %s", self.source_object)
            gcs_hook.download(
                bucket_name=self.source_bucket,
                object_name=self.source_object,
                filename=temp_file.name,
            )

            data = csv.reader(temp_file)
            rows = (tuple(row) for row in data)
            self.log.info("Inserting data into %s", self.trino_table)
            if self.schema_fields:
                trino_hook.insert_rows(table=self.trino_table,
                                       rows=rows,
                                       target_fields=self.schema_fields)
            elif self.schema_object:
                blob = gcs_hook.download(
                    bucket_name=self.source_bucket,
                    object_name=self.schema_object,
                )
                schema_fields = json.loads(blob.decode("utf-8"))
                trino_hook.insert_rows(table=self.trino_table,
                                       rows=rows,
                                       target_fields=schema_fields)
            else:
                trino_hook.insert_rows(table=self.trino_table, rows=rows)
示例#6
0
 def test_get_conn_invalid_auth(self, mock_get_connection):
     mock_get_connection.return_value = Connection(
         login='******',
         password='******',
         host='host',
         schema='hive',
         extra=json.dumps({'auth': 'kerberos'}),
     )
     with pytest.raises(
         AirflowException, match=re.escape("Kerberos authorization doesn't support password.")
     ):
         TrinoHook().get_conn()
示例#7
0
    def test_get_conn_verify(self, current_verify, expected_verify):
        patcher_connect = patch('airflow.providers.trino.hooks.trino.trino.dbapi.connect')
        patcher_get_connections = patch('airflow.providers.trino.hooks.trino.TrinoHook.get_connection')

        with patcher_connect as mock_connect, patcher_get_connections as mock_get_connection:
            mock_get_connection.return_value = Connection(
                login='******', host='host', schema='hive', extra=json.dumps({'verify': current_verify})
            )
            mock_verify = mock.PropertyMock()
            type(mock_connect.return_value._http_session).verify = mock_verify

            conn = TrinoHook().get_conn()
            mock_verify.assert_called_once_with(expected_verify)
            assert mock_connect.return_value == conn
示例#8
0
    def test_get_conn_kerberos_auth(self, mock_get_connection, mock_connect, mock_auth):
        mock_get_connection.return_value = Connection(
            login='******',
            host='host',
            schema='hive',
            extra=json.dumps(
                {
                    'auth': 'kerberos',
                    'kerberos__config': 'TEST_KERBEROS_CONFIG',
                    'kerberos__service_name': 'TEST_SERVICE_NAME',
                    'kerberos__mutual_authentication': 'TEST_MUTUAL_AUTHENTICATION',
                    'kerberos__force_preemptive': True,
                    'kerberos__hostname_override': 'TEST_HOSTNAME_OVERRIDE',
                    'kerberos__sanitize_mutual_error_response': True,
                    'kerberos__principal': 'TEST_PRINCIPAL',
                    'kerberos__delegate': 'TEST_DELEGATE',
                    'kerberos__ca_bundle': 'TEST_CA_BUNDLE',
                }
            ),
        )

        conn = TrinoHook().get_conn()
        mock_connect.assert_called_once_with(
            catalog='hive',
            host='host',
            port=None,
            http_scheme='http',
            schema='hive',
            source='airflow',
            user='******',
            isolation_level=0,
            auth=mock_auth.return_value,
        )
        mock_auth.assert_called_once_with(
            ca_bundle='TEST_CA_BUNDLE',
            config='TEST_KERBEROS_CONFIG',
            delegate='TEST_DELEGATE',
            force_preemptive=True,
            hostname_override='TEST_HOSTNAME_OVERRIDE',
            mutual_authentication='TEST_MUTUAL_AUTHENTICATION',
            principal='TEST_PRINCIPAL',
            sanitize_mutual_error_response=True,
            service_name='TEST_SERVICE_NAME',
        )
        assert mock_connect.return_value == conn
示例#9
0
    def test_get_conn_basic_auth(self, mock_get_connection, mock_connect, mock_basic_auth):
        mock_get_connection.return_value = Connection(
            login='******', password='******', host='host', schema='hive'
        )

        conn = TrinoHook().get_conn()
        mock_connect.assert_called_once_with(
            catalog='hive',
            host='host',
            port=None,
            http_scheme='http',
            schema='hive',
            source='airflow',
            user='******',
            isolation_level=0,
            auth=mock_basic_auth.return_value,
        )
        mock_basic_auth.assert_called_once_with('login', 'password')
        assert mock_connect.return_value == conn
示例#10
0
 def test_should_record_records(self):
     hook = TrinoHook()
     sql = "SELECT name FROM tpch.sf1.customer ORDER BY custkey ASC LIMIT 3"
     records = hook.get_records(sql)
     assert [['Customer#000000001'], ['Customer#000000002'], ['Customer#000000003']] == records
示例#11
0
 def get_hook(self) -> TrinoHook:
     """Get Trino hook"""
     return TrinoHook(trino_conn_id=self.trino_conn_id, )