def test_bulk_load(self): hook = PostgresHook() input_data = ["foo", "bar", "baz"] with hook.get_conn() as conn: with conn.cursor() as cur: cur.execute("CREATE TABLE {} (c VARCHAR)".format(self.table)) conn.commit() with NamedTemporaryFile() as f: f.write("\n".join(input_data).encode("utf-8")) f.flush() hook.bulk_load(self.table, f.name) cur.execute("SELECT * FROM {}".format(self.table)) results = [row[0] for row in cur.fetchall()] self.assertEqual(sorted(input_data), sorted(results))
def create_temporary_dbfile(self, request_id, buffer): pg_hook = PostgresHook(postgres_conn_id=self.postgres_conn_id, schema=self.database) with closing(pg_hook.get_conn()) as pg_conn: with closing(pg_conn.cursor()) as pg_cursor: pg_cursor.execute( """ INSERT INTO imgw.temporary_file (request_id, data) VALUES (%s, %s) ON CONFLICT (request_id) DO UPDATE SET data = EXCLUDED.data RETURNING id """, (request_id, psycopg2.Binary(buffer))) temporary_file_id = pg_cursor.fetchone() pg_conn.commit() return temporary_file_id[0]
def execute(self, context): aws_hook = AwsBaseHook(self.aws_credentials_id, client_type="s3") credentials = aws_hook.get_credentials() redshift = PostgresHook(postgres_conn_id=self.redshift_conn_id) self.log.info("Clearing data from destination Redshift table") redshift.run("DELETE FROM {}".format(self.table)) self.log.info("Copying data from S3 to Redshift") rendered_key = self.s3_key.format(**context) s3_path = "s3://{}/{}".format(self.s3_bucket, rendered_key) formatted_sql = StageToRedshiftOperator.copy_sql.format( self.table, s3_path, credentials.access_key, credentials.secret_key, self.json_format, ) redshift.run(formatted_sql)
def setUpClass(cls): postgres = PostgresHook() with postgres.get_conn() as conn: with conn.cursor() as cur: for table in TABLES: cur.execute(f"DROP TABLE IF EXISTS {table} CASCADE;") cur.execute( f"CREATE TABLE {table}(some_str varchar, some_num integer);" ) cur.execute( "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);", ('mock_row_content_1', 42)) cur.execute( "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);", ('mock_row_content_2', 43)) cur.execute( "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);", ('mock_row_content_3', 44))
def test_bulk_dump(self): hook = PostgresHook() input_data = ["foo", "bar", "baz"] with hook.get_conn() as conn: with conn.cursor() as cur: cur.execute(f"CREATE TABLE {self.table} (c VARCHAR)") values = ",".join(f"('{data}')" for data in input_data) cur.execute(f"INSERT INTO {self.table} VALUES {values}") conn.commit() with NamedTemporaryFile() as f: hook.bulk_dump(self.table, f.name) f.seek(0) results = [ line.rstrip().decode("utf-8") for line in f.readlines() ] assert sorted(input_data) == sorted(results)
def execute(self, context): aws_hook = AwsBaseHook(self.aws_credentials_id) aws_credentials = aws_hook.get_credentials() redshift_conn = PostgresHook( postgres_conn_id=self.redshift_conn_id, connect_args={ 'keepalives': 1, 'keepalives_idle': 60, 'keepalives_interval': 60 }) self.log.debug(f"Truncate Table: {self.table}") redshift_conn.run(f"TRUNCATE TABLE {self.table}") format = '' if self.data_format == 'csv' and self.ignore_header > 0: format += f"IGNOREHEADER {self.ignore_header}\n" if self.data_format == 'csv': format += f"DELIMITER '{self.delimiter}'\n" elif self.data_format == 'json': format += f"FORMAT AS JSON '{self.jsonpath}'\n" format += f"{self.copy_opts}" self.log.debug(f"format : {format}") formatted_key = self.s3_src_bucket_key.format(**context) self.log.info(f"Rendered S3 source file key : {formatted_key}") s3_url = f"s3://{self.s3_src_bucket_name}/{formatted_key}" self.log.debug(f"S3 URL : {s3_url}") formatted_sql = self._sql.format(**dict( table=self.table, source=s3_url, access_key=aws_credentials.access_key, secret_access_key=aws_credentials.secret_key, format=format )) self.log.debug(f"Base SQL: {self._sql}") self.log.info(f"Copying data from S3 to Redshift table {self.table}...") redshift_conn.run(formatted_sql) self.log.info(f"Finished copying data from S3 to Redshift table {self.table}")
def test_bulk_dump(self): hook = PostgresHook() input_data = ["foo", "bar", "baz"] with hook.get_conn() as conn: with conn.cursor() as cur: cur.execute("CREATE TABLE {} (c VARCHAR)".format(self.table)) values = ",".join("('{}')".format(data) for data in input_data) cur.execute("INSERT INTO {} VALUES {}".format( self.table, values)) conn.commit() with NamedTemporaryFile() as f: hook.bulk_dump(self.table, f.name) f.seek(0) results = [ line.rstrip().decode("utf-8") for line in f.readlines() ] self.assertEqual(sorted(input_data), sorted(results))
def capture_export_wrap(ds, **kwargs): from lib.utils import print_time db = PostgresHook(postgres_conn_id=postgresConnId) conn = db.get_conn() try: # # generate last month # # get the month value of the last month # now = datetime.now() # last_month = now.month - 1 # # get the year value of the last month # last_year = now.year # if last_month == 0: # last_month = 12 # last_year = now.year - 1 # print ("last_month:", last_month) # print ("last_year:", last_year) # # get the last month # year_month = str(last_year) + "-" + str(last_month) # print ("year_month:", year_month) date = datetime.now().strftime("%Y-%m-%d") print("date:", date) CKAN_DOMAIN = Variable.get("CKAN_DOMAIN") # check if CKAN_DOMAIN exists assert CKAN_DOMAIN CKAN_DATASET_NAME = Variable.get("CKAN_DATASET_NAME") assert CKAN_DATASET_NAME CKAN_API_KEY = Variable.get("CKAN_API_KEY") assert CKAN_API_KEY ckan_config = { "CKAN_DOMAIN": CKAN_DOMAIN, "CKAN_DATASET_NAME": CKAN_DATASET_NAME, "CKAN_API_KEY": CKAN_API_KEY, } print("ckan_config:", ckan_config) capture_export(conn, date, 178, ckan_config) return 0 except Exception as e: print("get error when exec SQL:", e) raise ValueError('Error executing query') return 1
def execute(self, context, testing=False): """Does data quality checks for each table in table list. Assert a list of tables against a business defined SQL metrics. """ self.log.info('DataQualityCheckOperator Starting...') self.log.info("Initializing Postgres Master DB Connection...") psql_hook = PostgresHook(postgres_conn_id=self._postgres_conn_id) try: conn = psql_hook.get_conn() cursor = conn.cursor(cursor_factory=RealDictCursor) for table in self._tables: data_quality = dict() for name, query in self._queries.items(): self.log.info(f"Running query: {query}") cursor.execute(query) result = cursor.fetchone() result = result.get('count') if not result: error = ("Data quality check FAILED. " f"{table} returned no results " f"for query: {name}") self.log.error(error) raise ValueError(error) data_quality[name] = result self.log.info( f"Data quality check on table '{table}' PASSED\n" "Results Summary:\n" f"{json.dumps(data_quality, indent=4, sort_keys=True)}") except (InterfaceError, OperationalError): self.log.error("DataQualityCheckOperator FAILED.") self.log.error(traceback.format_exc()) raise Exception("DataQualityCheckOperator FAILED.") except Exception: self.log.error("DataQualityCheckOperator FAILED.") raise Exception("DataQualityCheckOperator FAILED.") finally: if not testing: conn.close() self.log.info('DataQualityCheckOperator SUCCESS!') return data_quality
def execute(self, context) -> None: postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id) s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) credentials = s3_hook.get_credentials() credentials_block = build_credentials_block(credentials) copy_options = '\n\t\t\t'.join(self.copy_options) destination = f'{self.schema}.{self.table}' copy_destination = f'#{self.table}' if self.method == 'UPSERT' else destination copy_statement = self._build_copy_query(copy_destination, credentials_block, copy_options) if self.method == 'REPLACE': sql = f""" BEGIN; DELETE FROM {destination}; {copy_statement} COMMIT """ elif self.method == 'UPSERT': keys = self.upsert_keys or postgres_hook.get_table_primary_key(self.table, self.schema) if not keys: raise AirflowException( f"No primary key on {self.schema}.{self.table}. Please provide keys on 'upsert_keys'" ) where_statement = ' AND '.join([f'{self.table}.{k} = {copy_destination}.{k}' for k in keys]) sql = f""" CREATE TABLE {copy_destination} (LIKE {destination}); {copy_statement} BEGIN; DELETE FROM {destination} USING {copy_destination} WHERE {where_statement}; INSERT INTO {destination} SELECT * FROM {copy_destination}; COMMIT """ else: sql = copy_statement self.log.info('Executing COPY command...') postgres_hook.run(sql, self.autocommit) self.log.info("COPY command complete...")
def execute(self, context, testing=False): """ Read all data from mongo db, process it and write to postgresql db. Uses UPSERT SQL query to write data. """ self.log.info('LoadToMasterdbOperator Starting...') self.log.info("Initializing Mongo Staging DB Connection...") mongo_hook = MongoHook(conn_id=self._mongo_conn_id) ports_collection = mongo_hook.get_collection(self._mongo_collection) self.log.info("Initializing Postgres Master DB Connection...") psql_hook = PostgresHook(postgres_conn_id=self._postgres_conn_id) psql_conn = psql_hook.get_conn() psql_cursor = psql_conn.cursor() self.log.info("Loading Staging data to Master Database...") try: for idx, document in enumerate(ports_collection.find({})): document = self._processor.process_item(document) staging_id = document.get('_id').__str__() document['staging_id'] = staging_id document.pop('_id') psql_cursor.execute(self._sql_query, document) psql_conn.commit() except (OperationalError, UndefinedTable, OperationFailure): self.log.error("Writting to database FAILED.") self.log.error(traceback.format_exc()) raise Exception("LoadToMasterdbOperator FAILED.") except Exception: self.log.error(traceback.format_exc()) raise Exception("LoadToMasterdbOperator FAILED.") finally: if not testing: self.log.info('Closing database connections...') psql_conn.close() mongo_hook.close_conn() self.log.info(f'UPSERTED {idx+1} records into Postgres Database.') self.log.info('LoadToMasterdbOperator SUCCESS!')
def execute(self, context=None): """ Format the sql statements with the params_sql statement. Execute one by one the different statements. Args: context: Returns: """ if self.params_sql is not None: commands_formatted = [ S.SQL(q).format(**self.params_sql) for q in self.commands_stripped ] else: commands_formatted = [S.SQL(q) for q in self.commands_stripped] hook = PostgresHook(postgres_conn_id=self.redshift_conn_id) for qf in commands_formatted: self.log.info("Executing Query:{}".format( qf.as_string(hook.get_conn()))) hook.run((qf, )) pass
def execute(self, context) -> None: postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id) s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) credentials = s3_hook.get_credentials() credentials_block = build_credentials_block(credentials) copy_options = '\n\t\t\t'.join(self.copy_options) copy_statement = self._build_copy_query(credentials_block, copy_options) if self.truncate_table: delete_statement = f'DELETE FROM {self.schema}.{self.table};' sql = f""" BEGIN; {delete_statement} {copy_statement} COMMIT """ else: sql = copy_statement self.log.info('Executing COPY command...') postgres_hook.run(sql, self.autocommit) self.log.info("COPY command complete...")
def execute(self, context) -> None: postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id) conn = S3Hook.get_connection(conn_id=self.aws_conn_id) credentials_block = None if conn.extra_dejson.get('role_arn', False): credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}" else: s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) credentials = s3_hook.get_credentials() credentials_block = build_credentials_block(credentials) unload_options = '\n\t\t\t'.join(self.unload_options) unload_query = self._build_unload_query(credentials_block, self.select_query, self.s3_key, unload_options) self.log.info('Executing UNLOAD command...') postgres_hook.run(unload_query, self.autocommit, parameters=self.parameters) self.log.info("UNLOAD command complete...")
def execute(self, context): postgres_hook = PostgresHook(postgres_conn_id=self._postgres_conn_id) s3_hook = S3Hook(aws_conn_id=self._s3_conn_id) with postgres_hook.get_cursor() as cursor: cursor.execute(self._query) results = cursor.fetchall() headers = [_[0] for _ in cursor.description] data_buffer = io.StringIO() csv_writer = csv.writer(data_buffer, quoting=csv.QUOTE_ALL, lineterminator=os.linesep) csv_writer.writerow(headers) csv_writer.writerows(results) data_buffer_binary = io.BytesIO(data_buffer.getvalue().encode()) s3_hook.load_file_obj( file_obj=data_buffer_binary, bucket_name=self._s3_bucket, key=self._s3_key, replace=True, )
def execute(self, context): self.hook = PostgresHook(postgres_conn_id=self.redshift_conn_id) self.s3 = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) credentials = self.s3.get_credentials() copy_options = '\n\t\t\t'.join(self.copy_options) copy_query = """ COPY {schema}.{table} FROM 's3://{s3_bucket}/{s3_key}/{table}' with credentials 'aws_access_key_id={access_key};aws_secret_access_key={secret_key}' {copy_options}; """.format(schema=self.schema, table=self.table, s3_bucket=self.s3_bucket, s3_key=self.s3_key, access_key=credentials.access_key, secret_key=credentials.secret_key, copy_options=copy_options) self.log.info('Executing COPY command...') self.hook.run(copy_query, self.autocommit) self.log.info("COPY command complete...")
def execute(self, context): """ Description: This custom function fills a given fact table with a passed SQL statement. Arguments: self: Instance of the class context: Context dictionary Returns: None """ # Build connection postgres = PostgresHook(postgres_conn_id=self.postgres_conn_id) # Realize insert statement to fill dimension table formatted_sql = LoadFactOperator.insert_sql.format( self.table, self.insert_sql_query) postgres.run(formatted_sql) self.log.info( 'LoadFactOperator for dimension table {} completed'.format( self.table))
def execute(self, context): """Establish connections to both MySQL & PostgreSQL databases, open cursor and begin processing query, loading chunks of rows into PostgreSQL. Repeat loading chunks until all rows processed for query. """ source = MySqlHook(mysql_conn_id=self.mysql_conn_id) target = PostgresHook(postgres_conn_id=self.postgres_conn_id) with closing(source.get_conn()) as conn: with closing(conn.cursor()) as cursor: cursor.execute(self.sql, self.params) target_fields = [x[0] for x in cursor.description] row_count = 0 rows = cursor.fetchmany(self.rows_chunk) while len(rows) > 0: row_count += len(rows) target.insert_rows( self.postgres_table, rows, target_fields=target_fields, commit_every=self.rows_chunk, ) rows = cursor.fetchmany(self.rows_chunk) self.log.info( f"{row_count} row(s) inserted into {self.postgres_table}.")
def tearDown(self): super().tearDown() with PostgresHook().get_conn() as conn: with conn.cursor() as cur: cur.execute("DROP TABLE IF EXISTS {}".format(self.table))
def tearDownClass(cls): postgres = PostgresHook() with postgres.get_conn() as conn: with conn.cursor() as cur: for table in TABLES: cur.execute("DROP TABLE IF EXISTS {} CASCADE;".format(table))
def earnings_report(ds, **kwargs): db = PostgresHook(postgres_conn_id=postgresConnId) conn = db.get_conn() print("db:", conn) cursor = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) try: # these hard coded values are placeholders for the upcoming contracts system freetown_stakeholder_uuid = "2a34fa81-0683-4d25-94b9-24843ceec3c4" freetown_base_contract_uuid = "483a1f4e-0c52-4b53-b917-5ff4311ded26" freetown_base_contract_consolidation_uuid = "a2dc79ec-4556-4cc5-bff1-2dbb5fd35b51" cursor.execute(""" SELECT COUNT(tree_id) capture_count, person_id, stakeholder_uuid, MIN(time_created) consolidation_start_date, MAX(time_created) consolidation_end_date, ARRAY_AGG(tree_id) tree_ids FROM ( SELECT trees.id tree_id, person_id, time_created, stakeholder_uuid, rank() OVER ( PARTITION BY person_id ORDER BY time_created ASC ) FROM trees JOIN planter ON trees.planter_id = planter.id JOIN entity ON entity.id = planter.person_id AND earnings_id IS NULL AND planter.organization_id IN ( select entity_id from getEntityRelationshipChildren(178) ) AND time_created > TO_TIMESTAMP( '2021-09-01 00:00:00', 'YYYY-MM-DD HH24:MI:SS' ) AND time_created < TO_TIMESTAMP( '2021-11-12 00:00:00', 'YYYY-MM-DD HH24:MI:SS' ) AND trees.approved = true AND trees.active = true ) rank GROUP BY person_id, stakeholder_uuid ORDER BY person_id; """); print("SQL result:", cursor.query) for row in cursor: print(row) #calculate the earnings based on FCC logic multiplier = (row['capture_count'] - row['capture_count'] % 100) / 10 / 100 if multiplier > 1: multiplier = 1 print( "multiplier " + str(multiplier) ) maxPayout = 1200000 earningsCurrency = 'SLL' earnings = multiplier * maxPayout updateCursor = conn.cursor() updateCursor.execute(""" INSERT INTO earnings.earnings( worker_id, contract_id, funder_id, currency, amount, calculated_at, consolidation_rule_id, consolidation_period_start, consolidation_period_end, status ) VALUES( %s, %s, %s, %s, %s, NOW(), %s, %s, %s, 'calculated' ) RETURNING * """, ( row['stakeholder_uuid'], freetown_base_contract_uuid, freetown_stakeholder_uuid, earningsCurrency, earnings, freetown_base_contract_consolidation_uuid, row['consolidation_start_date'], row['consolidation_end_date'])) print("SQL result:", updateCursor.query) earningsId = updateCursor.fetchone()[0] print(earningsId) updateCursor.execute(""" UPDATE trees SET earnings_id = %s WHERE id = ANY(%s) """, (earningsId, row['tree_ids'])) conn.commit() return 0 except Exception as e: print("get error when exec SQL:", e) print("SQL result:", updateCursor.query) raise ValueError('Error executing query') return 1
def execute(self, context): self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id, schema=self.database) self.hook.run(self.sql, self.autocommit, parameters=self.parameters) for output in self.hook.conn.notices: self.log.info(output)
def setUp(self): self.oltp_hook = PostgresHook('oltp') self.olap_hook = PostgresHook('olap')
# Setting database name db_name = "userdata" # The api that we need to call NY_API = "https://health.data.ny.gov/resource/xdss-u53e.json?" # These args will get passed on to each operator # You can override them on a per-task basis during operator initialization default_args = { 'owner': 'Anil', 'dag_id': 'LOAD_NY_COVID_DLY', 'start_date': datetime(2020, 3, 1, tzinfo=local_tz), 'schedule_interval': '0 9 * * *' } # Using postgress Hook to get connection url and modifying it to have the right databasename result = PostgresHook(postgres_conn_id='postgres_new').get_uri().split("/") result[3] = db_name dbURI = "/".join(result) with DAG('LOAD_NY_COVID_DLY', default_args=default_args, catchup=False, template_searchpath='/opt/airflow/') as dag: @dag.task def getTodayDate(): """ gets the current context of Airflow task. This context will be used to get the execution date. """ context = {"test_date": get_current_context()["ds"]}
import os from airflow.operators.python_operator import PythonOperator from airflow.providers.postgres.hooks.postgres import PostgresHook import csv PostgresConn = PostgresHook(postgres_conn_id='postgresql_conn') def getconnection(): PostgresConn.get_conn() print("connected") def writerrecords_aisles(): id = PostgresConn.get_records(sql='SELECT * FROM aisles') print(id) if not os.path.exists(os.path.join(os.getcwd(), 'base_data')): os.makedirs(os.path.join(os.getcwd(), 'base_data')) with open(os.path.join(os.getcwd(), 'base_data/aisles.csv'), 'w') as f: writer = csv.writer(f) writer.writerows(id) def writerrecords_clients(): id = PostgresConn.get_records(sql='SELECT * FROM clients') if not os.path.exists(os.path.join(os.getcwd(), 'base_data')): os.makedirs(os.path.join(os.getcwd(), 'base_data')) with open(os.path.join(os.getcwd(), 'base_data/clients.csv'), 'w') as f: writer = csv.writer(f) writer.writerows(id)
def get_data(self): pgHook = PostgresHook(postgres_conn_id=self.postgres_conn_id) with closing(pgHook.get_conn()) as conn: df = pd.read_sql(self.postgres_sql, conn) return df
def tearDown(self): tables_to_drop = ['test_postgres_to_postgres', 'test_airflow'] with PostgresHook().get_conn() as conn: with conn.cursor() as cur: for table in tables_to_drop: cur.execute(f"DROP TABLE IF EXISTS {table}")
def create_tokens(ds, **kwargs): walletName = kwargs['dag_run'].conf.get('walletName') entityId = kwargs['dag_run'].conf.get('entityId') dryRun = kwargs['dag_run'].conf.get('dryRun') # print them out print('walletName:', walletName) print('entityId:', entityId) print('dryRun:', dryRun) # check if wallet exists if walletName is None: print('walletName is None') return if entityId is None: print('entityId is None') return if dryRun is None: print('dryRun is None') return result = 'pending' db = PostgresHook(postgres_conn_id='postgres_default') connection = db.get_conn() cursor = connection.cursor(cursor_factory=psycopg2.extras.DictCursor) try: # get first row from table 'wallet' cursor.execute("SELECT * FROM wallet.wallet WHERE name = '{}'".format(walletName)) wallet = cursor.fetchone() # check wallet exists if wallet is None: print('Wallet not found') return print('Wallet found', wallet) remaining = True for i in range(1, 100000): # if remaining is false, then we are done if not remaining: break # fetch rows from table 'trees' cursor.execute(""" select id, uuid, token_id from trees where planter_id IN ( select id from planter where organization_id IN ( select entity_id from getEntityRelationshipChildren({}) ) ) AND active = true AND approved = true AND token_id IS NULL LIMIT 3000 """.format(entityId)) trees = cursor.fetchall() print('Trees found', len(trees)) # check trees length < 3000 if len(trees) < 3000: print('Not more trees') remaining = False # for each tree, create a token for capture in trees: print('capture', capture) tokenData = { 'tree_id': capture['id'], 'capture_id': capture['uuid'], 'wallet_id': wallet['id'], } print('tokenData', tokenData) # create token cursor.execute(""" INSERT INTO wallet.token ( capture_id, wallet_id ) VALUES ( '{}', '{}' ) RETURNING id """.format(tokenData['capture_id'], tokenData['wallet_id'])) token = cursor.fetchone() print('token', token) print('token[id]', token['id']) # update tree with token id cursor.execute(""" UPDATE trees SET token_id = '{}' WHERE id = {} """.format(token['id'], capture['id'])) print('Token created: {}'.format(token)) # if dryRun is false, then commit if not dryRun: connection.commit() print('Commit') result = 'success' else: print('Dry run, not committing') result = 'dry run' except Exception as e: print(e) result = 'error' finally: cursor.close() connection.close() print('result', result) # check result value, if success, return true, else return false if result == 'success': return 0 else: return 1
def drop_db(): hook = PostgresHook() hook.run(DELETE_QUERY)
def get_hook(self): if self.conn_type == 'mysql': from airflow.providers.mysql.hooks.mysql import MySqlHook return MySqlHook(mysql_conn_id=self.conn_id) elif self.conn_type == 'google_cloud_platform': from airflow.gcp.hooks.bigquery import BigQueryHook return BigQueryHook(bigquery_conn_id=self.conn_id) elif self.conn_type == 'postgres': from airflow.providers.postgres.hooks.postgres import PostgresHook return PostgresHook(postgres_conn_id=self.conn_id) elif self.conn_type == 'pig_cli': from airflow.providers.apache.pig.hooks.pig import PigCliHook return PigCliHook(pig_cli_conn_id=self.conn_id) elif self.conn_type == 'hive_cli': from airflow.providers.apache.hive.hooks.hive import HiveCliHook return HiveCliHook(hive_cli_conn_id=self.conn_id) elif self.conn_type == 'presto': from airflow.providers.presto.hooks.presto import PrestoHook return PrestoHook(presto_conn_id=self.conn_id) elif self.conn_type == 'hiveserver2': from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook return HiveServer2Hook(hiveserver2_conn_id=self.conn_id) elif self.conn_type == 'sqlite': from airflow.providers.sqlite.hooks.sqlite import SqliteHook return SqliteHook(sqlite_conn_id=self.conn_id) elif self.conn_type == 'jdbc': from airflow.providers.jdbc.hooks.jdbc import JdbcHook return JdbcHook(jdbc_conn_id=self.conn_id) elif self.conn_type == 'mssql': from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook return MsSqlHook(mssql_conn_id=self.conn_id) elif self.conn_type == 'odbc': from airflow.providers.odbc.hooks.odbc import OdbcHook return OdbcHook(odbc_conn_id=self.conn_id) elif self.conn_type == 'oracle': from airflow.providers.oracle.hooks.oracle import OracleHook return OracleHook(oracle_conn_id=self.conn_id) elif self.conn_type == 'vertica': from airflow.providers.vertica.hooks.vertica import VerticaHook return VerticaHook(vertica_conn_id=self.conn_id) elif self.conn_type == 'cloudant': from airflow.providers.cloudant.hooks.cloudant import CloudantHook return CloudantHook(cloudant_conn_id=self.conn_id) elif self.conn_type == 'jira': from airflow.providers.jira.hooks.jira import JiraHook return JiraHook(jira_conn_id=self.conn_id) elif self.conn_type == 'redis': from airflow.providers.redis.hooks.redis import RedisHook return RedisHook(redis_conn_id=self.conn_id) elif self.conn_type == 'wasb': from airflow.providers.microsoft.azure.hooks.wasb import WasbHook return WasbHook(wasb_conn_id=self.conn_id) elif self.conn_type == 'docker': from airflow.providers.docker.hooks.docker import DockerHook return DockerHook(docker_conn_id=self.conn_id) elif self.conn_type == 'azure_data_lake': from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id) elif self.conn_type == 'azure_cosmos': from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id) elif self.conn_type == 'cassandra': from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook return CassandraHook(cassandra_conn_id=self.conn_id) elif self.conn_type == 'mongo': from airflow.providers.mongo.hooks.mongo import MongoHook return MongoHook(conn_id=self.conn_id) elif self.conn_type == 'gcpcloudsql': from airflow.gcp.hooks.cloud_sql import CloudSQLDatabaseHook return CloudSQLDatabaseHook(gcp_cloudsql_conn_id=self.conn_id) elif self.conn_type == 'grpc': from airflow.providers.grpc.hooks.grpc import GrpcHook return GrpcHook(grpc_conn_id=self.conn_id) raise AirflowException("Unknown hook type {}".format(self.conn_type))