def main(): # config for training config = Config() config.use_bow = False # config for validation valid_config = Config() valid_config.keep_prob = 1.0 valid_config.dec_keep_prob = 1.0 valid_config.batch_size = 60 valid_config.use_bow = False # configuration for testing test_config = Config() test_config.keep_prob = 1.0 test_config.dec_keep_prob = 1.0 test_config.batch_size = 1 test_config.use_bow = False pp(config) # get data set api = SWDADialogCorpus(FLAGS.data_dir, word2vec=FLAGS.word2vec_path, word2vec_dim=config.embed_size) dial_corpus = api.get_dialog_corpus() meta_corpus = api.get_meta_corpus() train_meta, valid_meta, test_meta = meta_corpus.get( "train"), meta_corpus.get("valid"), meta_corpus.get("test") train_dial, valid_dial, test_dial = dial_corpus.get( "train"), dial_corpus.get("valid"), dial_corpus.get("test") # convert to numeric input outputs that fits into TF models train_feed = SWDADataLoader("Train", train_dial, train_meta, config) valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config) test_feed = SWDADataLoader("Test", test_dial, test_meta, config) # begin training sess_config = tf.ConfigProto(log_device_placement=False, allow_soft_placement=True) # sess_config.gpu_options.allow_growth = True sess_config.gpu_options.per_process_gpu_memory_fraction = 0.45 with tf.Session(config=sess_config) as sess: initializer = tf.random_uniform_initializer(-1.0 * config.init_w, config.init_w) scope = "model" with tf.variable_scope(scope, reuse=None, initializer=initializer): model = KgRnnCVAE(sess, config, api, log_dir=None if FLAGS.forward_only else log_dir, forward=False, scope=scope) with tf.variable_scope(scope, reuse=True, initializer=initializer): valid_model = KgRnnCVAE(sess, valid_config, api, log_dir=None, forward=False, scope=scope) with tf.variable_scope(scope, reuse=True, initializer=initializer): test_model = KgRnnCVAE(sess, test_config, api, log_dir=None, forward=True, scope=scope) test_model.prepare_mul_ref() logger.info("Created computation graphs") if api.word2vec is not None and not FLAGS.forward_only: logger.info("Loaded word2vec") sess.run(model.embedding.assign(np.array(api.word2vec))) # write config to a file for logging if not FLAGS.forward_only: with open(os.path.join(log_dir, "run.log"), "wb") as f: f.write(pp(config, output=False)) # create a folder by force ckp_dir = os.path.join(log_dir, "checkpoints") if not os.path.exists(ckp_dir): os.mkdir(ckp_dir) ckpt = tf.train.get_checkpoint_state(ckp_dir) logger.info("Created models with fresh parameters.") sess.run(tf.global_variables_initializer()) if ckpt: logger.info("Reading dm models parameters from %s" % ckpt.model_checkpoint_path) model.saver.restore(sess, ckpt.model_checkpoint_path) if not FLAGS.forward_only: dm_checkpoint_path = os.path.join( ckp_dir, model.__class__.__name__ + ".ckpt") global_t = 1 patience = 10 # wait for at least 10 epoch before stop dev_loss_threshold = np.inf best_dev_loss = np.inf for epoch in range(config.max_epoch): logger.info(">> Epoch %d with lr %f" % (epoch, sess.run(model.learning_rate_cyc, {model.global_t: global_t}))) # begin training if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch: train_feed.epoch_init(config.batch_size, config.backward_size, config.step_size, shuffle=True) global_t, train_loss = model.train( global_t, sess, train_feed, update_limit=config.update_limit) # begin validation logger.record_tabular("Epoch", epoch) logger.record_tabular("Mode", "Val") valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_loss = valid_model.valid("ELBO_VALID", sess, valid_feed) logger.record_tabular("Epoch", epoch) logger.record_tabular("Mode", "Test") test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_model.valid("ELBO_TEST", sess, test_feed) # test_feed.epoch_init(test_config.batch_size, test_config.backward_size, # test_config.step_size, shuffle=True, intra_shuffle=False) # test_model.test_mul_ref(sess, test_feed, num_batch=5) done_epoch = epoch + 1 # only save a models if the dev loss is smaller # Decrease learning rate if no improvement was seen over last 3 times. if config.op == "sgd" and done_epoch > config.lr_hold: sess.run(model.learning_rate_decay_op) if valid_loss < best_dev_loss: if valid_loss <= dev_loss_threshold * config.improve_threshold: patience = max(patience, done_epoch * config.patient_increase) dev_loss_threshold = valid_loss # still save the best train model if FLAGS.save_model: logger.info("Save model!!") model.saver.save(sess, dm_checkpoint_path) best_dev_loss = valid_loss if (epoch % 3) == 2: tmp_model_path = os.path.join( ckp_dir, model.__class__.__name__ + str(epoch) + ".ckpt") model.saver.save(sess, tmp_model_path) if config.early_stop and patience <= done_epoch: logger.info("!!Early stop due to run out of patience!!") break logger.info("Best validation loss %f" % best_dev_loss) logger.info("Done training") else: # begin validation # begin validation valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_model.valid("ELBO_VALID", sess, valid_feed) test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_model.valid("ELBO_TEST", sess, test_feed) dest_f = open(os.path.join(log_dir, "test.txt"), "wb") test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=False, intra_shuffle=False) test_model.test_mul_ref(sess, test_feed, num_batch=None, repeat=5, dest=dest_f) dest_f.close()
def main(): # config for training config = Config() # config for validation valid_config = Config() valid_config.keep_prob = 1.0 valid_config.dec_keep_prob = 1.0 valid_config.batch_size = 60 # configuration for testing test_config = Config() test_config.keep_prob = 1.0 test_config.dec_keep_prob = 1.0 test_config.batch_size = 1 pp(config) # get data set api = SWDADialogCorpus(FLAGS.data_dir, word2vec=FLAGS.word2vec_path, word2vec_dim=config.embed_size) dial_corpus = api.get_dialog_corpus() meta_corpus = api.get_meta_corpus() train_meta, valid_meta, test_meta = meta_corpus.get( "train"), meta_corpus.get("valid"), meta_corpus.get("test") train_dial, valid_dial, test_dial = dial_corpus.get( "train"), dial_corpus.get("valid"), dial_corpus.get("test") # convert to numeric input outputs that fits into TF models train_feed = SWDADataLoader("Train", train_dial, train_meta, config) valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config) test_feed = SWDADataLoader("Test", test_dial, test_meta, config) if FLAGS.forward_only or FLAGS.resume: log_dir = os.path.join(FLAGS.work_dir, FLAGS.test_path) else: log_dir = os.path.join(FLAGS.work_dir, "run" + str(int(time.time()))) # begin training if True: scope = "model" model = KgRnnCVAE(config, api, log_dir=None if FLAGS.forward_only else log_dir, scope=scope) print("Created computation graphs") # write config to a file for logging if not FLAGS.forward_only: with open(os.path.join(log_dir, "run.log"), "wb") as f: f.write(pp(config, output=False)) # create a folder by force ckp_dir = os.path.join(log_dir, "checkpoints") if not os.path.exists(ckp_dir): os.mkdir(ckp_dir) ckpt = get_checkpoint_state(ckp_dir) print("Created models with fresh parameters.") model.apply(lambda m: [ torch.nn.init.uniform(p.data, -1.0 * config.init_w, config.init_w) for p in m.parameters() ]) # Load word2vec weight if api.word2vec is not None and not FLAGS.forward_only: print("Loaded word2vec") model.embedding.weight.data.copy_( torch.from_numpy(np.array(api.word2vec))) model.embedding.weight.data[0].fill_(0) if ckpt: print("Reading dm models parameters from %s" % ckpt) model.load_state_dict(torch.load(ckpt)) # # turn to cuda model.cuda() if not FLAGS.forward_only: dm_checkpoint_path = os.path.join( ckp_dir, model.__class__.__name__ + "-%d.pth") global_t = 1 patience = 10 # wait for at least 10 epoch before stop dev_loss_threshold = np.inf best_dev_loss = np.inf for epoch in range(config.max_epoch): print(">> Epoch %d with lr %f" % (epoch, model.learning_rate)) # begin training if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch: train_feed.epoch_init(config.batch_size, config.backward_size, config.step_size, shuffle=True) global_t, train_loss = model.train_model( global_t, train_feed, update_limit=config.update_limit) # begin validation valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) model.eval() valid_loss = model.valid_model("ELBO_VALID", valid_feed) test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=True, intra_shuffle=False) model.test_model(test_feed, num_batch=5) model.train() done_epoch = epoch + 1 # only save a models if the dev loss is smaller # Decrease learning rate if no improvement was seen over last 3 times. if config.op == "sgd" and done_epoch > config.lr_hold: model.learning_rate_decay() if valid_loss < best_dev_loss: if valid_loss <= dev_loss_threshold * config.improve_threshold: patience = max(patience, done_epoch * config.patient_increase) dev_loss_threshold = valid_loss # still save the best train model if FLAGS.save_model: print("Save model!!") torch.save(model.state_dict(), dm_checkpoint_path % (epoch)) best_dev_loss = valid_loss if config.early_stop and patience <= done_epoch: print("!!Early stop due to run out of patience!!") break print("Best validation loss %f" % best_dev_loss) print("Done training") else: # begin validation # begin validation valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) model.eval() model.valid_model("ELBO_VALID", valid_feed) test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) model.valid_model("ELBO_TEST", test_feed) dest_f = open(os.path.join(log_dir, "test.txt"), "wb") test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=False, intra_shuffle=False) model.test_model(test_feed, num_batch=None, repeat=10, dest=dest_f) model.train() dest_f.close()
def main(): ## random seeds seed = FLAGS.seed # tf.random.set_seed(seed) np.random.seed(seed) ## config for training config = Config() pid = PIDControl(FLAGS.exp_KL) # config for validation valid_config = Config() valid_config.keep_prob = 1.0 valid_config.dec_keep_prob = 1.0 valid_config.batch_size = 60 # configuration for testing test_config = Config() test_config.keep_prob = 1.0 test_config.dec_keep_prob = 1.0 test_config.batch_size = 1 pp(config) # get data set api = SWDADialogCorpus(FLAGS.data_dir, word2vec=FLAGS.word2vec_path, word2vec_dim=config.embed_size) dial_corpus = api.get_dialog_corpus() meta_corpus = api.get_meta_corpus() train_meta, valid_meta, test_meta = meta_corpus.get("train"), meta_corpus.get("valid"), meta_corpus.get("test") train_dial, valid_dial, test_dial = dial_corpus.get("train"), dial_corpus.get("valid"), dial_corpus.get("test") # convert to numeric input outputs that fits into TF models train_feed = SWDADataLoader("Train", train_dial, train_meta, config) valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config) test_feed = SWDADataLoader("Test", test_dial, test_meta, config) if FLAGS.forward_only or FLAGS.resume: # log_dir = os.path.join(FLAGS.work_dir, FLAGS.test_path) log_dir = os.path.join(FLAGS.work_dir, FLAGS.model_name) else: log_dir = os.path.join(FLAGS.work_dir, FLAGS.model_name) ## begin training with tf.Session() as sess: initializer = tf.random_uniform_initializer(-1.0 * config.init_w, config.init_w) scope = "model" with tf.variable_scope(scope, reuse=None, initializer=initializer): model = KgRnnCVAE(sess, config, api, log_dir=None if FLAGS.forward_only else log_dir, forward=False, pid_control=pid, scope=scope) with tf.variable_scope(scope, reuse=True, initializer=initializer): valid_model = KgRnnCVAE(sess, valid_config, api, log_dir=None, forward=False, pid_control=pid, scope=scope) with tf.variable_scope(scope, reuse=True, initializer=initializer): test_model = KgRnnCVAE(sess, test_config, api, log_dir=None, forward=True, pid_control=pid, scope=scope) print("Created computation graphs") if api.word2vec is not None and not FLAGS.forward_only: print("Loaded word2vec") sess.run(model.embedding.assign(np.array(api.word2vec))) # write config to a file for logging if not FLAGS.forward_only: with open(os.path.join(log_dir, "configure.log"), "wb") as f: f.write(pp(config, output=False)) # create a folder by force ckp_dir = os.path.join(log_dir, "checkpoints") print("*******checkpoint path: ", ckp_dir) if not os.path.exists(ckp_dir): os.mkdir(ckp_dir) ckpt = tf.train.get_checkpoint_state(ckp_dir) print("Created models with fresh parameters.") sess.run(tf.global_variables_initializer()) if ckpt: print("Reading dm models parameters from %s" % ckpt.model_checkpoint_path) model.saver.restore(sess, ckpt.model_checkpoint_path) ### save log when running if not FLAGS.forward_only: logfileName = "train.log" else: logfileName = "test.log" fw_log = open(os.path.join(log_dir, logfileName), "w") print("log directory >>> : ", os.path.join(log_dir, "run.log")) if not FLAGS.forward_only: print('--start training now---') dm_checkpoint_path = os.path.join(ckp_dir, model.__class__.__name__+ ".ckpt") global_t = 1 patience = 20 # wait for at least 10 epoch before stop dev_loss_threshold = np.inf best_dev_loss = np.inf pbar = tqdm(total = config.max_epoch) ## epoch start training for epoch in range(config.max_epoch): pbar.update(1) print(">> Epoch %d with lr %f" % (epoch, model.learning_rate.eval())) ## begin training FLAGS.mode = 'train' if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch: train_feed.epoch_init(config.batch_size, config.backward_size, config.step_size, shuffle=True) global_t, train_loss = model.train(global_t, sess, train_feed, update_limit=config.update_limit) FLAGS.mode = 'valid' valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) elbo, nll, ppl, au_count, kl_loss = valid_model.valid("ELBO_TEST", sess, valid_feed, test_feed) print('middle test nll: {} ppl: {} ActiveUnit: {} kl_loss:{}\n'.format(nll, ppl,au_count,kl_loss)) fw_log.write('epoch:{} testing nll:{} ppl:{} ActiveUnit:{} kl_loss:{} elbo:{}\n'.\ format(epoch, nll, ppl, au_count, kl_loss, elbo)) fw_log.flush() ''' ## begin validation FLAGS.mode = 'valid' valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_loss = valid_model.valid("ELBO_VALID", sess, valid_feed) ## test model FLAGS.mode = 'test' test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=True, intra_shuffle=False) test_model.test(sess, test_feed, num_batch=5) done_epoch = epoch + 1 # only save a models if the dev loss is smaller # Decrease learning rate if no improvement was seen over last 3 times. if config.op == "sgd" and done_epoch > config.lr_hold: sess.run(model.learning_rate_decay_op) if valid_loss < best_dev_loss: if valid_loss <= dev_loss_threshold * config.improve_threshold: patience = max(patience, done_epoch * config.patient_increase) dev_loss_threshold = valid_loss # still save the best train model if FLAGS.save_model: print("Save model!!") model.saver.save(sess, dm_checkpoint_path, global_step=epoch) best_dev_loss = valid_loss if config.early_stop and patience <= done_epoch: print("!!Early stop due to run out of patience!!") break ## print("Best validation loss %f" % best_dev_loss) ''' print("Done training and save checkpoint") if FLAGS.save_model: print("Save model!!") model.saver.save(sess, dm_checkpoint_path, global_step=epoch) # begin validation print('--------after training to testing now-----') FLAGS.mode = 'test' # valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, # valid_config.step_size, shuffle=False, intra_shuffle=False) # valid_model.valid("ELBO_VALID", sess, valid_feed) valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) elbo, nll, ppl, au_count,kl_loss = valid_model.valid("ELBO_TEST", sess, valid_feed, test_feed) print('final test nll: {} ppl: {} ActiveUnit: {} kl_loss:{}\n'.format(nll, ppl,au_count,kl_loss)) fw_log.write('Final testing nll:{} ppl:{} ActiveUnit:{} kl_loss:{} elbo:{}\n'.\ format(nll, ppl, au_count, kl_loss, elbo)) dest_f = open(os.path.join(log_dir, FLAGS.test_res), "wb") test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=False, intra_shuffle=False) test_model.test(sess, test_feed, num_batch=None, repeat=10, dest=dest_f) dest_f.close() print("****testing done****") else: # begin validation # begin validation print('*'*89) print('--------testing now-----') print('*'*89) FLAGS.mode = 'test' valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) # valid_model.valid("ELBO_VALID", sess, valid_feed) test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) elbo, nll, ppl, au_count, kl_loss = valid_model.valid("ELBO_TEST", sess, valid_feed, test_feed) print('final test nll: {} ppl: {} ActiveUnit: {} kl_loss:{}\n'.format(nll, ppl,au_count,kl_loss)) fw_log.write('Final testing nll:{} ppl:{} ActiveUnit:{} kl_loss:{} elbo:{}\n'.\ format(nll, ppl, au_count, kl_loss, elbo)) # dest_f = open(os.path.join(log_dir, FLAGS.test_res), "wb") # test_feed.epoch_init(test_config.batch_size, test_config.backward_size, # test_config.step_size, shuffle=False, intra_shuffle=False) # test_model.test(sess, test_feed, num_batch=None, repeat=10, dest=dest_f) # dest_f.close() print("****testing done****") fw_log.close()
def main(): # config for training config = Config() # config for validation valid_config = Config() valid_config.keep_prob = 1.0 valid_config.dec_keep_prob = 1.0 valid_config.batch_size = 60 # configuration for testing test_config = Config() test_config.keep_prob = 1.0 test_config.dec_keep_prob = 1.0 test_config.batch_size = 1 pp(config) # get data set api = SWDADialogCorpus(FLAGS.data_dir, word2vec=FLAGS.word2vec_path, word2vec_dim=config.embed_size) dial_corpus = api.get_dialog_corpus() meta_corpus = api.get_meta_corpus() '''item in meta like: ([0.46, 0.6666666666666666, 1, 0], [0.48, 0.6666666666666666, 1, 0], 29) ''' train_meta, valid_meta, test_meta = meta_corpus.get("train"), meta_corpus.get("valid"), meta_corpus.get("test") '''item in dial ([utt_id,], spearker_id, [dialogact_id, [,,,]]) ''' train_dial, valid_dial, test_dial = dial_corpus.get("train"), dial_corpus.get("valid"), dial_corpus.get("test") # convert to numeric input outputs that fits into TF models train_feed = SWDADataLoader("Train", train_dial, train_meta, config) valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config) test_feed = SWDADataLoader("Test", test_dial, test_meta, config) if FLAGS.forward_only or FLAGS.resume: log_dir = os.path.join(FLAGS.work_dir, FLAGS.test_path) else: log_dir = os.path.join(FLAGS.work_dir, "run"+str(int(time.time()))) # begin training with tf.Session() as sess: initializer = tf.random_uniform_initializer(-1.0 * config.init_w, config.init_w) scope = "model" with tf.variable_scope(scope, reuse=None, initializer=initializer): model = KgRnnCVAE(sess, config, api, log_dir=None if FLAGS.forward_only else log_dir, forward=False, scope=scope) with tf.variable_scope(scope, reuse=True, initializer=initializer): valid_model = KgRnnCVAE(sess, valid_config, api, log_dir=None, forward=False, scope=scope) with tf.variable_scope(scope, reuse=True, initializer=initializer): test_model = KgRnnCVAE(sess, test_config, api, log_dir=None, forward=True, scope=scope) print("Created computation graphs") if api.word2vec is not None and not FLAGS.forward_only: print("Loaded word2vec") sess.run(model.embedding.assign(np.array(api.word2vec))) # write config to a file for logging if not FLAGS.forward_only: with open(os.path.join(log_dir, "run.log"), "wb") as f: f.write(pp(config, output=False)) # create a folder by force ckp_dir = os.path.join(log_dir, "checkpoints") if not os.path.exists(ckp_dir): os.mkdir(ckp_dir) ckpt = tf.train.get_checkpoint_state(ckp_dir) print("Created models with fresh parameters.") sess.run(tf.global_variables_initializer()) if ckpt: print("Reading dm models parameters from %s" % ckpt.model_checkpoint_path) model.saver.restore(sess, ckpt.model_checkpoint_path) if not FLAGS.forward_only: dm_checkpoint_path = os.path.join(ckp_dir, model.__class__.__name__+ ".ckpt") global_t = 1 patience = 10 # wait for at least 10 epoch before stop dev_loss_threshold = np.inf best_dev_loss = np.inf for epoch in range(config.max_epoch): print(">> Epoch %d with lr %f" % (epoch, model.learning_rate.eval())) # begin training if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch: train_feed.epoch_init(config.batch_size, config.backward_size, config.step_size, shuffle=True) global_t, train_loss = model.train(global_t, sess, train_feed, update_limit=config.update_limit) # begin validation valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_loss = valid_model.valid("ELBO_VALID", sess, valid_feed) test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=True, intra_shuffle=False) test_model.test(sess, test_feed, num_batch=5) done_epoch = epoch + 1 # only save a models if the dev loss is smaller # Decrease learning rate if no improvement was seen over last 3 times. if config.op == "sgd" and done_epoch > config.lr_hold: sess.run(model.learning_rate_decay_op) if valid_loss < best_dev_loss: if valid_loss <= dev_loss_threshold * config.improve_threshold: patience = max(patience, done_epoch * config.patient_increase) dev_loss_threshold = valid_loss # still save the best train model if FLAGS.save_model: print("Save model!!") model.saver.save(sess, dm_checkpoint_path, global_step=epoch) best_dev_loss = valid_loss if config.early_stop and patience <= done_epoch: print("!!Early stop due to run out of patience!!") break print("Best validation loss %f" % best_dev_loss) print("Done training") elif not FLAGS.demo: # begin validation # begin validation valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_model.valid("ELBO_VALID", sess, valid_feed) test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_model.valid("ELBO_TEST", sess, test_feed) dest_f = open(os.path.join(log_dir, "test.txt"), "wb") test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=False, intra_shuffle=False) test_model.test(sess, test_feed, num_batch=None, repeat=10, dest=dest_f) dest_f.close() else: infer_api = Corpus(FLAGS.vocab, FLAGS.da_vocab, FLAGS.topic_vocab) while(1): demo_sent = raw_input('Please input your sentence:\t') if demo_sent == '' or demo_sent.isspace(): print('See you next time!') break speaker = raw_input('Are you speaker A?: 1 or 0\t') topic = raw_input('what is your topic:\t') # demo_sent = 'Hello' # speaker = '1' # topic = 'MUSIC' meta, dial = infer_api.format_input(demo_sent, speaker, topic) infer_feed = SWDADataLoader("Infer", dial, meta, config) infer_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=False, intra_shuffle=False) pred_strs, pred_das = test_model.inference(sess, infer_feed, num_batch=None, repeat=3) if pred_strs and pred_das: print('Here is your answer') for da, answer in zip(pred_strs, pred_das): print('{} >> {}'.format(da, answer))
def main(model_type): # config for training config = Config() # config for validation valid_config = Config() valid_config.keep_prob = 1.0 valid_config.dec_keep_prob = 1.0 valid_config.batch_size = 60 # configuration for testing test_config = Config() test_config.keep_prob = 1.0 test_config.dec_keep_prob = 1.0 test_config.batch_size = 1 pp(config) # which model to run if model_type == "kgcvae": model_class = KgRnnCVAE backward_size = config.backward_size config.use_hcf, valid_config.use_hcf, test_config.use_hcf = True elif model_type == "cvae": model_class = KgRnnCVAE backward_size = config.backward_size config.use_hcf, valid_config.use_hcf, test_config.use_hcf = False elif model_type == 'hierbaseline': model_class = HierBaseline backward_size = config.backward_size else: raise ValueError("This shouldn't happen.") # LDA Model ldamodel = LDAModel(config, trained_model_path=FLAGS.lda_model_path, id2word_path=FLAGS.id2word_path) # get data set api = SWDADialogCorpus(FLAGS.data_dir, word2vec=FLAGS.word2vec_path, word2vec_dim=config.embed_size, vocab_dict_path=FLAGS.vocab_dict_path, lda_model=ldamodel, imdb=FLAGS.use_imdb) dial_corpus = api.get_dialog_corpus() meta_corpus = api.get_meta_corpus() train_meta, valid_meta, test_meta = meta_corpus.get( "train"), meta_corpus.get("valid"), meta_corpus.get("test") train_dial, valid_dial, test_dial = dial_corpus.get( "train"), dial_corpus.get("valid"), dial_corpus.get("test") # convert to numeric input outputs that fits into TF models train_feed = SWDADataLoader("Train", train_dial, train_meta, config) valid_feed = SWDADataLoader("Valid", valid_dial, valid_meta, config) test_feed = SWDADataLoader("Test", test_dial, test_meta, config) # if you're testing an existing implementation or resuming training if FLAGS.forward_only or FLAGS.resume: log_dir = os.path.join(FLAGS.work_dir + "_" + FLAGS.model_type, FLAGS.test_path) else: log_dir = os.path.join(FLAGS.work_dir + "_" + FLAGS.model_type, "run" + str(int(time.time()))) # begin training with tf.Session() as sess: initializer = tf.random_uniform_initializer(-1.0 * config.init_w, config.init_w) scope = "model" with tf.variable_scope(scope, reuse=None, initializer=initializer): model = model_class( sess, config, api, log_dir=None if FLAGS.forward_only else log_dir, forward=False, scope=scope) with tf.variable_scope(scope, reuse=True, initializer=initializer): valid_model = model_class(sess, valid_config, api, log_dir=None, forward=False, scope=scope) with tf.variable_scope(scope, reuse=True, initializer=initializer): test_model = model_class(sess, test_config, api, log_dir=None, forward=True, scope=scope) print("Created computation graphs") if api.word2vec is not None and not FLAGS.forward_only: print("Loaded word2vec") sess.run(model.embedding.assign(np.array(api.word2vec))) # write config to a file for logging if not FLAGS.forward_only: with open(os.path.join(log_dir, "run.log"), "wb") as f: f.write(pp(config, output=False)) # create a folder by force ckp_dir = os.path.join(log_dir, "checkpoints") if not os.path.exists(ckp_dir): os.mkdir(ckp_dir) ckpt = tf.train.get_checkpoint_state(ckp_dir) print("Created models with fresh parameters.") sess.run(tf.global_variables_initializer()) if ckpt: print("Reading dm models parameters from %s" % ckpt.model_checkpoint_path) model.saver.restore(sess, ckpt.model_checkpoint_path) # if you're training a model if not FLAGS.forward_only: dm_checkpoint_path = os.path.join( ckp_dir, model.__class__.__name__ + ".ckpt") global_t = 1 patience = 10 # wait for at least 10 epoch before stop dev_loss_threshold = np.inf best_dev_loss = np.inf # train for a max of max_epoch's. saves the model after the epoch if it's some amount better than current best for epoch in range(config.max_epoch): print(">> Epoch %d with lr %f" % (epoch, model.learning_rate.eval())) # begin training if train_feed.num_batch is None or train_feed.ptr >= train_feed.num_batch: train_feed.epoch_init(config.batch_size, backward_size, config.step_size, shuffle=True) global_t, train_loss = model.train( global_t, sess, train_feed, update_limit=config.update_limit) # begin validation and testing valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_loss = valid_model.valid("ELBO_VALID", sess, valid_feed) test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=True, intra_shuffle=False) test_model.test( sess, test_feed, num_batch=1 ) #TODO change this batch size back to a reasonably large number done_epoch = epoch + 1 # only save a models if the dev loss is smaller # Decrease learning rate if no improvement was seen over last 3 times. if config.op == "sgd" and done_epoch > config.lr_hold: sess.run(model.learning_rate_decay_op) if True: #valid_loss < best_dev_loss: # TODO this change makes the model always save. Change this back when corpus not trivial if True: #valid_loss <= dev_loss_threshold * config.improve_threshold: patience = max(patience, done_epoch * config.patient_increase) # dev_loss_threshold = valid_loss # still save the best train model if FLAGS.save_model: print("Save model!!") model.saver.save(sess, dm_checkpoint_path, global_step=epoch) # best_dev_loss = valid_loss if config.early_stop and patience <= done_epoch: print("!!Early stop due to run out of patience!!") break # print("Best validation loss %f" % best_dev_loss) print("Done training") # else if you're just testing an existing model else: # begin validation valid_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_model.valid("ELBO_VALID", sess, valid_feed) test_feed.epoch_init(valid_config.batch_size, valid_config.backward_size, valid_config.step_size, shuffle=False, intra_shuffle=False) valid_model.valid("ELBO_TEST", sess, test_feed) # begin testing dest_f = open(os.path.join(log_dir, "test.txt"), "wb") test_feed.epoch_init(test_config.batch_size, test_config.backward_size, test_config.step_size, shuffle=False, intra_shuffle=False) test_model.test(sess, test_feed, num_batch=None, repeat=10, dest=dest_f) dest_f.close()