def get_extra_from_conn(conn_id):
    """
    Obtain extra fields from airflow connection.
    Parameters
    ----------
    conn_id : str
        Airflow Connection ID
    Returns
    -------
    dict
        extra kwargs
    """
    hook = BaseHook(conn_id)
    conn = hook.get_connection(conn_id)
    return json.loads(conn.extra)
def get_redshift_uri(conn_id):
    """
        Builds a (redshift-) jdbc-uri from a given airflow connection-id.
    """
    hook = BaseHook(conn_id)
    conn = hook.get_connection(conn_id)
    if not conn.host:
        return ""
    else:
        uri = f"jdbc:redshift://{conn.host}:{conn.port}/{conn.schema}?user={conn.login}&password={conn.password}"
        extra = conn.extra_dejson
        params = [f"{k}={v}" for k, v in extra.items()]

        if params:
            params = "&".join(params)
            uri += f"?{params}"

        return uri
    def get_records_batch(self, hook, query_filter):
        # Chunks the records and streams to s3 by specified batchsize.

        if query_filter == '':
            query_filter = 'WHERE'
        else:
            query_filter = query_filter + ' AND '

        count_sql_max = """
        SELECT max({0}) as c FROM {1} """.format(self.primary_key,
                                                 self.mssql_table)

        count_sql_min = """
        SELECT min({0}) as c FROM {1}  WHERE {0}>0""".format(
            self.primary_key, self.mssql_table)

        count_sql_max_incremental = """
                SELECT max({0}) as c FROM {1} """.format(
            self.incremental_key, self.mssql_table)

        count_sql_min_incremental = """
                        SELECT min({0}) as c FROM {1} """.format(
            self.incremental_key, self.mssql_table)

        # if query_filter != 'WHERE':
        #    # Remove the AND from the query filter so you're only batching
        #    # for incremental loads within your timerange. Assumes primary_key
        # is incremental.
        #   count_sql_max += query_filter.split("AND")[0]
        #  #count_sql_min += query_filter.split("AND")[0]

        count = int(hook.get_pandas_df(count_sql_max)['c'][0])
        min_count = int(hook.get_pandas_df(count_sql_min)['c'][0])

        max_date = (hook.get_pandas_df(count_sql_max_incremental)['c'][0])
        min_date = (hook.get_pandas_df(count_sql_min_incremental)['c'][0])
        print(count_sql_min)
        print(count)
        print(min_count)

        s3_conn = BaseHook('S3').get_connection(self.s3_conn_id)
        s3_creds = s3_conn.extra_dejson

        if s3_key_suffix:
            s3_key = s3_key.split(
                ".")[0] + min_date + '-' + max_date + s3_key.split(".")[1]

        s3_key = '{}/{}'.format(self.s3_bucket, self.s3_key)

        url = 's3://{}:{}@{}'.format(s3_creds['aws_access_key_id'],
                                     s3_creds['aws_secret_access_key'], s3_key)

        logging.info('Initiating record retrieval in batches.')
        logging.info('Query'.format(count_sql_min))
        logging.info(count_sql_min)
        logging.info('Start Date: {0}'.format(self.start))
        logging.info('End Date: {0}'.format(self.end))
        logging.info('smallest_number: {0}'.format(min_count))
        logging.info('count: {0}'.format(count))

        # Smart Open is a library for efficiently streaming large files to S3.
        # Streaming data to S3 here so it doesn't break the task container.
        # https://pypi.python.org/pypi/smart_open
        # Does this here because smart_open doesn't yet support an
        # append mode and doing it as a function was causing the file to be
        # overwritten every time.

        with smart_open.smart_open(url, 'wb') as fout:
            logging.info("First Row {0}".format(min_count)),
            logging.info("Total Rows: {0}".format(count))
            logging.info("Batch Size: {0}".format(self.batchsize))
            for batch in range(min_count, count, self.batchsize):
                query = \
                    """
                    SELECT  *
                    FROM {table}
                    {query_filter} {primary_key} >= {batch}
                    AND {primary_key} < {batch_two};
                    """.format(count=count,
                               table=self.mssql_table,
                               primary_key=self.primary_key,
                               query_filter=query_filter,
                               batch=batch,
                               batch_two=batch + self.batchsize)

                logging.info(query)

                # Perform query and convert returned tuple to list
                results = list(hook.get_records(query))
                logging.info(
                    'Successfully performed query for batch {0}-{1}.'.format(
                        batch, (batch + self.batchsize)))

                results = [
                    dict([k.lower(), str(v)] if v is not None else [k, v]
                         for k, v in i.items()) for i in results
                ]
                results = '\n'.join([json.dumps(i) for i in results])
                # Write the results to bytes.
                results = results.encode('utf-8')
                logging.info("Uploading!")
                fout.write(results)
Esempio n. 4
0
def load_data(**context):
    postgres_hook = PostgresHook('admin_postgres')
    tickers = postgres_hook.get_records(
        'select yf_code from tickers where fetch_from_yahoo_finance')
    logging.info('Loaded %d tickers from db.' % len(tickers))
    tickers = [x[0] for x in tickers]
    frequency = '1d'
    start_dt = parse_execution_date(
        context['yesterday_ds']) - timedelta(days=7)
    data = yf.download(tickers=tickers,
                       start=start_dt,
                       end=context['tomorrow_ds'],
                       interval=frequency,
                       auto_adjust=True,
                       group_by='ticker',
                       progress=False,
                       threads=True)

    columns_mapping = {
        'Date': 'ts',
        'Open': 'open',
        'High': 'high',
        'Low': 'low',
        'Close': 'close',
        'Volume': 'volume',
        'Adj Close': 'adj_close'
    }
    ch_columns = ['ticker', 'frequency', 'source', 'type'] + list(
        columns_mapping.values())
    df = None
    for ticker in tickers:
        try:
            _df = data[ticker].copy()
        except KeyError:
            logging.error('Ticker %s not found in data' % ticker)
            continue
        _df = _df.reset_index()
        _df['ticker'] = ticker
        _df['frequency'] = frequency
        _df['source'] = 'yfinance'
        _df['type'] = 'history'
        _df = _df.rename(columns=columns_mapping)
        if 'adj_close' not in _df.columns:
            _df['adj_close'] = np.nan
        _df = _df[ch_columns]
        _df = _df[~_df.close.isna()]

        if df is None:
            df = _df
        else:
            df = pd.concat([df, _df])

    logging.info('Prepared df with shape (%s, %s)' % df.shape)
    ch_hook = BaseHook(None)
    ch_conn = ch_hook.get_connection('rocket_clickhouse')
    data_json_each = ''
    df.reset_index(drop=True, inplace=True)
    for i in df.index:
        json_str = df.loc[i].to_json(date_format='iso')
        data_json_each += json_str + '\n'

    result = requests.post(
        url=ch_conn.host,
        data=data_json_each,
        params=dict(
            query='insert into rocket.events format JSONEachRow',
            user=ch_conn.login,
            password=ch_conn.password,
            date_time_input_format='best_effort',
        ))
    if result.ok:
        logging.info('Insert ok.')
    else:
        raise requests.HTTPError('Request response code: %d. Message: %s' %
                                 (result.status_code, result.text))
Esempio n. 5
0
def get_tblnm_list(conn_id):
    hookbs = BaseHook(source=None)
    ## conn_type
    # conn_type = hookbs.get_connection(conn_id=conn_id).conn_type
    db_hook = hookbs.get_hook(conn_id=conn_id)
    return db_hook.get_sqlalchemy_engine().table_names()