コード例 #1
0
ファイル: multi_stock_train.py プロジェクト: luhongkai/fin-ca
        def get_net_data(BEG, END):
            beg_idx, end_idx = env.get_data_idxs_range(BEG, END)

            raw_dates = env.get_raw_dates(beg_idx, end_idx)
            input = env.get_input(beg_idx, end_idx)
            px = env.get_adj_close_px(beg_idx, end_idx)
            px_pred_hor = env.get_adj_close_px(
                beg_idx + get_config().PRED_HORIZON,
                end_idx + get_config().PRED_HORIZON)
            tradeable_mask = env.get_tradeable_mask(beg_idx, end_idx)
            port_mask = env.get_portfolio_mask(beg_idx, end_idx)

            ds_sz = px_pred_hor.shape[1]

            raw_dates = raw_dates[:ds_sz]
            raw_week_days = np.full(raw_dates.shape, 0, dtype=np.int32)
            for i in range(raw_dates.shape[0]):
                date = date_from_timestamp(raw_dates[i])
                raw_week_days[i] = date.isoweekday()

            input = input[:, :ds_sz, :]
            tradeable_mask = tradeable_mask[:, :ds_sz]
            port_mask = port_mask[:, :ds_sz]
            px = px[:, :ds_sz]

            labels = (px_pred_hor - px) / px
            batch_num = get_batches_num(ds_sz, get_config().BPTT_STEPS)

            return beg_idx, ds_sz, batch_num, raw_dates, raw_week_days, tradeable_mask, port_mask, px, input, labels
コード例 #2
0
def get_net_data(env, BEG, END):
    beg_idx, end_idx = env.get_data_idxs_range(BEG, END)

    raw_dates = env.get_raw_dates(beg_idx, end_idx)
    input = env.get_input(beg_idx, end_idx)
    px = env.get_adj_close_px(beg_idx, end_idx)
    tradeable_mask = env.get_tradeable_mask(beg_idx, end_idx)

    raw_week_days = np.full(raw_dates.shape, 0, dtype=np.int32)
    for i in range(raw_dates.shape[0]):
        date = date_from_timestamp(raw_dates[i])
        raw_week_days[i] = date.isoweekday()

    return beg_idx, end_idx, raw_dates, raw_week_days, tradeable_mask, px, input
コード例 #3
0
ファイル: multi_stock_train.py プロジェクト: luhongkai/fin-ca
def build_time_axis(raw_dates):
    dt = []
    for raw_dt in np.nditer(raw_dates):
        dt.append(date_from_timestamp(raw_dt))
    return dt
コード例 #4
0
            prediction = np.zeros((total_stocks))
            fri_px = None
            mon_px = None
            fri_tradable_mask = None
            mon_tradable_mask = None

            cash = 1
            pos = np.zeros((total_stocks))
            pos_px = np.zeros((total_stocks))

            eq = np.zeros((trading_days))
            long_pos_mask = np.full((total_stocks), False)
            short_pos_mask = np.full((total_stocks), False)

            for idx in range(trading_days):
                date = date_from_timestamp(raw_dates[idx])
                curr_px = adj_px[:, idx]
                _tradable_maks = tradable_maks[:, idx]

                open_pos = date.isoweekday() == 1 and fri_px is not None
                # open_pos = date.isoweekday() == 5
                close_pos = date.isoweekday() == 5
                predict = date.isoweekday() == 5

                if predict:
                    fri_px = curr_px
                    fri_tradable_mask = _tradable_maks

                    selection = pred_df['date'] == date
                    curr_pred_df = pred_df.loc[selection]
                    prediction[curr_pred_df['stk_idx'].