Пример #1
0
def do_ranking(model, vocab):

    dataset = du.NarrativeClozeDataset(args.data,
                                       vocab,
                                       src_seq_length=MAX_EVAL_SEQ_LEN,
                                       min_seq_length=MIN_EVAL_SEQ_LEN)
    batches = BatchIter(dataset,
                        args.batch_size,
                        sort_key=lambda x: len(x.actual),
                        train=False,
                        device=device)

    ranked_acc = 0.0
    if args.emb_type:
        print("RANKING WITH ROLE EMB")
        vocab2 = du.load_vocab(args.vocab2)
        role_dataset = du.NarrativeClozeDataset(
            args.role_data,
            vocab2,
            src_seq_length=MAX_EVAL_SEQ_LEN,
            min_seq_length=MIN_EVAL_SEQ_LEN)
        role_batches = BatchIter(role_dataset,
                                 args.batch_size,
                                 sort_key=lambda x: len(x.actual),
                                 train=False,
                                 device=device)

        assert len(dataset) == len(
            role_dataset), "Dataset and Role dataset must be of same length."

        for iteration, (bl, rbl) in enumerate(zip(batches, role_batches)):

            if (iteration + 1) % 25 == 0:
                print("iteration {}".format(iteration + 1))

            ## DATA STEPS
            all_texts = [
                bl.actual, bl.actual_tgt, bl.dist1, bl.dist1_tgt, bl.dist2,
                bl.dist2_tgt, bl.dist3, bl.dist3_tgt, bl.dist4, bl.dist4_tgt,
                bl.dist5, bl.dist5_tgt
            ]  # each is a tup

            all_roles = [
                rbl.actual, rbl.dist1, rbl.dist2, rbl.dist3, rbl.dist4,
                rbl.dist5
            ]  # tgts are not needed for role
            assert len(all_roles) == 6, "6 = 6 * 1."

            assert len(all_texts) == 12, "12 = 6 * 2."

            all_texts_vars = []
            all_roles_vars = []

            if use_cuda:
                for tup in all_texts:
                    all_texts_vars.append((Variable(tup[0].cuda(),
                                                    volatile=True), tup[1]))
                for tup in all_roles:
                    all_roles_vars.append((Variable(tup[0].cuda(),
                                                    volatile=True), tup[1]))

            else:
                for tup in all_texts:
                    all_texts_vars.append((Variable(tup[0],
                                                    volatile=True), tup[1]))
                for tup in all_roles:
                    all_roles_vars.append((Variable(tup[0],
                                                    volatile=True), tup[1]))

            # will itetrate 2 at a time using iterator and next
            vars_iter = iter(all_texts_vars)
            roles_iter = iter(all_roles_vars)

            # run the model and collect ppls for all 6 sentences
            pps = []
            for tup in vars_iter:
                ## INIT AND DECODE before every sentence
                hidden = model.init_hidden(args.batch_size)
                next_tup = next(vars_iter)
                role_tup = next(roles_iter)
                nll = calc_perplexity(args, model, tup[0], vocab, next_tup[0],
                                      next_tup[1], hidden, role_tup[0])
                pp = torch.exp(nll)
                #print("NEG-LOSS {} PPL {}".format(nll.data[0], pp.data[0]))
                pps.append(pp.data.numpy()[0])

            # low perplexity == top ranked sentence- correct answer is the first one of course
            assert len(pps) == 6, "6 targets."
            #print("\n")
            all_texts_str = [
                transform(text[0].data.numpy()[0], vocab.itos)
                for text in all_texts_vars
            ]
            #print("ALL: {}".format(all_texts_str))
            min_index = np.argmin(pps)
            if min_index == 0:
                ranked_acc += 1
                #print("TARGET: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos)))
                #print("CORRECT: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos)))
            #else:
            # print the ones that are wrong
            #print("TARGET: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos)))
            #print("WRONG: {}".format(transform(all_texts_vars[min_index+2][0].data.numpy()[0], vocab.itos)))

            if (iteration + 1) == args.max_decode:
                print("Max decode reached. Exiting.")
                break

        ranked_acc /= (iteration + 1) * 1 / 100  # multiplying to get percent
        print("Average acc(%): {}".format(ranked_acc))
        return ranked_acc

    else:  # THIS IS FOR MODEL WITHOUT ROLE EMB

        print("RANKING WITHOUT ROLE EMB.")
        for iteration, bl in enumerate(batches):

            if (iteration + 1) % 25 == 0:
                print("iteration {}".format(iteration + 1))

            ## DATA STEPS
            all_texts = [
                bl.actual, bl.actual_tgt, bl.dist1, bl.dist1_tgt, bl.dist2,
                bl.dist2_tgt, bl.dist3, bl.dist3_tgt, bl.dist4, bl.dist4_tgt,
                bl.dist5, bl.dist5_tgt
            ]  # each is a tup

            assert len(all_texts) == 12, "12 = 6 * 2."

            all_texts_vars = []
            if use_cuda:
                for tup in all_texts:
                    all_texts_vars.append((Variable(tup[0].cuda(),
                                                    volatile=True), tup[1]))
            else:
                for tup in all_texts:
                    all_texts_vars.append((Variable(tup[0],
                                                    volatile=True), tup[1]))

            # will itetrate 2 at a time using iterator and next
            vars_iter = iter(all_texts_vars)

            # run the model for all 6 sentences
            pps = []
            for tup in vars_iter:
                ## INIT AND DECODE before every sentence
                hidden = model.init_hidden(args.batch_size)
                next_tup = next(vars_iter)

                nll = calc_perplexity(args, model, tup[0], vocab, next_tup[0],
                                      next_tup[1], hidden)
                pp = torch.exp(nll)
                #print("NEG-LOSS {} PPL {}".format(nll.data[0], pp.data[0]))
                pps.append(pp.data.numpy()[0])

            # low perplexity == top ranked sentence- correct answer is the first one of course
            assert len(pps) == 6, "6 targets."
            #print("\n")
            all_texts_str = [
                transform(text[0].data.numpy()[0], vocab.itos)
                for text in all_texts_vars
            ]
            #print("ALL: {}".format(all_texts_str))
            min_index = np.argmin(pps)
            if min_index == 0:
                ranked_acc += 1
                #print("TARGET: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos)))
                #print("CORRECT: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos)))
            #else:
            # print the ones that are wrong
            #print("TARGET: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos)))
            #print("WRONG: {}".format(transform(all_texts_vars[min_index+2][0].data.numpy()[0], vocab.itos)))

            if (iteration + 1) == args.max_decode:
                print("Max decode reached. Exiting.")
                break

        ranked_acc /= (iteration + 1) * 1 / 100  # multiplying to get percent
        print("Average acc(%): {}".format(ranked_acc))
        return ranked_acc
Пример #2
0
def gen_from_seed(model, vocab, eos_id, pad_id, sos_id, tup_id):

    if args.emb_type:  # GEN FROM SEED WITH ROLE EMB
        print("GEN SEED WITH ROLE EMB")
        vocab2 = du.load_vocab(args.vocab2)
        # will use this to feed in role ids in beam decode
        ROLES = [
            vocab2.stoi[TUP_TOK], vocab2.stoi[VERB], vocab2.stoi[SUB],
            vocab2.stoi[OBJ], vocab2.stoi[PREP]
        ]
        dataset = du.LMRoleSentenceDataset(
            args.data,
            vocab,
            args.role_data,
            vocab2,
            src_seq_length=MAX_EVAL_SEQ_LEN,
            min_seq_length=MIN_EVAL_SEQ_LEN)  #put in filter pred later
        dataset = du.LMRoleSentenceDataset(args.data, vocab, args.role_data,
                                           vocab2)  #put in filter pred later
        batches = BatchIter(dataset,
                            args.batch_size,
                            sort_key=lambda x: len(x.text),
                            train=False,
                            device=device)

        for iteration, bl in enumerate(batches):

            if (iteration + 1) % 25 == 0:
                print("iteration {}".format(iteration + 1))

            ## DATA STEPS
            batch, batch_lens = bl.text
            target, target_lens = bl.target
            role, role_lens = bl.role

            if use_cuda:
                batch = Variable(batch.cuda(), volatile=True)
                role = Variable(role.cuda(), volatile=True)
            else:
                batch = Variable(batch, volatile=True)
                role = Variable(role, volatile=True)

            ## INIT AND DECODE
            hidden = model.init_hidden(args.batch_size)
            #run the model first on t-1 events, except last word. we know corresponding role ids as well.
            seq_len = batch.size(1)
            for i in range(seq_len - 1):
                inp = batch[:, i]
                inp = inp.unsqueeze(args.batch_size)
                typ = role[:, i]
                typ = typ.unsqueeze(1)
                _, hidden = model(inp, hidden, typ)

            #print("seq len {}, decode after {} steps".format(seq_len, i+1))
            # beam set current state to last word in the sequence
            beam_inp = batch[:, i + 1]
            # do not need this anymore as assuming last sequence role obj is prep.
            #role_inp = role[:, i+1]
            #           print("ROLES LIST: {}".format(ROLES))
            #           print("FIRST ID: {}".format(role[:, i+1]))

            # init beam initializes the beam with the last sequence element. ROLE is a list of roe type ids.
            outputs = beam_decode(model,
                                  beam_inp,
                                  hidden,
                                  args.max_len_decode,
                                  args.beam_size,
                                  pad_id,
                                  sos_id,
                                  eos_id,
                                  tup_idx=tup_id,
                                  init_beam=True,
                                  roles=ROLES)
            predicted_events = get_pred_events(outputs, vocab)

            print("CONTEXT: {}".format(
                transform(batch.data.squeeze(), vocab.itos)))
            print("PRED_t: {}".format(
                predicted_events))  # n_best stitched together.

            if (iteration + 1) == args.max_decode:
                print("Max decode reached. Exiting.")
                break

    else:
        print("GEN SEED WITHOUT ROLE EMB")
        dataset = du.LMSentenceDataset(
            args.data,
            vocab,
            src_seq_length=MAX_EVAL_SEQ_LEN,
            min_seq_length=MIN_EVAL_SEQ_LEN)  #put in filter pred later
        batches = BatchIter(dataset,
                            args.batch_size,
                            sort_key=lambda x: len(x.text),
                            train=False,
                            device=device)
        for iteration, bl in enumerate(batches):

            if (iteration + 1) % 25 == 0:
                print("iteration {}".format(iteration + 1))

            ## DATA STEPS
            batch, batch_lens = bl.text
            target, target_lens = bl.target

            if use_cuda:
                batch = Variable(batch.cuda(), volatile=True)
            else:
                batch = Variable(batch, volatile=True)

            ## INIT AND DECODE
            hidden = model.init_hidden(args.batch_size)

            #run the model first on t-1 events, except last word
            seq_len = batch.size(1)
            for i in range(seq_len - 1):
                inp = batch[:, i]
                inp = inp.unsqueeze(args.batch_size)
                _, hidden = model(inp, hidden)

            #print("seq len {}, decode after {} steps".format(seq_len, i+1))
            # beam set current state to last word in the sequence
            beam_inp = batch[:, i + 1]

            # init beam initializesthe beam with the last sequence element
            outputs = beam_decode(model,
                                  beam_inp,
                                  hidden,
                                  args.max_len_decode,
                                  args.beam_size,
                                  pad_id,
                                  sos_id,
                                  eos_id,
                                  tup_idx=tup_id,
                                  init_beam=True)
            predicted_events = get_pred_events(outputs, vocab)

            print("CONTEXT: {}".format(
                transform(batch.data.squeeze(), vocab.itos)))
            print("PRED_t: {}".format(
                predicted_events))  # n_best stitched together.

            if (iteration + 1) == args.max_decode:
                print("Max decode reached. Exiting.")
                break
Пример #3
0
def get_perplexity(model, vocab):
    total_loss = 0.0
    if args.emb_type:  # GET PERPLEXITY WITH ROLE EMB
        print("PERPLEXITY WITH ROLE EMB")
        vocab2 = du.load_vocab(args.vocab2)
        dataset = du.LMRoleSentenceDataset(
            args.data,
            vocab,
            args.role_data,
            vocab2,
            src_seq_length=MAX_EVAL_SEQ_LEN,
            min_seq_length=MIN_EVAL_SEQ_LEN)  #put in filter pred later
        batches = BatchIter(dataset,
                            args.batch_size,
                            sort_key=lambda x: len(x.text),
                            train=False,
                            device=device)

        print("DATASET {}".format(len(dataset)))
        for iteration, bl in enumerate(batches):

            if (iteration + 1) % 25 == 0:
                print("iteration {}".format(iteration + 1))

            ## DATA STEPS
            batch, batch_lens = bl.text
            target, target_lens = bl.target
            role, role_lens = bl.role

            if use_cuda:
                batch = Variable(batch.cuda(), volatile=True)
                target = Variable(target.cuda(), volatile=True)
                role = Variable(role.cuda(), volatile=True)
            else:
                batch = Variable(batch, volatile=True)
                target = Variable(target, volatile=True)
                role = Variable(role, volatile=True)

            ## INIT AND DECODE
            hidden = model.init_hidden(args.batch_size)
            ce_loss = calc_perplexity(args, model, batch, vocab, target,
                                      target_lens, hidden, role)
            #print("Loss {}".format(ce_loss))
            total_loss = total_loss + ce_loss.data[0]

            if (iteration + 1) == args.max_decode:
                print("Max decode reached. Exiting.")
                break

        # after iterating over all examples
        loss = total_loss / (iteration + 1)
        print("Average Loss: {}".format(loss))
        return loss

    else:
        print("PERPLEXITY WITHOUT ROLE EMB")
        dataset = du.LMSentenceDataset(
            args.data,
            vocab,
            src_seq_length=MAX_EVAL_SEQ_LEN,
            min_seq_length=MIN_EVAL_SEQ_LEN)  #put in filter pred later
        batches = BatchIter(dataset,
                            args.batch_size,
                            sort_key=lambda x: len(x.text),
                            train=False,
                            device=device)
        for iteration, bl in enumerate(batches):

            if (iteration + 1) % 25 == 0:
                print("iteration {}".format(iteration + 1))

            ## DATA STEPS
            batch, batch_lens = bl.text
            target, target_lens = bl.target

            if use_cuda:
                batch = Variable(batch.cuda(), volatile=True)
                target = Variable(target, volatile=True)
            else:
                batch = Variable(batch, volatile=True)
                target = Variable(target, volatile=True)

            ## INIT AND DECODE
            hidden = model.init_hidden(args.batch_size)
            ce_loss = calc_perplexity(args, model, batch, vocab, target,
                                      target_lens, hidden)
            #print("Loss {}".format(ce_loss))
            total_loss = total_loss + ce_loss.data[0]

            if (iteration + 1) == args.max_decode:
                print("Max decode reached. Exiting.")
                break

        # after iterating over all examples
        loss = total_loss / (iteration + 1)
        print("Average Loss: {}".format(loss))
        return loss
Пример #4
0
def classic_train(args, args_dict, args_info):
    """
    Train the model in the ol' fashioned way, just like grandma used to
    Args
        args (argparse.ArgumentParser)
    """
    if args.cuda and torch.cuda.is_available():
        print("Using cuda")
        use_cuda = True
    elif args.cuda and not torch.cuda.is_available():
        print("You do not have CUDA, turning cuda off")
        use_cuda = False
    else:
        use_cuda = False

    #Load the data
    print("\nLoading Vocab")
    print('args.vocab: ', args.vocab)
    vocab, verb_max_idx = du.load_vocab(args.vocab)
    print("Vocab Loaded, Size {}".format(len(vocab.stoi.keys())))
    print(vocab.itos[:40])
    args_dict["vocab"] = len(vocab.stoi.keys())
    vocab2 = du.load_vocab(args.frame_vocab_address, is_Frame=True)
    print(vocab2.itos[:40])
    print("Frames-Vocab Loaded, Size {}".format(len(vocab2.stoi.keys())))
    total_frames = len(vocab2.stoi.keys())
    args.total_frames = total_frames
    args.num_latent_values = args.total_frames
    print('total frames: ', args.total_frames)
    experiment_name = 'SSDVAE_wotemp_{}_eps_{}_num_{}_seed_{}'.format(
        'chain_event', str(args_dict['obsv_prob']), str(args_dict['exp_num']),
        str(args_dict['seed']))

    experiment_name = '{}_eps_{}_num_{}_seed_{}'.format(
        'chain_event', str(args_dict['obsv_prob']), str(args_dict['exp_num']),
        str(args_dict['seed']))

    if args.use_pretrained:
        pretrained = GloVe(name='6B',
                           dim=args.emb_size,
                           unk_init=torch.Tensor.normal_)
        vocab.load_vectors(pretrained)
        print("Vectors Loaded")

    print("Loading Dataset")
    dataset = du.SentenceDataset(path=args.train_data,
                                 path2=args.train_frames,
                                 vocab=vocab,
                                 vocab2=vocab2,
                                 num_clauses=args.num_clauses,
                                 add_eos=False,
                                 is_ref=True,
                                 obsv_prob=args.obsv_prob)

    print("Finished Loading Dataset {} examples".format(len(dataset)))
    batches = BatchIter(dataset,
                        args.batch_size,
                        sort_key=lambda x: len(x.text),
                        train=True,
                        sort_within_batch=True,
                        device=-1)
    data_len = len(dataset)

    if args.load_model:
        print("Loading the Model")
        model = torch.load(args.load_model)
    else:
        print("Creating the Model")
        bidir_mod = 2 if args.bidir else 1
        latents = example_tree(
            args.num_latent_values,
            (bidir_mod * args.enc_hid_size, args.latent_dim),
            frame_max=args.total_frames,
            padding_idx=vocab2.stoi['<pad>'],
            use_cuda=use_cuda,
            nohier_mode=args.nohier)  #assume bidirectional

        hidsize = (args.enc_hid_size, args.dec_hid_size)
        model = SSDVAE(args.emb_size,
                       hidsize,
                       vocab,
                       latents,
                       layers=args.nlayers,
                       use_cuda=use_cuda,
                       pretrained=args.use_pretrained,
                       dropout=args.dropout,
                       frame_max=args.total_frames,
                       latent_dim=args.latent_dim,
                       verb_max_idx=verb_max_idx)

    #create the optimizer
    if args.load_opt:
        print("Loading the optimizer state")
        optimizer = torch.load(args.load_opt)
    else:
        print("Creating the optimizer anew")
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    start_time = time.time()  #start of epoch 1
    curr_epoch = 1
    valid_loss = [0.0]
    min_ppl = 1e10
    print("Loading Validation Dataset.")
    val_dataset = du.SentenceDataset(path=args.valid_data,
                                     path2=args.valid_frames,
                                     vocab=vocab,
                                     vocab2=vocab2,
                                     num_clauses=args.num_clauses,
                                     add_eos=False,
                                     is_ref=True,
                                     obsv_prob=0.0,
                                     print_valid=True)

    print("Finished Loading Validation Dataset {} examples.".format(
        len(val_dataset)))
    val_batches = BatchIter(val_dataset,
                            args.batch_size,
                            sort_key=lambda x: len(x.text),
                            train=False,
                            sort_within_batch=True,
                            device=-1)
    for idx, item in enumerate(val_batches):
        if idx == 0:
            break
        token_rev = [vocab.itos[int(v.numpy())] for v in item.target[0][-1]]
        frame_rev = [vocab2.itos[int(v.numpy())] for v in item.frame[0][-1]]
        ref_frame = [vocab2.itos[int(v.numpy())] for v in item.ref[0][-1]]

        print('token_rev:', token_rev, len(token_rev), "lengths: ",
              item.target[1][-1])
        print('frame_rev:', frame_rev, len(frame_rev), "lengths: ",
              item.frame[1][-1])
        print('ref_frame:', ref_frame, len(ref_frame), "lengths: ",
              item.ref[1][-1])
        print('-' * 50)
    print('Model_named_params:{}'.format(model.named_parameters()))

    for iteration, bl in enumerate(
            batches
    ):  #this will continue on forever (shuffling every epoch) till epochs finished
        batch, batch_lens = bl.text
        f_vals, f_vals_lens = bl.frame
        target, target_lens = bl.target
        f_ref, _ = bl.ref

        if use_cuda:
            batch = Variable(batch.cuda())
            f_vals = Variable(f_vals.cuda())
        else:
            batch = Variable(batch)
            f_vals = Variable(f_vals)

        model.zero_grad()
        latent_values, latent_root, diff, dec_outputs = model(batch,
                                                              batch_lens,
                                                              f_vals=f_vals)

        topics_dict, real_sentence, next_frames_dict, word_to_frame = show_inference(
            model, batch, vocab, vocab2, f_vals, f_ref, args)
        loss, _ = monolithic_compute_loss(iteration,
                                          model,
                                          target,
                                          target_lens,
                                          latent_values,
                                          latent_root,
                                          diff,
                                          dec_outputs,
                                          use_cuda,
                                          args=args,
                                          topics_dict=topics_dict,
                                          real_sentence=real_sentence,
                                          next_frames_dict=next_frames_dict,
                                          word_to_frame=word_to_frame,
                                          train=True,
                                          show=True)

        # backward propagation
        loss.backward()
        # Gradient clipping
        torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
        # Optimize
        optimizer.step()

        # End of an epoch - run validation
        if iteration % 10 == 0:
            print("\nFinished Training Epoch/iteration {}/{}".format(
                curr_epoch, iteration))
            # do validation
            valid_logprobs = 0.0
            valid_lengths = 0.0
            valid_loss = 0.0
            with torch.no_grad():
                for v_iteration, bl in enumerate(val_batches):
                    batch, batch_lens = bl.text
                    f_vals, f_vals_lens = bl.frame
                    target, target_lens = bl.target
                    f_ref, _ = bl.ref
                    batch_lens = batch_lens.cpu()
                    if use_cuda:
                        batch = Variable(batch.cuda())
                        f_vals = Variable(f_vals.cuda())
                    else:
                        batch = Variable(batch)
                        f_vals = Variable(f_vals)
                    latent_values, latent_root, diff, dec_outputs = model(
                        batch, batch_lens, f_vals=f_vals)
                    topics_dict, real_sentence, next_frames_dict, word_to_frame = show_inference(
                        model, batch, vocab, vocab2, f_vals, f_ref, args)
                    loss, ce_loss = monolithic_compute_loss(
                        iteration,
                        model,
                        target,
                        target_lens,
                        latent_values,
                        latent_root,
                        diff,
                        dec_outputs,
                        use_cuda,
                        args=args,
                        topics_dict=topics_dict,
                        real_sentence=real_sentence,
                        next_frames_dict=next_frames_dict,
                        word_to_frame=word_to_frame,
                        train=False,
                        show=False)

                    valid_loss = valid_loss + ce_loss.data.clone()
                    valid_logprobs += ce_loss.data.clone().cpu().numpy(
                    ) * target_lens.sum().cpu().data.numpy()
                    valid_lengths += target_lens.sum().cpu().data.numpy()
                    # print("valid_lengths: ",valid_lengths[0])

            nll = valid_logprobs / valid_lengths
            ppl = np.exp(nll)
            valid_loss = valid_loss / (v_iteration + 1)
            print("**Validation loss {:.2f}.**\n".format(valid_loss[0]))
            print("**Validation NLL {:.2f}.**\n".format(nll))
            print("**Validation PPL {:.2f}.**\n".format(ppl))
            args_dict_wandb = {
                "val_nll": nll,
                "val_ppl": ppl,
                "valid_loss": valid_loss
            }
            if ppl < min_ppl:
                min_ppl = ppl
                args_dict["min_ppl"] = min_ppl
                dir_path = os.path.dirname(os.path.realpath(__file__))
                save_file = "".join([
                    "_" + str(key) + "_" + str(value)
                    for key, value in args_dict.items() if key != "min_ppl"
                ])
                args_to_md(model="chain", args_dict=args_dict)
                model_path = os.path.join(dir_path + "/saved_models/chain_" +
                                          save_file + ".pt")
                torch.save(model, model_path)
                config_path = os.path.join(dir_path + "/saved_configs/chain_" +
                                           save_file + ".pkl")
                with open(config_path, "wb") as f:
                    pickle.dump((args_dict, args_info), f)
            print('\t==> min_ppl {:4.4f} '.format(min_ppl))
Пример #5
0
def generate(args):
    """
    Use the trained model for decoding
    Args
        args (argparse.ArgumentParser)
    """
    if args.cuda and torch.cuda.is_available():
        device = 0
        use_cuda = True
    elif args.cuda and not torch.cuda.is_available():
        print("You do not have CUDA, turning cuda off")
        device = -1
        use_cuda = False
    else:
        device = -1
        use_cuda = False

    #Load the vocab
    vocab = du.load_vocab(args.vocab)
    eos_id = vocab.stoi[EOS_TOK]
    pad_id = vocab.stoi[PAD_TOK]

    if args.ranking:  # default is HARD one, the 'Inverse Narrative Cloze' in the paper
        dataset = du.NarrativeClozeDataset(args.valid_data,
                                           vocab,
                                           src_seq_length=MAX_EVAL_SEQ_LEN,
                                           min_seq_length=MIN_EVAL_SEQ_LEN,
                                           LM=False)
        # Batch size during decoding is set to 1
        batches = BatchIter(dataset,
                            1,
                            sort_key=lambda x: len(x.actual),
                            train=False,
                            device=-1)
    else:
        dataset = du.SentenceDataset(args.valid_data,
                                     vocab,
                                     src_seq_length=MAX_EVAL_SEQ_LEN,
                                     min_seq_length=MIN_EVAL_SEQ_LEN,
                                     add_eos=False)  #put in filter pred later
        # Batch size during decoding is set to 1
        batches = BatchIter(dataset,
                            args.batch_size,
                            sort_key=lambda x: len(x.text),
                            train=False,
                            device=-1)

    data_len = len(dataset)

    #Create the model
    with open(args.load, 'rb') as fi:
        if not use_cuda:
            model = torch.load(fi, map_location=lambda storage, loc: storage)
        else:
            model = torch.load(fi, map_location=torch.device('cuda'))

    if not hasattr(model.latent_root, 'nohier'):
        model.latent_root.set_nohier(args.nohier)  #for backwards compatibility

    model.decoder.eval()
    model.set_use_cuda(use_cuda)

    #For reconstruction
    if args.perplexity:
        loss = calc_perplexity(args, model, batches, vocab, data_len)
        print("Loss = {}".format(loss))
    elif args.schema:
        generate_from_seed(args, model, batches, vocab, data_len)
    elif args.ranking:
        do_ranking(args, model, batches, vocab, data_len, use_cuda)
    else:
        #        sample_outputs(model, vocab)
        reconstruct(args, model, batches, vocab)
Пример #6
0
def train(args):
    """
    Train the model in the ol' fashioned way, just like grandma used to
    Args
        args (argparse.ArgumentParser)
    """
    #Load the data
    logging.info("Loading Vocab")
    evocab = du.load_vocab(args.evocab)
    tvocab = du.load_vocab(args.tvocab)
    logging.info("Event Vocab Loaded, Size {}".format(len(evocab.stoi.keys())))
    logging.info("Text Vocab Loaded, Size {}".format(len(tvocab.stoi.keys())))

    if args.use_pretrained:
        pretrained = GloVe(name='6B',
                           dim=args.text_embed_size,
                           unk_init=torch.Tensor.normal_)
        tvocab = du.load_vectors(pretrained)
        logging.info("Loaded Pretrained Word Embeddings")

    if args.load_model:
        logging.info("Loading the Model")
        model = torch.load(args.load_model, map_location=args.device)
    else:
        logging.info("Creating the Model")
        if args.onehot_events:
            logging.info(
                "Model Type: SemiNaiveAdjustmentEstimatorOneHotEvents")
            model = estimators.SemiNaiveAdjustmentEstimatorOneHotEvents(
                args, evocab, tvocab)
        else:
            logging.info("Model Type: SemiNaiveAdjustmentEstimator")
            model = estimators.SemiNaiveAdjustmentEstimator(
                args, evocab, tvocab)

    if args.finetune:
        assert args.load_model
        logging.info("Finetuning...")
        if args.freeze:
            logging.info("Freezing...")
            for param in model.parameters():
                param.requires_grad = False
        model = estimators.AdjustmentEstimator(args, evocab, tvocab, model)

        #Still finetune the last layer even if freeze is on (if freeze is on , then everything else is frozen)
        model.expected_outcome.event_text_logits_mlp.weight.requires_grad = True
        model.expected_outcome.event_text_logits_mlp.bias.requires_grad = True

        logging.info("Trainable Params: {}".format(
            [x[0] for x in model.named_parameters() if x[1].requires_grad]))

    model = model.to(device=args.device)

    #create the optimizer
    if args.load_opt:
        logging.info("Loading the optimizer state")
        optimizer = torch.load(args.load_opt)
    else:
        if args.optimizer == 'adagrad':
            logging.info("Creating Adagrad optimizer anew")
            optimizer = torch.optim.Adagrad(filter(lambda x: x.requires_grad,
                                                   model.parameters()),
                                            lr=args.lr)
        elif args.optimizer == 'sgd':
            logging.info("Creating SGD optimizer anew")
            optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad,
                                               model.parameters()),
                                        lr=args.lr)
        else:
            logging.info("Creating Adam optimizer anew")
            optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad,
                                                model.parameters()),
                                         lr=args.lr)

    logging.info("Loading Datasets")
    min_size = model.text_encoder.largest_ngram_size  #Add extra pads if text size smaller than largest CNN kernel size

    if args.load_pickle:
        logging.info("Loading Train from Pickled Data")
        with open(args.train_data, 'rb') as pfi:
            pickled_examples = pickle.load(pfi)
        train_dataset = du.InstanceDataset("",
                                           evocab,
                                           tvocab,
                                           min_size=min_size,
                                           pickled_examples=pickled_examples)
    else:
        train_dataset = du.InstanceDataset(args.train_data,
                                           evocab,
                                           tvocab,
                                           min_size=min_size)
    valid_dataset = du.InstanceDataset(args.valid_data,
                                       evocab,
                                       tvocab,
                                       min_size=min_size)

    #Remove UNK events from the e1prev_intext attribute so they don't mess up avg encoders
    #  train_dataset.filter_examples(['e1prev_intext'])  #These take really long time! Will have to figure something out...
    #  valid_dataset.filter_examples(['e1prev_intext'])
    logging.info("Finished Loading Training Dataset {} examples".format(
        len(train_dataset)))
    logging.info("Finished Loading Valid Dataset {} examples".format(
        len(valid_dataset)))

    train_batches = BatchIter(train_dataset,
                              args.batch_size,
                              sort_key=lambda x: len(x.allprev),
                              train=True,
                              repeat=False,
                              shuffle=True,
                              sort_within_batch=True,
                              device=None)
    valid_batches = BatchIter(valid_dataset,
                              args.batch_size,
                              sort_key=lambda x: len(x.allprev),
                              train=False,
                              repeat=False,
                              shuffle=False,
                              sort_within_batch=True,
                              device=None)
    train_data_len = len(train_dataset)
    valid_data_len = len(valid_dataset)

    loss_func = nn.CrossEntropyLoss()

    start_time = time.time()  #start of epoch 1
    best_valid_loss = float('inf')
    best_epoch = args.epochs

    if args.finetune:
        vloss = validation(args, valid_batches, model, loss_func)
        logging.info("Pre Finetune Validation Loss: {}".format(vloss))

    #MAIN TRAINING LOOP
    for curr_epoch in range(args.epochs):
        prev_losses = []
        for iteration, inst in enumerate(train_batches):
            instance = du.send_instance_to(inst, args.device)

            model.train()
            model.zero_grad()
            model_outputs = model(instance)

            exp_outcome_out = model_outputs[
                EXP_OUTCOME_COMPONENT]  #[batch X num events], output predication for e2
            exp_outcome_loss = loss_func(exp_outcome_out, instance.e2)
            loss = exp_outcome_loss

            loss.backward()
            torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
            optimizer.step()

            prev_losses.append(loss.cpu().data)
            prev_losses = prev_losses[-50:]

            if (iteration % args.log_every == 0) and iteration != 0:
                past_50_avg = sum(prev_losses) / len(prev_losses)
                logging.info(
                    "Epoch/iteration {}/{}, Past 50 Average Loss {}, Best Val {} at Epoch {}"
                    .format(
                        curr_epoch, iteration, past_50_avg, 'NA' if
                        best_valid_loss == float('inf') else best_valid_loss,
                        'NA' if best_epoch == args.epochs else best_epoch))

            if (iteration % args.validate_after == 0) and iteration != 0:
                logging.info(
                    "Running Validation at Epoch/iteration {}/{}".format(
                        curr_epoch, iteration))
                new_valid_loss = validation(args, valid_batches, model,
                                            loss_func)
                logging.info(
                    "Validation loss at Epoch/iteration {}/{}: {:.3f} - Best Validation Loss: {:.3f}"
                    .format(curr_epoch, iteration, new_valid_loss,
                            best_valid_loss))
                if new_valid_loss < best_valid_loss:
                    logging.info(
                        "New Validation Best...Saving Model Checkpoint")
                    best_valid_loss = new_valid_loss
                    best_epoch = curr_epoch
                    #torch.save(model, "{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, curr_epoch, best_valid_loss))
                    #torch.save(optimizer, "{}.{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, "optimizer", curr_epoch, best_valid_loss))
                    torch.save(model, "{}".format(args.save_model))
                    torch.save(optimizer,
                               "{}_optimizer".format(args.save_model))

        #END OF EPOCH
        logging.info("End of Epoch {}, Running Validation".format(curr_epoch))
        new_valid_loss = validation(args, valid_batches, model, loss_func)
        logging.info(
            "Validation loss at end of Epoch {}: {:.3f} - Best Validation Loss: {:.3f}"
            .format(curr_epoch, new_valid_loss, best_valid_loss))
        if new_valid_loss < best_valid_loss:
            logging.info("New Validation Best...Saving Model Checkpoint")
            best_valid_loss = new_valid_loss
            best_epoch = curr_epoch
            #torch.save(model, "{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, curr_epoch, best_valid_loss))
            #torch.save(optimizer, "{}.{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, "optimizer", curr_epoch, best_valid_loss))
            torch.save(model, "{}".format(args.save_model))
            torch.save(optimizer, "{}_optimizer".format(args.save_model))

        if curr_epoch - best_epoch >= args.stop_after:
            logging.info(
                "No improvement in {} epochs, terminating at epoch {}...".
                format(args.stop_after, curr_epoch))
            logging.info("Best Validation Loss: {:.2f} at Epoch {}".format(
                best_valid_loss, best_epoch))
            break
Пример #7
0
def classic_train(args):
    """
    Train the model in the ol' fashioned way, just like grandma used to
    Args
        args (argparse.ArgumentParser)
    """
    if args.cuda and torch.cuda.is_available():
        print("Using cuda")
        use_cuda = True
    elif args.cuda and not torch.cuda.is_available():
        print("You do not have CUDA, turning cuda off")
        use_cuda = False
    else:
        use_cuda = False

    #Load the data
    print("\nLoading Vocab")
    vocab = du.load_vocab(args.vocab)
    print("Vocab Loaded, Size {}".format(len(vocab.stoi.keys())))

    if args.use_pretrained:
        pretrained = GloVe(name='6B',
                           dim=args.emb_size,
                           unk_init=torch.Tensor.normal_)
        vocab.load_vectors(pretrained)
        print("Vectors Loaded")

    print("Loading Dataset")
    dataset = du.SentenceDataset(args.train_data,
                                 vocab,
                                 args.src_seq_length,
                                 add_eos=False)  #put in filter pred later
    print("Finished Loading Dataset {} examples".format(len(dataset)))
    batches = BatchIter(dataset,
                        args.batch_size,
                        sort_key=lambda x: len(x.text),
                        train=True,
                        sort_within_batch=True,
                        device=-1)
    data_len = len(dataset)

    if args.load_model:
        print("Loading the Model")
        model = torch.load(args.load_model)
    else:
        print("Creating the Model")
        bidir_mod = 2 if args.bidir else 1
        latents = example_tree(
            args.num_latent_values,
            (bidir_mod * args.enc_hid_size, args.latent_dim),
            use_cuda=use_cuda)  #assume bidirectional
        hidsize = (args.enc_hid_size, args.dec_hid_size)
        model = DAVAE(args.emb_size,
                      hidsize,
                      vocab,
                      latents,
                      layers=args.nlayers,
                      use_cuda=use_cuda,
                      pretrained=args.use_pretrained,
                      dropout=args.dropout)

    #create the optimizer
    if args.load_opt:
        print("Loading the optimizer state")
        optimizer = torch.load(args.load_opt)
    else:
        print("Creating the optimizer anew")
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    start_time = time.time()  #start of epoch 1
    curr_epoch = 1
    valid_loss = [0.0]
    for iteration, bl in enumerate(
            batches
    ):  #this will continue on forever (shuffling every epoch) till epochs finished
        batch, batch_lens = bl.text
        target, target_lens = bl.target

        if use_cuda:
            batch = Variable(batch.cuda())
        else:
            batch = Variable(batch)

        model.zero_grad()
        latent_values, latent_root, diff, dec_outputs = model(
            batch, batch_lens)
        # train set to True so returns total loss
        loss, _ = monolithic_compute_loss(iteration,
                                          model,
                                          target,
                                          target_lens,
                                          latent_values,
                                          latent_root,
                                          diff,
                                          dec_outputs,
                                          use_cuda,
                                          args=args)

        # backward propagation
        loss.backward()
        # Gradient clipping
        torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
        # Optimize
        optimizer.step()

        # End of an epoch - run validation
        if ((args.batch_size * iteration) % data_len == 0
                or iteration % args.validate_after == 0) and iteration != 0:
            print("\nFinished Training Epoch/iteration {}/{}".format(
                curr_epoch, iteration))

            # do validation
            print("Loading Validation Dataset.")
            val_dataset = du.SentenceDataset(
                args.valid_data, vocab, args.src_seq_length,
                add_eos=False)  #put in filter pred later
            print("Finished Loading Validation Dataset {} examples.".format(
                len(val_dataset)))
            val_batches = BatchIter(val_dataset,
                                    args.batch_size,
                                    sort_key=lambda x: len(x.text),
                                    train=False,
                                    sort_within_batch=True,
                                    device=-1)
            valid_loss = 0.0
            for v_iteration, bl in enumerate(val_batches):
                batch, batch_lens = bl.text
                target, target_lens = bl.target
                batch_lens = batch_lens.cpu()
                if use_cuda:
                    batch = Variable(batch.cuda(), volatile=True)
                else:
                    batch = Variable(batch, volatile=True)

                latent_values, latent_root, diff, dec_outputs = model(
                    batch, batch_lens)
                # train set to False so returns only CE loss
                loss, ce_loss = monolithic_compute_loss(iteration,
                                                        model,
                                                        target,
                                                        target_lens,
                                                        latent_values,
                                                        latent_root,
                                                        diff,
                                                        dec_outputs,
                                                        use_cuda,
                                                        args=args,
                                                        train=False)
                valid_loss = valid_loss + ce_loss.data.clone()

            valid_loss = valid_loss / (v_iteration + 1)
            print("**Validation loss {:.2f}.**\n".format(valid_loss[0]))

            # Check max epochs and break
            if (args.batch_size * iteration) % data_len == 0:
                curr_epoch += 1
            if curr_epoch > args.epochs:
                print("Max epoch {}-{} reached. Exiting.\n".format(
                    curr_epoch, args.epochs))
                break

        # Save the checkpoint
        if iteration % args.save_after == 0 and iteration != 0:
            print("Saving checkpoint for epoch {} at {}.\n".format(
                curr_epoch, args.save_model))
            # curr_epoch and validation stats appended to the model name
            torch.save(
                model, "{}_{}_{}_.epoch_{}.loss_{:.2f}.pt".format(
                    args.save_model, args.commit_c, args.commit2_c, curr_epoch,
                    float(valid_loss[0])))
            torch.save(
                optimizer,
                "{}.{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model,
                                                       "optimizer", curr_epoch,
                                                       float(valid_loss[0])))
Пример #8
0
def do_training(use_cuda=True):

    # Using our data utils to load data
    vocab = du.load_vocab(args.vocab)
    nvocab = len(vocab.stoi.keys())
    print("*Vocab Loaded, Size {}".format(len(vocab.stoi.keys())))

    if args.pretrained:
        print("using pretrained vectors.")
        pretrained = GloVe(name='6B',
                           dim=args.emsize,
                           unk_init=torch.Tensor.normal_)
        vocab.load_vectors(pretrained)
        print("Vectors Loaded")

    if args.emb_type:
        vocab2 = du.load_vocab(args.vocab2)
        nvocab2 = len(vocab2.stoi.keys())
        print("*Vocab2 Loaded, Size {}".format(len(vocab2.stoi.keys())))

        dataset = du.LMRoleSentenceDataset(args.train_data, vocab,
                                           args.train_type_data, vocab2)
        print("*Train Dataset Loaded {} examples".format(len(dataset)))

        # Build the model: word emb + type emb
        model = LSTMLM(args.emsize,
                       args.nhidden,
                       args.nlayers,
                       nvocab,
                       pretrained=args.pretrained,
                       vocab=vocab,
                       type_emb=args.emb_type,
                       ninput2=args.em2size,
                       nvocab2=nvocab2,
                       dropout=args.dropout,
                       use_cuda=use_cuda)
        print("Building word+type emb model.")

    else:

        dataset = du.LMSentenceDataset(args.train_data, vocab)
        print("*Train Dataset Loaded {} examples".format(len(dataset)))

        # Build the model: word emb
        model = LSTMLM(args.emsize,
                       args.nhidden,
                       args.nlayers,
                       nvocab,
                       pretrained=args.pretrained,
                       vocab=vocab,
                       dropout=args.dropout,
                       use_cuda=use_cuda)
        print("Building word emb model.")

    data_len = len(dataset)
    batches = BatchIter(dataset,
                        args.batch_size,
                        sort_key=lambda x: len(x.text),
                        train=True,
                        sort_within_batch=True,
                        device=-1)

    ## some checks
    tally_parameters(model)

    if use_cuda:
        model = model.cuda()

    lr = args.lr

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    val_loss = [0.0]

    # DO TRAINING

    total_loss = 0.0
    lapse = 1
    faulty = False
    for iteration, bl in enumerate(batches):

        # batch is [batch_size, seq_len]
        batch, batch_lens = bl.text
        if args.emb_type:
            role, role_lens = bl.role

        target, target_lens = bl.target

        # init the hidden state before every batch
        hidden = model.init_hidden(batch.size(0))  #args.batch_size)

        # batch has SOS prepended to it.
        # target has EOS appended to it.
        if use_cuda:
            batch = Variable(batch.cuda())
            target = Variable(target.cuda())
            if args.emb_type:
                role = Variable(role.cuda())
        else:
            batch = Variable(batch)
            target = Variable(target)
            if args.emb_type:
                role = Variable(role)

        # Repackaging is not needed.

        # zero the gradients
        model.zero_grad()
        # run the model
        logits = []
        for i in range(batch.size(1)):
            inp = batch[:, i]
            inp = inp.unsqueeze(1)
            if args.emb_type:
                # handle OOI exception by breaking out of the inner loop and moving to the next.
                try:
                    typ = role[:, i]
                    typ = typ.unsqueeze(1)
                    logit, hidden = model(inp, hidden, typ)
                except Exception as e:
                    print("ALERT!! word and type batch error. {}".format(e))
                    faulty = True
                    break
            else:
                # keep updating the hidden state accordingly
                logit, hidden = model(inp, hidden)

            logits += [logit]

        # if this batch was faulty; continue to the next iteration
        if faulty:
            faulty = False
            continue

        # logits is [batch_size, seq_len, vocab_size]
        logits = torch.stack(logits, dim=1)
        if use_cuda:
            loss = masked_cross_entropy(logits, target,
                                        Variable(target_lens.cuda()))
        else:
            loss = masked_cross_entropy(logits, target, Variable(target_lens))

        loss.backward()

        torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)

        # optimize
        optimizer.step()

        # aggregate the stats
        total_loss = total_loss + loss.data.clone()
        lapse += 1

        # print based on log interval
        if (iteration + 1) % args.log_interval == 0:
            print("| iteration {} | loss {:5.2f}".format(
                iteration + 1, loss.data[0]))

        # forcing buffers to write
        sys.stdout.flush()

        # saving only after specified iterations
        if (iteration + 1) % args.save_after == 0:
            # summarize every save after num iterations losses
            avg_loss = total_loss / lapse
            print("||| iteration {} | average loss {:5.2f}".format(
                iteration + 1,
                avg_loss.cpu().numpy()[0]))
            # reset values
            total_loss = 0.0
            lapse = 1

            #torch.save(model, "{}_.epoch_{}.iteration_{}.loss_{:.2f}.pt".format(args.save, curr_epoch, iteration+1, val_loss[0]))
            torch.save(model,
                       "{}_.iteration_{}.pt".format(args.save, iteration + 1))
            torch.save(
                optimizer,
                "{}.{}.iteration_{}.pt".format(args.save, "optimizer",
                                               iteration + 1))
            print(
                "model and optimizer saved for iteration {}".format(iteration +
                                                                    1))
Пример #9
0
def generate(args):
    """
    Use the trained model for decoding
    Args
        args (argparse.ArgumentParser)
    """
    if args.cuda and torch.cuda.is_available():
        device = 0
        use_cuda = True
    elif args.cuda and not torch.cuda.is_available():
        print("You do not have CUDA, turning cuda off")
        device = -1
        use_cuda = False
    else:
        device = -1
        use_cuda = False

    #Load the vocab
    # vocab = du.load_vocab(args.vocab)
    vocab, _ = du.load_vocab(args.vocab)
    vocab2 = du.load_vocab(args.frame_vocab_address, is_Frame=True)

    eos_id = vocab.stoi[EOS_TOK]
    pad_id = vocab.stoi[PAD_TOK]
    if args.ranking:  # default is HARD one, the 'Inverse Narrative Cloze' in the paper
        dataset = du.NarrativeClozeDataset(args.valid_narr,
                                           vocab,
                                           src_seq_length=MAX_EVAL_SEQ_LEN,
                                           min_seq_length=MIN_EVAL_SEQ_LEN,
                                           LM=False)
        print('ranking_dataset: ', len(dataset))
        # Batch size during decoding is set to 1
        batches = BatchIter(dataset,
                            1,
                            sort_key=lambda x: len(x.actual),
                            train=False,
                            device=-1)
    else:
        # dataset = du.SentenceDataset(args.valid_data, vocab, src_seq_length=MAX_EVAL_SEQ_LEN, min_seq_length=MIN_EVAL_SEQ_LEN, add_eos=False) #put in filter pred later
        dataset = du.SentenceDataset(path=args.valid_data,
                                     path2=args.valid_frames,
                                     vocab=vocab,
                                     vocab2=vocab2,
                                     num_clauses=args.num_clauses,
                                     add_eos=False,
                                     is_ref=True,
                                     obsv_prob=0.0,
                                     print_valid=True)
        # Batch size during decoding is set to 1
        batches = BatchIter(dataset,
                            args.batch_size,
                            sort_key=lambda x: len(x.text),
                            train=False,
                            device=-1)

    data_len = len(dataset)

    #Create the model
    with open(args.load, 'rb') as fi:
        if not use_cuda:
            model = torch.load(fi, map_location=lambda storage, loc: storage)
        else:
            model = torch.load(fi, map_location=torch.device('cuda'))

    if not hasattr(model.latent_root, 'nohier'):
        model.latent_root.set_nohier(args.nohier)  #for backwards compatibility

    model.decoder.eval()
    model.set_use_cuda(use_cuda)

    #For reconstruction
    if args.perplexity:
        print('calculating perplexity')
        loss = calc_perplexity(args, model, batches, vocab, data_len)
        NLL = loss
        PPL = np.exp(loss)
        print("Chain-NLL = {}".format(NLL))
        print("Chain-PPL = {}".format(PPL))
        return PPL
    elif args.schema:
        generate_from_seed(args, model, batches, vocab, data_len)
    elif args.ranking:
        ranked_acc = do_ranking(args, model, batches, vocab, data_len,
                                use_cuda)
        return ranked_acc
    else:
        #        sample_outputs(model, vocab)
        reconstruct(args, model, batches, vocab)
Пример #10
0
def train(args):
    """
    Train the model in the ol' fashioned way, just like grandma used to
    Args
        args (argparse.ArgumentParser)
    """
    #Load the data
    logging.info("Loading Vocab")
    evocab = du.load_vocab(args.evocab)
    logging.info("Event Vocab Loaded, Size {}".format(len(evocab.stoi.keys())))

    evocab.stoi[SOS_TOK] = len(evocab.itos)
    evocab.itos.append(SOS_TOK)

    evocab.stoi[EOS_TOK] = len(evocab.itos)
    evocab.itos.append(EOS_TOK)

    assert evocab.stoi[EOS_TOK] == evocab.itos.index(EOS_TOK)
    assert evocab.stoi[SOS_TOK] == evocab.itos.index(SOS_TOK)

    if args.load_model:
        logging.info("Loading the Model")
        model = torch.load(args.load_model, map_location=args.device)
    else:
        logging.info("Creating the Model")
        model = LM.EventLM(args.event_embed_size,
                           args.rnn_hidden_dim,
                           args.rnn_layers,
                           len(evocab.itos),
                           dropout=args.dropout)

    model = model.to(device=args.device)

    #create the optimizer
    if args.load_opt:
        logging.info("Loading the optimizer state")
        optimizer = torch.load(args.load_opt)
    else:
        logging.info("Creating the optimizer anew")
        optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad,
                                            model.parameters()),
                                     lr=args.lr)
    #  optimizer = torch.optim.Adagrad(model.parameters(), lr=args.lr)

    logging.info("Loading Datasets")

    train_dataset = du.LmInstanceDataset(args.train_data, evocab)
    valid_dataset = du.LmInstanceDataset(args.valid_data, evocab)

    #Remove UNK events from the e1prev_intext attribute so they don't mess up avg encoders
    #  train_dataset.filter_examples(['e1prev_intext'])  #These take really long time! Will have to figure something out...
    #  valid_dataset.filter_examples(['e1prev_intext'])
    logging.info("Finished Loading Training Dataset {} examples".format(
        len(train_dataset)))
    logging.info("Finished Loading Valid Dataset {} examples".format(
        len(valid_dataset)))

    train_batches = BatchIter(train_dataset,
                              args.batch_size,
                              sort_key=lambda x: len(x.text),
                              train=True,
                              repeat=False,
                              shuffle=True,
                              sort_within_batch=True,
                              device=None)
    valid_batches = BatchIter(valid_dataset,
                              args.batch_size,
                              sort_key=lambda x: len(x.text),
                              train=False,
                              repeat=False,
                              shuffle=False,
                              sort_within_batch=True,
                              device=None)
    train_data_len = len(train_dataset)
    valid_data_len = len(valid_dataset)

    start_time = time.time()  #start of epoch 1
    best_valid_loss = float('inf')
    best_epoch = args.epochs

    #MAIN TRAINING LOOP
    for curr_epoch in range(args.epochs):
        prev_losses = []
        for iteration, inst in enumerate(train_batches):
            instance = du.lm_send_instance_to(inst, args.device)

            text_inst, text_lens = inst.text
            target_inst, target_lens = inst.target

            model.train()
            model.zero_grad()

            logits = []
            hidden = None
            for step in range(text_inst.size(1)):
                step_inp = text_inst[:, step]  #get all instances for this step
                step_inp = step_inp.unsqueeze(1)  #[batch X 1]

                logit_i, hidden = model(step_inp, hidden)
                logits += [logit_i]
            logits = torch.stack(logits, dim=1)  #[batch, seq_len, vocab]

            loss = masked_cross_entropy(logits, target_inst, target_lens)
            loss.backward()
            torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
            optimizer.step()

            prev_losses.append(loss.cpu().data)
            prev_losses = prev_losses[-50:]

            if (iteration % args.log_every == 0) and iteration != 0:
                past_50_avg = sum(prev_losses) / len(prev_losses)
                logging.info(
                    "Epoch/iteration {}/{}, Past 50 Average Loss {}, Best Val {} at Epoch {}"
                    .format(
                        curr_epoch, iteration, past_50_avg, 'NA' if
                        best_valid_loss == float('inf') else best_valid_loss,
                        'NA' if best_epoch == args.epochs else best_epoch))

            if (iteration % args.validate_after == 0) and iteration != 0:
                logging.info(
                    "Running Validation at Epoch/iteration {}/{}".format(
                        curr_epoch, iteration))
                new_valid_loss = validation(args, valid_batches, model)
                logging.info(
                    "Validation loss at Epoch/iteration {}/{}: {:.3f} - Best Validation Loss: {:.3f}"
                    .format(curr_epoch, iteration, new_valid_loss,
                            best_valid_loss))
                if new_valid_loss < best_valid_loss:
                    logging.info(
                        "New Validation Best...Saving Model Checkpoint")
                    best_valid_loss = new_valid_loss
                    best_epoch = curr_epoch
                    #torch.save(model, "{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, curr_epoch, best_valid_loss))
                    #torch.save(optimizer, "{}.{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, "optimizer", curr_epoch, best_valid_loss))
                    torch.save(model, "{}".format(args.save_model))
                    torch.save(optimizer,
                               "{}_optimizer".format(args.save_model))

        #END OF EPOCH
        logging.info("End of Epoch {}, Running Validation".format(curr_epoch))
        new_valid_loss = validation(args, valid_batches, model)
        logging.info(
            "Validation loss at end of Epoch {}: {:.3f} - Best Validation Loss: {:.3f}"
            .format(curr_epoch, new_valid_loss, best_valid_loss))
        if new_valid_loss < best_valid_loss:
            logging.info("New Validation Best...Saving Model Checkpoint")
            best_valid_loss = new_valid_loss
            best_epoch = curr_epoch
            #torch.save(model, "{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, curr_epoch, best_valid_loss))
            #torch.save(optimizer, "{}.{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, "optimizer", curr_epoch, best_valid_loss))
            torch.save(model, "{}".format(args.save_model))
            torch.save(optimizer, "{}_optimizer".format(args.save_model))

        if curr_epoch - best_epoch >= args.stop_after:
            logging.info(
                "No improvement in {} epochs, terminating at epoch {}...".
                format(args.stop_after, curr_epoch))
            logging.info("Best Validation Loss: {:.2f} at Epoch {}".format(
                best_valid_loss, best_epoch))
            break