示例#1
0
def main(sents):

    parser = utils.make_parser()
    args = parser.parse_args()

    # build model

    # # get data
    #data = Dataloader(args.input_json, args.input_ques_h5)

    # # make op
    op = {
        "vocab_sz": 27699,  #data.getVocabSize(),
        "max_seq_len": 28,  #data.getSeqLength(),
        "emb_hid_dim": 256,  #args.emb_hid_dim,
        "emb_dim": 512,  #args.emb_dim,
        "enc_dim": 512,  #args.enc_dim,
        "enc_dropout": 0.5,  #args.enc_dropout,
        "enc_rnn_dim": 512,  #args.enc_rnn_dim,
        "gen_rnn_dim": 512,  #args.gen_rnn_dim,
        #"gen_dropout": 0.5,#args.gen_dropout,
        #"lr": 0.0008,#args.learning_rate,
        #"epochs": 1,#args.n_epoch
    }

    files = glob("save/*")
    files.sort(key=os.path.getmtime)
    WEIGHT_PATH = files[-1]
    print("### Loading weights from {} ###".format(WEIGHT_PATH))

    model = ParaphraseGenerator(op)
    model.load_state_dict(torch.load(WEIGHT_PATH))

    print("Maximum sequence length = {}".format(28))

    with open('data/Vocab_Extra', 'rb') as f:
        vocab = pickle.load(f)

    sents = inputData(sents, 28, vocab, embed_model)

    out, _, _ = model.forward(sents)
def main():

    # get arguments ---
    print("REMOVED METEOR SCORE FOR NOW")

    parser = utils.make_parser()
    args = parser.parse_args()

    # build model

    # # get data
    data = Dataloader(args.input_json, args.input_ques_h5)

    # # make op
    op = {
        "vocab_sz": data.getVocabSize(),
        "max_seq_len": data.getSeqLength(),
        "emb_hid_dim": args.emb_hid_dim,
        "emb_dim": args.emb_dim,
        "enc_dim": args.enc_dim,
        "enc_dropout": args.enc_dropout,
        "enc_rnn_dim": args.enc_rnn_dim,
        "gen_rnn_dim": args.gen_rnn_dim,
        "gen_dropout": args.gen_dropout,
        "lr": args.learning_rate,
        "epochs": args.n_epoch
    }
    print(op)

    # # instantiate paraphrase generator
    pgen = ParaphraseGenerator(op)
    pgen.load_state_dict(
        torch.load("save/model_epoch_0"))  # REMEMBER TO CHANGE

    # setup logging
    logger = SummaryWriter(os.path.join(LOG_DIR, TIME + args.name))
    subprocess.run(['mkdir', os.path.join(GEN_DIR, TIME)],
                   check=False,
                   shell=True)

    # ready model for training

    train_loader = Data.DataLoader(
        Data.Subset(data, range(args.train_dataset_len)),
        batch_size=args.batch_size,
        shuffle=True,
    )
    test_loader = Data.DataLoader(Data.Subset(
        data,
        range(args.train_dataset_len,
              args.train_dataset_len + args.val_dataset_len)),
                                  batch_size=args.batch_size,
                                  shuffle=True)

    pgen_optim = optim.RMSprop(pgen.parameters(), lr=op["lr"])
    pgen.train()

    # train model
    pgen = pgen.to(DEVICE)
    cross_entropy_loss = nn.CrossEntropyLoss(ignore_index=data.PAD_token)

    for epoch_ in range(op["epochs"]):

        epoch = epoch_ + 1  # REMEMBER TO CHANGE

        epoch_l1 = 0
        epoch_l2 = 0
        itr = 0
        ph = []
        pph = []
        gpph = []
        pgen.train()

        for phrase, phrase_len, para_phrase, para_phrase_len, _ in tqdm(
                train_loader, ascii=True, desc="epoch" + str(epoch)):

            phrase = phrase.to(DEVICE)
            para_phrase = para_phrase.to(DEVICE)

            out, enc_out, enc_sim_phrase = pgen(
                phrase.t(),
                sim_phrase=para_phrase.t(),
                train=True,
            )

            loss_1 = cross_entropy_loss(out.permute(1, 2, 0), para_phrase)
            loss_2 = net_utils.JointEmbeddingLoss(enc_out, enc_sim_phrase)

            pgen_optim.zero_grad()
            (loss_1 + loss_2).backward()

            pgen_optim.step()

            # accumulate results

            epoch_l1 += loss_1.item()
            epoch_l2 += loss_2.item()
            ph += net_utils.decode_sequence(data.ix_to_word, phrase)
            pph += net_utils.decode_sequence(data.ix_to_word, para_phrase)
            gpph += net_utils.decode_sequence(data.ix_to_word,
                                              torch.argmax(out, dim=-1).t())

            itr += 1

            #torch.cuda.empty_cache()

        # log results

        logger.add_scalar("l2_train", epoch_l2 / itr, epoch)
        logger.add_scalar("l1_train", epoch_l1 / itr, epoch)

        scores = evaluate_scores(gpph, pph)

        for key in scores:
            logger.add_scalar(key + "_train", scores[key], epoch)

        #dump_samples(ph, pph, gpph,
        #             os.path.join(GEN_DIR, TIME,
        #                          str(epoch) + "_train.txt"))

        # start validation

        epoch_l1 = 0
        epoch_l2 = 0
        itr = 0
        ph = []
        pph = []
        gpph = []
        pgen.eval()

        with torch.no_grad():
            for phrase, phrase_len, para_phrase, para_phrase_len, _ in tqdm(
                    test_loader, ascii=True, desc="val" + str(epoch)):

                phrase = phrase.to(DEVICE)
                para_phrase = para_phrase.to(DEVICE)

                out, enc_out, enc_sim_phrase = pgen(phrase.t(),
                                                    sim_phrase=para_phrase.t())

                loss_1 = cross_entropy_loss(out.permute(1, 2, 0), para_phrase)
                loss_2 = net_utils.JointEmbeddingLoss(enc_out, enc_sim_phrase)

                epoch_l1 += loss_1.item()
                epoch_l2 += loss_2.item()
                ph += net_utils.decode_sequence(data.ix_to_word, phrase)
                pph += net_utils.decode_sequence(data.ix_to_word, para_phrase)
                gpph += net_utils.decode_sequence(
                    data.ix_to_word,
                    torch.argmax(out, dim=-1).t())

                itr += 1
                torch.cuda.empty_cache()

            logger.add_scalar("l2_val", epoch_l2 / itr, epoch)
            logger.add_scalar("l1_val", epoch_l1 / itr, epoch)

            scores = evaluate_scores(gpph, pph)

            for key in scores:
                logger.add_scalar(key + "_val", scores[key], epoch)

            #dump_samples(ph, pph, gpph,
            #             os.path.join(GEN_DIR, TIME,
            #                          str(epoch) + "_val.txt"))

        #save_model(pgen, pgen_optim, epoch, os.path.join(SAVE_DIR, TIME, str(epoch)))
        save_model(pgen, pgen_optim, epoch,
                   "save/model_epoch_{}".format(epoch))

    # wrap ups
    logger.close()
    print("Done !!")
示例#3
0
def main():

    # get arguments ---

    parser = utils.make_parser()
    args = parser.parse_args()

    # build model

    # # get data
    if args.dataset_name == 'cub':
        # length 21 since 75% of captions have length <= 19
        dataset, train_loader = get_cub_200_2011_paraphrase_combined_vocab(
            split='train_val',
            no_start_end=args.no_start_end,
            should_pad=True,
            pad_to_length=21,
            d_batch=args.batch_size)
        _, test_loader = get_cub_200_2011_paraphrase_combined_vocab(
            split='test',
            no_start_end=args.no_start_end,
            should_pad=True,
            pad_to_length=21,
            d_batch=args.batch_size)
    elif args.dataset_name == 'quora':
        # length 26 since that was used by the paper authors
        dataset, train_loader = get_quora_paraphrase_dataset_combined_vocab(
            split='train',
            no_start_end=args.no_start_end,
            should_pad=True,
            pad_to_length=26,
            d_batch=args.batch_size)
        _, test_loader = get_quora_paraphrase_dataset_combined_vocab(
            split='test',
            no_start_end=args.no_start_end,
            should_pad=True,
            pad_to_length=26,
            d_batch=args.batch_size)
    else:
        raise NotImplementedError

    # # make op
    op = {
        "vocab_sz": dataset.d_vocab,
        "max_seq_len": dataset.pad_to_length,
        "emb_hid_dim": args.emb_hid_dim,
        "emb_dim": args.emb_dim,
        "enc_dim": args.enc_dim,
        "enc_dropout": args.enc_dropout,
        "enc_rnn_dim": args.enc_rnn_dim,
        "gen_rnn_dim": args.gen_rnn_dim,
        "gen_dropout": args.gen_dropout,
        "lr": args.learning_rate,
        "epochs": args.n_epoch
    }

    # # instantiate paraphrase generator
    pgen = ParaphraseGenerator(op)

    # setup logging
    logger = SummaryWriter(os.path.join(LOG_DIR, TIME + args.name))
    subprocess.run(
        ['mkdir',
         os.path.join(GEN_DIR, TIME),
         os.path.join(SAVE_DIR, TIME)],
        check=False)

    # ready model for training
    pgen_optim = optim.RMSprop(pgen.parameters(), lr=op["lr"])
    pgen.train()

    # train model
    pgen = pgen.to(DEVICE)
    cross_entropy_loss = nn.CrossEntropyLoss(ignore_index=dataset.pad_token)

    for epoch in range(op["epochs"]):

        epoch_l1 = 0
        epoch_l2 = 0
        itr = 0
        ph = []
        pph = []
        gpph = []
        pgen.train()

        for phrase, phrase_len, para_phrase, para_phrase_len in tqdm(
                train_loader, ascii=True, desc="epoch" + str(epoch)):

            phrase = phrase.to(DEVICE)
            para_phrase = para_phrase.to(DEVICE)

            out, enc_out, enc_sim_phrase = pgen(phrase.t(),
                                                sim_phrase=para_phrase.t(),
                                                train=True)

            loss_1 = cross_entropy_loss(out.permute(1, 2, 0), para_phrase)
            loss_2 = net_utils.JointEmbeddingLoss(enc_out, enc_sim_phrase)

            pgen_optim.zero_grad()
            (loss_1 + loss_2).backward()

            pgen_optim.step()

            # accumulate results

            epoch_l1 += loss_1.item()
            epoch_l2 += loss_2.item()
            ph += [dataset.decode_caption(p) for p in phrase]
            pph += [dataset.decode_caption(p) for p in para_phrase]
            gpph += [
                dataset.decode_caption(p)
                for p in torch.argmax(out, dim=-1).t()
            ]

            itr += 1
            torch.cuda.empty_cache()

        # log results

        logger.add_scalar("l2_train", epoch_l2 / itr, epoch)
        logger.add_scalar("l1_train", epoch_l1 / itr, epoch)

        scores = evaluate_scores(gpph, pph)

        for key in scores:
            logger.add_scalar(key + "_train", scores[key], epoch)

        dump_samples(ph, pph, gpph,
                     os.path.join(GEN_DIR, TIME,
                                  str(epoch) + "_train.txt"))
        # start validation

        epoch_l1 = 0
        epoch_l2 = 0
        itr = 0
        ph = []
        pph = []
        gpph = []
        pgen.eval()

        with torch.no_grad():
            for phrase, phrase_len, para_phrase, para_phrase_len in tqdm(
                    test_loader, ascii=True, desc="val" + str(epoch)):

                phrase = phrase.to(DEVICE)
                para_phrase = para_phrase.to(DEVICE)

                out, enc_out, enc_sim_phrase = pgen(phrase.t(),
                                                    sim_phrase=para_phrase.t())

                loss_1 = cross_entropy_loss(out.permute(1, 2, 0), para_phrase)
                loss_2 = net_utils.JointEmbeddingLoss(enc_out, enc_sim_phrase)

                epoch_l1 += loss_1.item()
                epoch_l2 += loss_2.item()
                ph += [dataset.decode_caption(p) for p in phrase]
                pph += [dataset.decode_caption(p) for p in para_phrase]
                gpph += [
                    dataset.decode_caption(p)
                    for p in torch.argmax(out, dim=-1).t()
                ]

                itr += 1
                torch.cuda.empty_cache()

            logger.add_scalar("l2_val", epoch_l2 / itr, epoch)
            logger.add_scalar("l1_val", epoch_l1 / itr, epoch)

            scores = evaluate_scores(gpph, pph)

            for key in scores:
                logger.add_scalar(key + "_val", scores[key], epoch)

            dump_samples(ph, pph, gpph,
                         os.path.join(GEN_DIR, TIME,
                                      str(epoch) + "_val.txt"))

        save_model(pgen, pgen_optim, epoch,
                   os.path.join(SAVE_DIR, TIME, str(epoch)))

    # wrap ups
    logger.close()
    print("Done !!")