Ejemplo n.º 1
0
def generate(args):
    """
    Use the trained model for decoding
    Args
        args (argparse.ArgumentParser)
    """
    if args.cuda and torch.cuda.is_available():
        device = 0
        use_cuda = True
    elif args.cuda and not torch.cuda.is_available():
        print("You do not have CUDA, turning cuda off")
        device = -1
        use_cuda = False
    else:
        device = -1
        use_cuda = False

    #Load the vocab
    vocab = du.load_vocab(args.vocab)
    eos_id = vocab.stoi[EOS_TOK]
    pad_id = vocab.stoi[PAD_TOK]

    if args.ranking:  # default is HARD one, the 'Inverse Narrative Cloze' in the paper
        dataset = du.NarrativeClozeDataset(args.valid_data,
                                           vocab,
                                           src_seq_length=MAX_EVAL_SEQ_LEN,
                                           min_seq_length=MIN_EVAL_SEQ_LEN,
                                           LM=False)
        # Batch size during decoding is set to 1
        batches = BatchIter(dataset,
                            1,
                            sort_key=lambda x: len(x.actual),
                            train=False,
                            device=-1)
    else:
        dataset = du.SentenceDataset(args.valid_data,
                                     vocab,
                                     src_seq_length=MAX_EVAL_SEQ_LEN,
                                     min_seq_length=MIN_EVAL_SEQ_LEN,
                                     add_eos=False)  #put in filter pred later
        # Batch size during decoding is set to 1
        batches = BatchIter(dataset,
                            args.batch_size,
                            sort_key=lambda x: len(x.text),
                            train=False,
                            device=-1)

    data_len = len(dataset)

    #Create the model
    with open(args.load, 'rb') as fi:
        if not use_cuda:
            model = torch.load(fi, map_location=lambda storage, loc: storage)
        else:
            model = torch.load(fi, map_location=torch.device('cuda'))

    if not hasattr(model.latent_root, 'nohier'):
        model.latent_root.set_nohier(args.nohier)  #for backwards compatibility

    model.decoder.eval()
    model.set_use_cuda(use_cuda)

    #For reconstruction
    if args.perplexity:
        loss = calc_perplexity(args, model, batches, vocab, data_len)
        print("Loss = {}".format(loss))
    elif args.schema:
        generate_from_seed(args, model, batches, vocab, data_len)
    elif args.ranking:
        do_ranking(args, model, batches, vocab, data_len, use_cuda)
    else:
        #        sample_outputs(model, vocab)
        reconstruct(args, model, batches, vocab)
Ejemplo n.º 2
0
def do_ranking(model, vocab):

    dataset = du.NarrativeClozeDataset(args.data,
                                       vocab,
                                       src_seq_length=MAX_EVAL_SEQ_LEN,
                                       min_seq_length=MIN_EVAL_SEQ_LEN)
    batches = BatchIter(dataset,
                        args.batch_size,
                        sort_key=lambda x: len(x.actual),
                        train=False,
                        device=device)

    ranked_acc = 0.0
    if args.emb_type:
        print("RANKING WITH ROLE EMB")
        vocab2 = du.load_vocab(args.vocab2)
        role_dataset = du.NarrativeClozeDataset(
            args.role_data,
            vocab2,
            src_seq_length=MAX_EVAL_SEQ_LEN,
            min_seq_length=MIN_EVAL_SEQ_LEN)
        role_batches = BatchIter(role_dataset,
                                 args.batch_size,
                                 sort_key=lambda x: len(x.actual),
                                 train=False,
                                 device=device)

        assert len(dataset) == len(
            role_dataset), "Dataset and Role dataset must be of same length."

        for iteration, (bl, rbl) in enumerate(zip(batches, role_batches)):

            if (iteration + 1) % 25 == 0:
                print("iteration {}".format(iteration + 1))

            ## DATA STEPS
            all_texts = [
                bl.actual, bl.actual_tgt, bl.dist1, bl.dist1_tgt, bl.dist2,
                bl.dist2_tgt, bl.dist3, bl.dist3_tgt, bl.dist4, bl.dist4_tgt,
                bl.dist5, bl.dist5_tgt
            ]  # each is a tup

            all_roles = [
                rbl.actual, rbl.dist1, rbl.dist2, rbl.dist3, rbl.dist4,
                rbl.dist5
            ]  # tgts are not needed for role
            assert len(all_roles) == 6, "6 = 6 * 1."

            assert len(all_texts) == 12, "12 = 6 * 2."

            all_texts_vars = []
            all_roles_vars = []

            if use_cuda:
                for tup in all_texts:
                    all_texts_vars.append((Variable(tup[0].cuda(),
                                                    volatile=True), tup[1]))
                for tup in all_roles:
                    all_roles_vars.append((Variable(tup[0].cuda(),
                                                    volatile=True), tup[1]))

            else:
                for tup in all_texts:
                    all_texts_vars.append((Variable(tup[0],
                                                    volatile=True), tup[1]))
                for tup in all_roles:
                    all_roles_vars.append((Variable(tup[0],
                                                    volatile=True), tup[1]))

            # will itetrate 2 at a time using iterator and next
            vars_iter = iter(all_texts_vars)
            roles_iter = iter(all_roles_vars)

            # run the model and collect ppls for all 6 sentences
            pps = []
            for tup in vars_iter:
                ## INIT AND DECODE before every sentence
                hidden = model.init_hidden(args.batch_size)
                next_tup = next(vars_iter)
                role_tup = next(roles_iter)
                nll = calc_perplexity(args, model, tup[0], vocab, next_tup[0],
                                      next_tup[1], hidden, role_tup[0])
                pp = torch.exp(nll)
                #print("NEG-LOSS {} PPL {}".format(nll.data[0], pp.data[0]))
                pps.append(pp.data.numpy()[0])

            # low perplexity == top ranked sentence- correct answer is the first one of course
            assert len(pps) == 6, "6 targets."
            #print("\n")
            all_texts_str = [
                transform(text[0].data.numpy()[0], vocab.itos)
                for text in all_texts_vars
            ]
            #print("ALL: {}".format(all_texts_str))
            min_index = np.argmin(pps)
            if min_index == 0:
                ranked_acc += 1
                #print("TARGET: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos)))
                #print("CORRECT: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos)))
            #else:
            # print the ones that are wrong
            #print("TARGET: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos)))
            #print("WRONG: {}".format(transform(all_texts_vars[min_index+2][0].data.numpy()[0], vocab.itos)))

            if (iteration + 1) == args.max_decode:
                print("Max decode reached. Exiting.")
                break

        ranked_acc /= (iteration + 1) * 1 / 100  # multiplying to get percent
        print("Average acc(%): {}".format(ranked_acc))
        return ranked_acc

    else:  # THIS IS FOR MODEL WITHOUT ROLE EMB

        print("RANKING WITHOUT ROLE EMB.")
        for iteration, bl in enumerate(batches):

            if (iteration + 1) % 25 == 0:
                print("iteration {}".format(iteration + 1))

            ## DATA STEPS
            all_texts = [
                bl.actual, bl.actual_tgt, bl.dist1, bl.dist1_tgt, bl.dist2,
                bl.dist2_tgt, bl.dist3, bl.dist3_tgt, bl.dist4, bl.dist4_tgt,
                bl.dist5, bl.dist5_tgt
            ]  # each is a tup

            assert len(all_texts) == 12, "12 = 6 * 2."

            all_texts_vars = []
            if use_cuda:
                for tup in all_texts:
                    all_texts_vars.append((Variable(tup[0].cuda(),
                                                    volatile=True), tup[1]))
            else:
                for tup in all_texts:
                    all_texts_vars.append((Variable(tup[0],
                                                    volatile=True), tup[1]))

            # will itetrate 2 at a time using iterator and next
            vars_iter = iter(all_texts_vars)

            # run the model for all 6 sentences
            pps = []
            for tup in vars_iter:
                ## INIT AND DECODE before every sentence
                hidden = model.init_hidden(args.batch_size)
                next_tup = next(vars_iter)

                nll = calc_perplexity(args, model, tup[0], vocab, next_tup[0],
                                      next_tup[1], hidden)
                pp = torch.exp(nll)
                #print("NEG-LOSS {} PPL {}".format(nll.data[0], pp.data[0]))
                pps.append(pp.data.numpy()[0])

            # low perplexity == top ranked sentence- correct answer is the first one of course
            assert len(pps) == 6, "6 targets."
            #print("\n")
            all_texts_str = [
                transform(text[0].data.numpy()[0], vocab.itos)
                for text in all_texts_vars
            ]
            #print("ALL: {}".format(all_texts_str))
            min_index = np.argmin(pps)
            if min_index == 0:
                ranked_acc += 1
                #print("TARGET: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos)))
                #print("CORRECT: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos)))
            #else:
            # print the ones that are wrong
            #print("TARGET: {}".format(transform(all_texts_vars[1][0].data.numpy()[0], vocab.itos)))
            #print("WRONG: {}".format(transform(all_texts_vars[min_index+2][0].data.numpy()[0], vocab.itos)))

            if (iteration + 1) == args.max_decode:
                print("Max decode reached. Exiting.")
                break

        ranked_acc /= (iteration + 1) * 1 / 100  # multiplying to get percent
        print("Average acc(%): {}".format(ranked_acc))
        return ranked_acc
Ejemplo n.º 3
0
def generate(args):
    """
    Use the trained model for decoding
    Args
        args (argparse.ArgumentParser)
    """
    if args.cuda and torch.cuda.is_available():
        device = 0
        use_cuda = True
    elif args.cuda and not torch.cuda.is_available():
        print("You do not have CUDA, turning cuda off")
        device = -1
        use_cuda = False
    else:
        device = -1
        use_cuda = False

    #Load the vocab
    # vocab = du.load_vocab(args.vocab)
    vocab, _ = du.load_vocab(args.vocab)
    vocab2 = du.load_vocab(args.frame_vocab_address, is_Frame=True)

    eos_id = vocab.stoi[EOS_TOK]
    pad_id = vocab.stoi[PAD_TOK]
    if args.ranking:  # default is HARD one, the 'Inverse Narrative Cloze' in the paper
        dataset = du.NarrativeClozeDataset(args.valid_narr,
                                           vocab,
                                           src_seq_length=MAX_EVAL_SEQ_LEN,
                                           min_seq_length=MIN_EVAL_SEQ_LEN,
                                           LM=False)
        print('ranking_dataset: ', len(dataset))
        # Batch size during decoding is set to 1
        batches = BatchIter(dataset,
                            1,
                            sort_key=lambda x: len(x.actual),
                            train=False,
                            device=-1)
    else:
        # dataset = du.SentenceDataset(args.valid_data, vocab, src_seq_length=MAX_EVAL_SEQ_LEN, min_seq_length=MIN_EVAL_SEQ_LEN, add_eos=False) #put in filter pred later
        dataset = du.SentenceDataset(path=args.valid_data,
                                     path2=args.valid_frames,
                                     vocab=vocab,
                                     vocab2=vocab2,
                                     num_clauses=args.num_clauses,
                                     add_eos=False,
                                     is_ref=True,
                                     obsv_prob=0.0,
                                     print_valid=True)
        # Batch size during decoding is set to 1
        batches = BatchIter(dataset,
                            args.batch_size,
                            sort_key=lambda x: len(x.text),
                            train=False,
                            device=-1)

    data_len = len(dataset)

    #Create the model
    with open(args.load, 'rb') as fi:
        if not use_cuda:
            model = torch.load(fi, map_location=lambda storage, loc: storage)
        else:
            model = torch.load(fi, map_location=torch.device('cuda'))

    if not hasattr(model.latent_root, 'nohier'):
        model.latent_root.set_nohier(args.nohier)  #for backwards compatibility

    model.decoder.eval()
    model.set_use_cuda(use_cuda)

    #For reconstruction
    if args.perplexity:
        print('calculating perplexity')
        loss = calc_perplexity(args, model, batches, vocab, data_len)
        NLL = loss
        PPL = np.exp(loss)
        print("Chain-NLL = {}".format(NLL))
        print("Chain-PPL = {}".format(PPL))
        return PPL
    elif args.schema:
        generate_from_seed(args, model, batches, vocab, data_len)
    elif args.ranking:
        ranked_acc = do_ranking(args, model, batches, vocab, data_len,
                                use_cuda)
        return ranked_acc
    else:
        #        sample_outputs(model, vocab)
        reconstruct(args, model, batches, vocab)