示例#1
0
 def parse_sample(self, samples, i):
     # unbpe _after_ building the tree and doing the traversal
     tokens = [self._i2tok[st.item()] for st in samples[i]]
     root = build_tree(tokens)
     tokens, nodes = tree_to_text(root)
     tokens = " ".join(filter(lambda t: t not in {'<p>', '<s>', '</s>', '<end>'}, tokens))
     tokens = tokens.replace("@@ <end>", " <end>").replace("@@ ", "")
     tokens = tokens.split()
     tokens_levels = [(node.value, node.level) for node in nodes]
     return root, nodes, tokens, tokens_levels
示例#2
0
def print_samples(samples, data, n=min(args.batch_size, 5)):
    for i in range(n):
        tokens = inds2toks(i2tok, samples[i].cpu().tolist())
        root = build_tree(tokens)
        tokens, nodes = tree_to_text(root)
        tokens_levels = [(node.value, node.level) for node in nodes]
        print(' '.join(tokens))
        print(' '.join(str(x) for x in tokens_levels))
        print(print_tree(root))
        print()
示例#3
0
def print_samples(samples, data, n=None):
    n = min(n, 5)
    for i in range(n):
        tokensa = [x.split() for x in TRG.reverse(samples[i].unsqueeze(0), unbpe=True)][0]
        root = build_tree(tokensa)
        tokens, nodes = tree_to_text(root)
        gt_tokens = [x.split() for x in TRG.reverse(data[0][i:i + 1].cpu(), unbpe=True)][0]
        print('ACTUAL:\t%s' % ' '.join(gt_tokens))
        print('PRED:\t%s' % ' '.join(tokens))
        print(print_tree(root))
        print()
示例#4
0
def print_samples(samples, data, n=min(args.batch_size, 5)):
    for i in range(n):
        tokens = inds2toks(i2tok, samples[i].cpu().tolist())
        root = build_tree(tokens)
        tokens, nodes = tree_to_text(root)
        tokens_levels = [(node.value, node.level) for node in nodes]
        gt_inds = [x for x in data[0][i].cpu().tolist() if x != tok2i['</s>'] and x != tok2i['<p>']]
        gt_tokens = inds2toks(i2tok, gt_inds)
        print('ACTUAL:\t%s' % ' '.join(gt_tokens))
        print('PRED:\t%s' % ' '.join(tokens))
        print(' '.join(str(x) for x in tokens_levels))
        print(print_tree(root))
        print()
示例#5
0
    def update(self, scores, samples, batch):
        for i in range(batch[0].size(0)):
            tokens = inds2toks(self._i2tok, samples[i].cpu().tolist())
            root = build_tree(tokens)
            tokens, nodes = tree_to_text(root)
            gt_inds = [
                x for x in batch[0][i].cpu().tolist()
                if x != self._tok2i['</s>'] and x != self._tok2i['<p>']
            ]
            gt_tokens = inds2toks(self._i2tok, gt_inds)

            self._metrics['em'] += self._exact_match(tokens, gt_tokens)
            precision, recall, f1 = self._prec_recall_f1_score(
                tokens, gt_tokens)

            # BLEU
            bleu_results = self._sentence_bleu(tokens, gt_tokens)
            self._metrics['bleu'] += bleu_results.score
            self._metrics['precisions-1'] += bleu_results.precisions[0]
            self._metrics['precisions-2'] += bleu_results.precisions[1]
            self._metrics['precisions-3'] += bleu_results.precisions[2]
            self._metrics['precisions-4'] += bleu_results.precisions[3]
            self._metrics['brevity_penalty'] += bleu_results.bp
            self._metrics['sys_len'] += bleu_results.sys_len
            self._metrics['ref_len'] += bleu_results.ref_len

            # Sentence-level averages scores over sentences
            self._metrics['precision'] += precision
            self._metrics['recall'] += recall
            self._metrics['f1'] += f1

            self._metrics['depth_score'] += self._depth_score(nodes)
            self._metrics['avg_span'] += self._avg_span(nodes)

            # Normalizer for the above summed metrics.
            self._metrics['n_sent'] += 1.0

            if self.bleu_to_file:
                self._save_tokens(tokens, gt_tokens)

        self._metrics['n_batch'] += 1.0
示例#6
0
def eval_single(model, sentence):
    model.eval()
    xs = sentence.split()  # NOTE(wellecks): split tokenizer
    idxs = ([model.tok2i.get(x, model.tok2i['<unk>'])
             for x in xs] + [model.tok2i['</s>']])
    x = th.tensor([idxs], dtype=th.long, device=model.device)
    scores, preds = model.forward(x)
    raw_tokens = data.inds2toks(model.i2tok, preds.cpu().tolist()[0])
    model.train()

    root = tree_util.build_tree(raw_tokens)
    inorder_tokens, nodes = tree_util.tree_to_text(root)

    tokens_levels = [(node.value, node.level) for node in nodes]
    output = {
        'raw_tokens': raw_tokens,
        'tree_string': tree_util.print_tree(root),
        'inorder_tokens': inorder_tokens,
        'genorder_tokens': common_eval.get_genorder_tokens(raw_tokens),
        'token_levels': tokens_levels,
        'gt_tokens': xs
    }
    return output, x, scores, preds
示例#7
0
def eval_single(model, sentence, TRG=None):
    scores, preds = model.forward(xs=sentence.src, oracle=None, num_samples=1)
    raw_tokens = [x.split() for x in TRG.reverse(preds, unbpe=True)][0]
    model.train()

    root = tree_util.build_tree(raw_tokens)
    inorder_tokens, nodes = tree_util.tree_to_text(root)
    tokens_levels = [(node.value, node.level) for node in nodes]

    if TRG is not None:
        trg = [x.split() for x in TRG.reverse(sentence.trg[0], unbpe=True)][0]
        src = [x.split() for x in TRG.reverse(sentence.src[0], unbpe=True)][0]
    else:
        gt_inds = [
            x for x in sentence.trg[0][0].cpu().tolist()
            if x != model.tok2i['</s>'] and x != model.tok2i['<p>']
            and x != model.tok2i['<s>']
        ]
        trg = TRG.inds2toks(model.i2tok, gt_inds)

        src_inds = [
            x for x in sentence.src[0][0].cpu().tolist()
            if x != model.tok2i['</s>'] and x != model.tok2i['<p>']
            and x != model.tok2i['<s>']
        ]
        src = TRG.inds2toks(model.i2tok, src_inds)

    output = {
        'raw_tokens': raw_tokens,
        'tree_string': tree_util.print_tree(root),
        'inorder_tokens': inorder_tokens,
        'genorder_tokens': common_eval.get_genorder_tokens(raw_tokens),
        'token_levels': tokens_levels,
        'gt_tokens': trg,
        'src_tokens': src
    }
    return output, scores, preds
示例#8
0
def convert_samples(samples, ground_truth, i2tok, tok2i):
    if not isinstance(samples, list):
        samples = samples.clone().cpu().tolist()
    if ground_truth is not None:
        ground_truth = ground_truth.clone().cpu().tolist()
    converted = []
    for i in range(len(samples)):
        sample = samples[i]
        raw_tokens = data.inds2toks(i2tok, sample)
        root = tree_util.build_tree(raw_tokens)
        inorder_tokens, nodes = tree_util.tree_to_text(root)
        tokens_levels = [(node.value, node.level) for node in nodes]
        output = {'raw_tokens': raw_tokens,
                  'tree_string': tree_util.print_tree(root),
                  'inorder_tokens': inorder_tokens,
                  'genorder_tokens': get_genorder_tokens(raw_tokens),
                  'token_levels': tokens_levels}
        if ground_truth is not None:
            gt_inds = [x for x in ground_truth[i] if x != tok2i['</s>'] and x != tok2i['<p>']]
            gt_tokens = data.inds2toks(i2tok, gt_inds)
            output['gt_tokens'] = gt_tokens

        converted.append(output)
    return converted