def _predict(self, loader): self.model.eval() preds = {} charts, probs = [], [] for words, *feats in progress_bar(loader): mask = words.ne(self.WORD.pad_index) mask = mask.unsqueeze(1) & mask.unsqueeze(2) mask[:, 0] = 0 lens = mask[:, 1].sum(-1).tolist() s_edge, s_label = self.model(words, feats) edge_preds, label_preds = self.model.decode(s_edge, s_label) chart_preds = label_preds.masked_fill(~(edge_preds.gt(0) & mask), -1) charts.extend(chart[1:i, :i].tolist() for i, chart in zip(lens, chart_preds.unbind())) if self.args.prob: probs.extend([ prob[1:i, :i].cpu() for i, prob in zip(lens, s_edge.softmax(-1).unbind()) ]) charts = [ CoNLL.build_relations( [[self.LABEL.vocab[i] if i >= 0 else None for i in row] for row in chart]) for chart in charts ] preds = {'labels': charts} if self.args.prob: preds['probs'] = probs return preds
def _predict(self, loader): self.model.eval() preds = {'labels': [], 'probs': [] if self.args.prob else None} for words, *feats in progress_bar(loader): word_mask = words.ne(self.args.pad_index) mask = word_mask if len(words.shape) < 3 else word_mask.any(-1) mask = mask.unsqueeze(1) & mask.unsqueeze(2) mask[:, 0] = 0 lens = mask[:, 1].sum(-1).tolist() s_edge, s_sib, s_cop, s_grd, s_label = self.model(words, feats) s_edge = self.model.inference((s_edge, s_sib, s_cop, s_grd), mask) label_preds = self.model.decode(s_edge, s_label).masked_fill(~mask, -1) preds['labels'].extend(chart[1:i, :i].tolist() for i, chart in zip(lens, label_preds)) if self.args.prob: preds['probs'].extend([ prob[1:i, :i].cpu() for i, prob in zip(lens, s_edge.unbind()) ]) preds['labels'] = [ CoNLL.build_relations( [[self.LABEL.vocab[i] if i >= 0 else None for i in row] for row in chart]) for chart in preds['labels'] ] return preds
def _predict(self, loader): self.model.eval() preds = {} charts, probs = [], [] for words, feats in progress_bar(loader): mask = words.ne(self.WORD.pad_index) mask = mask.unsqueeze(1) & mask.unsqueeze(2) lens = mask[:, 0].sum(-1).tolist() s_edge, s_label = self.model(words, feats) charts.extend(self.model.decode(s_edge, s_label, mask)) if self.args.prob: edge_probs = s_edge.softmax(-1) probs.extend([ prob[:i, :i].cpu() for i, prob in zip(lens, edge_probs.unbind()) ]) charts = [[[self.LABEL.vocab[i] if i >= 0 else None for i in row] for row in chart] for chart in charts] preds = {'labels': [CoNLL.build_relations(chart) for chart in charts]} if self.args.prob: preds['probs'] = probs return preds