class IBBroker(Broker): """ Interactive Brokers Broker class. Main purpose of this class is to connect to the API of IB broker and send the orders. It provides the functionality, which allows to retrieve a.o. the currently open positions and the value of the portfolio. Parameters ----------- contract_ticker_mapper: IBContractTickerMapper mapper which provides the functionality that allows to map a ticker from any data provider (BloombergTicker, PortaraTicker etc.) onto the contract object from the Interactive Brokers API clientId: int id of the Broker client host: str IP address port: int socket port """ def __init__(self, contract_ticker_mapper: IBContractTickerMapper, clientId: int = 0, host: str = "127.0.0.1", port: int = 7497): super().__init__(contract_ticker_mapper) self.logger = ib_logger.getChild(self.__class__.__name__) # Lock that synchronizes entries into the functions and makes sure we have a synchronous communication # with the client self.lock = Lock() self.orders_placement_lock = Lock() self.waiting_time = 30 # expressed in seconds # Lock that informs us that wrapper received the response self.action_event_lock = Event() self.wrapper = IBWrapper(self.action_event_lock, contract_ticker_mapper) self.client = EClient(wrapper=self.wrapper) self.clientId = clientId self.client.connect(host, port, self.clientId) # Run the client in the separate thread so that the execution of the program can go on # now we will have 3 threads: # - thread of the main program # - thread of the client # - thread of the wrapper thread = Thread(target=self.client.run) thread.start() # This will be released after the client initialises and wrapper receives the nextValidOrderId if not self._wait_for_results(): raise ConnectionError("IB Broker was not initialized correctly") def get_portfolio_value(self) -> float: with self.lock: request_id = 1 self._reset_action_lock() self.client.reqAccountSummary(request_id, 'All', 'NetLiquidation') wait_result = self._wait_for_results() self.client.cancelAccountSummary(request_id) if wait_result: return self.wrapper.net_liquidation else: error_msg = 'Time out while getting portfolio value' self.logger.error(error_msg) raise BrokerException(error_msg) def get_portfolio_tag(self, tag: str) -> float: with self.lock: request_id = 2 self._reset_action_lock() self.client.reqAccountSummary(request_id, 'All', tag) wait_result = self._wait_for_results() self.client.cancelAccountSummary(request_id) if wait_result: return self.wrapper.tmp_value else: error_msg = 'Time out while getting portfolio tag: {}'.format( tag) self.logger.error(error_msg) raise BrokerException(error_msg) def get_positions(self) -> List[BrokerPosition]: with self.lock: self._reset_action_lock() self.wrapper.reset_position_list() self.client.reqPositions() if self._wait_for_results(): return self.wrapper.position_list else: error_msg = 'Time out while getting positions' self.logger.error(error_msg) raise BrokerException(error_msg) def get_liquid_hours(self, contract: IBContract) -> QFDataFrame: """ Returns a QFDataFrame containing information about liquid hours of the given contract. """ with self.lock: self._reset_action_lock() request_id = 3 self.client.reqContractDetails(request_id, contract) if self._wait_for_results(): contract_details = self.wrapper.contract_details liquid_hours = contract_details.tradingHours.split(";") liquid_hours_df = QFDataFrame.from_records([ hours.split("-") for hours in liquid_hours if not hours.endswith("CLOSED") ], columns=[ "FROM", "TO" ]) for col in liquid_hours_df.columns: liquid_hours_df[col] = to_datetime(liquid_hours_df[col], format="%Y%m%d:%H%M") liquid_hours_df.name = contract_details.contract.symbol return liquid_hours_df else: error_msg = 'Time out while getting contract details' self.logger.error(error_msg) raise BrokerException(error_msg) def get_contract_details(self, contract: IBContract) -> ContractDetails: with self.lock: self._reset_action_lock() request_id = 4 self.client.reqContractDetails(request_id, contract) if self._wait_for_results(): return self.wrapper.contract_details else: error_msg = 'Time out while getting contract details' self.logger.error(error_msg) raise BrokerException(error_msg) def place_orders(self, orders: Sequence[Order]) -> Sequence[int]: with self.orders_placement_lock: open_order_ids = {o.id for o in self.get_open_orders()} order_ids_list = [] for order in orders: self.logger.info('Placing Order: {}'.format(order)) order_id = self._execute_single_order( order) or self._find_newly_added_order_id( order, open_order_ids) if order_id is None: error_msg = f"Not able to place order: {order}" self.logger.error(error_msg) raise BrokerException(error_msg) else: order_ids_list.append(order_id) return order_ids_list def cancel_order(self, order_id: int): with self.lock: self.logger.info('Cancel order: {}'.format(order_id)) self._reset_action_lock() self.wrapper.set_cancel_order_id(order_id) self.client.cancelOrder(order_id) if not self._wait_for_results(): error_msg = 'Time out while cancelling order id {} : \n'.format( order_id) self.logger.error(error_msg) raise OrderCancellingException(error_msg) def get_open_orders(self) -> List[Order]: with self.lock: self._reset_action_lock() self.wrapper.reset_order_list() self.client.reqOpenOrders() if self._wait_for_results(): return self.wrapper.order_list else: error_msg = 'Timeout while getting open orders' self.logger.error(error_msg) raise BrokerException(error_msg) def cancel_all_open_orders(self): """ There is no way to check if cancelling of all orders was finished. One can only get open orders and confirm that the list is empty """ with self.lock: self.client.reqGlobalCancel() self.logger.info('cancel_all_open_orders') def stop(self): """ Stop the Broker client and disconnect from the interactive brokers. """ with self.lock: self.client.disconnect() self.logger.info( "Disconnecting from the interactive brokers client") def _find_newly_added_order_id(self, order: Order, order_ids_existing_before: Set[int]): """ Given the list of order ids open before placing the given order, try to compute the id of the recently placed order. """ orders_matching_given_order = { o.id for o in self.get_open_orders() if o == order } order_ids = orders_matching_given_order.difference( order_ids_existing_before) return next(iter(order_ids)) if len(order_ids) == 1 else None def _execute_single_order(self, order) -> Optional[int]: with self.lock: order_id = self.wrapper.next_order_id() self._reset_action_lock() self.wrapper.set_waiting_order_id(order_id) ib_contract = self.contract_ticker_mapper.ticker_to_contract( order.ticker) ib_order = self._to_ib_order(order) self.client.placeOrder(order_id, ib_contract, ib_order) if self._wait_for_results(10): return order_id def _wait_for_results(self, waiting_time: Optional[int] = None) -> bool: """ Wait for self.waiting_time """ waiting_time = waiting_time or self.waiting_time wait_result = self.action_event_lock.wait(waiting_time) return wait_result def _reset_action_lock(self): """ threads calling wait() will block until set() is called""" self.action_event_lock.clear() def _to_ib_order(self, order: Order): ib_order = IBOrder() ib_order.action = 'BUY' if order.quantity > 0 else 'SELL' ib_order.totalQuantity = abs(order.quantity) ib_order = self._set_execution_style(ib_order, order.execution_style) time_in_force = order.time_in_force tif_str = self._map_to_tif_str(time_in_force) ib_order.tif = tif_str return ib_order def _map_to_tif_str(self, time_in_force): if time_in_force == TimeInForce.GTC: tif_str = "GTC" elif time_in_force == TimeInForce.DAY: tif_str = "DAY" elif time_in_force == TimeInForce.OPG: tif_str = "OPG" else: raise ValueError("Not supported TimeInForce {tif:s}".format( tif=str(time_in_force))) return tif_str def _set_execution_style(self, ib_order, execution_style): if isinstance(execution_style, MarketOrder): ib_order.orderType = "MKT" elif isinstance(execution_style, StopOrder): ib_order.orderType = "STP" ib_order.auxPrice = execution_style.stop_price return ib_order
class App: def __init__(self, ip_addr='127.0.0.1', port=7497, clientId=1): # Wrapper Methods def nextValidId(reqId): q.put(reqId) def connectionClosed(): print('CONNECTION HAS CLOSED') def error(reqId, errorCode: int, errorString: str): if errorCode != 2104 and errorCode != 2106: print("Error. Id: ", reqId, " Code: ", errorCode, " Msg: ", errorString) if errorCode == 10167 and 'Displaying delayed market data' in errorString: pass elif reqId != -1: self.data_errors_q.put((errorString, reqId)) if 'pacing violation' in errorString: self.slowdown = True # Wrapper Methods End self.wrapper = EWrapper() self.client = EClient(self.wrapper) self.client.connect(ip_addr, port, clientId) self.ip_addr = ip_addr self.my_port = port self.my_clientId = clientId self.resetData() # Wrap wrapper methods self.wrap(error) self.wrap(connectionClosed) self.wrap(nextValidId) q = queue.Queue() self._thread = Thread(target=self.client.run) self._thread.start() # Once we get a reqID, we know we can start self._reqId = q.get() def wrap(self, method): def f(wrapper, *args): return method(*args) name = method.__name__.split('.')[-1] setattr(self.wrapper.__class__, name, f) def getReqId(self): reqId = self._reqId self._reqId += 1 return reqId def resetData(self): # Historical Data self.hist_data_q = queue.Queue() self.hist_data_dict_q = {} # Fundamental Data self.fundamental_data_q = queue.Queue() self.slowdown = False # Price Data self.price_queue = queue.Queue() self.close_price_queue = queue.Queue() # Dict to map reqId w/ Symbols self.reqId_map = {} # Errors self.data_errors_q = queue.Queue() ### Client Functions (with wrapper methods as nested functions) ### def getAccounts(self): def managedAccounts(accountsList): q.put(accountsList) q = queue.Queue() self.wrap(managedAccounts) self.client.reqManagedAccts() return q.get() def getPositions(self, account): def positionMulti(reqId: int, account: str, modelCode: str, contract: Contract, pos: float, avgCost: float): positions.append((contract, pos, avgCost)) def positionMultiEnd(reqId: int): q.put(None) positions = [] q = queue.Queue() self.wrap(positionMulti) self.wrap(positionMultiEnd) self.client.reqPositionsMulti(self.getReqId(), account, "") q.get() return self.getDFPositions(positions) def getOrders(self): def openOrder(orderId, contract, order, orderState): orders.append((contract, order, orderState)) def openOrderEnd(): q.put(None) orders = [] q = queue.Queue() self.wrap(openOrder) self.wrap(openOrderEnd) self.client.reqOpenOrders() q.get() return self.getDFOrders(orders) def sellPosition(self, ticker, secType, orders, positions): pos = self.getPosDetails(ticker, secType, positions) if pos.shape[0] > 1: print('Multiple matching positions, defaulting to first record') pos = pos.head(0) contract = self.createContract(ticker, secType, "USD", "SMART") if int(pos['pos']) > 0: order = Order() order.action = "SELL" order.orderType = "MKT" order.totalQuantity = int(pos['pos']) if not self.duplicateOrder(ticker, secType, order, orders): print('Placing SELL order for: ' + ticker) self.place_order(contract, order) def getHistoricalData(self, contract, duration): ''' Requests historical daily prices Input: duration: Duration string e.g. "1 Y", "6 M", "3 D", etc ''' def historicalData(reqId: int, bar): self.hist_data_dict_q[reqId].put(bar) def historicalDataEnd(reqId: int, start: str, end: str): dates = [] prices = [] while not self.hist_data_dict_q[reqId].empty(): bar = self.hist_data_dict_q[reqId].get() dates.append(bar.date) prices.append(bar.close) data = {'date': dates, 'price': prices} df = pandas.DataFrame(data=data) self.hist_data_q.put((df, reqId)) self.wrap(historicalData) self.wrap(historicalDataEnd) queryTime = datetime.datetime.today().strftime("%Y%m%d %H:%M:%S") reqId = self.getReqId() self.client.reqHistoricalData(reqId, contract, queryTime, duration, "1 day", "MIDPOINT", 1, 1, False, []) self.reqId_map[reqId] = contract.symbol self.hist_data_dict_q[reqId] = queue.Queue() def getPrice(self, contract): ''' Requests last trade price ''' def tickPrice(reqId, tickType, price: float, attrib): if price == -1: # print("No Price Data currently available") pass # Last price elif tickType == 4: self.price_queue.put((price, reqId)) # Previous Close Price elif tickType == 9: self.close_price_queue.put((price, reqId)) # Delayed Last Price elif tickType == 68: self.price_queue.put((price, reqId)) # Delayed Close Price elif tickType == 75: self.close_price_queue.put((price, reqId)) self.wrap(tickPrice) if contract.currency != 'USD': self.client.reqMarketDataType(3) reqId = self.getReqId() self.client.reqMktData(reqId, contract, "", True, False, []) self.reqId_map[reqId] = contract.symbol if contract.currency != 'USD': # Go back to live/frozen self.client.reqMarketDataType(2) def findContracts(self, sybmol): self.client.reqMatchingSymbols(self.getReqId(), sybmol) def place_order(self, contract, order): self.client.placeOrder(self.getReqId(), contract, order) def getContractDetails(self, symbol, secType, currency=None, exchange=None): def contractDetails(reqId: int, contractDetails): contract_details.append(contractDetails) def contractDetailsEnd(reqId: int): q.put(None) contract_details = [] q = queue.Queue() self.wrap(contractDetails) self.wrap(contractDetailsEnd) contract = Contract() contract.symbol = symbol contract.secType = secType if currency is not None: contract.currency = currency if exchange is not None: contract.exchange = exchange self.client.reqContractDetails(self.getReqId(), contract) q.get() return contract_details def getYield(self, contract, data_type=3): def tickString(reqId, tickType, value: str): for val in value.split(';'): if 'YIELD' in val: div.append(float(val.split('=')[1]) / 100) # If ';' in the response, then we know we got the data if ';' in value: self.client.cancelMktData(reqId) q.put(None) div = [] q = queue.Queue() self.wrap(tickString) # Switch to live (1) frozen (2) delayed (3) delayed frozen (4). # MarketDataTypeEnum.DELAYED if contract.currency != 'USD': self.client.reqMarketDataType(data_type) self.client.reqMktData(self.getReqId(), contract, "258", False, False, []) if contract.currency != 'USD': # Go back to live/frozen self.client.reqMarketDataType(2) q.get() if div: return div[0] else: return 0 def getFinStatements(self, contract, data_type): def fundamentalData(reqId, data: str): self.fundamental_data_q.put((data, reqId)) self.client.cancelFundamentalData(reqId) self.wrap(fundamentalData) reqId = self.getReqId() self.client.reqFundamentalData(reqId, contract, data_type, []) self.reqId_map[reqId] = contract.symbol ### Client Functions End ### ### HELPER FUNCTIONS ### def getDFPositions(self, positions): symbols = [] types = [] currencies = [] sizes = [] avg_costs = [] for pos in positions: contract, size, cost = pos symbols.append(contract.symbol) types.append(contract.secType) currencies.append(contract.currency) sizes.append(size) avg_costs.append(cost) data = { 'symbol': symbols, 'secType': types, 'currency': currencies, 'pos': sizes, 'avg_cost': avg_costs } return pandas.DataFrame(data=data) def getDFOrders(self, orders): symbols = [] types = [] actions = [] quantities = [] status = [] for o in orders: contract, order, orderState = o symbols.append(contract.symbol) types.append(contract.secType) actions.append(order.action) quantities.append(order.totalQuantity) status.append(orderState.status) data = { 'symbol': symbols, 'secType': types, 'action': actions, 'quantity': quantities, 'status': status } return pandas.DataFrame(data=data) def createContract(self, symbol, secType, currency, exchange, primaryExchange=None, right=None, strike=None, expiry=None): contract = Contract() if type(symbol) is list: # Foreign stocks print(symbol[0], symbol[1]) contract.symbol = symbol[0] contract.currency = symbol[1] else: contract.symbol = symbol contract.currency = currency if primaryExchange: contract.primaryExchange = primaryExchange contract.secType = secType contract.exchange = exchange if right: contract.right = right if strike: contract.strike = strike if expiry: contract.lastTradeDateOrContractMonth = expiry return contract def createOptionContract(self, symbol, currency, exchange): contract = Contract() contract.symbol = symbol contract.secType = "OPT" contract.exchange = exchange contract.currency = currency contract.lastTradeDateOrContractMonth = "201901" contract.strike = 150 contract.right = "C" contract.multiplier = "100" return contract def portfolioCheck(self, ticker, positions): ''' Output: Boolean True if ticker is in portfolio, is a stock, and position is > 0 False otherwise ''' matching_ticker_df = positions[positions['symbol'].str.match("^%s$" % ticker)] matching_type_df = matching_ticker_df[ matching_ticker_df['secType'].str.match("^STK$")] return ((matching_type_df['pos'] > 0).any()) def calcOrderSize(self, price, size): ''' Determines how large the order should be Input: price: Current share price (float) size: How large we want our order to be, in dollar terms (int) Output: int: number of shares to buy Will default to 1 if price > size ''' if price > size: return 1 else: return int(size / price) def getPosDetails(self, ticker, secType, positions): ''' Returns a dataframe of position details given a ticker and security type ''' matching_ticker_df = positions[positions['symbol'].str.match("^%s$" % ticker)] return matching_ticker_df[matching_ticker_df['secType'].str.match( "^" + secType + "$")] def duplicateOrder(self, ticker, secType, order, orders): if not orders.empty: return ((orders['symbol'] == ticker) & (orders['secType'] == secType) & (orders['action'] == order.action) & (orders['quantity'] == order.totalQuantity) & (orders['status'] == 'PreSubmitted')).any() else: return False def parseFinancials(self, data, quarterly=False): accepted_reports = ["10-K", "10-Q", "Interim Report", "ARS"] fundamental_data = xmltodict.parse(data) if fundamental_data['ReportFinancialStatements'][ 'FinancialStatements'] is None: print('No Fundamental Data') if quarterly: return None, None, None, None return None, None try: coaMap = fundamental_data['ReportFinancialStatements'][ 'FinancialStatements']['COAMap'] annuals = fundamental_data['ReportFinancialStatements'][ 'FinancialStatements']['AnnualPeriods']['FiscalPeriod'] interims = fundamental_data['ReportFinancialStatements'][ 'FinancialStatements']['InterimPeriods']['FiscalPeriod'] except: print('ERROR with fundamental data') print(fundamental_data) return None, None, None, None if quarterly: qtr1 = None qtr2 = None qtr3 = None qtr4 = None for s in interims: # loops through each quarterly report parsed = {} if type(s['Statement']) == list and s['Statement'][0][ 'FPHeader']['Source']['#text'] in accepted_reports: data = s['Statement'] for item in data: # loops through income statement, balance sheet, and income statement # print(item['@Type']) ---- this is either INC, BAL, or CAS for i in item['lineItem']: try: parsed[coaCodes.coaCode_map[ i['@coaCode']]] = float(i['#text']) except KeyError: print('Could not find coaCode!!!') print(i['@coaCode']) print(coaMap) if qtr1 is None: qtr1 = parsed elif qtr2 is None: qtr2 = parsed elif qtr3 is None: qtr3 = parsed elif qtr4 is None: qtr4 = parsed return qtr1, qtr2, qtr3, qtr4 else: current_annual = None prev_annual = None # only one annual report if type(annuals) != list: if annuals['Statement'][0]['FPHeader']['Source'][ '#text'] in accepted_reports: # making it a list to work in the for loop below annuals = [annuals] else: # No annual reports that are of accepted type return current_annual, prev_annual for s in annuals: # loops through each annual report parsed = {} if type(s['Statement']) == list and s['Statement'][0][ 'FPHeader']['Source']['#text'] in accepted_reports: data = s['Statement'] for item in data: # loops through income statement, balance sheet, and income statement # print(item['@Type']) ---- this is either INC, BAL, or CAS for i in item['lineItem']: try: parsed[coaCodes.coaCode_map[ i['@coaCode']]] = float(i['#text']) except KeyError: print('Could not find coaCode!!!') print(i['@coaCode']) print(coaMap) if current_annual is None: current_annual = parsed elif prev_annual is None: prev_annual = parsed return current_annual, prev_annual
class IbApi(EWrapper): """""" data_filename = "ib_contract_data.db" data_filepath = str(get_file_path(data_filename)) local_tz = get_localzone() def __init__(self, gateway: BaseGateway): """""" super().__init__() self.gateway = gateway self.gateway_name = gateway.gateway_name self.status = False self.reqid = 0 self.orderid = 0 self.clientid = 0 self.history_reqid = 0 self.account = "" self.ticks = {} self.orders = {} self.accounts = {} self.contracts = {} self.tick_exchange = {} self.subscribed = {} self.data_ready = False self.history_req = None self.history_condition = Condition() self.history_buf = [] self.client = EClient(self) def connectAck(self): # pylint: disable=invalid-name """ Callback when connection is established. """ self.status = True self.gateway.write_log("IB TWS连接成功") self.load_contract_data() self.data_ready = False def connectionClosed(self): # pylint: disable=invalid-name """ Callback when connection is closed. """ self.status = False self.gateway.write_log("IB TWS连接断开") def nextValidId(self, orderId: int): # pylint: disable=invalid-name """ Callback of next valid orderid. """ super().nextValidId(orderId) if not self.orderid: self.orderid = orderId def currentTime(self, time: int): # pylint: disable=invalid-name """ Callback of current server time of IB. """ super().currentTime(time) dt = datetime.fromtimestamp(time) time_string = dt.strftime("%Y-%m-%d %H:%M:%S.%f") msg = f"服务器时间: {time_string}" self.gateway.write_log(msg) def error(self, reqId: TickerId, errorCode: int, errorString: str): # pylint: disable=invalid-name """ Callback of error caused by specific request. """ super().error(reqId, errorCode, errorString) if reqId == self.history_reqid: self.history_condition.acquire() self.history_condition.notify() self.history_condition.release() msg = f"信息通知,代码:{errorCode},内容: {errorString}" self.gateway.write_log(msg) # Market data server is connected if errorCode == 2104 and not self.data_ready: self.data_ready = True self.client.reqCurrentTime() reqs = list(self.subscribed.values()) self.subscribed.clear() for req in reqs: self.subscribe(req) def tickPrice( # pylint: disable=invalid-name self, reqId: TickerId, tickType: TickType, price: float, attrib: TickAttrib): """ Callback of tick price update. """ super().tickPrice(reqId, tickType, price, attrib) if tickType not in TICKFIELD_IB2VT: return tick = self.ticks[reqId] name = TICKFIELD_IB2VT[tickType] setattr(tick, name, price) # Update name into tick data. contract = self.contracts.get(tick.vt_symbol, None) if contract: tick.name = contract.name # Forex of IDEALPRO and Spot Commodity has no tick time and last price. # We need to calculate locally. exchange = self.tick_exchange[reqId] if exchange is Exchange.IDEALPRO or "CMDTY" in tick.symbol: tick.last_price = (tick.bid_price_1 + tick.ask_price_1) / 2 tick.datetime = datetime.now(self.local_tz) self.gateway.on_tick(copy(tick)) def tickSize(self, reqId: TickerId, tickType: TickType, size: int): # pylint: disable=invalid-name """ Callback of tick volume update. """ super().tickSize(reqId, tickType, size) if tickType not in TICKFIELD_IB2VT: return tick = self.ticks[reqId] name = TICKFIELD_IB2VT[tickType] setattr(tick, name, size) self.gateway.on_tick(copy(tick)) def tickString(self, reqId: TickerId, tickType: TickType, value: str): # pylint: disable=invalid-name """ Callback of tick string update. """ super().tickString(reqId, tickType, value) if tickType != TickTypeEnum.LAST_TIMESTAMP: return tick = self.ticks[reqId] dt = datetime.fromtimestamp(int(value)) tick.datetime = self.local_tz.localize(dt) self.gateway.on_tick(copy(tick)) def orderStatus( # pylint: disable=invalid-name self, orderId: OrderId, status: str, filled: float, remaining: float, avgFillPrice: float, permId: int, parentId: int, lastFillPrice: float, clientId: int, whyHeld: str, mktCapPrice: float, ): """ Callback of order status update. """ super().orderStatus( orderId, status, filled, remaining, avgFillPrice, permId, parentId, lastFillPrice, clientId, whyHeld, mktCapPrice, ) orderid = str(orderId) order = self.orders.get(orderid, None) if not order: return order.traded = filled # To filter PendingCancel status order_status = STATUS_IB2VT.get(status, None) if order_status: order.status = order_status self.gateway.on_order(copy(order)) def openOrder( # pylint: disable=invalid-name self, orderId: OrderId, ib_contract: Contract, ib_order: Order, orderState: OrderState, ): """ Callback when opening new order. """ super().openOrder(orderId, ib_contract, ib_order, orderState) orderid = str(orderId) order = OrderData( symbol=generate_symbol(ib_contract), exchange=EXCHANGE_IB2VT.get(ib_contract.exchange, Exchange.SMART), type=ORDERTYPE_IB2VT[ib_order.orderType], orderid=orderid, direction=DIRECTION_IB2VT[ib_order.action], volume=ib_order.totalQuantity, gateway_name=self.gateway_name, ) if order.type == OrderType.LIMIT: order.price = ib_order.lmtPrice elif order.type == OrderType.STOP: order.price = ib_order.auxPrice self.orders[orderid] = order self.gateway.on_order(copy(order)) def updateAccountValue( # pylint: disable=invalid-name self, key: str, val: str, currency: str, accountName: str): """ Callback of account update. """ super().updateAccountValue(key, val, currency, accountName) if not currency or key not in ACCOUNTFIELD_IB2VT: return accountid = f"{accountName}.{currency}" account = self.accounts.get(accountid, None) if not account: account = AccountData(accountid=accountid, gateway_name=self.gateway_name) self.accounts[accountid] = account name = ACCOUNTFIELD_IB2VT[key] setattr(account, name, float(val)) def updatePortfolio( # pylint: disable=invalid-name self, contract: Contract, position: float, marketPrice: float, marketValue: float, averageCost: float, unrealizedPNL: float, realizedPNL: float, accountName: str, ): """ Callback of position update. """ super().updatePortfolio( contract, position, marketPrice, marketValue, averageCost, unrealizedPNL, realizedPNL, accountName, ) if contract.exchange: exchange = EXCHANGE_IB2VT.get(contract.exchange, None) elif contract.primaryExchange: exchange = EXCHANGE_IB2VT.get(contract.primaryExchange, None) else: exchange = Exchange.SMART # Use smart routing for default if not exchange: msg = f"存在不支持的交易所持仓{generate_symbol(contract)} {contract.exchange} {contract.primaryExchange}" self.gateway.write_log(msg) return try: ib_size = int(contract.multiplier) except ValueError: ib_size = 1 price = averageCost / ib_size pos = PositionData( symbol=generate_symbol(contract), exchange=exchange, direction=Direction.NET, volume=position, price=price, pnl=unrealizedPNL, gateway_name=self.gateway_name, ) self.gateway.on_position(pos) def updateAccountTime(self, timeStamp: str): # pylint: disable=invalid-name """ Callback of account update time. """ super().updateAccountTime(timeStamp) for account in self.accounts.values(): self.gateway.on_account(copy(account)) def contractDetails(self, reqId: int, contractDetails: ContractDetails): # pylint: disable=invalid-name """ Callback of contract data update. """ super().contractDetails(reqId, contractDetails) # Generate symbol from ib contract details ib_contract = contractDetails.contract if not ib_contract.multiplier: ib_contract.multiplier = 1 symbol = generate_symbol(ib_contract) # Generate contract contract = ContractData( symbol=symbol, exchange=EXCHANGE_IB2VT[ib_contract.exchange], name=contractDetails.longName, product=PRODUCT_IB2VT[ib_contract.secType], size=int(ib_contract.multiplier), pricetick=contractDetails.minTick, net_position=True, history_data=True, stop_supported=True, gateway_name=self.gateway_name, ) if contract.vt_symbol not in self.contracts: self.gateway.on_contract(contract) self.contracts[contract.vt_symbol] = contract self.save_contract_data() def execDetails(self, reqId: int, contract: Contract, execution: Execution): # pylint: disable=invalid-name """ Callback of trade data update. """ super().execDetails(reqId, contract, execution) dt = datetime.strptime(execution.time, "%Y%m%d %H:%M:%S") dt = self.local_tz.localize(dt) trade = TradeData( symbol=generate_symbol(contract), exchange=EXCHANGE_IB2VT.get(contract.exchange, Exchange.SMART), orderid=str(execution.orderId), tradeid=str(execution.execId), direction=DIRECTION_IB2VT[execution.side], price=execution.price, volume=execution.shares, datetime=dt, gateway_name=self.gateway_name, ) self.gateway.on_trade(trade) def managedAccounts(self, accountsList: str): """ Callback of all sub accountid. """ super().managedAccounts(accountsList) if not self.account: for account_code in accountsList.split(","): self.account = account_code self.gateway.write_log(f"当前使用的交易账号为{self.account}") self.client.reqAccountUpdates(True, self.account) def historicalData(self, reqId: int, ib_bar: IbBarData): """ Callback of history data update. """ dt = datetime.strptime(ib_bar.date, "%Y%m%d %H:%M:%S") dt = self.local_tz.localize(dt) bar = BarData(symbol=self.history_req.symbol, exchange=self.history_req.exchange, datetime=dt, interval=self.history_req.interval, volume=ib_bar.volume, open_price=ib_bar.open, high_price=ib_bar.high, low_price=ib_bar.low, close_price=ib_bar.close, gateway_name=self.gateway_name) self.history_buf.append(bar) def historicalDataEnd(self, reqId: int, start: str, end: str): """ Callback of history data finished. """ self.history_condition.acquire() self.history_condition.notify() self.history_condition.release() def connect(self, host: str, port: int, clientid: int, account: str): """ Connect to TWS. """ if self.status: return self.host = host self.port = port self.clientid = clientid self.account = account self.client.connect(host, port, clientid) self.thread = Thread(target=self.client.run) self.thread.start() def check_connection(self): """""" if self.client.isConnected(): return if self.status: self.close() self.client.connect(self.host, self.port, self.clientid) self.thread = Thread(target=self.client.run) self.thread.start() def close(self): """ Disconnect to TWS. """ if not self.status: return self.status = False self.client.disconnect() def subscribe(self, req: SubscribeRequest): """ Subscribe tick data update. """ if not self.status: return if req.exchange not in EXCHANGE_VT2IB: self.gateway.write_log(f"不支持的交易所{req.exchange}") return # Filter duplicate subscribe if req.vt_symbol in self.subscribed: return self.subscribed[req.vt_symbol] = req # Extract ib contract detail ib_contract = generate_ib_contract(req.symbol, req.exchange) if not ib_contract: self.gateway.write_log("代码解析失败,请检查格式是否正确") return # Get contract data from TWS. self.reqid += 1 self.client.reqContractDetails(self.reqid, ib_contract) # Subscribe tick data and create tick object buffer. self.reqid += 1 self.client.reqMktData(self.reqid, ib_contract, "", False, False, []) tick = TickData( symbol=req.symbol, exchange=req.exchange, datetime=datetime.now(self.local_tz), gateway_name=self.gateway_name, ) self.ticks[self.reqid] = tick self.tick_exchange[self.reqid] = req.exchange def send_order(self, req: OrderRequest): """ Send a new order. """ if not self.status: return "" if req.exchange not in EXCHANGE_VT2IB: self.gateway.write_log(f"不支持的交易所:{req.exchange}") return "" if req.type not in ORDERTYPE_VT2IB: self.gateway.write_log(f"不支持的价格类型:{req.type}") return "" self.orderid += 1 ib_contract = generate_ib_contract(req.symbol, req.exchange) if not ib_contract: return "" ib_order = Order() ib_order.orderId = self.orderid ib_order.clientId = self.clientid ib_order.action = DIRECTION_VT2IB[req.direction] ib_order.orderType = ORDERTYPE_VT2IB[req.type] ib_order.totalQuantity = req.volume ib_order.account = self.account if req.type == OrderType.LIMIT: ib_order.lmtPrice = req.price elif req.type == OrderType.STOP: ib_order.auxPrice = req.price self.client.placeOrder(self.orderid, ib_contract, ib_order) self.client.reqIds(1) order = req.create_order_data(str(self.orderid), self.gateway_name) self.gateway.on_order(order) return order.vt_orderid def cancel_order(self, req: CancelRequest): """ Cancel an existing order. """ if not self.status: return self.client.cancelOrder(int(req.orderid)) def query_history(self, req: HistoryRequest): """""" self.history_req = req self.reqid += 1 ib_contract = generate_ib_contract(req.symbol, req.exchange) if req.end: end = req.end end_str = end.strftime("%Y%m%d %H:%M:%S") else: end = datetime.now(self.local_tz) end_str = "" delta = end - req.start days = min(delta.days, 180) # IB only provides 6-month data duration = f"{days} D" bar_size = INTERVAL_VT2IB[req.interval] if req.exchange == Exchange.IDEALPRO: bar_type = "MIDPOINT" else: bar_type = "TRADES" self.history_reqid = self.reqid self.client.reqHistoricalData(self.reqid, ib_contract, end_str, duration, bar_size, bar_type, 0, 1, False, []) self.history_condition.acquire() # Wait for async data return self.history_condition.wait() self.history_condition.release() history = self.history_buf self.history_buf = [] # Create new buffer list self.history_req = None return history def load_contract_data(self): """""" f = shelve.open(self.data_filepath) self.contracts = f.get("contracts", {}) f.close() for contract in self.contracts.values(): self.gateway.on_contract(contract) self.gateway.write_log("本地缓存合约信息加载成功") def save_contract_data(self): """""" f = shelve.open(self.data_filepath) f["contracts"] = self.contracts f.close()
class IBBroker(Broker): def __init__(self): self.logger = ib_logger.getChild(self.__class__.__name__) # lock that synchronizes entries into the functions and # makes sure we have a synchronous communication with client self.lock = Lock() self.waiting_time = 30 # expressed in seconds # lock that informs us that wrapper received the response self.action_event_lock = Event() self.wrapper = IBWrapper(self.action_event_lock) self.client = EClient(wrapper=self.wrapper) self.client.connect("127.0.0.1", 7497, clientId=0) # run the client in the separate thread so that the execution of the program can go on # now we will have 3 threads: # - thread of the main program # - thread of the client # - thread of the wrapper thread = Thread(target=self.client.run) thread.start() # this will be released after the client initialises and wrapper receives the nextValidOrderId if not self._wait_for_results(): raise ConnectionError("IB Broker was not initialized correctly") def get_portfolio_value(self) -> float: with self.lock: request_id = 1 self._reset_action_lock() self.client.reqAccountSummary(request_id, 'All', 'NetLiquidation') wait_result = self._wait_for_results() self.client.cancelAccountSummary(request_id) if wait_result: return self.wrapper.net_liquidation else: error_msg = 'Time out while getting portfolio value' self.logger.error('===> {}'.format(error_msg)) raise BrokerException(error_msg) def get_portfolio_tag(self, tag: str) -> float: with self.lock: request_id = 2 self._reset_action_lock() self.client.reqAccountSummary(request_id, 'All', tag) wait_result = self._wait_for_results() self.client.cancelAccountSummary(request_id) if wait_result: return self.wrapper.tmp_value else: error_msg = 'Time out while getting portfolio tag: {}'.format(tag) self.logger.error('===> {}'.format(error_msg)) raise BrokerException(error_msg) def get_positions(self) -> List[Position]: with self.lock: self._reset_action_lock() self.wrapper.reset_position_list() self.client.reqPositions() if self._wait_for_results(): return self.wrapper.position_list else: error_msg = 'Time out while getting positions' self.logger.error('===> {}'.format(error_msg)) raise BrokerException(error_msg) def place_orders(self, orders: Sequence[Order]) -> Sequence[int]: order_ids_list = [] for order in orders: self.logger.info('Placing Order: {}'.format(order)) order_id = self._execute_single_order(order) order_ids_list.append(order_id) return order_ids_list def cancel_order(self, order_id: int): with self.lock: self.logger.info('cancel_order: {}'.format(order_id)) self._reset_action_lock() self.wrapper.set_cancel_order_id(order_id) self.client.cancelOrder(order_id) if not self._wait_for_results(): error_msg = 'Time out while cancelling order id {} : \n'.format(order_id) self.logger.error('===> {}'.format(error_msg)) raise OrderCancellingException(error_msg) def get_open_orders(self) -> List[Order]: with self.lock: self._reset_action_lock() self.wrapper.reset_order_list() self.client.reqOpenOrders() if self._wait_for_results(): return self.wrapper.order_list else: error_msg = 'Time out while getting orders' self.logger.error('===> {}'.format(error_msg)) raise BrokerException(error_msg) def cancel_all_open_orders(self): """ There is now way to check if cancelling of all orders was finished. One can only get open orders and confirm that the list is empty """ with self.lock: self.client.reqGlobalCancel() self.logger.info('cancel_all_open_orders') def _execute_single_order(self, order) -> int: with self.lock: order_id = self.wrapper.next_order_id() self._reset_action_lock() self.wrapper.set_waiting_order_id(order_id) ib_contract = self._to_ib_contract(order.contract) ib_order = self._to_ib_order(order) self.client.placeOrder(order_id, ib_contract, ib_order) if self._wait_for_results(): return order_id else: error_msg = 'Time out while placing the trade for: \n\torder: {}'.format(order) self.logger.error('===> {}'.format(error_msg)) raise BrokerException(error_msg) def _wait_for_results(self) -> bool: """ Wait for self.waiting_time """ wait_result = self.action_event_lock.wait(self.waiting_time) return wait_result def _reset_action_lock(self): """ threads calling wait() will block until set() is called""" self.action_event_lock.clear() def _to_ib_contract(self, contract: Contract): ib_contract = IBContract() ib_contract.symbol = contract.symbol ib_contract.secType = contract.security_type ib_contract.exchange = contract.exchange return ib_contract def _to_ib_order(self, order: Order): ib_order = IBOrder() if order.quantity > 0: ib_order.action = 'BUY' else: ib_order.action = 'SELL' ib_order.totalQuantity = abs(order.quantity) execution_style = order.execution_style self._set_execution_style(ib_order, execution_style) time_in_force = order.time_in_force tif_str = self._map_to_tif_str(time_in_force) ib_order.tif = tif_str return ib_order def _map_to_tif_str(self, time_in_force): if time_in_force == TimeInForce.GTC: tif_str = "GTC" elif time_in_force == TimeInForce.DAY: tif_str = "DAY" elif time_in_force == TimeInForce.OPG: tif_str = "OPG" else: raise ValueError("Not supported TimeInForce {tif:s}".format(tif=str(time_in_force))) return tif_str def _set_execution_style(self, ib_order, execution_style): if isinstance(execution_style, MarketOrder): ib_order.orderType = "MKT" elif isinstance(execution_style, StopOrder): ib_order.orderType = "STP" ib_order.auxPrice = execution_style.stop_price