コード例 #1
0
def _test_account():
    # 建立相关数据
    n_step = 60
    ohlcav_col_name_list = ["open", "high", "low", "close", "amount", "volume"]
    from ibats_common.example.data import load_data
    md_df = load_data('RB.csv').set_index('trade_date')[ohlcav_col_name_list]
    md_df.index = pd.DatetimeIndex(md_df.index)
    from ibats_common.backend.factor import get_factor, transfer_2_batch
    factors_df = get_factor(md_df, dropna=True)
    df_index, df_columns, data_arr_batch = transfer_2_batch(factors_df,
                                                            n_step=n_step)
    md_df = md_df.loc[df_index, :]
    shape = [
        data_arr_batch.shape[0], 5,
        int(n_step / 5), data_arr_batch.shape[2]
    ]
    data_factors = np.transpose(data_arr_batch.reshape(shape), [0, 2, 3, 1])
    print(data_arr_batch.shape, '->', shape, '->', data_factors.shape)
    # 建立 Account
    env = Account(md_df, data_factors)
    next_observation = env.reset()
    print('next_observation.shape:', next_observation.shape)
    assert next_observation.shape == (1, 12, 78, 5)
    next_state, reward, done = env.step(1)
    assert next_observation.shape == (1, 12, 78, 5)
    assert not done
コード例 #2
0
ファイル: market.py プロジェクト: IBATS/IBATS_Common
def _test_quote_market():
    n_step = 60
    ohlcav_col_name_list = ["open", "high", "low", "close", "amount", "volume"]
    from ibats_common.example.data import load_data
    md_df = load_data('RB.csv').set_index('trade_date')[ohlcav_col_name_list]
    md_df.index = pd.DatetimeIndex(md_df.index)
    from ibats_common.backend.factor import get_factor, transfer_2_batch
    factors_df = get_factor(md_df, dropna=True)
    df_index, df_columns, data_arr_batch = transfer_2_batch(factors_df,
                                                            n_step=n_step)
    md_df = md_df.loc[df_index, :]
    # 建立 QuotesMarket
    qm = QuotesMarket(md_df=md_df[['close', 'open']],
                      data_factors=data_arr_batch,
                      state_with_flag=True)
    next_observation = qm.reset()
    assert len(next_observation) == 2
    assert next_observation[0].shape[0] == n_step
    assert next_observation[1] == 0
    next_observation, reward, done = qm.step(1)
    assert len(next_observation) == 2
    assert next_observation[1] == 1
    assert not done
    next_observation, reward, done = qm.step(0)
    assert next_observation[1] == 0
    assert reward != 0
    next_observation, reward, done = qm.step(0)
    assert next_observation[1] == 0
    assert reward == 0
    next_observation, reward, done = qm.step(3)
    assert next_observation[1] == 0
    assert reward == 0
    next_observation, reward, done = qm.step(2)
    assert next_observation[1] == -1
    assert not done
    next_observation, reward, done = qm.step(3)
    assert next_observation[1] == -1
    assert reward != 0
    try:
        qm.step(4)
    except ValueError:
        print('is ok for not supporting action>3')
コード例 #3
0
def _test_quote_market():
    import os
    n_step = 60
    from ibats_common.example.data import load_data
    md_df = load_data(
        'RB.csv',
        folder_path=os.path.join(os.pardir, os.pardir, os.pardir, 'example', 'data')  # r'..\..\..\example\data'
    ).set_index('trade_date')[DEFAULT_MD_OHLCVA_LABELS]
    md_df.index = pd.DatetimeIndex(md_df.index)
    from ibats_common.backend.factor import get_factor, transfer_2_batch
    factors_df = get_factor(md_df, dropna=True)
    df_index, df_columns, data_arr_batch = transfer_2_batch(factors_df, n_step=n_step)
    md_df = md_df.loc[df_index, :]
    # 建立 QuotesMarket
    qm = QuotesMarket(md_df=md_df[['close', 'open']], data_factors=data_arr_batch, state_with_flag=True)
    next_observation = qm.reset()
    assert len(next_observation) == 3
    assert next_observation[0].shape[0] == n_step
    assert next_observation[1] == FLAG_EMPTY
    next_observation, reward, done = qm.step(ACTION_LONG)
    assert len(next_observation) == 3
    assert next_observation[1] == FLAG_LONG
    assert not done
    next_observation, reward, done = qm.step(ACTION_CLOSE)
    assert next_observation[1] == FLAG_EMPTY
    assert reward != 0
    next_observation, reward, done = qm.step(ACTION_CLOSE)
    assert next_observation[1] == FLAG_EMPTY
    assert reward == 0
    next_observation, reward, done = qm.step(ACTION_KEEP)
    assert next_observation[1] == FLAG_EMPTY
    assert reward == 0
    next_observation, reward, done = qm.step(ACTION_SHORT)
    assert next_observation[1] == FLAG_SHORT
    assert not done
    next_observation, reward, done = qm.step(ACTION_KEEP)
    assert next_observation[1] == FLAG_SHORT
    assert reward != 0
    try:
        qm.step(4)
    except ValueError:
        print('is ok for not supporting action>3')
コード例 #4
0
ファイル: account.py プロジェクト: IBATS/IBATS_Common
def _test_account2():
    """测试 plot_data 返回数据是否符合预期"""
    n_step = 60
    ohlcav_col_name_list = ["open", "high", "low", "close", "amount", "volume"]
    from ibats_common.example.data import load_data
    md_df = load_data('RB.csv').set_index('trade_date')[ohlcav_col_name_list]
    md_df.index = pd.DatetimeIndex(md_df.index)
    from ibats_common.backend.factor import get_factor, transfer_2_batch
    factors_df = get_factor(md_df, dropna=True)
    df_index, df_columns, data_arr_batch = transfer_2_batch(factors_df,
                                                            n_step=n_step)
    md_df = md_df.loc[df_index, :]
    shape = [
        data_arr_batch.shape[0], 5,
        int(n_step / 5), data_arr_batch.shape[2]
    ]
    data_factors = np.transpose(data_arr_batch.reshape(shape), [0, 2, 3, 1])
    print(data_arr_batch.shape, '->', shape, '->', data_factors.shape)
    # 建立 Account
    env = Account(md_df, data_factors)
    next_observation = env.reset()
    # 做空
    env.step(2)
    for n in range(int(md_df.shape[0] / 2)):
        env.step(3)
    # 做多
    next_observation, reward, done = env.step(1)
    while not done:
        next_observation, reward, done = env.step(3)

    # 展示结果
    reward_df = env.plot_data()
    value_s = reward_df.iloc[:, 0]
    from ibats_utils.mess import datetime_2_str
    from datetime import datetime
    dt_str = datetime_2_str(datetime.now(), '%Y-%m-%d %H_%M_%S')
    title = f'test_account_{dt_str}'
    from ibats_common.analysis.plot import plot_twin
    plot_twin(value_s, md_df["close"], name=title)