Beispiel #1
0
    def _evaluate(self, loader):
        self.model.eval()

        total_loss, metric = 0, SpanMetric()

        for batch in loader:
            words, *feats, trees, charts = batch
            word_mask = words.ne(self.args.pad_index)[:, 1:]
            mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
            mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
            s_span, s_pair, s_label = self.model(words, feats)
            loss, s_span = self.model.loss(s_span, s_pair, s_label, charts,
                                           mask)
            chart_preds = self.model.decode(s_span, s_label, mask)
            # since the evaluation relies on terminals,
            # the tree should be first built and then factorized
            preds = [
                Tree.build(tree, [(i, j, self.CHART.vocab[label])
                                  for i, j, label in chart])
                for tree, chart in zip(trees, chart_preds)
            ]
            total_loss += loss.item()
            metric([
                Tree.factorize(tree, self.args.delete, self.args.equal)
                for tree in preds
            ], [
                Tree.factorize(tree, self.args.delete, self.args.equal)
                for tree in trees
            ])
        total_loss /= len(loader)

        return total_loss, metric
Beispiel #2
0
    def _predict(self, loader):
        self.model.eval()

        preds, probs = {'trees': []}, []

        for words, feats, trees in progress_bar(loader):
            batch_size, seq_len = words.shape
            lens = words.ne(self.args.pad_index).sum(1) - 1
            mask = lens.new_tensor(range(seq_len - 1)) < lens.view(-1, 1, 1)
            mask = mask & mask.new_ones(seq_len - 1, seq_len - 1).triu_(1)
            s_span, s_label = self.model(words, feats)
            if self.args.mbr:
                s_span = self.model.crf(s_span, mask, mbr=True)
            chart_preds = self.model.decode(s_span, s_label, mask)
            preds['trees'].extend([
                Tree.build(tree, [(i, j, self.CHART.vocab[label])
                                  for i, j, label in chart])
                for tree, chart in zip(trees, chart_preds)
            ])
            if self.args.prob:
                probs.extend([
                    prob[:i - 1, 1:i].cpu()
                    for i, prob in zip(lens, s_span.unbind())
                ])
        if self.args.prob:
            preds['probs'] = probs

        return preds
Beispiel #3
0
    def _predict(self, loader):
        self.model.eval()

        preds = {'trees': [], 'probs': [] if self.args.prob else None}
        for batch in progress_bar(loader):
            words, *feats, trees = batch
            word_mask = words.ne(self.args.pad_index)[:, 1:]
            mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
            mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
            lens = mask[:, 0].sum(-1)
            s_span, s_label = self.model(words, feats)
            s_span = ConstituencyCRF(
                s_span, mask[:,
                             0].sum(-1)).marginals if self.args.mbr else s_span
            chart_preds = self.model.decode(s_span, s_label, mask)
            preds['trees'].extend([
                Tree.build(tree, [(i, j, self.CHART.vocab[label])
                                  for i, j, label in chart])
                for tree, chart in zip(trees, chart_preds)
            ])
            if self.args.prob:
                preds['probs'].extend(
                    [prob[:i - 1, 1:i].cpu() for i, prob in zip(lens, s_span)])

        return preds
Beispiel #4
0
    def _evaluate(self, loader):
        self.model.eval()

        total_loss, metric = 0, BracketMetric()

        for words, feats, trees, (spans, labels) in loader:
            batch_size, seq_len = words.shape
            lens = words.ne(self.args.pad_index).sum(1) - 1
            mask = lens.new_tensor(range(seq_len - 1)) < lens.view(-1, 1, 1)
            mask = mask & mask.new_ones(seq_len - 1, seq_len - 1).triu_(1)
            s_span, s_label = self.model(words, feats)
            loss, s_span = self.model.loss(s_span, s_label, spans, labels,
                                           mask, self.args.mbr)
            chart_preds = self.model.decode(s_span, s_label, mask)
            # since the evaluation relies on terminals,
            # the tree should be first built and then factorized
            preds = [
                Tree.build(tree, [(i, j, self.CHART.vocab[label])
                                  for i, j, label in chart])
                for tree, chart in zip(trees, chart_preds)
            ]
            total_loss += loss.item()
            metric([
                Tree.factorize(tree, self.args.delete, self.args.equal)
                for tree in preds
            ], [
                Tree.factorize(tree, self.args.delete, self.args.equal)
                for tree in trees
            ])
        total_loss /= len(loader)

        return total_loss, metric