Example #1
0
def test_simulation_broker_register_tick_multi_symbol(
    sim_broker_runner_and_streamer_1s, ):
    _, runner, streamer = sim_broker_runner_and_streamer_1s
    spy_stock_contract = StockContract(symbol="SPY")
    schw_stock_contract = StockContract(symbol="SCHW")

    spy_ask = []
    schw_ask = []

    def spy_receiver(_, price):
        nonlocal spy_ask
        spy_ask.append(price)

    def schw_receiver(_, price):
        nonlocal schw_ask
        schw_ask.append(price)

    streamer.subscribe_to_tick_data(
        contract=spy_stock_contract,
        func=spy_receiver,
        price_type=PriceType.ASK,
    )
    streamer.subscribe_to_tick_data(
        contract=schw_stock_contract,
        func=schw_receiver,
        price_type=PriceType.ASK,
    )

    runner.run_sim(step_count=4)

    np.testing.assert_equal(spy_ask, [413.53, 411.52])
    np.testing.assert_equal(schw_ask, [37.442, 37.442])
Example #2
0
def test_tick_data_delivery_order(sim_broker_runner_and_streamer_1s):
    _, runner, streamer = sim_broker_runner_and_streamer_1s
    spy_stock_contract = StockContract(symbol="SPY")
    schw_stock_contract = StockContract(symbol="SCHW")

    spy_ask = []
    schw_ask = []

    def spy_receiver(_, price):
        nonlocal spy_ask, schw_ask
        spy_ask.append(price)
        if price == 406.9:
            assert abs(len(spy_ask) - len(schw_ask)) == 0
        elif price == 407.57:
            assert abs(len(spy_ask) - len(schw_ask)) == 0

    def schw_receiver(_, price):
        nonlocal schw_ask
        schw_ask.append(price)

    streamer.subscribe_to_tick_data(
        contract=spy_stock_contract,
        func=spy_receiver,
        price_type=PriceType.ASK,
    )
    streamer.subscribe_to_tick_data(
        contract=schw_stock_contract,
        func=schw_receiver,
        price_type=PriceType.ASK,
    )

    runner.run_sim(step_count=9)

    assert len(spy_ask) == 17
    assert len(schw_ask) == 18
def test_cancel_tick_data(streamer):
    mid = None

    def update_mid(c, price_):
        nonlocal mid
        mid = price_

    contract = StockContract(symbol="SPY")
    streamer.subscribe_to_tick_data(
        contract=contract,
        func=update_mid,
        price_type=PriceType.MARKET,
    )

    t0 = time.time()
    while mid is None and time.time() - t0 <= AWAIT_TIME_OUT:
        streamer.sleep()

    assert mid is not None

    mid = None
    t0 = time.time()
    while mid is None and time.time() - t0 <= AWAIT_TIME_OUT:
        streamer.sleep()

    assert mid is not None  # was refreshed

    streamer.cancel_tick_data(contract=contract, func=update_mid)
    streamer.sleep()
    mid = None
    t0 = time.time()
    while mid is None and time.time() - t0 <= AWAIT_TIME_OUT:
        streamer.sleep()

    assert mid is None  # did not refresh again
 def _trades_receiver(self, trade: Dict):
     contract = StockContract(symbol=trade["sym"])
     tick = self._parse_trade(trade=trade)
     with self._trade_subscribers_lock:
         sub_dict = self._trade_subscribers[contract]
         for func, fn_kwargs in sub_dict.items():
             func(tick, **fn_kwargs)
def test_subscribe_to_bars_data(streamer, bar_size):
    latest = None

    def update(bar):
        nonlocal latest
        latest = bar

    contract = StockContract(symbol="SPY")
    streamer.subscribe_to_bars(
        contract=contract,
        bar_size=bar_size,
        func=update,
        rth=False,
    )

    streamer.sleep(bar_size.seconds + 10)

    assert latest is not None
    assert latest.name > datetime.now() - bar_size - timedelta(seconds=10)

    prev = latest.copy()
    latest = None
    streamer.sleep(bar_size.seconds)

    assert latest is not None
    assert latest.name > datetime.now() - bar_size
    assert latest.name > prev.name
Example #6
0
def test_retrieving_intermittently_cached_intraday(tmpdir, provider):
    retriever = HistoricalRetriever(provider=provider, hist_data_dir=tmpdir,)

    data = pd.DataFrame()
    dates = generate_trading_days(
        start_date=date.today() - timedelta(days=10),
        end_date=date.today() - timedelta(days=1),
    )
    date_ranges = [
        dates[1:2],
        dates[-3:-2],
        dates,
    ]
    for date_range in date_ranges:
        start_date = date_range[0]
        end_date = date_range[-1]

        contract = StockContract(symbol="SPY")

        try:
            data = retriever.retrieve_bar_data(
                contract=contract,
                start_date=start_date,
                end_date=end_date,
                bar_size=timedelta(days=1),
            )
        except NotImplementedError:  # todo: fix
            return

    validate_data_range(data=data, start_date=dates[0], end_date=dates[-1])
Example #7
0
def test_retrieving_intermittently_cached_trades(tmpdir, provider):
    retriever = HistoricalRetriever(provider=provider, hist_data_dir=tmpdir,)

    data = pd.DataFrame()
    dates = generate_trading_days(
        start_date=date(2020, 7, 21), end_date=date(2020, 7, 23),
    )
    date_ranges = [
        dates[:1],
        dates[-1:],
        dates,
    ]
    for date_range in date_ranges:
        start_date = date_range[0]
        end_date = date_range[-1]

        contract = StockContract(symbol="SPY")

        try:
            data = retriever.retrieve_trades_data(
                contract=contract, start_date=start_date, end_date=end_date,
            )
        except NotImplementedError:
            return

    validate_data_range(data=data, start_date=dates[0], end_date=dates[-1])
    def _validate_contract(contract: AContract) -> AContract:
        if isinstance(contract, StockContract):
            contract = StockContract(symbol=contract.symbol)
        else:
            raise TypeError(f"Unknown contract type {type(contract)}.")

        return contract
def test_cancel_bars_data(streamer):
    latest = None

    def update(bar):
        nonlocal latest
        latest = bar

    contract = StockContract(symbol="SPY")
    streamer.subscribe_to_bars(
        contract=contract,
        bar_size=timedelta(seconds=5),
        func=update,
        rth=False,
    )

    streamer.sleep(15)

    assert latest is not None
    assert latest.name > datetime.now() - timedelta(seconds=15)

    streamer.cancel_bars(contract=contract, func=update)
    last_latest = latest
    streamer.sleep(15)

    assert last_latest.name == latest.name
Example #10
0
def test_simulation_broker_register_tick_resolution_fail(
    sim_broker_runner_and_streamer_15m, ):
    spy_stock_contract = StockContract(symbol="SPY")
    _, _, streamer = sim_broker_runner_and_streamer_15m
    with pytest.raises(ValueError):
        streamer.subscribe_to_tick_data(
            contract=spy_stock_contract,
            func=lambda x: None,
        )
Example #11
0
def test_subscribe_to_position_updates():
    con = StockContract(symbol="SPY")
    pos = Position(
        account=DEFAULT_SIM_ACC,
        contract=con,
        position=10,
        ave_fill_price=345.6,
    )
    starting_positions = {
        DEFAULT_SIM_ACC: {con: pos},
    }
    untethered_sim_broker = SimulationBroker(
        sim_streamer=SimulationDataStreamer(),
        starting_funds={Currency.USD: 10_000},
        transaction_cost=1,
        starting_positions=starting_positions,
    )

    pos_updates = []

    def pos_updates_receiver(pos_: Position):
        pos_updates.append(pos_)

    untethered_sim_broker.subscribe_to_position_updates(
        func=pos_updates_receiver,
    )
    sim_order = LimitOrder(action=OrderAction.BUY, quantity=2, limit_price=346)
    sim_trade = Trade(contract=con, order=sim_order)
    untethered_sim_broker.place_trade(trade=sim_trade)
    untethered_sim_broker.simulate_trade_execution(
        trade=sim_trade, price=345.8, n_shares=1,
    )
    untethered_sim_broker.step()

    assert len(pos_updates) == 1

    updated_pos: Position = pos_updates[0]
    target_ave_price = (10 * 345.6 + 345.8) / 11

    assert updated_pos.contract == con
    assert updated_pos.position == 11
    assert updated_pos.ave_fill_price == target_ave_price
    assert updated_pos.account == DEFAULT_SIM_ACC

    untethered_sim_broker.simulate_trade_execution(
        trade=sim_trade, price=345.9, n_shares=1,
    )
    untethered_sim_broker.step()

    updated_pos: Position = pos_updates[1]
    target_ave_price = (target_ave_price * 11 + 345.9) / 12

    assert updated_pos.contract == con
    assert updated_pos.position == 12
    assert updated_pos.ave_fill_price == target_ave_price
    assert updated_pos.account == DEFAULT_SIM_ACC
Example #12
0
 def __init__(self, start_date: date, end_date: date, bar_size: timedelta):
     hist_retriever = HistoricalRetriever(hist_data_dir=TEST_DATA_DIR)
     contract = StockContract(symbol="SPY")
     self._sim_data = hist_retriever.retrieve_bar_data(
         contract=contract,
         bar_size=bar_size,
         start_date=start_date,
         end_date=end_date,
         cache_only=True,
     )
     self._sim_idx = 0
Example #13
0
def get_1_spy_mkt_trade(buy: bool) -> Trade:
    contract = StockContract(symbol="SPY")

    if buy:
        action = OrderAction.BUY
    else:
        action = OrderAction.SELL

    order = MarketOrder(action=action, quantity=1)
    trade = Trade(contract=contract, order=order)

    return trade
Example #14
0
def test_simulation_broker_init():
    spy_stock_contract = StockContract(symbol="SPY")
    streamer = SimulationDataStreamer()
    broker = SimulationBroker(
        sim_streamer=streamer,
        starting_funds={Currency.USD: 1_000},
        transaction_cost=1,
    )

    assert broker.acc_cash[Currency.USD] == 1_000
    assert broker.get_position(contract=spy_stock_contract) == 0
    assert broker.get_transaction_fee() == 1
def test_cancel_trades_data(streamer):
    tick = None
    received = threading.Event()

    def receiver(t):
        nonlocal tick
        nonlocal received

        tick = t
        if not received.is_set():
            received.set()

    streamer.subscribe_to_trades(
        contract=StockContract(symbol="SPY"),
        func=receiver,
    )

    received.wait(1)

    assert tick is not None

    tick = None

    time.sleep(1)

    assert tick is not None

    streamer.cancel_trades(
        contract=StockContract(symbol="SPY"),
        func=receiver,
    )

    time.sleep(1)

    tick = None

    time.sleep(1)

    assert tick is None
Example #16
0
    def _from_ib_contract(self, ib_contract: _IBContract):
        contract = None

        exchange = self._from_ib_exchange(ib_exchange=ib_contract.exchange)
        currency = self._from_ib_currency(ib_currency=ib_contract.currency)
        con_id = ib_contract.conId
        if con_id == 0:
            con_id = None

        if isinstance(ib_contract, _IBStock):
            contract = StockContract(
                con_id=con_id,
                symbol=ib_contract.symbol,
                exchange=exchange,
                currency=currency,
            )
        elif isinstance(ib_contract, _IBOption):
            last_trade_date = _get_opt_trade_date(
                last_trade_date_str=ib_contract.lastTradeDateOrContractMonth,
            )
            if ib_contract.right == "C":
                right = Right.CALL
            elif ib_contract.right == "P":
                right = Right.PUT
            else:
                raise ValueError(f"Unknown right type {ib_contract.right}.")
            contract = OptionContract(
                con_id=con_id,
                symbol=ib_contract.symbol,
                strike=ib_contract.strike,
                right=right,
                multiplier=float(ib_contract.multiplier),
                last_trade_date=last_trade_date,
                exchange=exchange,
                currency=currency,
            )
        elif isinstance(ib_contract, _IBForex):
            contract = ForexContract(
                symbol=ib_contract.symbol,
                con_id=con_id,
                exchange=exchange,
                currency=currency,
            )
        else:
            logging.warning(
                f"Contract type {ib_contract.secType} not understood."
                f" No contract was built."
            )

        return contract
Example #17
0
def test_retrieve_non_cached_trades_data(tmpdir, provider):
    start_date = date(2020, 7, 22)
    end_date = date(2020, 7, 23)

    retriever = HistoricalRetriever(provider=provider, hist_data_dir=tmpdir)
    contract = StockContract(symbol="SPY")

    try:
        data = retriever.retrieve_trades_data(
            contract=contract, start_date=start_date, end_date=end_date,
        )
    except NotImplementedError:
        return

    validate_data_range(data=data, start_date=start_date, end_date=end_date)
Example #18
0
def test_simulation_broker_register_daily(sim_broker_runner_and_streamer_15m):
    _, runner, streamer = sim_broker_runner_and_streamer_15m
    checker = BarChecker(
        start_date=date(2020, 4, 6),
        end_date=date(2020, 4, 7),
        bar_size=timedelta(days=1),
    )
    contract = StockContract(symbol="SPY")
    streamer.subscribe_to_bars(
        contract=contract,
        bar_size=timedelta(days=1),
        func=checker.bar_receiver,
    )
    runner.run_sim()
    checker.assert_all_received()
Example #19
0
def test_retrieve_cached_trades():
    start_date = date(2020, 6, 17)
    end_date = date(2020, 6, 19)

    retriever = HistoricalRetriever(hist_data_dir=TEST_DATA_DIR)

    contract = StockContract(symbol="SPY")
    data = retriever.retrieve_trades_data(
        contract=contract,
        start_date=start_date,
        end_date=end_date,
        cache_only=True,
    )

    assert len(data) != 0

    validate_data_range(data=data, start_date=start_date, end_date=end_date)
Example #20
0
def test_retriever_cached_daily():
    start_date = date(2020, 4, 1)
    end_date = date(2020, 4, 2)

    retriever = HistoricalRetriever(hist_data_dir=TEST_DATA_DIR)
    contract = StockContract(symbol="SPY")
    data = retriever.retrieve_bar_data(
        contract=contract,
        start_date=start_date,
        end_date=end_date,
        bar_size=timedelta(days=1),
        cache_only=True,
    )

    assert len(data) != 0

    validate_data_range(data=data, start_date=start_date, end_date=end_date)
Example #21
0
def test_retrieve_non_cached_daily(tmpdir, provider):
    start_date = date(2020, 4, 1)
    end_date = date(2020, 4, 2)

    retriever = HistoricalRetriever(provider=provider, hist_data_dir=tmpdir,)
    contract = StockContract(symbol="SPY")

    try:
        data = retriever.retrieve_bar_data(
            contract=contract,
            start_date=start_date,
            end_date=end_date,
            bar_size=timedelta(days=1),
        )
    except NotImplementedError:
        return

    validate_data_range(data=data, start_date=start_date, end_date=end_date)
def test_subscribe_to_tick_data(streamer):
    con = None
    ask = None
    bid = None
    mid = None

    def update_ask(contract_, price_):
        nonlocal con, ask
        con = contract_
        ask = price_

    def update_bid(c, price_):
        nonlocal bid
        bid = price_

    def update_mid(c, price_):
        nonlocal mid
        mid = price_

    contract = StockContract(symbol="SPY")
    streamer.subscribe_to_tick_data(
        contract=contract,
        func=update_ask,
        price_type=PriceType.ASK,
    )
    streamer.subscribe_to_tick_data(
        contract=contract,
        func=update_bid,
        price_type=PriceType.BID,
    )
    streamer.subscribe_to_tick_data(
        contract=contract,
        func=update_mid,
        price_type=PriceType.MARKET,
    )

    t0 = time.time()
    while (con is None and ask is None and bid is None and mid is None
           and time.time() - t0 <= AWAIT_TIME_OUT):
        streamer.sleep()

    assert con == contract
    assert ask > bid
    assert mid == (ask + bid) / 2
Example #23
0
def test_simulation_broker_buy(
    sim_broker_runner_and_streamer_15m, buy_1_spy_mkt_trade,
):
    broker, runner, _ = sim_broker_runner_and_streamer_15m
    spy_stock_contract = StockContract(symbol="SPY")

    assert broker.get_position(contract=spy_stock_contract) == 0

    runner.run_sim(step_count=1, cache_only=True)

    broker.place_trade(trade=buy_1_spy_mkt_trade)

    runner.run_sim(step_count=1, cache_only=True)

    spy_2020_4_6_9_45_open = 259.79

    assert np.isclose(
        broker.acc_cash[Currency.USD], 1000 - spy_2020_4_6_9_45_open - 1,
    )
    assert broker.get_position(contract=spy_stock_contract) == 1
Example #24
0
def test_retrieve_non_cached_trades_data_today_partial(tmpdir, provider):
    end_date = date.today()
    start_date = end_date - timedelta(days=1)

    retriever = HistoricalRetriever(provider=provider, hist_data_dir=tmpdir)
    contract = StockContract(symbol="SPY")

    try:
        data = retriever.retrieve_trades_data(
            contract=contract,
            start_date=start_date,
            end_date=end_date,
            allow_partial=True,
            rth=False,
        )
    except NotImplementedError:
        return

    validate_data_range(data=data, start_date=start_date, end_date=end_date)
    assert data.iloc[-1].name.date() == date.today()
Example #25
0
def test_simulation_broker_register_tick_single_symbol(
    sim_broker_runner_and_streamer_1s, ):
    spy_stock_contract = StockContract(symbol="SPY")
    _, runner, streamer = sim_broker_runner_and_streamer_1s

    mkt_prices = []
    ask_prices = []
    bid_prices = []

    def mkt_receiver(_, price):
        nonlocal mkt_prices
        mkt_prices.append(price)

    def ask_receiver(_, price):
        nonlocal ask_prices
        ask_prices.append(price)

    def bid_receiver(_, price):
        nonlocal bid_prices
        bid_prices.append(price)

    streamer.subscribe_to_tick_data(
        contract=spy_stock_contract,
        func=mkt_receiver,
    )
    streamer.subscribe_to_tick_data(
        contract=spy_stock_contract,
        func=ask_receiver,
        price_type=PriceType.ASK,
    )
    streamer.subscribe_to_tick_data(
        contract=spy_stock_contract,
        func=bid_receiver,
        price_type=PriceType.BID,
    )

    runner.run_sim(step_count=4)

    np.testing.assert_equal(mkt_prices, [409.99, 410.385])
    np.testing.assert_equal(ask_prices, [413.53, 411.52])
    np.testing.assert_equal(bid_prices, [406.45, 409.25])
Example #26
0
def test_simulation_broker_sell(
    sim_broker_runner_and_streamer_15m, sell_1_spy_mkt_trade,
):
    broker, runner, _ = sim_broker_runner_and_streamer_15m
    spy_stock_contract = StockContract(symbol="SPY")

    assert broker.get_position(contract=spy_stock_contract) == 0

    runner.run_sim(step_count=1)

    broker.place_trade(trade=sell_1_spy_mkt_trade)

    runner.run_sim(step_count=1)

    spy_2020_4_6_9_30_close = 259.79

    assert np.isclose(
        broker.acc_cash[Currency.USD],
        1000 + spy_2020_4_6_9_30_close - 1,
        0.01,
    )
def test_subscribe_to_trades_data(streamer):
    tick = None
    received = threading.Event()

    def receiver(t):
        nonlocal tick
        nonlocal received

        if not received.is_set():
            tick = t
            received.set()

    streamer.subscribe_to_trades(
        contract=StockContract(symbol="SPY"),
        func=receiver,
    )

    received.wait(2)

    assert isinstance(tick, Tick)
    assert tick.symbol == "SPY"
Example #28
0
def test_historical_bar_aggregator():
    start_date = date(2020, 4, 6)
    end_date = date(2020, 4, 7)

    retriever = HistoricalRetriever(hist_data_dir=TEST_DATA_DIR)
    contract = StockContract(symbol="SPY")
    base_data = retriever.retrieve_bar_data(
        contract=contract,
        start_date=start_date,
        end_date=end_date,
        bar_size=timedelta(minutes=1),
        cache_only=True,
    )

    aggregator = HistoricalAggregator(hist_data_dir=TEST_DATA_DIR)
    agg_data = aggregator.aggregate_data(
        contract=contract,
        start_date=start_date,
        end_date=end_date,
        base_bar_size=timedelta(minutes=1),
        target_bar_size=timedelta(minutes=5),
    )

    assert len(agg_data) == 156
    assert agg_data.iloc[0]["open"] == base_data.iloc[0]["open"]
    assert agg_data.iloc[0]["close"] == base_data.iloc[4]["close"]
    assert agg_data.iloc[0]["high"] == base_data.iloc[:5]["high"].max()
    assert agg_data.iloc[0]["low"] == base_data.iloc[:5]["low"].min()
    assert agg_data.iloc[0]["volume"] == base_data.iloc[:5]["volume"].sum()

    agg_data = aggregator.aggregate_data(
        contract=contract,
        start_date=start_date,
        end_date=end_date,
        base_bar_size=timedelta(minutes=1),
        target_bar_size=timedelta(minutes=10),
    )

    assert len(agg_data) == 78
Example #29
0
def test_retrieving_intermittently_cached_daily(tmpdir, provider):
    retriever = HistoricalRetriever(provider=provider, hist_data_dir=tmpdir)

    start_date = date(2020, 3, 3)
    end_date = date(2020, 3, 3)

    contract = StockContract(symbol="SPY")

    try:
        retriever.retrieve_bar_data(
            contract=contract,
            start_date=start_date,
            end_date=end_date,
            bar_size=timedelta(days=1),
        )
    except NotImplementedError:  # todo: fix
        return

    start_date = date(2020, 3, 5)
    end_date = date(2020, 3, 5)

    retriever.retrieve_bar_data(
        contract=contract,
        start_date=start_date,
        end_date=end_date,
        bar_size=timedelta(days=1),
    )

    start_date = date(2020, 3, 2)
    end_date = date(2020, 3, 6)

    data = retriever.retrieve_bar_data(
        contract=contract,
        start_date=start_date,
        end_date=end_date,
        bar_size=timedelta(days=1),
    )

    validate_data_range(data=data, start_date=start_date, end_date=end_date)
Example #30
0
def test_retrieve_non_cached_intraday(tmpdir, provider):
    start_date = date.today() - timedelta(days=7)
    end_date = date.today() - timedelta(days=1)

    retriever = HistoricalRetriever(provider=provider, hist_data_dir=tmpdir)
    contract = StockContract(symbol="SPY")

    try:
        data = retriever.retrieve_bar_data(
            contract=contract,
            start_date=start_date,
            end_date=end_date,
            bar_size=timedelta(minutes=1),
        )
    except NotImplementedError:
        return

    validate_data_range(data=data, start_date=start_date, end_date=end_date)
    assert (
        np.isclose(len(data), 5 * 6.5 * 60, atol=7 * 60)
        or len(data) == 4800  # for outside RTHs IB
    )