Exemple #1
0
def test_l3_kbest():
    params, lengths = params_l3()
    dist = SentCFG(params, lengths=lengths)

    _, _, _, spans = dist.argmax
    spans, trees = extract_parses(spans, lengths)
    best_trees = "((0 1) 2)"
    best_spans = [(0, 1, 2), (0, 2, 2)]
    assert spans[0] == best_spans
    assert trees[0] == best_trees

    _, _, _, spans = dist.topk(4)
    size = (1, 0) + tuple(range(2, spans.dim()))
    spans = spans.permute(size)
    spans, trees = extract_topk(spans, lengths)
    best_trees = "((0 1) 2)"
    best_spans = [
        [(0, 1, 2), (0, 2, 2)],
        [(0, 1, 2), (0, 2, 2)],
        [(0, 1, 1), (0, 2, 2)],
        [(0, 1, 1), (0, 2, 2)],
    ]
    for i, (span, tree) in enumerate(zip(spans, trees)):
        assert span == best_spans[i]
        assert tree == best_trees
Exemple #2
0
def train():
    # model.train()
    losses = []
    for epoch in range(2):
        for i, ex in enumerate(train_iter):
            opt.zero_grad()
            words, lengths = ex.word
            N, batch = words.shape
            words = words.long()
            params = model(words.cuda().transpose(0, 1))
            dist = SentCFG(params, lengths=lengths)
            loss = dist.partition.mean()
            (-loss).backward()
            losses.append(loss.detach())
            torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
            opt.step()

            if i > 100:
                break
            if i % 100 == 1:
                print(-torch.tensor(losses).mean(), words.shape)
                losses = []
    if args.debug:
        print(f"saving to {args.debug}...")
        torch.save(params[0], args.debug)
Exemple #3
0
 def train(self, niter=1):
   for _ in range(niter):
     losses = []
     self.opt.zero_grad()
     params = self.model(self.words)
     dist = SentCFG(params, lengths=self.lengths)
     loss = dist.partition.mean()
     (-loss).backward()
     losses.append(loss.detach())
     torch.nn.utils.clip_grad_norm_(self.model.parameters(), 3.0)
     self.opt.step()
Exemple #4
0
    def forward_parser(self, captions, lengths):
        params, kl = self.parser(captions)
        dist = SentCFG(params, lengths=lengths)

        the_spans = dist.argmax[-1]
        argmax_spans, trees, lprobs = utils.extract_parses(the_spans, lengths.tolist(), inc=0) 

        ll = dist.partition
        nll = -ll
        kl = torch.zeros_like(nll) if kl is None else kl
        return nll, kl, argmax_spans, trees, lprobs
Exemple #5
0
 def train(self, niter=1):
     losses = []
     for i, ex in enumerate(self.train_iter):
         if i == niter:
             break
         self.opt.zero_grad()
         words, lengths = ex.word
         words = words.long()
         params = self.model(words.to(device=self.device).transpose(0, 1))
         dist = SentCFG(params, lengths=lengths)
         loss = dist.partition.mean()
         (-loss).backward()
         losses.append(loss.detach())
         torch.nn.utils.clip_grad_norm_(self.model.parameters(), 3.0)
         self.opt.step()
Exemple #6
0
    def forward_txt_parser(self, input, lengths):
        params, kl = self.txt_parser(input)
        dist = SentCFG(params, lengths=lengths)

        the_spans = dist.argmax[-1]
        argmax_spans, trees, lprobs = utils.extract_parses(the_spans,
                                                           lengths.tolist(),
                                                           inc=0)

        txt_outputs = self.txt_enc(input, lengths)

        ll, span_margs = dist.inside_im
        nll = -ll
        kl = torch.zeros_like(nll) if kl is None else kl
        return txt_outputs, nll, kl, span_margs, argmax_spans, trees, lprobs
Exemple #7
0
def eval_trees(args):
    checkpoint = torch.load(args.model, map_location='cpu')
    opt = checkpoint['opt']
    use_mean = True
    # load vocabulary used by the model
    data_path = args.data_path
    #data_path = getattr(opt, "data_path", args.data_path)
    vocab_name = getattr(opt, "vocab_name", args.vocab_name)
    vocab = pickle.load(open(os.path.join(data_path, vocab_name), 'rb'))
    checkpoint['word2idx'] = vocab.word2idx
    opt.vocab_size = len(vocab)

    parser = checkpoint['model']
    parser = make_model(parser, opt)
    parser.cuda()
    parser.eval()

    batch_size = 5
    prefix = args.prefix
    print('Loading dataset', data_path + prefix + args.split)
    data_loader = data.eval_data_iter(data_path,
                                      prefix + args.split,
                                      vocab,
                                      batch_size=batch_size)

    # stats
    trees = list()
    n_word, n_sent = 0, 0
    per_label_f1 = defaultdict(list)
    by_length_f1 = defaultdict(list)
    sent_f1, corpus_f1 = [], [0., 0., 0.]
    total_ll, total_kl, total_bc, total_h = 0., 0., 0., 0.

    pred_out = open(args.out_file, "w")

    for i, (captions, lengths, spans, labels, tags,
            ids) in enumerate(data_loader):
        lengths = torch.tensor(lengths).long() if isinstance(lengths,
                                                             list) else lengths
        if torch.cuda.is_available():
            lengths = lengths.cuda()
            captions = captions.cuda()

        params, kl = parser(captions, lengths, use_mean=use_mean)
        dist = SentCFG(params, lengths=lengths)

        arg_spans = dist.argmax[-1]
        argmax_spans, _, _ = utils.extract_parses(arg_spans,
                                                  lengths.tolist(),
                                                  inc=0)

        candidate_trees = list()
        bsize = captions.shape[0]
        n_word += (lengths + 1).sum().item()
        n_sent += bsize

        for b in range(bsize):
            max_len = lengths[b].item()
            pred = [(a[0], a[1]) for a in argmax_spans[b] if a[0] != a[1]]
            pred_set = set(pred[:-1])
            gold = [(l, r) for l, r in spans[b] if l != r]
            gold_set = set(gold[:-1])

            ccaption = captions[b].tolist()[:max_len]
            sent = [
                vocab.idx2word[int(word)] for _, word in enumerate(ccaption)
            ]
            iitem = (sent, gold, labels, pred)
            json.dump(iitem, pred_out)
            pred_out.write("\n")

            tp, fp, fn = utils.get_stats(pred_set, gold_set)
            corpus_f1[0] += tp
            corpus_f1[1] += fp
            corpus_f1[2] += fn

            overlap = pred_set.intersection(gold_set)
            prec = float(len(overlap)) / (len(pred_set) + 1e-8)
            reca = float(len(overlap)) / (len(gold_set) + 1e-8)

            if len(gold_set) == 0:
                reca = 1.
                if len(pred_set) == 0:
                    prec = 1.
            f1 = 2 * prec * reca / (prec + reca + 1e-8)
            sent_f1.append(f1)

            word_tree = build_parse(argmax_spans[b], captions[b].tolist(),
                                    vocab)
            candidate_trees.append(word_tree)

            for j, gold_span in enumerate(gold[:-1]):
                label = labels[b][j]
                label = re.split("=|-", label)[0]
                per_label_f1.setdefault(label, [0., 0.])
                per_label_f1[label][0] += 1

                lspan = gold_span[1] - gold_span[0] + 1
                by_length_f1.setdefault(lspan, [0., 0.])
                by_length_f1[lspan][0] += 1

                if gold_span in pred_set:
                    per_label_f1[label][1] += 1
                    by_length_f1[lspan][1] += 1

        appended_trees = ['' for _ in range(len(ids))]
        for j in range(len(ids)):
            tree = candidate_trees[j]
            appended_trees[ids[j] - min(ids)] = tree
        for tree in appended_trees:
            #print(tree)
            pass
        trees.extend(appended_trees)
        #if i == 50: break

    tp, fp, fn = corpus_f1
    prec = tp / (tp + fp)
    recall = tp / (tp + fn)
    corpus_f1 = 2 * prec * recall / (prec +
                                     recall) if prec + recall > 0 else 0.
    sent_f1 = np.mean(np.array(sent_f1))
    recon_ppl = np.exp(total_ll / n_word)
    ppl_elbo = np.exp((total_ll + total_kl) / n_word)
    kl = total_kl / n_sent
    info = '\nReconPPL: {:.2f}, KL: {:.4f}, PPL (Upper Bound): {:.2f}\n' + \
           'Corpus F1: {:.2f}, Sentence F1: {:.2f}'
    info = info.format(recon_ppl, kl, ppl_elbo, corpus_f1 * 100, sent_f1 * 100)
    print(info)

    f1_ids = ["CF1", "SF1", "NP", "VP", "PP", "SBAR", "ADJP", "ADVP"]

    f1s = {"CF1": corpus_f1, "SF1": sent_f1}

    print("\nPER-LABEL-F1 (label, acc)\n")
    for k, v in per_label_f1.items():
        print("{}\t{:.4f} = {}/{}".format(k, v[1] / v[0], v[1], v[0]))
        f1s[k] = v[1] / v[0]

    f1s = ['{:.2f}'.format(float(f1s[x]) * 100) for x in f1_ids]
    print("\t".join(f1_ids))
    print("\t".join(f1s))

    acc = []

    print("\nPER-LENGTH-F1 (length, acc)\n")
    xx = sorted(list(by_length_f1.items()), key=lambda x: x[0])
    for k, v in xx:
        print("{}\t{:.4f} = {}/{}".format(k, v[1] / v[0], v[1], v[0]))
        if v[0] >= 5:
            acc.append((str(k), '{:.2f}'.format(v[1] / v[0])))
    k = [x for x, _ in acc]
    v = [x for _, x in acc]
    print("\t".join(k))
    print("\t".join(v))

    pred_out.close()
    return trees