def main(config):
    corpus_client = getattr(corpora, config.corpus_client)(config)
    corpus_client.vocab, corpus_client.rev_vocab, corpus_client.unk_id = load_vocab(
        config.vocab)
    prepare_dirs_loggers(config, os.path.basename(__file__))

    dial_corpus = corpus_client.get_corpus()
    train_dial, valid_dial, test_dial = (dial_corpus['train'],
                                         dial_corpus['valid'],
                                         dial_corpus['test'])

    evaluator = evaluators.BleuEvaluator("CornellMovie")

    # create data loader that feed the deep models
    train_feed = data_loaders.SMDDialogSkipLoader("Train", train_dial, config)
    valid_feed = data_loaders.SMDDialogSkipLoader("Valid", valid_dial, config)
    test_feed = data_loaders.SMDDialogSkipLoader("Test", test_dial, config)
    model = dialog_models.StED(corpus_client, config)

    if config.forward_only:
        test_file = os.path.join(
            config.log_dir, config.load_sess,
            "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.log_dir, config.load_sess,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.log_dir, config.load_sess, "model")
    else:
        test_file = os.path.join(
            config.session_dir,
            "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.session_dir,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.session_dir, "model")

    if config.use_gpu:
        model.cuda()

    if not config.forward_only:
        try:
            engine.train(model,
                         train_feed,
                         valid_feed,
                         test_feed,
                         config,
                         evaluator,
                         gen=dialog_utils.generate_with_adv)
        except KeyboardInterrupt:
            print("Training stopped by keyboard.")

    config.batch_size = 10
    model.load_state_dict(torch.load(model_file))
    engine.validate(model, valid_feed, config)
    engine.validate(model, test_feed, config)

    dialog_utils.generate_with_adv(model, test_feed, config, None, num_batch=0)
    selected_clusters = utt_utils.latent_cluster(model,
                                                 train_feed,
                                                 config,
                                                 num_batch=None)
    selected_outs = dialog_utils.selective_generate(model, test_feed, config,
                                                    selected_clusters)
    print(len(selected_outs))

    with open(os.path.join(dump_file + '.json'), 'wb') as f:
        json.dump(selected_clusters, f, indent=2)

    with open(os.path.join(dump_file + '.out.json'), 'wb') as f:
        json.dump(selected_outs, f, indent=2)

    with open(os.path.join(dump_file), "wb") as f:
        print("Dumping test to {}".format(dump_file))
        dialog_utils.dump_latent(model, test_feed, config, f, num_batch=None)

    with open(os.path.join(test_file), "wb") as f:
        print("Saving test to {}".format(test_file))
        dialog_utils.gen_with_cond(model,
                                   test_feed,
                                   config,
                                   num_batch=None,
                                   dest_f=f)

    with open(os.path.join(test_file + '.txt'), "wb") as f:
        print("Saving test to {}".format(test_file))
        dialog_utils.generate(model,
                              test_feed,
                              config,
                              evaluator,
                              num_batch=None,
                              dest_f=f)
def main(config):
    prepare_dirs_loggers(config, os.path.basename(__file__))

    corpus_client = corpora.NormMultiWozCorpus(config)

    dial_corpus = corpus_client.get_corpus()
    train_dial, valid_dial, test_dial = dial_corpus

    # evaluator = evaluators.BleuEvaluator("os.path.basename(__file__)")
    evaluator = MultiWozEvaluator('SysWoz')
    # create data loader that feed the deep models
    train_feed = data_loaders.BeliefDbDataLoaders("Train", train_dial, config)
    valid_feed = data_loaders.BeliefDbDataLoaders("Valid", valid_dial, config)
    test_feed = data_loaders.BeliefDbDataLoaders("Test", test_dial, config)
    model = dialog_models.AeED(corpus_client, config)

    if config.forward_only:
        test_file = os.path.join(
            config.log_dir, config.load_sess,
            "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.log_dir, config.load_sess,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.log_dir, config.load_sess, "model")
    else:
        test_file = os.path.join(
            config.session_dir,
            "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.session_dir,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.session_dir, "model")

    if config.use_gpu:
        model.cuda()

    if config.forward_only is False:
        try:
            engine.train(model,
                         train_feed,
                         valid_feed,
                         test_feed,
                         config,
                         evaluator,
                         gen=dialog_utils.generate_with_adv)
        except KeyboardInterrupt:
            print("Training stopped by keyboard.")

    config.batch_size = 10
    model.load_state_dict(torch.load(model_file))
    logger.info("Test Bleu:")
    dialog_utils.generate(model, test_feed, config, evaluator,
                          test_feed.num_batch, False)
    engine.validate(model, valid_feed, config)
    engine.validate(model, test_feed, config)

    dialog_utils.generate_with_adv(model,
                                   test_feed,
                                   config,
                                   None,
                                   num_batch=None)
    cluster_name_id = None
    action_count = 0
    selected_clusters, index_cluster_id_train, cluster_name_id, action_count = utt_utils.latent_cluster(
        model, train_feed, config, None, 0, num_batch=None)
    _, index_cluster_id_test, cluster_name_id, action_count = utt_utils.latent_cluster(
        model,
        test_feed,
        config,
        cluster_name_id,
        action_count,
        num_batch=None)
    _, index_cluster_id_valid, cluster_name_id, action_count = utt_utils.latent_cluster(
        model,
        valid_feed,
        config,
        cluster_name_id,
        action_count,
        num_batch=None)
    selected_outs = dialog_utils.selective_generate(model, test_feed, config,
                                                    selected_clusters)
    print(len(selected_outs))

    with open(os.path.join(dump_file + '.json'), 'wb') as f:
        json.dump(selected_clusters, f, indent=2)

    with open(os.path.join(dump_file + '.cluster_id.json.Train'), 'wb') as f:
        json.dump(index_cluster_id_train, f, indent=2)
    with open(os.path.join(dump_file + '.cluster_id.json.Test'), 'wb') as f:
        json.dump(index_cluster_id_test, f, indent=2)
    with open(os.path.join(dump_file + '.cluster_id.json.Valid'), 'wb') as f:
        json.dump(index_cluster_id_valid, f, indent=2)

    with open(os.path.join(dump_file + '.out.json'), 'wb') as f:
        json.dump(selected_outs, f, indent=2)

    with open(os.path.join(dump_file), "wb") as f:
        print("Dumping test to {}".format(dump_file))
        dialog_utils.dump_latent(model, test_feed, config, f, num_batch=None)

    with open(os.path.join(test_file), "wb") as f:
        print("Saving test to {}".format(test_file))
        dialog_utils.gen_with_cond(model,
                                   test_feed,
                                   config,
                                   num_batch=None,
                                   dest_f=f)

    with open(os.path.join(test_file + '.txt'), "wb") as f:
        print("Saving test to {}".format(test_file))
        dialog_utils.generate(model,
                              test_feed,
                              config,
                              evaluator,
                              num_batch=None,
                              dest_f=f)
    print("All done. Have a nice day")