Example #1
0
def generate(use_cuda=False, device=-1):

    vocab = du.load_vocab(args.vocab)
    eos_id = vocab.stoi[EOS_TOK]
    pad_id = vocab.stoi[PAD_TOK]
    sos_id = vocab.stoi[SOS_TOK]
    tup_id = vocab.stoi[TUP_TOK]

    assert False == (args.perplexity and args.seed
                     and args.ranking), "Only 1 can be True at a time."
    # Batch size during decoding is set to 1
    assert args.batch_size == 1, "Set batch size to 1 during decoding."

    # Load the model.
    with open(args.model, 'rb') as f:
        model = torch.load(f, map_location=lambda s, loc: s)

    # set the eval mode
    model.eval()
    # to decode without cuda
    model.use_cuda = False

    # TEMP FIX to work with old models without this parameter.
    #model.type_emb = None

    # TASK SPECIFIC FUNCTION CALLS
    if args.ranking:
        # HARD only. Easy one has been deactivated.
        do_ranking(model, vocab)
    elif args.perplexity:
        get_perplexity(model, vocab)
    elif args.seed:
        gen_from_seed(model, vocab, eos_id, pad_id, sos_id, tup_id)
    else:
        print("NOT IMPLEMENTED. RETURNING.")
        return
Example #2
0
def do_ranking(model, vocab):

    dataset = du.NarrativeClozeDataset(args.data,
                                       vocab,
                                       src_seq_length=MAX_EVAL_SEQ_LEN,
                                       min_seq_length=MIN_EVAL_SEQ_LEN)
    batches = BatchIter(dataset,
                        args.batch_size,
                        sort_key=lambda x: len(x.actual),
                        train=False,
                        device=device)

    ranked_acc = 0.0
    if args.emb_type:
        print("RANKING WITH ROLE EMB")
        vocab2 = du.load_vocab(args.vocab2)
        role_dataset = du.NarrativeClozeDataset(
            args.role_data,
            vocab2,
            src_seq_length=MAX_EVAL_SEQ_LEN,
            min_seq_length=MIN_EVAL_SEQ_LEN)
        role_batches = BatchIter(role_dataset,
                                 args.batch_size,
                                 sort_key=lambda x: len(x.actual),
                                 train=False,
                                 device=device)

        assert len(dataset) == len(
            role_dataset), "Dataset and Role dataset must be of same length."

        for iteration, (bl, rbl) in enumerate(zip(batches, role_batches)):

            if (iteration + 1) % 25 == 0:
                print("iteration {}".format(iteration + 1))

            ## DATA STEPS
            all_texts = [
                bl.actual, bl.actual_tgt, bl.dist1, bl.dist1_tgt, bl.dist2,
                bl.dist2_tgt, bl.dist3, bl.dist3_tgt, bl.dist4, bl.dist4_tgt,
                bl.dist5, bl.dist5_tgt
            ]  # each is a tup

            all_roles = [
                rbl.actual, rbl.dist1, rbl.dist2, rbl.dist3, rbl.dist4,
                rbl.dist5
            ]  # tgts are not needed for role
            assert len(all_roles) == 6, "6 = 6 * 1."

            assert len(all_texts) == 12, "12 = 6 * 2."

            all_texts_vars = []
            all_roles_vars = []

            if use_cuda:
                for tup in all_texts:
                    all_texts_vars.append((Variable(tup[0].cuda(),
                                                    volatile=True), tup[1]))
                for tup in all_roles:
                    all_roles_vars.append((Variable(tup[0].cuda(),
                                                    volatile=True), tup[1]))

            else:
                for tup in all_texts:
                    all_texts_vars.append((Variable(tup[0],
                                                    volatile=True), tup[1]))
                for tup in all_roles:
                    all_roles_vars.append((Variable(tup[0],
                                                    volatile=True), tup[1]))

            # will itetrate 2 at a time using iterator and next
            vars_iter = iter(all_texts_vars)
            roles_iter = iter(all_roles_vars)

            # run the model and collect ppls for all 6 sentences
            pps = []
            for tup in vars_iter:
                ## INIT AND DECODE before every sentence
                hidden = model.init_hidden(args.batch_size)
                next_tup = next(vars_iter)
                role_tup = next(roles_iter)
                nll = calc_perplexity(args, model, tup[0], vocab, next_tup[0],
                                      next_tup[1], hidden, role_tup[0])
                pp = torch.exp(nll)
                #print("NEG-LOSS {} PPL {}".format(nll.data[0], pp.data[0]))
                pps.append(pp.data.numpy()[0])

            # low perplexity == top ranked sentence- correct answer is the first one of course
            assert len(pps) == 6, "6 targets."
            #print("\n")
            all_texts_str = [
                transform(text[0].data.numpy()[0], vocab.itos)
                for text in all_texts_vars
            ]
            #print("ALL: {}".format(all_texts_str))
            min_index = np.argmin(pps)
            if min_index == 0:
                ranked_acc += 1
                #print("TARGET: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos)))
                #print("CORRECT: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos)))
            #else:
            # print the ones that are wrong
            #print("TARGET: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos)))
            #print("WRONG: {}".format(transform(all_texts_vars[min_index+2][0].data.numpy()[0], vocab.itos)))

            if (iteration + 1) == args.max_decode:
                print("Max decode reached. Exiting.")
                break

        ranked_acc /= (iteration + 1) * 1 / 100  # multiplying to get percent
        print("Average acc(%): {}".format(ranked_acc))
        return ranked_acc

    else:  # THIS IS FOR MODEL WITHOUT ROLE EMB

        print("RANKING WITHOUT ROLE EMB.")
        for iteration, bl in enumerate(batches):

            if (iteration + 1) % 25 == 0:
                print("iteration {}".format(iteration + 1))

            ## DATA STEPS
            all_texts = [
                bl.actual, bl.actual_tgt, bl.dist1, bl.dist1_tgt, bl.dist2,
                bl.dist2_tgt, bl.dist3, bl.dist3_tgt, bl.dist4, bl.dist4_tgt,
                bl.dist5, bl.dist5_tgt
            ]  # each is a tup

            assert len(all_texts) == 12, "12 = 6 * 2."

            all_texts_vars = []
            if use_cuda:
                for tup in all_texts:
                    all_texts_vars.append((Variable(tup[0].cuda(),
                                                    volatile=True), tup[1]))
            else:
                for tup in all_texts:
                    all_texts_vars.append((Variable(tup[0],
                                                    volatile=True), tup[1]))

            # will itetrate 2 at a time using iterator and next
            vars_iter = iter(all_texts_vars)

            # run the model for all 6 sentences
            pps = []
            for tup in vars_iter:
                ## INIT AND DECODE before every sentence
                hidden = model.init_hidden(args.batch_size)
                next_tup = next(vars_iter)

                nll = calc_perplexity(args, model, tup[0], vocab, next_tup[0],
                                      next_tup[1], hidden)
                pp = torch.exp(nll)
                #print("NEG-LOSS {} PPL {}".format(nll.data[0], pp.data[0]))
                pps.append(pp.data.numpy()[0])

            # low perplexity == top ranked sentence- correct answer is the first one of course
            assert len(pps) == 6, "6 targets."
            #print("\n")
            all_texts_str = [
                transform(text[0].data.numpy()[0], vocab.itos)
                for text in all_texts_vars
            ]
            #print("ALL: {}".format(all_texts_str))
            min_index = np.argmin(pps)
            if min_index == 0:
                ranked_acc += 1
                #print("TARGET: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos)))
                #print("CORRECT: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos)))
            #else:
            # print the ones that are wrong
            #print("TARGET: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos)))
            #print("WRONG: {}".format(transform(all_texts_vars[min_index+2][0].data.numpy()[0], vocab.itos)))

            if (iteration + 1) == args.max_decode:
                print("Max decode reached. Exiting.")
                break

        ranked_acc /= (iteration + 1) * 1 / 100  # multiplying to get percent
        print("Average acc(%): {}".format(ranked_acc))
        return ranked_acc
Example #3
0
def get_perplexity(model, vocab):
    total_loss = 0.0
    if args.emb_type:  # GET PERPLEXITY WITH ROLE EMB
        print("PERPLEXITY WITH ROLE EMB")
        vocab2 = du.load_vocab(args.vocab2)
        dataset = du.LMRoleSentenceDataset(
            args.data,
            vocab,
            args.role_data,
            vocab2,
            src_seq_length=MAX_EVAL_SEQ_LEN,
            min_seq_length=MIN_EVAL_SEQ_LEN)  #put in filter pred later
        batches = BatchIter(dataset,
                            args.batch_size,
                            sort_key=lambda x: len(x.text),
                            train=False,
                            device=device)

        print("DATASET {}".format(len(dataset)))
        for iteration, bl in enumerate(batches):

            if (iteration + 1) % 25 == 0:
                print("iteration {}".format(iteration + 1))

            ## DATA STEPS
            batch, batch_lens = bl.text
            target, target_lens = bl.target
            role, role_lens = bl.role

            if use_cuda:
                batch = Variable(batch.cuda(), volatile=True)
                target = Variable(target.cuda(), volatile=True)
                role = Variable(role.cuda(), volatile=True)
            else:
                batch = Variable(batch, volatile=True)
                target = Variable(target, volatile=True)
                role = Variable(role, volatile=True)

            ## INIT AND DECODE
            hidden = model.init_hidden(args.batch_size)
            ce_loss = calc_perplexity(args, model, batch, vocab, target,
                                      target_lens, hidden, role)
            #print("Loss {}".format(ce_loss))
            total_loss = total_loss + ce_loss.data[0]

            if (iteration + 1) == args.max_decode:
                print("Max decode reached. Exiting.")
                break

        # after iterating over all examples
        loss = total_loss / (iteration + 1)
        print("Average Loss: {}".format(loss))
        return loss

    else:
        print("PERPLEXITY WITHOUT ROLE EMB")
        dataset = du.LMSentenceDataset(
            args.data,
            vocab,
            src_seq_length=MAX_EVAL_SEQ_LEN,
            min_seq_length=MIN_EVAL_SEQ_LEN)  #put in filter pred later
        batches = BatchIter(dataset,
                            args.batch_size,
                            sort_key=lambda x: len(x.text),
                            train=False,
                            device=device)
        for iteration, bl in enumerate(batches):

            if (iteration + 1) % 25 == 0:
                print("iteration {}".format(iteration + 1))

            ## DATA STEPS
            batch, batch_lens = bl.text
            target, target_lens = bl.target

            if use_cuda:
                batch = Variable(batch.cuda(), volatile=True)
                target = Variable(target, volatile=True)
            else:
                batch = Variable(batch, volatile=True)
                target = Variable(target, volatile=True)

            ## INIT AND DECODE
            hidden = model.init_hidden(args.batch_size)
            ce_loss = calc_perplexity(args, model, batch, vocab, target,
                                      target_lens, hidden)
            #print("Loss {}".format(ce_loss))
            total_loss = total_loss + ce_loss.data[0]

            if (iteration + 1) == args.max_decode:
                print("Max decode reached. Exiting.")
                break

        # after iterating over all examples
        loss = total_loss / (iteration + 1)
        print("Average Loss: {}".format(loss))
        return loss
Example #4
0
def gen_from_seed(model, vocab, eos_id, pad_id, sos_id, tup_id):

    if args.emb_type:  # GEN FROM SEED WITH ROLE EMB
        print("GEN SEED WITH ROLE EMB")
        vocab2 = du.load_vocab(args.vocab2)
        # will use this to feed in role ids in beam decode
        ROLES = [
            vocab2.stoi[TUP_TOK], vocab2.stoi[VERB], vocab2.stoi[SUB],
            vocab2.stoi[OBJ], vocab2.stoi[PREP]
        ]
        dataset = du.LMRoleSentenceDataset(
            args.data,
            vocab,
            args.role_data,
            vocab2,
            src_seq_length=MAX_EVAL_SEQ_LEN,
            min_seq_length=MIN_EVAL_SEQ_LEN)  #put in filter pred later
        dataset = du.LMRoleSentenceDataset(args.data, vocab, args.role_data,
                                           vocab2)  #put in filter pred later
        batches = BatchIter(dataset,
                            args.batch_size,
                            sort_key=lambda x: len(x.text),
                            train=False,
                            device=device)

        for iteration, bl in enumerate(batches):

            if (iteration + 1) % 25 == 0:
                print("iteration {}".format(iteration + 1))

            ## DATA STEPS
            batch, batch_lens = bl.text
            target, target_lens = bl.target
            role, role_lens = bl.role

            if use_cuda:
                batch = Variable(batch.cuda(), volatile=True)
                role = Variable(role.cuda(), volatile=True)
            else:
                batch = Variable(batch, volatile=True)
                role = Variable(role, volatile=True)

            ## INIT AND DECODE
            hidden = model.init_hidden(args.batch_size)
            #run the model first on t-1 events, except last word. we know corresponding role ids as well.
            seq_len = batch.size(1)
            for i in range(seq_len - 1):
                inp = batch[:, i]
                inp = inp.unsqueeze(args.batch_size)
                typ = role[:, i]
                typ = typ.unsqueeze(1)
                _, hidden = model(inp, hidden, typ)

            #print("seq len {}, decode after {} steps".format(seq_len, i+1))
            # beam set current state to last word in the sequence
            beam_inp = batch[:, i + 1]
            # do not need this anymore as assuming last sequence role obj is prep.
            #role_inp = role[:, i+1]
            #           print("ROLES LIST: {}".format(ROLES))
            #           print("FIRST ID: {}".format(role[:, i+1]))

            # init beam initializes the beam with the last sequence element. ROLE is a list of roe type ids.
            outputs = beam_decode(model,
                                  beam_inp,
                                  hidden,
                                  args.max_len_decode,
                                  args.beam_size,
                                  pad_id,
                                  sos_id,
                                  eos_id,
                                  tup_idx=tup_id,
                                  init_beam=True,
                                  roles=ROLES)
            predicted_events = get_pred_events(outputs, vocab)

            print("CONTEXT: {}".format(
                transform(batch.data.squeeze(), vocab.itos)))
            print("PRED_t: {}".format(
                predicted_events))  # n_best stitched together.

            if (iteration + 1) == args.max_decode:
                print("Max decode reached. Exiting.")
                break

    else:
        print("GEN SEED WITHOUT ROLE EMB")
        dataset = du.LMSentenceDataset(
            args.data,
            vocab,
            src_seq_length=MAX_EVAL_SEQ_LEN,
            min_seq_length=MIN_EVAL_SEQ_LEN)  #put in filter pred later
        batches = BatchIter(dataset,
                            args.batch_size,
                            sort_key=lambda x: len(x.text),
                            train=False,
                            device=device)
        for iteration, bl in enumerate(batches):

            if (iteration + 1) % 25 == 0:
                print("iteration {}".format(iteration + 1))

            ## DATA STEPS
            batch, batch_lens = bl.text
            target, target_lens = bl.target

            if use_cuda:
                batch = Variable(batch.cuda(), volatile=True)
            else:
                batch = Variable(batch, volatile=True)

            ## INIT AND DECODE
            hidden = model.init_hidden(args.batch_size)

            #run the model first on t-1 events, except last word
            seq_len = batch.size(1)
            for i in range(seq_len - 1):
                inp = batch[:, i]
                inp = inp.unsqueeze(args.batch_size)
                _, hidden = model(inp, hidden)

            #print("seq len {}, decode after {} steps".format(seq_len, i+1))
            # beam set current state to last word in the sequence
            beam_inp = batch[:, i + 1]

            # init beam initializesthe beam with the last sequence element
            outputs = beam_decode(model,
                                  beam_inp,
                                  hidden,
                                  args.max_len_decode,
                                  args.beam_size,
                                  pad_id,
                                  sos_id,
                                  eos_id,
                                  tup_idx=tup_id,
                                  init_beam=True)
            predicted_events = get_pred_events(outputs, vocab)

            print("CONTEXT: {}".format(
                transform(batch.data.squeeze(), vocab.itos)))
            print("PRED_t: {}".format(
                predicted_events))  # n_best stitched together.

            if (iteration + 1) == args.max_decode:
                print("Max decode reached. Exiting.")
                break
Example #5
0
    if torch.cuda.is_available():
        if not args.cuda:
            logging.warning("WARNING: You have a CUDA device, so you should probably run with --cuda")
            args.device = torch.device('cpu')
        else:
            args.device = torch.device('cuda')

            logging.info("Using GPU {}".format(torch.cuda.get_device_name(args.device)))

    else:
        args.device = torch.device('cpu')
    


    evocab = du.load_vocab(args.evocab)


    with open(args.pmi_dict, 'r') as fi:
        pmi_dict = json.load(fi)

    with open(args.causal_dict, 'rb') as fi:
        causal_dict = pickle.load(fi)

    evocab_lm = du.convert_to_lm_vocab(copy.deepcopy(evocab))
    lm_model = torch.load(args.lm_model, map_location=args.device)

    
    so_events = [x for x in evocab.itos if len(x.split('->')) == 2 and x.split('->')[1] in ['nsubj', 'dobj', 'iobj']] #only count these in the rankings

    print(len(so_events))
Example #6
0
def train(args):
    """
    Train the model in the ol' fashioned way, just like grandma used to
    Args
        args (argparse.ArgumentParser)
    """
    #Load the data
    logging.info("Loading Vocab")
    evocab = du.load_vocab(args.evocab)
    tvocab = du.load_vocab(args.tvocab)
    logging.info("Event Vocab Loaded, Size {}".format(len(evocab.stoi.keys())))
    logging.info("Text Vocab Loaded, Size {}".format(len(tvocab.stoi.keys())))

    if args.use_pretrained:
        pretrained = GloVe(name='6B',
                           dim=args.text_embed_size,
                           unk_init=torch.Tensor.normal_)
        tvocab = du.load_vectors(pretrained)
        logging.info("Loaded Pretrained Word Embeddings")

    if args.load_model:
        logging.info("Loading the Model")
        model = torch.load(args.load_model, map_location=args.device)
    else:
        logging.info("Creating the Model")
        if args.onehot_events:
            logging.info(
                "Model Type: SemiNaiveAdjustmentEstimatorOneHotEvents")
            model = estimators.SemiNaiveAdjustmentEstimatorOneHotEvents(
                args, evocab, tvocab)
        else:
            logging.info("Model Type: SemiNaiveAdjustmentEstimator")
            model = estimators.SemiNaiveAdjustmentEstimator(
                args, evocab, tvocab)

    if args.finetune:
        assert args.load_model
        logging.info("Finetuning...")
        if args.freeze:
            logging.info("Freezing...")
            for param in model.parameters():
                param.requires_grad = False
        model = estimators.AdjustmentEstimator(args, evocab, tvocab, model)

        #Still finetune the last layer even if freeze is on (if freeze is on , then everything else is frozen)
        model.expected_outcome.event_text_logits_mlp.weight.requires_grad = True
        model.expected_outcome.event_text_logits_mlp.bias.requires_grad = True

        logging.info("Trainable Params: {}".format(
            [x[0] for x in model.named_parameters() if x[1].requires_grad]))

    model = model.to(device=args.device)

    #create the optimizer
    if args.load_opt:
        logging.info("Loading the optimizer state")
        optimizer = torch.load(args.load_opt)
    else:
        if args.optimizer == 'adagrad':
            logging.info("Creating Adagrad optimizer anew")
            optimizer = torch.optim.Adagrad(filter(lambda x: x.requires_grad,
                                                   model.parameters()),
                                            lr=args.lr)
        elif args.optimizer == 'sgd':
            logging.info("Creating SGD optimizer anew")
            optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad,
                                               model.parameters()),
                                        lr=args.lr)
        else:
            logging.info("Creating Adam optimizer anew")
            optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad,
                                                model.parameters()),
                                         lr=args.lr)

    logging.info("Loading Datasets")
    min_size = model.text_encoder.largest_ngram_size  #Add extra pads if text size smaller than largest CNN kernel size

    if args.load_pickle:
        logging.info("Loading Train from Pickled Data")
        with open(args.train_data, 'rb') as pfi:
            pickled_examples = pickle.load(pfi)
        train_dataset = du.InstanceDataset("",
                                           evocab,
                                           tvocab,
                                           min_size=min_size,
                                           pickled_examples=pickled_examples)
    else:
        train_dataset = du.InstanceDataset(args.train_data,
                                           evocab,
                                           tvocab,
                                           min_size=min_size)
    valid_dataset = du.InstanceDataset(args.valid_data,
                                       evocab,
                                       tvocab,
                                       min_size=min_size)

    #Remove UNK events from the e1prev_intext attribute so they don't mess up avg encoders
    #  train_dataset.filter_examples(['e1prev_intext'])  #These take really long time! Will have to figure something out...
    #  valid_dataset.filter_examples(['e1prev_intext'])
    logging.info("Finished Loading Training Dataset {} examples".format(
        len(train_dataset)))
    logging.info("Finished Loading Valid Dataset {} examples".format(
        len(valid_dataset)))

    train_batches = BatchIter(train_dataset,
                              args.batch_size,
                              sort_key=lambda x: len(x.allprev),
                              train=True,
                              repeat=False,
                              shuffle=True,
                              sort_within_batch=True,
                              device=None)
    valid_batches = BatchIter(valid_dataset,
                              args.batch_size,
                              sort_key=lambda x: len(x.allprev),
                              train=False,
                              repeat=False,
                              shuffle=False,
                              sort_within_batch=True,
                              device=None)
    train_data_len = len(train_dataset)
    valid_data_len = len(valid_dataset)

    loss_func = nn.CrossEntropyLoss()

    start_time = time.time()  #start of epoch 1
    best_valid_loss = float('inf')
    best_epoch = args.epochs

    if args.finetune:
        vloss = validation(args, valid_batches, model, loss_func)
        logging.info("Pre Finetune Validation Loss: {}".format(vloss))

    #MAIN TRAINING LOOP
    for curr_epoch in range(args.epochs):
        prev_losses = []
        for iteration, inst in enumerate(train_batches):
            instance = du.send_instance_to(inst, args.device)

            model.train()
            model.zero_grad()
            model_outputs = model(instance)

            exp_outcome_out = model_outputs[
                EXP_OUTCOME_COMPONENT]  #[batch X num events], output predication for e2
            exp_outcome_loss = loss_func(exp_outcome_out, instance.e2)
            loss = exp_outcome_loss

            loss.backward()
            torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
            optimizer.step()

            prev_losses.append(loss.cpu().data)
            prev_losses = prev_losses[-50:]

            if (iteration % args.log_every == 0) and iteration != 0:
                past_50_avg = sum(prev_losses) / len(prev_losses)
                logging.info(
                    "Epoch/iteration {}/{}, Past 50 Average Loss {}, Best Val {} at Epoch {}"
                    .format(
                        curr_epoch, iteration, past_50_avg, 'NA' if
                        best_valid_loss == float('inf') else best_valid_loss,
                        'NA' if best_epoch == args.epochs else best_epoch))

            if (iteration % args.validate_after == 0) and iteration != 0:
                logging.info(
                    "Running Validation at Epoch/iteration {}/{}".format(
                        curr_epoch, iteration))
                new_valid_loss = validation(args, valid_batches, model,
                                            loss_func)
                logging.info(
                    "Validation loss at Epoch/iteration {}/{}: {:.3f} - Best Validation Loss: {:.3f}"
                    .format(curr_epoch, iteration, new_valid_loss,
                            best_valid_loss))
                if new_valid_loss < best_valid_loss:
                    logging.info(
                        "New Validation Best...Saving Model Checkpoint")
                    best_valid_loss = new_valid_loss
                    best_epoch = curr_epoch
                    #torch.save(model, "{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, curr_epoch, best_valid_loss))
                    #torch.save(optimizer, "{}.{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, "optimizer", curr_epoch, best_valid_loss))
                    torch.save(model, "{}".format(args.save_model))
                    torch.save(optimizer,
                               "{}_optimizer".format(args.save_model))

        #END OF EPOCH
        logging.info("End of Epoch {}, Running Validation".format(curr_epoch))
        new_valid_loss = validation(args, valid_batches, model, loss_func)
        logging.info(
            "Validation loss at end of Epoch {}: {:.3f} - Best Validation Loss: {:.3f}"
            .format(curr_epoch, new_valid_loss, best_valid_loss))
        if new_valid_loss < best_valid_loss:
            logging.info("New Validation Best...Saving Model Checkpoint")
            best_valid_loss = new_valid_loss
            best_epoch = curr_epoch
            #torch.save(model, "{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, curr_epoch, best_valid_loss))
            #torch.save(optimizer, "{}.{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, "optimizer", curr_epoch, best_valid_loss))
            torch.save(model, "{}".format(args.save_model))
            torch.save(optimizer, "{}_optimizer".format(args.save_model))

        if curr_epoch - best_epoch >= args.stop_after:
            logging.info(
                "No improvement in {} epochs, terminating at epoch {}...".
                format(args.stop_after, curr_epoch))
            logging.info("Best Validation Loss: {:.2f} at Epoch {}".format(
                best_valid_loss, best_epoch))
            break
Example #7
0
def train(args):
    """
    Train the model in the ol' fashioned way, just like grandma used to
    Args
        args (argparse.ArgumentParser)
    """
    #Load the data
    logging.info("Loading Vocab")
    evocab = du.load_vocab(args.evocab)
    logging.info("Event Vocab Loaded, Size {}".format(len(evocab.stoi.keys())))

    evocab.stoi[SOS_TOK] = len(evocab.itos)
    evocab.itos.append(SOS_TOK)

    evocab.stoi[EOS_TOK] = len(evocab.itos)
    evocab.itos.append(EOS_TOK)

    assert evocab.stoi[EOS_TOK] == evocab.itos.index(EOS_TOK)
    assert evocab.stoi[SOS_TOK] == evocab.itos.index(SOS_TOK)

    if args.load_model:
        logging.info("Loading the Model")
        model = torch.load(args.load_model, map_location=args.device)
    else:
        logging.info("Creating the Model")
        model = LM.EventLM(args.event_embed_size,
                           args.rnn_hidden_dim,
                           args.rnn_layers,
                           len(evocab.itos),
                           dropout=args.dropout)

    model = model.to(device=args.device)

    #create the optimizer
    if args.load_opt:
        logging.info("Loading the optimizer state")
        optimizer = torch.load(args.load_opt)
    else:
        logging.info("Creating the optimizer anew")
        optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad,
                                            model.parameters()),
                                     lr=args.lr)
    #  optimizer = torch.optim.Adagrad(model.parameters(), lr=args.lr)

    logging.info("Loading Datasets")

    train_dataset = du.LmInstanceDataset(args.train_data, evocab)
    valid_dataset = du.LmInstanceDataset(args.valid_data, evocab)

    #Remove UNK events from the e1prev_intext attribute so they don't mess up avg encoders
    #  train_dataset.filter_examples(['e1prev_intext'])  #These take really long time! Will have to figure something out...
    #  valid_dataset.filter_examples(['e1prev_intext'])
    logging.info("Finished Loading Training Dataset {} examples".format(
        len(train_dataset)))
    logging.info("Finished Loading Valid Dataset {} examples".format(
        len(valid_dataset)))

    train_batches = BatchIter(train_dataset,
                              args.batch_size,
                              sort_key=lambda x: len(x.text),
                              train=True,
                              repeat=False,
                              shuffle=True,
                              sort_within_batch=True,
                              device=None)
    valid_batches = BatchIter(valid_dataset,
                              args.batch_size,
                              sort_key=lambda x: len(x.text),
                              train=False,
                              repeat=False,
                              shuffle=False,
                              sort_within_batch=True,
                              device=None)
    train_data_len = len(train_dataset)
    valid_data_len = len(valid_dataset)

    start_time = time.time()  #start of epoch 1
    best_valid_loss = float('inf')
    best_epoch = args.epochs

    #MAIN TRAINING LOOP
    for curr_epoch in range(args.epochs):
        prev_losses = []
        for iteration, inst in enumerate(train_batches):
            instance = du.lm_send_instance_to(inst, args.device)

            text_inst, text_lens = inst.text
            target_inst, target_lens = inst.target

            model.train()
            model.zero_grad()

            logits = []
            hidden = None
            for step in range(text_inst.size(1)):
                step_inp = text_inst[:, step]  #get all instances for this step
                step_inp = step_inp.unsqueeze(1)  #[batch X 1]

                logit_i, hidden = model(step_inp, hidden)
                logits += [logit_i]
            logits = torch.stack(logits, dim=1)  #[batch, seq_len, vocab]

            loss = masked_cross_entropy(logits, target_inst, target_lens)
            loss.backward()
            torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
            optimizer.step()

            prev_losses.append(loss.cpu().data)
            prev_losses = prev_losses[-50:]

            if (iteration % args.log_every == 0) and iteration != 0:
                past_50_avg = sum(prev_losses) / len(prev_losses)
                logging.info(
                    "Epoch/iteration {}/{}, Past 50 Average Loss {}, Best Val {} at Epoch {}"
                    .format(
                        curr_epoch, iteration, past_50_avg, 'NA' if
                        best_valid_loss == float('inf') else best_valid_loss,
                        'NA' if best_epoch == args.epochs else best_epoch))

            if (iteration % args.validate_after == 0) and iteration != 0:
                logging.info(
                    "Running Validation at Epoch/iteration {}/{}".format(
                        curr_epoch, iteration))
                new_valid_loss = validation(args, valid_batches, model)
                logging.info(
                    "Validation loss at Epoch/iteration {}/{}: {:.3f} - Best Validation Loss: {:.3f}"
                    .format(curr_epoch, iteration, new_valid_loss,
                            best_valid_loss))
                if new_valid_loss < best_valid_loss:
                    logging.info(
                        "New Validation Best...Saving Model Checkpoint")
                    best_valid_loss = new_valid_loss
                    best_epoch = curr_epoch
                    #torch.save(model, "{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, curr_epoch, best_valid_loss))
                    #torch.save(optimizer, "{}.{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, "optimizer", curr_epoch, best_valid_loss))
                    torch.save(model, "{}".format(args.save_model))
                    torch.save(optimizer,
                               "{}_optimizer".format(args.save_model))

        #END OF EPOCH
        logging.info("End of Epoch {}, Running Validation".format(curr_epoch))
        new_valid_loss = validation(args, valid_batches, model)
        logging.info(
            "Validation loss at end of Epoch {}: {:.3f} - Best Validation Loss: {:.3f}"
            .format(curr_epoch, new_valid_loss, best_valid_loss))
        if new_valid_loss < best_valid_loss:
            logging.info("New Validation Best...Saving Model Checkpoint")
            best_valid_loss = new_valid_loss
            best_epoch = curr_epoch
            #torch.save(model, "{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, curr_epoch, best_valid_loss))
            #torch.save(optimizer, "{}.{}.epoch_{}.loss_{:.2f}.pt".format(args.save_model, "optimizer", curr_epoch, best_valid_loss))
            torch.save(model, "{}".format(args.save_model))
            torch.save(optimizer, "{}_optimizer".format(args.save_model))

        if curr_epoch - best_epoch >= args.stop_after:
            logging.info(
                "No improvement in {} epochs, terminating at epoch {}...".
                format(args.stop_after, curr_epoch))
            logging.info("Best Validation Loss: {:.2f} at Epoch {}".format(
                best_valid_loss, best_epoch))
            break