def train_lm(args): if args.data_name not in config.DATASET: raise ValueError("Invalid data name!") dataset = DataSet(config.DATASET[args.data_name]) train_dataset = dataset.load_train() test_dataset = dataset.load_test() corpus = Corpus(train_dataset.file_list, test_dataset.file_list, args.reverse) suffix = "backward" if args.reverse else "forward" kwargs = { "vocab_size": corpus.glove_embed.shape[0], "embed_dim": corpus.glove_embed.shape[1], "corpus": corpus, "hparams": { "hidden_size": args.hidden_size, "num_layers": args.num_layers, "cell_type": args.cell_type, "tie_embed": args.tie_embed, "rnn_dropout": args.rnn_dropout, "hidden_dropout": args.hidden_dropout, "num_epochs": args.num_epochs, "batch_size": args.batch_size, "bptt": args.bptt, "log_interval": args.log_interval, "save_path": args.save_path + '_' + args.data_name + '_' + suffix, "lr": args.lr, "wdecay": args.wdecay, } } lm = LanguageModel(**kwargs) best_valid_loss = lm.fit() print("Best Valid Loss:", best_valid_loss)
def run_lm_coherence(args): logging.info("Loading data...") if args.data_name not in config.DATASET: raise ValueError("Invalid data name!") dataset = DataSet(config.DATASET[args.data_name]) train_dataset = dataset.load_train() test_df = dataset.load_test_perm() test_dataset = dataset.load_test() test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) corpus = Corpus(train_dataset.file_list, test_dataset.file_list) # dataset = DataSet(config.DATASET["wsj_bigram"]) # test_df = dataset.load_test_perm() # test_dataset = dataset.load_test() # test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) with open( os.path.join(config.CHECKPOINT_PATH, args.lm_name + "_forward.pkl"), "rb") as f: hparams = pickle.load(f) kwargs = { "vocab_size": corpus.glove_embed.shape[0], "embed_dim": corpus.glove_embed.shape[1], "corpus": corpus, "hparams": hparams, } forward_model = LanguageModel(**kwargs) forward_model.load( os.path.join(config.CHECKPOINT_PATH, args.lm_name + "_forward.pt")) backward_model = LanguageModel(**kwargs) backward_model.load( os.path.join(config.CHECKPOINT_PATH, args.lm_name + "_backward.pt")) logging.info("Results for discrimination:") model = LMCoherence(forward_model.lm, backward_model.lm, corpus) dis_acc = model.evaluate_dis(test_dataloader, test_df) logging.info("Disc Accuracy: {}".format(dis_acc)) logging.info("Results for insertion:") ins_acc = model.evaluate_ins(test_dataloader, test_df) logging.info("Disc Accuracy: {}".format(ins_acc)) return dis_acc, ins_acc
def save_eval_perm(data_name, if_sample=False, random_seed=config.RANDOM_SEED): random.seed(random_seed) logging.info("Loading valid and test data...") if data_name not in config.DATASET: raise ValueError("Invalid data name!") dataset = DataSet(config.DATASET[data_name]) # dataset.random_seed = random_seed if if_sample: valid_dataset = dataset.load_valid_sample() else: valid_dataset = dataset.load_valid() if if_sample: test_dataset = dataset.load_test_sample() else: test_dataset = dataset.load_test() valid_df = valid_dataset.article_df test_df = test_dataset.article_df logging.info("Generating permuted articles...") def permute(x): x = np.array(x).squeeze() # neg_x_list = permute_articles([x], config.NEG_PERM)[0] neg_x_list = permute_articles_with_replacement([x], config.NEG_PERM)[0] return "<BREAK>".join(["<PUNC>".join(i) for i in neg_x_list]) valid_df["neg_list"] = valid_df.sentences.map(permute) valid_df["sentences"] = valid_df.sentences.map(lambda x: "<PUNC>".join(x)) valid_nums = valid_df.neg_list.map(lambda x: len(x.split("<BREAK>"))).sum() test_df["neg_list"] = test_df.sentences.map(permute) test_df["sentences"] = test_df.sentences.map(lambda x: "<PUNC>".join(x)) test_nums = test_df.neg_list.map(lambda x: len(x.split("<BREAK>"))).sum() logging.info("Number of validation pairs %d" % valid_nums) logging.info("Number of test pairs %d" % test_nums) logging.info("Saving...") dataset.save_valid_perm(valid_df) dataset.save_test_perm(test_df) logging.info("Finished!")
def run_bigram_coherence(args): logging.info("Loading data...") if args.data_name not in config.DATASET: raise ValueError("Invalid data name!") dataset = DataSet(config.DATASET[args.data_name]) # dataset.random_seed = args.random_seed if not os.path.isfile(dataset.test_perm): save_eval_perm(args.data_name, random_seed=args.random_seed) train_dataset = dataset.load_train(args.portion) train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) valid_dataset = dataset.load_valid(args.portion) valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=1, shuffle=False) valid_df = dataset.load_valid_perm() test_dataset = dataset.load_test(args.portion) test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) test_df = dataset.load_test_perm() logging.info("Loading sent embedding...") if args.sent_encoder == "infersent": sent_embedding = get_infersent(args.data_name, if_sample=args.test) embed_dim = 4096 elif args.sent_encoder == "average_glove": sent_embedding = get_average_glove(args.data_name, if_sample=args.test) embed_dim = 300 elif args.sent_encoder == "lm_hidden": corpus = Corpus(train_dataset.file_list, test_dataset.file_list) sent_embedding = get_lm_hidden(args.data_name, "lm_" + args.data_name, corpus) embed_dim = 2048 elif args.sent_encoder == "s2s_hidden": corpus = SentCorpus(train_dataset.file_list, test_dataset.file_list) sent_embedding = get_s2s_hidden(args.data_name, "s2s_" + args.data_name, corpus) embed_dim = 2048 else: raise ValueError("Invalid sent encoder name!") logging.info("Training BigramCoherence model...") kwargs = { "embed_dim": embed_dim, "sent_encoder": sent_embedding, "hparams": { "loss": args.loss, "input_dropout": args.input_dropout, "hidden_state": args.hidden_state, "hidden_layers": args.hidden_layers, "hidden_dropout": args.hidden_dropout, "num_epochs": args.num_epochs, "margin": args.margin, "lr": args.lr, "l2_reg_lambda": args.l2_reg_lambda, "use_bn": args.use_bn, "task": "discrimination", "bidirectional": args.bidirectional, } } model = BigramCoherence(**kwargs) model.init() best_step, valid_acc = model.fit(train_dataloader, valid_dataloader, valid_df) if args.save: model_path = os.path.join(config.CHECKPOINT_PATH, "%s-%.4f" % (args.data_name, valid_acc)) # model.save(model_path) torch.save(model, model_path + '.pth') model.load_best_state() # dataset = DataSet(config.DATASET["wsj_bigram"]) # test_dataset = dataset.load_test() # test_dataloader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) # test_df = dataset.load_test_perm() # if args.sent_encoder == "infersent": # model.sent_encoder = get_infersent("wsj_bigram", if_sample=args.test) # elif args.sent_encoder == "average_glove": # model.sent_encoder = get_average_glove("wsj_bigram", if_sample=args.test) # else: # model.sent_encoder = get_lm_hidden("wsj_bigram", "lm_" + args.data_name, corpus) logging.info("Results for discrimination:") dis_acc = model.evaluate_dis(test_dataloader, test_df) print("Test Acc:", dis_acc) logging.info("Disc Accuracy: {}".format(dis_acc[0])) logging.info("Results for insertion:") ins_acc = model.evaluate_ins(test_dataloader, test_df) print("Test Acc:", ins_acc) logging.info("Insert Accuracy: {}".format(ins_acc[0])) return dis_acc, ins_acc