コード例 #1
0
 def execute(self, context):
     self.log.info('Executing: %s', self.sql)
     hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
     hook.run(
         self.sql,
         autocommit=self.autocommit,
         parameters=self.parameters)
コード例 #2
0
 def execute(self, context):
     logging.info('Executing: ' + str(self.sql))
     hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
     hook.run(
         self.sql,
         autocommit=self.autocommit,
         parameters=self.parameters)
コード例 #3
0
 def execute(self, context):
     try:
         self.sql = return_sql(self.sql_file)
         self.log.info('Executing: %s', self.sql)
         hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
         hook.run(self.sql,
                  autocommit=self.autocommit,
                  parameters=self.parameters)
     except Exception as e:
         raise e
コード例 #4
0
 def query(self):
     """
     Queries Oracle and returns a cursor to the results.
     """
     self.log.info('Executing: %s %s', self.sql, self.parameters)
     hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
     conn = hook.get_conn()
     cursor = conn.cursor()
     cursor.execute(self.sql)
     return cursor
コード例 #5
0
    def setUp(self):
        super(TestOracleHookConn, self).setUp()

        self.connection = Connection(login='******',
                                     password='******',
                                     host='host',
                                     port=1521)

        self.db_hook = OracleHook()
        self.db_hook.get_connection = mock.Mock()
        self.db_hook.get_connection.return_value = self.connection
コード例 #6
0
ファイル: ExtractDataset.py プロジェクト: pavannpa/airflow
    def execute(self, context):
        """context : context objects has two sets of information 
						1. Task Instance 
								This contains dag_id,hostname,jobid,key,taskid etc
						2. Task
								This contains dag_id, upstream taskid and downstram taskid etc
								(Note: Even if the design contains to only downstream jobs airflow maps to the parent 
								task which would invoke current operator)
								"""
        #Depending on the source database initiate the corresponding hook
        self.hook = OracleHook(oracle_conn_id=self.srcdb)
        conn = self.hook.conn
        dataframe = pandas.read_sql(self.sql_stmt, conn)
        tempfilename = "/home/ubuntu/test" + 'temp.csv'
        dataframe.to_csv(tempfilename,
                         sep='|',
                         header=False,
                         index=False,
                         encoding='utf-8')
        conn.close()

        #After persisting the csv file to temp location I am pushing the file location to the transform operator rather
        #than passing it as a parameter
        task_instance = context['task_instance']
        task_instance.xcom_push(key=context['dag_run'].run_id,
                                value=tempfilename)
        log.info("Extracttion complete ")
コード例 #7
0
    def execute(self, context):
        oracle_hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
        azure_data_lake_hook = AzureDataLakeHook(
            azure_data_lake_conn_id=self.azure_data_lake_conn_id)

        self.log.info("Dumping Oracle query results to local file")
        conn = oracle_hook.get_conn()
        cursor = conn.cursor()
        cursor.execute(self.sql, self.sql_params)

        with TemporaryDirectory(prefix='airflow_oracle_to_azure_op_') as temp:
            self._write_temp_file(cursor, os.path.join(temp, self.filename))
            self.log.info("Uploading local file to Azure Data Lake")
            azure_data_lake_hook.upload_file(
                os.path.join(temp, self.filename),
                os.path.join(self.azure_data_lake_path, self.filename))
        cursor.close()
        conn.close()
コード例 #8
0
ファイル: test_sp.py プロジェクト: sagaranin/dags
    def get_data(ds, **kwargs):
        with closing(OracleHook(
                oracle_conn_id='oracle_src').get_conn()) as src_conn:
            src_cursor = src_conn.cursor()
            refCursor1 = src_cursor.callfunc('SYSTEM.GET_DATA1',
                                             cx_Oracle.CURSOR)

            for row1 in refCursor1:
                print(row1)
    def execute(self, context):
        oracle_hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
        azure_data_lake_hook = AzureDataLakeHook(
            azure_data_lake_conn_id=self.azure_data_lake_conn_id)

        self.log.info("Dumping Oracle query results to local file")
        conn = oracle_hook.get_conn()
        cursor = conn.cursor()
        cursor.execute(self.sql, self.sql_params)

        with TemporaryDirectory(prefix='airflow_oracle_to_azure_op_') as temp:
            self._write_temp_file(cursor, os.path.join(temp, self.filename))
            self.log.info("Uploading local file to Azure Data Lake")
            azure_data_lake_hook.upload_file(os.path.join(temp, self.filename),
                                             os.path.join(self.azure_data_lake_path,
                                                          self.filename))
        cursor.close()
        conn.close()
コード例 #10
0
 def get_hook(self):
     try:
         if self.conn_type == 'mysql':
             from airflow.hooks.mysql_hook import MySqlHook
             return MySqlHook(mysql_conn_id=self.conn_id)
         elif self.conn_type == 'google_cloud_platform':
             from airflow.contrib.hooks.bigquery_hook import BigQueryHook
             return BigQueryHook(bigquery_conn_id=self.conn_id)
         elif self.conn_type == 'postgres':
             from airflow.hooks.postgres_hook import PostgresHook
             return PostgresHook(postgres_conn_id=self.conn_id)
         elif self.conn_type == 'hive_cli':
             from airflow.hooks.hive_hooks import HiveCliHook
             return HiveCliHook(hive_cli_conn_id=self.conn_id)
         elif self.conn_type == 'presto':
             from airflow.hooks.presto_hook import PrestoHook
             return PrestoHook(presto_conn_id=self.conn_id)
         elif self.conn_type == 'hiveserver2':
             from airflow.hooks.hive_hooks import HiveServer2Hook
             return HiveServer2Hook(hiveserver2_conn_id=self.conn_id)
         elif self.conn_type == 'sqlite':
             from airflow.hooks.sqlite_hook import SqliteHook
             return SqliteHook(sqlite_conn_id=self.conn_id)
         elif self.conn_type == 'jdbc':
             from airflow.hooks.jdbc_hook import JdbcHook
             return JdbcHook(jdbc_conn_id=self.conn_id)
         elif self.conn_type == 'mssql':
             from airflow.hooks.mssql_hook import MsSqlHook
             return MsSqlHook(mssql_conn_id=self.conn_id)
         elif self.conn_type == 'oracle':
             from airflow.hooks.oracle_hook import OracleHook
             return OracleHook(oracle_conn_id=self.conn_id)
         elif self.conn_type == 'vertica':
             from airflow.contrib.hooks.vertica_hook import VerticaHook
             return VerticaHook(vertica_conn_id=self.conn_id)
         elif self.conn_type == 'cloudant':
             from airflow.contrib.hooks.cloudant_hook import CloudantHook
             return CloudantHook(cloudant_conn_id=self.conn_id)
         elif self.conn_type == 'jira':
             from airflow.contrib.hooks.jira_hook import JiraHook
             return JiraHook(jira_conn_id=self.conn_id)
         elif self.conn_type == 'redis':
             from airflow.contrib.hooks.redis_hook import RedisHook
             return RedisHook(redis_conn_id=self.conn_id)
         elif self.conn_type == 'wasb':
             from airflow.contrib.hooks.wasb_hook import WasbHook
             return WasbHook(wasb_conn_id=self.conn_id)
         elif self.conn_type == 'docker':
             from airflow.hooks.docker_hook import DockerHook
             return DockerHook(docker_conn_id=self.conn_id)
     except:
         pass
コード例 #11
0
def get_clients(oracle_conn_id, ds, **context):
    import pandas as pd
    query = """SELECT CLIENT_ID, SRC_CLIENT_ID
                FROM DWH.WC_CLIENT_D
                GROUP BY SRC_CLIENT_ID, CLIENT_ID"""
    conn = OracleHook(oracle_conn_id=oracle_conn_id).get_conn()
    data = pd.read_sql(query, con=conn)
    data.columns = data.columns.str.lower()
    # get dictionary of values e.g. "{34: [283], 37: [1482, 284, 1602, 1543],57: [285]}"
    data = data.groupby('src_client_id')['client_id'].apply(
        lambda data: data.tolist()).to_dict()
    client_list = data
    return client_list
コード例 #12
0
    def setUp(self):
        super(TestOracleHookConn, self).setUp()

        self.connection = Connection(
            login='******',
            password='******',
            host='host',
            port=1521
        )

        self.db_hook = OracleHook()
        self.db_hook.get_connection = mock.Mock()
        self.db_hook.get_connection.return_value = self.connection
コード例 #13
0
def get_ora_engine(oracle_conn_id="oracle_default") -> Engine:
    connection = OracleHook().get_connection(oracle_conn_id)
    user = connection.login
    password = connection.password
    host = connection.host
    port = connection.port
    db = connection.schema

    try:
        uri = f"oracle+cx_oracle://{user}:{password}@{host}:{port}/{db}?encoding=UTF-8&nencoding=UTF-8"  # noqa: E501
        return create_engine(uri, auto_convert_lobs=True)
    except SQLAlchemyError as e:
        raise AirflowException(str(e)) from e
コード例 #14
0
ファイル: test_sp_ab.py プロジェクト: sagaranin/dags
    def get_data(ds, **kwargs):
        with closing(OracleHook(oracle_conn_id='oracle_src').get_conn()) as src_conn:
            src_cursor = src_conn.cursor()
            src_cursor.execute(
                """
                DECLARE  
                    l_cursor1 SYS_REFCURSOR;
                    l_cursor2 SYS_REFCURSOR;
                BEGIN
                    OPEN l_cursor1 FOR
                        SELECT 1 as id, 'one' as name from dual union all
                        SELECT 2 as id, 'two' as name from dual;
                    OPEN l_cursor2 FOR
                        SELECT 3 as id, 'tree' as name from dual union all
                        SELECT 4 as id, 'four' as name from dual;
                    DBMS_SQL.RETURN_RESULT(l_cursor1);
                    DBMS_SQL.RETURN_RESULT(l_cursor2);
                END;
                """
            )

            for implicitCursor in src_cursor.getimplicitresults():
                for row in implicitCursor:
                    print(row)
コード例 #15
0
    def execute(self, context):
        # создание подключений
        with closing(
                OracleHook(oracle_conn_id=self.oracle_conn_id).get_conn()
        ) as src_conn:
            with closing(
                    PostgresHook(postgres_conn_id=self.postgres_conn_id).
                    get_conn()) as tgt_conn:

                if self.params['is_active']:

                    # если src_query не пустое, используем его значение, иначе конструируем запрос из полей fields, src_schema, src_table
                    if self.params['src_query']:
                        select_query = "with data as ({src_query}) select {fields} from data".format(
                            **self.params)
                    else:
                        select_query = "select {fields} from {src_schema}.{src_table}".format(
                            **self.params)

                    # при наличии добавляем conditions
                    if self.params['use_conditions']:
                        select_query += " where {conditions}".format(
                            **self.params)

                    logging.info(f"Run query: \"{select_query}\"")

                    # открытие курсоров на чтение и запись
                    src_cursor = src_conn.cursor("serverCursor")
                    tgt_cursor = tgt_conn.cursor()

                    # выполнение запроса
                    src_cursor.execute(select_query)

                    # очистка Stage-таблицы
                    tgt_cursor.execute(
                        "truncate table {tgt_schema}.{target_table_prefix}{tgt_table}"
                        .format(**self.params))

                    # обработка результата запроса
                    batch_count = 0
                    while True:
                        logging.info(
                            f"Processing batch:\t{batch_count},\tsize: {self.batch_size}"
                        )
                        records = src_cursor.fetchmany(self.batch_size)

                        if records:
                            execute_values(  # вставка батча данных
                                tgt_cursor,
                                "INSERT INTO {tgt_schema}.{target_table_prefix}{tgt_table} ({fields}) VALUES %s"
                                .format(**self.params), records)
                        else:
                            logging.info(
                                "Передача данных из таблицы {src_schema}.{src_table} завершена"
                                .format(**self.params))
                            break

                        batch_count += 1

                    tgt_conn.commit()

                    src_cursor.close()
                    tgt_cursor.close()

                else:
                    logging.info(
                        "Таблица {src_schema}.{src_table} не активна в метаданных, пропускаем..."
                        .format(**self.params))
コード例 #16
0
 def get_hook(self):
     if self.conn_type == 'mysql':
         from airflow.hooks.mysql_hook import MySqlHook
         return MySqlHook(mysql_conn_id=self.conn_id)
     elif self.conn_type == 'google_cloud_platform':
         from airflow.gcp.hooks.bigquery import BigQueryHook
         return BigQueryHook(bigquery_conn_id=self.conn_id)
     elif self.conn_type == 'postgres':
         from airflow.hooks.postgres_hook import PostgresHook
         return PostgresHook(postgres_conn_id=self.conn_id)
     elif self.conn_type == 'pig_cli':
         from airflow.hooks.pig_hook import PigCliHook
         return PigCliHook(pig_cli_conn_id=self.conn_id)
     elif self.conn_type == 'hive_cli':
         from airflow.hooks.hive_hooks import HiveCliHook
         return HiveCliHook(hive_cli_conn_id=self.conn_id)
     elif self.conn_type == 'presto':
         from airflow.hooks.presto_hook import PrestoHook
         return PrestoHook(presto_conn_id=self.conn_id)
     elif self.conn_type == 'hiveserver2':
         from airflow.hooks.hive_hooks import HiveServer2Hook
         return HiveServer2Hook(hiveserver2_conn_id=self.conn_id)
     elif self.conn_type == 'sqlite':
         from airflow.hooks.sqlite_hook import SqliteHook
         return SqliteHook(sqlite_conn_id=self.conn_id)
     elif self.conn_type == 'jdbc':
         from airflow.hooks.jdbc_hook import JdbcHook
         return JdbcHook(jdbc_conn_id=self.conn_id)
     elif self.conn_type == 'mssql':
         from airflow.hooks.mssql_hook import MsSqlHook
         return MsSqlHook(mssql_conn_id=self.conn_id)
     elif self.conn_type == 'oracle':
         from airflow.hooks.oracle_hook import OracleHook
         return OracleHook(oracle_conn_id=self.conn_id)
     elif self.conn_type == 'vertica':
         from airflow.contrib.hooks.vertica_hook import VerticaHook
         return VerticaHook(vertica_conn_id=self.conn_id)
     elif self.conn_type == 'cloudant':
         from airflow.contrib.hooks.cloudant_hook import CloudantHook
         return CloudantHook(cloudant_conn_id=self.conn_id)
     elif self.conn_type == 'jira':
         from airflow.contrib.hooks.jira_hook import JiraHook
         return JiraHook(jira_conn_id=self.conn_id)
     elif self.conn_type == 'redis':
         from airflow.contrib.hooks.redis_hook import RedisHook
         return RedisHook(redis_conn_id=self.conn_id)
     elif self.conn_type == 'wasb':
         from airflow.contrib.hooks.wasb_hook import WasbHook
         return WasbHook(wasb_conn_id=self.conn_id)
     elif self.conn_type == 'docker':
         from airflow.hooks.docker_hook import DockerHook
         return DockerHook(docker_conn_id=self.conn_id)
     elif self.conn_type == 'azure_data_lake':
         from airflow.contrib.hooks.azure_data_lake_hook import AzureDataLakeHook
         return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id)
     elif self.conn_type == 'azure_cosmos':
         from airflow.contrib.hooks.azure_cosmos_hook import AzureCosmosDBHook
         return AzureCosmosDBHook(azure_cosmos_conn_id=self.conn_id)
     elif self.conn_type == 'cassandra':
         from airflow.contrib.hooks.cassandra_hook import CassandraHook
         return CassandraHook(cassandra_conn_id=self.conn_id)
     elif self.conn_type == 'mongo':
         from airflow.contrib.hooks.mongo_hook import MongoHook
         return MongoHook(conn_id=self.conn_id)
     elif self.conn_type == 'gcpcloudsql':
         from airflow.gcp.hooks.cloud_sql import CloudSqlDatabaseHook
         return CloudSqlDatabaseHook(gcp_cloudsql_conn_id=self.conn_id)
     elif self.conn_type == 'grpc':
         from airflow.contrib.hooks.grpc_hook import GrpcHook
         return GrpcHook(grpc_conn_id=self.conn_id)
     raise AirflowException("Unknown hook type {}".format(self.conn_type))
コード例 #17
0
def report(**kwargs):
    hook = OracleHook("New Era - WMP04")

    sql = """
    SELECT   TO_CHAR(O.ACTUALSHIPDATE, 'YYYY-MM') AS SHIP_MONTH,
             TRUNC(O.ACTUALSHIPDATE) AS SHIP_DATE,
             CASE WHEN O.TYPE IN ('HOTSEALCPN', 'MANUAL') THEN 'HEATSEAL BASE INV' ELSE 'OUTBOUND' END ORDER_TYPE,
             OD.LOTTABLE01 PLANT_CODE,
             COUNT(DISTINCT O.ORDERKEY) AS ORDERS,
             COUNT(OD.ORDERKEY) AS LINES,
             SUM(OD.SHIPPEDQTY) AS SHIPPEDQTY
        FROM ORDERS_DW@WMP04 O JOIN ORDERDETAIL_DW@WMP04 OD ON O.ORDERKEY = OD.ORDERKEY
       WHERE O.STATUS = 95
         AND OD.SHIPPEDQTY > 0
         AND O.STORERKEY = '1168'
         AND O.ACTUALSHIPDATE < TRUNC(SYSDATE)
    GROUP BY TO_CHAR(O.ACTUALSHIPDATE, 'YYYY-MM'),
             TRUNC(O.ACTUALSHIPDATE),
             CASE WHEN O.TYPE IN ('HOTSEALCPN', 'MANUAL') THEN 'HEATSEAL BASE INV' ELSE 'OUTBOUND' END,
             OD.LOTTABLE01
    ORDER BY TRUNC(O.ACTUALSHIPDATE) DESC        
    """

    shipped_df = hook.get_pandas_df(sql=sql)

    shipped_df['SHIP_DATE'] = shipped_df['SHIP_DATE'].dt.date

    shipped_df_outbound = shipped_df[shipped_df['ORDER_TYPE'] == 'OUTBOUND']
    shipped_df_heatseal = shipped_df[shipped_df['ORDER_TYPE'] ==
                                     'HEATSEAL BASE INV']

    #Pull in YDAY detail for outbound orders
    sql = """
    SELECT  TO_CHAR(O.ACTUALSHIPDATE, 'YYYY-MM') AS SHIP_MONTH,
         TRUNC(O.ACTUALSHIPDATE) AS SHIP_DATE,
         O.TYPE ORDER_TYPE,
         OD.LOTTABLE01 PLANT_CODE,
         O.ORDERKEY XPO_ORDERKEY,
         O.EXTERNORDERKEY EXTERNORDERKEY,
         OD.SKU,
         OD.SHIPPEDQTY
    FROM ORDERS_DW@WMP04 O JOIN ORDERDETAIL_DW@WMP04 OD ON O.ORDERKEY = OD.ORDERKEY
   WHERE O.STATUS = 95
     AND OD.SHIPPEDQTY > 0
     AND O.STORERKEY = '1168'
     AND TRUNC(O.ACTUALSHIPDATE) = TRUNC(SYSDATE - 1)
     AND O.TYPE NOT IN ('HOTSEALCPN', 'MANUAL')
ORDER BY O.ORDERKEY,
         OD.SKU       
    """

    shippedYdayDetail = hook.get_pandas_df(sql=sql)

    shippedYdayDetail['SHIP_DATE'] = shippedYdayDetail['SHIP_DATE'].dt.date

    #For OutBound orders
    # making monthly report and adding it to the daily level data and sorting it, with blank first so that
    # MTD total comes first in the sequence, then filling the blank SHIP_DATE with MTD total to display
    shipped_df_outbound_month = shipped_df_outbound.groupby(
        ['SHIP_MONTH', 'PLANT_CODE']).agg({
            'SHIPPEDQTY': sum,
            'LINES': sum,
            'ORDERS': sum
        }).reset_index().sort_values('SHIP_MONTH',
                                     ascending=False).round(decimals=2)
    shippedReportOutbound = shipped_df_outbound.append(
        shipped_df_outbound_month).sort_values(['SHIP_MONTH', 'SHIP_DATE'],
                                               ascending=[False, False],
                                               na_position='first')

    #filling the blank SHIP_DATE column with appropriate plant code name
    shippedReportOutbound['SHIP_DATE'] = shippedReportOutbound.apply(
        lambda x: 'MTD ' + x['PLANT_CODE'] + ' Total'
        if pd.isna(x['SHIP_DATE']) else x['SHIP_DATE'],
        axis=1)

    #rearranging the order of the variables
    shippedReportOutbound = shippedReportOutbound[[
        'SHIP_MONTH', 'SHIP_DATE', 'PLANT_CODE', 'ORDERS', 'LINES',
        'SHIPPEDQTY'
    ]]

    #Putting Data of Outbound to report
    writer = pd.ExcelWriter(sourceFile + r'NEW ERA SHIPPED REPORT_' +
                            todayDate + '.xlsx',
                            engine='xlsxwriter')
    shippedReportOutbound.to_excel(writer,
                                   sheet_name='SHIPPED_OUTBOUND',
                                   startrow=0,
                                   index=False)

    #For HEATSEAL BASE INV Orders
    shipped_df_heatseal_month = shipped_df_heatseal.groupby(
        ['SHIP_MONTH']).agg({
            'SHIPPEDQTY': sum,
            'LINES': sum,
            'ORDERS': sum
        }).reset_index().sort_values('SHIP_MONTH',
                                     ascending=False).round(decimals=2)
    shippedReportheatseal = shipped_df_heatseal.append(
        shipped_df_heatseal_month).sort_values(['SHIP_MONTH', 'SHIP_DATE'],
                                               ascending=[False, False],
                                               na_position='first')
    shippedReportheatseal['SHIP_DATE'] = shippedReportheatseal[
        'SHIP_DATE'].fillna('MTD Total')
    shippedReportheatseal = shippedReportheatseal[[
        'SHIP_MONTH', 'SHIP_DATE', 'ORDERS', 'LINES', 'SHIPPEDQTY'
    ]]

    #Putting Data of heatseal to report
    shippedReportheatseal.to_excel(writer,
                                   sheet_name='SHIPPED_HEAT_SEAL',
                                   startrow=0,
                                   index=False)

    #Add Yday Detail Data
    shippedYdayDetail.to_excel(writer,
                               sheet_name='YDAY SHIPPED_OUTBOUND DETAIL',
                               startrow=0,
                               index=False)

    #Center All colums
    formatCenter = writer.book.add_format()
    formatCenter.set_center_across()
    formatPercent = writer.book.add_format({
        'num_format': '0.0%',
        'align': 'center'
    })
    formatNum = writer.book.add_format({
        'num_format': '#,##0',
        'align': 'center'
    })

    #center align
    def set_center(sheetname, cols):
        writer.sheets[sheetname].set_column(cols, None, formatCenter)

    #center align
    def format_num(sheetname, cols):
        writer.sheets[sheetname].set_column(cols, None, formatNum)

    #function to adjust column width, taking inputs of df and sheet name
    def adjust_cols_width(df, sheet_name):
        width_list = [
            max([len(str(s)) + 3 for s in df[col].values] + [len(col) + 3])
            for col in df.columns
        ]
        for i in range(0, len(width_list)):
            writer.sheets[sheet_name].set_column(i, i, width_list[i])

    set_center('SHIPPED_OUTBOUND', 'A:F')
    format_num('SHIPPED_OUTBOUND', 'D:F')
    adjust_cols_width(shippedReportOutbound, 'SHIPPED_OUTBOUND')

    set_center('SHIPPED_HEAT_SEAL', 'A:F')
    format_num('SHIPPED_HEAT_SEAL', 'D:F')
    adjust_cols_width(shippedReportheatseal, 'SHIPPED_HEAT_SEAL')

    set_center('YDAY SHIPPED_OUTBOUND DETAIL', 'A:H')
    format_num('YDAY SHIPPED_OUTBOUND DETAIL', 'H:H')
    adjust_cols_width(shippedYdayDetail, 'YDAY SHIPPED_OUTBOUND DETAIL')

    writer.save()
コード例 #18
0
 def execute(self, context):
     self.logger.info('Executing: %s', self.sql)
     hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
     hook.run(self.sql,
              autocommit=self.autocommit,
              parameters=self.parameters)
コード例 #19
0
class TestOracleHookConn(unittest.TestCase):

    def setUp(self):
        super(TestOracleHookConn, self).setUp()

        self.connection = Connection(
            login='******',
            password='******',
            host='host',
            port=1521
        )

        self.db_hook = OracleHook()
        self.db_hook.get_connection = mock.Mock()
        self.db_hook.get_connection.return_value = self.connection

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_host(self, mock_connect):
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['user'], 'login')
        self.assertEqual(kwargs['password'], 'password')
        self.assertEqual(kwargs['dsn'], 'host')

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_sid(self, mock_connect):
        dsn_sid = {'dsn': 'dsn', 'sid': 'sid'}
        self.connection.extra = json.dumps(dsn_sid)
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['dsn'],
                         cx_Oracle.makedsn(dsn_sid['dsn'],
                                           self.connection.port, dsn_sid['sid']))

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_service_name(self, mock_connect):
        dsn_service_name = {'dsn': 'dsn', 'service_name': 'service_name'}
        self.connection.extra = json.dumps(dsn_service_name)
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['dsn'], cx_Oracle.makedsn(
            dsn_service_name['dsn'], self.connection.port,
            service_name=dsn_service_name['service_name']))

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_encoding_without_nencoding(self, mock_connect):
        self.connection.extra = json.dumps({'encoding': 'UTF-8'})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['encoding'], 'UTF-8')
        self.assertEqual(kwargs['nencoding'], 'UTF-8')

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_encoding_with_nencoding(self, mock_connect):
        self.connection.extra = json.dumps({'encoding': 'UTF-8', 'nencoding': 'gb2312'})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['encoding'], 'UTF-8')
        self.assertEqual(kwargs['nencoding'], 'gb2312')

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_nencoding(self, mock_connect):
        self.connection.extra = json.dumps({'nencoding': 'UTF-8'})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertNotIn('encoding', kwargs)
        self.assertEqual(kwargs['nencoding'], 'UTF-8')

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_mode(self, mock_connect):
        mode = {
            'sysdba': cx_Oracle.SYSDBA,
            'sysasm': cx_Oracle.SYSASM,
            'sysoper': cx_Oracle.SYSOPER,
            'sysbkp': cx_Oracle.SYSBKP,
            'sysdgd': cx_Oracle.SYSDGD,
            'syskmt': cx_Oracle.SYSKMT,
        }
        first = True
        for m in mode:
            self.connection.extra = json.dumps({'mode': m})
            self.db_hook.get_conn()
            if first:
                assert mock_connect.call_count == 1
                first = False
            args, kwargs = mock_connect.call_args
            self.assertEqual(args, ())
            self.assertEqual(kwargs['mode'], mode.get(m))

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_threaded(self, mock_connect):
        self.connection.extra = json.dumps({'threaded': True})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['threaded'], True)

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_events(self, mock_connect):
        self.connection.extra = json.dumps({'events': True})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['events'], True)

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_purity(self, mock_connect):
        purity = {
            'new': cx_Oracle.ATTR_PURITY_NEW,
            'self': cx_Oracle.ATTR_PURITY_SELF,
            'default': cx_Oracle.ATTR_PURITY_DEFAULT
        }
        first = True
        for p in purity:
            self.connection.extra = json.dumps({'purity': p})
            self.db_hook.get_conn()
            if first:
                assert mock_connect.call_count == 1
                first = False
            args, kwargs = mock_connect.call_args
            self.assertEqual(args, ())
            self.assertEqual(kwargs['purity'], purity.get(p))
コード例 #20
0
def get_clv(oracle_conn_id, src_client_id, storage_bucket, ds, **context):
    import matplotlib.pyplot
    matplotlib.pyplot.ioff()
    ##
    from lifetimes.utils import calibration_and_holdout_data
    from lifetimes.plotting import plot_frequency_recency_matrix
    from lifetimes.plotting import plot_probability_alive_matrix
    from lifetimes.plotting import plot_calibration_purchases_vs_holdout_purchases
    from lifetimes.plotting import plot_period_transactions
    from lifetimes.plotting import plot_history_alive
    from lifetimes.plotting import plot_cumulative_transactions
    from lifetimes.utils import expected_cumulative_transactions
    from lifetimes.utils import summary_data_from_transaction_data
    from lifetimes import BetaGeoFitter
    from lifetimes import GammaGammaFitter
    import datetime
    import pandas as pd
    import datalab.storage as gcs
    conn = OracleHook(oracle_conn_id=oracle_conn_id).get_conn()
    print(src_client_id, context)
    query = context['templates_dict']['query']
    data = pd.read_sql(query, con=conn)
    data.columns = data.columns.str.lower()
    print(data.head())

    # Calculate RFM values#
    calibration_end_date = datetime.datetime(2018, 5, 24)
    training_rfm = calibration_and_holdout_data(
        transactions=data,
        customer_id_col='src_user_id',
        datetime_col='pickup_date',
        calibration_period_end=calibration_end_date,
        freq='D',
        monetary_value_col='price_total')
    bgf = BetaGeoFitter(penalizer_coef=0.0)
    bgf.fit(training_rfm['frequency_cal'], training_rfm['recency_cal'],
            training_rfm['T_cal'])
    print(bgf)

    # Matrix charts
    plot_period_transactions_chart = context.get("ds_nodash") + str(
        src_client_id) + '_plot_period_transactions_chart.svg'
    plot_frequency_recency_chart = context.get("ds_nodash") + str(
        src_client_id) + '_plot_frequency_recency_matrix.svg'
    plot_probability_chart = context.get("ds_nodash") + str(
        src_client_id) + '_plot_probability_alive_matrix.svg'
    plot_calibration_vs_holdout_chart = context.get("ds_nodash") + str(
        src_client_id) + '_plot_calibration_vs_holdout_purchases.svg'

    ax0 = plot_period_transactions(bgf, max_frequency=30)
    ax0.figure.savefig(plot_period_transactions_chart, format='svg')
    ax1 = plot_frequency_recency_matrix(bgf)
    ax1.figure.savefig(plot_frequency_recency_chart, format='svg')
    ax2 = plot_probability_alive_matrix(bgf)
    ax2.figure.savefig(plot_probability_chart, format='svg')
    ax3 = plot_calibration_purchases_vs_holdout_purchases(bgf,
                                                          training_rfm,
                                                          n=50)
    ax3.figure.savefig(plot_calibration_vs_holdout_chart, format='svg')
    full_rfm = summary_data_from_transaction_data(
        data,
        customer_id_col='src_user_id',
        datetime_col='pickup_date',
        monetary_value_col='price_total',
        datetime_format=None,
        observation_period_end=None,
        freq='D')
    returning_full_rfm = full_rfm[full_rfm['frequency'] > 0]
    ggf = GammaGammaFitter(penalizer_coef=0)
    ggf.fit(returning_full_rfm['frequency'],
            returning_full_rfm['monetary_value'])

    customer_lifetime = 30  # expected number of months lifetime of a customer
    clv = ggf.customer_lifetime_value(
        bgf,  #the model to use to predict the number of future transactions
        full_rfm['frequency'],
        full_rfm['recency'],
        full_rfm['T'],
        full_rfm['monetary_value'],
        time=customer_lifetime,  # months
        discount_rate=0.01  # monthly discount rate ~ 12.7% annually
    ).sort_values(ascending=False)
    full_rfm_with_value = full_rfm.join(clv)

    full_rfm_file = context.get("ds_nodash") + "-src_client_id-" + str(
        src_client_id) + '-icabbi-test.csv'
    full_rfm_with_value.to_csv(full_rfm_file)
    GoogleCloudStorageHook(
        google_cloud_storage_conn_id='google_conn_default').upload(
            bucket=storage_bucket,
            object=str(src_client_id) + "/" + context.get("ds_nodash") + "/" +
            full_rfm_file,
            filename=full_rfm_file)
    GoogleCloudStorageHook(
        google_cloud_storage_conn_id='google_conn_default').upload(
            bucket=storage_bucket,
            object=str(src_client_id) + "/" + context.get("ds_nodash") + "/" +
            plot_period_transactions_chart,
            filename=full_rfm_file)
    GoogleCloudStorageHook(
        google_cloud_storage_conn_id='google_conn_default').upload(
            bucket=storage_bucket,
            object=str(src_client_id) + "/" + context.get("ds_nodash") + "/" +
            plot_frequency_recency_chart,
            filename=full_rfm_file)
    GoogleCloudStorageHook(
        google_cloud_storage_conn_id='google_conn_default').upload(
            bucket=storage_bucket,
            object=str(src_client_id) + "/" + context.get("ds_nodash") + "/" +
            plot_probability_chart,
            filename=full_rfm_file)
    GoogleCloudStorageHook(
        google_cloud_storage_conn_id='google_conn_default').upload(
            bucket=storage_bucket,
            object=str(src_client_id) + "/" + context.get("ds_nodash") + "/" +
            plot_calibration_vs_holdout_chart,
            filename=full_rfm_file)
コード例 #21
0
 def execute(self, context):
     src_hook = OracleHook(oracle_conn_id=self.oracle_source_conn_id)
     dest_hook = OracleHook(oracle_conn_id=self.oracle_destination_conn_id)
     self._execute(src_hook, dest_hook, context)
コード例 #22
0
 def execute(self, context):
     logging.info("Executing: " + str(self.sql))
     hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
     hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
コード例 #23
0
class TestOracleHookConn(unittest.TestCase):
    def setUp(self):
        super(TestOracleHookConn, self).setUp()

        self.connection = Connection(login='******',
                                     password='******',
                                     host='host',
                                     port=1521)

        self.db_hook = OracleHook()
        self.db_hook.get_connection = mock.Mock()
        self.db_hook.get_connection.return_value = self.connection

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_host(self, mock_connect):
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['user'], 'login')
        self.assertEqual(kwargs['password'], 'password')
        self.assertEqual(kwargs['dsn'], 'host')

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_sid(self, mock_connect):
        dsn_sid = {'dsn': 'dsn', 'sid': 'sid'}
        self.connection.extra = json.dumps(dsn_sid)
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(
            kwargs['dsn'],
            cx_Oracle.makedsn(dsn_sid['dsn'], self.connection.port,
                              dsn_sid['sid']))

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_service_name(self, mock_connect):
        dsn_service_name = {'dsn': 'dsn', 'service_name': 'service_name'}
        self.connection.extra = json.dumps(dsn_service_name)
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(
            kwargs['dsn'],
            cx_Oracle.makedsn(dsn_service_name['dsn'],
                              self.connection.port,
                              service_name=dsn_service_name['service_name']))

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_encoding_without_nencoding(self, mock_connect):
        self.connection.extra = json.dumps({'encoding': 'UTF-8'})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['encoding'], 'UTF-8')
        self.assertEqual(kwargs['nencoding'], 'UTF-8')

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_encoding_with_nencoding(self, mock_connect):
        self.connection.extra = json.dumps({
            'encoding': 'UTF-8',
            'nencoding': 'gb2312'
        })
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['encoding'], 'UTF-8')
        self.assertEqual(kwargs['nencoding'], 'gb2312')

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_nencoding(self, mock_connect):
        self.connection.extra = json.dumps({'nencoding': 'UTF-8'})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertNotIn('encoding', kwargs)
        self.assertEqual(kwargs['nencoding'], 'UTF-8')

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_mode(self, mock_connect):
        mode = {
            'sysdba': cx_Oracle.SYSDBA,
            'sysasm': cx_Oracle.SYSASM,
            'sysoper': cx_Oracle.SYSOPER,
            'sysbkp': cx_Oracle.SYSBKP,
            'sysdgd': cx_Oracle.SYSDGD,
            'syskmt': cx_Oracle.SYSKMT,
        }
        first = True
        for m in mode:
            self.connection.extra = json.dumps({'mode': m})
            self.db_hook.get_conn()
            if first:
                assert mock_connect.call_count == 1
                first = False
            args, kwargs = mock_connect.call_args
            self.assertEqual(args, ())
            self.assertEqual(kwargs['mode'], mode.get(m))

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_threaded(self, mock_connect):
        self.connection.extra = json.dumps({'threaded': True})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['threaded'], True)

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_events(self, mock_connect):
        self.connection.extra = json.dumps({'events': True})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['events'], True)

    @mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
    def test_get_conn_purity(self, mock_connect):
        purity = {
            'new': cx_Oracle.ATTR_PURITY_NEW,
            'self': cx_Oracle.ATTR_PURITY_SELF,
            'default': cx_Oracle.ATTR_PURITY_DEFAULT
        }
        first = True
        for p in purity:
            self.connection.extra = json.dumps({'purity': p})
            self.db_hook.get_conn()
            if first:
                assert mock_connect.call_count == 1
                first = False
            args, kwargs = mock_connect.call_args
            self.assertEqual(args, ())
            self.assertEqual(kwargs['purity'], purity.get(p))