def main(argv): del argv random.seed(123) show = np_common.Show(True) if FLAGS.r: data_source = tushare_data.DataSource(20000101, '', '', 1, 20000101, 20200306) while True: code_index = random.randint(0, len(data_source.code_list) - 1) ts_code = data_source.code_list[code_index] print(ts_code) code_name = data_source.name_list[code_index] pp_data = data_source.LoadStockPPData(ts_code, True) if len(pp_data) == 0: return day_index = random.randint(0, len(pp_data) - 1) show.SetTitle('%s - %s' % (ts_code, code_name)) RandShow(pp_data, show, day_index) else: data_source = tushare_data.DataSource(0, '', '', 1, 20000101, 20200603) ts_code = FLAGS.c code_name = data_source.StockName(ts_code) data_source.DownloadStockData(ts_code) data_source.UpdateStockPPData(ts_code) pp_data = data_source.LoadStockPPData(ts_code, True) show.SetTitle('%s - %s' % (ts_code, code_name)) ShowAStockLowLevel(pp_data, show, [PPI_open, PPI_vol])
def ShowAStock_(ts_code, show_values=[ PPI_open, PPI_close, PPI_close_5_avg, PPI_close_10_avg, PPI_close_30_avg, PPI_close_100_avg, PPI_vol, PPI_vol_10_avg, PPI_vol_100_avg ]): plt, mdate, _, _, zhfont = base_common.ImportMatPlot() # current_date_int = int(base_common.CurrentDate()) current_date_int = 20200403 o_data_source = tushare_data.DataSource(current_date_int, '', '', 1, 0, current_date_int) o_data_source.DownloadStockData(ts_code) o_data_source.UpdateStockPPData(ts_code) pp_data = o_data_source.LoadStockPPData(ts_code) preprocess.PPRegionCompute(pp_data, PPI_close, PPI_close_5_avg, 500, len(pp_data), np.mean) stock_name = o_data_source.StockName(ts_code) if len(pp_data) == 0: return title = "%s - %s" % (ts_code, stock_name) title = unicode(title, "utf-8") fig1 = plt.figure(dpi=70, figsize=(32, 10)) ax1 = fig1.add_subplot(1, 1, 1) ax1.xaxis.set_major_formatter(mdate.DateFormatter('%Y-%m-%d')) plt.title(title, fontproperties=zhfont) plt.xlabel('date') plt.ylabel('price') date_str_arr = np.array(["%.0f" % x for x in pp_data[:, PPI_trade_date]]) xs = [datetime.strptime(d, '%Y%m%d').date() for d in date_str_arr] plt.grid(True) vol_list = [ PPI_vol, PPI_vol_5_avg, PPI_vol_10_avg, PPI_vol_30_avg, PPI_vol_100_avg ] print(set(show_values) - set(vol_list)) if len(set(show_values) - set(vol_list)) > 0: close_max = max(pp_data[:, PPI_close]) vol_max = max(pp_data[:, PPI_vol]) vol_ratio = 1.0 / vol_max * close_max / 2 else: vol_ratio = 1.0 for col_index in show_values: name = '' if col_index in vol_list: # plt.bar(xs, pp_data[:, col_index] * vol_ratio, 0.8) plt.plot(xs, pp_data[:, col_index] * vol_ratio, label=PPI_name[col_index], linewidth=1) else: plt.plot(xs, pp_data[:, col_index], label=PPI_name[col_index], linewidth=1) plt.gcf().autofmt_xdate() plt.legend() #plt.ylim([0,5]) plt.show() plt.pause(1000)
def main(argv): del argv o_data_source = tushare_data.DataSource(20000101, '', '', 1, 20120101, 20200106, False, False, True) o_feature = feature.Feature(30, feature.FUT_D5_NORM, 1, False, False) # o_feature = feature.Feature(30, feature.FUT_D5_NORM, 1, False, False) # o_feature = feature.Feature(30, feature.FUT_5REGION5_NORM, 5, False, False) # o_feature = feature.Feature(30, feature.FUT_2AVG5_NORM, 5, False, False) o_wave = ExtremeWave(o_data_source, o_feature, 2, 2, False, 0, 0.1, 5) split_date = 20180101 o_dl_model = dl_model.DLModel('%s_%u' % (o_wave.setting_name, split_date), o_feature.feature_unit_num, o_feature.feature_unit_size, 32, 10240, 0.004, 'mean_absolute_tp0_max_ratio_error') if FLAGS.mode == 'data': o_data_source.DownloadData() o_data_source.UpdatePPData() elif FLAGS.mode == 'testall': o_wave.TradeTestAll() elif FLAGS.mode == 'test': o_wave.TradeTestStock(FLAGS.c, FLAGS.show) elif FLAGS.mode == 'show': o_wave.ShowTradePP(FLAGS.c) elif FLAGS.mode == 'train': tf, tl, vf, vl, td = o_wave.GetDataset(split_date) tl = tl * 100.0 vl = vl * 100.0 o_dl_model.Train(tf, tl, vf, vl, FLAGS.epoch) elif FLAGS.mode == 'rtest': tf, tl, tf, tl, ta = o_wave.GetDataset(split_date) o_dl_model.LoadModel(FLAGS.epoch) o_wave.RTest(o_dl_model, tf, ta, False) exit()
def main(argv): del argv o_data_source = tushare_data.DataSource(20000101, '', '', 1, 20000101, 20200403, False, False, True) o_feature = feature.Feature(30, feature.FUT_D5_NORM_PCT, 1, False, False) # o_feature = feature.Feature(30, feature.FUT_D5_NORM, 1, False, False) # o_feature = feature.Feature(30, feature.FUT_5REGION5_NORM, 5, False, False) # o_feature = feature.Feature(30, feature.FUT_2AVG5_NORM, 5, False, False) o_trade = Breakup(o_data_source, o_feature, PPI_close_100_avg, 10, 10, 3.0, 0.1) split_date = 20180101 o_dl_model = dl_model.DLModel('%s_%u' % (o_trade.setting_name, split_date), o_feature.feature_unit_num, o_feature.feature_unit_size, 32, 10240, 0.004, 'mean_absolute_tp0_max_ratio_error') if FLAGS.mode == 'data': o_data_source.DownloadData() o_data_source.UpdatePPData() elif FLAGS.mode == 'testall': o_trade.TradeTestAll(True, FLAGS.show) elif FLAGS.mode == 'test': o_data_source.DownloadStockData(FLAGS.c) o_data_source.UpdateStockPPData(FLAGS.c) start_time = time.time() # o_trade.TradeTestStock(FLAGS.c, FLAGS.show) o_trade.Test(FLAGS.c) elif FLAGS.mode == 'train': tf, tl, vf, vl, td = o_trade.GetDataset(split_date) tl = tl * 100.0 vl = vl * 100.0 o_dl_model.Train(tf, tl, vf, vl, FLAGS.epoch) elif FLAGS.mode == 'rtest': tf, tl, tf, tl, ta = o_trade.GetDataset(split_date) o_dl_model.LoadModel(FLAGS.epoch) o_trade.RTest(o_dl_model, tf, ta, False) elif FLAGS.mode == 'dsw': dataset = o_trade.ShowDSW3DDataset() elif FLAGS.mode == 'show': dataset = o_trade.ShowTradePP(FLAGS.c) exit()
def main(argv): del argv o_data_source = tushare_data.DataSource(20000101, '', '', 1, 20100101, 20200306, False, False, True) o_feature = feature.Feature(7, feature.FUT_D5_NORM, 1, False, False) o_vol_wave = VolWave(o_data_source, o_feature, 0.1) # split_date = 20180101 # o_dl_model = dl_model.DLModel('%s_%u' % (o_vol_wave.setting_name, split_date), # o_feature.feature_unit_num, # o_feature.feature_unit_size, # 32, 10240, 0.004, 'mean_absolute_tp0_max_ratio_error') if FLAGS.mode == 'data': o_data_source.DownloadData() o_data_source.UpdatePPData() elif FLAGS.mode == 'testall': o_vol_wave.TradeTestAll(True, FLAGS.show) elif FLAGS.mode == 'test': o_data_source.DownloadStockData(FLAGS.c) o_data_source.UpdateStockPPData(FLAGS.c) start_time = time.time() o_vol_wave.TradeTestStock(FLAGS.c, FLAGS.show) print(time.time() - start_time) elif FLAGS.mode == 'train': tf, tl, vf, vl, td = o_vol_wave.GetDataset(split_date) tl = tl * 100.0 vl = vl * 100.0 o_dl_model.Train(tf, tl, vf, vl, FLAGS.epoch) elif FLAGS.mode == 'rtest': tf, tl, tf, tl, ta = o_vol_wave.GetDataset(split_date) o_dl_model.LoadModel(FLAGS.epoch) o_vol_wave.RTest(o_dl_model, tf, ta, False) elif FLAGS.mode == 'dsw': dataset = o_vol_wave.ShowDSW3DDataset() elif FLAGS.mode == 'show': dataset = o_vol_wave.ShowTradePP(FLAGS.c) exit()
(self.code_index_map[ts_code], ts_code)) def CreateDSFa3DDataset(self): dataset_file_name = self.FileNameDSFa3DDataset() self.dataset = np.zeros( (len(self.date_list), len(self.code_list), self.feature.unit_size)) # base_common.ListMultiThread(CreateDSFa3DSplitMTFunc, self, 1, self.code_list) for ts_code in self.code_list: self.CreateDSFa3DSplitStock(ts_code) base_common.MKFileDirs(dataset_file_name) np.save(dataset_file_name, self.dataset) self.dataset = None def GetDSFa3DDataset(self): dataset_file_name = self.FileNameDSFa3DDataset() if not os.path.exists(dataset_file_name): self.CreateDSFa3DDataset() return np.load(dataset_file_name) if __name__ == "__main__": data_source = tushare_data.DataSource(20000101, '', '', 1, 20000101, 20200106, False, False, True) data_source.ShowStockCodes() o_feature = feature.Feature(30, feature.FUT_D5_NORM, 1, False, False) o_dataset = DSFa3DDataset(data_source, o_feature) temp_dataset = o_dataset.GetDSFa3DDataset() print("dataset: {}".format(temp_dataset.shape))
def main(argv): del argv end_date = 20200306 split_date = 20100101 o_data_source = tushare_data.DataSource(20000101, '', '', 10, 20000101, end_date, False, False, True) o_feature = feature.Feature(10, feature.FUT_D5_NORM, 1, False, False) obj = OpenClose(o_data_source, o_feature, not FLAGS.overlap_feature) o_dl_model = dl_model.DLModel( '%s_%u' % (obj.setting_name, split_date), o_feature.feature_unit_num, o_feature.feature_unit_size, # 32, 10240, 0.04, 'mean_absolute_tp0_max_ratio_error') # rtest<0 # 4, 10240, 0.04, 'mean_absolute_tp0_max_ratio_error') # rtest<0 # 4, 10240, 0.01, 'mean_absolute_tp0_max_ratio_error') # rtest:0.14 32, 10240, 0.03, 'mean_absolute_tp_max_ratio_error_tanhmap', 50) # rtest:0.62 # 16, 10240, 0.01, 'mean_absolute_tp0_max_ratio_error') # rtest<0 # 16, 10240, 0.01, 'mean_absolute_tp_max_ratio_error_tanhmap', 100) if FLAGS.mode == 'datasource': o_data_source.DownloadData() o_data_source.UpdatePPData() elif FLAGS.mode == 'dataset': obj.CreateDataSet() elif FLAGS.mode == 'public_dataset': obj.CreateDataSet() public_dataset = obj.PublicDataset() file_name = './public/data/dataset.npy' np.save(file_name, public_dataset) elif FLAGS.mode == 'train': tf, tl, vf, vl, td = obj.GetDataset(split_date) # tf, tl, vf, vl, va = obj.GetDatasetRandom(0.5) train_epoch = FLAGS.epoch if FLAGS.epoch > 0 else 250 o_dl_model.Train(tf, tl, vf, vl, train_epoch) elif FLAGS.mode == 'rtest': tf, tl, vf, vl, va = obj.GetDataset(split_date) # tf, tl, vf, vl, va = obj.GetDatasetRandom(0.5) o_dl_model.LoadModel(FLAGS.epoch) obj.RTest(o_dl_model, vf, va, False) # elif FLAGS.mode == 'dqntest': # o_dl_model.LoadModel(FLAGS.epoch) # o_dsfa = dsfa3d_dataset.DSFa3DDataset(o_data_source, o_feature) # o_dqn_test = dqn_test.DQNTest(o_dsfa, split_date, o_dl_model) # o_dqn_test.Test(1, FLAGS.pt, True, FLAGS.show) # elif FLAGS.mode == 'dqntestall': # o_dl_model.LoadModel(FLAGS.epoch) # o_dsfa = dsfa3d_dataset.DSFa3DDataset(o_data_source, o_feature) # o_dqn_test = dqn_test.DQNTest(o_dsfa, split_date, o_dl_model) # o_dqn_test.TestAllModels(1, FLAGS.pt) # elif FLAGS.mode == 'predict': # o_dl_model.LoadModel(FLAGS.epoch) # o_data_source.SetPPDataDailyUpdate(20180101, 20200323) # o_dsfa = dsfa3d_dataset.DSFa3DDataset(o_data_source, o_feature) # o_dqn_test = dqn_test.DQNTest(o_dsfa, split_date, o_dl_model) # o_dqn_test.Test(1, FLAGS.pt, True, FLAGS.show) elif FLAGS.mode == 'dsw': dataset = obj.ShowDSW3DDataset() elif FLAGS.mode == 'show': dataset = obj.ShowTradePP(FLAGS.c) elif FLAGS.mode == 'showlabel': dataset = obj.ShowLabel() elif FLAGS.mode == 'debug': dataset = np.load(obj.FileNameDataset()) print("dataset: {}".format(dataset.shape)) dataset = np_common.Sort2D(dataset, [obj.index_increase], [False]) dataset = dataset[:5] obj.ShowDataSet(dataset, 'dataset') elif FLAGS.mode == 'clean': obj.Clean() o_dl_model.Clean() elif FLAGS.mode == 'pp': o_data_source.ShowStockPPData(FLAGS.c, FLAGS.date) elif FLAGS.mode == 'vol': o_data_source.ShowAvgVol(100000) exit()
def main(argv): del argv code_filter = '\ 000001.SZ,\ 000002.SZ,\ 000063.SZ,\ 000538.SZ,\ 000541.SZ,\ 000550.SZ,\ 000560.SZ,\ 000561.SZ,\ 000584.SZ,\ 000625.SZ,\ 000650.SZ,\ 000651.SZ,\ 000721.SZ,\ 000800.SZ,\ 000802.SZ,\ 000858.SZ,\ 000898.SZ,\ 000913.SZ,\ 000927.SZ,\ 000932.SZ,\ 000937.SZ,\ 000938.SZ,\ 000951.SZ,\ 000959.SZ,\ 001696.SZ,\ 600000.SH,\ 600006.SH,\ 600085.SH,\ 600104.SH,\ 600109.SH,\ 600115.SH,\ 600137.SH,\ 600177.SH,\ 600198.SH,\ 600199.SH,\ 600600.SH,\ 600601.SH,\ 600609.SH,\ 600612.SH,\ 600623.SH,\ 600624.SH,\ 600664.SH,\ 600679.SH,\ 600702.SH,\ 600718.SH,\ 600809.SH' o_data_source = tushare_data.DataSource(20000101, '', code_filter, 1, 20130101, 20200106, False, False, True) o_feature = feature.Feature(30, feature.FUT_D5_NORM, 1, False, False) # o_feature = feature.Feature(30, feature.FUT_D5_NORM, 1, False, False) # o_feature = feature.Feature(30, feature.FUT_5REGION5_NORM, 5, False, False) # o_feature = feature.Feature(30, feature.FUT_2AVG5_NORM, 5, False, False) o_avg_wave = AvgWave(o_data_source, o_feature, PPI_close_30_avg, MODE_GRAD, 0, 0, 0.1) split_date = 20180101 o_dl_model = dl_model.DLModel( '%s_%u' % (o_avg_wave.setting_name, split_date), o_feature.feature_unit_num, o_feature.feature_unit_size, 32, 10240, 0.004, 'mean_absolute_tp0_max_ratio_error') if FLAGS.mode == 'data': o_data_source.DownloadData() o_data_source.UpdatePPData() elif FLAGS.mode == 'testall': o_avg_wave.TradeTestAll(True, FLAGS.show) elif FLAGS.mode == 'test': o_data_source.DownloadStockData(FLAGS.c) o_data_source.UpdateStockPPData(FLAGS.c) start_time = time.time() o_avg_wave.TradeTestStock(FLAGS.c, FLAGS.show) print(time.time() - start_time) elif FLAGS.mode == 'train': tf, tl, vf, vl, td = o_avg_wave.GetDataset(split_date) tl = tl * 100.0 vl = vl * 100.0 o_dl_model.Train(tf, tl, vf, vl, FLAGS.epoch) elif FLAGS.mode == 'rtest': tf, tl, tf, tl, ta = o_avg_wave.GetDataset(split_date) o_dl_model.LoadModel(FLAGS.epoch) o_avg_wave.RTest(o_dl_model, tf, ta, False) elif FLAGS.mode == 'dsw': dataset = o_avg_wave.ShowDSW3DDataset() elif FLAGS.mode == 'show': dataset = o_avg_wave.ShowTradePP(FLAGS.c) exit()
def main(argv): del argv code_filter = '\ 000001.SZ,\ 000002.SZ,\ 000005.SZ,\ 000006.SZ,\ 000009.SZ,\ 000012.SZ,\ 000016.SZ,\ 000021.SZ,\ 000027.SZ,\ 000031.SZ,\ 000036.SZ,\ 000039.SZ,\ 000050.SZ,\ 000059.SZ,\ 000060.SZ,\ 000063.SZ,\ 000066.SZ,\ 000069.SZ,\ 000078.SZ,\ 000088.SZ,\ 000089.SZ,\ 000401.SZ,\ 000402.SZ,\ 000410.SZ,\ 000413.SZ,\ 000422.SZ,\ 000425.SZ,\ 000503.SZ,\ 000507.SZ,\ 000510.SZ,\ 000518.SZ,\ 000520.SZ,\ 000528.SZ,\ 000540.SZ,\ 000554.SZ,\ 000559.SZ,\ 000563.SZ,\ 000571.SZ,\ 000572.SZ,\ 000592.SZ,\ 000598.SZ,\ 000601.SZ,\ 000616.SZ,\ 000625.SZ,\ 000627.SZ,\ 000629.SZ,\ 000630.SZ,\ 000636.SZ,\ 000650.SZ,\ 000651.SZ,\ 000652.SZ,\ 000656.SZ,\ 000659.SZ,\ 000667.SZ,\ 000670.SZ,\ 000680.SZ,\ 000682.SZ,\ 000683.SZ,\ 000686.SZ,\ 000690.SZ,\ 000709.SZ,\ 000717.SZ,\ 000718.SZ,\ 000720.SZ,\ 000723.SZ,\ 000727.SZ,\ 000728.SZ,\ 000735.SZ,\ 000750.SZ,\ 000751.SZ,\ 000758.SZ,\ 000767.SZ,\ 000768.SZ,\ 000776.SZ,\ 000778.SZ,\ 000783.SZ,\ 000789.SZ,\ 000793.SZ,\ 000795.SZ,\ 000800.SZ,\ 000806.SZ,\ 000807.SZ,\ 000816.SZ,\ 000822.SZ,\ 000823.SZ,\ 000825.SZ,\ 000829.SZ,\ 000830.SZ,\ 000831.SZ,\ 000839.SZ,\ 000851.SZ,\ 000858.SZ,\ 000859.SZ,\ 000868.SZ,\ 000876.SZ,\ 000877.SZ,\ 000878.SZ,\ 000882.SZ,\ 000886.SZ,\ 000897.SZ,\ 000898.SZ,\ 000917.SZ,\ 000927.SZ,\ 000930.SZ,\ 000931.SZ,\ 000932.SZ,\ 000933.SZ,\ 000937.SZ,\ 000959.SZ,\ 600000.SH,\ 600006.SH,\ 600060.SH,\ 600063.SH,\ 600067.SH,\ 600068.SH,\ 600069.SH,\ 600078.SH,\ 600089.SH,\ 600100.SH,\ 600103.SH,\ 600104.SH,\ 600108.SH,\ 600109.SH,\ 600110.SH,\ 600111.SH,\ 600115.SH,\ 600122.SH,\ 600123.SH,\ 600125.SH,\ 600127.SH,\ 600151.SH,\ 600153.SH,\ 600157.SH,\ 600158.SH,\ 600160.SH,\ 600162.SH,\ 600166.SH,\ 600169.SH,\ 600170.SH,\ 600171.SH,\ 600175.SH,\ 600176.SH,\ 600177.SH,\ 600183.SH,\ 600186.SH,\ 600187.SH,\ 600188.SH,\ 600196.SH,\ 600198.SH,\ 600200.SH,\ 600206.SH,\ 600208.SH,\ 600209.SH,\ 600210.SH,\ 600212.SH,\ 600216.SH,\ 600219.SH,\ 600220.SH,\ 600221.SH,\ 600226.SH,\ 600239.SH,\ 600266.SH,\ 600601.SH,\ 600606.SH,\ 600609.SH,\ 600611.SH,\ 600624.SH,\ 600630.SH,\ 600635.SH,\ 600642.SH,\ 600643.SH,\ 600649.SH,\ 600651.SH,\ 600652.SH,\ 600653.SH,\ 600660.SH,\ 600662.SH,\ 600664.SH,\ 600667.SH,\ 600673.SH,\ 600674.SH,\ 600675.SH,\ 600677.SH,\ 600688.SH,\ 600690.SH,\ 600703.SH,\ 600704.SH,\ 600705.SH,\ 600711.SH,\ 600717.SH,\ 600718.SH,\ 600720.SH,\ 600736.SH,\ 600737.SH,\ 600739.SH,\ 600740.SH,\ 600741.SH,\ 600755.SH,\ 600759.SH,\ 600770.SH,\ 600776.SH,\ 600777.SH,\ 600782.SH,\ 600787.SH,\ 600789.SH,\ 600795.SH,\ 600797.SH,\ 600800.SH,\ 600804.SH,\ 600805.SH,\ 600808.SH,\ 600811.SH,\ 600812.SH,\ 600816.SH,\ 600820.SH,\ 600837.SH,\ 600839.SH,\ 600846.SH,\ 600851.SH,\ 600863.SH,\ 600868.SH,\ 600871.SH,\ 600872.SH,\ 600873.SH,\ 600874.SH,\ 600875.SH,\ 600877.SH,\ 600879.SH,\ 600881.SH,\ 600886.SH,\ 600887.SH,\ 600895.SH,\ 601607.SH\ ' # end_date = 20200106 end_date = 20200306 # end_date = 20190221 split_date = 20100101 o_data_source = tushare_data.DataSource(20000101, '', '', 2, 20000101, end_date, False, False, True) # o_feature = feature.Feature(30, feature.FUT_D5_NORM, 1, False, False) o_feature = feature.Feature(20, feature.FUT_D5_NORM, 1, False, False) # o_feature = feature.Feature(7, feature.FUT_5REGION5_NORM, 5, False, False) # o_feature = feature.Feature(30, feature.FUT_5REGION5_NORM, 5, False, False) # o_feature = feature.Feature(30, feature.FUT_D3_NORM, 1, False, False) # o_dqn_fix = DQNFix(o_data_source, o_feature, 6, DECAY_MODE_EXP, 0.6, not FLAGS.overlap_feature) o_dqn_fix = DQNFix(o_data_source, o_feature, 20, False, DECAY_MODE_EXP, 0.9, not FLAGS.overlap_feature) # o_dqn_fix = DQNFix(o_data_source, o_feature, 30, DECAY_MODE_EXP, 0.9, not FLAGS.overlap_feature) o_dl_model = dl_model.DLModel( '%s_%u' % (o_dqn_fix.setting_name, split_date), o_feature.feature_unit_num, o_feature.feature_unit_size, # 32, 10240, 0.04, 'mean_absolute_tp0_max_ratio_error') # rtest<0 # 4, 10240, 0.04, 'mean_absolute_tp0_max_ratio_error') # rtest<0 # 4, 10240, 0.01, 'mean_absolute_tp0_max_ratio_error') # rtest:0.14 0, 10240, 0.03, 'mean_absolute_tp_max_ratio_error_tanhmap', 50) # rtest:0.62 # 16, 10240, 0.01, 'mean_absolute_tp0_max_ratio_error') # rtest<0 # 16, 10240, 0.01, 'mean_absolute_tp_max_ratio_error_tanhmap', 100) if FLAGS.mode == 'datasource': o_data_source.DownloadData() o_data_source.UpdatePPData() elif FLAGS.mode == 'dataset': o_dqn_fix.CreateDataSet() elif FLAGS.mode == 'public_dataset': o_dqn_fix.CreateDataSet() public_dataset = o_dqn_fix.PublicDataset() file_name = './public/data/dataset.npy' np.save(file_name, public_dataset) elif FLAGS.mode == 'train': tf, tl, vf, vl, td = o_dqn_fix.GetDataset(split_date) # tf, tl, vf, vl, va = o_dqn_fix.GetDatasetRandom(0.5) train_epoch = FLAGS.epoch if FLAGS.epoch > 0 else 250 o_dl_model.Train(tf, tl, vf, vl, train_epoch) elif FLAGS.mode == 'rtest': tf, tl, vf, vl, va = o_dqn_fix.GetDataset(split_date) # tf, tl, vf, vl, va = o_dqn_fix.GetDatasetRandom(0.5) o_dl_model.LoadModel(FLAGS.epoch) o_dqn_fix.RTest(o_dl_model, vf, va, False) elif FLAGS.mode == 'dqntest': o_dl_model.LoadModel(FLAGS.epoch) o_dsfa = dsfa3d_dataset.DSFa3DDataset(o_data_source, o_feature) o_dqn_test = dqn_test.DQNTest(o_dsfa, split_date, o_dl_model) o_dqn_test.Test(1, FLAGS.pt, True, FLAGS.show) elif FLAGS.mode == 'dqntestall': o_dl_model.LoadModel(FLAGS.epoch) o_dsfa = dsfa3d_dataset.DSFa3DDataset(o_data_source, o_feature) o_dqn_test = dqn_test.DQNTest(o_dsfa, split_date, o_dl_model) o_dqn_test.TestAllModels(1, FLAGS.pt) elif FLAGS.mode == 'predict': o_dl_model.LoadModel(FLAGS.epoch) o_data_source.SetPPDataDailyUpdate(20180101, 20200323) o_dsfa = dsfa3d_dataset.DSFa3DDataset(o_data_source, o_feature) o_dqn_test = dqn_test.DQNTest(o_dsfa, split_date, o_dl_model) o_dqn_test.Test(1, FLAGS.pt, True, FLAGS.show) elif FLAGS.mode == 'dsw': dataset = o_dqn_fix.ShowDSW3DDataset() elif FLAGS.mode == 'show': dataset = o_dqn_fix.ShowTradePP(FLAGS.c) elif FLAGS.mode == 'showlabel': dataset = o_dqn_fix.ShowLabel() elif FLAGS.mode == 'debug': dataset = np.load(o_dqn_fix.FileNameDataset()) print("dataset: {}".format(dataset.shape)) dataset = np_common.Sort2D(dataset, [o_dqn_fix.index_increase], [False]) dataset = dataset[:5] o_dqn_fix.ShowDataSet(dataset, 'dataset') elif FLAGS.mode == 'clean': o_dqn_fix.Clean() o_dl_model.Clean() elif FLAGS.mode == 'pp': o_data_source.ShowStockPPData(FLAGS.c, FLAGS.date) elif FLAGS.mode == 'vol': o_data_source.ShowAvgVol(100000) exit()