コード例 #1
0
 def options(self, expiry, strike, right, database="Quotes", table=None):
     table = self.symbol if table == None else table
     global app
     app = EClient(
         Wrapper(self.qdate, self.symbol, self.currency, database, table,
                 expiry, strike, right, self.window))
     app.connect(host="127.0.0.1", port=4001, clientId=123)
     app.run()
コード例 #2
0
ファイル: IBRoute.py プロジェクト: jecker7/kazaam
 def __init__(self):
     EWrapper.__init__(self)
     EClient.__init__(self)
     self.connect(addr, port, client_id)
     self.order_id = None
     thread = Thread(target=self.run)
     thread.start()
     client = EClient(wrapper)
     client.connect("")
コード例 #3
0
 def equities(self, database="Quotes", table="Spot"):
     global app
     app = EClient(
         Wrapper(self.qdate,
                 self.symbol,
                 self.currency,
                 database,
                 table,
                 window=self.window))
     app.connect(host="127.0.0.1", port=4001, clientId=123)
     app.run()
コード例 #4
0
 def __init__(self, host, port, client_id, initial_cash: float):
     super().__init__(initial_cash)
     client = EClient(self)
     self.client = client
     self.positions = {}
     self.orders: List[Order] = []
     self._next_valid_id = None
     self.ib_order_id_to_order: Dict[int, Order] = {}
     self.contract_code_to_detail = {}
     client.connect(host, port, client_id)
     import threading
     threading.Thread(name="ib_msg_consumer", target=client.run).start()
コード例 #5
0
class IBBroker(Broker):
    """
    Interactive Brokers Broker class. Main purpose of this class is to connect to the API of IB broker and send
    the orders. It provides the functionality, which allows to retrieve a.o. the currently open positions and the
    value of the portfolio.

    Parameters
    -----------
    contract_ticker_mapper: IBContractTickerMapper
        mapper which provides the functionality that allows to map a ticker from any data provider
        (BloombergTicker, PortaraTicker etc.) onto the contract object from the Interactive Brokers API
    clientId: int
        id of the Broker client
    host: str
        IP address
    port: int
        socket port
    """
    def __init__(self,
                 contract_ticker_mapper: IBContractTickerMapper,
                 clientId: int = 0,
                 host: str = "127.0.0.1",
                 port: int = 7497):
        super().__init__(contract_ticker_mapper)
        self.logger = ib_logger.getChild(self.__class__.__name__)
        # Lock that synchronizes entries into the functions and makes sure we have a synchronous communication
        # with the client
        self.lock = Lock()
        self.orders_placement_lock = Lock()
        self.waiting_time = 30  # expressed in seconds
        # Lock that informs us that wrapper received the response
        self.action_event_lock = Event()
        self.wrapper = IBWrapper(self.action_event_lock,
                                 contract_ticker_mapper)
        self.client = EClient(wrapper=self.wrapper)
        self.clientId = clientId
        self.client.connect(host, port, self.clientId)

        # Run the client in the separate thread so that the execution of the program can go on
        # now we will have 3 threads:
        # - thread of the main program
        # - thread of the client
        # - thread of the wrapper
        thread = Thread(target=self.client.run)
        thread.start()

        # This will be released after the client initialises and wrapper receives the nextValidOrderId
        if not self._wait_for_results():
            raise ConnectionError("IB Broker was not initialized correctly")

    def get_portfolio_value(self) -> float:
        with self.lock:
            request_id = 1
            self._reset_action_lock()
            self.client.reqAccountSummary(request_id, 'All', 'NetLiquidation')
            wait_result = self._wait_for_results()
            self.client.cancelAccountSummary(request_id)

            if wait_result:
                return self.wrapper.net_liquidation
            else:
                error_msg = 'Time out while getting portfolio value'
                self.logger.error(error_msg)
                raise BrokerException(error_msg)

    def get_portfolio_tag(self, tag: str) -> float:
        with self.lock:
            request_id = 2
            self._reset_action_lock()
            self.client.reqAccountSummary(request_id, 'All', tag)
            wait_result = self._wait_for_results()
            self.client.cancelAccountSummary(request_id)

            if wait_result:
                return self.wrapper.tmp_value
            else:
                error_msg = 'Time out while getting portfolio tag: {}'.format(
                    tag)
                self.logger.error(error_msg)
                raise BrokerException(error_msg)

    def get_positions(self) -> List[BrokerPosition]:
        with self.lock:
            self._reset_action_lock()
            self.wrapper.reset_position_list()
            self.client.reqPositions()

            if self._wait_for_results():
                return self.wrapper.position_list
            else:
                error_msg = 'Time out while getting positions'
                self.logger.error(error_msg)
                raise BrokerException(error_msg)

    def get_liquid_hours(self, contract: IBContract) -> QFDataFrame:
        """ Returns a QFDataFrame containing information about liquid hours of the given contract. """
        with self.lock:
            self._reset_action_lock()
            request_id = 3
            self.client.reqContractDetails(request_id, contract)

            if self._wait_for_results():
                contract_details = self.wrapper.contract_details
                liquid_hours = contract_details.tradingHours.split(";")
                liquid_hours_df = QFDataFrame.from_records([
                    hours.split("-")
                    for hours in liquid_hours if not hours.endswith("CLOSED")
                ],
                                                           columns=[
                                                               "FROM", "TO"
                                                           ])
                for col in liquid_hours_df.columns:
                    liquid_hours_df[col] = to_datetime(liquid_hours_df[col],
                                                       format="%Y%m%d:%H%M")

                liquid_hours_df.name = contract_details.contract.symbol
                return liquid_hours_df

            else:
                error_msg = 'Time out while getting contract details'
                self.logger.error(error_msg)
                raise BrokerException(error_msg)

    def get_contract_details(self, contract: IBContract) -> ContractDetails:
        with self.lock:
            self._reset_action_lock()
            request_id = 4
            self.client.reqContractDetails(request_id, contract)

            if self._wait_for_results():
                return self.wrapper.contract_details
            else:
                error_msg = 'Time out while getting contract details'
                self.logger.error(error_msg)
                raise BrokerException(error_msg)

    def place_orders(self, orders: Sequence[Order]) -> Sequence[int]:
        with self.orders_placement_lock:
            open_order_ids = {o.id for o in self.get_open_orders()}

            order_ids_list = []
            for order in orders:
                self.logger.info('Placing Order: {}'.format(order))
                order_id = self._execute_single_order(
                    order) or self._find_newly_added_order_id(
                        order, open_order_ids)
                if order_id is None:
                    error_msg = f"Not able to place order: {order}"
                    self.logger.error(error_msg)
                    raise BrokerException(error_msg)
                else:
                    order_ids_list.append(order_id)
            return order_ids_list

    def cancel_order(self, order_id: int):
        with self.lock:
            self.logger.info('Cancel order: {}'.format(order_id))
            self._reset_action_lock()
            self.wrapper.set_cancel_order_id(order_id)
            self.client.cancelOrder(order_id)

            if not self._wait_for_results():
                error_msg = 'Time out while cancelling order id {} : \n'.format(
                    order_id)
                self.logger.error(error_msg)
                raise OrderCancellingException(error_msg)

    def get_open_orders(self) -> List[Order]:
        with self.lock:
            self._reset_action_lock()
            self.wrapper.reset_order_list()
            self.client.reqOpenOrders()

            if self._wait_for_results():
                return self.wrapper.order_list
            else:
                error_msg = 'Timeout while getting open orders'
                self.logger.error(error_msg)
                raise BrokerException(error_msg)

    def cancel_all_open_orders(self):
        """
        There is no way to check if cancelling of all orders was finished.
        One can only get open orders and confirm that the list is empty
        """
        with self.lock:
            self.client.reqGlobalCancel()
            self.logger.info('cancel_all_open_orders')

    def stop(self):
        """ Stop the Broker client and disconnect from the interactive brokers. """
        with self.lock:
            self.client.disconnect()
            self.logger.info(
                "Disconnecting from the interactive brokers client")

    def _find_newly_added_order_id(self, order: Order,
                                   order_ids_existing_before: Set[int]):
        """ Given the list of order ids open before placing the given order, try to compute the id of the recently
         placed order. """
        orders_matching_given_order = {
            o.id
            for o in self.get_open_orders() if o == order
        }
        order_ids = orders_matching_given_order.difference(
            order_ids_existing_before)
        return next(iter(order_ids)) if len(order_ids) == 1 else None

    def _execute_single_order(self, order) -> Optional[int]:
        with self.lock:
            order_id = self.wrapper.next_order_id()

            self._reset_action_lock()
            self.wrapper.set_waiting_order_id(order_id)

            ib_contract = self.contract_ticker_mapper.ticker_to_contract(
                order.ticker)
            ib_order = self._to_ib_order(order)

            self.client.placeOrder(order_id, ib_contract, ib_order)
            if self._wait_for_results(10):
                return order_id

    def _wait_for_results(self, waiting_time: Optional[int] = None) -> bool:
        """ Wait for self.waiting_time """
        waiting_time = waiting_time or self.waiting_time
        wait_result = self.action_event_lock.wait(waiting_time)
        return wait_result

    def _reset_action_lock(self):
        """ threads calling wait() will block until set() is called"""
        self.action_event_lock.clear()

    def _to_ib_order(self, order: Order):
        ib_order = IBOrder()
        ib_order.action = 'BUY' if order.quantity > 0 else 'SELL'
        ib_order.totalQuantity = abs(order.quantity)

        ib_order = self._set_execution_style(ib_order, order.execution_style)

        time_in_force = order.time_in_force
        tif_str = self._map_to_tif_str(time_in_force)
        ib_order.tif = tif_str

        return ib_order

    def _map_to_tif_str(self, time_in_force):
        if time_in_force == TimeInForce.GTC:
            tif_str = "GTC"
        elif time_in_force == TimeInForce.DAY:
            tif_str = "DAY"
        elif time_in_force == TimeInForce.OPG:
            tif_str = "OPG"
        else:
            raise ValueError("Not supported TimeInForce {tif:s}".format(
                tif=str(time_in_force)))

        return tif_str

    def _set_execution_style(self, ib_order, execution_style):
        if isinstance(execution_style, MarketOrder):
            ib_order.orderType = "MKT"
        elif isinstance(execution_style, StopOrder):
            ib_order.orderType = "STP"
            ib_order.auxPrice = execution_style.stop_price
        return ib_order
コード例 #6
0
        print("ContractDetailsEnd. ", reqId, "\n")

    def tickPrice(self, reqId, tickType, price, attrib):
        super().tickPrice(reqId, tickType, price, attrib)
        print("Tick Price. Ticker Id:", reqId, "tickType:", tickType, "Price:",
              price, "CanAutoExecute:", attrib.canAutoExecute, "PastLimit",
              attrib.pastLimit)

    def tickSnapshotEnd(self, reqId):
        super().tickSnapshotEnd(reqId)
        print("TickSnapshotEnd:", reqId)


wrapper = MyWrapper()
app = EClient(wrapper)
app.connect("127.0.0.1", 7497, clientId=0)
print("serverVersion:%s connectionTime:%s" %
      (app.serverVersion(), app.twsConnectionTime()))

from ibapi.contract import Contract
contract = Contract()
contract.symbol = "XAUUSD"
contract.secType = "CMDTY"
contract.exchange = "SMART"
contract.currency = "USD"

app.reqMktData(1, contract, "", False, False, [])
app.run()

if __name__ == '__main__':
    pass
コード例 #7
0
class IbApi(EWrapper):
    """"""
    data_filename = "ib_contract_data.db"
    data_filepath = str(get_file_path(data_filename))

    local_tz = get_localzone()

    def __init__(self, gateway: BaseGateway):
        """"""
        super().__init__()

        self.gateway = gateway
        self.gateway_name = gateway.gateway_name

        self.status = False

        self.reqid = 0
        self.orderid = 0
        self.clientid = 0
        self.history_reqid = 0
        self.account = ""
        self.ticks = {}
        self.orders = {}
        self.accounts = {}
        self.contracts = {}

        self.tick_exchange = {}
        self.subscribed = {}
        self.data_ready = False

        self.history_req = None
        self.history_condition = Condition()
        self.history_buf = []

        self.client = EClient(self)

    def connectAck(self):  # pylint: disable=invalid-name
        """
        Callback when connection is established.
        """
        self.status = True
        self.gateway.write_log("IB TWS连接成功")

        self.load_contract_data()

        self.data_ready = False

    def connectionClosed(self):  # pylint: disable=invalid-name
        """
        Callback when connection is closed.
        """
        self.status = False
        self.gateway.write_log("IB TWS连接断开")

    def nextValidId(self, orderId: int):  # pylint: disable=invalid-name
        """
        Callback of next valid orderid.
        """
        super().nextValidId(orderId)

        if not self.orderid:
            self.orderid = orderId

    def currentTime(self, time: int):  # pylint: disable=invalid-name
        """
        Callback of current server time of IB.
        """
        super().currentTime(time)

        dt = datetime.fromtimestamp(time)
        time_string = dt.strftime("%Y-%m-%d %H:%M:%S.%f")

        msg = f"服务器时间: {time_string}"
        self.gateway.write_log(msg)

    def error(self, reqId: TickerId, errorCode: int, errorString: str):  # pylint: disable=invalid-name
        """
        Callback of error caused by specific request.
        """
        super().error(reqId, errorCode, errorString)
        if reqId == self.history_reqid:
            self.history_condition.acquire()
            self.history_condition.notify()
            self.history_condition.release()

        msg = f"信息通知,代码:{errorCode},内容: {errorString}"
        self.gateway.write_log(msg)

        # Market data server is connected
        if errorCode == 2104 and not self.data_ready:
            self.data_ready = True

            self.client.reqCurrentTime()

            reqs = list(self.subscribed.values())
            self.subscribed.clear()
            for req in reqs:
                self.subscribe(req)

    def tickPrice(  # pylint: disable=invalid-name
            self, reqId: TickerId, tickType: TickType, price: float,
            attrib: TickAttrib):
        """
        Callback of tick price update.
        """
        super().tickPrice(reqId, tickType, price, attrib)

        if tickType not in TICKFIELD_IB2VT:
            return

        tick = self.ticks[reqId]
        name = TICKFIELD_IB2VT[tickType]
        setattr(tick, name, price)

        # Update name into tick data.
        contract = self.contracts.get(tick.vt_symbol, None)
        if contract:
            tick.name = contract.name

        # Forex of IDEALPRO and Spot Commodity has no tick time and last price.
        # We need to calculate locally.
        exchange = self.tick_exchange[reqId]
        if exchange is Exchange.IDEALPRO or "CMDTY" in tick.symbol:
            tick.last_price = (tick.bid_price_1 + tick.ask_price_1) / 2
            tick.datetime = datetime.now(self.local_tz)
        self.gateway.on_tick(copy(tick))

    def tickSize(self, reqId: TickerId, tickType: TickType, size: int):  # pylint: disable=invalid-name
        """
        Callback of tick volume update.
        """
        super().tickSize(reqId, tickType, size)

        if tickType not in TICKFIELD_IB2VT:
            return

        tick = self.ticks[reqId]
        name = TICKFIELD_IB2VT[tickType]
        setattr(tick, name, size)

        self.gateway.on_tick(copy(tick))

    def tickString(self, reqId: TickerId, tickType: TickType, value: str):  # pylint: disable=invalid-name
        """
        Callback of tick string update.
        """
        super().tickString(reqId, tickType, value)

        if tickType != TickTypeEnum.LAST_TIMESTAMP:
            return

        tick = self.ticks[reqId]
        dt = datetime.fromtimestamp(int(value))
        tick.datetime = self.local_tz.localize(dt)

        self.gateway.on_tick(copy(tick))

    def orderStatus(  # pylint: disable=invalid-name
        self,
        orderId: OrderId,
        status: str,
        filled: float,
        remaining: float,
        avgFillPrice: float,
        permId: int,
        parentId: int,
        lastFillPrice: float,
        clientId: int,
        whyHeld: str,
        mktCapPrice: float,
    ):
        """
        Callback of order status update.
        """
        super().orderStatus(
            orderId,
            status,
            filled,
            remaining,
            avgFillPrice,
            permId,
            parentId,
            lastFillPrice,
            clientId,
            whyHeld,
            mktCapPrice,
        )

        orderid = str(orderId)
        order = self.orders.get(orderid, None)
        if not order:
            return

        order.traded = filled

        # To filter PendingCancel status
        order_status = STATUS_IB2VT.get(status, None)
        if order_status:
            order.status = order_status

        self.gateway.on_order(copy(order))

    def openOrder(  # pylint: disable=invalid-name
        self,
        orderId: OrderId,
        ib_contract: Contract,
        ib_order: Order,
        orderState: OrderState,
    ):
        """
        Callback when opening new order.
        """
        super().openOrder(orderId, ib_contract, ib_order, orderState)

        orderid = str(orderId)
        order = OrderData(
            symbol=generate_symbol(ib_contract),
            exchange=EXCHANGE_IB2VT.get(ib_contract.exchange, Exchange.SMART),
            type=ORDERTYPE_IB2VT[ib_order.orderType],
            orderid=orderid,
            direction=DIRECTION_IB2VT[ib_order.action],
            volume=ib_order.totalQuantity,
            gateway_name=self.gateway_name,
        )

        if order.type == OrderType.LIMIT:
            order.price = ib_order.lmtPrice
        elif order.type == OrderType.STOP:
            order.price = ib_order.auxPrice

        self.orders[orderid] = order
        self.gateway.on_order(copy(order))

    def updateAccountValue(  # pylint: disable=invalid-name
            self, key: str, val: str, currency: str, accountName: str):
        """
        Callback of account update.
        """
        super().updateAccountValue(key, val, currency, accountName)

        if not currency or key not in ACCOUNTFIELD_IB2VT:
            return

        accountid = f"{accountName}.{currency}"
        account = self.accounts.get(accountid, None)
        if not account:
            account = AccountData(accountid=accountid,
                                  gateway_name=self.gateway_name)
            self.accounts[accountid] = account

        name = ACCOUNTFIELD_IB2VT[key]
        setattr(account, name, float(val))

    def updatePortfolio(  # pylint: disable=invalid-name
        self,
        contract: Contract,
        position: float,
        marketPrice: float,
        marketValue: float,
        averageCost: float,
        unrealizedPNL: float,
        realizedPNL: float,
        accountName: str,
    ):
        """
        Callback of position update.
        """
        super().updatePortfolio(
            contract,
            position,
            marketPrice,
            marketValue,
            averageCost,
            unrealizedPNL,
            realizedPNL,
            accountName,
        )

        if contract.exchange:
            exchange = EXCHANGE_IB2VT.get(contract.exchange, None)
        elif contract.primaryExchange:
            exchange = EXCHANGE_IB2VT.get(contract.primaryExchange, None)
        else:
            exchange = Exchange.SMART  # Use smart routing for default

        if not exchange:
            msg = f"存在不支持的交易所持仓{generate_symbol(contract)} {contract.exchange} {contract.primaryExchange}"
            self.gateway.write_log(msg)
            return

        try:
            ib_size = int(contract.multiplier)
        except ValueError:
            ib_size = 1
        price = averageCost / ib_size

        pos = PositionData(
            symbol=generate_symbol(contract),
            exchange=exchange,
            direction=Direction.NET,
            volume=position,
            price=price,
            pnl=unrealizedPNL,
            gateway_name=self.gateway_name,
        )
        self.gateway.on_position(pos)

    def updateAccountTime(self, timeStamp: str):  # pylint: disable=invalid-name
        """
        Callback of account update time.
        """
        super().updateAccountTime(timeStamp)
        for account in self.accounts.values():
            self.gateway.on_account(copy(account))

    def contractDetails(self, reqId: int, contractDetails: ContractDetails):  # pylint: disable=invalid-name
        """
        Callback of contract data update.
        """
        super().contractDetails(reqId, contractDetails)

        # Generate symbol from ib contract details
        ib_contract = contractDetails.contract
        if not ib_contract.multiplier:
            ib_contract.multiplier = 1

        symbol = generate_symbol(ib_contract)

        # Generate contract
        contract = ContractData(
            symbol=symbol,
            exchange=EXCHANGE_IB2VT[ib_contract.exchange],
            name=contractDetails.longName,
            product=PRODUCT_IB2VT[ib_contract.secType],
            size=int(ib_contract.multiplier),
            pricetick=contractDetails.minTick,
            net_position=True,
            history_data=True,
            stop_supported=True,
            gateway_name=self.gateway_name,
        )

        if contract.vt_symbol not in self.contracts:
            self.gateway.on_contract(contract)

            self.contracts[contract.vt_symbol] = contract
            self.save_contract_data()

    def execDetails(self, reqId: int, contract: Contract,
                    execution: Execution):  # pylint: disable=invalid-name
        """
        Callback of trade data update.
        """
        super().execDetails(reqId, contract, execution)

        dt = datetime.strptime(execution.time, "%Y%m%d  %H:%M:%S")
        dt = self.local_tz.localize(dt)

        trade = TradeData(
            symbol=generate_symbol(contract),
            exchange=EXCHANGE_IB2VT.get(contract.exchange, Exchange.SMART),
            orderid=str(execution.orderId),
            tradeid=str(execution.execId),
            direction=DIRECTION_IB2VT[execution.side],
            price=execution.price,
            volume=execution.shares,
            datetime=dt,
            gateway_name=self.gateway_name,
        )

        self.gateway.on_trade(trade)

    def managedAccounts(self, accountsList: str):
        """
        Callback of all sub accountid.
        """
        super().managedAccounts(accountsList)

        if not self.account:
            for account_code in accountsList.split(","):
                self.account = account_code

        self.gateway.write_log(f"当前使用的交易账号为{self.account}")
        self.client.reqAccountUpdates(True, self.account)

    def historicalData(self, reqId: int, ib_bar: IbBarData):
        """
        Callback of history data update.
        """
        dt = datetime.strptime(ib_bar.date, "%Y%m%d %H:%M:%S")
        dt = self.local_tz.localize(dt)

        bar = BarData(symbol=self.history_req.symbol,
                      exchange=self.history_req.exchange,
                      datetime=dt,
                      interval=self.history_req.interval,
                      volume=ib_bar.volume,
                      open_price=ib_bar.open,
                      high_price=ib_bar.high,
                      low_price=ib_bar.low,
                      close_price=ib_bar.close,
                      gateway_name=self.gateway_name)

        self.history_buf.append(bar)

    def historicalDataEnd(self, reqId: int, start: str, end: str):
        """
        Callback of history data finished.
        """
        self.history_condition.acquire()
        self.history_condition.notify()
        self.history_condition.release()

    def connect(self, host: str, port: int, clientid: int, account: str):
        """
        Connect to TWS.
        """
        if self.status:
            return

        self.host = host
        self.port = port
        self.clientid = clientid
        self.account = account

        self.client.connect(host, port, clientid)
        self.thread = Thread(target=self.client.run)
        self.thread.start()

    def check_connection(self):
        """"""
        if self.client.isConnected():
            return

        if self.status:
            self.close()

        self.client.connect(self.host, self.port, self.clientid)

        self.thread = Thread(target=self.client.run)
        self.thread.start()

    def close(self):
        """
        Disconnect to TWS.
        """
        if not self.status:
            return

        self.status = False
        self.client.disconnect()

    def subscribe(self, req: SubscribeRequest):
        """
        Subscribe tick data update.
        """
        if not self.status:
            return

        if req.exchange not in EXCHANGE_VT2IB:
            self.gateway.write_log(f"不支持的交易所{req.exchange}")
            return

        # Filter duplicate subscribe
        if req.vt_symbol in self.subscribed:
            return
        self.subscribed[req.vt_symbol] = req

        # Extract ib contract detail
        ib_contract = generate_ib_contract(req.symbol, req.exchange)
        if not ib_contract:
            self.gateway.write_log("代码解析失败,请检查格式是否正确")
            return

        # Get contract data from TWS.
        self.reqid += 1
        self.client.reqContractDetails(self.reqid, ib_contract)

        # Subscribe tick data and create tick object buffer.
        self.reqid += 1
        self.client.reqMktData(self.reqid, ib_contract, "", False, False, [])

        tick = TickData(
            symbol=req.symbol,
            exchange=req.exchange,
            datetime=datetime.now(self.local_tz),
            gateway_name=self.gateway_name,
        )
        self.ticks[self.reqid] = tick
        self.tick_exchange[self.reqid] = req.exchange

    def send_order(self, req: OrderRequest):
        """
        Send a new order.
        """
        if not self.status:
            return ""

        if req.exchange not in EXCHANGE_VT2IB:
            self.gateway.write_log(f"不支持的交易所:{req.exchange}")
            return ""

        if req.type not in ORDERTYPE_VT2IB:
            self.gateway.write_log(f"不支持的价格类型:{req.type}")
            return ""

        self.orderid += 1

        ib_contract = generate_ib_contract(req.symbol, req.exchange)
        if not ib_contract:
            return ""

        ib_order = Order()
        ib_order.orderId = self.orderid
        ib_order.clientId = self.clientid
        ib_order.action = DIRECTION_VT2IB[req.direction]
        ib_order.orderType = ORDERTYPE_VT2IB[req.type]
        ib_order.totalQuantity = req.volume
        ib_order.account = self.account

        if req.type == OrderType.LIMIT:
            ib_order.lmtPrice = req.price
        elif req.type == OrderType.STOP:
            ib_order.auxPrice = req.price

        self.client.placeOrder(self.orderid, ib_contract, ib_order)
        self.client.reqIds(1)

        order = req.create_order_data(str(self.orderid), self.gateway_name)
        self.gateway.on_order(order)
        return order.vt_orderid

    def cancel_order(self, req: CancelRequest):
        """
        Cancel an existing order.
        """
        if not self.status:
            return

        self.client.cancelOrder(int(req.orderid))

    def query_history(self, req: HistoryRequest):
        """"""
        self.history_req = req

        self.reqid += 1

        ib_contract = generate_ib_contract(req.symbol, req.exchange)

        if req.end:
            end = req.end
            end_str = end.strftime("%Y%m%d %H:%M:%S")
        else:
            end = datetime.now(self.local_tz)
            end_str = ""

        delta = end - req.start
        days = min(delta.days, 180)  # IB only provides 6-month data
        duration = f"{days} D"
        bar_size = INTERVAL_VT2IB[req.interval]

        if req.exchange == Exchange.IDEALPRO:
            bar_type = "MIDPOINT"
        else:
            bar_type = "TRADES"

        self.history_reqid = self.reqid
        self.client.reqHistoricalData(self.reqid, ib_contract, end_str,
                                      duration, bar_size, bar_type, 0, 1,
                                      False, [])

        self.history_condition.acquire()  # Wait for async data return
        self.history_condition.wait()
        self.history_condition.release()

        history = self.history_buf
        self.history_buf = []  # Create new buffer list
        self.history_req = None

        return history

    def load_contract_data(self):
        """"""
        f = shelve.open(self.data_filepath)
        self.contracts = f.get("contracts", {})
        f.close()

        for contract in self.contracts.values():
            self.gateway.on_contract(contract)

        self.gateway.write_log("本地缓存合约信息加载成功")

    def save_contract_data(self):
        """"""
        f = shelve.open(self.data_filepath)
        f["contracts"] = self.contracts
        f.close()
コード例 #8
0
    def historicalDataEnd(self, reqId: int, start: str, end: str):
        #8 data is finished
        print("HistoricalDataEnd. ReqId:", reqId, "from", start, "to", end)
        #9 this is the logical end of your program
        app.disconnect()
        print("finished")

    def error(self, reqId, errorCode, errorString):
        # these messages can come anytime.
        print("Error. Id: ", reqId, " Code: ", errorCode, " Msg: ",
              errorString)

    def start(self):
        queryTime = (datetime.datetime.today() -
                     datetime.timedelta(days=180)).strftime("%Y%m%d %H:%M:%S")

        fx = Contract()
        fx.secType = "CASH"
        fx.symbol = "USD"
        fx.currency = "JPY"
        fx.exchange = "IDEALPRO"

        #6 request data, using fx since I don't have Japanese data
        app.reqHistoricalData(4102, fx, queryTime, "1 M", "1 day", "MIDPOINT",
                              1, 1, False, [])


app = EClient(MyWrapper())  #1 create wrapper subclass and pass it to EClient
app.connect("127.0.0.1", 7497, clientId=123)  #2 connect to TWS/IBG
app.run()  #3 start message thread
コード例 #9
0
class App:
    def __init__(self, ip_addr='127.0.0.1', port=7497, clientId=1):

        # Wrapper Methods
        def nextValidId(reqId):
            q.put(reqId)

        def connectionClosed():
            print('CONNECTION HAS CLOSED')

        def error(reqId, errorCode: int, errorString: str):
            if errorCode != 2104 and errorCode != 2106:
                print("Error. Id: ", reqId, " Code: ", errorCode, " Msg: ",
                      errorString)
                if errorCode == 10167 and 'Displaying delayed market data' in errorString:
                    pass
                elif reqId != -1:
                    self.data_errors_q.put((errorString, reqId))
                if 'pacing violation' in errorString:
                    self.slowdown = True

        # Wrapper Methods End

        self.wrapper = EWrapper()
        self.client = EClient(self.wrapper)

        self.client.connect(ip_addr, port, clientId)
        self.ip_addr = ip_addr
        self.my_port = port
        self.my_clientId = clientId
        self.resetData()

        # Wrap wrapper methods
        self.wrap(error)
        self.wrap(connectionClosed)
        self.wrap(nextValidId)

        q = queue.Queue()
        self._thread = Thread(target=self.client.run)
        self._thread.start()
        # Once we get a reqID, we know we can start
        self._reqId = q.get()

    def wrap(self, method):
        def f(wrapper, *args):
            return method(*args)

        name = method.__name__.split('.')[-1]
        setattr(self.wrapper.__class__, name, f)

    def getReqId(self):
        reqId = self._reqId
        self._reqId += 1
        return reqId

    def resetData(self):
        # Historical Data
        self.hist_data_q = queue.Queue()
        self.hist_data_dict_q = {}

        # Fundamental Data
        self.fundamental_data_q = queue.Queue()
        self.slowdown = False

        # Price Data
        self.price_queue = queue.Queue()
        self.close_price_queue = queue.Queue()

        # Dict to map reqId w/ Symbols
        self.reqId_map = {}

        # Errors
        self.data_errors_q = queue.Queue()

    ### Client Functions (with wrapper methods as nested functions) ###

    def getAccounts(self):
        def managedAccounts(accountsList):
            q.put(accountsList)

        q = queue.Queue()
        self.wrap(managedAccounts)
        self.client.reqManagedAccts()
        return q.get()

    def getPositions(self, account):
        def positionMulti(reqId: int, account: str, modelCode: str,
                          contract: Contract, pos: float, avgCost: float):
            positions.append((contract, pos, avgCost))

        def positionMultiEnd(reqId: int):
            q.put(None)

        positions = []
        q = queue.Queue()
        self.wrap(positionMulti)
        self.wrap(positionMultiEnd)
        self.client.reqPositionsMulti(self.getReqId(), account, "")
        q.get()
        return self.getDFPositions(positions)

    def getOrders(self):
        def openOrder(orderId, contract, order, orderState):
            orders.append((contract, order, orderState))

        def openOrderEnd():
            q.put(None)

        orders = []
        q = queue.Queue()
        self.wrap(openOrder)
        self.wrap(openOrderEnd)
        self.client.reqOpenOrders()
        q.get()
        return self.getDFOrders(orders)

    def sellPosition(self, ticker, secType, orders, positions):
        pos = self.getPosDetails(ticker, secType, positions)
        if pos.shape[0] > 1:
            print('Multiple matching positions, defaulting to first record')
            pos = pos.head(0)
        contract = self.createContract(ticker, secType, "USD", "SMART")
        if int(pos['pos']) > 0:
            order = Order()
            order.action = "SELL"
            order.orderType = "MKT"
            order.totalQuantity = int(pos['pos'])
            if not self.duplicateOrder(ticker, secType, order, orders):
                print('Placing SELL order for: ' + ticker)
                self.place_order(contract, order)

    def getHistoricalData(self, contract, duration):
        '''
        Requests historical daily prices

          Input:
            duration: Duration string e.g. "1 Y", "6 M", "3 D", etc
        '''
        def historicalData(reqId: int, bar):
            self.hist_data_dict_q[reqId].put(bar)

        def historicalDataEnd(reqId: int, start: str, end: str):
            dates = []
            prices = []
            while not self.hist_data_dict_q[reqId].empty():
                bar = self.hist_data_dict_q[reqId].get()
                dates.append(bar.date)
                prices.append(bar.close)
            data = {'date': dates, 'price': prices}
            df = pandas.DataFrame(data=data)
            self.hist_data_q.put((df, reqId))

        self.wrap(historicalData)
        self.wrap(historicalDataEnd)
        queryTime = datetime.datetime.today().strftime("%Y%m%d %H:%M:%S")
        reqId = self.getReqId()
        self.client.reqHistoricalData(reqId, contract, queryTime, duration,
                                      "1 day", "MIDPOINT", 1, 1, False, [])
        self.reqId_map[reqId] = contract.symbol
        self.hist_data_dict_q[reqId] = queue.Queue()

    def getPrice(self, contract):
        '''
        Requests last trade price
        '''
        def tickPrice(reqId, tickType, price: float, attrib):
            if price == -1:
                # print("No Price Data currently available")
                pass
            # Last price
            elif tickType == 4:
                self.price_queue.put((price, reqId))
            # Previous Close Price
            elif tickType == 9:
                self.close_price_queue.put((price, reqId))
            # Delayed Last Price
            elif tickType == 68:
                self.price_queue.put((price, reqId))
            # Delayed Close Price
            elif tickType == 75:
                self.close_price_queue.put((price, reqId))

        self.wrap(tickPrice)
        if contract.currency != 'USD':
            self.client.reqMarketDataType(3)
        reqId = self.getReqId()
        self.client.reqMktData(reqId, contract, "", True, False, [])
        self.reqId_map[reqId] = contract.symbol
        if contract.currency != 'USD':
            # Go back to live/frozen
            self.client.reqMarketDataType(2)

    def findContracts(self, sybmol):
        self.client.reqMatchingSymbols(self.getReqId(), sybmol)

    def place_order(self, contract, order):
        self.client.placeOrder(self.getReqId(), contract, order)

    def getContractDetails(self,
                           symbol,
                           secType,
                           currency=None,
                           exchange=None):
        def contractDetails(reqId: int, contractDetails):
            contract_details.append(contractDetails)

        def contractDetailsEnd(reqId: int):
            q.put(None)

        contract_details = []
        q = queue.Queue()
        self.wrap(contractDetails)
        self.wrap(contractDetailsEnd)
        contract = Contract()
        contract.symbol = symbol
        contract.secType = secType
        if currency is not None:
            contract.currency = currency
        if exchange is not None:
            contract.exchange = exchange
        self.client.reqContractDetails(self.getReqId(), contract)
        q.get()
        return contract_details

    def getYield(self, contract, data_type=3):
        def tickString(reqId, tickType, value: str):
            for val in value.split(';'):
                if 'YIELD' in val:
                    div.append(float(val.split('=')[1]) / 100)
            # If ';' in the response, then we know we got the data
            if ';' in value:
                self.client.cancelMktData(reqId)
                q.put(None)

        div = []
        q = queue.Queue()
        self.wrap(tickString)
        # Switch to live (1) frozen (2) delayed (3) delayed frozen (4).
        # MarketDataTypeEnum.DELAYED
        if contract.currency != 'USD':
            self.client.reqMarketDataType(data_type)
        self.client.reqMktData(self.getReqId(), contract, "258", False, False,
                               [])
        if contract.currency != 'USD':
            # Go back to live/frozen
            self.client.reqMarketDataType(2)
        q.get()
        if div:
            return div[0]
        else:
            return 0

    def getFinStatements(self, contract, data_type):
        def fundamentalData(reqId, data: str):
            self.fundamental_data_q.put((data, reqId))
            self.client.cancelFundamentalData(reqId)

        self.wrap(fundamentalData)
        reqId = self.getReqId()
        self.client.reqFundamentalData(reqId, contract, data_type, [])
        self.reqId_map[reqId] = contract.symbol

    ### Client Functions End ###

    ### HELPER FUNCTIONS ###

    def getDFPositions(self, positions):
        symbols = []
        types = []
        currencies = []
        sizes = []
        avg_costs = []
        for pos in positions:
            contract, size, cost = pos
            symbols.append(contract.symbol)
            types.append(contract.secType)
            currencies.append(contract.currency)
            sizes.append(size)
            avg_costs.append(cost)
        data = {
            'symbol': symbols,
            'secType': types,
            'currency': currencies,
            'pos': sizes,
            'avg_cost': avg_costs
        }
        return pandas.DataFrame(data=data)

    def getDFOrders(self, orders):
        symbols = []
        types = []
        actions = []
        quantities = []
        status = []
        for o in orders:
            contract, order, orderState = o
            symbols.append(contract.symbol)
            types.append(contract.secType)
            actions.append(order.action)
            quantities.append(order.totalQuantity)
            status.append(orderState.status)
        data = {
            'symbol': symbols,
            'secType': types,
            'action': actions,
            'quantity': quantities,
            'status': status
        }
        return pandas.DataFrame(data=data)

    def createContract(self,
                       symbol,
                       secType,
                       currency,
                       exchange,
                       primaryExchange=None,
                       right=None,
                       strike=None,
                       expiry=None):
        contract = Contract()
        if type(symbol) is list:
            # Foreign stocks
            print(symbol[0], symbol[1])
            contract.symbol = symbol[0]
            contract.currency = symbol[1]
        else:
            contract.symbol = symbol
            contract.currency = currency
            if primaryExchange:
                contract.primaryExchange = primaryExchange
        contract.secType = secType
        contract.exchange = exchange
        if right:
            contract.right = right
        if strike:
            contract.strike = strike
        if expiry:
            contract.lastTradeDateOrContractMonth = expiry
        return contract

    def createOptionContract(self, symbol, currency, exchange):
        contract = Contract()
        contract.symbol = symbol
        contract.secType = "OPT"
        contract.exchange = exchange
        contract.currency = currency
        contract.lastTradeDateOrContractMonth = "201901"
        contract.strike = 150
        contract.right = "C"
        contract.multiplier = "100"
        return contract

    def portfolioCheck(self, ticker, positions):
        '''
        Output: Boolean
            True if ticker is in portfolio, is a stock, and position is > 0
            False otherwise
        '''
        matching_ticker_df = positions[positions['symbol'].str.match("^%s$" %
                                                                     ticker)]
        matching_type_df = matching_ticker_df[
            matching_ticker_df['secType'].str.match("^STK$")]
        return ((matching_type_df['pos'] > 0).any())

    def calcOrderSize(self, price, size):
        '''
        Determines how large the order should be

        Input:
            price: Current share price (float)
            size: How large we want our order to be, in dollar terms (int)

        Output:
            int: number of shares to buy
            Will default to 1 if price > size
        '''
        if price > size:
            return 1
        else:
            return int(size / price)

    def getPosDetails(self, ticker, secType, positions):
        '''
        Returns a dataframe of position details given a ticker and security type
        '''
        matching_ticker_df = positions[positions['symbol'].str.match("^%s$" %
                                                                     ticker)]
        return matching_ticker_df[matching_ticker_df['secType'].str.match(
            "^" + secType + "$")]

    def duplicateOrder(self, ticker, secType, order, orders):
        if not orders.empty:
            return ((orders['symbol'] == ticker) &
                    (orders['secType'] == secType) &
                    (orders['action'] == order.action) &
                    (orders['quantity'] == order.totalQuantity) &
                    (orders['status'] == 'PreSubmitted')).any()
        else:
            return False

    def parseFinancials(self, data, quarterly=False):
        accepted_reports = ["10-K", "10-Q", "Interim Report", "ARS"]

        fundamental_data = xmltodict.parse(data)
        if fundamental_data['ReportFinancialStatements'][
                'FinancialStatements'] is None:
            print('No Fundamental Data')
            if quarterly:
                return None, None, None, None
            return None, None
        try:
            coaMap = fundamental_data['ReportFinancialStatements'][
                'FinancialStatements']['COAMap']
            annuals = fundamental_data['ReportFinancialStatements'][
                'FinancialStatements']['AnnualPeriods']['FiscalPeriod']
            interims = fundamental_data['ReportFinancialStatements'][
                'FinancialStatements']['InterimPeriods']['FiscalPeriod']
        except:
            print('ERROR with fundamental data')
            print(fundamental_data)
            return None, None, None, None

        if quarterly:
            qtr1 = None
            qtr2 = None
            qtr3 = None
            qtr4 = None
            for s in interims:  # loops through each quarterly report
                parsed = {}
                if type(s['Statement']) == list and s['Statement'][0][
                        'FPHeader']['Source']['#text'] in accepted_reports:
                    data = s['Statement']
                    for item in data:  # loops through income statement, balance sheet, and income statement
                        # print(item['@Type'])   ---- this is either INC, BAL, or CAS
                        for i in item['lineItem']:
                            try:
                                parsed[coaCodes.coaCode_map[
                                    i['@coaCode']]] = float(i['#text'])
                            except KeyError:
                                print('Could not find coaCode!!!')
                                print(i['@coaCode'])
                                print(coaMap)
                    if qtr1 is None:
                        qtr1 = parsed
                    elif qtr2 is None:
                        qtr2 = parsed
                    elif qtr3 is None:
                        qtr3 = parsed
                    elif qtr4 is None:
                        qtr4 = parsed
            return qtr1, qtr2, qtr3, qtr4
        else:
            current_annual = None
            prev_annual = None
            # only one annual report
            if type(annuals) != list:
                if annuals['Statement'][0]['FPHeader']['Source'][
                        '#text'] in accepted_reports:
                    # making it a list to work in the for loop below
                    annuals = [annuals]
                else:
                    # No annual reports that are of accepted type
                    return current_annual, prev_annual
            for s in annuals:  # loops through each annual report
                parsed = {}
                if type(s['Statement']) == list and s['Statement'][0][
                        'FPHeader']['Source']['#text'] in accepted_reports:
                    data = s['Statement']
                    for item in data:  # loops through income statement, balance sheet, and income statement
                        # print(item['@Type'])   ---- this is either INC, BAL, or CAS
                        for i in item['lineItem']:
                            try:
                                parsed[coaCodes.coaCode_map[
                                    i['@coaCode']]] = float(i['#text'])
                            except KeyError:
                                print('Could not find coaCode!!!')
                                print(i['@coaCode'])
                                print(coaMap)
                    if current_annual is None:
                        current_annual = parsed
                    elif prev_annual is None:
                        prev_annual = parsed
            return current_annual, prev_annual
コード例 #10
0
ファイル: ib_broker.py プロジェクト: mborraty/qf-lib
class IBBroker(Broker):
    def __init__(self):
        self.logger = ib_logger.getChild(self.__class__.__name__)
        # lock that synchronizes entries into the functions and
        # makes sure we have a synchronous communication with client
        self.lock = Lock()
        self.waiting_time = 30  # expressed in seconds
        # lock that informs us that wrapper received the response
        self.action_event_lock = Event()
        self.wrapper = IBWrapper(self.action_event_lock)
        self.client = EClient(wrapper=self.wrapper)
        self.client.connect("127.0.0.1", 7497, clientId=0)

        # run the client in the separate thread so that the execution of the program can go on
        # now we will have 3 threads:
        # - thread of the main program
        # - thread of the client
        # - thread of the wrapper
        thread = Thread(target=self.client.run)
        thread.start()

        # this will be released after the client initialises and wrapper receives the nextValidOrderId
        if not self._wait_for_results():
            raise ConnectionError("IB Broker was not initialized correctly")

    def get_portfolio_value(self) -> float:
        with self.lock:
            request_id = 1
            self._reset_action_lock()
            self.client.reqAccountSummary(request_id, 'All', 'NetLiquidation')
            wait_result = self._wait_for_results()
            self.client.cancelAccountSummary(request_id)

            if wait_result:
                return self.wrapper.net_liquidation
            else:
                error_msg = 'Time out while getting portfolio value'
                self.logger.error('===> {}'.format(error_msg))
                raise BrokerException(error_msg)

    def get_portfolio_tag(self, tag: str) -> float:
        with self.lock:
            request_id = 2
            self._reset_action_lock()
            self.client.reqAccountSummary(request_id, 'All', tag)
            wait_result = self._wait_for_results()
            self.client.cancelAccountSummary(request_id)

            if wait_result:
                return self.wrapper.tmp_value
            else:
                error_msg = 'Time out while getting portfolio tag: {}'.format(tag)
                self.logger.error('===> {}'.format(error_msg))
                raise BrokerException(error_msg)

    def get_positions(self) -> List[Position]:
        with self.lock:
            self._reset_action_lock()
            self.wrapper.reset_position_list()
            self.client.reqPositions()

            if self._wait_for_results():
                return self.wrapper.position_list
            else:
                error_msg = 'Time out while getting positions'
                self.logger.error('===> {}'.format(error_msg))
                raise BrokerException(error_msg)

    def place_orders(self, orders: Sequence[Order]) -> Sequence[int]:
        order_ids_list = []
        for order in orders:
            self.logger.info('Placing Order: {}'.format(order))
            order_id = self._execute_single_order(order)
            order_ids_list.append(order_id)

        return order_ids_list

    def cancel_order(self, order_id: int):
        with self.lock:
            self.logger.info('cancel_order: {}'.format(order_id))
            self._reset_action_lock()
            self.wrapper.set_cancel_order_id(order_id)
            self.client.cancelOrder(order_id)

            if not self._wait_for_results():
                error_msg = 'Time out while cancelling order id {} : \n'.format(order_id)
                self.logger.error('===> {}'.format(error_msg))
                raise OrderCancellingException(error_msg)

    def get_open_orders(self) -> List[Order]:
        with self.lock:
            self._reset_action_lock()
            self.wrapper.reset_order_list()
            self.client.reqOpenOrders()

            if self._wait_for_results():
                return self.wrapper.order_list
            else:
                error_msg = 'Time out while getting orders'
                self.logger.error('===> {}'.format(error_msg))
                raise BrokerException(error_msg)

    def cancel_all_open_orders(self):
        """
        There is now way to check if cancelling of all orders was finished.
        One can only get open orders and confirm that the list is empty
        """
        with self.lock:
            self.client.reqGlobalCancel()
            self.logger.info('cancel_all_open_orders')

    def _execute_single_order(self, order) -> int:
        with self.lock:
            order_id = self.wrapper.next_order_id()

            self._reset_action_lock()
            self.wrapper.set_waiting_order_id(order_id)

            ib_contract = self._to_ib_contract(order.contract)
            ib_order = self._to_ib_order(order)
            self.client.placeOrder(order_id, ib_contract, ib_order)

            if self._wait_for_results():
                return order_id
            else:
                error_msg = 'Time out while placing the trade for: \n\torder: {}'.format(order)
                self.logger.error('===> {}'.format(error_msg))
                raise BrokerException(error_msg)

    def _wait_for_results(self) -> bool:
        """ Wait for self.waiting_time """
        wait_result = self.action_event_lock.wait(self.waiting_time)
        return wait_result

    def _reset_action_lock(self):
        """ threads calling wait() will block until set() is called"""
        self.action_event_lock.clear()

    def _to_ib_contract(self, contract: Contract):
        ib_contract = IBContract()
        ib_contract.symbol = contract.symbol
        ib_contract.secType = contract.security_type
        ib_contract.exchange = contract.exchange
        return ib_contract

    def _to_ib_order(self, order: Order):
        ib_order = IBOrder()

        if order.quantity > 0:
            ib_order.action = 'BUY'
        else:
            ib_order.action = 'SELL'

        ib_order.totalQuantity = abs(order.quantity)

        execution_style = order.execution_style
        self._set_execution_style(ib_order, execution_style)

        time_in_force = order.time_in_force
        tif_str = self._map_to_tif_str(time_in_force)
        ib_order.tif = tif_str

        return ib_order

    def _map_to_tif_str(self, time_in_force):
        if time_in_force == TimeInForce.GTC:
            tif_str = "GTC"
        elif time_in_force == TimeInForce.DAY:
            tif_str = "DAY"
        elif time_in_force == TimeInForce.OPG:
            tif_str = "OPG"
        else:
            raise ValueError("Not supported TimeInForce {tif:s}".format(tif=str(time_in_force)))

        return tif_str

    def _set_execution_style(self, ib_order, execution_style):
        if isinstance(execution_style, MarketOrder):
            ib_order.orderType = "MKT"
        elif isinstance(execution_style, StopOrder):
            ib_order.orderType = "STP"
            ib_order.auxPrice = execution_style.stop_price
コード例 #11
0
ファイル: wrapper.py プロジェクト: ajmal017/backtrader-ib-api
class RequestWrapper(EWrapper):
    """ Wrapper that turns the callback-based IB API Wrapper into a blocking API, by collecting results into tables
    and returning the complete tables.
    """

    REQUEST_OPTIONS_HISTORICAL_TYPE = [
        "TRADES",
        "MIDPOINT",
        "BID",
        "ASK",
        "BID_ASK",
        "HISTORICAL_VOLATILITY",
        "OPTION_IMPLIED_VOLATILITY",
    ]

    REQUEST_OPTIONS_BAR_SIZE = [
        "1 sec",
        "5 secs",
        "15 secs",
        "30 secs",
        "1 min",
        "2 mins",
        "3 mins",
        "5 mins",
        "15 mins",
        "30 mins",
        "1 hour",
        "1 day",
    ]

    def __init__(self, timeout: int = None):
        """
        Create an EWrapper to provide blocking access to the callback-based IB API.
        :param timeout: Amount of time in seconds to wait for a response before giving up. Use None to never give up.
        """
        EWrapper.__init__(self)
        self.timeout = timeout
        self._app = None
        self.connected = Event()
        self.pending_responses = {}
        self.next_request_id = 0
        self.thread = None

    def start_app(self, host: str, port: int, client_id: int):
        """ Start a connection ton IB TWS application in a background thread and confirm connection is successful.
        :param host: Hostname to connect to, usually 127.0.0.1
        :param port: Port to connect to, configurable and differs for live vs paper trading.
        :param client_id: Client ID setting for the TWS API
        """
        self._app = EClient(wrapper=self)
        self.connected.clear()
        self.next_request_id = 0
        self._app.connect(host, port, client_id)
        self.thread = Thread(target=self._app.run, daemon=True)
        self.thread.start()
        # connectAck will set the connected event once called
        self.connected.wait(timeout=self.timeout)

    def stop_app(self):
        """ Disconnect from the IB TWS and wait for the background thread to end. """
        self._app.disconnect()
        self.thread.join()

    @property
    def app(self):
        """ The currently running application representing the connection to the IB TWS """
        return self._app

    def request_stock_details(self, ticker: str, **kwargs):
        """ Performs a search using the ticker and provides a table of results including
        the general information about each match.
        :param ticker: stock ticker to search

        :Keyword Arguments:
            * *exchange* (``str``) --
              Exchange to look on, i.e. "SMART"
            * *currency* (``str``) --
              Currency to report information in, i.e. "USD"
        """
        response = StockDetailsResponse()
        request_id = self._start_request(response)
        contract = self._get_stock_contract(ticker, **kwargs)
        self._app.reqContractDetails(request_id, contract)
        response.finished.wait(timeout=self.timeout)
        return response.table

    def request_option_params(self, ticker: str, contract_id: int):
        """ Request options expiration and strike information about the provided stock ticker and contract_id.
        :param ticker: stock ticker with available options
        :param contract_id: contract ID of the stock with available options, returned by request_stock_details
        """
        response = OptionParamsResponse()
        request_id = self._start_request(response)
        self._app.reqSecDefOptParams(
            request_id,
            ticker,
            "",  # Leave blank so it will return all exchange options
            "STK",
            contract_id)
        response.finished.wait(timeout=self.timeout)
        return response.table

    def request_option_chain(self,
                             ticker: str,
                             exchange: str,
                             expiration: str,
                             currency="USD"):
        """ Request a list of all the options available for a given ticker and expiration.
        :param ticker: stock ticker with available options
        :param exchange: exchange of the options contracts
        :param expiration: expiration of the options contracts, in YYYYMMDD format
        :param currency: currency to report information in
        """
        response = OptionDetailsResponse()
        request_id = self._start_request(response)
        # do not use _get_option_contract shortcut because we are leaving right and strike blank
        contract = Contract()
        contract.secType = "OPT"
        contract.symbol = ticker
        contract.exchange = exchange
        contract.currency = currency
        contract.lastTradeDateOrContractMonth = expiration
        self._app.reqContractDetails(request_id, contract)
        response.finished.wait(timeout=self.timeout)
        return response.table

    def request_stock_trades_history(self, ticker: str, **kwargs):
        """ Request historical data for stock trades for the given ticker
        :param ticker: stock ticker to search

        :Keyword Arguments:
            * *exchange* (``str``) --
              Exchange to look on, i.e. "SMART"
            * *currency* (``str``) --
              Currency to report information in, i.e. "USD"
            * *duration* (``str``) --
              Amount of time to collect data for, i.e. "5 d" for five days of data.
            * *bar_size* (''str'') --
              Time interval that data is reported in, i.e. "30 mins" provides 30 minute bars
            * *query_time* (''str'') --
              End (latest, most recent) datetime of the returned historical data, in format "%Y%m%d %H:%M:%S"
            * *after_hours* (''bool'') --
              If True, data from outside normal market hours for this security are also returned.
        """
        response = HistoricalTradesResponse()
        request_id = self._start_request(response)
        contract = self._get_stock_contract(ticker, **kwargs)
        self._request_historical(request_id, contract, "TRADES", **kwargs)
        response.finished.wait(timeout=self.timeout)
        return response.table

    def request_stock_iv_history(self, ticker: str, **kwargs):
        """ Request historical data for stock implied volatility for the given ticker
        :param ticker: stock ticker to search

        :Keyword Arguments:
            * *exchange* (``str``) --
              Exchange to look on, i.e. "SMART"
            * *currency* (``str``) --
              Currency to report information in, i.e. "USD"
            * *duration* (``str``) --
              Amount of time to collect data for, i.e. "5 d" for five days of data.
            * *bar_size* (''str'') --
              Time interval that data is reported in, i.e. "30 mins" provides 30 minute bars
            * *query_time* (''str'') --
              End (latest, most recent) datetime of the returned historical data, in format "%Y%m%d %H:%M:%S"
            * *after_hours* (''bool'') --
              If True, data from outside normal market hours for this security are also returned.
        """
        response = HistoricalDataResponse()
        request_id = self._start_request(response)
        contract = self._get_stock_contract(ticker, **kwargs)
        self._request_historical(request_id, contract,
                                 "OPTION_IMPLIED_VOLATILITY", **kwargs)
        response.finished.wait(timeout=self.timeout)
        return response.table

    def request_stock_hv_history(self, ticker: str, **kwargs):
        """ Request historical data for stock historical volatility for the given ticker
        :param ticker: stock ticker to search

        :Keyword Arguments:
            * *exchange* (``str``) --
              Exchange to look on, i.e. "SMART"
            * *currency* (``str``) --
              Currency to report information in, i.e. "USD"
            * *duration* (``str``) --
              Amount of time to collect data for, i.e. "5 d" for five days of data.
            * *bar_size* (''str'') --
              Time interval that data is reported in, i.e. "30 mins" provides 30 minute bars
            * *query_time* (''str'') --
              End (latest, most recent) datetime of the returned historical data, in format "%Y%m%d %H:%M:%S"
            * *after_hours* (''bool'') --
              If True, data from outside normal market hours for this security are also returned.
        """
        response = HistoricalDataResponse()
        request_id = self._start_request(response)
        contract = self._get_stock_contract(ticker, **kwargs)
        self._request_historical(request_id, contract, "HISTORICAL_VOLATILITY",
                                 **kwargs)
        response.finished.wait(timeout=self.timeout)
        return response.table

    def request_option_trades_history(self, ticker: str, expiration: str,
                                      strike: float, right: str, **kwargs):
        """ Request historical data for option trades for the given options contract
        :param ticker: stock ticker with available options
        :param expiration: expiration of the options contract, in "%Y%m%d" format
        :param strike: strike price of the options contract
        :param right: "C" for call options and "P" for put options

        :Keyword Arguments:
            * *exchange* (``str``) --
              Exchange to look on, i.e. "SMART"
            * *currency* (``str``) --
              Currency to report information in, i.e. "USD"
            * *duration* (``str``) --
              Amount of time to collect data for, i.e. "5 d" for five days of data.
            * *bar_size* (''str'') --
              Time interval that data is reported in, i.e. "30 mins" provides 30 minute bars
            * *query_time* (''str'') --
              End (latest, most recent) datetime of the returned historical data, in format "%Y%m%d %H:%M:%S"
            * *after_hours* (''bool'') --
              If True, data from outside normal market hours for this security are also returned.
        """
        response = HistoricalTradesResponse()
        request_id = self._start_request(response)
        contract = self._get_option_contract(ticker, expiration, strike, right,
                                             **kwargs)
        self._request_historical(request_id, contract, "TRADES", **kwargs)
        response.finished.wait(timeout=self.timeout)
        return response.table

    def request_option_bidask_history(self, ticker: str, expiration: str,
                                      strike: float, right: str, **kwargs):
        """ Request historical data for option bid and ask for the given options contract
        :param ticker: stock ticker with available options
        :param expiration: expiration of the options contract, in "%Y%m%d" format
        :param strike: strike price of the options contract
        :param right: "C" for call options and "P" for put options

        :Keyword Arguments:
            * *exchange* (``str``) --
              Exchange to look on, i.e. "SMART"
            * *currency* (``str``) --
              Currency to report information in, i.e. "USD"
            * *duration* (``str``) --
              Amount of time to collect data for, i.e. "5 d" for five days of data.
            * *bar_size* (''str'') --
              Time interval that data is reported in, i.e. "30 mins" provides 30 minute bars
            * *query_time* (''str'') --
              End (latest, most recent) datetime of the returned historical data, in format "%Y%m%d %H:%M:%S"
            * *after_hours* (''bool'') --
              If True, data from outside normal market hours for this security are also returned.
        """
        response = HistoricalBidAskResponse()
        request_id = self._start_request(response)
        contract = self._get_option_contract(ticker, expiration, strike, right,
                                             **kwargs)
        self._request_historical(request_id, contract, "BID_ASK", **kwargs)
        response.finished.wait(timeout=self.timeout)
        return response.table

    def _start_request(self, response: Response) -> int:
        """ Gets a request id for a new request, associates it with the given response object,
        then returns the new request id.
        """
        current_id = self.next_request_id
        self.next_request_id += 1
        self.pending_responses[current_id] = response
        return current_id

    @staticmethod
    def _get_stock_contract(ticker: str,
                            exchange="SMART",
                            currency="USD",
                            **_):
        """ Helper function for creating a contract object for use in querying
        data for stocks
        """
        contract = Contract()
        contract.secType = "STK"
        contract.localSymbol = ticker
        contract.exchange = exchange
        contract.currency = currency
        return contract

    @staticmethod
    def _get_option_contract(ticker: str,
                             expiration: str,
                             strike: float,
                             right: str,
                             exchange="SMART",
                             currency="USD",
                             **_):
        """ Helper function for creating a contract object for use in querying
        data for options
        """
        if right not in ["C", "P"]:
            raise ValueError(f"Invalid right: {right}")
        contract = Contract()
        contract.secType = "OPT"
        contract.symbol = ticker
        contract.exchange = exchange
        contract.currency = currency
        contract.lastTradeDateOrContractMonth = expiration
        contract.strike = strike
        contract.right = right
        return contract

    def _request_historical(
            self,
            request_id: int,
            contract: Contract,
            data_type: str,
            duration="5 d",
            bar_size="30 mins",
            query_time=datetime.today().strftime("%Y%m%d %H:%M:%S"),
            after_hours=False,
            **_):
        """ Helper function used to send a request for historical data
        """
        if data_type not in self.REQUEST_OPTIONS_HISTORICAL_TYPE:
            raise ValueError(
                f"Invalid data type '{data_type}'. Valid options: {self.REQUEST_OPTIONS_HISTORICAL_TYPE}"
            )

        if bar_size not in self.REQUEST_OPTIONS_BAR_SIZE:
            raise ValueError(
                f"Invalid data type '{bar_size}'. Valid options: {self.REQUEST_OPTIONS_BAR_SIZE}"
            )

        self._app.reqHistoricalData(reqId=request_id,
                                    contract=contract,
                                    endDateTime=query_time,
                                    durationStr=duration,
                                    barSizeSetting=bar_size,
                                    whatToShow=data_type,
                                    useRTH=0 if after_hours else 1,
                                    formatDate=1,
                                    keepUpToDate=False,
                                    chartOptions=[])

    def _handle_callback(self, callback_name, request_id, *args):
        """ Helper function for IB API callbacks to call to notify the pending
        response object of new data
        """
        try:
            response = self.pending_responses[request_id]
        except KeyError:
            logger.error(f"Unexpected callback {callback_name} had invalid"
                         f"request id '{request_id}'")
            return

        response.handle_response(callback_name, *args)

    # ------------------------------------------------------------------------------------------------------------------
    # Callbacks from the IB TWS
    # ------------------------------------------------------------------------------------------------------------------

    def error(self, req_id: TickerId, error_code: int, error_string: str):
        """This event is called when there is an error with the
        communication or when TWS wants to send a message to the client."""
        logger.error(
            f"{error_string} (req_id:{req_id}, error_code:{error_code})")

        if 2000 <= error_code < 10000:  # non-fatal
            pass
        elif error_code == 10167:  # delayed market data instead
            pass
        else:
            logger.error("Ending response since error code is fatal")
            self._handle_callback("error", req_id, error_code, error_string)

    def connectAck(self):
        super().connectAck()
        logger.info("Connection successful.")
        self.connected.set()

    def contractDetails(self, request_id: int, *args):
        super().contractDetails(request_id, *args)
        self._handle_callback("contractDetails", request_id, *args)

    def contractDetailsEnd(self, request_id: int):
        super().contractDetailsEnd(request_id)
        self._handle_callback("contractDetailsEnd", request_id)

    def securityDefinitionOptionParameter(self, request_id: int, *args):
        super().securityDefinitionOptionParameter(request_id, *args)
        self._handle_callback("securityDefinitionOptionParameter", request_id,
                              *args)

    def securityDefinitionOptionParameterEnd(self, request_id: int):
        """ Called when all callbacks to securityDefinitionOptionParameter are
        complete

        reqId - the ID used in the call to securityDefinitionOptionParameter """
        super().securityDefinitionOptionParameterEnd(request_id)
        self._handle_callback("securityDefinitionOptionParameterEnd",
                              request_id)

    def historicalData(self, request_id: int, *args):
        """ returns the requested historical data bars

        request_id - the request's identifier
        date  - the bar's date and time (either as a yyyymmdd hh:mm:ss formatted
             string or as system time according to the request)
        open  - the bar's open point
        high  - the bar's high point
        low   - the bar's low point
        close - the bar's closing point
        volume - the bar's traded volume if available
        barCount - the number of trades during the bar's timespan (only available
            for TRADES).
        average -   the bar's Weighted Average Price
        """
        super().historicalData(request_id, *args)
        self._handle_callback("historicalData", request_id, *args)

    def historicalDataEnd(self, request_id: int, *args):
        """ Marks the ending of the historical bars reception. """
        super().historicalDataEnd(request_id, *args)
        self._handle_callback("historicalDataEnd", request_id)
コード例 #12
0
    my_wrapper = MyWrapper()

    my_wrapper.current_date = start_date
    my_wrapper.sampling_rate = "5"  # minutes
    my_wrapper.symbol = stock_symbol

    while datetime.datetime.strptime(my_wrapper.current_date,
                                     time_format) > datetime.datetime.strptime(
                                         end_date, time_format):
        try:
            # the duplication of the following three lines allows us to extract continuous data on the same stock
            # another effect of this implementation is at least two rounds of data extraction
            # to sum up, the following code works but sometimes will extract extra data - its OK by me :)
            app = EClient(my_wrapper)
            my_wrapper.did_something = False
            app.connect('127.0.0.1', 7496, clientId=123)
            main()
            if not my_wrapper.did_something:
                # if my_wrapper didn't change the 'just_starting' flag - it is stuck and we need to stop querying the stock
                break
        finally:
            if my_wrapper.the_app_is_down:
                print('It seems like the app is down, closing connection')
                exit()
            else:
                app = EClient(my_wrapper)
                my_wrapper.did_something = False
                app.connect('127.0.0.1', 7496, clientId=123)
                main()
                if not my_wrapper.did_something:
                    # if my_wrapper didn't change the 'just_starting' flag - it is stuck and we need to stop querying the stock