class CInvestor(object): def __init__(self, dbinfo = ct.DB_INFO, redis_host = None): self.dbname = self.get_dbname() self.table = self.get_table_name() self.redis = create_redis_obj() if redis_host is None else create_redis_obj(host = redis_host) self.mysql_client = CMySQL(dbinfo, self.dbname, iredis = self.redis) @staticmethod def get_dbname(): return "stock" @staticmethod def get_table_name(): return "month_investor" def get_data(self, mdate = None): table_name = self.get_table_name() if mdate is not None: sql = "select * from %s where date=\"%s\"" %(table_name, mdate) else: sql = "select * from %s" % table_name return self.mysql_client.get(sql) def get_data_in_range(self, start_date, end_date): table_name = self.get_table_name() sql = "select * from %s where date between \"%s\" and \"%s\"" %(table_name, start_date, end_date) return self.mysql_client.get(sql)
class MonthInvestorCrawler(object): def __init__(self, dbinfo = ct.DB_INFO, redis_host = None): self.dbname = self.get_dbname() self.table = self.get_table_name() self.redis = create_redis_obj() if redis_host is None else create_redis_obj(host = redis_host) self.mysql_client = CMySQL(dbinfo, self.dbname, iredis = self.redis) if not self.mysql_client.create_db(self.dbname): raise Exception("init stock database failed") if not self.create_table(): raise Exception("init month investor table failed") @staticmethod def get_dbname(): return "stock" @staticmethod def get_table_name(): return "month_investor" def create_table(self): sql = 'create table if not exists %s(date varchar(10) not null,\ new_investor float,\ new_natural_person float,\ new_non_natural_person float,\ final_investor float,\ final_natural_person float,\ final_natural_a_person float,\ final_natural_b_person float,\ final_non_natural_person float,\ final_non_natural_a_person float,\ final_non_natural_b_person float,\ final_hold_investor float,\ final_a_hold_investor float,\ final_b_hold_investor float,\ trading_investor float,\ trading_a_investor float,\ trading_b_investor float,\ PRIMARY KEY(date))' % self.table return True if self.table in self.mysql_client.get_all_tables() else self.mysql_client.create(sql, self.table) def get_data(self, mdate): table_name = self.get_table_name() if date is not None: sql = "select * from %s where date=\"%s\"" %(table_name, date) else: sql = "select * from %s" % table_name return self.mysql_client.get(sql) def get_data_in_range(self, start_date, end_date): table_name = self.get_table_name() sql = "select * from %s where date between \"%s\" and \"%s\"" %(table_name, start_date, end_date) return self.mysql_client.get(sql)
class RIndexStock: def __init__(self, dbinfo=ct.DB_INFO, redis_host=None): self.redis = create_redis_obj( ) if redis_host is None else create_redis_obj(host=redis_host) self.dbname = self.get_dbname() self.redis_host = redis_host self.logger = getLogger(__name__) self.mysql_client = CMySQL(dbinfo, self.dbname, iredis=self.redis) if not self.mysql_client.create_db(self.get_dbname()): raise Exception("init rstock database failed") @staticmethod def get_dbname(): return ct.RINDEX_STOCK_INFO_DB def get_table_name(self, cdate): cdates = cdate.split('-') return "%s_day_%s_%s" % (self.get_dbname(), cdates[0], (int(cdates[1]) - 1) // 3 + 1) def is_date_exists(self, table_name, cdate): if self.redis.exists(table_name): return cdate in set( str(tdate, encoding="utf8") for tdate in self.redis.smembers(table_name)) return False def is_table_exists(self, table_name): if self.redis.exists(self.dbname): return table_name in set( str(table, encoding="utf8") for table in self.redis.smembers(self.dbname)) return False def create_table(self, table): sql = 'create table if not exists %s(date varchar(10) not null,\ code varchar(10) not null,\ open float,\ high float,\ close float,\ preclose float,\ low float,\ volume float,\ amount float,\ outstanding float,\ totals float,\ adj float,\ aprice float,\ pchange float,\ turnover float,\ sai float,\ sri float,\ uprice float,\ sprice float,\ mprice float,\ lprice float,\ ppercent float,\ npercent float,\ base float,\ ibase bigint,\ breakup int,\ ibreakup bigint,\ pday int,\ profit float,\ gamekline float,\ PRIMARY KEY (date, code))' % table return True if table in self.mysql_client.get_all_tables( ) else self.mysql_client.create(sql, table) def get_k_data_in_range(self, start_date, end_date): ndays = delta_days(start_date, end_date) date_dmy_format = time.strftime("%m/%d/%Y", time.strptime(start_date, "%Y-%m-%d")) data_times = pd.date_range(date_dmy_format, periods=ndays, freq='D') date_only_array = np.vectorize(lambda s: s.strftime('%Y-%m-%d'))( data_times.to_pydatetime()) data_dict = OrderedDict() for _date in date_only_array: if CCalendar.is_trading_day(_date, redis=self.redis): table_name = self.get_table_name(_date) if table_name not in data_dict: data_dict[table_name] = list() data_dict[table_name].append(str(_date)) all_df = pd.DataFrame() for key in data_dict: table_list = sorted(data_dict[key], reverse=False) if len(table_list) == 1: df = self.get_data(table_list[0]) if df is not None: all_df = all_df.append(df) else: start_date = table_list[0] end_date = table_list[len(table_list) - 1] df = self.get_data_between(start_date, end_date) if df is not None: all_df = all_df.append(df) return all_df def get_data_between(self, start_date, end_date): #start_date and end_date should be in the same table sql = "select * from %s where date between \"%s\" and \"%s\"" % ( self.get_table_name(start_date), start_date, end_date) return self.mysql_client.get(sql) def get_data(self, cdate=datetime.now().strftime('%Y-%m-%d')): sql = "select * from %s where date=\"%s\"" % ( self.get_table_name(cdate), cdate) return self.mysql_client.get(sql) def get_stock_data(self, cdate, code): return (code, CStock(code).get_k_data(cdate)) def generate_all_data_1(self, cdate, black_list=list()): failed_list = CStockInfo(redis_host=self.redis_host).get( redis=self.redis).code.tolist() if len(black_list) > 0: failed_list = list(set(failed_list).difference(set(black_list))) cfunc = partial(self.get_stock_data, cdate) return queue_process_concurrent_run(cfunc, failed_list, redis_client=self.redis) def generate_all_data(self, cdate, black_list=ct.BLACK_LIST): from gevent.pool import Pool obj_pool = Pool(5000) failed_list = CStockInfo(redis_host=self.redis_host).get( redis=self.redis).code.tolist() if len(black_list) > 0: failed_list = list(set(failed_list).difference(set(black_list))) all_df = pd.DataFrame() last_length = len(failed_list) cfunc = partial(self.get_stock_data, cdate) while last_length > 0: self.logger.info("all stock list:%s, cdate:%s", len(failed_list), cdate) for code_data in obj_pool.imap_unordered(cfunc, failed_list): if code_data[1] is not None: tem_df = code_data[1] tem_df['code'] = code_data[0] all_df = all_df.append(tem_df) failed_list.remove(code_data[0]) if len(failed_list) != last_length: self.logger.debug( "last failed list:%s, current failed list:%s" % (last_length, len(failed_list))) last_length = len(failed_list) else: if last_length > 0: time.sleep(600) obj_pool.join(timeout=5) obj_pool.kill() all_df = all_df.drop_duplicates() all_df = all_df.sort_values(by='date', ascending=True) all_df = all_df.reset_index(drop=True) return all_df def update(self, end_date=datetime.now().strftime('%Y-%m-%d'), num=30): #if end_date == datetime.now().strftime('%Y-%m-%d'): end_date = get_day_nday_ago(end_date, num = 1, dformat = "%Y-%m-%d") start_date = get_day_nday_ago(end_date, num=num, dformat="%Y-%m-%d") date_array = get_dates_array(start_date, end_date) succeed = True count = 0 for mdate in date_array: count += 1 if CCalendar.is_trading_day(mdate, redis=self.redis): if not self.set_day_data(mdate): self.logger.error("set %s data for rstock failed" % mdate) succeed = False return succeed def set_day_data(self, cdate): table_name = self.get_table_name(cdate) if not self.is_table_exists(table_name): if not self.create_table(table_name): self.logger.error("create tick table failed") return False self.redis.sadd(self.dbname, table_name) if self.is_date_exists(table_name, cdate): self.logger.debug("existed table:%s, date:%s" % (table_name, cdate)) return True df = self.generate_all_data(cdate) if df is None: return False if self.mysql_client.set(df, table_name): self.redis.sadd(table_name, cdate) return True return False
class CLimit: def __init__(self, dbinfo=ct.DB_INFO, redis_host=None): self.table = self.get_table_name() self.logger = getLogger(__name__) self.redis = create_redis_obj( ) if redis_host is None else create_redis_obj(redis_host) self.mysql_client = CMySQL(dbinfo, iredis=self.redis) self.header = { "Host": "home.flashdata2.jrj.com.cn", "Referer": "http://stock.jrj.com.cn/tzzs/zdtwdj/zdforce.shtml", "User-Agent": "Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/67.0.3396.99 Safari/537.36" } if not self.create(): raise Exception("create stock %s failed" % self.table) def create(self): if self.table not in self.mysql_client.get_all_tables(): sql = 'create table if not exists %s(date varchar(10) not null,\ code varchar(6) not null,\ price float,\ pchange float,\ prange float,\ concept varchar(50),\ fcb float,\ flb float,\ fdmoney float,\ first_time varchar(20),\ last_time varchar(20),\ open_times varchar(20),\ intensity float,\ PRIMARY KEY (date, code))' % self.table return True if self.mysql_client.create(sql, self.table) else False return True @staticmethod def get_table_name(): return ct.LIMIT_TABLE def get_useful_columns(self, dtype): if dtype == ct.LIMIT_UP or dtype == ct.LIMIT_DOWN: return ['代码', '所属概念'] else: return [ '代码', '价格', '涨跌幅', '振幅', '封成比', '封流比', '封单金额', '首次涨跌停时间', '最后涨跌停时间', '开板次数', '强度' ] def get_columns(self, dtype): #封成比 = 封单金额/日成交额 #封流比 = 封单手数/流通股本 if ct.LIMIT_UP == dtype or ct.LIMIT_DOWN == dtype: return [ '代码', '名称', '涨跌停时间', '价格', '涨跌幅', '成交额', '振幅', '换手率', '五日涨跌幅', '无用', '所属概念(代码)', '所属概念' ] else: return [ '代码', '名称', '价格', '涨跌幅', '封成比', '封流比', '封单金额', '首次涨跌停时间', '最后涨跌停时间', '开板次数', '振幅', '强度' ] def get_url(self, dtype, date): if ct.LIMIT_UP == dtype: return ct.LIMIT_URL_PRIFIX + "zt/%s" % date + ct.LIMIT_URL_MID + str( int(round(time.time() * 1000))) elif ct.LIMIT_DOWN == dtype: return ct.LIMIT_URL_PRIFIX + "dt/%s" % date + ct.LIMIT_URL_MID + str( int(round(time.time() * 1000))) elif ct.LIMIT_UP_INTENSITY == dtype: return ct.LIMIT_URL_PRIFIX + "ztForce/%s" % date + ct.LIMIT_URL_MID + str( int(round(time.time() * 1000))) else: return ct.LIMIT_URL_PRIFIX + "dtForce/%s" % date + ct.LIMIT_URL_MID + str( int(round(time.time() * 1000))) def get_data_from_url(self, date, dtype, retry_times=5): for i in range(retry_times): try: response = requests.get(self.get_url(dtype, date), headers=self.header) if response.status_code == 200: content = response.text return content except Exception as e: self.logger.error(e) time.sleep(i * 2) return def convert_to_json(self, content): if 0 == len(content): return None p = re.compile(r'"Data":(.*)};', re.S) result = p.findall(content) if result: try: return json.loads(result[0]) except Exception as e: self.logger.info(e) def gen_df(self, dtype, cdate): cdate = datetime.strptime(cdate, "%Y-%m-%d").strftime("%Y%m%d") table = ct.LIMIT_UP if dtype == "UP" else ct.LIMIT_DOWN data = self.get_data_from_url(cdate, table) if data is None: return limit_up_json_obj = self.convert_to_json(data) limit_up_df = pd.DataFrame(limit_up_json_obj, columns=self.get_columns(table)) limit_up_df = limit_up_df[self.get_useful_columns(table)] intense_table = ct.LIMIT_UP_INTENSITY if dtype == "UP" else ct.LIMIT_DOWN_INTENSITY data = self.get_data_from_url(cdate, intense_table) if data is None: return limit_up_intensity_json_obj = self.convert_to_json(data) limit_up_intensity_df = pd.DataFrame( limit_up_intensity_json_obj, columns=self.get_columns(intense_table)) limit_up_intensity_df = limit_up_intensity_df[self.get_useful_columns( intense_table)] df = pd.merge(limit_up_intensity_df, limit_up_df, how='left', on=['代码']) df.replace(np.inf, 1000, inplace=True) return df def get_data(self, cdate=None): if cdate is None: cdate = datetime.now().strftime('%Y-%m-%d') return self.mysql_client.get('select * from %s where date=\"%s\"' % (self.table, cdate)) def crawl_data(self, cdate): if self.is_date_exists(self.table, cdate): self.logger.debug("existed table:%s, date:%s" % (self.table, cdate)) return True df_up = self.gen_df("UP", cdate) df_down = self.gen_df("DOWN", cdate) if df_up is None or df_down is None: return False df = pd.concat([df_up, df_down]) if df.empty: return False df = df.reset_index(drop=True) df.columns = [ 'code', 'price', 'pchange', 'prange', 'fcb', 'flb', 'fdmoney', 'first_time', 'last_time', 'open_times', 'intensity', 'concept' ] df['date'] = cdate if self.mysql_client.set(df, self.table): return self.redis.sadd(self.table, cdate) return False def is_date_exists(self, table_name, cdate): if self.redis.exists(table_name): return cdate in set(tdate.decode() for tdate in self.redis.smembers(table_name)) return False def update(self, end_date=None): if end_date is None: end_date = datetime.now().strftime('%Y-%m-%d') start_date = get_day_nday_ago(end_date, num=205, dformat="%Y-%m-%d") date_array = get_dates_array(start_date, end_date) succeed = True for mdate in date_array: if CCalendar.is_trading_day(mdate, redis=self.redis): #if mdate == end_date: continue if not self.crawl_data(mdate): self.logger.error("%s set failed" % mdate) succeed = False return succeed
class CStock(TickerHandlerBase): def __init__(self, dbinfo, code): self.code = code self.dbname = self.get_dbname(code) self.redis = create_redis_obj() self.name = self.get('name') self.data_type_dict = {9: "day"} self.influx_client = CInflux(ct.IN_DB_INFO, self.dbname) self.mysql_client = CMySQL(dbinfo, self.dbname) if not self.create(): raise Exception("create stock %s table failed" % self.code) @staticmethod def get_dbname(code): return "s%s" % code @staticmethod def get_redis_name(code): return "realtime_%s" % code def on_recv_rsp(self, rsp_pb): '''获取逐笔 get_rt_ticker 和 TickerHandlerBase''' ret, data = super(CStock, self).on_recv_rsp(rsp_pb) return ret, data def has_on_market(self, cdate): time2Market = self.get('timeToMarket') if str(time2Market) == '0': return False t = time.strptime(str(time2Market), "%Y%m%d") y, m, d = t[0:3] time2Market = datetime(y, m, d) t = time.strptime(cdate, "%Y-%m-%d") y, m, d = t[0:3] time4Date = datetime(y, m, d) return True if (time4Date - time2Market).days > 0 else False def is_subnew(self, time2Market=None, timeLimit=365): if time2Market == '0': return False #for stock has not been in market if time2Market == None: time2Market = self.get('timeToMarket') t = time.strptime(time2Market, "%Y%m%d") y, m, d = t[0:3] time2Market = datetime(y, m, d) return True if (datetime.today() - time2Market).days < timeLimit else False def create(self): self.create_influx_db() return self.create_mysql_table() def create_influx_db(self): self.influx_client.create() def create_mysql_table(self): for _, table_name in self.data_type_dict.items(): if table_name not in self.mysql_client.get_all_tables(): sql = 'create table if not exists %s(cdate varchar(10) not null, open float, high float, close float, low float, volume float, amount float, PRIMARY KEY(cdate))' % table_name if not self.mysql_client.create(sql, table_name): return False return True def create_ticket_table(self, table): sql = 'create table if not exists %s(date varchar(10) not null, ctime varchar(8) not null, price float(5,2), cchange varchar(10) not null, volume int not null, amount int not null, ctype varchar(9) not null, PRIMARY KEY (date, ctime, cchange, volume, amount, ctype))' % table return True if table in self.mysql_client.get_all_tables( ) else self.mysql_client.create(sql, table) def get(self, attribute): df_byte = self.redis.get(ct.STOCK_INFO) if df_byte is None: return None df = _pickle.loads(df_byte) if len(df.loc[df.code == self.code][attribute].values) == 0: return None return df.loc[df.code == self.code][attribute].values[0] def run(self, data): self.redis.set(self.get_redis_name(self.code), _pickle.dumps(data.tail(1), 2)) self.influx_client.set(data) def merge_ticket(self, df): ex = df[df.duplicated( subset=['ctime', 'cchange', 'volume', 'amount', 'ctype'], keep=False)] dlist = list(ex.index) while len(dlist) > 0: snum = 1 sindex = dlist[0] for _index in range(1, len(dlist)): if sindex + 1 == dlist[_index]: snum += 1 if _index == len(dlist) - 1: df.drop_duplicates(keep='first', inplace=True) df.at[sindex, 'volume'] = snum * df.loc[sindex]['volume'] df.at[sindex, 'amount'] = snum * df.loc[sindex]['amount'] else: df.drop_duplicates(keep='first', inplace=True) df.at[sindex, 'volume'] = snum * df.loc[sindex]['volume'] df.at[sindex, 'amount'] = snum * df.loc[sindex]['amount'] sindex = dlist[_index] snum = 1 df = df.reset_index(drop=True) ex = df[df.duplicated( subset=['ctime', 'cchange', 'volume', 'amount', 'ctype'], keep=False)] dlist = list(ex.index) return df def get_market(self): if (self.code.startswith("6") or self.code.startswith("500") or self.code.startswith("550") or self.code.startswith("510")) or self.code.startswith("7"): return ct.MARKET_SH elif (self.code.startswith("00") or self.code.startswith("30") or self.code.startswith("150") or self.code.startswith("159")): return ct.MARKET_SZ else: return ct.MARKET_OTHER def get_redis_tick_table(self, cdate): cdates = cdate.split('-') return "tick_%s_%s_%s" % (self.code, cdates[0], cdates[1]) def is_tick_table_exists(self, tick_table): if self.redis.exists(self.dbname): return tick_table in set( str(table, encoding="utf8") for table in self.redis.smembers(self.dbname)) return False def is_date_exists(self, tick_table, cdate): if self.redis.exists(tick_table): return cdate in set( str(tdate, encoding="utf8") for tdate in self.redis.smembers(tick_table)) return False def set_k_data(self): prestr = "1" if self.get_market() == ct.MARKET_SH else "0" filename = "%s%s.csv" % (prestr, self.code) df = pd.read_csv("/data/tdx/history/days/%s" % filename, sep=',') df = df[['date', 'open', 'high', 'close', 'low', 'amount', 'volume']] df['date'] = df['date'].astype(str) df['date'] = pd.to_datetime(df.date).dt.strftime("%Y-%m-%d") df = df.rename(columns={'date': 'cdate'}) df = df.reset_index(drop=True) return self.mysql_client.set(df, 'day', method=ct.REPLACE) def set_ticket(self, cdate=None): cdate = datetime.now().strftime('%Y-%m-%d') if cdate is None else cdate if not self.has_on_market(cdate): logger.debug("not on market code:%s, date:%s" % (self.code, cdate)) return tick_table = self.get_redis_tick_table(cdate) if not self.is_tick_table_exists(tick_table): if not self.create_ticket_table(tick_table): logger.error("create tick table failed") return if self.is_date_exists(tick_table, cdate): logger.debug("existed code:%s, date:%s" % (self.code, cdate)) return df = ts.get_tick_data(self.code, date=cdate) df_tdx = read_tick( os.path.join( ct.TIC_DIR, '%s.tic' % datetime.strptime(cdate, "%Y-%m-%d").strftime("%Y%m%d")), self.code) if not df_tdx.empty: if df is not None and not df.empty and df.loc[0]['time'].find( "当天没有数据") == -1: net_volume = df.volume.sum() tdx_volume = df_tdx.volume.sum() if net_volume != tdx_volume: raise Exception( "code:%s, date:%s, net volume:%s, tdx volume:%s not equal" % (self.code, cdate, net_volume, tdx_volume)) df = df_tdx else: if df is None: logger.debug("nonedata code:%s, date:%s" % (self.code, cdate)) return if df.empty: logger.debug("emptydata code:%s, date:%s" % (self.code, cdate)) return if df.loc[0]['time'].find("当天没有数据") != -1: logger.debug("nodata code:%s, date:%s" % (self.code, cdate)) return df.columns = ['ctime', 'price', 'cchange', 'volume', 'amount', 'ctype'] logger.debug("code:%s date:%s" % (self.code, cdate)) df = self.merge_ticket(df) df['date'] = cdate logger.debug("write data code:%s, date:%s, table:%s" % (self.code, cdate, tick_table)) if self.mysql_client.set(df, tick_table): logger.info("start record:%s. table:%s" % (self.code, tick_table)) self.redis.sadd(tick_table, cdate) def get_k_data(self, date=None, dtype=9): table_name = self.data_type_dict[dtype] if date is not None: sql = "select * from %s where date=\"%s\"" % (table_name, date) else: sql = "select * from %s" % table_name return self.mysql_client.get(sql) def is_after_release(self, code_id, _date): time2Market = self.get('timeToMarket') t = time.strptime(str(time2Market), "%Y%m%d") y, m, d = t[0:3] time2Market = datetime(y, m, d) return (datetime.strptime(_date, "%Y-%m-%d") - time2Market).days > 0
class RIndexIndustryInfo: def __init__(self, dbinfo=ct.DB_INFO, redis_host=None): self.redis = create_redis_obj( ) if redis_host is None else create_redis_obj(host=redis_host) self.dbname = self.get_dbname() self.logger = getLogger(__name__) self.mysql_client = CMySQL(dbinfo, self.dbname, iredis=self.redis) if not self.mysql_client.create_db(self.get_dbname()): raise Exception("init rindex stock database failed") @staticmethod def get_dbname(): return ct.RINDEX_INDUSTRY_INFO_DB def get_table_name(self, cdate): cdates = cdate.split('-') return "rindustry_day_%s_%s" % (cdates[0], (int(cdates[1]) - 1) // 3 + 1) def is_date_exists(self, table_name, cdate): if self.redis.exists(table_name): return cdate in set( str(tdate, encoding="utf8") for tdate in self.redis.smembers(table_name)) return False def is_table_exists(self, table_name): if self.redis.exists(self.dbname): return table_name in set( str(table, encoding="utf8") for table in self.redis.smembers(self.dbname)) return False def create_table(self, table): sql = 'create table if not exists %s(date varchar(10) not null,\ code varchar(10) not null,\ open float,\ high float,\ close float,\ preclose float,\ low float,\ volume bigint,\ amount float,\ preamount float,\ pchange float,\ mchange float,\ PRIMARY KEY (date, code))' % table return True if table in self.mysql_client.get_all_tables( ) else self.mysql_client.create(sql, table) def get_k_data_in_range(self, start_date, end_date): ndays = delta_days(start_date, end_date) date_dmy_format = time.strftime("%m/%d/%Y", time.strptime(start_date, "%Y-%m-%d")) data_times = pd.date_range(date_dmy_format, periods=ndays, freq='D') date_only_array = np.vectorize(lambda s: s.strftime('%Y-%m-%d'))( data_times.to_pydatetime()) data_dict = OrderedDict() for _date in date_only_array: if CCalendar.is_trading_day(_date, redis=self.redis): table_name = self.get_table_name(_date) if table_name not in data_dict: data_dict[table_name] = list() data_dict[table_name].append(str(_date)) all_df = pd.DataFrame() for key in data_dict: table_list = sorted(data_dict[key], reverse=False) if len(table_list) == 1: df = self.get_data(table_list[0]) if df is not None: all_df = all_df.append(df) else: start_date = table_list[0] end_date = table_list[len(table_list) - 1] df = self.get_data_between(start_date, end_date) if df is not None: all_df = all_df.append(df) return all_df def get_data_between(self, start_date, end_date): #start_date and end_date shoulw be in the same table sql = "select * from %s where date between \"%s\" and \"%s\"" % ( self.get_table_name(start_date), start_date, end_date) return self.mysql_client.get(sql) def get_k_data(self, cdate): cdate = datetime.now().strftime('%Y-%m-%d') if cdate is None else cdate sql = "select * from %s where date=\"%s\"" % ( self.get_table_name(cdate), cdate) return self.mysql_client.get(sql) def get_industry_data(self, cdate, code): return (code, CIndex(code).get_k_data(cdate)) def generate_data(self, cdate): good_list = list() obj_pool = Pool(500) all_df = pd.DataFrame() industry_info = IndustryInfo.get(self.redis) failed_list = industry_info.code.tolist() cfunc = partial(self.get_industry_data, cdate) failed_count = 0 while len(failed_list) > 0: is_failed = False self.logger.debug("restart failed ip len(%s)" % len(failed_list)) for code_data in obj_pool.imap_unordered(cfunc, failed_list): if code_data[1] is not None: tem_df = code_data[1] tem_df['code'] = code_data[0] all_df = all_df.append(tem_df) failed_list.remove(code_data[0]) else: is_failed = True if is_failed: failed_count += 1 if failed_count > 10: self.logger.info("%s rindustry init failed" % failed_list) return pd.DataFrame() time.sleep(10) obj_pool.join(timeout=5) obj_pool.kill() self.mysql_client.changedb(self.get_dbname()) if all_df.empty: return all_df all_df = all_df.reset_index(drop=True) return all_df def set_data(self, cdate=datetime.now().strftime('%Y-%m-%d')): if not CCalendar.is_trading_day(cdate, redis=self.redis): return False table_name = self.get_table_name(cdate) if not self.is_table_exists(table_name): if not self.create_table(table_name): self.logger.error("create rindex table failed") return False self.redis.sadd(self.dbname, table_name) if self.is_date_exists(table_name, cdate): self.logger.debug("existed rindex table:%s, date:%s" % (table_name, cdate)) return True df = self.generate_data(cdate) if df.empty: return False self.redis.set(ct.TODAY_ALL_INDUSTRY, _pickle.dumps(df, 2)) if self.mysql_client.set(df, table_name): return self.redis.sadd(table_name, cdate) return False def update(self, end_date=None, num=10): if end_date is None: end_date = datetime.now().strftime('%Y-%m-%d') #if end_date == datetime.now().strftime('%Y-%m-%d'): end_date = get_day_nday_ago(end_date, num = 1, dformat = "%Y-%m-%d") start_date = get_day_nday_ago(end_date, num=num, dformat="%Y-%m-%d") date_array = get_dates_array(start_date, end_date) succeed = True for mdate in date_array: if CCalendar.is_trading_day(mdate, redis=self.redis): if not self.set_data(mdate): self.logger.error("%s rindustry set failed" % mdate) succeed = False return succeed
class StockExchange(object): def __init__(self, market = ct.SH_MARKET_SYMBOL, dbinfo = ct.DB_INFO, redis_host = None): self.market = market self.dbinfo = dbinfo self.balcklist = ['2006-07-10'] if market == ct.SH_MARKET_SYMBOL else list() self.logger = getLogger(__name__) self.dbname = self.get_dbname(market) self.redis = create_redis_obj() if redis_host is None else create_redis_obj(host = redis_host) self.header = {"Host": "query.sse.com.cn", "Referer": "http://www.sse.com.cn/market/stockdata/overview/day/", "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_13_6) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/70.0.3538.77 Safari/537.36"} self.mysql_client = CMySQL(self.dbinfo, dbname = self.dbname, iredis = self.redis) if not self.mysql_client.create_db(self.dbname): raise Exception("create %s failed" % self.dbname) @staticmethod def get_dbname(market): return market def get_table_name(self): return "%s_deal" % self.dbname def create_table(self): table = self.get_table_name() if table not in self.mysql_client.get_all_tables(): sql = 'create table if not exists %s(date varchar(10) not null,\ name varchar(20) not null,\ amount float,\ number int,\ negotiable_value float,\ market_value float,\ pe float,\ totals float,\ outstanding float,\ volume float,\ transactions float,\ turnover float,\ PRIMARY KEY (date, name))' % table if not self.mysql_client.create(sql, table): return False return True def get_k_data_in_range(self, start_date, end_date): sql = "select * from %s where date between \"%s\" and \"%s\"" % (self.get_table_name(), start_date, end_date) return self.mysql_client.get(sql) def get_k_data(self, cdate = datetime.now().strftime('%Y-%m-%d')): sql = "select * from %s where date=\"%s\"" % (self.get_table_name(), cdate) return self.mysql_client.get(sql) def is_table_exists(self, table_name): if self.redis.exists(self.dbname): return table_name in set(str(table, encoding = "utf8") for table in self.redis.smembers(self.dbname)) return False def is_date_exists(self, table_name, cdate): if self.redis.exists(table_name): return cdate in set(str(tdate, encoding = "utf8") for tdate in self.redis.smembers(table_name)) return False def get_url(self): if self.market == ct.SH_MARKET_SYMBOL: return "http://query.sse.com.cn/marketdata/tradedata/queryTradingByProdTypeData.do?jsonCallBack=jsonpCallback%s&searchDate=%s&prodType=gp&_=%s" else: return "http://www.szse.cn/api/report/ShowReport?SHOWTYPE=xlsx&CATALOGID=1803&TABKEY=%s&txtQueryDate=%s&random=%s" def get_sh_type_name(self, dtype): if '1' == dtype: return "A股" elif '2' == dtype: return "B股" elif '12' == dtype: return "上海市场" elif '90' == dtype: return "科创板" else: return None def get_data_from_url(self, cdate = datetime.now().strftime('%Y-%m-%d')): if self.market == ct.SH_MARKET_SYMBOL: current_milli_time = lambda: int(round(time.time() * 1000)) url = self.get_url() % (int_random(5), cdate, current_milli_time()) response = smart_get(requests.get, url, headers=self.header) if response.status_code != 200: self.logger.error("get exchange data failed, response code:%s" % response.status_code) return pd.DataFrame() json_result = loads_jsonp(response.text) if json_result is None: self.logger.error("parse exchange data jsonp failed") return pd.DataFrame() datas = list() for json_obj in json_result['result']: name = self.get_sh_type_name(json_obj['productType']) if name is None: self.logger.error("get unknown type for SH data:%s" % json_obj['productType']) return pd.DataFrame() elif name == "科创板": continue else: amount = 0 if json_obj['trdAmt'] == '' else float(json_obj['trdAmt']) number = 0 if json_obj['istVol'] == '' else int(json_obj['istVol']) negotiable_value = 0 if json_obj['negotiableValue'] == '' else float(json_obj['negotiableValue']) market_value = 0 if json_obj['marketValue'] == '' else float(json_obj['marketValue']) volume = 0 if json_obj['trdVol'] == '' else float(json_obj['trdVol']) pe = 0 if json_obj['profitRate'] == '' else float(json_obj['profitRate']) transactions = 0 if json_obj['trdTm'] == '' else float(json_obj['trdTm']) turnover = 0 if json_obj['exchangeRate'] == '' else float(json_obj['exchangeRate']) outstanding = 0 if turnover == 0 else volume / (100 * turnover) totals = outstanding data = {'amount': amount,\ 'number': number,\ 'negotiable_value': negotiable_value,\ 'market_value': market_value,\ 'pe': pe,\ 'totals': totals,\ 'outstanding': outstanding,\ 'volume': volume,\ 'transactions': transactions,\ 'turnover': turnover} if any(data.values()): data['name'] = name data['date'] = cdate datas.append(data) df = pd.DataFrame.from_dict(datas) else: datas = list() for name, tab in ct.SZ_MARKET_DICT.items(): url = self.get_url() % (tab, cdate, float_random(17)) df = smart_get(pd.read_excel, url, usecols = [0, 1]) if df is None: return pd.DataFrame() if df.empty: continue if len(df) == 1 and df.values[0][0] == '没有找到符合条件的数据!': continue if name == "深圳市场": amount = 0 #amount = float(df.loc[df['指标名称'] == '市场总成交金额(元)', '本日数值'].values[0].replace(',', '')) / 100000000 number = int(float(df.loc[df['指标名称'] == '上市公司数', '本日数值'].values[0].replace(',', ''))) negotiable_value = float(df.loc[df['指标名称'] == '股票流通市值(元)', '本日数值'].values[0].replace(',', '')) / 100000000 market_value = float(df.loc[df['指标名称'] == '股票总市值(元)', '本日数值'].values[0].replace(',', '')) / 100000000 pe = float(df.loc[df['指标名称'] == '股票平均市盈率', '本日数值'].values[0].replace(',', '')) totals = float(df.loc[df['指标名称'] == '股票总股本(股)', '本日数值'].values[0].replace(',', '')) / 100000000 outstanding = float(df.loc[df['指标名称'] == '股票流通股本(股)', '本日数值'].values[0].replace(',', '')) / 100000000 volume = 0 transactions = 0 turnover = float(df.loc[df['指标名称'] == '股票平均换手率', '本日数值'].values[0]) else: amount = float(df.loc[df['指标名称'] == '总成交金额(元)', '本日数值'].values[0].replace(',', '')) / 100000000 number = int(float(df.loc[df['指标名称'] == '上市公司数', '本日数值'].values[0].replace(',', ''))) negotiable_value = float(df.loc[df['指标名称'] == '上市公司流通市值(元)', '本日数值'].values[0].replace(',', '')) / 100000000 market_value = float(df.loc[df['指标名称'] == '上市公司市价总值(元)', '本日数值'].values[0].replace(',', '')) / 100000000 pe = float(df.loc[df['指标名称'] == '平均市盈率(倍)', '本日数值'].values[0]) totals = float(df.loc[df['指标名称'] == '总发行股本(股)', '本日数值'].values[0].replace(',', '')) / 100000000 outstanding = float(df.loc[df['指标名称'] == '总流通股本(股)', '本日数值'].values[0].replace(',', '')) / 100000000 volume = float(df.loc[df['指标名称'] == '总成交股数', '本日数值'].values[0].replace(',', '')) / 100000000 transactions = float(df.loc[df['指标名称'] == '总成交笔数', '本日数值'].values[0].replace(',', '')) / 10000 turnover = 100 * volume / outstanding data = { 'name': name,\ 'date': cdate,\ 'amount': amount,\ 'number': number,\ 'negotiable_value': negotiable_value,\ 'market_value': market_value,\ 'pe': pe,\ 'totals': totals,\ 'outstanding': outstanding,\ 'volume': volume,\ 'transactions': transactions,\ 'turnover': turnover } datas.append(data) df = pd.DataFrame.from_dict(datas) if not df.empty: df.at[df.name == "深圳市场", 'amount'] = df.amount.sum() - df.loc[df.name == "深圳市场", 'amount'] df.at[df.name == "深圳市场", 'volume'] = df.volume.sum() - df.loc[df.name == "深圳市场", 'volume'] df.at[df.name == "深圳市场", 'transactions'] = df.transactions.sum() - df.loc[df.name == "深圳市场", 'transactions'] return df def set_k_data(self, cdate = datetime.now().strftime('%Y-%m-%d')): table_name = self.get_table_name() if not self.is_table_exists(table_name): if not self.create_table(): self.logger.error("create tick table failed") return False self.redis.sadd(self.dbname, table_name) if self.is_date_exists(table_name, cdate): self.logger.debug("existed table:%s, date:%s" % (table_name, cdate)) return True df = self.get_data_from_url(cdate) if df.empty: self.logger.debug("get data from %s failed, date:%s" % (self.market, cdate)) return False if self.mysql_client.set(df, table_name): self.redis.sadd(table_name, cdate) return True return False def update(self, end_date = None, num = 10): if end_date is None: end_date = datetime.now().strftime('%Y-%m-%d') if end_date == datetime.now().strftime('%Y-%m-%d'): end_date = get_day_nday_ago(end_date, num = 1, dformat = "%Y-%m-%d") start_date = get_day_nday_ago(end_date, num = num, dformat = "%Y-%m-%d") succeed = True for mdate in get_dates_array(start_date, end_date): if mdate in self.balcklist: continue if CCalendar.is_trading_day(mdate, redis = self.redis): if not self.set_k_data(mdate): succeed = False self.logger.info("market %s for %s set failed" % (self.market, mdate)) return succeed
class Margin(object): def __init__(self, dbinfo=ct.DB_INFO, redis_host=None): self.logger = getLogger(__name__) self.crawler = get_tushare_client() self.dbname = self.get_dbname() self.redis = create_redis_obj( ) if redis_host is None else create_redis_obj(host=redis_host) self.mysql_client = CMySQL(dbinfo, self.dbname, iredis=self.redis) if not self.mysql_client.create_db(self.dbname): raise Exception("init margin database failed") @staticmethod def get_dbname(): return "margin" def get_table_name(self, cdate): cdates = cdate.split('-') return "%s_day_%s_%s" % (self.dbname, cdates[0], (int(cdates[1]) - 1) // 3 + 1) def is_date_exists(self, table_name, cdate): if self.redis.exists(table_name): return self.redis.sismember(table_name, cdate) return False def create_table(self, table): sql = 'create table if not exists %s(date varchar(10) not null,\ code varchar(10) not null,\ rzye float,\ rzmre float,\ rzche float,\ rqye float,\ rqyl float,\ rqmcl float,\ rqchl float,\ rzrqye float,\ PRIMARY KEY (date, code))' % table return True if table in self.mysql_client.get_all_tables( ) else self.mysql_client.create(sql, table) def get_k_data_in_range(self, start_date, end_date): ndays = delta_days(start_date, end_date) date_dmy_format = time.strftime("%m/%d/%Y", time.strptime(start_date, "%Y-%m-%d")) data_times = pd.date_range(date_dmy_format, periods=ndays, freq='D') date_only_array = np.vectorize(lambda s: s.strftime('%Y-%m-%d'))( data_times.to_pydatetime()) data_dict = OrderedDict() for _date in date_only_array: if CCalendar.is_trading_day(_date, redis=self.redis): table_name = self.get_table_name(_date) if table_name not in data_dict: data_dict[table_name] = list() data_dict[table_name].append(str(_date)) all_df = pd.DataFrame() for key in data_dict: table_list = sorted(data_dict[key], reverse=False) if len(table_list) == 1: df = self.get_data(table_list[0]) if df is not None: all_df = all_df.append(df) else: start_date = table_list[0] end_date = table_list[len(table_list) - 1] df = self.get_data_between(start_date, end_date) if df is not None: all_df = all_df.append(df) return all_df def get_data_between(self, start_date, end_date): #start_date and end_date should be in the same table table_name = self.get_table_name(start_date) if not self.is_table_exists(table_name): return None sql = "select * from %s where date between \"%s\" and \"%s\"" % ( table_name, start_date, end_date) return self.mysql_client.get(sql) def get_data(self, cdate=datetime.now().strftime('%Y-%m-%d')): sql = "select * from %s where date=\"%s\"" % ( self.get_table_name(cdate), cdate) return self.mysql_client.get(sql) def update(self, end_date=None, num=10): if end_date is None: end_date = datetime.now().strftime('%Y-%m-%d') start_date = get_day_nday_ago(end_date, num=num, dformat="%Y-%m-%d") date_array = get_dates_array(start_date, end_date) succeed = True for mdate in date_array: if CCalendar.is_trading_day(mdate, redis=self.redis): if mdate == end_date: continue if not self.set_data(mdate): self.logger.error("%s set failed" % mdate) succeed = False return succeed def is_table_exists(self, table_name): if self.redis.exists(self.dbname): return self.redis.sismember(self.dbname, table_name) return False def set_data(self, cdate=datetime.now().strftime('%Y-%m-%d')): table_name = self.get_table_name(cdate) if not self.is_table_exists(table_name): if not self.create_table(table_name): self.logger.error("create tick table failed") return False self.redis.sadd(self.dbname, table_name) if self.is_date_exists(table_name, cdate): self.logger.debug("existed table:%s, date:%s" % (table_name, cdate)) return True total_df = smart_get(self.crawler.margin, trade_date=transfer_date_string_to_int(cdate)) if total_df is None: self.logger.error("crawel margin for %s failed" % cdate) return False total_df = total_df.rename(columns={ "trade_date": "date", "exchange_id": "code" }) total_df['rqyl'] = 0 total_df['rqchl'] = 0 detail_df = smart_get(self.crawler.margin_detail, trade_date=transfer_date_string_to_int(cdate)) if detail_df is None: self.logger.error("crawel detail margin for %s failed" % cdate) return False detail_df = detail_df.rename(columns={ "trade_date": "date", "ts_code": "code" }) total_df = total_df.append(detail_df, sort=False) total_df['date'] = pd.to_datetime( total_df.date).dt.strftime("%Y-%m-%d") total_df = total_df.reset_index(drop=True) if self.mysql_client.set(total_df, table_name): time.sleep(1) return self.redis.sadd(table_name, cdate) return False
class StockConnect(object): def __init__(self, market_from=ct.SH_MARKET_SYMBOL, market_to=ct.HK_MARKET_SYMBOL, dbinfo=ct.DB_INFO, redis_host=None): self.market_from = market_from self.market_to = market_to self.balcklist = None self.crawler = None self.mysql_client = None self.dbinfo = dbinfo self.logger = getLogger(__name__) self.redis = create_redis_obj( ) if redis_host is None else create_redis_obj(host=redis_host) def set_market(self, market_from, market_to): self.market_from = market_from self.market_to = market_to self.balcklist = [ "2018-10-17", "2018-09-25", "2018-07-02", "2018-05-22", "2018-04-02", "2018-03-30" ] if market_from in [ct.SH_MARKET_SYMBOL, ct.SZ_MARKET_SYMBOL ] else list() self.dbname = self.get_dbname(market_from, market_to) self.crawler = MCrawl(market_from) self.mysql_client = CMySQL(self.dbinfo, self.dbname, iredis=self.redis) return False if not self.mysql_client.create_db(self.dbname) else True def quit(self): self.crawler.quit() def close(self): self.crawler.close() @staticmethod def get_dbname(mfrom, mto): return "%s2%s" % (mfrom, mto) def get_table_name(self, cdate): cdates = cdate.split('-') return "%s_stock_day_%s_%s" % (self.dbname, cdates[0], (int(cdates[1]) - 1) // 3 + 1) def is_date_exists(self, table_name, cdate): if self.redis.exists(table_name): return cdate in set( str(tdate, encoding=ct.UTF8) for tdate in self.redis.smembers(table_name)) return False def create_table(self, table): sql = 'create table if not exists %s(date varchar(10) not null,\ code varchar(10) not null,\ name varchar(90),\ volume int,\ percent float,\ PRIMARY KEY (date, code))' % table return True if table in self.mysql_client.get_all_tables( ) else self.mysql_client.create(sql, table) def get_k_data_in_range(self, start_date, end_date): ndays = delta_days(start_date, end_date) date_dmy_format = time.strftime("%m/%d/%Y", time.strptime(start_date, "%Y-%m-%d")) data_times = pd.date_range(date_dmy_format, periods=ndays, freq='D') date_only_array = np.vectorize(lambda s: s.strftime('%Y-%m-%d'))( data_times.to_pydatetime()) data_dict = OrderedDict() for _date in date_only_array: if CCalendar.is_trading_day(_date, redis=self.redis): table_name = self.get_table_name(_date) if table_name not in data_dict: data_dict[table_name] = list() data_dict[table_name].append(str(_date)) all_df = pd.DataFrame() for key in data_dict: table_list = sorted(data_dict[key], reverse=False) if len(table_list) == 1: df = self.get_k_data(table_list[0]) if df is not None: all_df = all_df.append(df) else: start_date = table_list[0] end_date = table_list[len(table_list) - 1] df = self.get_data_between(start_date, end_date) if df is not None: all_df = all_df.append(df) return all_df def get_data_between(self, start_date, end_date): #start_date and end_date should be in the same table sql = "select * from %s where date between \"%s\" and \"%s\"" % ( self.get_table_name(start_date), start_date, end_date) return self.mysql_client.get(sql) def get_k_data(self, cdate=datetime.now().strftime('%Y-%m-%d')): sql = "select * from %s where date=\"%s\"" % ( self.get_table_name(cdate), cdate) return self.mysql_client.get(sql) def update(self, end_date=None, num=10): if end_date is None: end_date = datetime.now().strftime('%Y-%m-%d') start_date = get_day_nday_ago(end_date, num=num, dformat="%Y-%m-%d") succeed = True for mdate in get_dates_array(start_date, end_date): if CCalendar.is_trading_day(mdate, redis=self.redis): if mdate == end_date or mdate in self.balcklist: continue if not self.set_data(mdate): succeed = False return succeed def is_table_exists(self, table_name): if self.redis.exists(self.dbname): return table_name in set( str(table, encoding=ct.UTF8) for table in self.redis.smembers(self.dbname)) return False def set_data(self, cdate=datetime.now().strftime('%Y-%m-%d')): table_name = self.get_table_name(cdate) if not self.is_table_exists(table_name): if not self.create_table(table_name): self.logger.error("create tick table failed") return False self.redis.sadd(self.dbname, table_name) if self.is_date_exists(table_name, cdate): self.logger.debug("existed table:%s, date:%s" % (table_name, cdate)) return True ret, df = self.crawler.crawl(cdate) if ret != 0: return False if df.empty: return True df = df.reset_index(drop=True) df['date'] = cdate if self.mysql_client.set(df, table_name): return self.redis.sadd(table_name, cdate) return False
class CReivew: def __init__(self, dbinfo=ct.DB_INFO, redis_host=None): self.dbinfo = dbinfo self.logger = getLogger(__name__) self.tu_client = get_tushare_client() self.doc = CDoc() self.redis = create_redis_obj( ) if redis_host is None else create_redis_obj(redis_host) self.mysql_client = CMySQL(dbinfo, iredis=self.redis) self.margin_client = Margin(dbinfo=dbinfo, redis_host=redis_host) self.rstock_client = RIndexStock(dbinfo=dbinfo, redis_host=redis_host) self.sh_market_client = StockExchange(ct.SH_MARKET_SYMBOL) self.sz_market_client = StockExchange(ct.SZ_MARKET_SYMBOL) self.emotion_client = Emotion() def get_industry_data(self, cdate): ri = RIndexIndustryInfo() df = ri.get_k_data(cdate) if df.empty: return df df = df.reset_index(drop=True) df = df.sort_values(by='amount', ascending=False) df['money_change'] = (df['amount'] - df['preamount']) / 1e8 industry_info = IndustryInfo.get() df = pd.merge(df, industry_info, how='left', on=['code']) return df def get_index_data(self, cdate): df = pd.DataFrame() for code, name in ct.TDX_INDEX_DICT.items(): data = CIndex(code).get_k_data(cdate) data['name'] = name data['code'] = code df = df.append(data) df = df.reset_index(drop=True) return df def get_market_data(self, market, start_date, end_date): if market == ct.SH_MARKET_SYMBOL: df = self.sh_market_client.get_k_data_in_range( start_date, end_date) df = df.loc[df.name == '上海市场'] else: df = self.sz_market_client.get_k_data_in_range( start_date, end_date) df = df.loc[df.name == '深圳市场'] df = df.round(2) df = df.drop_duplicates() df = df.reset_index(drop=True) df = df.sort_values(by='date', ascending=True) df.negotiable_value = (df.negotiable_value / 2).astype(int) return df def get_rzrq_info(self, market, start_date, end_date): df = self.margin_client.get_k_data_in_range(start_date, end_date) if market == ct.SH_MARKET_SYMBOL: df = df.loc[df.code == 'SSE'] df['code'] = '上海市场' else: df = df.loc[df.code == 'SZSE'] df['code'] = '深圳市场' df = df.round(2) df['rzye'] = df['rzye'] / 1e+8 df['rzmre'] = df['rzmre'] / 1e+8 df['rzche'] = df['rzche'] / 1e+8 df['rqye'] = df['rqye'] / 1e+8 df['rzrqye'] = df['rzrqye'] / 1e+8 df = df.drop_duplicates() df = df.reset_index(drop=True) df = df.sort_values(by='date', ascending=True) return df def get_index_df(self, code, start_date, end_date): cindex_client = CIndex(code) df = cindex_client.get_k_data_in_range(start_date, end_date) df['time'] = df.index.tolist() df = df[[ 'time', 'open', 'high', 'low', 'close', 'volume', 'amount', 'date' ]] return df def update(self, cdate=datetime.now().strftime('%Y-%m-%d')): start_date = get_day_nday_ago(cdate, 100, dformat="%Y-%m-%d") end_date = cdate try: #market info sh_df = self.get_market_data(ct.SH_MARKET_SYMBOL, start_date, end_date) sz_df = self.get_market_data(ct.SZ_MARKET_SYMBOL, start_date, end_date) date_list = list( set(sh_df.date.tolist()).intersection(set( sz_df.date.tolist()))) sh_df = sh_df[sh_df.date.isin(date_list)] sh_df = sh_df.reset_index(drop=True) sz_df = sz_df[sz_df.date.isin(date_list)] sz_df = sz_df.reset_index(drop=True) #rzrq info sh_rzrq_df = self.get_rzrq_info(ct.SH_MARKET_SYMBOL, start_date, end_date) sz_rzrq_df = self.get_rzrq_info(ct.SZ_MARKET_SYMBOL, start_date, end_date) date_list = list( set(sh_rzrq_df.date.tolist()).intersection( set(sz_rzrq_df.date.tolist()))) sh_rzrq_df = sh_rzrq_df[sh_rzrq_df.date.isin(date_list)] sh_rzrq_df = sh_rzrq_df.reset_index(drop=True) sz_rzrq_df = sz_rzrq_df[sz_rzrq_df.date.isin(date_list)] sz_rzrq_df = sz_rzrq_df.reset_index(drop=True) #average price info av_df = self.get_index_df('880003', start_date, end_date) #limit up and down info limit_info = CLimit(self.dbinfo).get_data(cdate) stock_info = self.rstock_client.get_data(cdate) stock_info = stock_info[stock_info.volume > 0] #get volume > 0 stock list stock_info = stock_info.reset_index(drop=True) #index info index_info = self.get_index_data(end_date) #industry analysis industry_info = self.get_industry_data(cdate) #all stock info all_stock_info = self.rstock_client.get_k_data_in_range( start_date, end_date) #gen review file and make dir for new data self.doc.generate(cdate, sh_df, sz_df, sh_rzrq_df, sz_rzrq_df, av_df, limit_info, stock_info, industry_info, index_info, all_stock_info) ##gen review animation #self.gen_animation() except Exception as e: self.logger.error(e) traceback.print_exc() def gen_animation(self, sfile=None): style.use('fivethirtyeight') Writer = animation.writers['ffmpeg'] writer = Writer(fps=1, metadata=dict(artist='biek'), bitrate=1800) fig = plt.figure() ax = fig.add_subplot(1, 1, 1) _today = datetime.now().strftime('%Y-%m-%d') cdata = self.mysql_client.get('select * from %s where date = "%s"' % (ct.ANIMATION_INFO, _today)) if cdata is None: return None cdata = cdata.reset_index(drop=True) ctime_list = cdata.time.unique() name_list = cdata.name.unique() ctime_list = [ datetime.strptime(ctime, '%H:%M:%S') for ctime in ctime_list ] frame_num = len(ctime_list) if 0 == frame_num: return None def animate(i): ax.clear() ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) ax.xaxis.set_major_locator(mdates.DayLocator()) ax.set_title('盯盘', fontproperties=get_chinese_font()) ax.set_xlabel('时间', fontproperties=get_chinese_font()) ax.set_ylabel('增长', fontproperties=get_chinese_font()) ax.set_ylim((-6, 6)) fig.autofmt_xdate() for name in name_list: pchange_list = list() price_list = cdata[cdata.name == name]['price'].tolist() pchange_list.append(0) for _index in range(1, len(price_list)): pchange_list.append( 10 * (price_list[_index] - price_list[_index - 1]) / price_list[0]) ax.plot(ctime_list[0:i], pchange_list[0:i], label=name, linewidth=1.5) if pchange_list[i - 1] > 1 or pchange_list[i - 1] < -1: ax.text(ctime_list[i - 1], pchange_list[i - 1], name, font_properties=get_chinese_font()) ani = animation.FuncAnimation(fig, animate, frame_num, interval=60000, repeat=False) sfile = '/data/animation/%s_animation.mp4' % _today if sfile is None else sfile ani.save(sfile, writer) plt.close(fig)
class CReivew: def __init__(self, dbinfo): self.dbinfo = dbinfo self.sdir = '/data/docs/blog/hellobiek.github.io/source/_posts' self.doc = CDoc(self.sdir) self.stock_objs = dict() self.redis = create_redis_obj() self.mysql_client = CMySQL(self.dbinfo, iredis=self.redis) self.cal_client = ccalendar.CCalendar(without_init=True) self.animating = False self.emotion_table = ct.EMOTION_TABLE if not self.create_emotion(): raise Exception("create emotion table failed") def create_emotion(self): if self.emotion_table not in self.mysql_client.get_all_tables(): sql = 'create table if not exists %s(date varchar(10) not null, score float, PRIMARY KEY (date))' % self.emotion_table if not self.mysql_client.create(sql, self.emotion_table): return False return True def get_today_all_stock_data(self, _date): df_byte = self.redis.get(ct.TODAY_ALL_STOCK) if df_byte is None: return None df = _pickle.loads(df_byte) return df[df.date == _date] def get_industry_data(self, _date): df = pd.DataFrame() df_info = IndustryInfo.get() for _, code in df_info.code.iteritems(): data = CIndex(code).get_k_data(date=_date) df = df.append(data) df = df.reset_index(drop=True) df['name'] = df_info['name'] df = df.sort_values(by='amount', ascending=False) df = df.reset_index(drop=True) return df def emotion_plot(self, dir_name): sql = "select * from %s" % self.emotion_table df = self.mysql_client.get(sql) fig = plt.figure() x = df.date.tolist() xn = range(len(x)) y = df.score.tolist() plt.plot(xn, y) for xi, yi in zip(xn, y): plt.plot((xi, ), (yi, ), 'ro') plt.text(xi, yi, '%s' % yi) plt.scatter(xn, y, label='score', color='k', s=25, marker="o") plt.xticks(xn, x) plt.xlabel('时间', fontproperties=get_chinese_font()) plt.ylabel('分数', fontproperties=get_chinese_font()) plt.title('股市情绪', fontproperties=get_chinese_font()) fig.autofmt_xdate() plt.savefig('%s/emotion.png' % dir_name, dpi=1000) def industry_plot(self, dir_name, industry_info): #colors = ['#F5DEB3', '#A0522D', '#1E90FF', '#FFE4C4', '#00FFFF', '#DAA520', '#3CB371', '#808080', '#ADFF2F', '#4B0082', '#ADD8E6'] colors = [ '#F5DEB3', '#A0522D', '#1E90FF', '#FFE4C4', '#00FFFF', '#DAA520', '#3CB371', '#808080', '#ADFF2F', '#4B0082' ] industry_info.amount = industry_info.amount / 10000000000 total_amount = industry_info.amount.sum() amount_list = industry_info[0:10].amount.tolist() x = date.fromtimestamp(time.time()) fig = plt.figure() base_line = 0 for i in range(len(amount_list)): label_name = "%s:%s" % (industry_info.loc[i]['name'], 100 * amount_list[i] / total_amount) plt.bar(x, amount_list[i], width=0.1, color=colors[i], bottom=base_line, align='center', label=label_name) base_line += amount_list[i] plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%m/%d/%Y')) plt.gca().xaxis.set_major_locator(mdates.DayLocator()) plt.xlabel('x轴', fontproperties=get_chinese_font()) plt.ylabel('y轴', fontproperties=get_chinese_font()) plt.title('市值分布', fontproperties=get_chinese_font()) fig.autofmt_xdate() plt.legend(loc='upper right', prop=get_chinese_font()) plt.savefig('%s/industry.png' % dir_name, dpi=1000) def get_limitup_data(self, date): return CLimit(self.dbinfo).get_data(date) def gen_market_emotion_score(self, stock_info, limit_info): limit_up_list = limit_info[(limit_info.pchange > 0) & (limit_info.prange != 0)].reset_index( drop=True).code.tolist() limit_down_list = limit_info[limit_info.pchange < 0].reset_index( drop=True).code.tolist() limit_up_list.extend(limit_down_list) total = 0 for _index, pchange in stock_info.changepercent.iteritems(): code = str(stock_info.loc[_index, 'code']).zfill(6) if code in limit_up_list: total += 2 * pchange else: total += pchange aver = total / len(stock_info) data = { 'date': ["%s" % datetime.now().strftime('%Y-%m-%d')], 'score': [aver] } df = pd.DataFrame.from_dict(data) if not self.mysql_client.set(df, self.emotion_table): raise Exception("set data to emotion failed") def static_plot(self, dir_name, stock_info, limit_info): colors = ['b', 'r', 'y', 'g', 'm'] limit_up_list = limit_info[(limit_info.pchange > 0) & (limit_info.prange != 0)].reset_index( drop=True).code.tolist() limit_down_list = limit_info[limit_info.pchange < 0].reset_index( drop=True).code.tolist() limit_list = limit_up_list + limit_down_list changepercent_list = [9, 7, 5, 3, 1, 0, -1, -3, -5, -7, -9] num_list = list() name_list = list() num_list.append(len(limit_up_list)) name_list.append("涨停") c_length = len(changepercent_list) for _index in range(c_length): pchange = changepercent_list[_index] if 0 == _index: num_list.append( len(stock_info[(stock_info.changepercent > pchange) & ( stock_info.loc[_index, 'code'] not in limit_list)])) name_list.append(">%s" % pchange) elif c_length - 1 == _index: num_list.append( len(stock_info[(stock_info.changepercent < pchange) & ( stock_info.loc[_index, 'code'] not in limit_list)])) name_list.append("<%s" % pchange) else: p_max_change = changepercent_list[_index - 1] num_list.append( len(stock_info[(stock_info.changepercent > pchange) & (stock_info.changepercent < p_max_change)])) name_list.append("%s-%s" % (pchange, p_max_change)) num_list.append(len(limit_down_list)) name_list.append("跌停") fig = plt.figure() for i in range(len(num_list)): plt.bar(i + 1, num_list[i], color=colors[i % len(colors)], width=0.3) plt.text(i + 1, 15 + num_list[i], num_list[i], ha='center', font_properties=get_chinese_font()) plt.xlabel('x轴', fontproperties=get_chinese_font()) plt.ylabel('y轴', fontproperties=get_chinese_font()) plt.title('涨跌分布', fontproperties=get_chinese_font()) plt.xticks(range(1, len(num_list) + 1), name_list, fontproperties=get_chinese_font()) fig.autofmt_xdate() plt.savefig('%s/static.png' % dir_name, dpi=1000) def is_collecting_time(self): now_time = datetime.now() _date = now_time.strftime('%Y-%m-%d') y, m, d = time.strptime(_date, "%Y-%m-%d")[0:3] mor_open_hour, mor_open_minute, mor_open_second = (21, 0, 0) mor_open_time = datetime(y, m, d, mor_open_hour, mor_open_minute, mor_open_second) mor_close_hour, mor_close_minute, mor_close_second = (23, 59, 59) mor_close_time = datetime(y, m, d, mor_close_hour, mor_close_minute, mor_close_second) return mor_open_time < now_time < mor_close_time def get_index_data(self, _date): df = pd.DataFrame() for code, name in ct.TDX_INDEX_DICT.items(): self.mysql_client.changedb(CIndex.get_dbname(code)) data = self.mysql_client.get( "select * from day where date=\"%s\";" % _date) data['name'] = name df = df.append(data) self.mysql_client.changedb() return df def update(self): _date = datetime.now().strftime('%Y-%m-%d') dir_name = os.path.join(self.sdir, "%s-StockReView" % _date) try: if not os.path.exists(dir_name): logger.info("create daily info") #stock analysis stock_info = self.get_today_all_stock_data(_date) #get volume > 0 stock list stock_info = stock_info[stock_info.volume > 0] stock_info = stock_info.reset_index(drop=True) os.makedirs(dir_name, exist_ok=True) #industry analysis industry_info = self.get_industry_data(_date) #index and total analysis index_info = self.get_index_data(_date) index_info = index_info.reset_index(drop=True) #limit up and down analysis limit_info = self.get_limitup_data(_date) #emotion analysis self.gen_market_emotion_score(stock_info, limit_info) self.emotion_plot(dir_name) #static analysis self.static_plot(dir_name, stock_info, limit_info) #gen review file self.doc.generate(stock_info, industry_info, index_info) #gen review animation self.gen_animation() except Exception as e: logger.error(e) def gen_animation(self, sfile=None): style.use('fivethirtyeight') Writer = animation.writers['ffmpeg'] writer = Writer(fps=1, metadata=dict(artist='biek'), bitrate=1800) fig = plt.figure() ax = fig.add_subplot(1, 1, 1) _today = datetime.now().strftime('%Y-%m-%d') cdata = self.mysql_client.get('select * from %s where date = "%s"' % (ct.ANIMATION_INFO, _today)) if cdata is None: return None cdata = cdata.reset_index(drop=True) ctime_list = cdata.time.unique() name_list = cdata.name.unique() ctime_list = [ datetime.strptime(ctime, '%H:%M:%S') for ctime in ctime_list ] frame_num = len(ctime_list) if 0 == frame_num: return None def animate(i): ax.clear() ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) ax.xaxis.set_major_locator(mdates.DayLocator()) ax.set_title('盯盘', fontproperties=get_chinese_font()) ax.set_xlabel('时间', fontproperties=get_chinese_font()) ax.set_ylabel('增长', fontproperties=get_chinese_font()) ax.set_ylim((-6, 6)) fig.autofmt_xdate() for name in name_list: pchange_list = list() price_list = cdata[cdata.name == name]['price'].tolist() pchange_list.append(0) for _index in range(1, len(price_list)): pchange_list.append( 10 * (price_list[_index] - price_list[_index - 1]) / price_list[0]) ax.plot(ctime_list[0:i], pchange_list[0:i], label=name, linewidth=1.5) if pchange_list[i - 1] > 1 or pchange_list[i - 1] < -1: ax.text(ctime_list[i - 1], pchange_list[i - 1], name, font_properties=get_chinese_font()) ani = animation.FuncAnimation(fig, animate, frame_num, interval=60000, repeat=False) sfile = '/data/animation/%s_animation.mp4' % _today if sfile is None else sfile ani.save(sfile, writer) plt.close(fig) def get_range_data(self, start_date, end_date, code): sql = "select * from day where date between \"%s\" and \"%s\"" % ( start_date, end_date) self.mysql_client.changedb(CStock.get_dbname(code)) return (code, self.mysql_client.get(sql)) def gen_stocks_trends(self, start_date, end_date, stock_info, max_length): good_list = list() obj_pool = Pool(500) all_df = pd.DataFrame() failed_list = stock_info.code.tolist() cfunc = partial(self.get_range_data, start_date, end_date) while len(failed_list) > 0: logger.info("restart failed ip len(%s)" % len(failed_list)) for code_data in obj_pool.imap_unordered(cfunc, failed_list): if code_data[1] is not None: tem_df = code_data[1] if len(tem_df) == max_length: tem_df = tem_df.sort_values(by='date', ascending=True) tem_df['code'] = code_data[0] tem_df['preclose'] = tem_df['close'].shift(1) tem_df = tem_df[tem_df.date != start_date] all_df = all_df.append(tem_df) failed_list.remove(code_data[0]) obj_pool.join(timeout=5) obj_pool.kill() self.mysql_client.changedb() return all_df def relation_plot(self, df, good_list): close_price_list = [ df[df.code == code].close.tolist() for code in good_list ] close_prices = np.vstack(close_price_list) open_price_list = [ df[df.code == code].open.tolist() for code in good_list ] open_prices = np.vstack(open_price_list) # the daily variations of the quotes are what carry most information variation = (close_prices - open_prices) * 100 / open_prices logger.info("get variation succeed") # ############################################################################# # learn a graphical structure from the correlations edge_model = covariance.GraphLassoCV() # standardize the time series: using correlations rather than covariance is more efficient for structure recovery X = variation.copy().T X /= X.std(axis=0) edge_model.fit(X) logger.info("mode compute succeed") # ############################################################################# # cluster using affinity propagation _, labels = cluster.affinity_propagation(edge_model.covariance_) n_labels = labels.max() code_list = np.array(good_list) industry_dict = dict() industry_df_info = IndustryInfo.get() for index, name in industry_df_info.name.iteritems(): content = industry_df_info.loc[index]['content'] a_code_list = json.loads(content) for code in a_code_list: industry_dict[code] = name cluster_dict = dict() for i in range(n_labels + 1): cluster_dict[i] = code_list[labels == i] name_list = [ CStockInfo.get(code, 'name') for code in code_list[labels == i] ] logger.info('cluster code %i: %s' % ((i + 1), ', '.join(name_list))) cluster_info = dict() for group, _code_list in cluster_dict.items(): for code in _code_list: iname = industry_dict[code] if group not in cluster_info: cluster_info[group] = set() cluster_info[group].add(iname) logger.info('cluster inustry %i: %s' % ((i + 1), ', '.join(list(cluster_info[group])))) # ############################################################################# # find a low-dimension embedding for visualization: find the best position of # the nodes (the stocks) on a 2D plane # we use a dense eigen_solver to achieve reproducibility (arpack is # initiated with random vectors that we don't control). In addition, we # use a large number of neighbors to capture the large-scale structure. node_position_model = manifold.LocallyLinearEmbedding( n_components=2, eigen_solver='dense', n_neighbors=6) embedding = node_position_model.fit_transform(X.T).T # ############################################################################# # visualizatio plt.figure(1, facecolor='w', figsize=(10, 8)) plt.clf() ax = plt.axes([0., 0., 1., 1.]) plt.axis('off') # display a graph of the partial correlations partial_correlations = edge_model.precision_.copy() d = 1 / np.sqrt(np.diag(partial_correlations)) partial_correlations *= d partial_correlations *= d[:, np.newaxis] non_zero = (np.abs(np.triu(partial_correlations, k=1)) > 0.02) # plot the nodes using the coordinates of our embedding plt.scatter(embedding[0], embedding[1], s=100 * d**2, c=labels, cmap=plt.cm.nipy_spectral) # plot the edges start_idx, end_idx = np.where(non_zero) # a sequence of (*line0*, *line1*, *line2*), where:: linen = (x0, y0), (x1, y1), ... (xm, ym) segments = [[embedding[:, start], embedding[:, stop]] for start, stop in zip(start_idx, end_idx)] values = np.abs(partial_correlations[non_zero]) lc = LineCollection(segments, zorder=0, cmap=plt.cm.hot_r, norm=plt.Normalize(0, .7 * values.max())) lc.set_array(values) lc.set_linewidths(15 * values) ax.add_collection(lc) # add a label to each node. The challenge here is that we want to position the labels to avoid overlap with other labels for index, (name, label, (x, y)) in enumerate(zip(code_list, labels, embedding.T)): dx = x - embedding[0] dx[index] = 1 dy = y - embedding[1] dy[index] = 1 this_dx = dx[np.argmin(np.abs(dy))] this_dy = dy[np.argmin(np.abs(dx))] if this_dx > 0: horizontalalignment = 'left' x = x + .002 else: horizontalalignment = 'right' x = x - .002 if this_dy > 0: verticalalignment = 'bottom' y = y + .002 else: verticalalignment = 'top' y = y - .002 plt.text(x, y, name, size=10, horizontalalignment=horizontalalignment, verticalalignment=verticalalignment, bbox=dict(facecolor='w', edgecolor=plt.cm.nipy_spectral(label / float(n_labels)), alpha=.6)) plt.xlim( embedding[0].min() - .15 * embedding[0].ptp(), embedding[0].max() + .10 * embedding[0].ptp(), ) plt.ylim(embedding[1].min() - .03 * embedding[1].ptp(), embedding[1].max() + .03 * embedding[1].ptp()) plt.savefig('/tmp/relation.png', dpi=1000) def plot_price_series(self, df, ts1, ts2): fig = plt.figure() x = df.loc[df.code == ts1].date.tolist() xn = range(len(x)) y1 = df.loc[df.code == ts1].close.tolist() name1 = df[df.code == ts1].name.values[0] name2 = df[df.code == ts2].name.values[0] y2 = df.loc[df.code == ts2].close.tolist() plt.plot(xn, y1, label=name1, linewidth=1.5) plt.plot(xn, y2, label=name2, linewidth=1.5) plt.xticks(xn, x) plt.xlabel('时间', fontproperties=get_chinese_font()) plt.ylabel('分数', fontproperties=get_chinese_font()) plt.title('协整关系', fontproperties=get_chinese_font()) fig.autofmt_xdate() plt.legend(loc='upper right', prop=get_chinese_font()) plt.savefig('/tmp/relation/%s_%s.png' % (ts1, ts2), dpi=1000) plt.close(fig)
class BullStockRatio: def __init__(self, index_code, dbinfo=ct.DB_INFO, redis_host=None): self.dbinfo = dbinfo self.index_code = index_code self.index_obj = CIndex(index_code, dbinfo=self.dbinfo, redis_host=redis_host) self.db_name = self.index_obj.get_dbname(index_code) self.logger = getLogger(__name__) self.ris = RIndexStock(dbinfo, redis_host) self.bull_stock_ratio_table = self.get_table_name() self.redis = create_redis_obj( ) if redis_host is None else create_redis_obj(redis_host) self.mysql_client = CMySQL(self.dbinfo, dbname=self.db_name, iredis=self.redis) if not self.create(): raise Exception("create emotion table failed") def delete(self): self.mysql_client.delete(self.bull_stock_ratio_table) def get_table_name(self): return "%s_%s" % (self.db_name, ct.BULLSTOCKRATIO_TABLE) def create(self): if self.bull_stock_ratio_table not in self.mysql_client.get_all_tables( ): sql = 'create table if not exists %s(date varchar(10) not null, ratio float, PRIMARY KEY (date))' % self.bull_stock_ratio_table if not self.mysql_client.create(sql, self.bull_stock_ratio_table): return False return True def is_date_exists(self, table_name, cdate): if self.redis.exists(table_name): return cdate in set( str(tdate, encoding=ct.UTF8) for tdate in self.redis.smembers(table_name)) return False def get_k_data_between(self, start_date, end_date): sql = "select * from %s where date between \"%s\" and \"%s\"" % ( self.get_table_name(), start_date, end_date) return self.mysql_client.get(sql) def get_components(self, cdate): df = self.index_obj.get_components_data(cdate) if df is None: return list() if df.empty: return list() if self.index_code == '000001': df = df[df.code.str.startswith('6')] return df.code.tolist() def get_data(self, cdate): return self.ris.get_data(cdate) def update(self, end_date=None, num=30): if end_date is None: end_date = datetime.now().strftime('%Y-%m-%d') #start_date = "1997-12-30" start_date = get_day_nday_ago(end_date, num=num, dformat="%Y-%m-%d") succeed = True code_list = self.get_components(end_date) if 0 == len(code_list): self.logger.error("%s code_list for %s is empty" % (end_date, self.index_code)) return False for mdate in get_dates_array(start_date, end_date): if CCalendar.is_trading_day(mdate, redis=self.redis): if not self.set_ratio(code_list, mdate): self.logger.error("set %s score for %s set failed" % (self.index_code, mdate)) succeed = False return succeed def get_profit_stocks(self, df): data = df[df.profit >= 0] return data.code.tolist() def set_ratio(self, now_code_list, cdate=datetime.now().strftime('%Y-%m-%d')): if self.is_date_exists(self.bull_stock_ratio_table, cdate): self.logger.debug("existed date:%s, date:%s" % (self.bull_stock_ratio_table, cdate)) return True code_list = self.get_components(cdate) if len(code_list) == 0: code_list = now_code_list df = self.get_data(cdate) if df is None: return False if df.empty: return False df = df[df.code.isin(code_list)] if df.empty: return True profit_code_list = self.get_profit_stocks(df) bull_stock_num = len(profit_code_list) bull_ration = 100 * bull_stock_num / len(df) data = {'date': [cdate], 'ratio': [bull_ration]} df = pd.DataFrame.from_dict(data) if self.mysql_client.set(df, self.bull_stock_ratio_table): return self.redis.sadd(self.bull_stock_ratio_table, cdate) return False
class CReivew: def __init__(self, dbinfo): self.sdir = '/data/docs/blog/hellobiek.github.io/source/_posts' self.doc = CDoc(self.sdir) self.redis = create_redis_obj() self.mysql_client = CMySQL(dbinfo) self.cal_client = ccalendar.CCalendar(without_init=True) self.trading_info = None self.animating = False self.emotion_table = ct.EMOTION_TABLE self.industry_table = ct.INDUSTRY_TABLE if not self.create_industry(): raise Exception("create industry table failed") if not self.create_emotion(): raise Exception("create emotion table failed") def get_industry_name_dict_from_tongdaxin(self, fname): industry_dict = dict() with open(fname, "rb") as f: data = f.read() info_list = data.decode("gbk").split('######\r\n') for info in info_list: xlist = info.split('\r\n') if xlist[0] == '#TDXNHY': zinfo = xlist[1:len(xlist) - 1] for z in zinfo: x = z.split('|') industry_dict[x[0]] = x[1] return industry_dict def get_industry_code_dict_from_tongdaxin(self, fname): industry_dict = dict() with open(fname, "rb") as f: data = f.read() str_list = data.decode("utf-8").split('\r\n') for x in str_list: info_list = x.split('|') if len(info_list) == 4: industry = info_list[2] code = info_list[1] if industry == "T00": continue #not include B stock if industry not in industry_dict: industry_dict[industry] = list() industry_dict[industry].append(code) for key in industry_dict: industry_dict[key] = json.dumps(industry_dict[key]) return industry_dict def get_industry(self): industry_code_dict = self.get_industry_code_dict_from_tongdaxin( ct.TONG_DA_XIN_CODE_PATH) industry_name_dict = self.get_industry_name_dict_from_tongdaxin( ct.TONG_DA_XIN_INDUSTRY_PATH) name_list = list() for key in industry_code_dict: name_list.append(industry_name_dict[key]) data = { 'name': name_list, 'code': list(industry_code_dict.keys()), 'content': list(industry_code_dict.values()) } return pd.DataFrame.from_dict(data) def create_industry(self): if self.industry_table not in self.mysql_client.get_all_tables(): sql = 'create table if not exists %s(date varchar(10) not null, code varchar(10) not null, name varchar(20), amount float, PRIMARY KEY (date, code))' % self.industry_table if not self.mysql_client.create(sql, self.industry_table): return False return True def create_emotion(self): if self.emotion_table not in self.mysql_client.get_all_tables(): sql = 'create table if not exists %s(date varchar(10) not null, score float, PRIMARY KEY (date))' % self.emotion_table if not self.mysql_client.create(sql, self.emotion_table): return False return True def collect_industry_info(self): industry_df = self.get_industry() name_list = list() icode_list = list() changepercent_list = list() turnoverratio_list = list() amount_list = list() for index, code_id in industry_df['code'].items(): code_list = json.loads(industry_df.loc[index]['content']) code_info = self.trading_info[self.trading_info.code.isin( code_list)] _name = industry_df.loc[index]['name'] name_list.append(_name) icode_list.append(code_id) _amount = code_info.amount.astype(float).sum() amount_list.append(_amount) data = {'name': name_list, 'code': icode_list, 'amount': amount_list} df = pd.DataFrame.from_dict(data) df['date'] = datetime.now().strftime('%Y-%m-%d') if not self.mysql_client.set(df, self.industry_table): raise Exception("set data to industry failed") def gen_today_industry(self): sql = "select * from %s where date = '%s';" % ( self.industry_table, datetime.now().strftime('%Y-%m-%d')) df = self.mysql_client.get(sql) df = df[['amount', 'name']] df = df.sort_values(by='amount', ascending=False) total_amount = df.amount.astype(float).sum() df = df[0:10] most_amount = df.amount.astype(float).sum() other_amount = total_amount - most_amount df.loc[len(df)] = [other_amount, '其他'] df = df.sort_values(by='amount', ascending=False) return df.reset_index(drop=True) def emotion_plot(self, dir_name): sql = "select * from %s;" % self.emotion_table df = self.mysql_client.get(sql) fig = plt.figure() x = df.date.tolist() xn = range(len(x)) y = df.score.tolist() plt.plot(xn, y) for xi, yi in zip(xn, y): plt.plot((xi, ), (yi, ), 'ro') plt.text(xi, yi, '%s' % yi) plt.scatter(xn, y, label='score', color='k', s=25, marker="o") plt.xticks(xn, x) plt.xlabel('时间', fontproperties=get_chinese_font()) plt.ylabel('分数', fontproperties=get_chinese_font()) plt.title('股市情绪', fontproperties=get_chinese_font()) fig.autofmt_xdate() plt.savefig('%s/emotion.png' % dir_name, dpi=1000) def industry_plot(self, df, dir_name): colors = [ '#F5DEB3', '#A0522D', '#1E90FF', '#FFE4C4', '#00FFFF', '#DAA520', '#3CB371', '#808080', '#ADFF2F', '#4B0082', '#ADD8E6' ] fig = plt.figure() sum_amount = df.amount.sum() / 10000000000 amount_list = df.amount.tolist() amount_list = [i / 100000000 for i in amount_list] x = date.fromtimestamp(time.time()) base_line = 0 for i in range(len(amount_list)): label_name = "%s:%s" % (df.loc[i]['name'], amount_list[i] / sum_amount) plt.bar(x, amount_list[i], width=0.35, color=colors[i], bottom=base_line, align='center', label=label_name) base_line += amount_list[i] plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%m/%d/%Y')) plt.gca().xaxis.set_major_locator(mdates.DayLocator()) plt.xlabel('x轴', fontproperties=get_chinese_font()) plt.ylabel('y轴', fontproperties=get_chinese_font()) plt.title('市值分布', fontproperties=get_chinese_font()) fig.autofmt_xdate() plt.legend(loc='upper right', prop=get_chinese_font()) plt.savefig('%s/industry.png' % dir_name, dpi=1000) def gen_market_emotion_score(self): total = 0 changepercent_list = self.trading_info.changepercent.tolist() for changepercent in changepercent_list: if changepercent > 9.8 or changepercent < -9.8: total += changepercent * 2 else: total += changepercent aver = total / len(changepercent_list) data = { 'date': ["%s" % datetime.now().strftime('%Y-%m-%d')], 'score': [aver] } df = pd.DataFrame.from_dict(data) if not self.mysql_client.set(df, self.emotion_table): raise Exception("set data to emotion failed") def static_plot(self, dir_name): colors = ['b', 'r', 'y', 'g', 'm'] num_list = list() changepercent_list = [9.81, 5, 3, 1, 0, -1, -3, -5, -9.91] name_list = list() c_length = len(changepercent_list) for _index in range(c_length): pchange = changepercent_list[_index] if 0 == _index: num_list.append( len(self.trading_info[ self.trading_info.changepercent > pchange])) name_list.append(">%s" % pchange) elif c_length - 1 == _index: num_list.append( len(self.trading_info[ self.trading_info.changepercent < pchange])) name_list.append("<%s" % pchange) else: p_max_change = changepercent_list[_index - 1] num_list.append( len(self.trading_info[ (self.trading_info.changepercent > pchange) & (self.trading_info.changepercent < p_max_change)])) name_list.append("%s-%s" % (pchange, p_max_change)) fig = plt.figure() for i in range(len(num_list)): plt.bar(i, num_list[i], color=colors[i % len(colors)], width=0.1) plt.text(i, 1.1 * num_list[i], '个数:%s' % num_list[i], ha='center', font_properties=get_chinese_font()) plt.xlabel('x轴', fontproperties=get_chinese_font()) plt.ylabel('y轴', fontproperties=get_chinese_font()) plt.title('涨跌分布', fontproperties=get_chinese_font()) plt.xticks(range(len(num_list)), name_list) fig.autofmt_xdate() plt.savefig('%s/static.png' % dir_name, dpi=1000) def is_collecting_time(self): now_time = datetime.now() _date = now_time.strftime('%Y-%m-%d') y, m, d = time.strptime(_date, "%Y-%m-%d")[0:3] mor_open_hour, mor_open_minute, mor_open_second = (16, 0, 0) mor_open_time = datetime(y, m, d, mor_open_hour, mor_open_minute, mor_open_second) mor_close_hour, mor_close_minute, mor_close_second = (23, 59, 59) mor_close_time = datetime(y, m, d, mor_close_hour, mor_close_minute, mor_close_second) return mor_open_time < now_time < mor_close_time def update(self, sleep_time): while True: try: if self.cal_client.is_trading_day(): if self.is_collecting_time(): self.trading_info = ts.get_today_all() _date = datetime.now().strftime('%Y-%m-%d') dir_name = os.path.join(self.sdir, "%s-StockReView" % _date) if not os.path.exists(dir_name): logger.info("create daily info") os.makedirs(dir_name) self.collect_industry_info() df = self.gen_today_industry() self.industry_plot(df, dir_name) self.gen_market_emotion_score() self.emotion_plot(dir_name) self.static_plot(dir_name) self.doc.generate() self.gen_animation() time.sleep(sleep_time) except Exception as e: time.sleep(120) traceback.print_exc() def is_sleep_time(self): now_time = datetime.now() _date = now_time.strftime('%Y-%m-%d') y, m, d = time.strptime(_date, "%Y-%m-%d")[0:3] mor_open_hour, mor_open_minute, mor_open_second = (11, 30, 0) mor_open_time = datetime(y, m, d, mor_open_hour, mor_open_minute, mor_open_second) aft_close_hour, aft_close_minute, aft_close_second = (13, 0, 0) aft_close_time = datetime(y, m, d, aft_close_hour, aft_close_minute, aft_close_second) return mor_open_time < now_time < aft_close_time def is_animate_time(self): now_time = datetime.now() _date = now_time.strftime('%Y-%m-%d') y, m, d = time.strptime(_date, "%Y-%m-%d")[0:3] mor_open_hour, mor_open_minute, mor_open_second = (9, 10, 0) mor_open_time = datetime(y, m, d, mor_open_hour, mor_open_minute, mor_open_second) aft_close_hour, aft_close_minute, aft_close_second = (15, 5, 0) aft_close_time = datetime(y, m, d, aft_close_hour, aft_close_minute, aft_close_second) return mor_open_time < now_time < aft_close_time def gen_animation(self, sfile=None): style.use('fivethirtyeight') Writer = animation.writers['ffmpeg'] writer = Writer(fps=1, metadata=dict(artist='biek'), bitrate=1800) fig = plt.figure() ax = fig.add_subplot(1, 1, 1) cdata = self.mysql_client.get('select * from %s' % ct.ANIMATION_INFO) if cdata is None: return None cdata = _pickle.loads(cdata) _today = datetime.now().strftime('%Y-%m-%d') show_data = cdata[cdata.cdate == _today] show_data = show_data.reset_index(drop=True) ctime_list = show_data.ctime.unique() ctime_list = [ datetime.strptime(ctime, '%H:%M:%S') for ctime in ctime_list ] frame_num = len(ctime_list) if 0 == frame_num: return None def animate(i): ax.clear() ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) ax.xaxis.set_major_locator(mdates.DayLocator()) ax.set_title('盯盘', fontproperties=get_chinese_font()) ax.set_xlabel('时间', fontproperties=get_chinese_font()) ax.set_ylabel('增长', fontproperties=get_chinese_font()) ax.set_ylim((-10, 50)) fig.autofmt_xdate() for _, name in show_data.name.iteritems(): pchange_list = show_data[show_data.name == name]['pchange'].tolist() ax.plot(ctime_list, pchange_list, label=name, linewidth=1.5) pchange = pchange_list[len(pchange_list) - 1] ctime = ctime_list[len(ctime_list) - 1] if pchange > 3: ax.text(ctime, pchange * 2, name, font_properties=get_chinese_font()) ani = animation.FuncAnimation(fig, animate, frame_num, interval=60000, repeat=False) sfile = '/data/animation/%s_animation.mp4' % _today if sfile is None else sfile ani.save(sfile, writer) plt.close(fig)
class Emotion: def __init__(self, dbinfo = ct.DB_INFO, redis_host = None): self.dbinfo = dbinfo self.emotion_table = ct.EMOTION_TABLE self.redis = create_redis_obj() if redis_host is None else create_redis_obj(redis_host) self.mysql_client = CMySQL(self.dbinfo, iredis = self.redis) self.rstock_client = RIndexStock(dbinfo, redis_host) self.logger = getLogger(__name__) if not self.create(): raise Exception("create emotion table failed") def create(self): if self.emotion_table not in self.mysql_client.get_all_tables(): sql = 'create table if not exists %s(date varchar(10) not null, score float, PRIMARY KEY (date))' % self.emotion_table if not self.mysql_client.create(sql, self.emotion_table): return False return True def get_score(self, cdate = None): if cdate is None: sql = "select * from %s" % self.emotion_table else: sql = "select * from %s where date=\"%s\"" %(self.emotion_table, cdate) return self.mysql_client.get(sql) def get_stock_data(self, cdate): df_byte = self.redis.get(ct.TODAY_ALL_STOCK) if df_byte is None: return None df = _pickle.loads(df_byte) return df.loc[df.date == date] def update(self, end_date = None, num = 3): if end_date is None: end_date = datetime.now().strftime('%Y-%m-%d') start_date = get_day_nday_ago(end_date, num = num, dformat = "%Y-%m-%d") succeed = True for mdate in get_dates_array(start_date, end_date): if CCalendar.is_trading_day(mdate, redis = self.redis): if not self.set_score(mdate): succeed = False self.logger.info("set score for %s set failed" % mdate) return succeed def set_score(self, cdate = datetime.now().strftime('%Y-%m-%d')): stock_info = self.rstock_client.get_data(cdate) limit_info = CLimit(self.dbinfo).get_data(cdate) if stock_info.empty or limit_info.empty: self.logger.error("info is empty failed") return False limit_up_list = limit_info[(limit_info.pchange > 0) & (limit_info.prange != 0)].reset_index(drop = True).code.tolist() limit_down_list = limit_info[limit_info.pchange < 0].reset_index(drop = True).code.tolist() limit_up_list.extend(limit_down_list) total = 0 for _index, pchange in stock_info.pchange.iteritems(): code = stock_info.loc[_index, 'code'] if code in limit_up_list: total += 2 * pchange else: total += pchange aver = total / len(stock_info) data = {'date':[cdate], 'score':[aver]} df = pd.DataFrame.from_dict(data) return self.mysql_client.set(df, self.emotion_table)