def train_lm(args):
    if args.data_name not in config.DATASET:
        raise ValueError("Invalid data name!")
    dataset = DataSet(config.DATASET[args.data_name])
    train_dataset = dataset.load_train()
    test_dataset = dataset.load_test()
    corpus = Corpus(train_dataset.file_list, test_dataset.file_list,
                    args.reverse)
    suffix = "backward" if args.reverse else "forward"

    kwargs = {
        "vocab_size": corpus.glove_embed.shape[0],
        "embed_dim": corpus.glove_embed.shape[1],
        "corpus": corpus,
        "hparams": {
            "hidden_size": args.hidden_size,
            "num_layers": args.num_layers,
            "cell_type": args.cell_type,
            "tie_embed": args.tie_embed,
            "rnn_dropout": args.rnn_dropout,
            "hidden_dropout": args.hidden_dropout,
            "num_epochs": args.num_epochs,
            "batch_size": args.batch_size,
            "bptt": args.bptt,
            "log_interval": args.log_interval,
            "save_path": args.save_path + '_' + args.data_name + '_' + suffix,
            "lr": args.lr,
            "wdecay": args.wdecay,
        }
    }

    lm = LanguageModel(**kwargs)
    best_valid_loss = lm.fit()
    print("Best Valid Loss:", best_valid_loss)
def run_lm_coherence(args):
    logging.info("Loading data...")
    if args.data_name not in config.DATASET:
        raise ValueError("Invalid data name!")

    dataset = DataSet(config.DATASET[args.data_name])
    train_dataset = dataset.load_train()
    test_df = dataset.load_test_perm()
    test_dataset = dataset.load_test()
    test_dataloader = DataLoader(dataset=test_dataset,
                                 batch_size=1,
                                 shuffle=False)
    corpus = Corpus(train_dataset.file_list, test_dataset.file_list)

    # dataset = DataSet(config.DATASET["wsj_bigram"])
    # test_df = dataset.load_test_perm()
    # test_dataset = dataset.load_test()
    # test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

    with open(
            os.path.join(config.CHECKPOINT_PATH,
                         args.lm_name + "_forward.pkl"), "rb") as f:
        hparams = pickle.load(f)

    kwargs = {
        "vocab_size": corpus.glove_embed.shape[0],
        "embed_dim": corpus.glove_embed.shape[1],
        "corpus": corpus,
        "hparams": hparams,
    }

    forward_model = LanguageModel(**kwargs)
    forward_model.load(
        os.path.join(config.CHECKPOINT_PATH, args.lm_name + "_forward.pt"))
    backward_model = LanguageModel(**kwargs)
    backward_model.load(
        os.path.join(config.CHECKPOINT_PATH, args.lm_name + "_backward.pt"))

    logging.info("Results for discrimination:")
    model = LMCoherence(forward_model.lm, backward_model.lm, corpus)
    dis_acc = model.evaluate_dis(test_dataloader, test_df)
    logging.info("Disc Accuracy: {}".format(dis_acc))

    logging.info("Results for insertion:")
    ins_acc = model.evaluate_ins(test_dataloader, test_df)
    logging.info("Disc Accuracy: {}".format(ins_acc))

    return dis_acc, ins_acc
def save_eval_perm(data_name, if_sample=False, random_seed=config.RANDOM_SEED):
    random.seed(random_seed)

    logging.info("Loading valid and test data...")
    if data_name not in config.DATASET:
        raise ValueError("Invalid data name!")
    dataset = DataSet(config.DATASET[data_name])
    # dataset.random_seed = random_seed
    if if_sample:
        valid_dataset = dataset.load_valid_sample()
    else:
        valid_dataset = dataset.load_valid()
    if if_sample:
        test_dataset = dataset.load_test_sample()
    else:
        test_dataset = dataset.load_test()
    valid_df = valid_dataset.article_df
    test_df = test_dataset.article_df

    logging.info("Generating permuted articles...")

    def permute(x):
        x = np.array(x).squeeze()
        # neg_x_list = permute_articles([x], config.NEG_PERM)[0]
        neg_x_list = permute_articles_with_replacement([x], config.NEG_PERM)[0]
        return "<BREAK>".join(["<PUNC>".join(i) for i in neg_x_list])

    valid_df["neg_list"] = valid_df.sentences.map(permute)
    valid_df["sentences"] = valid_df.sentences.map(lambda x: "<PUNC>".join(x))
    valid_nums = valid_df.neg_list.map(lambda x: len(x.split("<BREAK>"))).sum()
    test_df["neg_list"] = test_df.sentences.map(permute)
    test_df["sentences"] = test_df.sentences.map(lambda x: "<PUNC>".join(x))
    test_nums = test_df.neg_list.map(lambda x: len(x.split("<BREAK>"))).sum()

    logging.info("Number of validation pairs %d" % valid_nums)
    logging.info("Number of test pairs %d" % test_nums)

    logging.info("Saving...")
    dataset.save_valid_perm(valid_df)
    dataset.save_test_perm(test_df)
    logging.info("Finished!")
Example #4
0
def run_bigram_coherence(args):
    logging.info("Loading data...")
    if args.data_name not in config.DATASET:
        raise ValueError("Invalid data name!")
    dataset = DataSet(config.DATASET[args.data_name])
    # dataset.random_seed = args.random_seed
    if not os.path.isfile(dataset.test_perm):
        save_eval_perm(args.data_name, random_seed=args.random_seed)

    train_dataset = dataset.load_train(args.portion)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True, drop_last=True)
    valid_dataset = dataset.load_valid(args.portion)
    valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=1, shuffle=False)
    valid_df = dataset.load_valid_perm()
    test_dataset = dataset.load_test(args.portion)
    test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
    test_df = dataset.load_test_perm()

    logging.info("Loading sent embedding...")
    if args.sent_encoder == "infersent":
        sent_embedding = get_infersent(args.data_name, if_sample=args.test)
        embed_dim = 4096
    elif args.sent_encoder == "average_glove":
        sent_embedding = get_average_glove(args.data_name, if_sample=args.test)
        embed_dim = 300
    elif args.sent_encoder == "lm_hidden":
        corpus = Corpus(train_dataset.file_list, test_dataset.file_list)
        sent_embedding = get_lm_hidden(args.data_name, "lm_" + args.data_name, corpus)
        embed_dim = 2048
    elif args.sent_encoder == "s2s_hidden":
        corpus = SentCorpus(train_dataset.file_list, test_dataset.file_list)
        sent_embedding = get_s2s_hidden(args.data_name, "s2s_" + args.data_name, corpus)
        embed_dim = 2048
    else:
        raise ValueError("Invalid sent encoder name!")

    logging.info("Training BigramCoherence model...")
    kwargs = {
        "embed_dim": embed_dim,
        "sent_encoder": sent_embedding,
        "hparams": {
            "loss": args.loss,
            "input_dropout": args.input_dropout,
            "hidden_state": args.hidden_state,
            "hidden_layers": args.hidden_layers,
            "hidden_dropout": args.hidden_dropout,
            "num_epochs": args.num_epochs,
            "margin": args.margin,
            "lr": args.lr,
            "l2_reg_lambda": args.l2_reg_lambda,
            "use_bn": args.use_bn,
            "task": "discrimination",
            "bidirectional": args.bidirectional,
        }
    }

    model = BigramCoherence(**kwargs)
    model.init()
    best_step, valid_acc = model.fit(train_dataloader, valid_dataloader, valid_df)
    if args.save:
        model_path = os.path.join(config.CHECKPOINT_PATH, "%s-%.4f" % (args.data_name, valid_acc))
        # model.save(model_path)
        torch.save(model, model_path + '.pth')
    model.load_best_state()

    # dataset = DataSet(config.DATASET["wsj_bigram"])
    # test_dataset = dataset.load_test()
    # test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
    # test_df = dataset.load_test_perm()
    # if args.sent_encoder == "infersent":
    #    model.sent_encoder = get_infersent("wsj_bigram", if_sample=args.test)
    # elif args.sent_encoder == "average_glove":
    #    model.sent_encoder = get_average_glove("wsj_bigram", if_sample=args.test)
    # else:
    #    model.sent_encoder = get_lm_hidden("wsj_bigram", "lm_" + args.data_name, corpus)

    logging.info("Results for discrimination:")
    dis_acc = model.evaluate_dis(test_dataloader, test_df)
    print("Test Acc:", dis_acc)
    logging.info("Disc Accuracy: {}".format(dis_acc[0]))

    logging.info("Results for insertion:")
    ins_acc = model.evaluate_ins(test_dataloader, test_df)
    print("Test Acc:", ins_acc)
    logging.info("Insert Accuracy: {}".format(ins_acc[0]))

    return dis_acc, ins_acc