def test_get_conn_from_connection_with_schema(self, mock_connect): conn = Connection(login='******', password='******', host='host', schema='schema') hook = PostgresHook(connection=conn, schema='schema-override') hook.get_conn() mock_connect.assert_called_once_with( user='******', password='******', host='host', dbname='schema-override', port=None )
def transfer_function(ds, **kwargs): query = "SELECT * FROM source_city_table" #source hook source_hook = PostgresHook(postgres_conn_id='postgres_conn', schema='airflow') source_conn = source_hook.get_conn() #destination hook destination_hook = PostgresHook(postgres_conn_id='postgres_conn', schema='airflow') destination_conn = destination_hook.get_conn() source_cursor = source_conn.cursor() destination_cursor = destination_conn.cursor() source_cursor.execute(query) records = source_cursor.fetchall() if records: execute_values(destination_cursor, "INSERT INTO target_city_table VALUES %s", records) destination_conn.commit() source_cursor.close() destination_cursor.close() source_conn.close() destination_conn.close() print("Data transferred successfully!")
def add_execution_in_database( table, data={}, connection_id="postgres_report_connection" ): """Registra informações em um banco PostgreSQL de forma dinâmica.""" data = dict(data) if data is None or len(data.keys()) == 0: logging.info( "Cannot insert `empty data` into the database. Please verify your data attributes." ) return hook = PostgresHook(postgres_conn_id=connection_id) try: hook.get_conn() except AirflowException: logging.info("Cannot insert data. Connection '%s' is not configured.", connection_id) return if data.get("payload"): data["payload"] = json.dumps(data["payload"]) columns = list(data.keys()) values = list(data.values()) try: hook.insert_rows(table, [values], target_fields=columns) except (AirflowException, ProgrammingError) as exc: logging.error(exc) else: logging.info("Registering `%s` into '%s' table.", data, table)
def upload_to_postgres(*args, **kwargs): ti = kwargs['ti'] file_date, data = ti.xcom_pull(task_ids="get_data_from_api") file_date = file_date.strftime("%Y-%m-%d %H:%M:%S%z") data = [extract_content_id(i) for i in list(data)] # DbApiHook.insert_rows pg_hook = PostgresHook(postgres_conn_id='postgres_conn', schema='seoulbike') rows = [ [st_data['stationName'], st_data['rackTotCnt'],st_data['parkingBikeTotCnt'], file_date] for st_data in data ] pg_hook.insert_rows(table='bike_realtime_log_tz', rows=rows) # autocommit pg_hook.get_conn().close()
def posgres_database_check(**context): try: pg_hook = PostgresHook(postgres_conn_id="smirn08_database") pg_hook.get_conn() except psycopg2.OperationalError as connection_error: context["task_instance"].xcom_push(key="error_message", value=connection_error) logging.warn("Ошибка подключения к postgres", exc_info=True) return ["telegram_pg_allert"] else: return ["download_and_preprocessing_orders"]
def run_sql(sql: str, connection_id: str) -> None: """Executa instruções SQL em uma conexão airflow""" hook = PostgresHook(postgres_conn_id=connection_id) try: hook.get_conn() except OperationalError: logging.info("Connection `%s` is not configured.", connection_id) return try: hook.run(sql) except (AirflowException, ProgrammingError, OperationalError) as exc: logging.error(exc)
def updateData(self, table_name, data): ''' This function updates county table data when required :param table_name: county name :param data: Data is a list of tuples :return: None ''' pg_hook = PostgresHook(postgre_conn_id="postgres_default", schema="airflow") conn = pg_hook.get_conn() cursor = conn.cursor() cursor.execute( "INSERT INTO " + table_name + " (id, test_date, new_positives, cumulative_number_of_positives, " "total_number_of_test_performed, cumulative_number_of_test_performed, " "load_date, state) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (id) " "DO UPDATE SET test_date = excluded.test_date, new_positives = excluded.new_positives," " cumulative_number_of_positives = excluded.cumulative_number_of_positives," "total_number_of_test_performed = excluded.total_number_of_test_performed," " cumulative_number_of_test_performed = excluded.cumulative_number_of_test_performed," "load_date = excluded.load_date, state = excluded.state; ", data[0]) c = conn.commit() print(c) conn.close()
def update_app_load_stats(table_name, record_count, load_type, load_status): print('update load stats') pg_hook = PostgresHook(postgres_conn_id=config.app_afw_meta_db) con = pg_hook.get_conn() cursor = con.cursor() insert_sql = """insert into app_load_stats(dag_name, tenant_name, table_name, load_type, record_count, triggered_by, load_status) values ( '{dag_name}', '{tenant_name}', '{table_name}', '{load_type}', {record_count}, '{triggered_by}', {load_status}) """.format( dag_name=dag_name, tenant_name=tenant_name, table_name=table_name, load_type=load_type, load_status=load_status, record_count=record_count, triggered_by=config.owner_name) try: cursor.execute(insert_sql) con.commit() except Exception as ERROR: print('error inserting load stats into metadata table ', ERROR) finally: cursor.close() con.close()
def create_table_in_rs(bucket_name, tenant_name, schema_name, table_name): set_query = ' SET AUTOCOMMIT = ON ' current_dir = pathlib.Path(__file__).parent print('current_dir---', current_dir) file_name = str(current_dir) + '/db/ddl/' + table_name + '.sql' with open(file_name, 'r') as f: ddl_query = f.read().replace("\n", '') print('ddl_query ---', ddl_query) pg_hook = PostgresHook(postgres_conn_id=config.rs_conn_id) con = pg_hook.get_conn() location = 's3://' + bucket_name + '/' + tenant_name + '/' + table_name + '/' ddl_query = ddl_query.format(schema_name=schema_name, location=location) print('ddl_query ---', ddl_query) cursor = con.cursor() cursor.execute(set_query) cursor.execute(ddl_query) con.commit() cursor.close() con.close()
def export_to_csv(): db = PostgresHook(postgres_conn_id="postgres_covid") with open(CSV_FN, "w") as f: with closing(db.get_conn()) as conn: with closing(conn.cursor()) as cur: cur.copy_expert( "COPY (SELECT * From covid_us) TO STDOUT CSV HEADER;", f)
def purgeDupes(**kwargs): """ Call postgres function to purge dupes from lf_store table """ pgHook = PostgresHook(postgres_conn_id=POSTGRES_DB) try: pgConn = pgHook.get_conn() pgCursor = pgConn.cursor() pgCursor.callproc('store_remove_dupes') result = pgCursor.fetchall() pgConn.commit() recCount = result[0][0] logging.info( 'DEDUPE COMPLETED SUCCESSFULLY - {} RECORDS DELETED FROM LF_STORE'. format(recCount)) except Exception as e: logging.error('Error trying to run store_remove_dupes step') raise e finally: if pgConn: pgCursor.close() pgConn.close()
def pg_conn(self): postgres_hook = PostgresHook(postgres_conn_id="postgres_dwh", schema="test") conn = postgres_hook.get_conn() cur = conn.cursor() return conn, cur
def execute(self, context): self.log.info("Context {}".format(context)) table = self.table task_id = self.task_ids deleteTable = self.deleteLoad data = context['task_instance'].xcom_pull(task_ids=task_id)['data'] insertTable = self.insertTable redshift = PostgresHook(postgres_conn_id=self.redshift_conn_id) self.log.info("Write table {}".format(table)) insertTable = self.insertTable redshift = PostgresHook(postgres_conn_id="redshift") if deleteTable: sql = f"DELETE FROM {table}" redshift.run(sql) self.log.info("Clean the table: {}".format(table)) with redshift.get_conn() as conn: with conn.cursor() as cur: for row in data: try: cur.execute(insertTable, list(row)) conn.commit() except Exception as e: logging.error(e) conn.rollback()
def __init__( self, # dag_confing = None, sql, postgres_conn_id, group_name, *args, **kwargs): super(ListTable, self).__init__(*args, **kwargs) # self.dag = dag_confing hook = PostgresHook(postgres_conn_id = postgres_conn_id) conn = hook.get_conn() cursor = conn.cursor() cursor.execute(sql) rows = cursor.fetchall() self.tiList = [] # // TODO: 動態產生不同Operator, 不指定產生特定Operator # generate dyncmic tasks for row in rows: self.tiList.append(MultiplyBy5Operator(task_id="{0}.{1}".format(group_name, row[0]), my_operator_param=row[0], dag=self.dag)) # set task dependencies self.dag.task_dict[self.task_id].set_downstream(self.tiList) DummyOperator(task_id="{0}_{1}".format('done',group_name), dag=self.dag).set_upstream(self.tiList) cursor.close()
def dump_to_dest(self, src_conn_id: str, t_schema: str, t_name: str): """ Carrega dump da origem para o destino """ # avoid parse to python dictionary (keeps postgres json) register_adapter(dict, Json) register_json(oid=3802, array_oid=3807, globally=True) src_hook = PostgresHook(postgres_conn_id=src_conn_id) src_conn = src_hook.get_conn() src_cursor = src_conn.cursor() src_cursor.execute(f'select count(0) from {t_name};') qtd = src_cursor.fetchone()[0] dest_cursor = self.conn.cursor() dest_cursor.execute(f'TRUNCATE TABLE {t_schema}.{t_name};') self.conn.commit() if qtd > 0: with tempfile.NamedTemporaryFile() as temp_file: print('Gerando dump tabela:', t_name, 'linhas:', qtd) src_hook.bulk_dump(t_name, temp_file.name) print('Carregando dump tabela:', f'{t_schema}.{t_name}', 'linhas:', qtd) self.hook.bulk_load(f'{t_schema}.{t_name}', temp_file.name) else: print('Não foi gerado dump tabela:', t_name, 'pois possui 0 registros')
def validate_row_count(schema): ''' Validate row count for each table. If the row count for a table is less than the minimum defined for that table, then log an error and fail the task If the row count for a table is greater than the minimum defined for that table, then log the info and succeed the task ''' # extract only the required tables from the dataframe (src/stg/core tables list) DF_STG_SRC_TABLES = DF_ROW_CNT_VALDTN[DF_ROW_CNT_VALDTN['table'].str.contains(schema+'.')] pghook = PostgresHook('postgres_local') conn = pghook.get_conn() cursor = conn.cursor() # validate the row count for each table for ix, row in DF_STG_SRC_TABLES.iterrows(): table = row[0] min_rows = int(row[1]) cursor.execute(VALIDATE_ROW_CNT_SQL.format(table)) result = cursor.fetchall() row_cnt = int(result[0][0]) if row_cnt < min_rows: logging.error('Row count validation FAILED for : '+table+'. Number of rows in the table = '+str(row_cnt)+', Minimum rows expected = '+str(min_rows)) sys.exit(200) else: logging.info('Row count validation PASSED for : '+table+'. Number of rows in the table = '+str(row_cnt))
def get_postgres_cursor(self): postgres = PostgresHook(postgres_conn_id=self.postgres_conn_id) logging.info("Success connect to %s Postgres databases" % (self.postgres_conn_id)) pg_conn = postgres.get_conn() pg_cursor = pg_conn.cursor() return pg_conn, pg_cursor
def mergeStore(**kwargs): """ Call postgres function to merge from lf_store to store table """ pgHook = PostgresHook(postgres_conn_id=POSTGRES_DB) try: pgConn = pgHook.get_conn() pgCursor = pgConn.cursor() pgCursor.callproc('store_merge') result = pgCursor.fetchall() pgConn.commit() recCount = result[0][0] logging.info( 'MERGE COMPLETED SUCCESSFULLY - {} STORE RECORDS INSERTED/UPDATED'. format(recCount)) except Exception as e: logging.error('Error trying to run store_merge step') raise e finally: if pgConn: pgCursor.close() pgConn.close()
def gen_error_reports(statfile, logfile, tablename, **kwargs): # database hook db_hook = PostgresHook(postgres_conn_id='postgres_default', schema='airflow') db_conn = db_hook.get_conn() db_cursor = db_conn.cursor() #('extApp.log', '22995', '23 Jul 2020', '02:53:13,527', None, 'extApp', 'Unrecognized SSL message, plaintext connection?') sql = f"SELECT * FROM {tablename}" sql_output = "COPY ({0}) TO STDOUT WITH CSV HEADER".format(sql) with open(logfile, 'w') as f_output: db_cursor.copy_expert(sql_output, f_output) sql = f"SELECT error, count(*) as occurrence FROM {tablename} group by error ORDER BY occurrence DESC" # db_cursor.execute(sql, group) # get the generated id back # Use the COPY function on the SQL we created above. sql_output = "COPY ({0}) TO STDOUT WITH CSV HEADER".format(sql) # Set up a variable to store our file path and name. with open(statfile, 'w') as f_output: db_cursor.copy_expert(sql_output, f_output) db_cursor.execute(sql) print("The number of error type: ", db_cursor.rowcount) row = db_cursor.fetchone() print('first record: ', row) push_message(kwargs['ti'], row[1] if row else 0) db_cursor.close() db_conn.close() print( f"Two reports are generated from table: {tablename}\n1. {statfile}\n2. {logfile}" )
def createNYAndUSTable(self, table_name): ''' This function defines schema of Ny and US hospital data tables and creates new table if not exists :param table_name: Us or Ny :return: None ''' '''(test_date timestamp, new_positives real, cumulative_number_of_positives real, total_number_of_test_performed real, cumulative_number_of_test_performed real, load_date timestamp)''' pg_hook = PostgresHook(postgre_conn_id="postgres_default", schema="airflow") conn = pg_hook.get_conn() cursor = conn.cursor() query = '''CREATE TABLE IF NOT EXISTS ''' + table_name + ''' ( id TIMESTAMP PRIMARY KEY, totalTestResults integer, totalTestResultsIncrease integer, positive integer, positiveIncrease integer, hospitalizedCurrently integer, inIcuCurrently integer, onVentilatorCurrently integer, death integer );''' k = cursor.execute(query) conn.commit() conn.close()
def fill_tables(schemaName="", execute_date="", table_type=""): request = """ select tbl_name,tbl_fill_query, tbl_del_query from {0}.f_meta_tables, {0}.f_meta_type where tbl_type_id = type_id and type_name='{1}' order by tbl_id """.format( schemaName, table_type) pg_hook = PostgresHook() conn = pg_hook.get_conn() cursor = conn.cursor() cursor.execute(request) sources = cursor.fetchall() for tbl_name, tbl_fill_query, tbl_del_query in sources: try: if tbl_del_query is not None and tbl_fill_query is not None: if execute_date is not None: cursor.execute( tbl_del_query.format(schemaName, execute_date)) else: cursor.execute(tbl_del_query.format(schemaName)) except Exception as e: raise Exception('Ошибка:%s Запрос:%s' % (e, tbl_del_query)) try: if tbl_fill_query is not None: if execute_date is not None: cursor.execute( tbl_fill_query.format(schemaName, execute_date)) else: cursor.execute(tbl_fill_query.format(schemaName)) else: raise Exception("Query for fill %s is empty!" % tbl_name) except Exception as e: raise Exception('Ошибка:%s Запрос:%s' % (e, tbl_fill_query)) cursor.execute('commit')
def load_trip_data_to_redshift(*args, **kwargs): redshift_hook = PostgresHook("redshift") connection = redshift_hook.get_conn() cur = connection.cursor() logging.info(f"redshift_hook connection: { connection }") execution_date = kwargs["execution_date"] year = execution_date.year month = execution_date.month s3_location = f"s3://udacity-dend/data-pipelines/divvy/partitioned/{year}/{month}/divvy_trips.csv" t0 = time() query = (""" copy trips from '{}' iam_role '{}' region 'us-west-2' IGNOREHEADER 1 DELIMITER ','; """).format(s3_location, 'arn:aws:iam::850186040772:role/dwhRole') logging.info(f"executing~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~: {query}") cur.execute(query) connection.commit() loadTime = time() - t0 logging.info("===~~~~~~~~~~~~~~~~ DONE IN: {0:.2f} sec\n".format(loadTime)) connection.close()
def tearDownClass(cls): postgres = PostgresHook() with postgres.get_conn() as conn: with conn.cursor() as cur: for table in TABLES: cur.execute( "DROP TABLE IF EXISTS {} CASCADE;".format(table))
def create_tables(): #(cur, conn): """ Run's all the create table queries defined in sql_queries.py :param cur: cursor to the database :param conn: database connection reference """ postgres_hook = PostgresHook(postgres_conn_id="postgres_dwh", schema="test") conn = postgres_hook.get_conn() cur = conn.cursor() for query in create_table_queries: cur.execute(query) conn.commit() # def main(): # """ # Driver main function. # """ # cur, conn = create_database() # drop_tables(cur, conn) # print("Table dropped successfully!!") # create_tables(cur, conn) # print("Table created successfully!!") # conn.close() # if __name__ == "__main__": # main()
def execute(self, context): logging.info('Executing: ' + str(self.sql)) src_pg = PostgresHook(postgres_conn_id=self.src_postgres_conn_id) dest_pg = PostgresHook(postgres_conn_id=self.dest_postgress_conn_id) logging.info( "Transferring Postgres query results into other Postgres database." ) conn = src_pg.get_conn() cursor = conn.cursor() cursor.execute(self.sql, self.parameters) if self.pg_preoperator: logging.info("Running Postgres preoperator") dest_pg.run(self.pg_preoperator) logging.info("Inserting rows into Postgres") dest_pg.insert_rows(table=self.pg_table, rows=cursor) if self.pg_postoperator: logging.info("Running Postgres postoperator") dest_pg.run(self.pg_postoperator) logging.info("Done.")
def execute(self, context): ''' Data quality check for Redshift tables Parameters: 1) redshift_conn_id: redshift cluster connection 2) queries: list of sql queries to be executed ''' self.log.info('DataQualityCheckOperator - start') redshift = PostgresHook(self.redshift_conn_id) conn = redshift.get_conn() cursor = conn.cursor() for query in self.count_queries: cursor.execute(query) results = cursor.fetchone() for row in results: self.log.info(f'{row} records') for query in self.show_queries: self.log.info(query) cursor.execute(query) results = cursor.fetchall() for row in results: self.log.info(row) self.log.info('DataQualityCheckOperator - complete')
def execute(self, context): hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id) pg = PostgresHook(postgres_conn_id=self.postgres_conn_id) logging.info("Dumping postgres query results to local file") conn = pg.get_conn() cursor = conn.cursor() cursor.execute(self.sql, self.parameters) with NamedTemporaryFile("wb") as f: csv_writer = csv.writer(f, delimiter=self.delimiter, encoding="utf-8") field_dict = OrderedDict() for field in cursor.description: field_dict[field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor) f.flush() cursor.close() conn.close() logging.info("Loading file into Hive") hive.load_file(f.name, self.hive_table, field_dict=field_dict, create=self.create, partition=self.partition, delimiter=self.delimiter, recreate=self.recreate)
def execute(self, context): s3_hook: S3Hook = S3Hook(self.s3_conn) keys: List[str] = s3_hook.list_keys(prefix=self.bim_path, bucket_name=self.bim_bucket) keys = list(filter(lambda key: key.endswith('.xlsx'), keys)) print(keys) result = [] with Pool() as pool: records = pool.map( partial(process_xlsx, s3_hook=s3_hook, bim_bucket=self.bim_bucket), keys) print(records) for recs in records: result.extend(recs) print(result) pg_hook = PostgresHook('pg_default') conn = pg_hook.get_conn() c = conn.cursor() try: c.execute("BEGIN") params = Json(result) c.callproc("fset_bim", [params]) results = c.fetchone()[0] c.execute("COMMIT") # except Exception as e: # results = {"error": str(e)} finally: c.close() return results
def load_station_data_to_redshift(*args, **kwargs): redshift_hook = PostgresHook("redshift") connection = redshift_hook.get_conn() cur = connection.cursor() logging.info(f"redshift_hook connection: { connection }") t0 = time() query = (""" copy stations from '{}' iam_role '{}' region 'us-west-2' IGNOREHEADER 1 DELIMITER ','; """).format( "s3://udacity-dend/data-pipelines/divvy/unpartitioned/divvy_stations_2017.csv", 'arn:aws:iam::850186040772:role/dwhRole') logging.info(f"executing~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~: {query}") cur.execute(query) connection.commit() loadTime = time() - t0 logging.info("===~~~~~~~~~~~~~~~~ DONE IN: {0:.2f} sec\n".format(loadTime)) connection.close()
def save_user_to_database(tablename, records): if not records: print("Empty records!") logging.debug("Empty records") return logging.info(f'Total number of records to insert: {len(records)}') # database hook db_hook = PostgresHook(postgres_conn_id='postgres_default', schema='airflow') db_conn = db_hook.get_conn() db_cursor = db_conn.cursor() for p in records: print("Inserting " + p['UserName']) logging.info(f'Inserting: {p["UserName"]}') sql = """INSERT INTO {} (FullName, Email, UserName, JobTitle, ScannedDate, id, ManagerId, ProfileImageAddress, State, UserType, ProcessDate) VALUES (%s, %s, %s, %s, to_timestamp(%s / 1000), %s, %s, %s, %s, %s, %s)""".format(tablename) db_cursor.execute(sql,(p['FullName'], p['Email'], p['UserName'], p["JobTitle"], p['ScannedDate'], p['id'], p["ManagerId"], p["ProfileImageAddress"], p["State"], p["UserType"], p["ProcessDate"])) logging.debug(f'Inserted: {p["UserName"]} successfully') db_conn.commit() # select * from data_ingest_20210501 where username = '******' db_cursor.close() db_conn.close() print(f" -> {len(records)} records are saved to table: {tablename}.") logging.info(f" -> {len(records)} records are saved to table: {tablename}.")
def _query_postgres(self): """ Queries Postgres and returns a cursor to the results. """ postgres = PostgresHook(postgres_conn_id=self.postgres_conn_id) conn = postgres.get_conn() cursor = conn.cursor() cursor.execute(self.sql, self.parameters) return cursor
def test_bulk_dump(self): hook = PostgresHook() input_data = ["foo", "bar", "baz"] with hook.get_conn() as conn: with conn.cursor() as cur: cur.execute("CREATE TABLE {} (c VARCHAR)".format(self.table)) values = ",".join("('{}')".format(data) for data in input_data) cur.execute("INSERT INTO {} VALUES {}".format(self.table, values)) conn.commit() with NamedTemporaryFile() as f: hook.bulk_dump(self.table, f.name) f.seek(0) results = [line.rstrip().decode("utf-8") for line in f.readlines()] self.assertEqual(sorted(input_data), sorted(results))
def test_bulk_load(self): hook = PostgresHook() input_data = ["foo", "bar", "baz"] with hook.get_conn() as conn: with conn.cursor() as cur: cur.execute("CREATE TABLE {} (c VARCHAR)".format(self.table)) conn.commit() with NamedTemporaryFile() as f: f.write("\n".join(input_data).encode("utf-8")) f.flush() hook.bulk_load(self.table, f.name) cur.execute("SELECT * FROM {}".format(self.table)) results = [row[0] for row in cur.fetchall()] self.assertEqual(sorted(input_data), sorted(results))
def setUp(self): postgres = PostgresHook() with postgres.get_conn() as conn: with conn.cursor() as cur: for table in TABLES: cur.execute("DROP TABLE IF EXISTS {} CASCADE;".format(table)) cur.execute("CREATE TABLE {}(some_str varchar, some_num integer);" .format(table)) cur.execute( "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);", ('mock_row_content_1', 42) ) cur.execute( "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);", ('mock_row_content_2', 43) ) cur.execute( "INSERT INTO postgres_to_gcs_operator VALUES(%s, %s);", ('mock_row_content_3', 44) )
class RedshiftToS3Transfer(BaseOperator): """ Executes an UNLOAD command to s3 as a CSV with headers :param schema: reference to a specific schema in redshift database :type schema: string :param table: reference to a specific table in redshift database :type table: string :param s3_bucket: reference to a specific S3 bucket :type s3_bucket: string :param s3_key: reference to a specific S3 key :type s3_key: string :param redshift_conn_id: reference to a specific redshift database :type redshift_conn_id: string :param aws_conn_id: reference to a specific S3 connection :type aws_conn_id: string :param unload_options: reference to a list of UNLOAD options :type unload_options: list """ template_fields = () template_ext = () ui_color = '#ededed' @apply_defaults def __init__( self, schema, table, s3_bucket, s3_key, redshift_conn_id='redshift_default', aws_conn_id='aws_default', unload_options=tuple(), autocommit=False, parameters=None, include_header=False, *args, **kwargs): super(RedshiftToS3Transfer, self).__init__(*args, **kwargs) self.schema = schema self.table = table self.s3_bucket = s3_bucket self.s3_key = s3_key self.redshift_conn_id = redshift_conn_id self.aws_conn_id = aws_conn_id self.unload_options = unload_options self.autocommit = autocommit self.parameters = parameters self.include_header = include_header if self.include_header and \ 'PARALLEL OFF' not in [uo.upper().strip() for uo in unload_options]: self.unload_options = list(unload_options) + ['PARALLEL OFF', ] def execute(self, context): self.hook = PostgresHook(postgres_conn_id=self.redshift_conn_id) self.s3 = S3Hook(aws_conn_id=self.aws_conn_id) credentials = self.s3.get_credentials() unload_options = '\n\t\t\t'.join(self.unload_options) if self.include_header: self.log.info("Retrieving headers from %s.%s...", self.schema, self.table) columns_query = """SELECT column_name FROM information_schema.columns WHERE table_schema = '{schema}' AND table_name = '{table}' ORDER BY ordinal_position """.format(schema=self.schema, table=self.table) cursor = self.hook.get_conn().cursor() cursor.execute(columns_query) rows = cursor.fetchall() columns = [row[0] for row in rows] column_names = ', '.join("{0}".format(c) for c in columns) column_headers = ', '.join("\\'{0}\\'".format(c) for c in columns) column_castings = ', '.join("CAST({0} AS text) AS {0}".format(c) for c in columns) select_query = """SELECT {column_names} FROM (SELECT 2 sort_order, {column_castings} FROM {schema}.{table} UNION ALL SELECT 1 sort_order, {column_headers}) ORDER BY sort_order"""\ .format(column_names=column_names, column_castings=column_castings, column_headers=column_headers, schema=self.schema, table=self.table) else: select_query = "SELECT * FROM {schema}.{table}"\ .format(schema=self.schema, table=self.table) unload_query = """ UNLOAD ('{select_query}') TO 's3://{s3_bucket}/{s3_key}/{table}_' with credentials 'aws_access_key_id={access_key};aws_secret_access_key={secret_key}' {unload_options}; """.format(select_query=select_query, table=self.table, s3_bucket=self.s3_bucket, s3_key=self.s3_key, access_key=credentials.access_key, secret_key=credentials.secret_key, unload_options=unload_options) self.log.info('Executing UNLOAD command...') self.hook.run(unload_query, self.autocommit) self.log.info("UNLOAD command complete...")
class RedshiftToS3Transfer(BaseOperator): """ Executes an UNLOAD command to s3 as a CSV with headers :param schema: reference to a specific schema in redshift database :type schema: str :param table: reference to a specific table in redshift database :type table: str :param s3_bucket: reference to a specific S3 bucket :type s3_bucket: str :param s3_key: reference to a specific S3 key :type s3_key: str :param redshift_conn_id: reference to a specific redshift database :type redshift_conn_id: str :param aws_conn_id: reference to a specific S3 connection :type aws_conn_id: str :param verify: Whether or not to verify SSL certificates for S3 connection. By default SSL certificates are verified. You can provide the following values: - ``False``: do not validate SSL certificates. SSL will still be used (unless use_ssl is False), but SSL certificates will not be verified. - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. You can specify this argument if you want to use a different CA cert bundle than the one used by botocore. :type verify: bool or str :param unload_options: reference to a list of UNLOAD options :type unload_options: list """ template_fields = () template_ext = () ui_color = '#ededed' @apply_defaults def __init__( self, schema, table, s3_bucket, s3_key, redshift_conn_id='redshift_default', aws_conn_id='aws_default', verify=None, unload_options=tuple(), autocommit=False, include_header=False, *args, **kwargs): super().__init__(*args, **kwargs) self.schema = schema self.table = table self.s3_bucket = s3_bucket self.s3_key = s3_key self.redshift_conn_id = redshift_conn_id self.aws_conn_id = aws_conn_id self.verify = verify self.unload_options = unload_options self.autocommit = autocommit self.include_header = include_header if self.include_header and \ 'PARALLEL OFF' not in [uo.upper().strip() for uo in unload_options]: self.unload_options = list(unload_options) + ['PARALLEL OFF', ] def execute(self, context): self.hook = PostgresHook(postgres_conn_id=self.redshift_conn_id) self.s3 = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) credentials = self.s3.get_credentials() unload_options = '\n\t\t\t'.join(self.unload_options) if self.include_header: self.log.info("Retrieving headers from %s.%s...", self.schema, self.table) columns_query = """SELECT column_name FROM information_schema.columns WHERE table_schema = '{schema}' AND table_name = '{table}' ORDER BY ordinal_position """.format(schema=self.schema, table=self.table) cursor = self.hook.get_conn().cursor() cursor.execute(columns_query) rows = cursor.fetchall() columns = [row[0] for row in rows] column_names = ', '.join("{0}".format(c) for c in columns) column_headers = ', '.join("\\'{0}\\'".format(c) for c in columns) column_castings = ', '.join("CAST({0} AS text) AS {0}".format(c) for c in columns) select_query = """SELECT {column_names} FROM (SELECT 2 sort_order, {column_castings} FROM {schema}.{table} UNION ALL SELECT 1 sort_order, {column_headers}) ORDER BY sort_order"""\ .format(column_names=column_names, column_castings=column_castings, column_headers=column_headers, schema=self.schema, table=self.table) else: select_query = "SELECT * FROM {schema}.{table}"\ .format(schema=self.schema, table=self.table) unload_query = """ UNLOAD ('{select_query}') TO 's3://{s3_bucket}/{s3_key}/{table}_' with credentials 'aws_access_key_id={access_key};aws_secret_access_key={secret_key}' {unload_options}; """.format(select_query=select_query, table=self.table, s3_bucket=self.s3_bucket, s3_key=self.s3_key, access_key=credentials.access_key, secret_key=credentials.secret_key, unload_options=unload_options) self.log.info('Executing UNLOAD command...') self.hook.run(unload_query, self.autocommit) self.log.info("UNLOAD command complete...")
def tearDown(self): postgres = PostgresHook() with postgres.get_conn() as conn: with conn.cursor() as cur: for table in TABLES: cur.execute("DROP TABLE IF EXISTS {} CASCADE;".format(table))
class RedshiftToS3Transfer(BaseOperator): """ Executes an UNLOAD command to s3 as a CSV with headers :param schema: reference to a specific schema in redshift database :type schema: string :param table: reference to a specific table in redshift database :type table: string :param s3_bucket: reference to a specific S3 bucket :type s3_bucket: string :param s3_key: reference to a specific S3 key :type s3_key: string :param redshift_conn_id: reference to a specific redshift database :type redshift_conn_id: string :param s3_conn_id: reference to a specific S3 connection :type s3_conn_id: string :param options: reference to a list of UNLOAD options :type options: list """ template_fields = () template_ext = () ui_color = '#ededed' @apply_defaults def __init__( self, schema, table, s3_bucket, s3_key, redshift_conn_id='redshift_default', s3_conn_id='s3_default', unload_options=tuple(), autocommit=False, parameters=None, *args, **kwargs): super(RedshiftToS3Transfer, self).__init__(*args, **kwargs) self.schema = schema self.table = table self.s3_bucket = s3_bucket self.s3_key = s3_key self.redshift_conn_id = redshift_conn_id self.s3_conn_id = s3_conn_id self.unload_options = unload_options self.autocommit = autocommit self.parameters = parameters def execute(self, context): self.hook = PostgresHook(postgres_conn_id=self.redshift_conn_id) self.s3 = S3Hook(s3_conn_id=self.s3_conn_id) a_key, s_key = self.s3.get_credentials() unload_options = ('\n\t\t\t').join(self.unload_options) _log.info("Retrieving headers from %s.%s..." % (self.schema, self.table)) columns_query = """SELECT column_name FROM information_schema.columns WHERE table_schema = '{0}' AND table_name = '{1}' ORDER BY ordinal_position """.format(self.schema, self.table) cursor = self.hook.get_conn().cursor() cursor.execute(columns_query) rows = cursor.fetchall() columns = map(lambda row: row[0], rows) column_names = (', ').join(map(lambda c: "\\'{0}\\'".format(c), columns)) column_castings = (', ').join(map(lambda c: "CAST({0} AS text) AS {0}".format(c), columns)) unload_query = """ UNLOAD ('SELECT {0} UNION ALL SELECT {1} FROM {2}.{3}') TO 's3://{4}/{5}/{3}_' with credentials 'aws_access_key_id={6};aws_secret_access_key={7}' {8}; """.format(column_names, column_castings, self.schema, self.table, self.s3_bucket, self.s3_key, a_key, s_key, unload_options) _log.info('Executing UNLOAD command...') self.hook.run(unload_query, self.autocommit) _log.info("UNLOAD command complete...")