async def _build_train_data(self, frame_type: FrameType, n: int, max_error: float = 0.01): """ 从最近的符合条件的日期开始,遍历股票,提取特征和标签,生成数据集。 Args: n: 需要采样的样本数 Returns: """ watch_win = 5 max_curve_len = 5 max_ma_win = 20 # y_stop = arrow.get('2020-7-24').date() y_stop = tf.floor(arrow.now(tz=cfg.tz), frame_type) y_start = tf.shift(y_stop, -watch_win + 1, frame_type) x_stop = tf.shift(y_start, -1, frame_type) x_start = tf.shift(x_stop, -(max_curve_len + max_ma_win - 1), frame_type) data = [] while len(data) < n: for code in Securities().choose(['stock']): #for code in ['000601.XSHE']: try: sec = Security(code) x_bars = await sec.load_bars(x_start, x_stop, FrameType.DAY) y_bars = await sec.load_bars(y_start, y_stop, FrameType.DAY) # [a, b, axis] * 3 x = self.extract_features(x_bars, max_error) if len(x) == 0: continue y = np.max(y_bars['close']) / x_bars[-1]['close'] - 1 if np.isnan(y): continue feature = [code, tf.date2int(x_stop)] feature.extend(x) data.append(feature) except Exception as e: logger.warning("Failed to extract features for %s (%s)", code, x_stop) logger.exception(e) if len(data) >= n: break if len(data) % 500 == 0: logger.info("got %s records.", len(data)) y_stop = tf.day_shift(y_stop, -1) y_start = tf.day_shift(y_start, -1) x_stop = tf.day_shift(y_start, -1) x_start = tf.day_shift(x_start, -1) return data
async def test_buy_limit_events(self): end = arrow.get('2020-8-7').date() start = tf.day_shift(end, -9) sec = Security('603390.XSHG') bars = await sec.load_bars(start, end, FrameType.DAY) count, indices = count_buy_limit_event(sec, bars) self.assertEqual(count, 1) self.assertEqual( arrow.get('2020-7-28').date(), bars['frame'][indices[0]]) sec = Security('000070.XSHE') start = tf.day_shift(end, -29) bars = await sec.load_bars(start, end, FrameType.DAY) count_buy_limit_event(sec, bars)
async def predict(self, code, x_end_date: datetime.date, max_error: float = 0.01): sec = Security(code) start = tf.day_shift(x_end_date, -29) bars = await sec.load_bars(start, x_end_date, FrameType.DAY) features = self.extract_features(bars, max_error) if len(features) == 0: logger.warning("cannot extract features from %s(%s)", code, x_end_date) else: return self.model.predict([features])
async def test_cross(self): end = arrow.get('2020-7-24').date() start = tf.day_shift(end, -270) sec = Security('000035.XSHE') jlkg = await sec.load_bars(start, end, FrameType.DAY) ma5 = signal.moving_average(jlkg['close'], 5) ma250 = signal.moving_average(jlkg['close'], 250) flag, idx = signal.cross(ma5[-10:], ma250[-10:]) self.assertEqual(flag, -1) self.assertEqual(idx, 8)
def parse_sync_params( frame: Union[str, Frame], cat: List[str] = None, start: Union[str, datetime.date] = None, stop: Union[str, Frame] = None, delay: int = 0, include: str = "", exclude: str = "", ) -> Tuple: """按照[使用手册](usage.md#22-如何同步K线数据)中的规则,解析和补全同步参数。 如果`frame_type`为分钟级,则当`start`指定为`date`类型时,自动更正为对应交易日的起始帧; 当`stop`为`date`类型时,自动更正为对应交易日的最后一帧。 Args: frame (Union[str, Frame]): frame type to be sync. The word ``frame`` is used here for easy understand by end user. It actually implies "FrameType". cat (List[str]): which catetories is about to be synced. Should be one of ['stock', 'index']. Defaults to None. start (Union[str, datetime.date], optional): [description]. Defaults to None. stop (Union[str, Frame], optional): [description]. Defaults to None. delay (int, optional): [description]. Defaults to 5. include (str, optional): which securities should be included, seperated by space, for example, "000001.XSHE 000004.XSHE". Defaults to empty string. exclude (str, optional): which securities should be excluded, seperated by a space. Defaults to empty string. Returns: - codes (List[str]): 待同步证券列表 - frame_type (FrameType): - start (Frame): - stop (Frame): - delay (int): """ frame_type = FrameType(frame) if frame_type in tf.minute_level_frames: if stop: stop = arrow.get(stop, tzinfo=cfg.tz) if stop.hour == 0: # 未指定有效的时间帧,使用当日结束帧 stop = tf.last_min_frame(tf.day_shift(stop.date(), 0), frame_type) else: stop = tf.floor(stop, frame_type) else: stop = tf.floor(arrow.now(tz=cfg.tz).datetime, frame_type) if stop > arrow.now(tz=cfg.tz): raise ValueError(f"请勿将同步截止时间设置在未来: {stop}") if start: start = arrow.get(start, tzinfo=cfg.tz) if start.hour == 0: # 未指定有效的交易帧,使用当日的起始帧 start = tf.first_min_frame(tf.day_shift(start.date(), 0), frame_type) else: start = tf.floor(start, frame_type) else: start = tf.shift(stop, -999, frame_type) else: stop = (stop and arrow.get(stop).date()) or arrow.now().date() if stop == arrow.now().date(): stop = arrow.now(tz=cfg.tz) stop = tf.floor(stop, frame_type) start = tf.floor( (start and arrow.get(start).date()), frame_type) or tf.shift( stop, -1000, frame_type) secs = Securities() codes = secs.choose(cat or []) exclude = map(lambda x: x, exclude.split(" ")) codes = list(set(codes) - set(exclude)) include = list(filter(lambda x: x, include.split(" "))) codes.extend(include) return codes, frame_type, start, stop, int(delay)