def predict_today(datatype, timesteps, data_dim=15): # log = logger.log nowdate = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) x_predict, id_predict, name_predict = trade.get_today(seg_len=timesteps, datatype=datatype, split=0.1, debug=False) network = policy.LSTMPolicy.create_network(timesteps=timesteps, data_dim=data_dim) USER_HOME = os.environ['HOME'] out_directory_path = USER_HOME + '/dw/' meta_file = os.path.join(out_directory_path, 'metadata.json') weights_path = policy_trainer.get_best_weights(meta_file) network.load_weights(weights_path) predicts = network.predict(x_predict, batch_size=16) v_predicts = pd.DataFrame() v_predicts['code'] = id_predict v_predicts['name'] = name_predict v_predicts['predict'] = predicts v_predicts['datain_date'] = nowdate db = Db() v_predicts = v_predicts.to_dict('records') db.insertmany("""INSERT INTO predicts(code,name,predict,datain_date) VALUES (%(code)s,%(name)s,%(predict)s,%(datain_date)s)""", v_predicts) log.info('predicts finished')
def insert_hist_trade(self): self.set_data() db = Db() engine = db._get_engine() sql_stocklist = "select code,name from stock_code" codes = pd.read_sql_query(sql_stocklist, engine) codes = codes.to_dict('records') i = 1 for row in codes: gta = td.get_hist_data(code=row['code'], start=self.nowdate, end=self.nowdate, ktype='D', retry_count=3, pause=0.001) gta['datain_date'] = self.nowtime gta['code'] = row['code'] gta['name'] = row['name'] gta['c_yearmonthday'] = gta.index gta = gta.to_dict('records') try: db.insertmany( """INSERT INTO trade_hist(c_yearmonthday,code,name,open,high,close,low,volume,price_change,p_change,ma5,ma10,ma20,v_ma5,v_ma10,v_ma20,turnover,datain_date) VALUES (%(c_yearmonthday)s,%(code)s,%(name)s,%(open)s,%(high)s,%(close)s,%(low)s,%(volume)s,%(price_change)s,%(p_change)s,%(ma5)s,%(ma10)s,%(ma20)s,%(v_ma5)s,%(v_ma10)s,%(v_ma20)s,%(turnover)s,%(datain_date)s)""", gta) except Exception, e: log.error('insert error:%s ', e) log.info('%s stock insert finished,%s,%s', i, row['code'], row['name'].decode('utf-8')) i += 1
def conecta(): #print(config) db = Db(config) #print(db.config); conn = db.connect() #print(db) return db
def predict_today(datatype, timesteps, data_dim=15): # log = logger.log nowdate = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) x_predict, id_predict, name_predict = trade.get_today(seg_len=timesteps, datatype=datatype, split=0.1, debug=False) network = policy.LSTMPolicy.create_network(timesteps=timesteps, data_dim=data_dim) USER_HOME = os.environ['HOME'] out_directory_path = USER_HOME + '/dw/' meta_file = os.path.join(out_directory_path, 'metadata.json') weights_path = policy_trainer.get_best_weights(meta_file) network.load_weights(weights_path) predicts = network.predict(x_predict, batch_size=16) v_predicts = pd.DataFrame() v_predicts['code'] = id_predict v_predicts['name'] = name_predict v_predicts['predict'] = predicts v_predicts['datain_date'] = nowdate db = Db() v_predicts = v_predicts.to_dict('records') db.insertmany( """INSERT INTO predicts(code,name,predict,datain_date) VALUES (%(code)s,%(name)s,%(predict)s,%(datain_date)s)""", v_predicts) log.info('predicts finished')
def __init__(self, hostname, protocol, username=None, password=None): self.db = Db() self.connection = Connection() self.hostname = protocol + hostname self.username = username self.password = password self.is_logged_in = False self.exploit_results = {} self.connection.reset_session()
def setUp(self): init() self.parser = ConfigParser() logging.config.fileConfig('test/db_test.cfg') logger = logging.getLogger('basic') self.parser.read('test/db_test.cfg') connection_string = self.parser.get('Db', 'connection_string') self.db = Db(connection_string, logger) self.db.open()
def get_hist_orgindata(debug=False): db = Db() engine = db._get_engine() sql_stocklist = "select * from trade_hist where code in (select code from trade_hist where high<>0.0 and low <>0.0 group by code having count(code)>100)" if debug: sql_stocklist += " and code in ('002717','601888','002405')" df = pd.read_sql_query(sql_stocklist, engine) codes = df['code'].unique() # 增加技术指标 df = add_volatility(df) df = get_technique(df) return df, codes
def get_predict_acc1(debug=False): db = Db() engine = db._get_engine() sql_tradehist = "select code,name,p_change from trade_hist where code in (select code from predict_head where c_yearmonthday in (select max(c_yearmonthday) from predict_head) ) order by c_yearmonthday desc" sql_predicthead = "select code,predict from predict_head order by c_yearmonthday desc" if debug: pass df_trade = pd.read_sql_query(sql_tradehist, engine).head(2) df_predict = pd.read_sql_query(sql_predicthead, engine).head(2) df = pd.merge(df_trade, df_predict, on='code') df['acc'] = (df.p_change > 0).astype(float) return df
def insert_predict_statics(): db = Db() nowdate = time.strftime("%Y-%m-%d", time.localtime(time.time())) psummery,headpredict = trade.get_predict() psummery['c_yearmonthday'] = nowdate headpredict['c_yearmonthday'] = nowdate psummery=psummery.to_dict('records') headpredict=headpredict.to_dict('records') db.insertmany("""INSERT INTO predict_head(c_yearmonthday,code,name,predict) VALUES (%(c_yearmonthday)s,%(code)s,%(name)s,%(predict)s)""", headpredict) db.insertmany("""INSERT INTO predict_statics(c_yearmonthday,p_cnt,p_mean,p_std,p_min,p25,p50,p75,p_max) VALUES (%(c_yearmonthday)s,%(p_cnt)s,%(p_mean)s,%(p_std)s,%(p_min)s,%(p25)s,%(p50)s,%(p75)s,%(p_max)s)""", psummery)
def get_predict(debug=False): db = Db() engine = db._get_engine() sql_stocklist = "select * from predicts where datain_date in(select max(datain_date) from predicts) order by predict desc" if debug: pass df = pd.read_sql_query(sql_stocklist, engine) headpredict = df.head(2) psummery = df.describe().T psummery.columns = ['p_cnt', 'p_mean', 'p_std', 'p_min', 'p25', 'p50', 'p75', 'p_max'] return psummery, headpredict
def get_predict(debug=False): db = Db() engine = db._get_engine() sql_stocklist = "select * from predicts where datain_date in(select max(datain_date) from predicts) order by predict desc" if debug: pass df = pd.read_sql_query(sql_stocklist, engine) headpredict = df.head(2) psummery = df.describe().T psummery.columns = [ 'p_cnt', 'p_mean', 'p_std', 'p_min', 'p25', 'p50', 'p75', 'p_max' ] return psummery, headpredict
def insert_predict_acc(): db = Db() nowdate = time.strftime("%Y-%m-%d", time.localtime(time.time())) acc1 = trade.get_predict_acc1() acc1['c_yearmonthday'] = nowdate acc1=acc1.to_dict('records') db.insertmany("""INSERT INTO acc1(c_yearmonthday,code,name,predict,p_change,acc) VALUES (%(c_yearmonthday)s,%(code)s,%(name)s,%(predict)s,%(p_change)s,%(acc)s)""", acc1) acc2 = trade.get_predict_acc2() acc2['c_yearmonthday'] = nowdate acc2=acc2.to_dict('records') db.insertmany("""INSERT INTO acc2(c_yearmonthday,p_acc,p_change,h_p_acc,h_p_change) VALUES (%(c_yearmonthday)s,%(p_acc)s,%(p_change)s,%(h_p_acc)s,%(h_p_change)s)""", acc2)
def insert_today_trade(self): self.set_data() db = Db() gta = td.get_today_all() gta['datain_date'] = self.nowtime gta['c_yearmonthday'] = self.nowdate gta = gta.to_dict('records') db.insertmany( """INSERT INTO trade_record(c_yearmonthday,code,name,changepercent,trade,open,high,low,settlement ,volume,turnoverratio,amount,per,pb,mktcap,nmc,datain_date) VALUES (%(c_yearmonthday)s,%(code)s,%(name)s,%(changepercent)s,%(trade)s,%(open)s,%(high)s,%(low)s,%(settlement)s,%(volume)s,%(turnoverratio)s,%(amount)s,%(per)s,%(pb)s,%(mktcap)s,%(nmc)s,%(datain_date)s)""", gta)
def mockdb(mocker): """A test mock database.""" mock_init_conn = mocker.patch('db.db.Db.initConnection') mock_init_conn.side_effect = mock_db db = Db() # Tried scoping to module - but mocker is function, so not able. Is there a bulk delete? for r in db.Results.objects(): r.delete() db.Results(file_id="file_1", rule_id="rule_1", state=1, message="").save() db.Results(file_id="file_2", rule_id="rule_1", state=1, message="").save() db.Results(file_id="file_2", rule_id="rule_2", state=1, message="").save() return db
def validate(rule, result): logid = "Validate [" + str(os.getpid()) + "] " log.debug(logid + " multi name = " + current_process().name) db = Db(True, 15) resultObj = db.Results.objects(file_id=result.file_id, rule_id=result.rule_id)[0] log.debug(logid + "Result = " + str(resultObj)) notifier = Notifications() source = "" if 'Source' in rule: source = rule['Source'] else: source = rule['source'] result, message = read_file_and_evaluate(source, resultObj) log.debug("Running validation process for " + rule['name'] + " got result " + str(result) + " and message " + message) if result: resultObj.state = 0 else: resultObj.state = 1 if resultObj.mandatory: resultObj.message = "Failed" else: resultObj.message = "Warning" resultObj.save() notifier.publish('fileStatus', resultObj.to_json())
def registrar_saida_in(self, hora: str = None) -> bool: try: where = {"data": self.data, "colaborador_id": self.colaborador_id} self.registros_ES[-1].update({"saida": hora if hora else pendulum.now().format(hour_format)}) return True if Db.update("ponto", where, self) else False except: raise Exception
def checker(): print("checker started") db = Db(cfg) while True: time.sleep(cfg.SLEEP_TIME) ipPort = db.select_proxy() if ipPort: proxyDict = { "http": ipPort, "https": ipPort, } r = requests.get(cfg.LINK_TO_BE_CHECKED, headers={'User-Agent': random.choice(cfg.USER_AGENT)}, proxies=proxyDict) if not 200 <= r.status_code <= 299: db.delete_row(ipPort) logger.info("{} deleted".format(ipPort)) else: logger.error("{} : {}".format(r.status_code, r.text))
def process_item(self, item, spider): db = Db() if item['table_name'] == 'movie': if item['summary']: db.movie_update(item) else: db.movie_insert(item) else: db.comment_insert(item) return item
def get_predict_acc2(debug=False): db = Db() engine = db._get_engine() sql_stocklist = "select * from acc1" if debug: pass df = pd.read_sql_query(sql_stocklist, engine) acc2 = df.sort_values('c_yearmonthday', ascending=0) acc2 = acc2.head(2) acc2 = acc2.groupby('c_yearmonthday').sum() acc2_final = pd.DataFrame() acc2_final['h_p_acc'] = [df['acc'].sum() / float(df['acc'].count())] acc2_final['h_p_change'] = [df['p_change'].sum() / 2.0] acc2_final['p_acc'] = [acc2['acc'].sum() / 2.0] acc2_final['p_change'] = [acc2['p_change'].sum() / 2.0] return acc2_final
def start(token, backend, db, dialog, filename): bot = init_backend(token, backend) path = db['path'] del db['path'] db_path = init_db(path, **db) db = Db(db_path) dialog = Dialog(db, dialog) print('Start') while True: messages = bot.get_unread_messages() if messages["count"] >= 1: id, message_id, body, url = bot.get_message_ids_image(messages) print('Запрос:', id, body) try: if url and url[-3:] == 'jpg': urllib.request.urlretrieve(url, filename) db.save_image(id, body, filename) bot.send_message(id, dialog.ok) continue if body.lower() in dialog.common_answer: bot.send_message(id, dialog.common_answer[body.lower()]) continue for key in dialog.action: if key in body.lower(): part = body.lower().split(':')[1].strip() if '&' in body.lower(): splitted = part.split('&') part_1, part_2 = splitted[0].strip( ), splitted[1].strip() db.hset(id, part_1, part_2) bot.send_message(id, dialog.ok) break data = db.hget(id, part) if data and len(data) > 15000: with open(filename, 'wb') as file: file.write(data) bot.upload_image(id, filename) break elif data: bot.send_message(id, data) else: bot.send_message(id, dialog.text_does_not_exist) break else: bot.send_message(id, dialog.don_t_understand) except Exception as e: print(str(e)) bot.send_message(id, dialog.error) time.sleep(1)
def new_proxy(): print("new_proxy started") db = Db(cfg) while True: time.sleep(cfg.SLEEP_TIME) if db.tot_rows() <= cfg.MAX_IP: try: proxyjson = requests.get('http://gimmeproxy.com/api/getProxy?maxCheckPeriod=300?protocol=http').json() print(proxyjson) except requests.exceptions.RequestException as e: logger.error(e) except ValueError as e: # includes simplejson.decoder.JSONDecodeError logger.error(e) if 'ipPort' in proxyjson.keys(): db.insert_row(proxyjson['ipPort']) logger.info("proxy added: {}".format(proxyjson['ipPort'])) else: logger.error(proxyjson)
def test(): infile = 'data.csv' df = pd.read_csv(infile) df = df.groupby('X') db = Db(TEMP_DB) with open('script.sql', 'w') as file: for d in df: (table_name, data) = d t = Table(db, table_name, data, file, test_mode=True) t.prepare_temp_table()
def criar_registro_ES(self, hour_in: str = None, hour_out: str = "") -> Dict[str, str]: try: self.data = pendulum.now().format(date_format) self.registros_ES = [] self.registros_ES = [ {"_id": str(uuid1()), "entrada": hour_in if hour_in else pendulum.now().format(hour_format), "saida": hour_out if hour_out else ""}] ret = Db.save('ponto', self) return ret if ret else None except: raise Exception
def find(cls, colaborador_id: str = None, cpf: str = None) -> List[Dict[str, str]]: try: where = {} where.update({"_id": f"{colaborador_id.replace(' ', '')}"} if colaborador_id else {}) where.update({"cpf": f"{cpf.replace(' ', '')}"} if cpf else {}) colaboradores = Db.find("colaborador", where) return colaboradores if colaboradores else None except: # TODO EXCEPT raise Exception
def mockdb(mocker): """A test mock database.""" mock_init_conn = mocker.patch('db.db.Db.initConnection') mock_init_conn.side_effect = mock_db db = Db() # Tried scoping to module - but mocker is function, so not able. Is there a bulk delete? for r in Results.objects(): r.delete() Rules(name="rule1", source="${file.name}!=badFile", mandatory=False).save() Rules(name="rule2", source="${file.size}<500", state=1, mandatory=True).save() return db
def start(token, backend, db, filename): bot = init_backend(token, backend) path = db['path'] del db['path'] db_path = init_db(path, **db) db = Db(db_path) print('Start') while True: try: messages = bot.get_unread_messages() if messages["count"] >= 1: id, body = bot.get_message_and_id(messages) bot.convert_text_to_voice(body) uploaded_voice = bot.upload_file(filename, id) bot.send_message(id, attach=uploaded_voice) db.rpush(id, body) print('Запрос:', id, body) time.sleep(1) except Exception as e: print('Error:', str(e))
def find(cls, ponto_id: str = None, colaborador_id: str = None, data: str = None, mes: str = None) -> List[ Dict[str, str]]: try: where = {} where.update({"_id": f"{ponto_id.replace(' ', '')}"} if ponto_id else {}) where.update({"data": f"{data.replace(' ', '')}"} if data else {}) where.update({"colaborador_id": f"{colaborador_id.replace(' ', '')}"} if colaborador_id else {}) if mes: regex = re.compile(f"\\d\\d\\/{mes.replace(' ', '')}\\/\\d\\d\\d\\d") where.update({"data": regex}) pontos = Db.find("ponto", where, sort_by="data") return pontos if pontos else None except: raise Exception
def registrar_entrada_in(self, hour: str = None) -> bool: try: if not self.data: return True if self.criar_registro_ES() else False dif_dias = pendulum.period(pendulum.from_format(self.data, date_format), pendulum.now()).in_days() if dif_dias > 0 and self.registros_ES[-1]["saida"] != "": return True if self.criar_registro_ES() else False where = {"colaborador_id": self.colaborador_id, "data": self.data} self.registros_ES.append( {"_id": str(uuid1()), "entrada": pendulum.now().format(hour_format) if not hour else hour , "saida": ""}) return True if Db.update("ponto", where, self) else False except: # TODO EXCEPT raise Exception
def mockdb(mocker): """A test mock database.""" mock_init_conn = mocker.patch('db.db.Db.initConnection') mock_init_conn.side_effect = mock_db db = Db() # Tried scoping to module - but mocker is function, so not able. Is there a bulk delete? for r in db.Rules.objects(): r.delete() for p in db.Policies.objects(): p.delete() Rules(name="rule1",source="${file.name}!=badFile",mandatory=False).save() Rules(name="rule2",source="${file.size}<500",mandatory=True).save() Policies(name="export-data",rules=['rule1','rule2']).save() Policies(name="bad-policy",rules=['rule1','rule2','rule3']).save() return db
def insert_predict_statics(): db = Db() nowdate = time.strftime("%Y-%m-%d", time.localtime(time.time())) psummery, headpredict = trade.get_predict() psummery['c_yearmonthday'] = nowdate headpredict['c_yearmonthday'] = nowdate psummery = psummery.to_dict('records') headpredict = headpredict.to_dict('records') db.insertmany( """INSERT INTO predict_head(c_yearmonthday,code,name,predict) VALUES (%(c_yearmonthday)s,%(code)s,%(name)s,%(predict)s)""", headpredict) db.insertmany( """INSERT INTO predict_statics(c_yearmonthday,p_cnt,p_mean,p_std,p_min,p25,p50,p75,p_max) VALUES (%(c_yearmonthday)s,%(p_cnt)s,%(p_mean)s,%(p_std)s,%(p_min)s,%(p25)s,%(p50)s,%(p75)s,%(p_max)s)""", psummery)
def insert_predict_acc(): db = Db() nowdate = time.strftime("%Y-%m-%d", time.localtime(time.time())) acc1 = trade.get_predict_acc1() acc1['c_yearmonthday'] = nowdate acc1 = acc1.to_dict('records') db.insertmany( """INSERT INTO acc1(c_yearmonthday,code,name,predict,p_change,acc) VALUES (%(c_yearmonthday)s,%(code)s,%(name)s,%(predict)s,%(p_change)s,%(acc)s)""", acc1) acc2 = trade.get_predict_acc2() acc2['c_yearmonthday'] = nowdate acc2 = acc2.to_dict('records') db.insertmany( """INSERT INTO acc2(c_yearmonthday,p_acc,p_change,h_p_acc,h_p_change) VALUES (%(c_yearmonthday)s,%(p_acc)s,%(p_change)s,%(h_p_acc)s,%(h_p_change)s)""", acc2)
def run(): infile = 'data.csv' result_file = 'results.csv' alter_table_file = 'alter_tables.sql' df = pd.read_csv(infile) df = df.groupby('X') db = Db(TEMP_DB) total_errors = float(0) alter_statements = [] with open('script.sql', 'w') as file: with open(result_file, 'w') as results: with open(alter_table_file, 'w') as alter_file: for d in df: (table_name, data) = d t = Table(db, table_name, data, sql_log=file, results_log=results, test_mode=False) t.prepare_temp_table() errors = t.test_temp_table() if errors == 0: statement = t.get_alter_table_for_original_table() alter_statements.append(statement) alter_file.write(statement) total_errors += float(errors) if total_errors == 0: print(f'All tests passed: alter table file generated in {alter_table_file}') do_it = input(f"Do you want to apply the changes to the original database ({ORIGINAL_DB}) [yN]:") if do_it.lower() == 'y': print('Applying modifications to the original database') db = Db(ORIGINAL_DB) for statement in alter_statements: db.execute(statement) print('All done') else: print(f'Not all tests passed') sys.exit(1)
def remover(self) -> bool: try: return True if Db.delete("colaborador", self) else None except: # TODO EXCEPT return False
class Exploit(object): def __init__(self, hostname, protocol, username=None, password=None): self.db = Db() self.connection = Connection() self.hostname = protocol + hostname self.username = username self.password = password self.is_logged_in = False self.exploit_results = {} self.connection.reset_session() def exploit(self, short_name=None): if self.connection.verify_socket(self.hostname) is False: results = {"error": "Could not connect to host."} elif self.username and self.password is not None and not self.login(self.hostname, self.username, self.password): results = {"error": "Unable to login with the credentials provided."} else: if short_name is not None: for exploit in self.db.get_exploits_by_exploit_type_short_name(short_name): self.run_exploit(exploit) else: for exploit_type in self.db.get_exploit_types(): for exploit in self.db.get_exploits_by_exploit_type_id(exploit_type.id): self.run_exploit(exploit) results = self.get_exploit_results() return results def run_exploit(self, exploit: DBExploit): if exploit.is_authenticated and not self.is_logged_in: pass else: self.validate_response( exploit, self.do_request(exploit, exploit.exploit_body if exploit.exploit_body is not None else '') ) def validate_response(self, exploit: DBExploit, response): if self.get_validator_by_id(exploit.validator_id).validate(response): self.exploit_found(exploit) def do_request(self, exploit: DBExploit, data): url = self.hostname + exploit.exploit_url if self.connection.verify_url(url) is False: return None return self.connection.request(hostname=url, data=data, headers=eval( exploit.exploit_headers) if exploit.exploit_headers is not None else {}, method=exploit.request_method, urlencode=exploit.is_url_encode) def exploit_found(self, exploit: DBExploit): self.exploit_results.update({ exploit.id: { "name": exploit.name, "version": exploit.version, "exploiturl": exploit.exploit_url } }) def login(self, hostname, username, password): self.is_logged_in = self.connection.login(hostname, username, password) return self.is_logged_in def get_exploit_results(self): exploits = self.exploit_results.copy() self.exploit_results.clear() return exploits @staticmethod def check_file(file): if not os.path.isfile(file) and not os.access(file, os.R_OK): print('[X] ' + file + ' file is missing or not readable') sys.exit(1) else: return file @staticmethod def get_validator_by_id(validator_id): attribute = '__validator_id__' for name, obj in inspect.getmembers(sys.modules[__name__]): if hasattr(obj, attribute) and getattr(obj, attribute) == validator_id: return obj() raise ValueError('Could not find Validator with validator id %d' % validator_id)
def get_hist6years(split=0.2, seg_len=3, debug=False, datatype='cnn', datafile=None, predict_days=18): log = logger.log db = Db() engine = db._get_engine() sql_stocklist = "select * from trade_hist where code in (select code from trade_hist where high<>0.0 and low <>0.0 group by code having count(code)>100)" if debug: sql_stocklist += " and code in ('002717','601888','002405')" df = pd.read_sql_query(sql_stocklist, engine) # 增加技术指标 df = add_volatility(df) stockcodes = df['code'].unique() df = get_technique(df, stockcodes) X_train = [] X_valid = [] Y_train = [] Y_valid = [] ID_train = [] ID_valid = [] log.info('begin generate train data and validate data.') begin_time = time.clock() k = 0 predict_days = predict_days for codes in stockcodes: temp_df = df[df.code == codes] temp_df1 = temp_df.copy(deep=True) temp_df1 = temp_df1.sort_values(by='c_yearmonthday', ascending=1) tradedaylist = temp_df1['c_yearmonthday'].values tradedaylist.sort() tradedaylist = tradedaylist[::-1] temp_df1 = temp_df1.set_index('c_yearmonthday') if len(tradedaylist) < seg_len: log.info('not enough trade days ,code is :%s', codes) continue validdays = np.round(split * len(tradedaylist)) # validdays = 2 i = 0 for day in tradedaylist: i += 1 segdays = tradedaylist[i + predict_days:i + predict_days + seg_len] segbegin = segdays[len(segdays) - 1] segend = segdays[0] if len(segdays) < seg_len: break data = [] # for segday in segdays: data = temp_df1.loc[segbegin:segend, [ 'open', 'high', 'close', 'low', 'volume', 'price_change', 'p_change', 'ma5', 'ma10', 'ma20', 'v_ma5', 'v_ma10', 'v_ma20', 'turnover', 'deltat', 'BIAS_B', 'BIAS_S', 'BOLL_B', 'BOLL_S', 'CCI_B', 'CCI_S', 'DMI_B', 'DMI_HL', 'DMI_IF1', 'DMI_IF2', 'DMI_MAX1', 'DMI_S', 'KDJ_B', 'KDJ_S', 'KD_B', 'KD_S', 'MACD', 'MACD_B', 'MACD_DEA', 'MACD_DIFF', 'MACD_EMA_12', 'MACD_EMA_26', 'MACD_EMA_9', 'MACD_S', 'MA_B', 'MA_S', 'PSY_B', 'PSY_MYPSY1', 'PSY_S', 'ROC_B', 'ROC_S', 'RSI_B', 'RSI_S', 'VR_B', 'VR_IF1', 'VR_IF2', 'VR_IF3', 'VR_S', 'XYYH_B', 'XYYH_B1', 'XYYH_B2', 'XYYH_B3', 'XYYH_CC', 'XYYH_DD' ]] data = data.values if datatype == 'cnn': data = [data] d1 = tradedaylist[i - 1] d3 = tradedaylist[i + predict_days - 1] data_tag = temp_df[temp_df.c_yearmonthday == d1][ ['code', 'name', 'p_change', 'close']] data_tag3 = temp_df[temp_df.c_yearmonthday == d3][ ['code', 'name', 'p_change', 'close']] temp_y = data_tag['close'].values[0] temp_y3 = data_tag3['close'].values[0] temp_y = (temp_y - temp_y3) / temp_y3 temp_y = to_cate01(temp_y) temp_id = data_tag['code'].values[0] if (i > 0 and i <= validdays): X_valid.append(data) ID_valid.append(temp_id) Y_valid.append(temp_y) else: X_train.append(data) ID_train.append(temp_id) Y_train.append(temp_y) k += 1 samples = 12 if k % samples == 0: print k log.info('%s stock finished ', k) yield ((np.asarray(X_train), np.asarray(Y_train), np.asarray(ID_train)), (np.asarray(X_valid), np.asarray(Y_valid), np.asarray(ID_valid))) X_train = [] X_valid = [] Y_train = [] Y_valid = [] ID_train = [] ID_valid = [] yield ((np.asarray(X_train), np.asarray(Y_train), np.asarray(ID_train)), (np.asarray(X_valid), np.asarray(Y_valid), np.asarray(ID_valid)))
def get_today(split=0.2, seg_len=3, debug=False, datatype='cnn', datafile=None): log = logger.log db = Db() engine = db._get_engine() sql_stocklist = "select * from trade_hist where code in (select code from trade_hist where high<>0.0 and low <>0.0 group by code having count(code)>100)" if debug: sql_stocklist += " and code in ('002717','601888','002405')" df = pd.read_sql_query(sql_stocklist, engine) df = add_volatility(df) stockcodes = df['code'].unique() df = get_technique(df) print stockcodes X_predict = [] ID_predict = [] NAME_predict = [] log.info('begin generate train data and validate data.') k = 0 for codes in stockcodes: temp_df = df[df.code == codes] temp_df1 = temp_df.copy(deep=True) temp_df1 = temp_df1.sort_values(by='c_yearmonthday', ascending=1) tradedaylist = temp_df1['c_yearmonthday'].values tradedaylist.sort() tradedaylist = tradedaylist[::-1] temp_df1 = temp_df1.set_index('c_yearmonthday') if len(tradedaylist) < seg_len: log.info('not enough trade days ,code is :%s', codes) continue i = 0 segdays = tradedaylist[i:i + seg_len] segbegin = segdays[len(segdays) - 1] segend = segdays[0] if len(segdays) < seg_len: break data = [] data = temp_df1.loc[segbegin:segend, [ 'open', 'high', 'close', 'low', 'volume', 'price_change', 'p_change', 'ma5', 'ma10', 'ma20', 'v_ma5', 'v_ma10', 'v_ma20', 'turnover', 'deltat', 'BIAS_B', 'BIAS_S', 'BOLL_B', 'BOLL_S', 'CCI_B', 'CCI_S', 'DMI_B', 'DMI_HL', 'DMI_IF1', 'DMI_IF2', 'DMI_MAX1', 'DMI_S', 'KDJ_B', 'KDJ_S', 'KD_B', 'KD_S', 'MACD', 'MACD_B', 'MACD_DEA', 'MACD_DIFF', 'MACD_EMA_12', 'MACD_EMA_26', 'MACD_EMA_9', 'MACD_S', 'MA_B', 'MA_S', 'PSY_B', 'PSY_MYPSY1', 'PSY_S', 'ROC_B', 'ROC_S', 'RSI_B', 'RSI_S', 'VR_B', 'VR_IF1', 'VR_IF2', 'VR_IF3', 'VR_S', 'XYYH_B', 'XYYH_B1', 'XYYH_B2', 'XYYH_B3', 'XYYH_CC', 'XYYH_DD' ]] data = data.values if datatype == 'cnn': data = [data] data_tag = temp_df[temp_df.c_yearmonthday == tradedaylist[0]][ ['code', 'name', 'p_change']] temp_id = data_tag['code'].values[0] temp_name = data_tag['name'].values[0] X_predict.append(data) ID_predict.append(temp_id) NAME_predict.append(temp_name) k += 1 log.info('%s stock finished ', k) return (np.asarray(X_predict), np.asarray(ID_predict), np.asarray(NAME_predict))
def save(self) -> Dict: try: return self.dict() if Db.save("colaborador", self) else {} except: # TODO EXCEPT raise Exception
def get_histdata(split=0.15, seg_len=3, debug=False, datatype='cnn'): db = Db() engine = db._get_engine() sql_stocklist = "select * from trade_record where code in (select code from trade_record where high<>0.0 and low <>0.0 group by code having count(code)=(select count(distinct c_yearmonthday) from trade_record))" if debug: sql_stocklist += " and code in ('300138','002372')" df = pd.read_sql_query(sql_stocklist, engine) stockcodes = df['code'].unique() X_train = [] X_valid = [] Y_train = [] Y_valid = [] ID_train = [] ID_valid = [] log.info('begin generate train data and validate data.') begin_time = time.clock() k = 0 for codes in stockcodes: temp_df = df[df.code == codes] tradedaylist = temp_df.copy(deep=True)['c_yearmonthday'].values tradedaylist.sort() tradedaylist = tradedaylist[::-1] if len(tradedaylist) < seg_len: log.info('not enough trade days ,code is :%s', codes) continue validdays = np.round(split * len(tradedaylist)) i = 0 for day in tradedaylist: i += 1 segdays = tradedaylist[i:i + seg_len] if len(segdays) < seg_len: break SEG_X = [] data = [] for segday in segdays: data = temp_df[temp_df.c_yearmonthday == segday][ ['changepercent', 'trade', 'open', 'high', 'low', 'settlement', 'volume', 'turnoverratio', 'amount', 'per', 'pb', 'mktcap', 'nmc', 'deltat']] data = data.values SEG_X.append(data[0]) # SEG_X=np.array(SEG_X).T if datatype == 'cnn': SEG_X = [SEG_X] data_tag = temp_df[temp_df.c_yearmonthday == day][ ['code', 'name', 'changepercent']] temp_y = data_tag['changepercent'].values[0] temp_y = to_cate01(temp_y) temp_id = data_tag['code'].values[0] if (i > 0 and i <= validdays): X_valid.append(SEG_X) ID_valid.append(temp_id) Y_valid.append(temp_y) else: X_train.append(SEG_X) ID_train.append(temp_id) Y_train.append(temp_y) k += 1 if k % 500 == 0: log.info('%s stock finished ', k) log.info('generate data finished ,cost time:%s', time.clock() - begin_time) log.info('X_train shape is :%s', np.asarray(X_train).shape) log.info('Y_train shape is :%s', np.asarray(Y_train).shape) log.info('X_valid shape is :%s', np.asarray(X_valid).shape) log.info('Y_valid shape is :%s', np.asarray(Y_valid).shape) # X_train=normalize(X_train) # X_valid=normalize(X_valid) if debug: print(np.asarray(X_train), np.asarray(Y_train), np.asarray(ID_train)), (np.asarray(X_valid), np.asarray(Y_valid), np.asarray(ID_valid)) print(np.asarray(X_train[0][0][0])) pickle.dump( ((np.asarray(X_train), np.asarray(Y_train), np.asarray(ID_train)), (np.asarray(X_valid), np.asarray(Y_valid), np.asarray(ID_valid))), open(datatype + '_seg' + str(seg_len) + '.pkl', 'wb'))
def run(self): session = Db.get_session() for f in session.query(File.path): self.queue.put((10, PathUpdate(f.path, False)))