def main(config): prepare_dirs_loggers(config, os.path.basename(__file__)) corpus_client = ZslStanfordCorpus(config) warmup_data = corpus_client.get_seed_responses(config.target_example_cnt) dial_corpus = corpus_client.get_corpus() train_dial, valid_dial, test_dial = dial_corpus['train'],\ dial_corpus['valid'],\ dial_corpus['test'] evaluator = evaluators.BleuEntEvaluator("SMD", corpus_client.ent_metas) # create data loader that feed the deep models train_feed = data_loaders.ZslSMDDialDataLoader("Train", train_dial, config, warmup_data) valid_feed = data_loaders.ZslSMDDialDataLoader("Valid", valid_dial, config) test_feed = data_loaders.ZslSMDDialDataLoader("Test", test_dial, config) if config.action_match: if config.use_ptr: model = models.ZeroShotPtrHRED(corpus_client, config) else: model = models.ZeroShotHRED(corpus_client, config) else: if config.use_ptr: model = models.PtrHRED(corpus_client, config) else: model = models.HRED(corpus_client, config) if config.forward_only: session_dir = os.path.join(config.log_dir, config.load_sess) test_file = os.path.join(session_dir, "{}-test-{}.txt".format(get_time(), config.gen_type)) model_file = os.path.join(config.log_dir, config.load_sess, "model") else: session_dir = config.session_dir test_file = os.path.join(config.session_dir, "{}-test-{}.txt".format(get_time(), config.gen_type)) model_file = os.path.join(config.session_dir, "model") if config.use_gpu: model.cuda() if config.forward_only is False: try: train(model, train_feed, valid_feed, test_feed, config, evaluator, gen=hred_utils.generate) except KeyboardInterrupt: print("Training stopped by keyboard.") config.batch_size = 20 model.load_state_dict(torch.load(model_file)) # hred_utils.dump_latent(model, test_feed, config, session_dir) # run the model on the test dataset. validate(model, test_feed, config) with open(os.path.join(test_file), "wb") as f: hred_utils.generate(model, test_feed, config, evaluator, num_batch=None, dest_f=f)
def main(config): prepare_dirs_loggers(config, os.path.basename(__file__)) corpus_client_class = get_corpus_client(config) train_client = corpus_client_class(config) corpus = train_client.get_corpus() train_dial, valid_dial, test_dial = corpus['train'], corpus[ 'valid'], corpus['test'] evaluator = evaluators.BleuEntEvaluator("SMD", train_client.ent_metas) # create data loader that feed the deep models data_loader_class = get_data_loader(config) train_feed = data_loader_class("Train", train_dial, config) valid_feed = data_loader_class("Valid", valid_dial, config) test_feed = data_loader_class("Test", test_dial, config) model = get_model(config, train_client) if config.forward_only: session_dir = os.path.join(config.log_dir, config.load_sess) test_file = os.path.join( session_dir, "{}-test-{}.txt".format(get_time(), config.gen_type)) model_file = os.path.join(config.log_dir, config.load_sess, "model") else: session_dir = config.session_dir test_file = os.path.join( session_dir, "{}-test-{}.txt".format(get_time(), config.gen_type)) model_file = os.path.join(config.session_dir, "model") if config.use_gpu: model.cuda() if not config.forward_only: try: train(model, train_feed, valid_feed, test_feed, config, evaluator, gen=hred_utils.generate) except KeyboardInterrupt: print("Training stopped by keyboard.") config.batch_size = 10 model.load_state_dict(torch.load(model_file)) # run the model on the test dataset. validate(model, test_feed, config) with open(os.path.join(test_file), "wb") as f: hred_utils.generate(model, test_feed, config, evaluator, num_batch=None, dest_f=f)
def main(config): prepare_dirs_loggers(config, os.path.basename(__file__)) corpus_client = getattr(corpora, config.corpus_client)(config) corpus_client.vocab, corpus_client.rev_vocab, corpus_client.unk_id = load_vocab( config.vocab) # warmup_data = maluuba_client.get_seed_responses(len(maluuba_client.domain_descriptions)) # maluuba_corpus = maluuba_client.get_corpus() # train_dial, valid_dial = maluuba_corpus['train'], maluuba_corpus['valid'] corpus = corpus_client.get_corpus() train_dial, valid_dial, test_dial = (corpus['train'], corpus['valid'], corpus['test']) evaluator = evaluators.BleuEntEvaluator("SMD", corpus_client.ent_metas) laed_z = load_laed_features(config.laed_z_folder) config.laed_z_size = laed_z['dialog'][0].shape[-1] laed_z_test = laed_z['dialog'][len(train_dial) + len(valid_dial):] test_feed = data_loaders.ZslLASMDDialDataLoader("Test", test_dial, laed_z_test, [], config) if config.action_match: if config.use_ptr: model = ZeroShotLAPtrHRED(corpus_client, config) else: raise NotImplementedError() else: raise NotImplementedError() session_dir = os.path.join(config.log_dir, config.load_sess) test_file = os.path.join( session_dir, "{}-test-{}.txt".format(get_time(), config.gen_type)) model_file = os.path.join(config.log_dir, config.load_sess, "model") if config.use_gpu: model.cuda() config.batch_size = 20 model.load_state_dict(torch.load(model_file)) # run the model on the test dataset. validate(model, test_feed, config) with open(os.path.join(test_file), "wb") as f: hred_utils.generate(model, test_feed, config, evaluator, num_batch=None, dest_f=f)
def main(config): prepare_dirs_loggers(config, os.path.basename(__file__)) train_client = getattr(corpora, config.corpus_client)(config) utt_cnt_map = defaultdict(lambda: config.source_example_cnt) for black_domain in config.black_domains: utt_cnt_map[black_domain] = config.target_example_cnt warmup_data = train_client.get_seed_responses(utt_cnt_map) train_corpus = train_client.get_corpus() train_dial, valid_dial, test_dial = train_corpus['train'], train_corpus[ 'valid'], train_corpus['test'] evaluator = evaluators.BleuEntEvaluator("SMD", train_client.ent_metas) data_loader_class = data_loaders.ZslLASMDDialDataLoader \ if len(config.laed_z_folders) \ else data_loaders.ZslSMDDialDataLoader # create data loader that feed the deep models train_feed = data_loader_class("Train", train_dial, config, warmup_data) valid_feed = data_loader_class("Valid", valid_dial, config) test_feed = data_loader_class("Test", test_dial, config) model = get_model(train_client, config) if config.forward_only: session_dir = os.path.join(config.log_dir, config.load_sess) test_file = os.path.join( session_dir, "{}-test-{}.txt".format(get_time(), config.gen_type)) model_file = os.path.join(config.log_dir, config.load_sess, "model") else: session_dir = config.session_dir test_file = os.path.join( session_dir, "{}-test-{}.txt".format(get_time(), config.gen_type)) model_file = os.path.join(config.session_dir, "model") if config.use_gpu: model.cuda() if not config.forward_only: try: train(model, train_feed, valid_feed, test_feed, config, evaluator, gen=hred_utils.generate) except KeyboardInterrupt: print("Training stopped by keyboard.") config.batch_size = 10 model.load_state_dict(torch.load(model_file)) # run the model on the test dataset. validate(model, test_feed, config) with open(os.path.join(test_file), "wb") as f: hred_utils.generate(model, test_feed, config, evaluator, num_batch=None, dest_f=f)