예제 #1
0
def main(config):
    prepare_dirs_loggers(config, os.path.basename(__file__))

    corpus_client = getattr(corpora, config.corpus_client)(config)
    corpus_client.vocab, corpus_client.rev_vocab, corpus_client.unk_id = load_vocab(
        config.vocab)

    # warmup_data = maluuba_client.get_seed_responses(len(maluuba_client.domain_descriptions))
    # maluuba_corpus = maluuba_client.get_corpus()
    # train_dial, valid_dial = maluuba_corpus['train'], maluuba_corpus['valid']
    corpus = corpus_client.get_corpus()
    train_dial, valid_dial, test_dial = (corpus['train'], corpus['valid'],
                                         corpus['test'])

    evaluator = evaluators.BleuEntEvaluator("SMD", corpus_client.ent_metas)

    laed_z = load_laed_features(config.laed_z_folder)
    config.laed_z_size = laed_z['dialog'][0].shape[-1]

    laed_z_test = laed_z['dialog'][len(train_dial) + len(valid_dial):]
    test_feed = data_loaders.ZslLASMDDialDataLoader("Test", test_dial,
                                                    laed_z_test, [], config)
    if config.action_match:
        if config.use_ptr:
            model = ZeroShotLAPtrHRED(corpus_client, config)
        else:
            raise NotImplementedError()
    else:
        raise NotImplementedError()

    session_dir = os.path.join(config.log_dir, config.load_sess)
    test_file = os.path.join(
        session_dir, "{}-test-{}.txt".format(get_time(), config.gen_type))
    model_file = os.path.join(config.log_dir, config.load_sess, "model")

    if config.use_gpu:
        model.cuda()
    config.batch_size = 20
    model.load_state_dict(torch.load(model_file))

    # run the model on the test dataset.
    validate(model, test_feed, config)

    with open(os.path.join(test_file), "wb") as f:
        hred_utils.generate(model,
                            test_feed,
                            config,
                            evaluator,
                            num_batch=None,
                            dest_f=f)
예제 #2
0
def main(config):
    corpus_client = getattr(corpora, config.corpus_client)(config)
    corpus_client.vocab, corpus_client.rev_vocab, corpus_client.unk_id = load_vocab(
        config.vocab)
    prepare_dirs_loggers(config, os.path.basename(__file__))

    dial_corpus = corpus_client.get_corpus()
    train_dial, valid_dial, test_dial = (dial_corpus['train'],
                                         dial_corpus['valid'],
                                         dial_corpus['test'])

    evaluator = evaluators.BleuEvaluator("CornellMovie")

    # create data loader that feed the deep models
    train_feed = data_loaders.SMDDialogSkipLoader("Train", train_dial, config)
    valid_feed = data_loaders.SMDDialogSkipLoader("Valid", valid_dial, config)
    test_feed = data_loaders.SMDDialogSkipLoader("Test", test_dial, config)
    model = dialog_models.VAE(corpus_client, config)

    if config.forward_only:
        test_file = os.path.join(
            config.log_dir, config.load_sess,
            "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.log_dir, config.load_sess,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.log_dir, config.load_sess, "model")
    else:
        test_file = os.path.join(
            config.session_dir,
            "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.session_dir,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.session_dir, "model")

    if config.use_gpu:
        model.cuda()

    if config.forward_only is False:
        try:
            engine.train(model,
                         train_feed,
                         valid_feed,
                         test_feed,
                         config,
                         evaluator,
                         gen=dialog_utils.generate_vae)
        except KeyboardInterrupt:
            print("Training stopped by keyboard.")
def main(config):
    corpus_client = getattr(corpora, config.corpus_client)(config)
    corpus_client.vocab, corpus_client.rev_vocab, corpus_client.unk_id = load_vocab(
        config.vocab)
    prepare_dirs_loggers(config, os.path.basename(__file__))

    dial_corpus = corpus_client.get_corpus()
    train_dial, valid_dial, test_dial = (dial_corpus['train'],
                                         dial_corpus['valid'],
                                         dial_corpus['test'])

    evaluator = evaluators.BleuEvaluator("CornellMovie")

    # create data loader that feed the deep models
    train_feed = data_loaders.SMDDialogSkipLoader("Train", train_dial, config)
    valid_feed = data_loaders.SMDDialogSkipLoader("Valid", valid_dial, config)
    test_feed = data_loaders.SMDDialogSkipLoader("Test", test_dial, config)
    model = dialog_models.StED(corpus_client, config)

    if config.forward_only:
        test_file = os.path.join(
            config.log_dir, config.load_sess,
            "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.log_dir, config.load_sess,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.log_dir, config.load_sess, "model")
    else:
        test_file = os.path.join(
            config.session_dir,
            "{}-test-{}.txt".format(get_time(), config.gen_type))
        dump_file = os.path.join(config.session_dir,
                                 "{}-z.pkl".format(get_time()))
        model_file = os.path.join(config.session_dir, "model")

    if config.use_gpu:
        model.cuda()

    if not config.forward_only:
        try:
            engine.train(model,
                         train_feed,
                         valid_feed,
                         test_feed,
                         config,
                         evaluator,
                         gen=dialog_utils.generate_with_adv)
        except KeyboardInterrupt:
            print("Training stopped by keyboard.")

    config.batch_size = 10
    model.load_state_dict(torch.load(model_file))
    engine.validate(model, valid_feed, config)
    engine.validate(model, test_feed, config)

    dialog_utils.generate_with_adv(model, test_feed, config, None, num_batch=0)
    selected_clusters = utt_utils.latent_cluster(model,
                                                 train_feed,
                                                 config,
                                                 num_batch=None)
    selected_outs = dialog_utils.selective_generate(model, test_feed, config,
                                                    selected_clusters)
    print(len(selected_outs))

    with open(os.path.join(dump_file + '.json'), 'wb') as f:
        json.dump(selected_clusters, f, indent=2)

    with open(os.path.join(dump_file + '.out.json'), 'wb') as f:
        json.dump(selected_outs, f, indent=2)

    with open(os.path.join(dump_file), "wb") as f:
        print("Dumping test to {}".format(dump_file))
        dialog_utils.dump_latent(model, test_feed, config, f, num_batch=None)

    with open(os.path.join(test_file), "wb") as f:
        print("Saving test to {}".format(test_file))
        dialog_utils.gen_with_cond(model,
                                   test_feed,
                                   config,
                                   num_batch=None,
                                   dest_f=f)

    with open(os.path.join(test_file + '.txt'), "wb") as f:
        print("Saving test to {}".format(test_file))
        dialog_utils.generate(model,
                              test_feed,
                              config,
                              evaluator,
                              num_batch=None,
                              dest_f=f)
예제 #4
0
def main(config):
    laed_config = load_config(config.model)
    laed_config.use_gpu = config.use_gpu
    laed_config = process_config(laed_config)

    setattr(laed_config, 'black_domains', config.black_domains)
    setattr(laed_config, 'black_ratio', config.black_ratio)
    setattr(laed_config, 'include_domain', True)
    setattr(laed_config, 'include_example', False)
    setattr(laed_config, 'include_state', True)
    setattr(laed_config, 'entities_file', 'NeuralDialog-ZSDG/data/stanford/kvret_entities.json')
    setattr(laed_config, 'action_match', True)
    setattr(laed_config, 'batch_size', config.batch_size)
    setattr(laed_config, 'data_dir', config.data_dir)
    setattr(laed_config, 'include_eod', False) # for StED model
    setattr(laed_config, 'domain_description', config.domain_description)

    if config.process_seed_data:
        assert config.corpus_client[:3] == 'Zsl', 'Incompatible coprus_client for --process_seed_data flag'
    corpus_client = getattr(corpora, config.corpus_client)(laed_config)
    if config.vocab:
        corpus_client.vocab, corpus_client.rev_vocab, corpus_client.unk_token = load_vocab(config.vocab)
    prepare_dirs_loggers(config, os.path.basename(__file__))

    dial_corpus = corpus_client.get_corpus()
    # train_dial, valid_dial, test_dial = dial_corpus['train'], dial_corpus['valid'], dial_corpus['test']
    # all_dial = train_dial + valid_dial + test_dial
    # all_utts = reduce(lambda x, y: x + y, all_dial, [])

    model = load_model(config.model, config.model_name, config.model_type, corpus_client=corpus_client)

    if config.use_gpu:
        model.cuda()

    for dataset_name in ['train', 'valid', 'test']:
        dataset = dial_corpus[dataset_name]
        feed_data = dataset if config.model_type == 'dialog' else reduce(lambda x, y: x + y, dataset, [])

        # create data loader that feed the deep models
        if config.process_seed_data:
            seed_utts = corpus_client.get_seed_responses(utt_cnt=len(corpus_client.domain_descriptions))
        main_feed = getattr(data_loaders, config.data_loader)("Test", feed_data, laed_config)

        features = process_data_feed(model, main_feed, laed_config)
        if config.data_loader == 'SMDDialogSkipLoader':
            pad_mode = 'start_end'
        elif config.data_loader == 'SMDDataLoader':
            pad_mode = 'start'
        else:
            pad_mode = None
        features = deflatten_laed_features(features, dataset, pad_mode=pad_mode)
        assert sum(map(len, dataset)) == sum(map(lambda x: x.shape[0], features))

        if not os.path.exists(config.out_folder):
            os.makedirs(config.out_folder)
        with open(os.path.join(config.out_folder, 'dialogs_{}.pkl'.format(dataset_name)), 'w') as result_out:
            pickle.dump(features, result_out)

    if config.process_seed_data:
        seed_utts = corpus_client.get_seed_responses(utt_cnt=len(corpus_client.domain_descriptions))
        seed_feed = data_loaders.PTBDataLoader("Seed", seed_utts, laed_config)
        seed_features = process_data_feed(model, seed_feed, laed_config)
        with open(os.path.join(config.out_folder, 'seed_utts.pkl'), 'w') as result_out:
            pickle.dump(seed_features, result_out)