def execute(self, context): self.log.info('Executing: %s', self.sql) hook = OracleHook(oracle_conn_id=self.oracle_conn_id) hook.run( self.sql, autocommit=self.autocommit, parameters=self.parameters)
def execute(self, context): logging.info('Executing: ' + str(self.sql)) hook = OracleHook(oracle_conn_id=self.oracle_conn_id) hook.run( self.sql, autocommit=self.autocommit, parameters=self.parameters)
def execute(self, context): try: self.sql = return_sql(self.sql_file) self.log.info('Executing: %s', self.sql) hook = OracleHook(oracle_conn_id=self.oracle_conn_id) hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) except Exception as e: raise e
def query(self): """ Queries Oracle and returns a cursor to the results. """ self.log.info('Executing: %s %s', self.sql, self.parameters) hook = OracleHook(oracle_conn_id=self.oracle_conn_id) conn = hook.get_conn() cursor = conn.cursor() cursor.execute(self.sql) return cursor
def setUp(self): super(TestOracleHookConn, self).setUp() self.connection = Connection(login='******', password='******', host='host', port=1521) self.db_hook = OracleHook() self.db_hook.get_connection = mock.Mock() self.db_hook.get_connection.return_value = self.connection
def execute(self, context): """context : context objects has two sets of information 1. Task Instance This contains dag_id,hostname,jobid,key,taskid etc 2. Task This contains dag_id, upstream taskid and downstram taskid etc (Note: Even if the design contains to only downstream jobs airflow maps to the parent task which would invoke current operator) """ #Depending on the source database initiate the corresponding hook self.hook = OracleHook(oracle_conn_id=self.srcdb) conn = self.hook.conn dataframe = pandas.read_sql(self.sql_stmt, conn) tempfilename = "/home/ubuntu/test" + 'temp.csv' dataframe.to_csv(tempfilename, sep='|', header=False, index=False, encoding='utf-8') conn.close() #After persisting the csv file to temp location I am pushing the file location to the transform operator rather #than passing it as a parameter task_instance = context['task_instance'] task_instance.xcom_push(key=context['dag_run'].run_id, value=tempfilename) log.info("Extracttion complete ")
def execute(self, context): oracle_hook = OracleHook(oracle_conn_id=self.oracle_conn_id) azure_data_lake_hook = AzureDataLakeHook( azure_data_lake_conn_id=self.azure_data_lake_conn_id) self.log.info("Dumping Oracle query results to local file") conn = oracle_hook.get_conn() cursor = conn.cursor() cursor.execute(self.sql, self.sql_params) with TemporaryDirectory(prefix='airflow_oracle_to_azure_op_') as temp: self._write_temp_file(cursor, os.path.join(temp, self.filename)) self.log.info("Uploading local file to Azure Data Lake") azure_data_lake_hook.upload_file( os.path.join(temp, self.filename), os.path.join(self.azure_data_lake_path, self.filename)) cursor.close() conn.close()
def get_data(ds, **kwargs): with closing(OracleHook( oracle_conn_id='oracle_src').get_conn()) as src_conn: src_cursor = src_conn.cursor() refCursor1 = src_cursor.callfunc('SYSTEM.GET_DATA1', cx_Oracle.CURSOR) for row1 in refCursor1: print(row1)
def execute(self, context): oracle_hook = OracleHook(oracle_conn_id=self.oracle_conn_id) azure_data_lake_hook = AzureDataLakeHook( azure_data_lake_conn_id=self.azure_data_lake_conn_id) self.log.info("Dumping Oracle query results to local file") conn = oracle_hook.get_conn() cursor = conn.cursor() cursor.execute(self.sql, self.sql_params) with TemporaryDirectory(prefix='airflow_oracle_to_azure_op_') as temp: self._write_temp_file(cursor, os.path.join(temp, self.filename)) self.log.info("Uploading local file to Azure Data Lake") azure_data_lake_hook.upload_file(os.path.join(temp, self.filename), os.path.join(self.azure_data_lake_path, self.filename)) cursor.close() conn.close()
def get_hook(self): try: if self.conn_type == 'mysql': from airflow.hooks.mysql_hook import MySqlHook return MySqlHook(mysql_conn_id=self.conn_id) elif self.conn_type == 'google_cloud_platform': from airflow.contrib.hooks.bigquery_hook import BigQueryHook return BigQueryHook(bigquery_conn_id=self.conn_id) elif self.conn_type == 'postgres': from airflow.hooks.postgres_hook import PostgresHook return PostgresHook(postgres_conn_id=self.conn_id) elif self.conn_type == 'hive_cli': from airflow.hooks.hive_hooks import HiveCliHook return HiveCliHook(hive_cli_conn_id=self.conn_id) elif self.conn_type == 'presto': from airflow.hooks.presto_hook import PrestoHook return PrestoHook(presto_conn_id=self.conn_id) elif self.conn_type == 'hiveserver2': from airflow.hooks.hive_hooks import HiveServer2Hook return HiveServer2Hook(hiveserver2_conn_id=self.conn_id) elif self.conn_type == 'sqlite': from airflow.hooks.sqlite_hook import SqliteHook return SqliteHook(sqlite_conn_id=self.conn_id) elif self.conn_type == 'jdbc': from airflow.hooks.jdbc_hook import JdbcHook return JdbcHook(jdbc_conn_id=self.conn_id) elif self.conn_type == 'mssql': from airflow.hooks.mssql_hook import MsSqlHook return MsSqlHook(mssql_conn_id=self.conn_id) elif self.conn_type == 'oracle': from airflow.hooks.oracle_hook import OracleHook return OracleHook(oracle_conn_id=self.conn_id) elif self.conn_type == 'vertica': from airflow.contrib.hooks.vertica_hook import VerticaHook return VerticaHook(vertica_conn_id=self.conn_id) elif self.conn_type == 'cloudant': from airflow.contrib.hooks.cloudant_hook import CloudantHook return CloudantHook(cloudant_conn_id=self.conn_id) elif self.conn_type == 'jira': from airflow.contrib.hooks.jira_hook import JiraHook return JiraHook(jira_conn_id=self.conn_id) elif self.conn_type == 'redis': from airflow.contrib.hooks.redis_hook import RedisHook return RedisHook(redis_conn_id=self.conn_id) elif self.conn_type == 'wasb': from airflow.contrib.hooks.wasb_hook import WasbHook return WasbHook(wasb_conn_id=self.conn_id) elif self.conn_type == 'docker': from airflow.hooks.docker_hook import DockerHook return DockerHook(docker_conn_id=self.conn_id) except: pass
def get_clients(oracle_conn_id, ds, **context): import pandas as pd query = """SELECT CLIENT_ID, SRC_CLIENT_ID FROM DWH.WC_CLIENT_D GROUP BY SRC_CLIENT_ID, CLIENT_ID""" conn = OracleHook(oracle_conn_id=oracle_conn_id).get_conn() data = pd.read_sql(query, con=conn) data.columns = data.columns.str.lower() # get dictionary of values e.g. "{34: [283], 37: [1482, 284, 1602, 1543],57: [285]}" data = data.groupby('src_client_id')['client_id'].apply( lambda data: data.tolist()).to_dict() client_list = data return client_list
def setUp(self): super(TestOracleHookConn, self).setUp() self.connection = Connection( login='******', password='******', host='host', port=1521 ) self.db_hook = OracleHook() self.db_hook.get_connection = mock.Mock() self.db_hook.get_connection.return_value = self.connection
def get_ora_engine(oracle_conn_id="oracle_default") -> Engine: connection = OracleHook().get_connection(oracle_conn_id) user = connection.login password = connection.password host = connection.host port = connection.port db = connection.schema try: uri = f"oracle+cx_oracle://{user}:{password}@{host}:{port}/{db}?encoding=UTF-8&nencoding=UTF-8" # noqa: E501 return create_engine(uri, auto_convert_lobs=True) except SQLAlchemyError as e: raise AirflowException(str(e)) from e
def get_data(ds, **kwargs): with closing(OracleHook(oracle_conn_id='oracle_src').get_conn()) as src_conn: src_cursor = src_conn.cursor() src_cursor.execute( """ DECLARE l_cursor1 SYS_REFCURSOR; l_cursor2 SYS_REFCURSOR; BEGIN OPEN l_cursor1 FOR SELECT 1 as id, 'one' as name from dual union all SELECT 2 as id, 'two' as name from dual; OPEN l_cursor2 FOR SELECT 3 as id, 'tree' as name from dual union all SELECT 4 as id, 'four' as name from dual; DBMS_SQL.RETURN_RESULT(l_cursor1); DBMS_SQL.RETURN_RESULT(l_cursor2); END; """ ) for implicitCursor in src_cursor.getimplicitresults(): for row in implicitCursor: print(row)
def execute(self, context): # создание подключений with closing( OracleHook(oracle_conn_id=self.oracle_conn_id).get_conn() ) as src_conn: with closing( PostgresHook(postgres_conn_id=self.postgres_conn_id). get_conn()) as tgt_conn: if self.params['is_active']: # если src_query не пустое, используем его значение, иначе конструируем запрос из полей fields, src_schema, src_table if self.params['src_query']: select_query = "with data as ({src_query}) select {fields} from data".format( **self.params) else: select_query = "select {fields} from {src_schema}.{src_table}".format( **self.params) # при наличии добавляем conditions if self.params['use_conditions']: select_query += " where {conditions}".format( **self.params) logging.info(f"Run query: \"{select_query}\"") # открытие курсоров на чтение и запись src_cursor = src_conn.cursor("serverCursor") tgt_cursor = tgt_conn.cursor() # выполнение запроса src_cursor.execute(select_query) # очистка Stage-таблицы tgt_cursor.execute( "truncate table {tgt_schema}.{target_table_prefix}{tgt_table}" .format(**self.params)) # обработка результата запроса batch_count = 0 while True: logging.info( f"Processing batch:\t{batch_count},\tsize: {self.batch_size}" ) records = src_cursor.fetchmany(self.batch_size) if records: execute_values( # вставка батча данных tgt_cursor, "INSERT INTO {tgt_schema}.{target_table_prefix}{tgt_table} ({fields}) VALUES %s" .format(**self.params), records) else: logging.info( "Передача данных из таблицы {src_schema}.{src_table} завершена" .format(**self.params)) break batch_count += 1 tgt_conn.commit() src_cursor.close() tgt_cursor.close() else: logging.info( "Таблица {src_schema}.{src_table} не активна в метаданных, пропускаем..." .format(**self.params))
def get_hook(self): if self.conn_type == 'mysql': from airflow.hooks.mysql_hook import MySqlHook return MySqlHook(mysql_conn_id=self.conn_id) elif self.conn_type == 'google_cloud_platform': from airflow.gcp.hooks.bigquery import BigQueryHook return BigQueryHook(bigquery_conn_id=self.conn_id) elif self.conn_type == 'postgres': from airflow.hooks.postgres_hook import PostgresHook return PostgresHook(postgres_conn_id=self.conn_id) elif self.conn_type == 'pig_cli': from airflow.hooks.pig_hook import PigCliHook return PigCliHook(pig_cli_conn_id=self.conn_id) elif self.conn_type == 'hive_cli': from airflow.hooks.hive_hooks import HiveCliHook return HiveCliHook(hive_cli_conn_id=self.conn_id) elif self.conn_type == 'presto': from airflow.hooks.presto_hook import PrestoHook return PrestoHook(presto_conn_id=self.conn_id) elif self.conn_type == 'hiveserver2': from airflow.hooks.hive_hooks import HiveServer2Hook return HiveServer2Hook(hiveserver2_conn_id=self.conn_id) elif self.conn_type == 'sqlite': from airflow.hooks.sqlite_hook import SqliteHook return SqliteHook(sqlite_conn_id=self.conn_id) elif self.conn_type == 'jdbc': from airflow.hooks.jdbc_hook import JdbcHook return JdbcHook(jdbc_conn_id=self.conn_id) elif self.conn_type == 'mssql': from airflow.hooks.mssql_hook import MsSqlHook return MsSqlHook(mssql_conn_id=self.conn_id) elif self.conn_type == 'oracle': from airflow.hooks.oracle_hook import OracleHook return OracleHook(oracle_conn_id=self.conn_id) elif self.conn_type == 'vertica': from airflow.contrib.hooks.vertica_hook import VerticaHook return VerticaHook(vertica_conn_id=self.conn_id) elif self.conn_type == 'cloudant': from airflow.contrib.hooks.cloudant_hook import CloudantHook return CloudantHook(cloudant_conn_id=self.conn_id) elif self.conn_type == 'jira': from airflow.contrib.hooks.jira_hook import JiraHook return JiraHook(jira_conn_id=self.conn_id) elif self.conn_type == 'redis': from airflow.contrib.hooks.redis_hook import RedisHook return RedisHook(redis_conn_id=self.conn_id) elif self.conn_type == 'wasb': from airflow.contrib.hooks.wasb_hook import WasbHook return WasbHook(wasb_conn_id=self.conn_id) elif self.conn_type == 'docker': from airflow.hooks.docker_hook import DockerHook return DockerHook(docker_conn_id=self.conn_id) elif self.conn_type == 'azure_data_lake': from airflow.contrib.hooks.azure_data_lake_hook import AzureDataLakeHook return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id) elif self.conn_type == 'azure_cosmos': from airflow.contrib.hooks.azure_cosmos_hook import AzureCosmosDBHook return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id) elif self.conn_type == 'cassandra': from airflow.contrib.hooks.cassandra_hook import CassandraHook return CassandraHook(cassandra_conn_id=self.conn_id) elif self.conn_type == 'mongo': from airflow.contrib.hooks.mongo_hook import MongoHook return MongoHook(conn_id=self.conn_id) elif self.conn_type == 'gcpcloudsql': from airflow.gcp.hooks.cloud_sql import CloudSqlDatabaseHook return CloudSqlDatabaseHook(gcp_cloudsql_conn_id=self.conn_id) elif self.conn_type == 'grpc': from airflow.contrib.hooks.grpc_hook import GrpcHook return GrpcHook(grpc_conn_id=self.conn_id) raise AirflowException("Unknown hook type {}".format(self.conn_type))
def report(**kwargs): hook = OracleHook("New Era - WMP04") sql = """ SELECT TO_CHAR(O.ACTUALSHIPDATE, 'YYYY-MM') AS SHIP_MONTH, TRUNC(O.ACTUALSHIPDATE) AS SHIP_DATE, CASE WHEN O.TYPE IN ('HOTSEALCPN', 'MANUAL') THEN 'HEATSEAL BASE INV' ELSE 'OUTBOUND' END ORDER_TYPE, OD.LOTTABLE01 PLANT_CODE, COUNT(DISTINCT O.ORDERKEY) AS ORDERS, COUNT(OD.ORDERKEY) AS LINES, SUM(OD.SHIPPEDQTY) AS SHIPPEDQTY FROM ORDERS_DW@WMP04 O JOIN ORDERDETAIL_DW@WMP04 OD ON O.ORDERKEY = OD.ORDERKEY WHERE O.STATUS = 95 AND OD.SHIPPEDQTY > 0 AND O.STORERKEY = '1168' AND O.ACTUALSHIPDATE < TRUNC(SYSDATE) GROUP BY TO_CHAR(O.ACTUALSHIPDATE, 'YYYY-MM'), TRUNC(O.ACTUALSHIPDATE), CASE WHEN O.TYPE IN ('HOTSEALCPN', 'MANUAL') THEN 'HEATSEAL BASE INV' ELSE 'OUTBOUND' END, OD.LOTTABLE01 ORDER BY TRUNC(O.ACTUALSHIPDATE) DESC """ shipped_df = hook.get_pandas_df(sql=sql) shipped_df['SHIP_DATE'] = shipped_df['SHIP_DATE'].dt.date shipped_df_outbound = shipped_df[shipped_df['ORDER_TYPE'] == 'OUTBOUND'] shipped_df_heatseal = shipped_df[shipped_df['ORDER_TYPE'] == 'HEATSEAL BASE INV'] #Pull in YDAY detail for outbound orders sql = """ SELECT TO_CHAR(O.ACTUALSHIPDATE, 'YYYY-MM') AS SHIP_MONTH, TRUNC(O.ACTUALSHIPDATE) AS SHIP_DATE, O.TYPE ORDER_TYPE, OD.LOTTABLE01 PLANT_CODE, O.ORDERKEY XPO_ORDERKEY, O.EXTERNORDERKEY EXTERNORDERKEY, OD.SKU, OD.SHIPPEDQTY FROM ORDERS_DW@WMP04 O JOIN ORDERDETAIL_DW@WMP04 OD ON O.ORDERKEY = OD.ORDERKEY WHERE O.STATUS = 95 AND OD.SHIPPEDQTY > 0 AND O.STORERKEY = '1168' AND TRUNC(O.ACTUALSHIPDATE) = TRUNC(SYSDATE - 1) AND O.TYPE NOT IN ('HOTSEALCPN', 'MANUAL') ORDER BY O.ORDERKEY, OD.SKU """ shippedYdayDetail = hook.get_pandas_df(sql=sql) shippedYdayDetail['SHIP_DATE'] = shippedYdayDetail['SHIP_DATE'].dt.date #For OutBound orders # making monthly report and adding it to the daily level data and sorting it, with blank first so that # MTD total comes first in the sequence, then filling the blank SHIP_DATE with MTD total to display shipped_df_outbound_month = shipped_df_outbound.groupby( ['SHIP_MONTH', 'PLANT_CODE']).agg({ 'SHIPPEDQTY': sum, 'LINES': sum, 'ORDERS': sum }).reset_index().sort_values('SHIP_MONTH', ascending=False).round(decimals=2) shippedReportOutbound = shipped_df_outbound.append( shipped_df_outbound_month).sort_values(['SHIP_MONTH', 'SHIP_DATE'], ascending=[False, False], na_position='first') #filling the blank SHIP_DATE column with appropriate plant code name shippedReportOutbound['SHIP_DATE'] = shippedReportOutbound.apply( lambda x: 'MTD ' + x['PLANT_CODE'] + ' Total' if pd.isna(x['SHIP_DATE']) else x['SHIP_DATE'], axis=1) #rearranging the order of the variables shippedReportOutbound = shippedReportOutbound[[ 'SHIP_MONTH', 'SHIP_DATE', 'PLANT_CODE', 'ORDERS', 'LINES', 'SHIPPEDQTY' ]] #Putting Data of Outbound to report writer = pd.ExcelWriter(sourceFile + r'NEW ERA SHIPPED REPORT_' + todayDate + '.xlsx', engine='xlsxwriter') shippedReportOutbound.to_excel(writer, sheet_name='SHIPPED_OUTBOUND', startrow=0, index=False) #For HEATSEAL BASE INV Orders shipped_df_heatseal_month = shipped_df_heatseal.groupby( ['SHIP_MONTH']).agg({ 'SHIPPEDQTY': sum, 'LINES': sum, 'ORDERS': sum }).reset_index().sort_values('SHIP_MONTH', ascending=False).round(decimals=2) shippedReportheatseal = shipped_df_heatseal.append( shipped_df_heatseal_month).sort_values(['SHIP_MONTH', 'SHIP_DATE'], ascending=[False, False], na_position='first') shippedReportheatseal['SHIP_DATE'] = shippedReportheatseal[ 'SHIP_DATE'].fillna('MTD Total') shippedReportheatseal = shippedReportheatseal[[ 'SHIP_MONTH', 'SHIP_DATE', 'ORDERS', 'LINES', 'SHIPPEDQTY' ]] #Putting Data of heatseal to report shippedReportheatseal.to_excel(writer, sheet_name='SHIPPED_HEAT_SEAL', startrow=0, index=False) #Add Yday Detail Data shippedYdayDetail.to_excel(writer, sheet_name='YDAY SHIPPED_OUTBOUND DETAIL', startrow=0, index=False) #Center All colums formatCenter = writer.book.add_format() formatCenter.set_center_across() formatPercent = writer.book.add_format({ 'num_format': '0.0%', 'align': 'center' }) formatNum = writer.book.add_format({ 'num_format': '#,##0', 'align': 'center' }) #center align def set_center(sheetname, cols): writer.sheets[sheetname].set_column(cols, None, formatCenter) #center align def format_num(sheetname, cols): writer.sheets[sheetname].set_column(cols, None, formatNum) #function to adjust column width, taking inputs of df and sheet name def adjust_cols_width(df, sheet_name): width_list = [ max([len(str(s)) + 3 for s in df[col].values] + [len(col) + 3]) for col in df.columns ] for i in range(0, len(width_list)): writer.sheets[sheet_name].set_column(i, i, width_list[i]) set_center('SHIPPED_OUTBOUND', 'A:F') format_num('SHIPPED_OUTBOUND', 'D:F') adjust_cols_width(shippedReportOutbound, 'SHIPPED_OUTBOUND') set_center('SHIPPED_HEAT_SEAL', 'A:F') format_num('SHIPPED_HEAT_SEAL', 'D:F') adjust_cols_width(shippedReportheatseal, 'SHIPPED_HEAT_SEAL') set_center('YDAY SHIPPED_OUTBOUND DETAIL', 'A:H') format_num('YDAY SHIPPED_OUTBOUND DETAIL', 'H:H') adjust_cols_width(shippedYdayDetail, 'YDAY SHIPPED_OUTBOUND DETAIL') writer.save()
def execute(self, context): self.logger.info('Executing: %s', self.sql) hook = OracleHook(oracle_conn_id=self.oracle_conn_id) hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
class TestOracleHookConn(unittest.TestCase): def setUp(self): super(TestOracleHookConn, self).setUp() self.connection = Connection( login='******', password='******', host='host', port=1521 ) self.db_hook = OracleHook() self.db_hook.get_connection = mock.Mock() self.db_hook.get_connection.return_value = self.connection @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_host(self, mock_connect): self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['user'], 'login') self.assertEqual(kwargs['password'], 'password') self.assertEqual(kwargs['dsn'], 'host') @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_sid(self, mock_connect): dsn_sid = {'dsn': 'dsn', 'sid': 'sid'} self.connection.extra = json.dumps(dsn_sid) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['dsn'], cx_Oracle.makedsn(dsn_sid['dsn'], self.connection.port, dsn_sid['sid'])) @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_service_name(self, mock_connect): dsn_service_name = {'dsn': 'dsn', 'service_name': 'service_name'} self.connection.extra = json.dumps(dsn_service_name) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['dsn'], cx_Oracle.makedsn( dsn_service_name['dsn'], self.connection.port, service_name=dsn_service_name['service_name'])) @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_encoding_without_nencoding(self, mock_connect): self.connection.extra = json.dumps({'encoding': 'UTF-8'}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['encoding'], 'UTF-8') self.assertEqual(kwargs['nencoding'], 'UTF-8') @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_encoding_with_nencoding(self, mock_connect): self.connection.extra = json.dumps({'encoding': 'UTF-8', 'nencoding': 'gb2312'}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['encoding'], 'UTF-8') self.assertEqual(kwargs['nencoding'], 'gb2312') @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_nencoding(self, mock_connect): self.connection.extra = json.dumps({'nencoding': 'UTF-8'}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertNotIn('encoding', kwargs) self.assertEqual(kwargs['nencoding'], 'UTF-8') @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_mode(self, mock_connect): mode = { 'sysdba': cx_Oracle.SYSDBA, 'sysasm': cx_Oracle.SYSASM, 'sysoper': cx_Oracle.SYSOPER, 'sysbkp': cx_Oracle.SYSBKP, 'sysdgd': cx_Oracle.SYSDGD, 'syskmt': cx_Oracle.SYSKMT, } first = True for m in mode: self.connection.extra = json.dumps({'mode': m}) self.db_hook.get_conn() if first: assert mock_connect.call_count == 1 first = False args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['mode'], mode.get(m)) @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_threaded(self, mock_connect): self.connection.extra = json.dumps({'threaded': True}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['threaded'], True) @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_events(self, mock_connect): self.connection.extra = json.dumps({'events': True}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['events'], True) @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_purity(self, mock_connect): purity = { 'new': cx_Oracle.ATTR_PURITY_NEW, 'self': cx_Oracle.ATTR_PURITY_SELF, 'default': cx_Oracle.ATTR_PURITY_DEFAULT } first = True for p in purity: self.connection.extra = json.dumps({'purity': p}) self.db_hook.get_conn() if first: assert mock_connect.call_count == 1 first = False args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['purity'], purity.get(p))
def get_clv(oracle_conn_id, src_client_id, storage_bucket, ds, **context): import matplotlib.pyplot matplotlib.pyplot.ioff() ## from lifetimes.utils import calibration_and_holdout_data from lifetimes.plotting import plot_frequency_recency_matrix from lifetimes.plotting import plot_probability_alive_matrix from lifetimes.plotting import plot_calibration_purchases_vs_holdout_purchases from lifetimes.plotting import plot_period_transactions from lifetimes.plotting import plot_history_alive from lifetimes.plotting import plot_cumulative_transactions from lifetimes.utils import expected_cumulative_transactions from lifetimes.utils import summary_data_from_transaction_data from lifetimes import BetaGeoFitter from lifetimes import GammaGammaFitter import datetime import pandas as pd import datalab.storage as gcs conn = OracleHook(oracle_conn_id=oracle_conn_id).get_conn() print(src_client_id, context) query = context['templates_dict']['query'] data = pd.read_sql(query, con=conn) data.columns = data.columns.str.lower() print(data.head()) # Calculate RFM values# calibration_end_date = datetime.datetime(2018, 5, 24) training_rfm = calibration_and_holdout_data( transactions=data, customer_id_col='src_user_id', datetime_col='pickup_date', calibration_period_end=calibration_end_date, freq='D', monetary_value_col='price_total') bgf = BetaGeoFitter(penalizer_coef=0.0) bgf.fit(training_rfm['frequency_cal'], training_rfm['recency_cal'], training_rfm['T_cal']) print(bgf) # Matrix charts plot_period_transactions_chart = context.get("ds_nodash") + str( src_client_id) + '_plot_period_transactions_chart.svg' plot_frequency_recency_chart = context.get("ds_nodash") + str( src_client_id) + '_plot_frequency_recency_matrix.svg' plot_probability_chart = context.get("ds_nodash") + str( src_client_id) + '_plot_probability_alive_matrix.svg' plot_calibration_vs_holdout_chart = context.get("ds_nodash") + str( src_client_id) + '_plot_calibration_vs_holdout_purchases.svg' ax0 = plot_period_transactions(bgf, max_frequency=30) ax0.figure.savefig(plot_period_transactions_chart, format='svg') ax1 = plot_frequency_recency_matrix(bgf) ax1.figure.savefig(plot_frequency_recency_chart, format='svg') ax2 = plot_probability_alive_matrix(bgf) ax2.figure.savefig(plot_probability_chart, format='svg') ax3 = plot_calibration_purchases_vs_holdout_purchases(bgf, training_rfm, n=50) ax3.figure.savefig(plot_calibration_vs_holdout_chart, format='svg') full_rfm = summary_data_from_transaction_data( data, customer_id_col='src_user_id', datetime_col='pickup_date', monetary_value_col='price_total', datetime_format=None, observation_period_end=None, freq='D') returning_full_rfm = full_rfm[full_rfm['frequency'] > 0] ggf = GammaGammaFitter(penalizer_coef=0) ggf.fit(returning_full_rfm['frequency'], returning_full_rfm['monetary_value']) customer_lifetime = 30 # expected number of months lifetime of a customer clv = ggf.customer_lifetime_value( bgf, #the model to use to predict the number of future transactions full_rfm['frequency'], full_rfm['recency'], full_rfm['T'], full_rfm['monetary_value'], time=customer_lifetime, # months discount_rate=0.01 # monthly discount rate ~ 12.7% annually ).sort_values(ascending=False) full_rfm_with_value = full_rfm.join(clv) full_rfm_file = context.get("ds_nodash") + "-src_client_id-" + str( src_client_id) + '-icabbi-test.csv' full_rfm_with_value.to_csv(full_rfm_file) GoogleCloudStorageHook( google_cloud_storage_conn_id='google_conn_default').upload( bucket=storage_bucket, object=str(src_client_id) + "/" + context.get("ds_nodash") + "/" + full_rfm_file, filename=full_rfm_file) GoogleCloudStorageHook( google_cloud_storage_conn_id='google_conn_default').upload( bucket=storage_bucket, object=str(src_client_id) + "/" + context.get("ds_nodash") + "/" + plot_period_transactions_chart, filename=full_rfm_file) GoogleCloudStorageHook( google_cloud_storage_conn_id='google_conn_default').upload( bucket=storage_bucket, object=str(src_client_id) + "/" + context.get("ds_nodash") + "/" + plot_frequency_recency_chart, filename=full_rfm_file) GoogleCloudStorageHook( google_cloud_storage_conn_id='google_conn_default').upload( bucket=storage_bucket, object=str(src_client_id) + "/" + context.get("ds_nodash") + "/" + plot_probability_chart, filename=full_rfm_file) GoogleCloudStorageHook( google_cloud_storage_conn_id='google_conn_default').upload( bucket=storage_bucket, object=str(src_client_id) + "/" + context.get("ds_nodash") + "/" + plot_calibration_vs_holdout_chart, filename=full_rfm_file)
def execute(self, context): src_hook = OracleHook(oracle_conn_id=self.oracle_source_conn_id) dest_hook = OracleHook(oracle_conn_id=self.oracle_destination_conn_id) self._execute(src_hook, dest_hook, context)
def execute(self, context): logging.info("Executing: " + str(self.sql)) hook = OracleHook(oracle_conn_id=self.oracle_conn_id) hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
class TestOracleHookConn(unittest.TestCase): def setUp(self): super(TestOracleHookConn, self).setUp() self.connection = Connection(login='******', password='******', host='host', port=1521) self.db_hook = OracleHook() self.db_hook.get_connection = mock.Mock() self.db_hook.get_connection.return_value = self.connection @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_host(self, mock_connect): self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['user'], 'login') self.assertEqual(kwargs['password'], 'password') self.assertEqual(kwargs['dsn'], 'host') @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_sid(self, mock_connect): dsn_sid = {'dsn': 'dsn', 'sid': 'sid'} self.connection.extra = json.dumps(dsn_sid) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual( kwargs['dsn'], cx_Oracle.makedsn(dsn_sid['dsn'], self.connection.port, dsn_sid['sid'])) @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_service_name(self, mock_connect): dsn_service_name = {'dsn': 'dsn', 'service_name': 'service_name'} self.connection.extra = json.dumps(dsn_service_name) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual( kwargs['dsn'], cx_Oracle.makedsn(dsn_service_name['dsn'], self.connection.port, service_name=dsn_service_name['service_name'])) @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_encoding_without_nencoding(self, mock_connect): self.connection.extra = json.dumps({'encoding': 'UTF-8'}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['encoding'], 'UTF-8') self.assertEqual(kwargs['nencoding'], 'UTF-8') @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_encoding_with_nencoding(self, mock_connect): self.connection.extra = json.dumps({ 'encoding': 'UTF-8', 'nencoding': 'gb2312' }) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['encoding'], 'UTF-8') self.assertEqual(kwargs['nencoding'], 'gb2312') @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_nencoding(self, mock_connect): self.connection.extra = json.dumps({'nencoding': 'UTF-8'}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertNotIn('encoding', kwargs) self.assertEqual(kwargs['nencoding'], 'UTF-8') @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_mode(self, mock_connect): mode = { 'sysdba': cx_Oracle.SYSDBA, 'sysasm': cx_Oracle.SYSASM, 'sysoper': cx_Oracle.SYSOPER, 'sysbkp': cx_Oracle.SYSBKP, 'sysdgd': cx_Oracle.SYSDGD, 'syskmt': cx_Oracle.SYSKMT, } first = True for m in mode: self.connection.extra = json.dumps({'mode': m}) self.db_hook.get_conn() if first: assert mock_connect.call_count == 1 first = False args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['mode'], mode.get(m)) @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_threaded(self, mock_connect): self.connection.extra = json.dumps({'threaded': True}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['threaded'], True) @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_events(self, mock_connect): self.connection.extra = json.dumps({'events': True}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['events'], True) @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect') def test_get_conn_purity(self, mock_connect): purity = { 'new': cx_Oracle.ATTR_PURITY_NEW, 'self': cx_Oracle.ATTR_PURITY_SELF, 'default': cx_Oracle.ATTR_PURITY_DEFAULT } first = True for p in purity: self.connection.extra = json.dumps({'purity': p}) self.db_hook.get_conn() if first: assert mock_connect.call_count == 1 first = False args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['purity'], purity.get(p))