def persist(strategy, results): if not os.path.exists(ConfigUtils.get_stock("OUTPUT")): os.mkdir(ConfigUtils.get_stock("OUTPUT")) with open(ConfigUtils.get_stock("OUTPUT") + "/" + strategy + ".txt", 'w') as wf: for e in results: wf.write(e[0] + "-" + e[1] + "-" + str(e[2]) + '\n')
def update_trades(): try: et = stock_utils.get_recently_trade_date() st = ConfigUtils.get_stock("START_DATE") # 登陆系统 #### lg = bs.login() # 显示登陆返回信息 print('login respond error_code:' + lg.error_code + ', error_msg:' + lg.error_msg) pd_names = pd.read_csv(ConfigUtils.get_stock("STOCK_NAME")) data = list(pd_names.iterrows()) # multi processing number_kernel = 1 size = int((len(data) + number_kernel - 1) / number_kernel) p = Pool(number_kernel) for i in range(number_kernel): start = size * i _end = size * (i + 1) end = len(data) if _end > len(data) else _end p.apply_async(loading_stock, args=(data[start:end], st, et)) p.close() p.join() print('all subprocesses done.') bs.logout() except IOError as e: print("Update Data Error ", e)
def prepare(): dirs = [ConfigUtils.get_stock("DATA_DIR"), ConfigUtils.get_stock("DB_DIR")] for dir in dirs: if os.path.exists(dir): clean_files() return else: os.makedirs(dir)
def clean_files(): for the_file in os.listdir(ConfigUtils.get_stock("DATA_DIR")): file_path = os.path.join(ConfigUtils.get_stock("DATA_DIR"), the_file) try: if os.path.isfile(file_path): os.unlink(file_path) except Exception as e: print(e)
def syncup_transaction(): from api_modules.server.models import db_session stock_path = ConfigUtils.get_stock('STOCK_NAME') stock_pd = pd.read_csv(stock_path, encoding='utf-8') # print(stock_pd.to_dict('index')) for index, elem in stock_pd.to_dict('index').items(): code, name = elem['code'], elem['code_name'] code_name = (code, name) stock_data = stock_utils.read_data(code_name) stock_data.fillna(0, inplace=True) tmp_stock = stock_data.tail(1) try: tras = list() for idx, ele in tmp_stock.to_dict('index').items(): tra = Transaction(code=code, name=name) tra.build_from_dict(ele) tra.set_type(1) tras.append(tra) db_session.add_all(tras) db_session.commit() except IntegrityError as e: pass except InvalidRequestError as e: pass print("syncup transaction done.")
def __init__(self): """可以考虑加入回测数据 """ self.stock_info = pd.read_csv(ConfigUtils.get_stock("STOCKS_DATESET"), encoding='utf-8', low_memory=False) print("finish load dataset.") self.build()
def init_shares(): from api_modules.server.models import db_session industry_path = ConfigUtils.get_stock('STOCK_INDUSTRY') stock_path = ConfigUtils.get_stock('STOCK_NAME') industry_pd = pd.read_csv(industry_path, encoding='utf-8') stock_pd = pd.read_csv(stock_path, encoding='utf-8') share_pd = pd.merge(stock_pd, industry_pd, how='left', on=['code', 'code_name']) shares = list() # init shares for index, row in share_pd.iterrows(): code = row['code'] name = row['code_name'] industry = None if pd.isna(row['industry']) else row['industry'] industryClassification = None if pd.isna(row['industryClassification']) else row['industryClassification'] tmp_share = Share(code=code, name=name, industry=industry, industryClassification=industryClassification) shares.append(tmp_share) # insert db db_session.add_all(shares) db_session.commit()
def run_pipline_stg(filter_kc=True): logging.info( "************************ process start ***************************************" ) pd_names = pd.read_csv(ConfigUtils.get_stock("STOCK_NAME")) strategies = { '海龟交易法则': turtle_trade.check_enter, '潜伏慢阳3d': turtle_trade.continue_increase3, '潜伏慢阳5d': turtle_trade.continue_increase5, '潜伏慢阳7d': turtle_trade.continue_increase7, 'BigInc': turtle_trade.big_inc, 'BreakInc': turtle_trade.break_inc, 'BackInc': turtle_trade.feedback_inc, 'Demon': turtle_trade.demon_inc, '连续含1y': turtle_trade.inc_cyin1, '连续含2y': turtle_trade.inc_cyin2, '放量上涨_1.5': enter.check_volume, '大涨': enter.check_continuous_inc, 'ATR': enter.check_breakthrough, '突破平台': breakthrough_platform.check, '均线多头': keep_increasing.check, '停机坪': parking_apron.check, '回踩年线': backtrace_ma250.check, '低ATR成长策略': low_atr.check_low_increase, # 'pe': pe.check, } stg_result_dict = collections.defaultdict(list) for index, row in pd_names.iterrows(): code = row['code'] name = row['code_name'] if filter_kc and stock_utils.is_jiucaiban(code): continue if "ST" in name: continue code_name = (code, name) print(code_name) data = stock_utils.read_data(code_name) if data is None: continue for strategy, strategy_func in strategies.items(): r = strategy_func(code_name, data) if r and r > 0.0: stg_result_dict[strategy].append((code, name, r)) for strategy, results in stg_result_dict.items(): outs = sorted(results, key=lambda e: e[2], reverse=True) stock_utils.persist(strategy, outs) logging.info( "************************ process end ***************************************" )
def loading_stock(rows, st, et): for index, row in rows: code = row['code'] name = row['code_name'] k_rs = bs.query_history_k_data_plus( code, ConfigUtils.get_stock("STOCK_FIELDS"), start_date=st, end_date=et) data_list = [] while (k_rs.error_code == '0') & k_rs.next(): # 获取一条记录,将记录合并在一起 data_list.append(k_rs.get_row_data()) result = pd.DataFrame(data_list, columns=k_rs.fields) print(result.tail()) if not os.path.exists(ConfigUtils.get_stock("DATA_DIR")): os.makedirs(ConfigUtils.get_stock("DATA_DIR")) result.to_csv(os.path.join(ConfigUtils.get_stock("DATA_DIR"), code + "_" + name + ".csv"), index=False) print("Downloading :" + code + " , name :" + name)
def init_dates(): from api_modules.server.models import db_session date_path = ConfigUtils.get_stock('STOCKS_DATE') date_pd = pd.read_csv(date_path, encoding='utf-8') dates = list() for index, row in date_pd.iterrows(): calendar_date = row['calendar_date'] is_trading_day = int(row['is_trading_day']) tmp_cal = CalendarDate(calendar_date=calendar_date, is_trading_day=is_trading_day) dates.append(tmp_cal) # insert db db_session.add_all(dates) db_session.commit()
def get_recently_trade_date(): dt = datetime.date.today() if os.path.exists(ConfigUtils.get_stock("STOCKS_DATE")): trade_dates = pd.read_csv(ConfigUtils.get_stock("STOCKS_DATE"), header=0) trade_date_dict = trade_dates.set_index("calendar_date")['is_trading_day'].to_dict() tmp_dt = str(dt) if tmp_dt in trade_date_dict: if trade_date_dict[tmp_dt] == 1: return tmp_dt else: dt_num = 1 dt_pass = str(dt - datetime.timedelta(days=dt_num)) while dt_pass in trade_date_dict and trade_date_dict[dt_pass] == 0: dt_num += 1 dt_pass = str(dt - datetime.timedelta(days=dt_num)) if dt_pass in trade_date_dict: return dt_pass print("Date Is Not Exist !!!, Reload Trade Dates. ") else: print("Date Is Not Exist, Reloading Trade Dates. ") init_trade_date() print("Date Loading Finish. ") return None
def init_trade_date(): # 登陆系统 #### lg = bs.login() # 显示登陆返回信息 print('login respond error_code:' + lg.error_code) print('login respond error_msg:' + lg.error_msg) # 获取交易日信息 #### st = ConfigUtils.get_stock("START_DATE") et = ConfigUtils.get_stock("END_DATE") print(st, et) rs = bs.query_trade_dates(start_date=st, end_date=et) print('query_trade_dates respond error_code:' + rs.error_code) print('query_trade_dates respond error_msg:' + rs.error_msg) # 打印结果集 #### data_list = [] while (rs.error_code == '0') & rs.next(): # 获取一条记录,将记录合并在一起 data_list.append(rs.get_row_data()) result = pd.DataFrame(data_list, columns=rs.fields) result.to_csv(ConfigUtils.get_stock("STOCKS_DATE"), index=False) # 结果集输出到csv文件 ####import ujson bs.logout()
def get_stocks(config=None): if config: data = xlrd.open_workbook(config) table = data.sheets()[0] rows_count = table.nrows codes = table.col_values(0)[1:rows_count - 1] names = table.col_values(1)[1:rows_count - 1] return list(zip(codes, names)) else: data_files = os.listdir(ConfigUtils.get_stock('DATA_DIR')) stocks = [] for file in data_files: code_name = file.split(".")[0] code = code_name.split("-")[0] name = code_name.split("-")[1] appender = (code, name) stocks.append(appender) return stocks
def init_transaction(): from api_modules.server.models import db_session stock_path = ConfigUtils.get_stock('STOCK_NAME') stock_pd = pd.read_csv(stock_path, encoding='utf-8') # print(stock_pd.to_dict('index')) for index, elem in stock_pd.to_dict('index').items(): code, name = elem['code'], elem['code_name'] code_name = (code, name) stock_data = stock_utils.read_data(code_name) stock_data.fillna(0, inplace=True) tras = list() for idx, ele in stock_data.to_dict('index').items(): tra = Transaction(code=code, name=name) tra.build_from_dict(ele) tra.set_type(1) tras.append(tra) db_session.add_all(tras) db_session.commit()
def get_all_stock_names(): # 登陆系统 #### lg = bs.login() # 显示登陆返回信息 print('login respond error_code:' + lg.error_code + ', error_msg:' + lg.error_msg) dt = stock_utils.get_recently_trade_date() dt = '2020-08-03' k_rs = bs.query_all_stock(day=dt) print(k_rs) data_list = [] while (k_rs.error_code == '0') & k_rs.next(): # 获取一条记录,将记录合并在一起 data_list.append(k_rs.get_row_data()) result = pd.DataFrame(data_list, columns=k_rs.fields) print(result.tail()) result.to_csv(ConfigUtils.get_stock("STOCK_NAME"), index=False) print("init all stock names") bs.logout()
def trade_run(): cerebro = bt.Cerebro() data = DBDataFeed( # 本地postgresql数据库 db_uri=ConfigUtils.get_mysql('engine'), dataname="sh.600004", fromdate=datetime.date(2010, 1, 1), ) cerebro.adddata(data) cerebro.addstrategy(SMACross) cerebro.addsizer(bt.sizers.AllInSizerInt) cerebro.broker.set_cash(100000) cerebro.addanalyzer(bt.analyzers.AnnualReturn, _name="annual_returns") cerebro.addanalyzer(bt.analyzers.DrawDown, _name="draw_down") cerebro.addanalyzer(bt.analyzers.Transactions, _name="transactions") results = cerebro.run() # 打印Analyzer结果到日志 for result in results: annual_returns = result.analyzers.annual_returns.get_analysis() log.info("annual returns:") for year, ret in annual_returns.items(): log.info("\t {} {}%, ".format(year, round(ret * 100, 2))) draw_down = result.analyzers.draw_down.get_analysis() log.info( "drawdown={drawdown}%, moneydown={moneydown}, drawdown len={len}, " "max.drawdown={max.drawdown}, max.moneydown={max.moneydown}, " "max.len={max.len}".format(**draw_down)) transactions = result.analyzers.transactions.get_analysis() log.info("transactions") # 运行结果绘图 # cerebro.plot() b = Bokeh(style="bar", tabs="multi", scheme=Tradimo()) cerebro.plot(b)
def get_all_stock_industries(): lg = bs.login() print('login respond error_code:' + lg.error_code) print('login respond error_msg:' + lg.error_msg) # 获取行业分类数据 rs = bs.query_stock_industry(date='2020-08-01') # rs = bs.query_stock_basic(code_name="浦发银行") print('query_stock_industry error_code:' + rs.error_code) print('query_stock_industry respond error_msg:' + rs.error_msg) # 打印结果集 industry_list = [] while (rs.error_code == '0') & rs.next(): # 获取一条记录,将记录合并在一起 industry_list.append(rs.get_row_data()) result = pd.DataFrame(industry_list, columns=rs.fields) # 结果集输出到csv文件 result.to_csv(ConfigUtils.get_stock("STOCK_INDUSTRY"), index=False) print(result) # 登出系统 bs.logout()
def read_data(code_name): code = code_name[0] name = code_name[1] df = None file_name = str(code) + '_' + str(name) + '.csv' file_path = ConfigUtils.get_stock("DATA_DIR") + "/" + file_name if os.path.exists(file_path): try: df = pd.read_csv(file_path) except pd.errors.EmptyDataError as e: df = None pass if df is not None and not df.empty: # print(df.keys()) df["open"] = df['open'].astype(float) df["high"] = df["high"].astype(float) df["low"] = df["low"].astype(float) df["close"] = df["close"].astype(float) df["preclose"] = df["preclose"].astype(float) df["volume"] = df["volume"].astype(float) df["pctChg"] = df["pctChg"].astype(float) return df return None
# -*- encoding: UTF-8 -*- import random import datetime import pandas as pd import numpy as np from sklearn import metrics from api_modules.server.utils import stock_utils from api_modules.server.utils.config import ConfigUtils from api_modules.server.sequoia.stg import LGB from api_modules.server.models import get_recently_trade_date, Account from api_modules.server.apis import user_api from api_modules.server.apis import account_api root = ConfigUtils.get_stock("DATA") model_root = ConfigUtils.get_stock("MODELS") model_name = 'lgb.pkl' thresh_hold = 0.6 trn_date_min = '2010-01-01' trn_date_max = '2020-07-31' val_date_min = '2020-08-03' val_date_max = '2020-08-07' test_date_min = '2020-08-08' test_date_max = '2020-08-21' feature_cols = [ 'close_transform', 'open_transform', 'high_transform', 'low_transform' ] trn_col = [
def run_gen_ds(): gen_dataset() click.echo("finish dataset {}!".format( ConfigUtils.get_stock("STOCKS_DATESET")))
def gen_dataset(): start_dt = '2019-01-01' date_info = pd.read_csv(ConfigUtils.get_stock("STOCKS_DATE"), encoding='utf-8') industry_info = pd.read_csv(ConfigUtils.get_stock("STOCK_INDUSTRY"), encoding='utf-8') company_info = pd.read_csv(ConfigUtils.get_stock("STOCK_NAME"), encoding='utf-8') company_info = company_info.merge(industry_info, how='left', on=['code', 'code_name']) # 时序处理 dt = datetime.date.today() tmp_list = sorted([row['calendar_date'] for idx, row in date_info.iterrows() if row['calendar_date'] < str(dt) and row['calendar_date'] >= start_dt and row['is_trading_day'] == 1], reverse=True) date_map = dict(zip(tmp_list, range(len(tmp_list)))) # 读取股票交易信息 stock_info = pd.DataFrame() remove_stock = [] tmp_list = [] for i, row in tqdm.tqdm(company_info.iterrows()): code, name = row["code"], row["code_name"] path = os.path.join(ConfigUtils.get_stock("DATA_DIR"), code + "_" + name + ".csv") if not os.path.exists(path): continue tmp_df = pd.read_csv(path) tmp_df = tmp_df[tmp_df.date >= start_dt] if len(tmp_df) < 60 or code.startswith('sz.300') or code.startswith('sh.688'):# 去除一些上市不久的企业 688 300 remove_stock.append(code) continue tmp_df = tmp_df.sort_values('date', ascending=True).reset_index() tmp_list.append(tmp_df) stock_info = pd.concat(tmp_list) ts_code_map = dict(zip(stock_info['code'].unique(), range(stock_info['code'].nunique()))) stock_info = stock_info.reset_index() stock_info['ts_code_id'] = stock_info['code'].map(ts_code_map) stock_info['trade_date_id'] = stock_info['date'].map(date_map) stock_info['ts_date_id'] = (10000 + stock_info['ts_code_id']) * 10000 + stock_info['trade_date_id'] stock_info = stock_info.merge(company_info, how='left', on='code') # 特征工程 col = ['close', 'open', 'high', 'low'] feature_col = [] for tmp_col in col: stock_info[tmp_col+'_'+'transform'] = (stock_info[tmp_col] - stock_info['preclose']) / stock_info['preclose'] feature_col.append(tmp_col+'_'+'transform') # 提取前5天收盘价与今天收盘价的盈亏比 for i in range(5): tmp_df = pd.DataFrame(stock_info, columns=['ts_date_id', 'close']) tmp_df = tmp_df.rename(columns={'close':'close_shift_{}'.format(i+1)}) feature_col.append('close_shift_{}'.format(i+1)) tmp_df['ts_date_id'] = tmp_df['ts_date_id'] + i + 1 stock_info = stock_info.merge(tmp_df, how='left', on='ts_date_id') stock_info.drop('level_0', axis=1, inplace=True) for i in range(5): stock_info['close_shift_{}'.format(i+1)] = (stock_info['close'] - stock_info['close_shift_{}'.format(i+1)]) / stock_info['close_shift_{}'.format(i+1)] # print(stock_info) # stock_info.dropna(inplace=True) # 标签制作 # make_label use_col = [] for i in range(5): tmp_df = stock_info[['ts_date_id', 'high', 'low']] tmp_df = tmp_df.rename(columns={'high':'high_shift_{}'.format(i+1), 'low':'low_shift_{}'.format(i+1)}) use_col.append('high_shift_{}'.format(i+1)) use_col.append('low_shift_{}'.format(i+1)) tmp_df['ts_date_id'] = tmp_df['ts_date_id'] - i - 1 stock_info = stock_info.merge(tmp_df, how='left', on='ts_date_id') #stock_info.dropna(inplace=True) for i in range(5): stock_info['high_shift_{}'.format(i+1)] = (stock_info['high_shift_{}'.format(i+1)] - stock_info['close']) / stock_info['close'] stock_info['low_shift_{}'.format(i+1)] = (stock_info['low_shift_{}'.format(i+1)] - stock_info['close']) / stock_info['close'] tmp_array = stock_info[use_col].values max_increse = np.max(tmp_array, axis=1) min_increse = np.min(tmp_array, axis=1) stock_info['label_max'] = max_increse stock_info['label_min'] = min_increse stock_info['change'] = (stock_info['high'] - stock_info['low']) / stock_info['preclose'] stock_info['label_final'] = (stock_info['label_max'] > 0.06) & (stock_info['label_min'] > -0.03) stock_info['label_final'] = stock_info['label_final'].apply(lambda x: int(x)) # print(stock_info[stock_info.date == '2020-08-21']) # print(stock_info[stock_info.label_final == 1]) # stock_info = stock_info.reset_index() stock_info = stock_info.reset_index() stock_info.drop('index', axis=1, inplace=True) stock_info.to_csv(ConfigUtils.get_stock("STOCKS_DATESET"), index=False)
from flask import current_app from sqlalchemy import exc, Column, Integer, FLOAT, String, DateTime, create_engine, distinct, and_, PrimaryKeyConstraint import datetime from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.ext.declarative import declarative_base from api_modules.server.utils.config import ConfigUtils app = current_app # 连接数据库 engine = create_engine(ConfigUtils.get_mysql('engine'), echo=False, pool_size=100, pool_recycle=3600) # 基本类 Base = declarative_base() session_factory = sessionmaker(autocommit=False, autoflush=False, bind=engine) db_session = scoped_session(session_factory) class Fund(Base): __tablename__ = 'funds' __table_args__ = {"mysql_charset" : "utf8"} id = Column(Integer, primary_key=True) code = Column(String(40), unique=True) name = Column(String(40), unique=True) type = Column(String(20), default=None) scale = Column(FLOAT, default=None) positions = Column(String(20000), default=None) update_date = Column(DateTime, default=datetime.datetime.now) create_date = Column(DateTime, default=datetime.datetime.now)