예제 #1
0
    def send_order(self, strategy: StrategyTemplate, vt_symbol: str,
                   direction: Direction, offset: Offset, price: float,
                   volume: float, lock: bool) -> List[str]:
        """"""
        price = round_to(price, self.priceticks[vt_symbol])
        symbol, exchange = extract_vt_symbol(vt_symbol)

        self.limit_order_count += 1

        order = OrderData(
            symbol=symbol,
            exchange=exchange,
            orderid=str(self.limit_order_count),
            direction=direction,
            offset=offset,
            price=price,
            volume=volume,
            status=Status.SUBMITTING,
            datetime=self.datetime,
            gateway_name=self.gateway_name,
        )

        self.active_limit_orders[order.vt_orderid] = order
        self.limit_orders[order.vt_orderid] = order

        return [order.vt_orderid]
예제 #2
0
 def get_position(self, vt_symbol: str, direction: Direction = Direction.NET, gateway_name: str = ''):
     """
     查询合约在账号的持仓
     """
     if not gateway_name:
         gateway_name = self.gateway_name
     k = f'{gateway_name}.{vt_symbol}.{direction.value}'
     pos = self.positions.get(k, None)
     if not pos:
         contract = self.get_contract(vt_symbol)
         if not contract:
             self.write_log(f'{vt_symbol}合约信息不存在,构造一个')
             symbol, exchange = extract_vt_symbol(vt_symbol)
             if self.contract_type == 'future':
                 product = Product.FUTURES
             elif self.contract_type == 'stock':
                 product = Product.EQUITY
             else:
                 product = Product.SPOT
             contract = ContractData(gateway_name=gateway_name,
                                     name=vt_symbol,
                                     product=product,
                                     symbol=symbol,
                                     exchange=exchange,
                                     size=self.get_size(vt_symbol),
                                     pricetick=self.get_price_tick(vt_symbol),
                                     margin_rate=self.get_margin_rate(vt_symbol))
         pos = PositionData(
             gateway_name=gateway_name,
             symbol=contract.symbol,
             exchange=contract.exchange,
             direction=direction
         )
         self.positions[k] = pos
     return pos
예제 #3
0
    def init_bar_overview(self) -> None:
        """
        Init overview table if not exists.
        """
        f = shelve.open(self.overview_filepath)

        query: str = "select count(close_price) from bar_data group by *"
        result = self.client.query(query)

        for k, v in result.items():
            tags = k[1]
            data = list(v)[0]

            vt_symbol = tags["vt_symbol"]
            symbol, exchange = extract_vt_symbol(vt_symbol)
            interval = Interval(tags["interval"])

            overview = BarOverview(
                symbol=symbol,
                exchange=exchange,
                interval=interval,
                count=data["count"]
            )
            overview.start = self.get_bar_datetime(vt_symbol, interval, 1)
            overview.end = self.get_bar_datetime(vt_symbol, interval, -1)

            key = f"{vt_symbol}_{interval.value}"
            f[key] = overview

        f.close()
예제 #4
0
파일: engine.py 프로젝트: lalacat/vnpy_test
    def run_downloading(
        self,
        vt_symbol: str,
        interval: str,
        start: datetime,
        end: datetime
    ):
        """
        Query bar data from RQData.
        """
        self.write_log(f"{vt_symbol}-{interval}开始下载历史数据")

        symbol, exchange = extract_vt_symbol(vt_symbol)
        data = rqdata_client.query_bar(
            symbol, exchange, Interval(interval), start, end
        )

        if not data:
            self.write_log(f"数据下载失败,无法获取{vt_symbol}的历史数据")

        database_manager.save_bar_data(data)

        # Clear thread object handler.
        self.thread = None
        self.write_log(f"{vt_symbol}-{interval}历史数据下载完成")
예제 #5
0
    def __init__(self, strategy_engine: StrategyEngine, strategy_name: str,
                 vt_symbols: List[str], setting: dict):
        """"""
        super().__init__(strategy_engine, strategy_name, vt_symbols, setting)

        self.bgs: Dict[str, BarGenerator] = {}
        self.last_tick_time: datetime = None

        # Obtain contract info
        for vt_symbol in self.vt_symbols:
            symbol, exchange = extract_vt_symbol(vt_symbol)

            if "C" in symbol:
                self.call_symbol = vt_symbol
                _, strike_str = symbol.split("-C-")  # For CFFEX/DCE options
                self.strike_price = int(strike_str)
            elif "P" in symbol:
                self.put_symbol = vt_symbol
            else:
                self.futures_symbol = vt_symbol

            def on_bar(bar: BarData):
                """"""
                pass

            self.bgs[vt_symbol] = BarGenerator(on_bar)
예제 #6
0
    def run_downloading(self, vt_symbol: str, interval: str, start: datetime,
                        end: datetime):
        """
        Query bar data from RQData.
        """
        self.write_log(f"{vt_symbol}-{interval}开始下载历史数据")

        symbol, exchange = extract_vt_symbol(vt_symbol)

        req = HistoryRequest(symbol=symbol,
                             exchange=exchange,
                             interval=Interval(interval),
                             start=start,
                             end=end)

        contract = self.main_engine.get_contract(vt_symbol)

        # If history data provided in gateway, then query
        if contract and contract.history_data:
            data = self.main_engine.query_history(req, contract.gateway_name)
        # Otherwise use RQData to query data
        else:
            data = rqdata_client.query_history(req)

        if data:
            database_manager.save_bar_data(data)
            self.write_log(f"{vt_symbol}-{interval}历史数据下载完成")
        else:
            self.write_log(f"数据下载失败,无法获取{vt_symbol}的历史数据")

        # Clear thread object handler.
        self.thread = None
예제 #7
0
    def rq_download(
        self,
        vt_symbol: str,
        interval: str,
        start: datetime,
        end: datetime,
    ):
        rqdata_client.init()
        symbol, exchange = extract_vt_symbol(vt_symbol)

        req = HistoryRequest(symbol=symbol,
                             exchange=exchange,
                             interval=Interval(interval),
                             start=start,
                             end=end)

        # print(req)

        data = rqdata_client.query_history(req)

        if data:
            database_manager.save_bar_data(data)
            print(f"{vt_symbol}-{interval} 历史数据下载完成")
        else:
            print(f"数据下载失败,无法得到 {vt_symbol} 的数据")
예제 #8
0
    def save_to_database(data: List[dict], vt_symbol: str, rq_interval: str):
        interval = INTERVAL_RQ2VT.get(rq_interval)
        if not rq_interval:
            return None

        symbol, exchange = extract_vt_symbol(vt_symbol)
        exchange = Exchange(exchange)
        dt_format = "%Y-%m-%d %H:%M:%S"

        res_list: List[BarData] = []
        if data is not None:
            for row in data:
                bar = BarData(symbol=symbol,
                              exchange=exchange,
                              interval=interval,
                              datetime=datetime.strptime(
                                  row['datetime'], dt_format),
                              open_price=row["open"],
                              high_price=row["high"],
                              low_price=row["low"],
                              close_price=row["close"],
                              volume=row["volume"],
                              gateway_name="RQ_WEB")
                res_list.append(bar)
        database_manager.save_bar_data(res_list)
예제 #9
0
    def load_bar(
        self, 
        vt_symbol: str, 
        days: int, 
        interval: Interval,
        callback: Callable[[BarData], None]
    ):
        """"""
        symbol, exchange = extract_vt_symbol(vt_symbol)
        end = datetime.now()
        start = end - timedelta(days)

        # Query bars from RQData by default, if not found, load from database.
        bars = self.query_bar_from_rq(symbol, exchange, interval, start, end)
        if not bars:
            bars = database_manager.load_bar_data(
                symbol=symbol,
                exchange=exchange,
                interval=interval,
                start=start,
                end=end,
            )

        for bar in bars:
            callback(bar)
예제 #10
0
파일: engine.py 프로젝트: zpf4934/vnpy
    def load_bar(self, vt_symbol: str, days: int,
                 interval: Interval) -> List[BarData]:
        """"""
        symbol, exchange = extract_vt_symbol(vt_symbol)
        end = datetime.now(get_localzone())
        start = end - timedelta(days)
        contract: ContractData = self.main_engine.get_contract(vt_symbol)
        data = []

        # Query bars from gateway if available
        if contract and contract.history_data:
            req = HistoryRequest(symbol=symbol,
                                 exchange=exchange,
                                 interval=interval,
                                 start=start,
                                 end=end)
            data = self.main_engine.query_history(req, contract.gateway_name)
        # Try to query bars from RQData, if not found, load from database.
        else:
            data = self.query_bar_from_rq(symbol, exchange, interval, start,
                                          end)

        if not data:
            data = database_manager.load_bar_data(
                symbol=symbol,
                exchange=exchange,
                interval=interval,
                start=start,
                end=end,
            )

        return data
예제 #11
0
    def load_bar(self, vt_symbol: str, days: int, interval: Interval,
                 callback: Callable[[BarData], None]):
        """"""
        symbol, exchange = extract_vt_symbol(vt_symbol)
        end = datetime.now()
        start = end - timedelta(days)

        # Query bars from gateway if available
        contract = self.main_engine.get_contract(vt_symbol)

        if contract and contract.history_data:
            req = HistoryRequest(symbol=symbol,
                                 exchange=exchange,
                                 interval=interval,
                                 start=start,
                                 end=end)
            bars = self.main_engine.query_history(req, contract.gateway_name)

        # Try to query bars from RQData, if not found, load from database.
        else:
            bars = self.query_bar_from_rq(symbol, exchange, interval, start,
                                          end)

        if not bars:
            bars = database_manager.load_bar_data(
                symbol=symbol,
                exchange=exchange,
                interval=interval,
                start=start,
                end=end,
            )

        for bar in bars:
            callback(bar)
예제 #12
0
def load_bar_data(vt_symbol: str, interval: Interval, start: datetime,
                  end: datetime):
    """"""
    symbol, exchange = extract_vt_symbol(vt_symbol)

    return database_manager.load_bar_data(symbol, exchange, interval, start,
                                          end)
예제 #13
0
파일: engine.py 프로젝트: arthurlirui/vnpy
    def load_market_trade(self, vt_symbol: str, callback: Callable):
        contract = self.main_engine.get_contract(vt_symbol)
        symbol, exchange = extract_vt_symbol(vt_symbol)

        req = HistoryRequest(symbol=symbol, exchange=exchange, start=datetime.now())
        trades = self.main_engine.query_market_trade(req=req, gateway_name=contract.gateway_name)
        for trade in trades:
            callback(trade)
예제 #14
0
def load_bar_data(
    spread: SpreadData,
    interval: Interval,
    start: datetime,
    end: datetime,
    pricetick: float = 0
):
    """"""
    # Load bar data of each spread leg
    leg_bars: Dict[str, Dict] = {}

    for vt_symbol in spread.legs.keys():
        symbol, exchange = extract_vt_symbol(vt_symbol)

        bar_data: List[BarData] = database_manager.load_bar_data(
            symbol, exchange, interval, start, end
        )

        bars: Dict[datetime, BarData] = {bar.datetime: bar for bar in bar_data}
        leg_bars[vt_symbol] = bars

    # Calculate spread bar data
    spread_bars: List[BarData] = []

    for dt in bars.keys():
        spread_price = 0
        spread_value = 0
        spread_available = True

        for leg in spread.legs.values():
            leg_bar = leg_bars[leg.vt_symbol].get(dt, None)

            if leg_bar:
                price_multiplier = spread.price_multipliers[leg.vt_symbol]
                spread_price += price_multiplier * leg_bar.close_price
                spread_value += abs(price_multiplier) * leg_bar.close_price
            else:
                spread_available = False

        if spread_available:
            if pricetick:
                spread_price = round_to(spread_price, pricetick)

            spread_bar = BarData(
                symbol=spread.name,
                exchange=exchange.LOCAL,
                datetime=dt,
                interval=interval,
                open_price=spread_price,
                high_price=spread_price,
                low_price=spread_price,
                close_price=spread_price,
                gateway_name="SPREAD",
            )
            spread_bar.value = spread_value
            spread_bars.append(spread_bar)

    return spread_bars
예제 #15
0
def load_data(vt_symbol: str, interval: str, start: datetime,
              end: datetime) -> pd.DataFrame:
    symbol, exchange = extract_vt_symbol(vt_symbol)
    data = database_manager.load_bar_data(
        symbol,
        exchange,
        Interval(interval),
        start=start,
        end=end,
    )
    return vt_bar_to_df(data)
예제 #16
0
    def is_trading(self, vt_symbol, current_time) -> bool:
        """
        交易时间,过滤校验Tick
        """
        symbol, exchange = extract_vt_symbol(vt_symbol)

        if self.drop_start <= current_time < self.drop_end:
            return False
        if exchange in [Exchange.DCE, Exchange.SHFE, Exchange.CZCE]:
            if self.rest_start <= current_time < self.rest_end:
                return False
        return True
예제 #17
0
def vt_symbol_to_tq_symbol(vt_symbol: str, bar_type: str):
    """
    bar_type: "trading", "index", "main"
    """
    symbol, exchange = extract_vt_symbol(vt_symbol)
    if bar_type == "trading":
        return f"{exchange.value}.{symbol}"
    elif bar_type == "index":
        return f"KQ.i@{exchange.value}.{strip_digt(symbol)}"
    elif bar_type == "main":
        return f"KQ.m@{exchange.value}.{strip_digt(symbol)}"
    else:
        raise ValueError("The bar_type argument must be trading, index or main")
예제 #18
0
파일: engine.py 프로젝트: xiumingxu/vnpy-xx
    def get_position(self, vt_symbol: str, direction: Direction):
        """"""
        key = (vt_symbol, direction)

        if key in self.positions:
            return self.positions[key]
        else:
            symbol, exchange = extract_vt_symbol(vt_symbol)
            position = PositionData(symbol=symbol,
                                    exchange=exchange,
                                    direction=direction,
                                    gateway_name=GATEWAY_NAME)

            self.positions[key] = position
            return position
예제 #19
0
    def load_tick(self, vt_symbol: str, days: int,
                  callback: Callable[[TickData], None]):
        """"""
        symbol, exchange = extract_vt_symbol(vt_symbol)
        end = datetime.now()
        start = end - timedelta(days)

        ticks = database_manager.load_tick_data(
            symbol=symbol,
            exchange=exchange,
            start=start,
            end=end,
        )

        for tick in ticks:
            callback(tick)
예제 #20
0
    def run_downloading(self, vt_symbol: str, interval: str, start: datetime,
                        end: datetime):
        """
        Query bar data from RQData.
        """
        self.write_log(f"{vt_symbol}-{interval}开始下载历史数据")

        try:
            symbol, exchange = extract_vt_symbol(vt_symbol)
        except ValueError:
            self.write_log(f"{vt_symbol}解析失败,请检查交易所后缀")
            self.thread = None
            return

        req = HistoryRequest(symbol=symbol,
                             exchange=exchange,
                             interval=Interval(interval),
                             start=start,
                             end=end)

        contract = self.main_engine.get_contract(vt_symbol)

        try:
            # If history data provided in gateway, then query
            if contract and contract.history_data:
                data = self.main_engine.query_history(req,
                                                      contract.gateway_name)
            # Otherwise use RQData to query data
            else:
                if SETTINGS["rqdata.username"]:
                    data = rqdata_client.query_history(req)
                elif SETTINGS["tqdata.username"]:
                    data = tqdata_client.query_history(req)
                else:
                    data = []

            if data:
                database_manager.save_bar_data(data)
                self.write_log(f"{vt_symbol}-{interval}历史数据下载完成")
            else:
                self.write_log(f"数据下载失败,无法获取{vt_symbol}的历史数据")
        except Exception:
            msg = f"数据下载失败,触发异常:\n{traceback.format_exc()}"
            self.write_log(msg)

        # Clear thread object handler.
        self.thread = None
예제 #21
0
    def load_bar(
        self,
        vt_symbol: str,
        days: int,
        interval: Interval,
        callback: Callable[[BarData], None],
        use_database: bool
    ):
        """"""
        symbol, exchange = extract_vt_symbol(vt_symbol)
        end = datetime.now()
        start = end - timedelta(days)
        bars = []

        # Pass gateway and RQData if use_database set to True
        #如果database设置为foalse,就先从rqdata里面取数,如果里面没有再从数据库里面查找有
        #没有相关数据
        if not use_database:
            # Query bars from gateway if available
            contract = self.main_engine.get_contract(vt_symbol)##什么意思

            if contract and contract.history_data:
                req = HistoryRequest(
                    symbol=symbol,
                    exchange=exchange,
                    interval=interval,
                    start=start,
                    end=end
                )
                bars = self.main_engine.query_history(req, contract.gateway_name)

            # Try to query bars from RQData, if not found, load from database.
            else:
                bars = self.query_bar_from_rq(symbol, exchange, interval, start, end)
        #从数据库中提取数据
        if not bars:
            bars = database_manager.load_bar_data(
                symbol=symbol,
                exchange=exchange,
                interval=interval,
                start=start,
                end=end,
            )

        for bar in bars:
            callback(bar)
예제 #22
0
def save_tdx_data(file_path,vt_symbol:str,future_download:bool,interval: Interval = Interval.MINUTE):
    """
    保存通达信导出的lc1分钟数据,期货数据对齐datetime到文华财经
    """
    print("%(processName)s %(message)s save_tdx_data函数")

    symbol,exchange= extract_vt_symbol(vt_symbol)
    #读取二进制文件
    dt = np.dtype([
        ('date', 'u2'),
        ('time', 'u2'),
        ('open_price', 'f4'),
        ('high_price', 'f4'),
        ('low_price', 'f4'),
        ('close_price', 'f4'),
        ('amount', 'f4'),
        ('volume', 'u4'),
        ('reserve','u4')])
    data = np.fromfile(file_path, dtype=dt)
    df = pd.DataFrame(data, columns=data.dtype.names)
    df.eval('''
    year=floor(date/2048)+2004
    month=floor((date%2048)/100)
    day=floor(date%2048%100)
    hour = floor(time/60)
    minute = time%60
    ''',inplace=True)

    # 需要标准datetime格式,非datetime64[ns],timeStamp,此处将datetime64[ns]生成int型timestamp
    df['datetime2']=pd.to_datetime(df.loc[:,['year','month','day','hour','minute']])
    #tz_localize('Asia/Shanghai') 下面处理时区的问题 '1970-01-01T00:00:00-08:00' 与 UTC差8个小时
    df['datetime3'] =((df['datetime2'] - np.datetime64('1970-01-01T00:00:00-08:00')) / np.timedelta64(1, 's'))
    df['datetime'] = df['datetime3'].astype(int)
    #df['datetime'] = datetime.fromtimestamp(df['datetime4'] )  #这一步将int型timestamp转换为datetime,放到最后的BarData赋值时


    #删除多余字段
    df.drop(['date','time','year','month','day','hour','minute',"amount","reserve",'datetime2','datetime3'],1,inplace=True)

    #补全信息
    df['symbol'] = symbol
    df['exchange'] = exchange
    df['interval'] = interval
    #将整理好的df存入数据库
    return move_df_to_db(df,future_download)
예제 #23
0
    def run_downloading(
        self,
        vt_symbol: str,
        interval: str,
        start: datetime,
        end: datetime
    ):
        """
        Query bar data from RQData.
        """
        self.write_log(f"{vt_symbol}-{interval} start downloading historical data ")

        symbol, exchange = extract_vt_symbol(vt_symbol)

        req = HistoryRequest(
            symbol=symbol,
            exchange=exchange,
            interval=Interval(interval),
            start=start,
            end=end
        )

        contract = self.main_engine.get_contract(vt_symbol)

        try:
            # If history data provided in gateway, then query
            if contract and contract.history_data:
                data = self.main_engine.query_history(
                    req, contract.gateway_name
                )
            # Otherwise use RQData to query data
            else:
                data = rqdata_client.query_history(req)

            if data:
                database_manager.save_bar_data(data)
                self.write_log(f"{vt_symbol}-{interval} historical data download is complete ")
            else:
                self.write_log(f" data download failed , unable to get {vt_symbol} historical data ")
        except Exception:
            msg = f" data download failed , trigger abnormal :\n{traceback.format_exc()}"
            self.write_log(msg)

        # Clear thread object handler.
        self.thread = None
예제 #24
0
    def load_bar(
        self,
        vt_symbol: str,
        days: int,
        interval: Interval,
        callback: Callable[[BarData], None],
        use_database: bool
    ):
        """"""
        symbol, exchange = extract_vt_symbol(vt_symbol)
        end = datetime.now(get_localzone())
        end = datetime(year=end.year, month=end.month, day=end.day, hour=end.hour, minute=0)
        start = end - timedelta(days)
        bars = []

        # Pass gateway and RQData if use_database set to True
        if not use_database:
            # Query bars from gateway if available
            contract = self.main_engine.get_contract(vt_symbol)

            if contract and contract.history_data:
                req = HistoryRequest(
                    symbol=symbol,
                    exchange=exchange,
                    interval=interval,
                    start=start,
                    end=end
                )
                bars = self.main_engine.query_history(req, contract.gateway_name)

            # Try to query bars from RQData, if not found, load from database.
            else:
                bars = self.query_bar_from_rq(symbol, exchange, interval, start, end)

        if not bars:
            bars = database_manager.load_bar_data(
                symbol=symbol,
                exchange=exchange,
                interval=interval,
                start=start,
                end=end,
            )

        for bar in bars:
            callback(bar)
예제 #25
0
파일: engine.py 프로젝트: zhangjf76/vnpy
    def load_bar(self, vt_symbol: str, days: int, interval: Interval,
                 callback: Callable[[BarData], None], use_database: bool):
        """"""
        symbol, exchange = extract_vt_symbol(vt_symbol)
        end = datetime.now()
        start = end - timedelta(days)
        bars = []

        data_source = ""
        # Pass gateway and RQData if use_database set to True
        if not use_database:
            # Query bars from gateway if available
            contract = self.main_engine.get_contract(vt_symbol)

            if contract and contract.history_data:
                data_source = "行情服务器"
                req = HistoryRequest(symbol=symbol,
                                     exchange=exchange,
                                     interval=interval,
                                     start=start,
                                     end=end)
                bars = self.main_engine.query_history(req,
                                                      contract.gateway_name)

            # Try to query bars from RQData, if not found, load from database.
            else:
                data_source = "tdx数据源"
                bars = self.query_bar_from_rq(symbol, exchange, interval,
                                              start, end)

        if not bars:
            data_source = "本地数据库"
            bars = database_manager.load_bar_data(
                symbol=symbol,
                exchange=exchange,
                interval=interval,
                start=start,
                end=end,
            )

        # JinAdd : 添加日志
        self.write_log(f"从{data_source}加载{len(bars)}条数据")
        for bar in bars:
            callback(bar)
예제 #26
0
    def get_bar(self, vt_symbol: str, bar_type: str, interval: Interval, size: int = 200):
        print(vt_symbol, bar_type, interval, size)
        symbol, exchange = extract_vt_symbol(vt_symbol)
        vt_tq_symbol = f"{symbol}.{bar_type}"
        tq_interval = INTERVAL_MAP_VT2TQ.get(interval, None)
        if tq_interval is None:
            raise KeyError("The interval can only be daily, hour or minute")
        bar_name = f"{vt_symbol}_{bar_type}_{interval.value}"
        bars_df = self.data_dict.get(bar_name, None)
        if bars_df is None:
            tq_symbol = vt_symbol_to_tq_symbol(vt_symbol, bar_type)
            print('get_bar_arguments', tq_symbol, tq_interval, size)
            bars_df = self.tqapi.get_kline_serial(tq_symbol, tq_interval, size)
            self.data_dict[bar_name] = bars_df
            print(bars_df)

        for _ix, row in bars_df.iterrows():
            vt_bar = self.to_vt_bar(row, vt_tq_symbol, exchange, interval)
            self.on_tqdata_bar(vt_bar)
예제 #27
0
    def update_portfolio_setting(self,
                                 portfolio_name: str,
                                 model_name: str,
                                 interest_rate: float,
                                 chain_underlying_map: Dict[str, str],
                                 inverse: bool = False,
                                 precision: int = 0) -> None:
        """"""
        portfolio = self.get_portfolio(portfolio_name)

        for chain_symbol, underlying_symbol in chain_underlying_map.items():
            if "LOCAL" in underlying_symbol:
                symbol, exchange = extract_vt_symbol(underlying_symbol)
                contract = ContractData(symbol=symbol,
                                        exchange=exchange,
                                        name="",
                                        product=Product.INDEX,
                                        size=0,
                                        pricetick=0,
                                        gateway_name=APP_NAME)
            else:
                contract = self.main_engine.get_contract(underlying_symbol)
            portfolio.set_chain_underlying(chain_symbol, contract)

        portfolio.set_interest_rate(interest_rate)

        pricing_model = PRICING_MODELS[model_name]
        portfolio.set_pricing_model(pricing_model)
        portfolio.set_inverse(inverse)
        portfolio.set_precision(precision)

        portfolio_settings = self.setting.setdefault("portfolio_settings", {})
        portfolio_settings[portfolio_name] = {
            "model_name": model_name,
            "interest_rate": interest_rate,
            "chain_underlying_map": chain_underlying_map,
            "inverse": inverse,
            "precision": precision
        }
        self.save_setting()
예제 #28
0
파일: engine.py 프로젝트: hun1982qhu/CTA
    def load_bar(self, vt_symbol: str, days: int, interval: Interval,
                 callback: Callable[[BarData], None], use_database: bool):
        """"""
        symbol, exchange = extract_vt_symbol(vt_symbol)
        end = datetime.now(LOCAL_TZ)
        start = end - timedelta(days)
        bars = []

        # Pass gateway and datafeed if use_database set to True
        if not use_database:
            # Query bars from gateway if available
            contract = self.main_engine.get_contract(vt_symbol)

            if contract and contract.history_data:
                req = HistoryRequest(symbol=symbol,
                                     exchange=exchange,
                                     interval=interval,
                                     start=start,
                                     end=end)
                bars = self.main_engine.query_history(req,
                                                      contract.gateway_name)

            # Try to query bars from datafeed, if not found, load from database.
            else:
                bars = self.query_bar_from_datafeed(symbol, exchange, interval,
                                                    start, end)

        if not bars:
            bars = self.database.load_bar_data(
                symbol=symbol,
                exchange=exchange,
                interval=interval,
                start=start,
                end=end,
            )

        for bar in bars:
            callback(bar)
예제 #29
0
def load_symbol_data(vt_symbol, start, end):
    """
	从数据库读取数据
	:param vt_symbol:
	:param start:
	:param end:
	:return:
	"""
    symbol, exchange = extract_vt_symbol(vt_symbol)  # XBTUSD.BITMEX 分离
    start = datetime.strptime(start, "%Y%m%d")
    end = datetime.strptime(end, "%Y%m%d")
    interval = Interval.MINUTE
    data = database_manager.load_bar_data(symbol, exchange, interval, start,
                                          end)  # 数据库读取数据

    dt_list = []
    close_list = []
    for bar in data:
        dt_list.append(bar.datetime)
        close_list.append(bar.close_price)

    s = pd.Series(close_list, index=dt_list)
    return s
예제 #30
0
    def load_bar(self, vt_symbol: str, days: int, interval: Interval,
                 callback: Callable[[BarData], None]):
        """载入历史bar"""
        symbol, exchange = extract_vt_symbol(vt_symbol)
        end = datetime.now()
        start = end - timedelta(days)

        # Query bars from RQData by default, if not found, load from database.
        # TODO
        # 这里CTA的载入历史数据是从米匡那里拿,这个必须修改,没有米匡账号!
        # 初步建议修改为在主引擎中发送query_history拿数据,由gateway回调数据
        # OKEX的历史数据由OKEX提供,FUTURES的历史数据由数据库提供,每个都不一样,因此,不能在这里统一,要改成在gaetway中分发
        bars = self.query_bar_from_rq(symbol, exchange, interval, start, end)
        if not bars:
            bars = database_manager.load_bar_data(
                symbol=symbol,
                exchange=exchange,
                interval=interval,
                start=start,
                end=end,
            )

        for bar in bars:
            callback(bar)