Ejemplo n.º 1
0
    def connect(self, setting: dict):
        """"""
        if not setting['password'].startswith("md5:"):
            setting['password'] = "******" + hashlib.md5(setting['password'].encode()).hexdigest()

        username = setting['username']
        password = setting['password']

        config_path = str(get_file_path("vnoes.ini"))
        with open(config_path, "wt") as f:
            if 'test' in setting:
                log_level = 'DEBUG'
                log_mode = 'console'
            else:
                log_level = 'WARNING'
                log_mode = 'file'
            log_dir = get_file_path('oes')
            log_path = os.path.join(log_dir, 'log.log')
            if not os.path.exists(log_dir):
                os.mkdir(log_dir)
            content = config_template.format(**setting,
                                             log_level=log_level,
                                             log_mode=log_mode,
                                             log_path=log_path)
            f.write(content)

        self.md_api.tcp_server = setting['md_tcp_server']
        self.md_api.qry_server = setting['md_qry_server']
        Thread(target=self._connect_md_sync, args=(config_path, username, password)).start()

        self.td_api.ord_server = setting['td_ord_server']
        self.td_api.rpt_server = setting['td_rpt_server']
        self.td_api.qry_server = setting['td_qry_server']
        self.td_api.hdd_serial = setting['hdd_serial']
        Thread(target=self._connect_td_sync, args=(config_path, username, password)).start()
Ejemplo n.º 2
0
    def __initDataManager(self, file_preffix):

        database = f"{file_preffix}{self.__code}_{self.__exchange.value.__str__()}.db"
        path = str(get_file_path(database))
        db = SqliteDatabase(path)
        self.__database_manager = init_by_sql_databease(db)
        self.__updateNewestTime()
Ejemplo n.º 3
0
def init(driver: Driver, settings: dict):

    database = settings["database"]
    store_path = str(get_folder_path('db_store'))
    path = str(get_file_path(database))
    db = SqliteDatabase(path)
    bcolz_meta, bar, tick = init_models(db, driver)
    return SqlManager(store_path, bcolz_meta, bar, tick)
Ejemplo n.º 4
0
def init():
    settings = get_settings("database.")
    database = settings["database"]
    path = str(get_file_path(database))
    db = SqliteDatabase(path)
    DbCtaTrade, DbCtaSignal, DbCtaPosition, DbCtaParams = init_models(db)

    return RecorderDbManager(DbCtaTrade, DbCtaSignal, DbCtaPosition,
                             DbCtaParams)
Ejemplo n.º 5
0
    def connect(self, setting: dict) -> None:
        """"""
        if not setting['password'].startswith("md5:"):
            setting['password'] = "******" + \
                hashlib.md5(setting['password'].encode()).hexdigest()

        username = setting['username']
        password = setting['password']

        config_path = str(get_file_path("vnoes.ini"))
        with open(config_path, "wt") as f:
            if 'test' in setting:
                log_level = 'DEBUG'
                log_mode = 'console'
            else:
                log_level = 'WARNING'
                log_mode = 'file'
            log_dir = get_file_path('oes')
            log_path = os.path.join(log_dir, 'log.log')
            if not os.path.exists(log_dir):
                os.mkdir(log_dir)
            content = config_template.format(**setting,
                                             log_level=log_level,
                                             log_mode=log_mode,
                                             log_path=log_path)
            f.write(content)

        self.md_api.connect(config_path=config_path,
                            username=username,
                            password=password)

        self.td_api.connect(
            config_path=config_path,
            username=username,
            password=password,
            ord_server=setting['td_ord_server'],
            rpt_server=setting['td_rpt_server'],
            hdd_serial=setting['hdd_serial'],
            qry_server=setting['td_qry_server'],
            customize_ip=setting['customize_ip'],
            customize_mac=setting['customize_mac'],
        )
Ejemplo n.º 6
0
    def connect(self, setting: dict):
        """"""
        if not setting['password'].startswith("md5:"):
            setting['password'] = "******" + \
                hashlib.md5(setting['password'].encode()).hexdigest()

        username = setting['username']
        password = setting['password']

        config_path = str(get_file_path("vnoes.ini"))
        with open(config_path, "wt") as f:
            if 'test' in setting:
                log_level = 'DEBUG'
                log_mode = 'console'
            else:
                log_level = 'WARNING'
                log_mode = 'file'
            log_dir = get_file_path('oes')
            log_path = os.path.join(log_dir, 'log.log')
            if not os.path.exists(log_dir):
                os.mkdir(log_dir)
            content = config_template.format(**setting,
                                             log_level=log_level,
                                             log_mode=log_mode,
                                             log_path=log_path)
            f.write(content)

        self.md_api.tcp_server = setting['md_tcp_server']
        self.md_api.qry_server = setting['md_qry_server']
        Thread(target=self._connect_md_sync,
               args=(config_path, username, password)).start()

        self.td_api.ord_server = setting['td_ord_server']
        self.td_api.rpt_server = setting['td_rpt_server']
        self.td_api.qry_server = setting['td_qry_server']
        self.td_api.hdd_serial = setting['hdd_serial']
        Thread(target=self._connect_td_sync,
               args=(config_path, username, password)).start()
    def __init__(self, settings_dict):
        self.settings_dict = settings_dict
        self.file_path_str = get_file_path(settings_dict['database'])

        os_str = platform.system()
        if os_str == "Windows":
            sqlite_os = "/"
        elif os_str == "Linux":
            sqlite_os = "//"
        else:
            print(f"OS is {os_str}. DBoperation may meet problem.")

        self.engine = create_engine(
            f"{self.settings_dict['driver']}://{sqlite_os}{self.file_path_str}"
        )
Ejemplo n.º 8
0
class IbApi(EWrapper):
    """"""
    data_filename = "ib_contract_data.db"
    data_filepath = str(get_file_path(data_filename))

    def __init__(self, gateway: BaseGateway):
        """"""
        super().__init__()

        self.gateway = gateway
        self.gateway_name = gateway.gateway_name

        self.status = False

        self.reqid = 0
        self.orderid = 0
        self.clientid = 0
        self.account = ""
        self.ticks = {}
        self.orders = {}
        self.accounts = {}
        self.contracts = {}

        self.tick_exchange = {}

        self.history_req = None
        self.history_condition = Condition()
        self.history_buf = []

        self.client = IbClient(self)
        self.thread = Thread(target=self.client.run)

    def connectAck(self):  # pylint: disable=invalid-name
        """
        Callback when connection is established.
        """
        self.status = True
        self.gateway.write_log("IB TWS连接成功")

        self.load_contract_data()

    def connectionClosed(self):  # pylint: disable=invalid-name
        """
        Callback when connection is closed.
        """
        self.status = False
        self.gateway.write_log("IB TWS连接断开")

    def nextValidId(self, orderId: int):  # pylint: disable=invalid-name
        """
        Callback of next valid orderid.
        """
        super().nextValidId(orderId)

        if not self.orderid:
            self.orderid = orderId

    def currentTime(self, time: int):  # pylint: disable=invalid-name
        """
        Callback of current server time of IB.
        """
        super().currentTime(time)

        dt = datetime.fromtimestamp(time)
        time_string = dt.strftime("%Y-%m-%d %H:%M:%S.%f")

        msg = f"服务器时间: {time_string}"
        self.gateway.write_log(msg)

    def error(self, reqId: TickerId, errorCode: int, errorString: str):  # pylint: disable=invalid-name
        """
        Callback of error caused by specific request.
        """
        super().error(reqId, errorCode, errorString)

        msg = f"信息通知,代码:{errorCode},内容: {errorString}"
        self.gateway.write_log(msg)

    def tickPrice(  # pylint: disable=invalid-name
            self, reqId: TickerId, tickType: TickType, price: float,
            attrib: TickAttrib):
        """
        Callback of tick price update.
        """
        super().tickPrice(reqId, tickType, price, attrib)

        if tickType not in TICKFIELD_IB2VT:
            return

        tick = self.ticks[reqId]
        name = TICKFIELD_IB2VT[tickType]
        setattr(tick, name, price)

        # Update name into tick data.
        contract = self.contracts.get(tick.vt_symbol, None)
        if contract:
            tick.name = contract.name

        # Forex and spot product of IDEALPRO has no tick time and last price.
        # We need to calculate locally.
        exchange = self.tick_exchange[reqId]
        if exchange is Exchange.IDEALPRO:
            tick.last_price = (tick.bid_price_1 + tick.ask_price_1) / 2
            tick.datetime = datetime.now()
        self.gateway.on_tick(copy(tick))

    def tickSize(self, reqId: TickerId, tickType: TickType, size: int):  # pylint: disable=invalid-name
        """
        Callback of tick volume update.
        """
        super().tickSize(reqId, tickType, size)

        if tickType not in TICKFIELD_IB2VT:
            return

        tick = self.ticks[reqId]
        name = TICKFIELD_IB2VT[tickType]
        setattr(tick, name, size)

        self.gateway.on_tick(copy(tick))

    def tickString(self, reqId: TickerId, tickType: TickType, value: str):  # pylint: disable=invalid-name
        """
        Callback of tick string update.
        """
        super().tickString(reqId, tickType, value)

        if tickType != TickTypeEnum.LAST_TIMESTAMP:
            return

        tick = self.ticks[reqId]
        tick.datetime = datetime.fromtimestamp(int(value))

        self.gateway.on_tick(copy(tick))

    def orderStatus(  # pylint: disable=invalid-name
        self,
        orderId: OrderId,
        status: str,
        filled: float,
        remaining: float,
        avgFillPrice: float,
        permId: int,
        parentId: int,
        lastFillPrice: float,
        clientId: int,
        whyHeld: str,
        mktCapPrice: float,
    ):
        """
        Callback of order status update.
        """
        super().orderStatus(
            orderId,
            status,
            filled,
            remaining,
            avgFillPrice,
            permId,
            parentId,
            lastFillPrice,
            clientId,
            whyHeld,
            mktCapPrice,
        )

        orderid = str(orderId)
        order = self.orders.get(orderid, None)
        order.traded = filled

        # To filter PendingCancel status
        order_status = STATUS_IB2VT.get(status, None)
        if order_status:
            order.status = order_status

        self.gateway.on_order(copy(order))

    def openOrder(  # pylint: disable=invalid-name
        self,
        orderId: OrderId,
        ib_contract: Contract,
        ib_order: Order,
        orderState: OrderState,
    ):
        """
        Callback when opening new order.
        """
        super().openOrder(orderId, ib_contract, ib_order, orderState)

        orderid = str(orderId)
        order = OrderData(
            symbol=ib_contract.conId,
            exchange=EXCHANGE_IB2VT.get(ib_contract.exchange,
                                        ib_contract.exchange),
            type=ORDERTYPE_IB2VT[ib_order.orderType],
            orderid=orderid,
            direction=DIRECTION_IB2VT[ib_order.action],
            volume=ib_order.totalQuantity,
            gateway_name=self.gateway_name,
        )

        if order.type == OrderType.LIMIT:
            order.price = ib_order.lmtPrice
        elif order.type == OrderType.STOP:
            order.price = ib_order.auxPrice

        self.orders[orderid] = order
        self.gateway.on_order(copy(order))

    def updateAccountValue(  # pylint: disable=invalid-name
            self, key: str, val: str, currency: str, accountName: str):
        """
        Callback of account update.
        """
        super().updateAccountValue(key, val, currency, accountName)

        if not currency or key not in ACCOUNTFIELD_IB2VT:
            return

        accountid = f"{accountName}.{currency}"
        account = self.accounts.get(accountid, None)
        if not account:
            account = AccountData(accountid=accountid,
                                  gateway_name=self.gateway_name)
            self.accounts[accountid] = account

        name = ACCOUNTFIELD_IB2VT[key]
        setattr(account, name, float(val))

    def updatePortfolio(  # pylint: disable=invalid-name
        self,
        contract: Contract,
        position: float,
        marketPrice: float,
        marketValue: float,
        averageCost: float,
        unrealizedPNL: float,
        realizedPNL: float,
        accountName: str,
    ):
        """
        Callback of position update.
        """
        super().updatePortfolio(
            contract,
            position,
            marketPrice,
            marketValue,
            averageCost,
            unrealizedPNL,
            realizedPNL,
            accountName,
        )

        if contract.exchange:
            exchange = EXCHANGE_IB2VT.get(contract.exchange, None)
        elif contract.primaryExchange:
            exchange = EXCHANGE_IB2VT.get(contract.primaryExchange, None)
        else:
            exchange = Exchange.SMART  # Use smart routing for default

        if not exchange:
            msg = f"存在不支持的交易所持仓{contract.conId} {contract.exchange} {contract.primaryExchange}"
            self.gateway.write_log(msg)
            return

        try:
            ib_size = int(contract.multiplier)
        except ValueError:
            ib_size = 1
        price = averageCost / ib_size

        pos = PositionData(
            symbol=generate_symbol(contract),
            exchange=exchange,
            direction=Direction.NET,
            volume=position,
            price=price,
            pnl=unrealizedPNL,
            gateway_name=self.gateway_name,
        )
        self.gateway.on_position(pos)

    def updateAccountTime(self, timeStamp: str):  # pylint: disable=invalid-name
        """
        Callback of account update time.
        """
        super().updateAccountTime(timeStamp)
        for account in self.accounts.values():
            self.gateway.on_account(copy(account))

    def contractDetails(self, reqId: int, contractDetails: ContractDetails):  # pylint: disable=invalid-name
        """
        Callback of contract data update.
        """
        super().contractDetails(reqId, contractDetails)

        # Generate symbol from ib contract details
        ib_contract = contractDetails.contract
        if not ib_contract.multiplier:
            ib_contract.multiplier = 1

        symbol = generate_symbol(ib_contract)

        # Generate contract
        contract = ContractData(
            symbol=symbol,
            exchange=EXCHANGE_IB2VT[ib_contract.exchange],
            name=contractDetails.longName,
            product=PRODUCT_IB2VT[ib_contract.secType],
            size=ib_contract.multiplier,
            pricetick=contractDetails.minTick,
            net_position=True,
            history_data=True,
            stop_supported=True,
            gateway_name=self.gateway_name,
        )

        if contract.vt_symbol not in self.contracts:
            self.gateway.on_contract(contract)

            self.contracts[contract.vt_symbol] = contract
            self.save_contract_data()

    def execDetails(self, reqId: int, contract: Contract,
                    execution: Execution):  # pylint: disable=invalid-name
        """
        Callback of trade data update.
        """
        super().execDetails(reqId, contract, execution)

        # today_date = datetime.now().strftime("%Y%m%d")
        trade = TradeData(
            symbol=contract.conId,
            exchange=EXCHANGE_IB2VT.get(contract.exchange, contract.exchange),
            orderid=str(execution.orderId),
            tradeid=str(execution.execId),
            direction=DIRECTION_IB2VT[execution.side],
            price=execution.price,
            volume=execution.shares,
            time=datetime.strptime(execution.time, "%Y%m%d  %H:%M:%S"),
            gateway_name=self.gateway_name,
        )

        self.gateway.on_trade(trade)

    def managedAccounts(self, accountsList: str):
        """
        Callback of all sub accountid.
        """
        super().managedAccounts(accountsList)

        if not self.account:
            for account_code in accountsList.split(","):
                self.account = account_code

        self.gateway.write_log(f"当前使用的交易账号为{self.account}")
        self.client.reqAccountUpdates(True, self.account)

    def historicalData(self, reqId: int, ib_bar: IbBarData):
        """
        Callback of history data update.
        """
        dt = datetime.strptime(ib_bar.date, "%Y%m%d %H:%M:%S")

        bar = BarData(symbol=self.history_req.symbol,
                      exchange=self.history_req.exchange,
                      datetime=dt,
                      interval=self.history_req.interval,
                      volume=ib_bar.volume,
                      open_price=ib_bar.open,
                      high_price=ib_bar.high,
                      low_price=ib_bar.low,
                      close_price=ib_bar.close,
                      gateway_name=self.gateway_name)

        self.history_buf.append(bar)

    def historicalDataEnd(self, reqId: int, start: str, end: str):
        """
        Callback of history data finished.
        """
        self.history_condition.acquire()
        self.history_condition.notify()
        self.history_condition.release()

    def connect(self, host: str, port: int, clientid: int, account: str):
        """
        Connect to TWS.
        """
        if self.status:
            return

        self.clientid = clientid
        self.account = account
        self.client.connect(host, port, clientid)
        self.thread.start()

        self.client.reqCurrentTime()

    def close(self):
        """
        Disconnect to TWS.
        """
        if not self.status:
            return

        self.status = False
        self.client.disconnect()

    def subscribe(self, req: SubscribeRequest):
        """
        Subscribe tick data update.
        """
        if not self.status:
            return

        if req.exchange not in EXCHANGE_VT2IB:
            self.gateway.write_log(f"不支持的交易所{req.exchange}")
            return

        # Extract ib contract detail
        ib_contract = generate_ib_contract(req.symbol, req.exchange)
        if not ib_contract:
            self.gateway.write_log("代码解析失败,请检查格式是否正确")
            return

        # Get contract data from TWS.
        self.reqid += 1
        self.client.reqContractDetails(self.reqid, ib_contract)

        # Subscribe tick data and create tick object buffer.
        self.reqid += 1
        self.client.reqMktData(self.reqid, ib_contract, "", False, False, [])

        tick = TickData(
            symbol=req.symbol,
            exchange=req.exchange,
            datetime=datetime.now(),
            gateway_name=self.gateway_name,
        )
        self.ticks[self.reqid] = tick
        self.tick_exchange[self.reqid] = req.exchange

    def send_order(self, req: OrderRequest):
        """
        Send a new order.
        """
        if not self.status:
            return ""

        if req.exchange not in EXCHANGE_VT2IB:
            self.gateway.write_log(f"不支持的交易所:{req.exchange}")
            return ""

        if req.type not in ORDERTYPE_VT2IB:
            self.gateway.write_log(f"不支持的价格类型:{req.type}")
            return ""

        self.orderid += 1

        ib_contract = generate_ib_contract(req.symbol, req.exchange)
        if not ib_contract:
            return ""

        ib_order = Order()
        ib_order.orderId = self.orderid
        ib_order.clientId = self.clientid
        ib_order.action = DIRECTION_VT2IB[req.direction]
        ib_order.orderType = ORDERTYPE_VT2IB[req.type]
        ib_order.totalQuantity = req.volume
        ib_order.account = self.account

        if req.type == OrderType.LIMIT:
            ib_order.lmtPrice = req.price
        elif req.type == OrderType.STOP:
            ib_order.auxPrice = req.price

        self.client.placeOrder(self.orderid, ib_contract, ib_order)
        self.client.reqIds(1)

        order = req.create_order_data(str(self.orderid), self.gateway_name)
        self.gateway.on_order(order)
        return order.vt_orderid

    def cancel_order(self, req: CancelRequest):
        """
        Cancel an existing order.
        """
        if not self.status:
            return

        self.client.cancelOrder(int(req.orderid))

    def query_history(self, req: HistoryRequest):
        """"""
        self.history_req = req

        self.reqid += 1

        ib_contract = generate_ib_contract(req.symbol, req.exchange)

        if req.end:
            end = req.end
            end_str = end.strftime("%Y%m%d %H:%M:%S")
        else:
            end = datetime.now()
            end_str = ""

        delta = end - req.start
        days = min(delta.days, 180)  # IB only provides 6-month data
        duration = f"{days} D"
        bar_size = INTERVAL_VT2IB[req.interval]

        if req.exchange == Exchange.IDEALPRO:
            bar_type = "MIDPOINT"
        else:
            bar_type = "TRADES"

        self.client.reqHistoricalData(self.reqid, ib_contract, end_str,
                                      duration, bar_size, bar_type, 1, 1,
                                      False, [])

        self.history_condition.acquire()  # Wait for async data return
        self.history_condition.wait()
        self.history_condition.release()

        history = self.history_buf
        self.history_buf = []  # Create new buffer list
        self.history_req = None

        return history

    def load_contract_data(self):
        """"""
        f = shelve.open(self.data_filepath)
        self.contracts = f.get("contracts", {})
        f.close()

        for contract in self.contracts.values():
            self.gateway.on_contract(contract)

        self.gateway.write_log("本地缓存合约信息加载成功")

    def save_contract_data(self):
        """"""
        f = shelve.open(self.data_filepath)
        f["contracts"] = self.contracts
        f.close()
Ejemplo n.º 9
0
def init_sqlite(settings: dict):
    database = settings["database"]
    path = str(get_file_path(database))
    db = SqliteDatabase(path)
    return db
Ejemplo n.º 10
0
from peewee import (
    AutoField,
    CharField,
    Database,
    FloatField,
    DateTimeField,
    Model,
    SqliteDatabase
)

from vnpy.trader.constant import Interval, Exchange

# 使用vn.py运行时目录的SQLite数据库
from vnpy.trader.utility import get_file_path
from vnpy.trader.setting import SETTINGS
path = get_file_path(SETTINGS["database.database"])

# 或者可以手动指定数据库位置
# path = "C:\\users\\administrator\\.vntrader\\database.db"   

# 创建数据库对象
db = SqliteDatabase(path)


# 创建数据ORM的类
class DbBarData(Model):
    """
    Candlestick bar data for database storage.

    Index is defined unique with datetime, interval, symbol
    """
Ejemplo n.º 11
0
class InfluxdbDatabase(BaseDatabase):
    """"""
    overview_filename = "influxdb_overview"
    overview_filepath = str(get_file_path(overview_filename))

    def __init__(self) -> None:
        """"""
        database = SETTINGS["database.database"]
        user = SETTINGS["database.user"]
        password = SETTINGS["database.password"]
        host = SETTINGS["database.host"]
        port = SETTINGS["database.port"]

        self.client = InfluxDBClient(host, port, user, password, database)
        self.client.create_database(database)

    def save_bar_data(self, bars: List[BarData]) -> bool:
        """"""
        json_body = []

        bar = bars[0]
        vt_symbol = bar.vt_symbol
        interval = bar.interval

        for bar in bars:
            bar.datetime = convert_tz(bar.datetime)

            d = {
                "measurement": "bar_data",
                "tags": {
                    "vt_symbol": vt_symbol,
                    "interval": interval.value
                },
                "time": bar.datetime.isoformat(),
                "fields": {
                    "open_price": bar.open_price,
                    "high_price": bar.high_price,
                    "low_price": bar.low_price,
                    "close_price": bar.close_price,
                    "volume": bar.volume,
                    "open_interest": bar.open_interest,
                }
            }
            json_body.append(d)

        self.client.write_points(json_body, batch_size=10000)

        # Update bar overview
        symbol, exchange = extract_vt_symbol(vt_symbol)
        key = f"{vt_symbol}_{interval.value}"

        f = shelve.open(self.overview_filepath)
        overview = f.get(key, None)

        if not overview:
            overview = BarOverview(
                symbol=symbol,
                exchange=exchange,
                interval=interval
            )
            overview.count = len(bars)
            overview.start = bars[0].datetime
            overview.end = bars[-1].datetime
        else:
            overview.start = min(overview.start, bars[0].datetime)
            overview.end = max(overview.end, bars[-1].datetime)

            query = (
                "select count(close_price) from bar_data"
                " where vt_symbol=$vt_symbol"
                " and interval=$interval"
            )
            bind_params = {
                "vt_symbol": vt_symbol,
                "interval": interval.value
            }
            result = self.client.query(query, bind_params=bind_params)
            points = result.get_points()

            for d in points:
                overview.count = d["count"]

        f[key] = overview
        f.close()

    def save_tick_data(self, ticks: List[TickData]) -> bool:
        """"""
        json_body = []

        tick = ticks[0]
        vt_symbol = tick.vt_symbol

        for tick in ticks:
            tick.datetime = convert_tz(tick.datetime)

            d = {
                "measurement": "tick_data",
                "tags": {
                    "vt_symbol": vt_symbol
                },
                "time": tick.datetime.isoformat(),
                "fields": {
                    "name": tick.name,
                    "volume": tick.volume,
                    "open_interest": tick.open_interest,
                    "last_price": tick.last_price,
                    "last_volume": tick.last_volume,
                    "limit_up": tick.limit_up,
                    "limit_down": tick.limit_down,

                    "open_price": tick.open_price,
                    "high_price": tick.high_price,
                    "low_price": tick.low_price,
                    "pre_close": tick.pre_close,

                    "bid_price_1": tick.bid_price_1,
                    "bid_price_2": tick.bid_price_2,
                    "bid_price_3": tick.bid_price_3,
                    "bid_price_4": tick.bid_price_4,
                    "bid_price_5": tick.bid_price_5,

                    "ask_price_1": tick.ask_price_1,
                    "ask_price_2": tick.ask_price_2,
                    "ask_price_3": tick.ask_price_3,
                    "ask_price_4": tick.ask_price_4,
                    "ask_price_5": tick.ask_price_5,

                    "bid_volume_1": tick.bid_volume_1,
                    "bid_volume_2": tick.bid_volume_2,
                    "bid_volume_3": tick.bid_volume_3,
                    "bid_volume_4": tick.bid_volume_4,
                    "bid_volume_5": tick.bid_volume_5,

                    "ask_volume_1": tick.ask_volume_1,
                    "ask_volume_2": tick.ask_volume_2,
                    "ask_volume_3": tick.ask_volume_3,
                    "ask_volume_4": tick.ask_volume_4,
                    "ask_volume_5": tick.ask_volume_5,
                }
            }
            json_body.append(d)

        self.client.write_points(json_body, batch_size=10000)

    def load_bar_data(
        self,
        symbol: str,
        exchange: Exchange,
        interval: Interval,
        start: datetime,
        end: datetime
    ) -> List[BarData]:
        """"""
        query = (
            "select * from bar_data"
            " where vt_symbol=$vt_symbol"
            " and interval=$interval"
            f" and time >= '{start.date().isoformat()}'"
            f" and time <= '{end.date().isoformat()}';"
        )

        bind_params = {
            "vt_symbol": generate_vt_symbol(symbol, exchange),
            "interval": interval.value
        }

        result = self.client.query(query, bind_params=bind_params)
        points = result.get_points()

        bars: List[BarData] = []
        for d in points:
            dt = datetime.strptime(d["time"], "%Y-%m-%dT%H:%M:%SZ")

            bar = BarData(
                symbol=symbol,
                exchange=exchange,
                interval=interval,
                datetime=DB_TZ.localize(dt),
                open_price=d["open_price"],
                high_price=d["high_price"],
                low_price=d["low_price"],
                close_price=d["close_price"],
                volume=d["volume"],
                open_interest=d["open_interest"],
                gateway_name="DB"
            )
            bars.append(bar)

        return bars

    def load_tick_data(
        self,
        symbol: str,
        exchange: Exchange,
        start: datetime,
        end: datetime
    ) -> List[TickData]:
        """"""
        query = (
            "select * from tick_data"
            " where vt_symbol=$vt_symbol"
            f" and time >= '{start.date().isoformat()}'"
            f" and time <= '{end.date().isoformat()}';"
        )

        bind_params = {
            "vt_symbol": generate_vt_symbol(symbol, exchange),
        }

        result = self.client.query(query, bind_params=bind_params)
        points = result.get_points()

        ticks: List[TickData] = []
        for d in points:
            dt = datetime.strptime(d["time"], "%Y-%m-%dT%H:%M:%SZ")

            tick = TickData(
                symbol=symbol,
                exchange=exchange,
                datetime=DB_TZ.localize(dt),
                name=d["name"],
                volume=d["volume"],
                open_interest=d["open_interest"],
                last_price=d["last_price"],
                last_volume=d["last_volume"],
                limit_up=d["limit_up"],
                limit_down=d["limit_down"],
                open_price=d["open_price"],
                high_price=d["high_price"],
                low_price=d["low_price"],
                pre_close=d["pre_close"],
                bid_price_1=d["bid_price_1"],
                bid_price_2=d["bid_price_2"],
                bid_price_3=d["bid_price_3"],
                bid_price_4=d["bid_price_4"],
                bid_price_5=d["bid_price_5"],
                ask_price_1=d["ask_price_1"],
                ask_price_2=d["ask_price_2"],
                ask_price_3=d["ask_price_3"],
                ask_price_4=d["ask_price_4"],
                ask_price_5=d["ask_price_5"],
                bid_volume_1=d["bid_volume_1"],
                bid_volume_2=d["bid_volume_2"],
                bid_volume_3=d["bid_volume_3"],
                bid_volume_4=d["bid_volume_4"],
                bid_volume_5=d["bid_volume_5"],
                ask_volume_1=d["ask_volume_1"],
                ask_volume_2=d["ask_volume_2"],
                ask_volume_3=d["ask_volume_3"],
                ask_volume_4=d["ask_volume_4"],
                ask_volume_5=d["ask_volume_5"],
                gateway_name="DB"
            )
            ticks.append(tick)

        return ticks

    def delete_bar_data(
        self,
        symbol: str,
        exchange: Exchange,
        interval: Interval
    ) -> int:
        """"""
        bind_params = {
            "vt_symbol": generate_vt_symbol(symbol, exchange),
            "interval": interval.value
        }

        # Query data count
        query1 = (
            "select count(close_price) from bar_data"
            " where vt_symbol=$vt_symbol"
            " and interval=$interval"
        )
        result = self.client.query(query1, bind_params=bind_params)
        points = result.get_points()

        for d in points:
            count = d["count"]

        # Delete data
        query2 = (
            "drop series from bar_data"
            " where vt_symbol=$vt_symbol"
            " and interval=$interval"
        )
        self.client.query(query2, bind_params=bind_params)

        # Delete overview
        f = shelve.open(self.overview_filepath)
        vt_symbol = generate_vt_symbol(symbol, exchange)
        key = f"{vt_symbol}_{interval.value}"
        if key in f:
            f.pop(key)
        f.close()

        return count

    def delete_tick_data(
        self,
        symbol: str,
        exchange: Exchange
    ) -> int:
        """"""
        bind_params = {
            "vt_symbol": generate_vt_symbol(symbol, exchange),
        }

        # Query data count
        query1 = (
            "select count(last_price) from tick_data"
            " where vt_symbol=$vt_symbol"
        )
        result = self.client.query(query1, bind_params=bind_params)
        points = result.get_points()

        for d in points:
            count = d["count"]

        # Delete data
        query2 = (
            "drop series from tick_data"
            " where vt_symbol=$vt_symbol"
        )
        self.client.query(query2, bind_params=bind_params)

        return count

    def get_bar_overview(self) -> List[BarOverview]:
        """
        Return data avaible in database.
        """
        # Init bar overview if not exists
        query = "select count(close_price) from bar_data"
        result = self.client.query(query)
        points = result.get_points()
        data_count = 0
        for d in points:
            data_count = d["count"]

        f = shelve.open(self.overview_filepath)
        overview_count = len(f)

        if data_count and not overview_count:
            self.init_bar_overview()

        overviews = list(f.values())
        f.close()
        return overviews

    def init_bar_overview(self) -> None:
        """
        Init overview table if not exists.
        """
        f = shelve.open(self.overview_filepath)

        query: str = "select count(close_price) from bar_data group by *"
        result = self.client.query(query)

        for k, v in result.items():
            tags = k[1]
            data = list(v)[0]

            vt_symbol = tags["vt_symbol"]
            symbol, exchange = extract_vt_symbol(vt_symbol)
            interval = Interval(tags["interval"])

            overview = BarOverview(
                symbol=symbol,
                exchange=exchange,
                interval=interval,
                count=data["count"]
            )
            overview.start = self.get_bar_datetime(vt_symbol, interval, 1)
            overview.end = self.get_bar_datetime(vt_symbol, interval, -1)

            key = f"{vt_symbol}_{interval.value}"
            f[key] = overview

        f.close()

    def get_bar_datetime(self, vt_symbol: str, interval: Interval, order: int) -> datetime:
        """"""
        if order > 0:
            keyword = "first"
        else:
            keyword = "last"

        query = (
            f"select {keyword}(close_price), * from bar_data"
            " where vt_symbol=$vt_symbol"
            " and interval=$interval"
        )

        bind_params = {
            "vt_symbol": vt_symbol,
            "interval": interval.value
        }

        result = self.client.query(query, bind_params=bind_params)
        points = result.get_points()

        for d in points:
            dt = datetime.strptime(d["time"], "%Y-%m-%dT%H:%M:%SZ")

        return dt
Ejemplo n.º 12
0
def get_sqlite(dbName):
    path = str(get_file_path(dbName))
    db = SqliteDatabase(path)
    return db
Ejemplo n.º 13
0
""""""
from datetime import datetime
from typing import List

from peewee import (AutoField, CharField, DateTimeField, FloatField,
                    IntegerField, Model, SqliteDatabase as
                    PeeweeSqliteDatabase, ModelSelect, ModelDelete, chunked,
                    fn)

from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.object import BarData, TickData
from vnpy.trader.utility import get_file_path
from vnpy.trader.database import (BaseDatabase, BarOverview, DB_TZ, convert_tz)

path = str(get_file_path("database.db"))
db = PeeweeSqliteDatabase(path)


class DbBarData(Model):
    """"""

    id = AutoField()

    symbol: str = CharField()
    exchange: str = CharField()
    datetime: datetime = DateTimeField()
    interval: str = CharField()

    volume: float = FloatField()
    turnover: float = FloatField()
    open_interest: float = FloatField()
Ejemplo n.º 14
0
""""""

from peewee import SqliteDatabase, Model, CharField, DateTimeField, FloatField

#from .constant import Exchange, Interval
#from .object import BarData, TickData
#from .utility import get_file_path
from vnpy.trader.constant import Exchange, Interval
from vnpy.trader.object import BarData, TickData
from vnpy.trader.utility import get_file_path

DB_NAME = "database.db"
dbname = str(get_file_path(DB_NAME))
DB = SqliteDatabase(str(get_file_path(DB_NAME)))


class DbBarData(Model):
    """
    Candlestick bar data for database storage.

    Index is defined unique with vt_symbol, interval and datetime.
    """

    symbol = CharField()
    exchange = CharField()
    datetime = DateTimeField()
    interval = CharField()

    volume = FloatField()
    open_price = FloatField()
    high_price = FloatField()