def main(config): prepare_dirs_loggers(config, os.path.basename(__file__)) corpus_client = corpora.NormMultiWozCorpus(config) dial_corpus = corpus_client.get_corpus() train_dial, valid_dial, test_dial = dial_corpus sample_shape = config.batch_size, config.state_noise_dim, config.action_noise_dim # evaluator = evaluators.BleuEvaluator("os.path.basename(__file__)") evaluator = MultiWozEvaluator('SysWoz') # create data loader that feed the deep models train_feed = data_loaders.BeliefDbDataLoaders("Train", train_dial, config) valid_feed = data_loaders.BeliefDbDataLoaders("Valid", valid_dial, config) test_feed = data_loaders.BeliefDbDataLoaders("Test", test_dial, config) model = GanRnnAgent(corpus_client, config) load_context_encoder( model, os.path.join(config.log_dir, config.encoder_sess, "model_lirl")) if config.forward_only: test_file = os.path.join( config.log_dir, config.load_sess, "{}-test-{}.txt".format(get_time(), config.gen_type)) dump_file = os.path.join(config.log_dir, config.load_sess, "{}-z.pkl".format(get_time())) model_file = os.path.join(config.log_dir, config.load_sess, "model") else: test_file = os.path.join( config.session_dir, "{}-test-{}.txt".format(get_time(), config.gen_type)) dump_file = os.path.join(config.session_dir, "{}-z.pkl".format(get_time())) model_file = os.path.join(config.session_dir, "model") if config.use_gpu: model.cuda() print("Evaluate initial model on Validate set") engine.disc_validate(model, valid_feed, config, sample_shape) print("Start training") if config.forward_only is False: try: engine.gan_train(model, train_feed, valid_feed, test_feed, config, evaluator) except KeyboardInterrupt: print("Training stopped by keyboard.") print("Trainig Done! Start Testing") model.load_state_dict(torch.load(model_file)) engine.disc_validate(model, valid_feed, config, sample_shape) engine.disc_validate(model, test_feed, config, sample_shape) # dialog_utils.generate_with_adv(model, test_feed, config, None, num_batch=None) # selected_clusters, index_cluster_id_train = utt_utils.latent_cluster(model, train_feed, config, num_batch=None) # _, index_cluster_id_test = utt_utils.latent_cluster(model, test_feed, config, num_batch=None) # _, index_cluster_id_valid = utt_utils.latent_cluster(model, valid_feed, config, num_batch=None) # selected_outs = dialog_utils.selective_generate(model, test_feed, config, selected_clusters) # print(len(selected_outs)) '''
def main(config): prepare_dirs_loggers(config, os.path.basename(__file__)) corpus_client = corpora.DailyDialogCorpus(config) dial_corpus = corpus_client.get_corpus() train_dial, valid_dial, test_dial = dial_corpus['train'],\ dial_corpus['valid'],\ dial_corpus['test'] evaluator = evaluators.BleuEvaluator("CornellMovie") # create data loader that feed the deep models train_feed = data_loaders.DailyDialogLoader("Train", train_dial, config) valid_feed = data_loaders.DailyDialogLoader("Valid", valid_dial, config) test_feed = data_loaders.DailyDialogLoader("Test", test_dial, config) model = sent_models.DiVAE(corpus_client, config) if config.forward_only: test_file = os.path.join(config.log_dir, config.load_sess, "{}-test-{}.txt".format(get_time(), config.gen_type)) dump_file = os.path.join(config.log_dir, config.load_sess, "{}-z.pkl".format(get_time())) model_file = os.path.join(config.log_dir, config.load_sess, "model") else: test_file = os.path.join(config.session_dir, "{}-test-{}.txt".format(get_time(), config.gen_type)) dump_file = os.path.join(config.session_dir, "{}-z.pkl".format(get_time())) model_file = os.path.join(config.session_dir, "model") if config.use_gpu: model.cuda() if config.forward_only is False: try: engine.train(model, train_feed, valid_feed, test_feed, config, evaluator, gen=utt_utils.generate) except KeyboardInterrupt: print("Training stopped by keyboard.") config.batch_size = 50 model.load_state_dict(torch.load(model_file)) engine.validate(model, test_feed, config) utt_utils.sweep(model, test_feed, config, num_batch=50) with open(os.path.join(dump_file), "wb") as f: print("Dumping test to {}".format(dump_file)) utt_utils.dump_latent(model, test_feed, config, f, num_batch=None) with open(os.path.join(test_file), "wb") as f: print("Saving test to {}".format(test_file)) utt_utils.generate(model, test_feed, config, evaluator, num_batch=None, dest_f=f)
def main(config): corpus_client = getattr(corpora, config.corpus_client)(config) corpus_client.vocab, corpus_client.rev_vocab, corpus_client.unk_id = load_vocab( config.vocab) prepare_dirs_loggers(config, os.path.basename(__file__)) dial_corpus = corpus_client.get_corpus() train_dial, valid_dial, test_dial = (dial_corpus['train'], dial_corpus['valid'], dial_corpus['test']) evaluator = evaluators.BleuEvaluator("CornellMovie") # create data loader that feed the deep models train_feed = data_loaders.SMDDialogSkipLoader("Train", train_dial, config) valid_feed = data_loaders.SMDDialogSkipLoader("Valid", valid_dial, config) test_feed = data_loaders.SMDDialogSkipLoader("Test", test_dial, config) model = dialog_models.VAE(corpus_client, config) if config.forward_only: test_file = os.path.join( config.log_dir, config.load_sess, "{}-test-{}.txt".format(get_time(), config.gen_type)) dump_file = os.path.join(config.log_dir, config.load_sess, "{}-z.pkl".format(get_time())) model_file = os.path.join(config.log_dir, config.load_sess, "model") else: test_file = os.path.join( config.session_dir, "{}-test-{}.txt".format(get_time(), config.gen_type)) dump_file = os.path.join(config.session_dir, "{}-z.pkl".format(get_time())) model_file = os.path.join(config.session_dir, "model") if config.use_gpu: model.cuda() if config.forward_only is False: try: engine.train(model, train_feed, valid_feed, test_feed, config, evaluator, gen=dialog_utils.generate_vae) except KeyboardInterrupt: print("Training stopped by keyboard.")
def main(config): prepare_dirs_loggers(config, os.path.basename(__file__)) manualSeed=config.seed random.seed(manualSeed) torch.manual_seed(manualSeed) np.random.seed(manualSeed) sample_shape = config.batch_size, config.state_noise_dim, config.action_noise_dim # evaluator = evaluators.BleuEvaluator(os.path.basename(__file__)) evaluator = False train_feed = WoZGanDataLoaders("train", config) valid_feed = WoZGanDataLoaders("val", config) test_feed = WoZGanDataLoaders("test", config) # action2name = load_action2name(config) action2name = None corpus_client = None # model = GanAgent_AutoEncoder(corpus_client, config, action2name) # model = GanAgent_AutoEncoder_Encode(corpus_client, config, action2name) # model = GanAgent_AutoEncoder_State(corpus_client, config, action2name) if config.gan_type=='wgan': model = WGanAgent_VAE_State(corpus_client, config, action2name) else: model = GanAgent_VAE_State(corpus_client, config, action2name) logger.info(summary(model, show_weights=False)) model.discriminator.apply(weights_init) model.generator.apply(weights_init) model.vae.apply(weights_init) if config.forward_only: test_file = os.path.join(config.log_dir, config.load_sess, "{}-test-{}.txt".format(get_time(), config.gen_type)) dump_file = os.path.join(config.log_dir, config.load_sess, "{}-z.pkl".format(get_time())) model_file = os.path.join(config.log_dir, config.load_sess, "model") else: test_file = os.path.join(config.session_dir, "{}-test-{}.txt".format(get_time(), config.gen_type)) dump_file = os.path.join(config.session_dir, "{}-z.pkl".format(get_time())) model_file = os.path.join(config.session_dir, "model") vocab_file = os.path.join(config.session_dir, "vocab.json") if config.use_gpu: model.cuda() pred_list = [] generator_samples = [] print("Evaluate initial model on Validate set") model.eval() # policy_validate_for_human(model,valid_feed, config, sample_shape) disc_validate(model, valid_feed, config, sample_shape) _, sample_batch = gen_validate(model,valid_feed, config, sample_shape, -1) generator_samples.append([-1, sample_batch]) machine_data, human_data = build_fake_data(model, valid_feed, config, sample_shape) model.train() print("Start VAE training") # # this is for the training of VAE. If you already have a pretrained model, you can skip this step. # if config.forward_only is False: # try: # engine.vae_train(model, train_feed, valid_feed, test_feed, config) # except KeyboardInterrupt: # print("Training stopped by keyboard.") # print("AutoEncoder Training Done ! ") # load_model_vae(model, config) # this is a pretrained vae model, you can load it to the current model. TODO: move path todata_args path='./logs/2019-09-06T10:50:18.034181-mwoz_gan_vae.py' load_model_vae(model, path) print("Start GAN training") if config.forward_only is False: try: engine.gan_train(model, machine_data, train_feed, valid_feed, test_feed, config, evaluator, pred_list, generator_samples) except KeyboardInterrupt: print("Training stopped by keyboard.") print("Reward Model Training Done ! ") print("Saved path: {}".format(model_file))
def main(config): corpus_client = getattr(corpora, config.corpus_client)(config) corpus_client.vocab, corpus_client.rev_vocab, corpus_client.unk_id = load_vocab( config.vocab) prepare_dirs_loggers(config, os.path.basename(__file__)) dial_corpus = corpus_client.get_corpus() train_dial, valid_dial, test_dial = (dial_corpus['train'], dial_corpus['valid'], dial_corpus['test']) evaluator = evaluators.BleuEvaluator("CornellMovie") # create data loader that feed the deep models train_feed = data_loaders.SMDDialogSkipLoader("Train", train_dial, config) valid_feed = data_loaders.SMDDialogSkipLoader("Valid", valid_dial, config) test_feed = data_loaders.SMDDialogSkipLoader("Test", test_dial, config) model = dialog_models.StED(corpus_client, config) if config.forward_only: test_file = os.path.join( config.log_dir, config.load_sess, "{}-test-{}.txt".format(get_time(), config.gen_type)) dump_file = os.path.join(config.log_dir, config.load_sess, "{}-z.pkl".format(get_time())) model_file = os.path.join(config.log_dir, config.load_sess, "model") else: test_file = os.path.join( config.session_dir, "{}-test-{}.txt".format(get_time(), config.gen_type)) dump_file = os.path.join(config.session_dir, "{}-z.pkl".format(get_time())) model_file = os.path.join(config.session_dir, "model") if config.use_gpu: model.cuda() if not config.forward_only: try: engine.train(model, train_feed, valid_feed, test_feed, config, evaluator, gen=dialog_utils.generate_with_adv) except KeyboardInterrupt: print("Training stopped by keyboard.") config.batch_size = 10 model.load_state_dict(torch.load(model_file)) engine.validate(model, valid_feed, config) engine.validate(model, test_feed, config) dialog_utils.generate_with_adv(model, test_feed, config, None, num_batch=0) selected_clusters = utt_utils.latent_cluster(model, train_feed, config, num_batch=None) selected_outs = dialog_utils.selective_generate(model, test_feed, config, selected_clusters) print(len(selected_outs)) with open(os.path.join(dump_file + '.json'), 'wb') as f: json.dump(selected_clusters, f, indent=2) with open(os.path.join(dump_file + '.out.json'), 'wb') as f: json.dump(selected_outs, f, indent=2) with open(os.path.join(dump_file), "wb") as f: print("Dumping test to {}".format(dump_file)) dialog_utils.dump_latent(model, test_feed, config, f, num_batch=None) with open(os.path.join(test_file), "wb") as f: print("Saving test to {}".format(test_file)) dialog_utils.gen_with_cond(model, test_feed, config, num_batch=None, dest_f=f) with open(os.path.join(test_file + '.txt'), "wb") as f: print("Saving test to {}".format(test_file)) dialog_utils.generate(model, test_feed, config, evaluator, num_batch=None, dest_f=f)
def main(config): laed_config = load_config(config.model) laed_config.use_gpu = config.use_gpu laed_config = process_config(laed_config) setattr(laed_config, 'black_domains', config.black_domains) setattr(laed_config, 'black_ratio', config.black_ratio) setattr(laed_config, 'include_domain', True) setattr(laed_config, 'include_example', False) setattr(laed_config, 'include_state', True) setattr(laed_config, 'entities_file', 'NeuralDialog-ZSDG/data/stanford/kvret_entities.json') setattr(laed_config, 'action_match', True) setattr(laed_config, 'batch_size', config.batch_size) setattr(laed_config, 'data_dir', config.data_dir) setattr(laed_config, 'include_eod', False) # for StED model setattr(laed_config, 'domain_description', config.domain_description) if config.process_seed_data: assert config.corpus_client[:3] == 'Zsl', 'Incompatible coprus_client for --process_seed_data flag' corpus_client = getattr(corpora, config.corpus_client)(laed_config) if config.vocab: corpus_client.vocab, corpus_client.rev_vocab, corpus_client.unk_token = load_vocab(config.vocab) prepare_dirs_loggers(config, os.path.basename(__file__)) dial_corpus = corpus_client.get_corpus() # train_dial, valid_dial, test_dial = dial_corpus['train'], dial_corpus['valid'], dial_corpus['test'] # all_dial = train_dial + valid_dial + test_dial # all_utts = reduce(lambda x, y: x + y, all_dial, []) model = load_model(config.model, config.model_name, config.model_type, corpus_client=corpus_client) if config.use_gpu: model.cuda() for dataset_name in ['train', 'valid', 'test']: dataset = dial_corpus[dataset_name] feed_data = dataset if config.model_type == 'dialog' else reduce(lambda x, y: x + y, dataset, []) # create data loader that feed the deep models if config.process_seed_data: seed_utts = corpus_client.get_seed_responses(utt_cnt=len(corpus_client.domain_descriptions)) main_feed = getattr(data_loaders, config.data_loader)("Test", feed_data, laed_config) features = process_data_feed(model, main_feed, laed_config) if config.data_loader == 'SMDDialogSkipLoader': pad_mode = 'start_end' elif config.data_loader == 'SMDDataLoader': pad_mode = 'start' else: pad_mode = None features = deflatten_laed_features(features, dataset, pad_mode=pad_mode) assert sum(map(len, dataset)) == sum(map(lambda x: x.shape[0], features)) if not os.path.exists(config.out_folder): os.makedirs(config.out_folder) with open(os.path.join(config.out_folder, 'dialogs_{}.pkl'.format(dataset_name)), 'w') as result_out: pickle.dump(features, result_out) if config.process_seed_data: seed_utts = corpus_client.get_seed_responses(utt_cnt=len(corpus_client.domain_descriptions)) seed_feed = data_loaders.PTBDataLoader("Seed", seed_utts, laed_config) seed_features = process_data_feed(model, seed_feed, laed_config) with open(os.path.join(config.out_folder, 'seed_utts.pkl'), 'w') as result_out: pickle.dump(seed_features, result_out)
def main(config): prepare_dirs_loggers(config, os.path.basename(__file__)) corpus_client = corpora.NormMultiWozCorpus(config) dial_corpus = corpus_client.get_corpus() train_dial, valid_dial, test_dial = dial_corpus # evaluator = evaluators.BleuEvaluator("os.path.basename(__file__)") evaluator = MultiWozEvaluator('SysWoz') # create data loader that feed the deep models train_feed = data_loaders.BeliefDbDataLoaders("Train", train_dial, config) valid_feed = data_loaders.BeliefDbDataLoaders("Valid", valid_dial, config) test_feed = data_loaders.BeliefDbDataLoaders("Test", test_dial, config) model = dialog_models.AeED(corpus_client, config) if config.forward_only: test_file = os.path.join( config.log_dir, config.load_sess, "{}-test-{}.txt".format(get_time(), config.gen_type)) dump_file = os.path.join(config.log_dir, config.load_sess, "{}-z.pkl".format(get_time())) model_file = os.path.join(config.log_dir, config.load_sess, "model") else: test_file = os.path.join( config.session_dir, "{}-test-{}.txt".format(get_time(), config.gen_type)) dump_file = os.path.join(config.session_dir, "{}-z.pkl".format(get_time())) model_file = os.path.join(config.session_dir, "model") if config.use_gpu: model.cuda() if config.forward_only is False: try: engine.train(model, train_feed, valid_feed, test_feed, config, evaluator, gen=dialog_utils.generate_with_adv) except KeyboardInterrupt: print("Training stopped by keyboard.") config.batch_size = 10 model.load_state_dict(torch.load(model_file)) logger.info("Test Bleu:") dialog_utils.generate(model, test_feed, config, evaluator, test_feed.num_batch, False) engine.validate(model, valid_feed, config) engine.validate(model, test_feed, config) dialog_utils.generate_with_adv(model, test_feed, config, None, num_batch=None) cluster_name_id = None action_count = 0 selected_clusters, index_cluster_id_train, cluster_name_id, action_count = utt_utils.latent_cluster( model, train_feed, config, None, 0, num_batch=None) _, index_cluster_id_test, cluster_name_id, action_count = utt_utils.latent_cluster( model, test_feed, config, cluster_name_id, action_count, num_batch=None) _, index_cluster_id_valid, cluster_name_id, action_count = utt_utils.latent_cluster( model, valid_feed, config, cluster_name_id, action_count, num_batch=None) selected_outs = dialog_utils.selective_generate(model, test_feed, config, selected_clusters) print(len(selected_outs)) with open(os.path.join(dump_file + '.json'), 'wb') as f: json.dump(selected_clusters, f, indent=2) with open(os.path.join(dump_file + '.cluster_id.json.Train'), 'wb') as f: json.dump(index_cluster_id_train, f, indent=2) with open(os.path.join(dump_file + '.cluster_id.json.Test'), 'wb') as f: json.dump(index_cluster_id_test, f, indent=2) with open(os.path.join(dump_file + '.cluster_id.json.Valid'), 'wb') as f: json.dump(index_cluster_id_valid, f, indent=2) with open(os.path.join(dump_file + '.out.json'), 'wb') as f: json.dump(selected_outs, f, indent=2) with open(os.path.join(dump_file), "wb") as f: print("Dumping test to {}".format(dump_file)) dialog_utils.dump_latent(model, test_feed, config, f, num_batch=None) with open(os.path.join(test_file), "wb") as f: print("Saving test to {}".format(test_file)) dialog_utils.gen_with_cond(model, test_feed, config, num_batch=None, dest_f=f) with open(os.path.join(test_file + '.txt'), "wb") as f: print("Saving test to {}".format(test_file)) dialog_utils.generate(model, test_feed, config, evaluator, num_batch=None, dest_f=f) print("All done. Have a nice day")
def main(config): prepare_dirs_loggers(config, os.path.basename(__file__)) manualSeed=config.seed random.seed(manualSeed) torch.manual_seed(manualSeed) np.random.seed(manualSeed) sample_shape = config.batch_size, config.state_noise_dim, config.action_noise_dim evaluator = evaluators.BleuEvaluator(os.path.basename(__file__)) train_feed = WoZGanDataLoaders("train", config) valid_feed = WoZGanDataLoaders("val", config) test_feed = WoZGanDataLoaders("test", config) # action2name = load_action2name(config) action2name = None corpus_client = None if config.gan_type=='gan' and config.input_type=='sat': model = GanAgent_SAT_WoZ(corpus_client, config, action2name) else: raise ValueError("No such GAN types: {}".format(config.gan_type)) logger.info(summary(model, show_weights=False)) model.discriminator.apply(weights_init) model.generator.apply(weights_init) if config.state_type=='rnn': load_context_encoder(model, os.path.join(config.log_dir, config.encoder_sess, "model_lirl")) if config.forward_only: test_file = os.path.join(config.log_dir, config.load_sess, "{}-test-{}.txt".format(get_time(), config.gen_type)) dump_file = os.path.join(config.log_dir, config.load_sess, "{}-z.pkl".format(get_time())) model_file = os.path.join(config.log_dir, config.load_sess, "model") else: test_file = os.path.join(config.session_dir, "{}-test-{}.txt".format(get_time(), config.gen_type)) dump_file = os.path.join(config.session_dir, "{}-z.pkl".format(get_time())) model_file = os.path.join(config.session_dir, "model") vocab_file = os.path.join(config.session_dir, "vocab.json") if config.use_gpu: model.cuda() pred_list = [] generator_samples = [] print("Evaluate initial model on Validate set") model.eval() # policy_validate_for_human(model,valid_feed, config, sample_shape) disc_validate(model, valid_feed, config, sample_shape) _, sample_batch = gen_validate(model,valid_feed, config, sample_shape, -1) generator_samples.append([-1, sample_batch]) machine_data, human_data = build_fake_data(model, valid_feed, config, sample_shape) model.train() print("Start training") if config.forward_only is False: try: engine.gan_train(model, machine_data, train_feed, valid_feed, test_feed, config, evaluator, pred_list, generator_samples) except KeyboardInterrupt: print("Training stopped by keyboard.") # save_data_for_tsne(human_data, machine_data, generator_samples, pred_list, config) print("Training Done ! ") ''' model.load_state_dict(torch.load(model_file)) print("Evaluate final model on Validate set") model.eval() policy_validate_for_human(model,valid_feed, config, sample_shape) disc_validate(model, valid_feed, config, sample_shape) gen_validate(model,valid_feed, config, sample_shape) print("Evaluate final model on Test set") policy_validate_for_human(model,test_feed, config, sample_shape) disc_validate(model, test_feed, config, sample_shape) gen_validate(model,test_feed, config, sample_shape) ''' # dialog_utils.generate_with_adv(model, test_feed, config, None, num_batch=None) # selected_clusters, index_cluster_id_train = utt_utils.latent_cluster(model, train_feed, config, num_batch=None) # _, index_cluster_id_test = utt_utils.latent_cluster(model, test_feed, config, num_batch=None) # _, index_cluster_id_valid = utt_utils.latent_cluster(model, valid_feed, config, num_batch=None) # selected_outs = dialog_utils.selective_generate(model, test_feed, config, selected_clusters) # print(len(selected_outs)) '''
def main(config): set_seed(config.seed) start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) stats_path = 'sys_config_log_model' if config.forward_only: saved_path = os.path.join(stats_path, config.pretrain_folder) config = Pack(json.load(open(os.path.join(saved_path, 'config.json')))) config['forward_only'] = True else: saved_path = os.path.join( stats_path, start_time + '-' + os.path.basename(__file__).split('.')[0]) if not os.path.exists(saved_path): os.makedirs(saved_path) config.saved_path = saved_path prepare_dirs_loggers(config) logger = logging.getLogger() logger.info('[START]\n{}\n{}'.format(start_time, '=' * 30)) corpus = corpora.MovieCorpus(config) train_dial, valid_dial, test_dial = corpus.get_corpus() sample_shape = config.batch_size, config.state_noise_dim, config.action_noise_dim # evaluator = MultiWozEvaluator("os.path.basename(__file__)") evaluator = BleuEvaluator(os.path.basename(__file__)) # create data loader that feed the deep models train_data = MovieDataLoaders("Train", train_dial, config) valid_data = MovieDataLoaders("Valid", valid_dial, config) test_data = MovieDataLoaders("Test", test_dial, config) model = LIRL(corpus, config) if config.use_gpu: model.cuda() best_epoch = None if not config.forward_only: try: best_epoch = train(model, train_data, valid_data, test_data, config, evaluator, gen=task_generate) except KeyboardInterrupt: print('Training stopped by keyboard.') if best_epoch is None: model_ids = sorted([ int(p.replace('-model', '')) for p in os.listdir(saved_path) if 'model' in p and 'rl' not in p ]) best_epoch = model_ids[-1] print("$$$ Load {}-model".format(best_epoch)) config.batch_size = 32 best_epoch = best_epoch model.load_state_dict( torch.load(os.path.join(saved_path, '{}-model'.format(best_epoch)))) logger.info("Forward Only Evaluation") validate(model, valid_data, config) validate(model, test_data, config) with open( os.path.join(saved_path, '{}_{}_valid_file.txt'.format(start_time, best_epoch)), 'w') as f: task_generate(model, valid_data, config, evaluator, num_batch=None, dest_f=f) with open( os.path.join(saved_path, '{}_{}_test_file.txt'.format(start_time, best_epoch)), 'w') as f: task_generate(model, test_data, config, evaluator, num_batch=None, dest_f=f) end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) print('[END]', end_time, '=' * 30)
def main(config): prepare_dirs_loggers(config, os.path.basename(__file__)) manualSeed = config.seed random.seed(manualSeed) torch.manual_seed(manualSeed) np.random.seed(manualSeed) sample_shape = config.batch_size, config.state_noise_dim, config.action_noise_dim evaluator = evaluators.BleuEvaluator(os.path.basename(__file__)) train_feed = WoZGanDataLoaders_StateActionEmbed("train", config) valid_feed = WoZGanDataLoaders_StateActionEmbed("val", config) test_feed = WoZGanDataLoaders_StateActionEmbed("test", config) # action2name = load_action2name(config) action2name = None corpus_client = None # model = GanAgent_VAE_StateActioneEmbed(corpus_client, config, action2name) model = GanAgent_StateVaeActionSeg(corpus_client, config, action2name) logger.info(summary(model, show_weights=False)) model.discriminator.apply(weights_init) model.generator.apply(weights_init) model.vae.apply(weights_init) if config.forward_only: test_file = os.path.join( config.log_dir, config.load_sess, "{}-test-{}.txt".format(get_time(), config.gen_type)) dump_file = os.path.join(config.log_dir, config.load_sess, "{}-z.pkl".format(get_time())) model_file = os.path.join(config.log_dir, config.load_sess, "model") else: test_file = os.path.join( config.session_dir, "{}-test-{}.txt".format(get_time(), config.gen_type)) dump_file = os.path.join(config.session_dir, "{}-z.pkl".format(get_time())) model_file = os.path.join(config.session_dir, "model") vocab_file = os.path.join(config.session_dir, "vocab.json") if config.use_gpu: model.cuda() pred_list = [] generator_samples = [] print("Evaluate initial model on Validate set") model.eval() # policy_validate_for_human(model,valid_feed, config, sample_shape) disc_validate(model, valid_feed, config, sample_shape) _, sample_batch = gen_validate(model, valid_feed, config, sample_shape, -1) generator_samples.append([-1, sample_batch]) machine_data, human_data = build_fake_data(model, valid_feed, config, sample_shape) model.train() print("Start training") if config.forward_only is False: try: engine.vae_train(model, train_feed, valid_feed, test_feed, config) except KeyboardInterrupt: print("Training stopped by keyboard.") print("AutoEncoder Training Done ! ") load_model_vae(model, config) # path='logs/2019-09-18T12:20:26.063708-mwoz_gan_vae_StateActionEmbed.py' # this is embed version # path='logs/2019-09-18T12:24:45.517636-mwoz_gan_vae_StateActionEmbed.py' # this is embed_merged version # path='logs/2019-09-18T17:21:35.420069-mwoz_gan_vae_StateActionEmbed.py' # this is state_vae action_seg version, without hotel domain # load_model_vae(model, path) if config.forward_only is False: try: engine.gan_train(model, machine_data, train_feed, valid_feed, test_feed, config, evaluator, pred_list, generator_samples) except KeyboardInterrupt: print("Training stopped by keyboard.") print("Reward Model Training Done ! ") print("Saved path: {}".format(model_file))