def _evaluate(self, loader): self.model.eval() total_loss, metric = 0, SpanMetric() for batch in loader: words, *feats, trees, charts = batch word_mask = words.ne(self.args.pad_index)[:, 1:] mask = word_mask if len(words.shape) < 3 else word_mask.any(-1) mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1) s_span, s_pair, s_label = self.model(words, feats) loss, s_span = self.model.loss(s_span, s_pair, s_label, charts, mask) chart_preds = self.model.decode(s_span, s_label, mask) # since the evaluation relies on terminals, # the tree should be first built and then factorized preds = [ Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart]) for tree, chart in zip(trees, chart_preds) ] total_loss += loss.item() metric([ Tree.factorize(tree, self.args.delete, self.args.equal) for tree in preds ], [ Tree.factorize(tree, self.args.delete, self.args.equal) for tree in trees ]) total_loss /= len(loader) return total_loss, metric
def _predict(self, loader): self.model.eval() preds, probs = {'trees': []}, [] for words, feats, trees in progress_bar(loader): batch_size, seq_len = words.shape lens = words.ne(self.args.pad_index).sum(1) - 1 mask = lens.new_tensor(range(seq_len - 1)) < lens.view(-1, 1, 1) mask = mask & mask.new_ones(seq_len - 1, seq_len - 1).triu_(1) s_span, s_label = self.model(words, feats) if self.args.mbr: s_span = self.model.crf(s_span, mask, mbr=True) chart_preds = self.model.decode(s_span, s_label, mask) preds['trees'].extend([ Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart]) for tree, chart in zip(trees, chart_preds) ]) if self.args.prob: probs.extend([ prob[:i - 1, 1:i].cpu() for i, prob in zip(lens, s_span.unbind()) ]) if self.args.prob: preds['probs'] = probs return preds
def _predict(self, loader): self.model.eval() preds = {'trees': [], 'probs': [] if self.args.prob else None} for batch in progress_bar(loader): words, *feats, trees = batch word_mask = words.ne(self.args.pad_index)[:, 1:] mask = word_mask if len(words.shape) < 3 else word_mask.any(-1) mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1) lens = mask[:, 0].sum(-1) s_span, s_label = self.model(words, feats) s_span = ConstituencyCRF( s_span, mask[:, 0].sum(-1)).marginals if self.args.mbr else s_span chart_preds = self.model.decode(s_span, s_label, mask) preds['trees'].extend([ Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart]) for tree, chart in zip(trees, chart_preds) ]) if self.args.prob: preds['probs'].extend( [prob[:i - 1, 1:i].cpu() for i, prob in zip(lens, s_span)]) return preds
def _evaluate(self, loader): self.model.eval() total_loss, metric = 0, BracketMetric() for words, feats, trees, (spans, labels) in loader: batch_size, seq_len = words.shape lens = words.ne(self.args.pad_index).sum(1) - 1 mask = lens.new_tensor(range(seq_len - 1)) < lens.view(-1, 1, 1) mask = mask & mask.new_ones(seq_len - 1, seq_len - 1).triu_(1) s_span, s_label = self.model(words, feats) loss, s_span = self.model.loss(s_span, s_label, spans, labels, mask, self.args.mbr) chart_preds = self.model.decode(s_span, s_label, mask) # since the evaluation relies on terminals, # the tree should be first built and then factorized preds = [ Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart]) for tree, chart in zip(trees, chart_preds) ] total_loss += loss.item() metric([ Tree.factorize(tree, self.args.delete, self.args.equal) for tree in preds ], [ Tree.factorize(tree, self.args.delete, self.args.equal) for tree in trees ]) total_loss /= len(loader) return total_loss, metric