def get_dataset(device): # with open(params.api_dir, "rb") as fh: # api = pkl.load(fh, encoding='latin1') api = SWDADialogCorpus(params.data_dir) dial_corpus = api.get_dialog_corpus() train_dial, test_dial = dial_corpus.get("train"), dial_corpus.get("test") # convert to numeric input outputs train_loader = SWDADataLoader("Train", train_dial, params.max_utt_len, params.max_dialog_len, device=device) valid_loader = test_loader = SWDADataLoader("Test", test_dial, params.max_utt_len, params.max_dialog_len, device=device) if api.word2vec is not None: return train_loader, valid_loader, test_loader, np.array(api.word2vec) else: return train_loader, valid_loader, test_loader, None
def main(): # config for training config = Config() config.use_bow = False # config for validation valid_config = Config() valid_config.keep_prob = 1.0 valid_config.dec_keep_prob = 1.0 valid_config.batch_size = 60 valid_config.use_bow = False # configuration for testing test_config = Config() test_config.keep_prob = 1.0 test_config.dec_keep_prob = 1.0 test_config.batch_size = 1 test_config.use_bow = False pp(config) # get data set api = SWDADialogCorpus(FLAGS.data_dir, word2vec=FLAGS.word2vec_path, word2vec_dim=config.embed_size) dial_corpus = api.get_dialog_corpus() meta_corpus = api.get_meta_corpus() train_meta, valid_meta, test_meta = meta_corpus.get( "train"), meta_corpus.get("valid"), meta_corpus.get("test") train_dial, valid_dial, test_dial = dial_corpus.get( "train"), dial_corpus.get("valid"), dial_corpus.get("test") # convert to numeric input outputs that fits into TF models train_feed = SWDADataLoader("Train", train_dial, train_meta, config) valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config) test_feed = SWDADataLoader("Test", test_dial, test_meta, config) # begin training sess_config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) # sess_config.gpu_options.allow_growth = True sess_config.gpu_options.per_process_gpu_memory_fraction = 0.45 with tf.Session(config=sess_config) as sess: initializer = tf.random_uniform_initializer(-1.0 * config.init_w, config.init_w) scope = "model" with tf.variable_scope(scope, reuse=None, initializer=initializer): model = KgRnnCVAE(sess, config, api, log_dir=None if FLAGS.forward_only else log_dir, forward=False, scope=scope) with tf.variable_scope(scope, reuse=True, initializer=initializer): valid_model = KgRnnCVAE(sess, valid_config, api, log_dir=None, forward=False, scope=scope) with tf.variable_scope(scope, reuse=True, initializer=initializer): test_model = KgRnnCVAE(sess, test_config, api, log_dir=None, forward=True, scope=scope) test_model.prepare_mul_ref() logger.info("Created computation graphs") if api.word2vec is not None and not FLAGS.forward_only: logger.info("Loaded word2vec") sess.run(model.embedding.assign(np.array(api.word2vec))) # write config to a file for logging if not FLAGS.forward_only: with open(os.path.join(log_dir, "run.log"), "wb") as f: f.write(pp(config, output=False)) # create a folder by force ckp_dir = os.path.join(log_dir, "checkpoints") if not os.path.exists(ckp_dir): os.mkdir(ckp_dir) ckpt = tf.train.get_checkpoint_state(ckp_dir) logger.info("Created models with fresh parameters.") sess.run(tf.global_variables_initializer()) if ckpt: logger.info("Reading dm models parameters from %s" % ckpt.model_checkpoint_path) model.saver.restore(sess, ckpt.model_checkpoint_path) if not FLAGS.forward_only: dm_checkpoint_path = os.path.join( ckp_dir, model.__class__.__name__ + ".ckpt") global_t = 1 patience = 10 # wait for at least 10 epoch before stop dev_loss_threshold = np.inf best_dev_loss = np.inf for epoch in range(config.max_epoch): logger.info(">> Epoch %d with lr %f" % (epoch, sess.run(model.learning_rate_cyc, {model.global_t: global_t}))) # begin training if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch: train_feed.epoch_init(config.batch_size, config.backward_size, config.step_size, shuffle=True) global_t, train_loss = model.train( global_t, sess, train_feed, update_limit=config.update_limit) # begin validation logger.record_tabular("Epoch", epoch) logger.record_tabular("Mode", "Val") valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_loss = valid_model.valid("ELBO_VALID", sess, valid_feed) logger.record_tabular("Epoch", epoch) logger.record_tabular("Mode", "Test") test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_model.valid("ELBO_TEST", sess, test_feed) # test_feed.epoch_init(test_config.batch_size, test_config.backward_size, # test_config.step_size, shuffle=True, intra_shuffle=False) # test_model.test_mul_ref(sess, test_feed, num_batch=5) done_epoch = epoch + 1 # only save a models if the dev loss is smaller # Decrease learning rate if no improvement was seen over last 3 times. if config.op == "sgd" and done_epoch > config.lr_hold: sess.run(model.learning_rate_decay_op) if valid_loss < best_dev_loss: if valid_loss <= dev_loss_threshold * config.improve_threshold: patience = max(patience, done_epoch * config.patient_increase) dev_loss_threshold = valid_loss # still save the best train model if FLAGS.save_model: logger.info("Save model!!") model.saver.save(sess, dm_checkpoint_path) best_dev_loss = valid_loss if (epoch % 3) == 2: tmp_model_path = os.path.join( ckp_dir, model.__class__.__name__ + str(epoch) + ".ckpt") model.saver.save(sess, tmp_model_path) if config.early_stop and patience <= done_epoch: logger.info("!!Early stop due to run out of patience!!") break logger.info("Best validation loss %f" % best_dev_loss) logger.info("Done training") else: # begin validation # begin validation valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_model.valid("ELBO_VALID", sess, valid_feed) test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_model.valid("ELBO_TEST", sess, test_feed) dest_f = open(os.path.join(log_dir, "test.txt"), "wb") test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=False, intra_shuffle=False) test_model.test_mul_ref(sess, test_feed, num_batch=None, repeat=5, dest=dest_f) dest_f.close()
def main(): # config for training config = Config() print("Normal train config:") pp(config) valid_config = Config() valid_config.dropout = 0 valid_config.batch_size = 20 # config for test test_config = Config() test_config.dropout = 0 test_config.batch_size = 1 with_sentiment = config.with_sentiment ############################################################################### # Load data ############################################################################### # sentiment data path: ../ final_data / poem_with_sentiment.txt # 该path必须命令行显示输入LoadPoem,因为defaultNonehjk # 处理pretrain数据和完整诗歌数据 # api = LoadPoem(args.train_data_dir, args.test_data_dir, args.max_vocab_size) api = LoadPoem(corpus_path=args.train_data_dir, test_path=args.test_data_dir, max_vocab_cnt=config.max_vocab_cnt, with_sentiment=with_sentiment) # 交替训练,准备大数据集 poem_corpus = api.get_tokenized_poem_corpus( type=1 + int(with_sentiment)) # corpus for training and validation test_data = api.get_tokenized_test_corpus() # 测试数据 # 三个list,每个list中的每一个元素都是 [topic, last_sentence, current_sentence] train_poem, valid_poem, test_poem = poem_corpus["train"], poem_corpus[ "valid"], test_data["test"] train_loader = SWDADataLoader("Train", train_poem, config) valid_loader = SWDADataLoader("Valid", valid_poem, config) test_loader = SWDADataLoader("Test", test_poem, config) print("Finish Poem data loading, not pretraining or alignment test") if not args.forward_only: # LOG # log_start_time = str(datetime.now().strftime('%Y%m%d%H%M')) if not os.path.isdir('./output'): os.makedirs('./output') if not os.path.isdir('./output/{}'.format(args.expname)): os.makedirs('./output/{}'.format(args.expname)) if not os.path.isdir('./output/{}/{}'.format(args.expname, log_start_time)): os.makedirs('./output/{}/{}'.format(args.expname, log_start_time)) # save arguments json.dump( vars(args), open( './output/{}/{}/args.json'.format(args.expname, log_start_time), 'w')) logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG, format="%(message)s") fh = logging.FileHandler("./output/{}/{}/logs.txt".format( args.expname, log_start_time)) # add the handlers to the logger logger.addHandler(fh) logger.info(vars(args)) tb_writer = SummaryWriter("./output/{}/{}/tb_logs".format( args.expname, log_start_time)) if args.visual else None if config.reload_model: model = load_model(config.model_name) else: if args.model == "mCVAE": model = CVAE_GMP(config=config, api=api) elif args.model == 'CVAE': model = CVAE(config=config, api=api) else: model = Seq2Seq(config=config, api=api) if use_cuda: model = model.cuda() # if corpus.word2vec is not None and args.reload_from<0: # print("Loaded word2vec") # model.embedder.weight.data.copy_(torch.from_numpy(corpus.word2vec)) # model.embedder.weight.data[0].fill_(0) ############################################################################### # Start training ############################################################################### # model依然是PoemWAE_GMP保持不变,只不过,用这部分数据强制训练其中一个高斯先验分布 # pretrain = True cur_best_score = { 'min_valid_loss': 100, 'min_global_itr': 0, 'min_epoch': 0, 'min_itr': 0 } train_loader.epoch_init(config.batch_size, shuffle=True) # model = load_model(3, 3) epoch_id = 0 global_t = 0 while epoch_id < config.epochs: while True: # loop through all batches in training data # train一个batch model, finish_train, loss_records, global_t = \ train_process(global_t=global_t, model=model, train_loader=train_loader, config=config, sentiment_data=with_sentiment) if finish_train: test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) # evaluate_process(model=model, valid_loader=valid_loader, log_start_time=log_start_time, global_t=global_t, epoch=epoch_id, logger=logger, tb_writer=tb_writer, api=api) # save model after each epoch save_model(model=model, epoch=epoch_id, global_t=global_t, log_start_time=log_start_time) logger.info( 'Finish epoch %d, current min valid loss: %.4f \ correspond epoch: %d itr: %d \n\n' % (cur_best_score['min_valid_loss'], cur_best_score['min_global_itr'], cur_best_score['min_epoch'], cur_best_score['min_itr'])) # 初始化下一个unlabeled data epoch的训练 # unlabeled_epoch += 1 epoch_id += 1 train_loader.epoch_init(config.batch_size, shuffle=True) break # elif batch_idx >= start_batch + config.n_batch_every_iter: # print("Finish unlabel epoch %d batch %d to %d" % # (unlabeled_epoch, start_batch, start_batch + config.n_batch_every_iter)) # start_batch += config.n_batch_every_iter # break # 写一下log if global_t % config.log_every == 0: log = 'Epoch id %d: step: %d/%d: ' \ % (epoch_id, global_t % train_loader.num_batch, train_loader.num_batch) for loss_name, loss_value in loss_records: if loss_name == 'avg_lead_loss': continue log = log + loss_name + ':%.4f ' % loss_value if args.visual: tb_writer.add_scalar(loss_name, loss_value, global_t) logger.info(log) # valid if global_t % config.valid_every == 0: # test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) valid_process( global_t=global_t, model=model, valid_loader=valid_loader, valid_config=valid_config, unlabeled_epoch= epoch_id, # 如果sample_rate_unlabeled不是1,这里要在最后加一个1 tb_writer=tb_writer, logger=logger, cur_best_score=cur_best_score) # if batch_idx % (train_loader.num_batch // 3) == 0: # test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) if global_t % config.test_every == 0: test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) # forward_only 测试 else: expname = 'sentInput' time = '202101191105' model = load_model( './output/{}/{}/model_global_t_13596_epoch3.pckl'.format( expname, time)) test_loader.epoch_init(test_config.batch_size, shuffle=False) if not os.path.exists('./output/{}/{}/test/'.format(expname, time)): os.mkdir('./output/{}/{}/test/'.format(expname, time)) output_file = [ open('./output/{}/{}/test/output_0.txt'.format(expname, time), 'w'), open('./output/{}/{}/test/output_1.txt'.format(expname, time), 'w'), open('./output/{}/{}/test/output_2.txt'.format(expname, time), 'w') ] poem_count = 0 predict_results = {0: [], 1: [], 2: []} titles = {0: [], 1: [], 2: []} sentiment_result = {0: [], 1: [], 2: []} # Get all poem predictions while True: model.eval() batch = test_loader.next_batch_test() # test data使用专门的batch poem_count += 1 if poem_count % 10 == 0: print("Predicted {} poems".format(poem_count)) if batch is None: break title_list = batch # batch size是1,一个batch写一首诗 title_tensor = to_tensor(title_list) # test函数将当前batch对应的这首诗decode出来,记住每次decode的输入context是上一次的结果 for i in range(3): sentiment_label = np.zeros(1, dtype=np.int64) sentiment_label[0] = int(i) sentiment_label = to_tensor(sentiment_label) output_poem, output_tokens = model.test( title_tensor, title_list, sentiment_label=sentiment_label) titles[i].append(output_poem.strip().split('\n')[0]) predict_results[i] += (np.array(output_tokens)[:, :7].tolist()) # Predict sentiment use the sort net from collections import defaultdict neg = defaultdict(int) neu = defaultdict(int) pos = defaultdict(int) total = defaultdict(int) for i in range(3): _, neg[i], neu[i], pos[i] = test_sentiment(predict_results[i]) total[i] = neg[i] + neu[i] + pos[i] for i in range(3): print("%d%%\t%d%%\t%d%%" % (neg * 100 / total, neu * 100 / total, pos * 100 / total)) for i in range(3): write_predict_result_to_file(titles[i], predict_results[i], sentiment_result[i], output_file[i]) output_file[i].close() print("Done testing")
def main(): # config for training config = Config() # config for validation valid_config = Config() valid_config.keep_prob = 1.0 valid_config.dec_keep_prob = 1.0 valid_config.batch_size = 60 # configuration for testing test_config = Config() test_config.keep_prob = 1.0 test_config.dec_keep_prob = 1.0 test_config.batch_size = 1 pp(config) # get data set api = SWDADialogCorpus(FLAGS.data_dir, word2vec=FLAGS.word2vec_path, word2vec_dim=config.embed_size) dial_corpus = api.get_dialog_corpus() meta_corpus = api.get_meta_corpus() train_meta, valid_meta, test_meta = meta_corpus.get( "train"), meta_corpus.get("valid"), meta_corpus.get("test") train_dial, valid_dial, test_dial = dial_corpus.get( "train"), dial_corpus.get("valid"), dial_corpus.get("test") # convert to numeric input outputs that fits into TF models train_feed = SWDADataLoader("Train", train_dial, train_meta, config) valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config) test_feed = SWDADataLoader("Test", test_dial, test_meta, config) if FLAGS.forward_only or FLAGS.resume: log_dir = os.path.join(FLAGS.work_dir, FLAGS.test_path) else: log_dir = os.path.join(FLAGS.work_dir, "run" + str(int(time.time()))) # begin training if True: scope = "model" model = KgRnnCVAE(config, api, log_dir=None if FLAGS.forward_only else log_dir, scope=scope) print("Created computation graphs") # write config to a file for logging if not FLAGS.forward_only: with open(os.path.join(log_dir, "run.log"), "wb") as f: f.write(pp(config, output=False)) # create a folder by force ckp_dir = os.path.join(log_dir, "checkpoints") if not os.path.exists(ckp_dir): os.mkdir(ckp_dir) ckpt = get_checkpoint_state(ckp_dir) print("Created models with fresh parameters.") model.apply(lambda m: [ torch.nn.init.uniform(p.data, -1.0 * config.init_w, config.init_w) for p in m.parameters() ]) # Load word2vec weight if api.word2vec is not None and not FLAGS.forward_only: print("Loaded word2vec") model.embedding.weight.data.copy_( torch.from_numpy(np.array(api.word2vec))) model.embedding.weight.data[0].fill_(0) if ckpt: print("Reading dm models parameters from %s" % ckpt) model.load_state_dict(torch.load(ckpt)) # # turn to cuda model.cuda() if not FLAGS.forward_only: dm_checkpoint_path = os.path.join( ckp_dir, model.__class__.__name__ + "-%d.pth") global_t = 1 patience = 10 # wait for at least 10 epoch before stop dev_loss_threshold = np.inf best_dev_loss = np.inf for epoch in range(config.max_epoch): print(">> Epoch %d with lr %f" % (epoch, model.learning_rate)) # begin training if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch: train_feed.epoch_init(config.batch_size, config.backward_size, config.step_size, shuffle=True) global_t, train_loss = model.train_model( global_t, train_feed, update_limit=config.update_limit) # begin validation valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) model.eval() valid_loss = model.valid_model("ELBO_VALID", valid_feed) test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=True, intra_shuffle=False) model.test_model(test_feed, num_batch=5) model.train() done_epoch = epoch + 1 # only save a models if the dev loss is smaller # Decrease learning rate if no improvement was seen over last 3 times. if config.op == "sgd" and done_epoch > config.lr_hold: model.learning_rate_decay() if valid_loss < best_dev_loss: if valid_loss <= dev_loss_threshold * config.improve_threshold: patience = max(patience, done_epoch * config.patient_increase) dev_loss_threshold = valid_loss # still save the best train model if FLAGS.save_model: print("Save model!!") torch.save(model.state_dict(), dm_checkpoint_path % (epoch)) best_dev_loss = valid_loss if config.early_stop and patience <= done_epoch: print("!!Early stop due to run out of patience!!") break print("Best validation loss %f" % best_dev_loss) print("Done training") else: # begin validation # begin validation valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) model.eval() model.valid_model("ELBO_VALID", valid_feed) test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) model.valid_model("ELBO_TEST", test_feed) dest_f = open(os.path.join(log_dir, "test.txt"), "wb") test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=False, intra_shuffle=False) model.test_model(test_feed, num_batch=None, repeat=10, dest=dest_f) model.train() dest_f.close()
def main(): # config for training config = Config() # config for validation valid_config = Config() valid_config.keep_prob = 1.0 valid_config.dec_keep_prob = 1.0 valid_config.batch_size = 60 # configuration for testing test_config = Config() test_config.keep_prob = 1.0 test_config.dec_keep_prob = 1.0 test_config.batch_size = 1 pp(config) # get data set api = SWDADialogCorpus(FLAGS.data_dir, word2vec=FLAGS.word2vec_path, word2vec_dim=config.embed_size) dial_corpus = api.get_dialog_corpus() meta_corpus = api.get_meta_corpus() '''item in meta like: ([0.46, 0.6666666666666666, 1, 0], [0.48, 0.6666666666666666, 1, 0], 29) ''' train_meta, valid_meta, test_meta = meta_corpus.get("train"), meta_corpus.get("valid"), meta_corpus.get("test") '''item in dial ([utt_id,], spearker_id, [dialogact_id, [,,,]]) ''' train_dial, valid_dial, test_dial = dial_corpus.get("train"), dial_corpus.get("valid"), dial_corpus.get("test") # convert to numeric input outputs that fits into TF models train_feed = SWDADataLoader("Train", train_dial, train_meta, config) valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config) test_feed = SWDADataLoader("Test", test_dial, test_meta, config) if FLAGS.forward_only or FLAGS.resume: log_dir = os.path.join(FLAGS.work_dir, FLAGS.test_path) else: log_dir = os.path.join(FLAGS.work_dir, "run"+str(int(time.time()))) # begin training with tf.Session() as sess: initializer = tf.random_uniform_initializer(-1.0 * config.init_w, config.init_w) scope = "model" with tf.variable_scope(scope, reuse=None, initializer=initializer): model = KgRnnCVAE(sess, config, api, log_dir=None if FLAGS.forward_only else log_dir, forward=False, scope=scope) with tf.variable_scope(scope, reuse=True, initializer=initializer): valid_model = KgRnnCVAE(sess, valid_config, api, log_dir=None, forward=False, scope=scope) with tf.variable_scope(scope, reuse=True, initializer=initializer): test_model = KgRnnCVAE(sess, test_config, api, log_dir=None, forward=True, scope=scope) print("Created computation graphs") if api.word2vec is not None and not FLAGS.forward_only: print("Loaded word2vec") sess.run(model.embedding.assign(np.array(api.word2vec))) # write config to a file for logging if not FLAGS.forward_only: with open(os.path.join(log_dir, "run.log"), "wb") as f: f.write(pp(config, output=False)) # create a folder by force ckp_dir = os.path.join(log_dir, "checkpoints") if not os.path.exists(ckp_dir): os.mkdir(ckp_dir) ckpt = tf.train.get_checkpoint_state(ckp_dir) print("Created models with fresh parameters.") sess.run(tf.global_variables_initializer()) if ckpt: print("Reading dm models parameters from %s" % ckpt.model_checkpoint_path) model.saver.restore(sess, ckpt.model_checkpoint_path) if not FLAGS.forward_only: dm_checkpoint_path = os.path.join(ckp_dir, model.__class__.__name__+ ".ckpt") global_t = 1 patience = 10 # wait for at least 10 epoch before stop dev_loss_threshold = np.inf best_dev_loss = np.inf for epoch in range(config.max_epoch): print(">> Epoch %d with lr %f" % (epoch, model.learning_rate.eval())) # begin training if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch: train_feed.epoch_init(config.batch_size, config.backward_size, config.step_size, shuffle=True) global_t, train_loss = model.train(global_t, sess, train_feed, update_limit=config.update_limit) # begin validation valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_loss = valid_model.valid("ELBO_VALID", sess, valid_feed) test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=True, intra_shuffle=False) test_model.test(sess, test_feed, num_batch=5) done_epoch = epoch + 1 # only save a models if the dev loss is smaller # Decrease learning rate if no improvement was seen over last 3 times. if config.op == "sgd" and done_epoch > config.lr_hold: sess.run(model.learning_rate_decay_op) if valid_loss < best_dev_loss: if valid_loss <= dev_loss_threshold * config.improve_threshold: patience = max(patience, done_epoch * config.patient_increase) dev_loss_threshold = valid_loss # still save the best train model if FLAGS.save_model: print("Save model!!") model.saver.save(sess, dm_checkpoint_path, global_step=epoch) best_dev_loss = valid_loss if config.early_stop and patience <= done_epoch: print("!!Early stop due to run out of patience!!") break print("Best validation loss %f" % best_dev_loss) print("Done training") elif not FLAGS.demo: # begin validation # begin validation valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_model.valid("ELBO_VALID", sess, valid_feed) test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_model.valid("ELBO_TEST", sess, test_feed) dest_f = open(os.path.join(log_dir, "test.txt"), "wb") test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=False, intra_shuffle=False) test_model.test(sess, test_feed, num_batch=None, repeat=10, dest=dest_f) dest_f.close() else: infer_api = Corpus(FLAGS.vocab, FLAGS.da_vocab, FLAGS.topic_vocab) while(1): demo_sent = raw_input('Please input your sentence:\t') if demo_sent == '' or demo_sent.isspace(): print('See you next time!') break speaker = raw_input('Are you speaker A?: 1 or 0\t') topic = raw_input('what is your topic:\t') # demo_sent = 'Hello' # speaker = '1' # topic = 'MUSIC' meta, dial = infer_api.format_input(demo_sent, speaker, topic) infer_feed = SWDADataLoader("Infer", dial, meta, config) infer_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=False, intra_shuffle=False) pred_strs, pred_das = test_model.inference(sess, infer_feed, num_batch=None, repeat=3) if pred_strs and pred_das: print('Here is your answer') for da, answer in zip(pred_strs, pred_das): print('{} >> {}'.format(da, answer))
def main(): ## random seeds seed = FLAGS.seed # tf.random.set_seed(seed) np.random.seed(seed) ## config for training config = Config() pid = PIDControl(FLAGS.exp_KL) # config for validation valid_config = Config() valid_config.keep_prob = 1.0 valid_config.dec_keep_prob = 1.0 valid_config.batch_size = 60 # configuration for testing test_config = Config() test_config.keep_prob = 1.0 test_config.dec_keep_prob = 1.0 test_config.batch_size = 1 pp(config) # get data set api = SWDADialogCorpus(FLAGS.data_dir, word2vec=FLAGS.word2vec_path, word2vec_dim=config.embed_size) dial_corpus = api.get_dialog_corpus() meta_corpus = api.get_meta_corpus() train_meta, valid_meta, test_meta = meta_corpus.get("train"), meta_corpus.get("valid"), meta_corpus.get("test") train_dial, valid_dial, test_dial = dial_corpus.get("train"), dial_corpus.get("valid"), dial_corpus.get("test") # convert to numeric input outputs that fits into TF models train_feed = SWDADataLoader("Train", train_dial, train_meta, config) valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config) test_feed = SWDADataLoader("Test", test_dial, test_meta, config) if FLAGS.forward_only or FLAGS.resume: # log_dir = os.path.join(FLAGS.work_dir, FLAGS.test_path) log_dir = os.path.join(FLAGS.work_dir, FLAGS.model_name) else: log_dir = os.path.join(FLAGS.work_dir, FLAGS.model_name) ## begin training with tf.Session() as sess: initializer = tf.random_uniform_initializer(-1.0 * config.init_w, config.init_w) scope = "model" with tf.variable_scope(scope, reuse=None, initializer=initializer): model = KgRnnCVAE(sess, config, api, log_dir=None if FLAGS.forward_only else log_dir, forward=False, pid_control=pid, scope=scope) with tf.variable_scope(scope, reuse=True, initializer=initializer): valid_model = KgRnnCVAE(sess, valid_config, api, log_dir=None, forward=False, pid_control=pid, scope=scope) with tf.variable_scope(scope, reuse=True, initializer=initializer): test_model = KgRnnCVAE(sess, test_config, api, log_dir=None, forward=True, pid_control=pid, scope=scope) print("Created computation graphs") if api.word2vec is not None and not FLAGS.forward_only: print("Loaded word2vec") sess.run(model.embedding.assign(np.array(api.word2vec))) # write config to a file for logging if not FLAGS.forward_only: with open(os.path.join(log_dir, "configure.log"), "wb") as f: f.write(pp(config, output=False)) # create a folder by force ckp_dir = os.path.join(log_dir, "checkpoints") print("*******checkpoint path: ", ckp_dir) if not os.path.exists(ckp_dir): os.mkdir(ckp_dir) ckpt = tf.train.get_checkpoint_state(ckp_dir) print("Created models with fresh parameters.") sess.run(tf.global_variables_initializer()) if ckpt: print("Reading dm models parameters from %s" % ckpt.model_checkpoint_path) model.saver.restore(sess, ckpt.model_checkpoint_path) ### save log when running if not FLAGS.forward_only: logfileName = "train.log" else: logfileName = "test.log" fw_log = open(os.path.join(log_dir, logfileName), "w") print("log directory >>> : ", os.path.join(log_dir, "run.log")) if not FLAGS.forward_only: print('--start training now---') dm_checkpoint_path = os.path.join(ckp_dir, model.__class__.__name__+ ".ckpt") global_t = 1 patience = 20 # wait for at least 10 epoch before stop dev_loss_threshold = np.inf best_dev_loss = np.inf pbar = tqdm(total = config.max_epoch) ## epoch start training for epoch in range(config.max_epoch): pbar.update(1) print(">> Epoch %d with lr %f" % (epoch, model.learning_rate.eval())) ## begin training FLAGS.mode = 'train' if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch: train_feed.epoch_init(config.batch_size, config.backward_size, config.step_size, shuffle=True) global_t, train_loss = model.train(global_t, sess, train_feed, update_limit=config.update_limit) FLAGS.mode = 'valid' valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) elbo, nll, ppl, au_count, kl_loss = valid_model.valid("ELBO_TEST", sess, valid_feed, test_feed) print('middle test nll: {} ppl: {} ActiveUnit: {} kl_loss:{}\n'.format(nll, ppl,au_count,kl_loss)) fw_log.write('epoch:{} testing nll:{} ppl:{} ActiveUnit:{} kl_loss:{} elbo:{}\n'.\ format(epoch, nll, ppl, au_count, kl_loss, elbo)) fw_log.flush() ''' ## begin validation FLAGS.mode = 'valid' valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_loss = valid_model.valid("ELBO_VALID", sess, valid_feed) ## test model FLAGS.mode = 'test' test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=True, intra_shuffle=False) test_model.test(sess, test_feed, num_batch=5) done_epoch = epoch + 1 # only save a models if the dev loss is smaller # Decrease learning rate if no improvement was seen over last 3 times. if config.op == "sgd" and done_epoch > config.lr_hold: sess.run(model.learning_rate_decay_op) if valid_loss < best_dev_loss: if valid_loss <= dev_loss_threshold * config.improve_threshold: patience = max(patience, done_epoch * config.patient_increase) dev_loss_threshold = valid_loss # still save the best train model if FLAGS.save_model: print("Save model!!") model.saver.save(sess, dm_checkpoint_path, global_step=epoch) best_dev_loss = valid_loss if config.early_stop and patience <= done_epoch: print("!!Early stop due to run out of patience!!") break ## print("Best validation loss %f" % best_dev_loss) ''' print("Done training and save checkpoint") if FLAGS.save_model: print("Save model!!") model.saver.save(sess, dm_checkpoint_path, global_step=epoch) # begin validation print('--------after training to testing now-----') FLAGS.mode = 'test' # valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, # valid_config.step_size, shuffle=False, intra_shuffle=False) # valid_model.valid("ELBO_VALID", sess, valid_feed) valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) elbo, nll, ppl, au_count,kl_loss = valid_model.valid("ELBO_TEST", sess, valid_feed, test_feed) print('final test nll: {} ppl: {} ActiveUnit: {} kl_loss:{}\n'.format(nll, ppl,au_count,kl_loss)) fw_log.write('Final testing nll:{} ppl:{} ActiveUnit:{} kl_loss:{} elbo:{}\n'.\ format(nll, ppl, au_count, kl_loss, elbo)) dest_f = open(os.path.join(log_dir, FLAGS.test_res), "wb") test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=False, intra_shuffle=False) test_model.test(sess, test_feed, num_batch=None, repeat=10, dest=dest_f) dest_f.close() print("****testing done****") else: # begin validation # begin validation print('*'*89) print('--------testing now-----') print('*'*89) FLAGS.mode = 'test' valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) # valid_model.valid("ELBO_VALID", sess, valid_feed) test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) elbo, nll, ppl, au_count, kl_loss = valid_model.valid("ELBO_TEST", sess, valid_feed, test_feed) print('final test nll: {} ppl: {} ActiveUnit: {} kl_loss:{}\n'.format(nll, ppl,au_count,kl_loss)) fw_log.write('Final testing nll:{} ppl:{} ActiveUnit:{} kl_loss:{} elbo:{}\n'.\ format(nll, ppl, au_count, kl_loss, elbo)) # dest_f = open(os.path.join(log_dir, FLAGS.test_res), "wb") # test_feed.epoch_init(test_config.batch_size, test_config.backward_size, # test_config.step_size, shuffle=False, intra_shuffle=False) # test_model.test(sess, test_feed, num_batch=None, repeat=10, dest=dest_f) # dest_f.close() print("****testing done****") fw_log.close()
def main(): # config for training config = Config() print("Normal train config:") pp(config) valid_config = Config() valid_config.dropout = 0 valid_config.batch_size = 20 # config for test test_config = Config() test_config.dropout = 0 test_config.batch_size = 1 with_sentiment = config.with_sentiment pretrain = False ############################################################################### # Logs ############################################################################### log_start_time = str(datetime.now().strftime('%Y%m%d%H%M')) if not os.path.isdir('./output'): os.makedirs('./output') if not os.path.isdir('./output/{}'.format(args.expname)): os.makedirs('./output/{}'.format(args.expname)) if not os.path.isdir('./output/{}/{}'.format(args.expname, log_start_time)): os.makedirs('./output/{}/{}'.format(args.expname, log_start_time)) # save arguments json.dump( vars(args), open('./output/{}/{}/args.json'.format(args.expname, log_start_time), 'w')) logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG, format="%(message)s") fh = logging.FileHandler("./output/{}/{}/logs.txt".format( args.expname, log_start_time)) # add the handlers to the logger logger.addHandler(fh) logger.info(vars(args)) tb_writer = SummaryWriter("./output/{}/{}/tb_logs".format( args.expname, log_start_time)) if args.visual else None ############################################################################### # Model ############################################################################### # vocab and rev_vocab with open(args.vocab_path) as vocab_file: vocab = vocab_file.read().strip().split('\n') rev_vocab = {vocab[idx]: idx for idx in range(len(vocab))} if not pretrain: pass # assert config.reload_model # model = load_model(config.model_name) else: if args.model == "multiVAE": model = multiVAE(config=config, vocab=vocab, rev_vocab=rev_vocab) else: model = CVAE(config=config, vocab=vocab, rev_vocab=rev_vocab) if use_cuda: model = model.cuda() ############################################################################### # Load data ############################################################################### if pretrain: from collections import defaultdict api = LoadPretrainPoem(corpus_path=args.pretrain_data_dir, vocab_path="data/vocab.txt") train_corpus, valid_corpus = defaultdict(list), defaultdict(list) divide = 50000 train_corpus['pos'], valid_corpus['pos'] = api.data[ 'pos'][:divide], api.data['pos'][divide:] train_corpus['neu'], valid_corpus['neu'] = api.data[ 'neu'][:divide], api.data['neu'][divide:] train_corpus['neg'], valid_corpus['neg'] = api.data[ 'neg'][:divide], api.data['neg'][divide:] token_corpus = defaultdict(dict) token_corpus['pos'], token_corpus['neu'], token_corpus['neg'] = \ api.get_tokenized_poem_corpus(train_corpus['pos'], valid_corpus['pos']), \ api.get_tokenized_poem_corpus(train_corpus['neu'], valid_corpus['neu']), \ api.get_tokenized_poem_corpus(train_corpus['neg'], valid_corpus['neg']), # train_loader_dict = {'pos': } train_loader = { 'pos': SWDADataLoader("Train", token_corpus['pos']['train'], config), 'neu': SWDADataLoader("Train", token_corpus['neu']['train'], config), 'neg': SWDADataLoader("Train", token_corpus['neg']['train'], config) } valid_loader = { 'pos': SWDADataLoader("Train", token_corpus['pos']['valid'], config), 'neu': SWDADataLoader("Train", token_corpus['neu']['valid'], config), 'neg': SWDADataLoader("Train", token_corpus['neg']['valid'], config) } ############################################################################### # Pretrain three VAEs ############################################################################### epoch_id = 0 global_t = 0 init_train_loaders(train_loader, config) while epoch_id < config.epochs: while True: # loop through all batches in training data # train一个batch model, finish_train, loss_records, global_t = \ pre_train_process(global_t=global_t, model=model, train_loader=train_loader) if finish_train: if epoch_id > 5: save_model(model=model, epoch=epoch_id, global_t=global_t, log_start_time=log_start_time) epoch_id += 1 init_train_loaders(train_loader, config) break # 写一下log if global_t % config.log_every == 0: pre_log_process(epoch_id=epoch_id, global_t=global_t, train_loader=train_loader, loss_records=loss_records, logger=logger, tb_writer=tb_writer) # valid if global_t % config.valid_every == 0: # test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) pre_valid_process(global_t=global_t, model=model, valid_loader=valid_loader, valid_config=valid_config, tb_writer=tb_writer, logger=logger) if global_t % config.test_every == 0: pre_test_process(model=model, logger=logger) ############################################################################### # Train the big model ############################################################################### api = LoadPoem(corpus_path=args.train_data_dir, vocab_path="data/vocab.txt", test_path=args.test_data_dir, max_vocab_cnt=config.max_vocab_cnt, with_sentiment=with_sentiment) from collections import defaultdict token_corpus = defaultdict(dict) token_corpus['pos'], token_corpus['neu'], token_corpus['neg'] = \ api.get_tokenized_poem_corpus(api.train_corpus['pos'], api.valid_corpus['pos']), \ api.get_tokenized_poem_corpus(api.train_corpus['neu'], api.valid_corpus['neu']), \ api.get_tokenized_poem_corpus(api.train_corpus['neg'], api.valid_corpus['neg']), train_loader = { 'pos': SWDADataLoader("Train", token_corpus['pos']['train'], config), 'neu': SWDADataLoader("Train", token_corpus['neu']['train'], config), 'neg': SWDADataLoader("Train", token_corpus['neg']['train'], config) } valid_loader = { 'pos': SWDADataLoader("Train", token_corpus['pos']['valid'], config), 'neu': SWDADataLoader("Train", token_corpus['neu']['valid'], config), 'neg': SWDADataLoader("Train", token_corpus['neg']['valid'], config) } test_poem = api.get_tokenized_test_corpus()['test'] # 测试数据 test_loader = SWDADataLoader("Test", test_poem, config) print("Finish Poem data loading, not pretraining or alignment test") if not args.forward_only: # model依然是PoemWAE_GMP保持不变,只不过,用这部分数据强制训练其中一个高斯先验分布 # pretrain = True cur_best_score = { 'min_valid_loss': 100, 'min_global_itr': 0, 'min_epoch': 0, 'min_itr': 0 } # model = load_model(3, 3) epoch_id = 0 global_t = 0 init_train_loaders(train_loader, config) while epoch_id < config.epochs: while True: # loop through all batches in training data # train一个batch model, finish_train, loss_records, global_t = \ train_process(global_t=global_t, model=model, train_loader=train_loader) if finish_train: if epoch_id > 5: save_model(model=model, epoch=epoch_id, global_t=global_t, log_start_time=log_start_time) epoch_id += 1 init_train_loaders(train_loader, config) break # 写一下log if global_t % config.log_every == 0: pre_log_process(epoch_id=epoch_id, global_t=global_t, train_loader=train_loader, loss_records=loss_records, logger=logger, tb_writer=tb_writer) # valid if global_t % config.valid_every == 0: valid_process(global_t=global_t, model=model, valid_loader=valid_loader, valid_config=valid_config, tb_writer=tb_writer, logger=logger) # if batch_idx % (train_loader.num_batch // 3) == 0: # test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) if global_t % config.test_every == 0: test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) # forward_only 测试 else: expname = 'trainVAE' time = '202101231631' model = load_model( './output/{}/{}/model_global_t_26250_epoch9.pckl'.format( expname, time)) test_loader.epoch_init(test_config.batch_size, shuffle=False) if not os.path.exists('./output/{}/{}/test/'.format(expname, time)): os.mkdir('./output/{}/{}/test/'.format(expname, time)) output_file = [ open('./output/{}/{}/test/output_0.txt'.format(expname, time), 'w'), open('./output/{}/{}/test/output_1.txt'.format(expname, time), 'w'), open('./output/{}/{}/test/output_2.txt'.format(expname, time), 'w') ] poem_count = 0 predict_results = {0: [], 1: [], 2: []} titles = {0: [], 1: [], 2: []} sentiment_result = {0: [], 1: [], 2: []} # sent_dict = {0: ['0', '1', '1', '0'], 1: ['2', '1', '2', '2'], 2: ['1', '0', '1', '2']} sent_dict = { 0: ['0', '0', '0', '0'], 1: ['1', '1', '1', '1'], 2: ['2', '2', '2', '2'] } # Get all poem predictions while True: model.eval() batch = test_loader.next_batch_test() # test data使用专门的batch poem_count += 1 if poem_count % 10 == 0: print("Predicted {} poems".format(poem_count)) if batch is None: break title_list = batch # batch size是1,一个batch写一首诗 title_tensor = to_tensor(title_list) # test函数将当前batch对应的这首诗decode出来,记住每次decode的输入context是上一次的结果 for i in range(3): sent_labels = sent_dict[i] for _ in range(4): sent_labels.append(str(i)) output_poem, output_tokens = model.test( title_tensor, title_list, sent_labels=sent_labels) titles[i].append(output_poem.strip().split('\n')[0]) predict_results[i] += (np.array(output_tokens)[:, :7].tolist()) # Predict sentiment use the sort net from collections import defaultdict neg = defaultdict(int) neu = defaultdict(int) pos = defaultdict(int) total = defaultdict(int) for i in range(3): cur_sent_result, neg[i], neu[i], pos[i] = test_sentiment( predict_results[i]) sentiment_result[i] = cur_sent_result total[i] = neg[i] + neu[i] + pos[i] for i in range(3): print("%d%%\t%d%%\t%d%%" % (neg[i] * 100 / total[i], neu[i] * 100 / total[i], pos[i] * 100 / total[i])) for i in range(3): write_predict_result_to_file(titles[i], predict_results[i], sentiment_result[i], output_file[i]) output_file[i].close() print("Done testing")
def main(model_type): # config for training config = Config() # config for validation valid_config = Config() valid_config.keep_prob = 1.0 valid_config.dec_keep_prob = 1.0 valid_config.batch_size = 60 # configuration for testing test_config = Config() test_config.keep_prob = 1.0 test_config.dec_keep_prob = 1.0 test_config.batch_size = 1 pp(config) # which model to run if model_type == "kgcvae": model_class = KgRnnCVAE backward_size = config.backward_size config.use_hcf, valid_config.use_hcf, test_config.use_hcf = True elif model_type == "cvae": model_class = KgRnnCVAE backward_size = config.backward_size config.use_hcf, valid_config.use_hcf, test_config.use_hcf = False elif model_type == 'hierbaseline': model_class = HierBaseline backward_size = config.backward_size else: raise ValueError("This shouldn't happen.") # LDA Model ldamodel = LDAModel(config, trained_model_path=FLAGS.lda_model_path, id2word_path=FLAGS.id2word_path) # get data set api = SWDADialogCorpus(FLAGS.data_dir, word2vec=FLAGS.word2vec_path, word2vec_dim=config.embed_size, vocab_dict_path=FLAGS.vocab_dict_path, lda_model=ldamodel, imdb=FLAGS.use_imdb) dial_corpus = api.get_dialog_corpus() meta_corpus = api.get_meta_corpus() train_meta, valid_meta, test_meta = meta_corpus.get( "train"), meta_corpus.get("valid"), meta_corpus.get("test") train_dial, valid_dial, test_dial = dial_corpus.get( "train"), dial_corpus.get("valid"), dial_corpus.get("test") # convert to numeric input outputs that fits into TF models train_feed = SWDADataLoader("Train", train_dial, train_meta, config) valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config) test_feed = SWDADataLoader("Test", test_dial, test_meta, config) # if you're testing an existing implementation or resuming training if FLAGS.forward_only or FLAGS.resume: log_dir = os.path.join(FLAGS.work_dir + "_" + FLAGS.model_type, FLAGS.test_path) else: log_dir = os.path.join(FLAGS.work_dir + "_" + FLAGS.model_type, "run" + str(int(time.time()))) # begin training with tf.Session() as sess: initializer = tf.random_uniform_initializer(-1.0 * config.init_w, config.init_w) scope = "model" with tf.variable_scope(scope, reuse=None, initializer=initializer): model = model_class( sess, config, api, log_dir=None if FLAGS.forward_only else log_dir, forward=False, scope=scope) with tf.variable_scope(scope, reuse=True, initializer=initializer): valid_model = model_class(sess, valid_config, api, log_dir=None, forward=False, scope=scope) with tf.variable_scope(scope, reuse=True, initializer=initializer): test_model = model_class(sess, test_config, api, log_dir=None, forward=True, scope=scope) print("Created computation graphs") if api.word2vec is not None and not FLAGS.forward_only: print("Loaded word2vec") sess.run(model.embedding.assign(np.array(api.word2vec))) # write config to a file for logging if not FLAGS.forward_only: with open(os.path.join(log_dir, "run.log"), "wb") as f: f.write(pp(config, output=False)) # create a folder by force ckp_dir = os.path.join(log_dir, "checkpoints") if not os.path.exists(ckp_dir): os.mkdir(ckp_dir) ckpt = tf.train.get_checkpoint_state(ckp_dir) print("Created models with fresh parameters.") sess.run(tf.global_variables_initializer()) if ckpt: print("Reading dm models parameters from %s" % ckpt.model_checkpoint_path) model.saver.restore(sess, ckpt.model_checkpoint_path) # if you're training a model if not FLAGS.forward_only: dm_checkpoint_path = os.path.join( ckp_dir, model.__class__.__name__ + ".ckpt") global_t = 1 patience = 10 # wait for at least 10 epoch before stop dev_loss_threshold = np.inf best_dev_loss = np.inf # train for a max of max_epoch's. saves the model after the epoch if it's some amount better than current best for epoch in range(config.max_epoch): print(">> Epoch %d with lr %f" % (epoch, model.learning_rate.eval())) # begin training if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch: train_feed.epoch_init(config.batch_size, backward_size, config.step_size, shuffle=True) global_t, train_loss = model.train( global_t, sess, train_feed, update_limit=config.update_limit) # begin validation and testing valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_loss = valid_model.valid("ELBO_VALID", sess, valid_feed) test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=True, intra_shuffle=False) test_model.test( sess, test_feed, num_batch=1 ) #TODO change this batch size back to a reasonably large number done_epoch = epoch + 1 # only save a models if the dev loss is smaller # Decrease learning rate if no improvement was seen over last 3 times. if config.op == "sgd" and done_epoch > config.lr_hold: sess.run(model.learning_rate_decay_op) if True: #valid_loss < best_dev_loss: # TODO this change makes the model always save. Change this back when corpus not trivial if True: #valid_loss <= dev_loss_threshold * config.improve_threshold: patience = max(patience, done_epoch * config.patient_increase) # dev_loss_threshold = valid_loss # still save the best train model if FLAGS.save_model: print("Save model!!") model.saver.save(sess, dm_checkpoint_path, global_step=epoch) # best_dev_loss = valid_loss if config.early_stop and patience <= done_epoch: print("!!Early stop due to run out of patience!!") break # print("Best validation loss %f" % best_dev_loss) print("Done training") # else if you're just testing an existing model else: # begin validation valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_model.valid("ELBO_VALID", sess, valid_feed) test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_model.valid("ELBO_TEST", sess, test_feed) # begin testing dest_f = open(os.path.join(log_dir, "test.txt"), "wb") test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=False, intra_shuffle=False) test_model.test(sess, test_feed, num_batch=None, repeat=10, dest=dest_f) dest_f.close()
def main(): # config for training config = Config() # config for validation valid_config = Config() # valid_config.keep_prob = 1.0 # valid_config.dec_keep_prob = 1.0 # valid_config.batch_size = 60 # configuration for testing test_config = Config() test_config.keep_prob = 1.0 test_config.dec_keep_prob = 1.0 test_config.batch_size = 1 config.n_state = FLAGS.n_state valid_config.n_state = FLAGS.n_state test_config.n_state = FLAGS.n_state config.with_direct_transition = FLAGS.with_direct_transition valid_config.with_direct_transition = FLAGS.with_direct_transition test_config.with_direct_transition = FLAGS.with_direct_transition config.with_word_weights = FLAGS.with_word_weights valid_config.with_word_weights = FLAGS.with_word_weights test_config.with_word_weights = FLAGS.with_word_weights pp(config) print(config.n_state) print(config.with_direct_transition) print(config.with_word_weights) # get data set # api = SWDADialogCorpus(FLAGS.data_dir, word2vec=FLAGS.word2vec_path, word2vec_dim=config.embed_size) with open(config.api_dir, "r") as fh: api = pkl.load(fh) dial_corpus = api.get_dialog_corpus() if config.with_label_loss: labeled_dial_labels = api.get_state_corpus( config.max_dialog_len)['labeled'] # meta_corpus = api.get_meta_corpus() # train_meta, valid_meta, test_meta = meta_corpus.get("train"), meta_corpus.get("valid"), meta_corpus.get("test") train_dial, labeled_dial, test_dial = dial_corpus.get( "train"), dial_corpus.get("labeled"), dial_corpus.get("test") # convert to numeric input outputs that fits into TF models train_feed = SWDADataLoader("Train", train_dial, config) # valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config) test_feed = SWDADataLoader("Test", test_dial, config) if config.with_label_loss: labeled_feed = SWDADataLoader("Labeled", labeled_dial, config, labeled=True) valid_feed = test_feed if FLAGS.forward_only or FLAGS.resume: log_dir = os.path.join(FLAGS.work_dir, FLAGS.test_path) else: log_dir = os.path.join(FLAGS.work_dir, "run" + str(int(time.time()))) # begin training with tf.Session(config=tf.ConfigProto(log_device_placement=True, allow_soft_placement=True)) as sess: initializer = tf.random_uniform_initializer(-1.0 * config.init_w, config.init_w) scope = "model" with tf.variable_scope(scope, reuse=None, initializer=initializer): model = VRNN(sess, config, api, log_dir=None if FLAGS.forward_only else log_dir, scope=scope) with tf.variable_scope(scope, reuse=True, initializer=initializer): valid_model = VRNN(sess, valid_config, api, log_dir=None, scope=scope) #with tf.variable_scope(scope, reuse=True, initializer=initializer): # test_model = KgRnnCVAE(sess, test_config, api, log_dir=None, forward=True, scope=scope) print("Created computation graphs") if api.word2vec is not None and not FLAGS.forward_only: print("Loaded word2vec") sess.run(model.W_embedding.assign(np.array(api.word2vec))) # write config to a file for logging if not FLAGS.forward_only: with open(os.path.join(log_dir, "run.log"), "wb") as f: f.write(pp(config, output=False)) # create a folder by force ckp_dir = os.path.join(log_dir, "checkpoints") if not os.path.exists(ckp_dir): os.mkdir(ckp_dir) ckpt = tf.train.get_checkpoint_state(ckp_dir) print("Created models with fresh parameters.") sess.run(tf.global_variables_initializer()) if ckpt: print("Reading dm models parameters from %s" % ckpt.model_checkpoint_path) model.saver.restore(sess, ckpt.model_checkpoint_path) #print([str(op.name) for op in tf.get_default_graph().get_operations()]) print([str(v.name) for v in tf.global_variables()]) import sys # sys.exit() if not FLAGS.forward_only: dm_checkpoint_path = os.path.join( ckp_dir, model.__class__.__name__ + ".ckpt") global_t = 1 patience = config.n_epoch # wait for at least 10 epoch before stop dev_loss_threshold = np.inf best_dev_loss = np.inf for epoch in range(config.max_epoch): print(">> Epoch %d with lr %f" % (epoch, model.learning_rate.eval())) # begin training if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch: train_feed.epoch_init(config.batch_size, shuffle=True) if config.with_label_loss: labeled_feed.epoch_init(len(labeled_dial), shuffle=False) else: labeled_feed = None labeled_dial_labels = None global_t, train_loss = model.train( global_t, sess, train_feed, labeled_feed, labeled_dial_labels, update_limit=config.update_limit) # begin validation valid_feed.epoch_init(config.batch_size, shuffle=False) valid_loss = valid_model.valid("ELBO_VALID", sess, valid_feed, labeled_feed, labeled_dial_labels) """ test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=True, intra_shuffle=False) test_model.test(sess, test_feed, num_batch=5) """ done_epoch = epoch + 1 # only save a models if the dev loss is smaller # Decrease learning rate if no improvement was seen over last 3 times. if config.op == "sgd" and done_epoch > config.lr_hold: sess.run(model.learning_rate_decay_op) """ if valid_loss < best_dev_loss: if valid_loss <= dev_loss_threshold * config.improve_threshold: patience = max(patience, done_epoch * config.patient_increase) dev_loss_threshold = valid_loss # still save the best train model if FLAGS.save_model: print("Save model!!") model.saver.save(sess, dm_checkpoint_path, global_step=epoch) best_dev_loss = valid_loss """ # still save the best train model if FLAGS.save_model: print("Save model!!") model.saver.save(sess, dm_checkpoint_path, global_step=epoch) if config.early_stop and patience <= done_epoch: print("!!Early stop due to run out of patience!!") break print("Best validation loss %f" % best_dev_loss) print("Done training") else: # begin validation # begin validation global_t = 1 for epoch in range(1): print("test-----------") print(">> Epoch %d with lr %f" % (epoch, model.learning_rate.eval())) if not FLAGS.use_test_batch: # begin training if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch: train_feed.epoch_init(config.batch_size, shuffle=False) if config.with_label_loss: labeled_feed.epoch_init(len(labeled_dial), shuffle=False) else: labeled_feed = None labeled_dial_labels = None results, fetch_results = model.get_zt( global_t, sess, train_feed, update_limit=config.update_limit, labeled_feed=labeled_feed, labeled_labels=labeled_dial_labels) with open(FLAGS.result_path, "w") as fh: pkl.dump(results, fh) with open(FLAGS.result_path + ".param.pkl", "w") as fh: pkl.dump(fetch_results, fh) else: print("use_test_batch") # begin training valid_feed.epoch_init(config.batch_size, shuffle=False) if config.with_label_loss: labeled_feed.epoch_init(len(labeled_dial), shuffle=False) else: labeled_feed = None labeled_dial_labels = None results, fetch_results = model.get_zt( global_t, sess, valid_feed, update_limit=config.update_limit, labeled_feed=labeled_feed, labeled_labels=labeled_dial_labels) with open(FLAGS.result_path, "w") as fh: pkl.dump(results, fh) with open(FLAGS.result_path + ".param.pkl", "w") as fh: pkl.dump(fetch_results, fh) """
def main(): # config for training config = Config() print("Normal train config:") # pp(config) valid_config = Config() valid_config.dropout = 0 valid_config.batch_size = 20 # config for test test_config = Config() test_config.dropout = 0 test_config.batch_size = 1 # LOG # if not os.path.isdir('./output'): os.makedirs('./output') if not os.path.isdir('./output/{}'.format(args.expname)): os.makedirs('./output/{}'.format(args.expname)) cur_time = str(datetime.now().strftime('%Y%m%d%H%M')) # save arguments json.dump( vars(args), open('./output/{}/{}_args.json'.format(args.expname, cur_time), 'w')) logger = logging.getLogger(__name__) logging.basicConfig(level=logging.DEBUG, format="%(message)s") fh = logging.FileHandler("./output/{}/logs_{}.txt".format( args.expname, cur_time)) # add the handlers to the logger logger.addHandler(fh) logger.info(vars(args)) ############################################################################### # Load data ############################################################################### # sentiment data path: ../ final_data / poem_with_sentiment.txt # 该path必须命令行显示输入LoadPoem,因为defaultNone # 处理pretrain数据和完整诗歌数据 api = LoadPoem(args.train_data_dir, args.test_data_dir, args.max_vocab_size) # 交替训练,准备大数据集 poem_corpus = api.get_poem_corpus() # corpus for training and validation test_data = api.get_test_corpus() # 测试数据 # 三个list,每个list中的每一个元素都是 [topic, last_sentence, current_sentence] train_poem, valid_poem, test_poem = poem_corpus.get( "train"), poem_corpus.get("valid"), test_data.get("test") train_loader = SWDADataLoader("Train", train_poem, config) valid_loader = SWDADataLoader("Valid", valid_poem, config) test_loader = SWDADataLoader("Test", test_poem, config) print("Finish Poem data loading, not pretraining or alignment test") if not args.forward_only: ############################################################################### # Define the models and word2vec weight ############################################################################### # 处理用四库全书训练的word2vec # if args.model != "Seq2Seq" # logger.info("Start loading siku word2vec") # pretrain_weight = None # if os.path.exists(args.word2vec_path): # pretrain_vec = {} # word2vec = open(args.word2vec_path) # pretrain_data = word2vec.read().split('\n')[1:] # for data in pretrain_data: # data = data.split(' ') # pretrain_vec[data[0]] = [float(item) for item in data[1:-1]] # # nparray (vocab_len, emb_dim) # pretrain_weight = process_pretrain_vec(pretrain_vec, api.vocab) # logger.info("Successfully loaded siku word2vec") # import pdb # pdb.set_trace() # 无论是否pretrain,都使用高斯混合模型 # pretrain时,用特定数据训练特定的高斯分布 # 不用pretrain时,用大数据训练高斯混合分布 if args.model == "Seq2Seq": model = Seq2Seq(config=config, api=api) else: model = PoemWAE(config=config, api=api) if use_cuda: model = model.cuda() # if corpus.word2vec is not None and args.reload_from<0: # print("Loaded word2vec") # model.embedder.weight.data.copy_(torch.from_numpy(corpus.word2vec)) # model.embedder.weight.data[0].fill_(0) ############################################################################### # Start training ############################################################################### # model依然是PoemWAE_GMP保持不变,只不过,用这部分数据强制训练其中一个高斯先验分布 # pretrain = True tb_writer = SummaryWriter( "./output/{}/{}/{}/logs/".format(args.model, args.expname, args.dataset)\ + datetime.now().strftime('%Y%m%d%H%M')) if args.visual else None global_iter = 1 cur_best_score = { 'min_valid_loss': 100, 'min_global_itr': 0, 'min_epoch': 0, 'min_itr': 0 } train_loader.epoch_init(config.batch_size, shuffle=True) # model = load_model(3, 3) batch_idx = 0 while global_iter < 100: batch_idx = 0 while True: # loop through all batches in training data # train一个batch model, finish_train, loss_records = \ train_process(model=model, train_loader=train_loader, config=config, sentiment_data=False) batch_idx += 1 if finish_train: test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) evaluate_process(model=model, valid_loader=valid_loader, global_iter=global_iter, epoch=global_iter, logger=logger, tb_writer=tb_writer, api=api) # save model after each epoch save_model(model=model, epoch=global_iter, global_iter=global_iter, batch_idx=batch_idx) logger.info( 'Finish epoch %d, current min valid loss: %.4f \ correspond global_itr: %d epoch: %d itr: %d \n\n' % (global_iter, cur_best_score['min_valid_loss'], cur_best_score['min_global_itr'], cur_best_score['min_epoch'], cur_best_score['min_itr'])) # 初始化下一个unlabeled data epoch的训练 # unlabeled_epoch += 1 train_loader.epoch_init(config.batch_size, shuffle=True) break # elif batch_idx >= start_batch + config.n_batch_every_iter: # print("Finish unlabel epoch %d batch %d to %d" % # (unlabeled_epoch, start_batch, start_batch + config.n_batch_every_iter)) # start_batch += config.n_batch_every_iter # break # 写一下log if batch_idx % (train_loader.num_batch // 50) == 0: log = 'Global iter %d: step: %d/%d: ' \ % (global_iter, batch_idx, train_loader.num_batch) for loss_name, loss_value in loss_records: log = log + loss_name + ':%.4f ' % loss_value if args.visual: tb_writer.add_scalar(loss_name, loss_value, global_iter) logger.info(log) # valid if batch_idx % (train_loader.num_batch // 10) == 0: valid_process( model=model, valid_loader=valid_loader, valid_config=valid_config, global_iter=global_iter, unlabeled_epoch= global_iter, # 如果sample_rate_unlabeled不是1,这里要在最后加一个1 batch_idx=batch_idx, tb_writer=tb_writer, logger=logger, cur_best_score=cur_best_score) test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) save_model(model=model, epoch=global_iter, global_iter=global_iter, batch_idx=batch_idx) # if batch_idx % (train_loader.num_batch // 3) == 0: # test_process(model=model, test_loader=test_loader, test_config=test_config, logger=logger) global_iter += 1 # forward_only 测试 else: # test_global_list = [4, 4, 2] # test_epoch_list = [21, 19, 8] test_global_list = [8] test_epoch_list = [20] for i in range(1): # import pdb # pdb.set_trace() model = load_model('./output/basic/header_model.pckl') model.vocab = api.vocab model.rev_vocab = api.rev_vocab test_loader.epoch_init(test_config.batch_size, shuffle=False) last_title = None while True: model.eval() # eval()主要影响BatchNorm, dropout等操作 batch = get_user_input(api.rev_vocab, config.title_size) # batch = test_loader.next_batch_test() # test data使用专门的batch # import pdb # pdb.set_trace() if batch is None: break title_list, headers, title = batch # batch size是1,一个batch写一首诗 if title == last_title: continue last_title = title title_tensor = to_tensor(title_list) # test函数将当前batch对应的这首诗decode出来,记住每次decode的输入context是上一次的结果 output_poem = model.test(title_tensor=title_tensor, title_words=title_list, headers=headers) with open('./content_from_remote.txt', 'w') as file: file.write(output_poem) print(output_poem) print('\n') print("Done testing")