Beispiel #1
0
def main(config):
    prepare_dirs_loggers(config, os.path.basename(__file__))

    corpus_client = corpora.NormMultiWozCorpus(config)

    dial_corpus = corpus_client.get_corpus()
    train_dial, valid_dial, test_dial = dial_corpus
    sample_shape = config.batch_size, config.state_noise_dim, config.action_noise_dim
    # evaluator = evaluators.BleuEvaluator("os.path.basename(__file__)")
    evaluator = MultiWozEvaluator('SysWoz')
    # create data loader that feed the deep models
    train_feed = data_loaders.BeliefDbDataLoaders("Train", train_dial, config)
    valid_feed = data_loaders.BeliefDbDataLoaders("Valid", valid_dial, config)
    test_feed = data_loaders.BeliefDbDataLoaders("Test", test_dial, config)
    model = GanRnnAgent(corpus_client, config)
    load_context_encoder(
        model, os.path.join(config.log_dir, config.encoder_sess, "model_lirl"))

    if config.forward_only:
        test_file = os.path.join(
            config.log_dir, config.load_sess,
            "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.log_dir, config.load_sess,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.log_dir, config.load_sess, "model")
    else:
        test_file = os.path.join(
            config.session_dir,
            "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.session_dir,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.session_dir, "model")

    if config.use_gpu:
        model.cuda()

    print("Evaluate initial model on Validate set")
    engine.disc_validate(model, valid_feed, config, sample_shape)
    print("Start training")

    if config.forward_only is False:
        try:
            engine.gan_train(model, train_feed, valid_feed, test_feed, config,
                             evaluator)
        except KeyboardInterrupt:
            print("Training stopped by keyboard.")
    print("Trainig Done! Start Testing")
    model.load_state_dict(torch.load(model_file))
    engine.disc_validate(model, valid_feed, config, sample_shape)
    engine.disc_validate(model, test_feed, config, sample_shape)

    # dialog_utils.generate_with_adv(model, test_feed, config, None, num_batch=None)
    # selected_clusters, index_cluster_id_train = utt_utils.latent_cluster(model, train_feed, config, num_batch=None)
    # _, index_cluster_id_test = utt_utils.latent_cluster(model, test_feed, config, num_batch=None)
    # _, index_cluster_id_valid = utt_utils.latent_cluster(model, valid_feed, config, num_batch=None)
    # selected_outs = dialog_utils.selective_generate(model, test_feed, config, selected_clusters)
    # print(len(selected_outs))
    '''
Beispiel #2
0
def main(config):
    prepare_dirs_loggers(config, os.path.basename(__file__))

    corpus_client = corpora.DailyDialogCorpus(config)

    dial_corpus = corpus_client.get_corpus()
    train_dial, valid_dial, test_dial = dial_corpus['train'],\
                                        dial_corpus['valid'],\
                                        dial_corpus['test']
    evaluator = evaluators.BleuEvaluator("CornellMovie")

    # create data loader that feed the deep models
    train_feed = data_loaders.DailyDialogLoader("Train", train_dial, config)
    valid_feed = data_loaders.DailyDialogLoader("Valid", valid_dial, config)
    test_feed = data_loaders.DailyDialogLoader("Test", test_dial, config)
    model = sent_models.DiVAE(corpus_client, config)

    if config.forward_only:
        test_file = os.path.join(config.log_dir, config.load_sess,
                                 "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.log_dir, config.load_sess,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.log_dir, config.load_sess, "model")
    else:
        test_file = os.path.join(config.session_dir,
                                 "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.session_dir, "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.session_dir, "model")

    if config.use_gpu:
        model.cuda()

    if config.forward_only is False:
        try:
            engine.train(model, train_feed, valid_feed,
                         test_feed, config, evaluator, gen=utt_utils.generate)
        except KeyboardInterrupt:
            print("Training stopped by keyboard.")


    config.batch_size = 50
    model.load_state_dict(torch.load(model_file))

    engine.validate(model, test_feed, config)
    utt_utils.sweep(model, test_feed, config, num_batch=50)

    with open(os.path.join(dump_file), "wb") as f:
        print("Dumping test to {}".format(dump_file))
        utt_utils.dump_latent(model, test_feed, config, f, num_batch=None)

    with open(os.path.join(test_file), "wb") as f:
        print("Saving test to {}".format(test_file))
        utt_utils.generate(model, test_feed, config, evaluator, num_batch=None, dest_f=f)
def main(config):
    corpus_client = getattr(corpora, config.corpus_client)(config)
    corpus_client.vocab, corpus_client.rev_vocab, corpus_client.unk_id = load_vocab(
        config.vocab)
    prepare_dirs_loggers(config, os.path.basename(__file__))

    dial_corpus = corpus_client.get_corpus()
    train_dial, valid_dial, test_dial = (dial_corpus['train'],
                                         dial_corpus['valid'],
                                         dial_corpus['test'])

    evaluator = evaluators.BleuEvaluator("CornellMovie")

    # create data loader that feed the deep models
    train_feed = data_loaders.SMDDialogSkipLoader("Train", train_dial, config)
    valid_feed = data_loaders.SMDDialogSkipLoader("Valid", valid_dial, config)
    test_feed = data_loaders.SMDDialogSkipLoader("Test", test_dial, config)
    model = dialog_models.VAE(corpus_client, config)

    if config.forward_only:
        test_file = os.path.join(
            config.log_dir, config.load_sess,
            "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.log_dir, config.load_sess,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.log_dir, config.load_sess, "model")
    else:
        test_file = os.path.join(
            config.session_dir,
            "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.session_dir,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.session_dir, "model")

    if config.use_gpu:
        model.cuda()

    if config.forward_only is False:
        try:
            engine.train(model,
                         train_feed,
                         valid_feed,
                         test_feed,
                         config,
                         evaluator,
                         gen=dialog_utils.generate_vae)
        except KeyboardInterrupt:
            print("Training stopped by keyboard.")
Beispiel #4
0
def main(config):
    prepare_dirs_loggers(config, os.path.basename(__file__))
    manualSeed=config.seed
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    np.random.seed(manualSeed)
    sample_shape = config.batch_size, config.state_noise_dim, config.action_noise_dim

    # evaluator = evaluators.BleuEvaluator(os.path.basename(__file__))
    evaluator = False

    train_feed = WoZGanDataLoaders("train", config)
    valid_feed = WoZGanDataLoaders("val", config)
    test_feed = WoZGanDataLoaders("test", config)


 # action2name = load_action2name(config)
    action2name = None
    corpus_client = None
    # model = GanAgent_AutoEncoder(corpus_client, config, action2name)
    # model = GanAgent_AutoEncoder_Encode(corpus_client, config, action2name)
    # model = GanAgent_AutoEncoder_State(corpus_client, config, action2name)
    if config.gan_type=='wgan':
        model = WGanAgent_VAE_State(corpus_client, config, action2name)
    else:
        model = GanAgent_VAE_State(corpus_client, config, action2name)
    
    logger.info(summary(model, show_weights=False))
    model.discriminator.apply(weights_init)
    model.generator.apply(weights_init)
    model.vae.apply(weights_init)

    if config.forward_only:
        test_file = os.path.join(config.log_dir, config.load_sess,
                                 "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.log_dir, config.load_sess,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.log_dir, config.load_sess, "model")
    else:
        test_file = os.path.join(config.session_dir,
                                 "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.session_dir, "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.session_dir, "model")
        vocab_file = os.path.join(config.session_dir, "vocab.json")
    

    if config.use_gpu:
        model.cuda()

    pred_list = []
    generator_samples = []
    print("Evaluate initial model on Validate set")
    model.eval()
    # policy_validate_for_human(model,valid_feed, config, sample_shape)
    disc_validate(model, valid_feed, config, sample_shape)
    _, sample_batch = gen_validate(model,valid_feed, config, sample_shape, -1)
    generator_samples.append([-1, sample_batch])
    machine_data, human_data = build_fake_data(model, valid_feed, config, sample_shape)

    
    model.train()
    print("Start VAE training")


    # # this is for the training of VAE. If you already have a pretrained model, you can skip this step.
    # if config.forward_only is False:
    #     try:
    #         engine.vae_train(model, train_feed, valid_feed, test_feed, config)
    #     except KeyboardInterrupt:
    #         print("Training stopped by keyboard.")
    # print("AutoEncoder Training Done ! ")
    # load_model_vae(model, config)
    
    
    # this is a pretrained vae model, you can load it to the current model. TODO: move path todata_args
    path='./logs/2019-09-06T10:50:18.034181-mwoz_gan_vae.py'
    load_model_vae(model, path)
    
    print("Start GAN training")
    
    if config.forward_only is False:
        try:
            engine.gan_train(model, machine_data, train_feed, valid_feed, test_feed, config, evaluator, pred_list, generator_samples)
        except KeyboardInterrupt:
            print("Training stopped by keyboard.")
    print("Reward Model Training Done ! ")
    print("Saved path: {}".format(model_file))
def main(config):
    corpus_client = getattr(corpora, config.corpus_client)(config)
    corpus_client.vocab, corpus_client.rev_vocab, corpus_client.unk_id = load_vocab(
        config.vocab)
    prepare_dirs_loggers(config, os.path.basename(__file__))

    dial_corpus = corpus_client.get_corpus()
    train_dial, valid_dial, test_dial = (dial_corpus['train'],
                                         dial_corpus['valid'],
                                         dial_corpus['test'])

    evaluator = evaluators.BleuEvaluator("CornellMovie")

    # create data loader that feed the deep models
    train_feed = data_loaders.SMDDialogSkipLoader("Train", train_dial, config)
    valid_feed = data_loaders.SMDDialogSkipLoader("Valid", valid_dial, config)
    test_feed = data_loaders.SMDDialogSkipLoader("Test", test_dial, config)
    model = dialog_models.StED(corpus_client, config)

    if config.forward_only:
        test_file = os.path.join(
            config.log_dir, config.load_sess,
            "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.log_dir, config.load_sess,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.log_dir, config.load_sess, "model")
    else:
        test_file = os.path.join(
            config.session_dir,
            "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.session_dir,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.session_dir, "model")

    if config.use_gpu:
        model.cuda()

    if not config.forward_only:
        try:
            engine.train(model,
                         train_feed,
                         valid_feed,
                         test_feed,
                         config,
                         evaluator,
                         gen=dialog_utils.generate_with_adv)
        except KeyboardInterrupt:
            print("Training stopped by keyboard.")

    config.batch_size = 10
    model.load_state_dict(torch.load(model_file))
    engine.validate(model, valid_feed, config)
    engine.validate(model, test_feed, config)

    dialog_utils.generate_with_adv(model, test_feed, config, None, num_batch=0)
    selected_clusters = utt_utils.latent_cluster(model,
                                                 train_feed,
                                                 config,
                                                 num_batch=None)
    selected_outs = dialog_utils.selective_generate(model, test_feed, config,
                                                    selected_clusters)
    print(len(selected_outs))

    with open(os.path.join(dump_file + '.json'), 'wb') as f:
        json.dump(selected_clusters, f, indent=2)

    with open(os.path.join(dump_file + '.out.json'), 'wb') as f:
        json.dump(selected_outs, f, indent=2)

    with open(os.path.join(dump_file), "wb") as f:
        print("Dumping test to {}".format(dump_file))
        dialog_utils.dump_latent(model, test_feed, config, f, num_batch=None)

    with open(os.path.join(test_file), "wb") as f:
        print("Saving test to {}".format(test_file))
        dialog_utils.gen_with_cond(model,
                                   test_feed,
                                   config,
                                   num_batch=None,
                                   dest_f=f)

    with open(os.path.join(test_file + '.txt'), "wb") as f:
        print("Saving test to {}".format(test_file))
        dialog_utils.generate(model,
                              test_feed,
                              config,
                              evaluator,
                              num_batch=None,
                              dest_f=f)
Beispiel #6
0
def main(config):
    laed_config = load_config(config.model)
    laed_config.use_gpu = config.use_gpu
    laed_config = process_config(laed_config)

    setattr(laed_config, 'black_domains', config.black_domains)
    setattr(laed_config, 'black_ratio', config.black_ratio)
    setattr(laed_config, 'include_domain', True)
    setattr(laed_config, 'include_example', False)
    setattr(laed_config, 'include_state', True)
    setattr(laed_config, 'entities_file', 'NeuralDialog-ZSDG/data/stanford/kvret_entities.json')
    setattr(laed_config, 'action_match', True)
    setattr(laed_config, 'batch_size', config.batch_size)
    setattr(laed_config, 'data_dir', config.data_dir)
    setattr(laed_config, 'include_eod', False) # for StED model
    setattr(laed_config, 'domain_description', config.domain_description)

    if config.process_seed_data:
        assert config.corpus_client[:3] == 'Zsl', 'Incompatible coprus_client for --process_seed_data flag'
    corpus_client = getattr(corpora, config.corpus_client)(laed_config)
    if config.vocab:
        corpus_client.vocab, corpus_client.rev_vocab, corpus_client.unk_token = load_vocab(config.vocab)
    prepare_dirs_loggers(config, os.path.basename(__file__))

    dial_corpus = corpus_client.get_corpus()
    # train_dial, valid_dial, test_dial = dial_corpus['train'], dial_corpus['valid'], dial_corpus['test']
    # all_dial = train_dial + valid_dial + test_dial
    # all_utts = reduce(lambda x, y: x + y, all_dial, [])

    model = load_model(config.model, config.model_name, config.model_type, corpus_client=corpus_client)

    if config.use_gpu:
        model.cuda()

    for dataset_name in ['train', 'valid', 'test']:
        dataset = dial_corpus[dataset_name]
        feed_data = dataset if config.model_type == 'dialog' else reduce(lambda x, y: x + y, dataset, [])

        # create data loader that feed the deep models
        if config.process_seed_data:
            seed_utts = corpus_client.get_seed_responses(utt_cnt=len(corpus_client.domain_descriptions))
        main_feed = getattr(data_loaders, config.data_loader)("Test", feed_data, laed_config)

        features = process_data_feed(model, main_feed, laed_config)
        if config.data_loader == 'SMDDialogSkipLoader':
            pad_mode = 'start_end'
        elif config.data_loader == 'SMDDataLoader':
            pad_mode = 'start'
        else:
            pad_mode = None
        features = deflatten_laed_features(features, dataset, pad_mode=pad_mode)
        assert sum(map(len, dataset)) == sum(map(lambda x: x.shape[0], features))

        if not os.path.exists(config.out_folder):
            os.makedirs(config.out_folder)
        with open(os.path.join(config.out_folder, 'dialogs_{}.pkl'.format(dataset_name)), 'w') as result_out:
            pickle.dump(features, result_out)

    if config.process_seed_data:
        seed_utts = corpus_client.get_seed_responses(utt_cnt=len(corpus_client.domain_descriptions))
        seed_feed = data_loaders.PTBDataLoader("Seed", seed_utts, laed_config)
        seed_features = process_data_feed(model, seed_feed, laed_config)
        with open(os.path.join(config.out_folder, 'seed_utts.pkl'), 'w') as result_out:
            pickle.dump(seed_features, result_out)
def main(config):
    prepare_dirs_loggers(config, os.path.basename(__file__))

    corpus_client = corpora.NormMultiWozCorpus(config)

    dial_corpus = corpus_client.get_corpus()
    train_dial, valid_dial, test_dial = dial_corpus

    # evaluator = evaluators.BleuEvaluator("os.path.basename(__file__)")
    evaluator = MultiWozEvaluator('SysWoz')
    # create data loader that feed the deep models
    train_feed = data_loaders.BeliefDbDataLoaders("Train", train_dial, config)
    valid_feed = data_loaders.BeliefDbDataLoaders("Valid", valid_dial, config)
    test_feed = data_loaders.BeliefDbDataLoaders("Test", test_dial, config)
    model = dialog_models.AeED(corpus_client, config)

    if config.forward_only:
        test_file = os.path.join(
            config.log_dir, config.load_sess,
            "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.log_dir, config.load_sess,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.log_dir, config.load_sess, "model")
    else:
        test_file = os.path.join(
            config.session_dir,
            "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.session_dir,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.session_dir, "model")

    if config.use_gpu:
        model.cuda()

    if config.forward_only is False:
        try:
            engine.train(model,
                         train_feed,
                         valid_feed,
                         test_feed,
                         config,
                         evaluator,
                         gen=dialog_utils.generate_with_adv)
        except KeyboardInterrupt:
            print("Training stopped by keyboard.")

    config.batch_size = 10
    model.load_state_dict(torch.load(model_file))
    logger.info("Test Bleu:")
    dialog_utils.generate(model, test_feed, config, evaluator,
                          test_feed.num_batch, False)
    engine.validate(model, valid_feed, config)
    engine.validate(model, test_feed, config)

    dialog_utils.generate_with_adv(model,
                                   test_feed,
                                   config,
                                   None,
                                   num_batch=None)
    cluster_name_id = None
    action_count = 0
    selected_clusters, index_cluster_id_train, cluster_name_id, action_count = utt_utils.latent_cluster(
        model, train_feed, config, None, 0, num_batch=None)
    _, index_cluster_id_test, cluster_name_id, action_count = utt_utils.latent_cluster(
        model,
        test_feed,
        config,
        cluster_name_id,
        action_count,
        num_batch=None)
    _, index_cluster_id_valid, cluster_name_id, action_count = utt_utils.latent_cluster(
        model,
        valid_feed,
        config,
        cluster_name_id,
        action_count,
        num_batch=None)
    selected_outs = dialog_utils.selective_generate(model, test_feed, config,
                                                    selected_clusters)
    print(len(selected_outs))

    with open(os.path.join(dump_file + '.json'), 'wb') as f:
        json.dump(selected_clusters, f, indent=2)

    with open(os.path.join(dump_file + '.cluster_id.json.Train'), 'wb') as f:
        json.dump(index_cluster_id_train, f, indent=2)
    with open(os.path.join(dump_file + '.cluster_id.json.Test'), 'wb') as f:
        json.dump(index_cluster_id_test, f, indent=2)
    with open(os.path.join(dump_file + '.cluster_id.json.Valid'), 'wb') as f:
        json.dump(index_cluster_id_valid, f, indent=2)

    with open(os.path.join(dump_file + '.out.json'), 'wb') as f:
        json.dump(selected_outs, f, indent=2)

    with open(os.path.join(dump_file), "wb") as f:
        print("Dumping test to {}".format(dump_file))
        dialog_utils.dump_latent(model, test_feed, config, f, num_batch=None)

    with open(os.path.join(test_file), "wb") as f:
        print("Saving test to {}".format(test_file))
        dialog_utils.gen_with_cond(model,
                                   test_feed,
                                   config,
                                   num_batch=None,
                                   dest_f=f)

    with open(os.path.join(test_file + '.txt'), "wb") as f:
        print("Saving test to {}".format(test_file))
        dialog_utils.generate(model,
                              test_feed,
                              config,
                              evaluator,
                              num_batch=None,
                              dest_f=f)
    print("All done. Have a nice day")
Beispiel #8
0
def main(config):
    prepare_dirs_loggers(config, os.path.basename(__file__))
    manualSeed=config.seed
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    np.random.seed(manualSeed)
    sample_shape = config.batch_size, config.state_noise_dim, config.action_noise_dim

    evaluator = evaluators.BleuEvaluator(os.path.basename(__file__))

    train_feed = WoZGanDataLoaders("train", config)
    valid_feed = WoZGanDataLoaders("val", config)
    test_feed = WoZGanDataLoaders("test", config)


    # action2name = load_action2name(config)
    action2name = None
    corpus_client = None
    if config.gan_type=='gan' and config.input_type=='sat':
        model = GanAgent_SAT_WoZ(corpus_client, config, action2name)
    else:
        raise ValueError("No such GAN types: {}".format(config.gan_type))
    logger.info(summary(model, show_weights=False))
    model.discriminator.apply(weights_init)
    model.generator.apply(weights_init)

    if config.state_type=='rnn':
        load_context_encoder(model, os.path.join(config.log_dir, config.encoder_sess, "model_lirl"))

    if config.forward_only:
        test_file = os.path.join(config.log_dir, config.load_sess,
                                 "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.log_dir, config.load_sess,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.log_dir, config.load_sess, "model")
    else:
        test_file = os.path.join(config.session_dir,
                                 "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.session_dir, "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.session_dir, "model")
        vocab_file = os.path.join(config.session_dir, "vocab.json")
    
    

    if config.use_gpu:
        model.cuda()

    pred_list = []
    generator_samples = []
    print("Evaluate initial model on Validate set")
    model.eval()
    # policy_validate_for_human(model,valid_feed, config, sample_shape)
    disc_validate(model, valid_feed, config, sample_shape)
    _, sample_batch = gen_validate(model,valid_feed, config, sample_shape, -1)
    generator_samples.append([-1, sample_batch])
    machine_data, human_data = build_fake_data(model, valid_feed, config, sample_shape)

    
    model.train()
    print("Start training")

    if config.forward_only is False:
        try:
            engine.gan_train(model, machine_data, train_feed, valid_feed, test_feed, config, evaluator, pred_list, generator_samples)
        except KeyboardInterrupt:
            print("Training stopped by keyboard.")
    

    # save_data_for_tsne(human_data, machine_data, generator_samples, pred_list, config)
    print("Training Done ! ")
    '''
    model.load_state_dict(torch.load(model_file))
    print("Evaluate final model on Validate set")
    model.eval()
    policy_validate_for_human(model,valid_feed, config, sample_shape)
    disc_validate(model, valid_feed, config, sample_shape)
    gen_validate(model,valid_feed, config, sample_shape)

    print("Evaluate final model on Test set")
    policy_validate_for_human(model,test_feed, config, sample_shape)
    disc_validate(model, test_feed, config, sample_shape)
    gen_validate(model,test_feed, config, sample_shape)
    '''

    # dialog_utils.generate_with_adv(model, test_feed, config, None, num_batch=None)
    # selected_clusters, index_cluster_id_train = utt_utils.latent_cluster(model, train_feed, config, num_batch=None)
    # _, index_cluster_id_test = utt_utils.latent_cluster(model, test_feed, config, num_batch=None)
    # _, index_cluster_id_valid = utt_utils.latent_cluster(model, valid_feed, config, num_batch=None)
    # selected_outs = dialog_utils.selective_generate(model, test_feed, config, selected_clusters)
    # print(len(selected_outs))
    '''
def main(config):
    set_seed(config.seed)
    start_time = time.strftime('%Y-%m-%d-%H-%M-%S',
                               time.localtime(time.time()))
    stats_path = 'sys_config_log_model'
    if config.forward_only:
        saved_path = os.path.join(stats_path, config.pretrain_folder)
        config = Pack(json.load(open(os.path.join(saved_path, 'config.json'))))
        config['forward_only'] = True
    else:
        saved_path = os.path.join(
            stats_path,
            start_time + '-' + os.path.basename(__file__).split('.')[0])
        if not os.path.exists(saved_path):
            os.makedirs(saved_path)
    config.saved_path = saved_path

    prepare_dirs_loggers(config)
    logger = logging.getLogger()
    logger.info('[START]\n{}\n{}'.format(start_time, '=' * 30))

    corpus = corpora.MovieCorpus(config)
    train_dial, valid_dial, test_dial = corpus.get_corpus()
    sample_shape = config.batch_size, config.state_noise_dim, config.action_noise_dim
    # evaluator = MultiWozEvaluator("os.path.basename(__file__)")
    evaluator = BleuEvaluator(os.path.basename(__file__))
    # create data loader that feed the deep models
    train_data = MovieDataLoaders("Train", train_dial, config)
    valid_data = MovieDataLoaders("Valid", valid_dial, config)
    test_data = MovieDataLoaders("Test", test_dial, config)

    model = LIRL(corpus, config)
    if config.use_gpu:
        model.cuda()

    best_epoch = None
    if not config.forward_only:
        try:
            best_epoch = train(model,
                               train_data,
                               valid_data,
                               test_data,
                               config,
                               evaluator,
                               gen=task_generate)
        except KeyboardInterrupt:
            print('Training stopped by keyboard.')
    if best_epoch is None:
        model_ids = sorted([
            int(p.replace('-model', '')) for p in os.listdir(saved_path)
            if 'model' in p and 'rl' not in p
        ])
        best_epoch = model_ids[-1]

    print("$$$ Load {}-model".format(best_epoch))
    config.batch_size = 32
    best_epoch = best_epoch
    model.load_state_dict(
        torch.load(os.path.join(saved_path, '{}-model'.format(best_epoch))))

    logger.info("Forward Only Evaluation")

    validate(model, valid_data, config)
    validate(model, test_data, config)

    with open(
            os.path.join(saved_path,
                         '{}_{}_valid_file.txt'.format(start_time,
                                                       best_epoch)), 'w') as f:
        task_generate(model,
                      valid_data,
                      config,
                      evaluator,
                      num_batch=None,
                      dest_f=f)

    with open(
            os.path.join(saved_path,
                         '{}_{}_test_file.txt'.format(start_time, best_epoch)),
            'w') as f:
        task_generate(model,
                      test_data,
                      config,
                      evaluator,
                      num_batch=None,
                      dest_f=f)

    end_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
    print('[END]', end_time, '=' * 30)
def main(config):
    prepare_dirs_loggers(config, os.path.basename(__file__))
    manualSeed = config.seed
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    np.random.seed(manualSeed)
    sample_shape = config.batch_size, config.state_noise_dim, config.action_noise_dim

    evaluator = evaluators.BleuEvaluator(os.path.basename(__file__))

    train_feed = WoZGanDataLoaders_StateActionEmbed("train", config)
    valid_feed = WoZGanDataLoaders_StateActionEmbed("val", config)
    test_feed = WoZGanDataLoaders_StateActionEmbed("test", config)

    # action2name = load_action2name(config)
    action2name = None
    corpus_client = None

    # model = GanAgent_VAE_StateActioneEmbed(corpus_client, config, action2name)
    model = GanAgent_StateVaeActionSeg(corpus_client, config, action2name)

    logger.info(summary(model, show_weights=False))
    model.discriminator.apply(weights_init)
    model.generator.apply(weights_init)
    model.vae.apply(weights_init)

    if config.forward_only:
        test_file = os.path.join(
            config.log_dir, config.load_sess,
            "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.log_dir, config.load_sess,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.log_dir, config.load_sess, "model")
    else:
        test_file = os.path.join(
            config.session_dir,
            "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.session_dir,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.session_dir, "model")
        vocab_file = os.path.join(config.session_dir, "vocab.json")

    if config.use_gpu:
        model.cuda()

    pred_list = []
    generator_samples = []
    print("Evaluate initial model on Validate set")
    model.eval()
    # policy_validate_for_human(model,valid_feed, config, sample_shape)
    disc_validate(model, valid_feed, config, sample_shape)
    _, sample_batch = gen_validate(model, valid_feed, config, sample_shape, -1)
    generator_samples.append([-1, sample_batch])
    machine_data, human_data = build_fake_data(model, valid_feed, config,
                                               sample_shape)

    model.train()
    print("Start training")

    if config.forward_only is False:
        try:
            engine.vae_train(model, train_feed, valid_feed, test_feed, config)
        except KeyboardInterrupt:
            print("Training stopped by keyboard.")
    print("AutoEncoder Training Done ! ")
    load_model_vae(model, config)

    # path='logs/2019-09-18T12:20:26.063708-mwoz_gan_vae_StateActionEmbed.py'  # this is embed version
    # path='logs/2019-09-18T12:24:45.517636-mwoz_gan_vae_StateActionEmbed.py'  # this is embed_merged version
    # path='logs/2019-09-18T17:21:35.420069-mwoz_gan_vae_StateActionEmbed.py'  # this is state_vae action_seg version, without hotel domain
    # load_model_vae(model, path)

    if config.forward_only is False:
        try:
            engine.gan_train(model, machine_data, train_feed, valid_feed,
                             test_feed, config, evaluator, pred_list,
                             generator_samples)
        except KeyboardInterrupt:
            print("Training stopped by keyboard.")
    print("Reward Model Training Done ! ")
    print("Saved path: {}".format(model_file))