예제 #1
0
def test_namespace():

    a1 = Array("a1", [7, 8, 9])
    a2 = Array("a2", [3, 2, 1])

    t1 = BinOp("t1", operator.mul)(a1, a2)

    a = Namespace("world")(a1, a2)

    feed = DataFeed([t1, a])

    assert feed.next() == {"world:/a1": 7, "world:/a2": 3, "t1": 21}
    feed.reset()
    assert feed.next() == {"world:/a1": 7, "world:/a2": 3, "t1": 21}
예제 #2
0
def test_runs_with__external_feed_only(portfolio):

    df = pd.read_csv("tests/data/input/coinbase_(BTC,ETH)USD_d.csv").tail(100)
    df = df.rename({"Unnamed: 0": "date"}, axis=1)
    df = df.set_index("date")

    coinbase_btc = df.loc[:, [name.startswith("BTC") for name in df.columns]]
    coinbase_eth = df.loc[:, [name.startswith("ETH") for name in df.columns]]

    ta.add_all_ta_features(
        coinbase_btc,
        colprefix="BTC:",
        **{k: "BTC:" + k for k in ['open', 'high', 'low', 'close', 'volume']}
    )
    ta.add_all_ta_features(
        coinbase_eth,
        colprefix="ETH:",
        **{k: "ETH:" + k for k in ['open', 'high', 'low', 'close', 'volume']}
    )

    nodes = []
    for name in coinbase_btc.columns:
        nodes += [Stream(name, list(coinbase_btc[name]))]
    for name in coinbase_eth.columns:
        nodes += [Stream(name, list(coinbase_eth[name]))]
    coinbase = Namespace("coinbase")(*nodes)
    feed = DataFeed([coinbase])

    action_scheme = ManagedRiskOrders()
    reward_scheme = SimpleProfit()

    env = TradingEnvironment(
        portfolio=portfolio,
        action_scheme=action_scheme,
        reward_scheme=reward_scheme,
        feed=feed,
        window_size=50,
        use_internal=False,
        enable_logger=False
    )

    done = False
    obs = env.reset()
    while not done:

        action = env.action_space.sample()
        obs, reward, done, info = env.step(action)

    n_features = coinbase_btc.shape[1] + coinbase_eth.shape[1]
    assert obs.shape == (50, n_features)
예제 #3
0
def test_select():
    a1 = Array("a1", [7, 8, 9])
    a2 = Array("a2", [3, 2, 1])

    t1 = BinOp("t1", operator.mul)(a1, a2)
    a = Namespace("world")(a1, a2)

    s = Select("world:/a1")(t1, a)
    feed = DataFeed([s])

    print(a1.name, a1.inbound, a1.outbound)
    print(a2.name, a2.inbound, a2.outbound)
    print(t1.name, t1.inbound, t1.outbound)
    print(a.name, a.inbound, a.outbound)
    print(s.name, s.inbound, s.outbound)
    print(feed.inputs)

    assert feed.next() == {"world:/a1": 7}