def decode_output(self, output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], mask: torch.BoolTensor, batch: Dict[str, Any], decoder, **kwargs) -> Union[Dict[str, Any], Any]: (arc_scores, rel_scores), mask, punct_mask = output return BiaffineSemanticDependencyParser.decode(self, arc_scores, rel_scores, mask, batch)
def decode(self, arc_scores, rel_scores, mask, batch=None, predicting=None): output_1st, output_2nd = batch.get('outputs', (None, None)) if output_1st is None: arc_scores_1st, arc_scores_2nd, rel_scores_1st, rel_scores_2nd = self.unpack_scores( arc_scores, rel_scores) output_1st = super().decode(arc_scores_1st, rel_scores_1st, mask) mask = self.compute_mask(arc_scores_2nd, batch, mask) # noinspection PyCallByClass output_2nd = BiaffineSemanticDependencyParser.decode( self, arc_scores_2nd, rel_scores_2nd, mask, batch) if self.config.get('no_cycle'): assert predicting, 'No cycle constraint for evaluation is not implemented yet. If you are ' \ 'interested, welcome to submit a pull request.' root_rel_idx = self.vocabs['rel'].token_to_idx.get( self.config.get('root', None), None) arc_pred_1st, rel_pred_1st, arc_pred_2nd, rel_pred_2nd = *output_1st, *output_2nd arc_scores_2nd = arc_scores_2nd.transpose( 1, 2).cpu().detach().numpy() arc_pred_2nd = arc_pred_2nd.cpu().detach().numpy() rel_pred_2nd = rel_pred_2nd.cpu().detach().numpy() trees = arc_pred_1st.cpu().detach().numpy() graphs = [] for i, (arc_scores, arc_preds, rel_preds, tree, tokens) in enumerate( zip(arc_scores_2nd, arc_pred_2nd, rel_pred_2nd, trees, batch['token'])): sent_len = len(tokens) graph = add_secondary_arcs_by_preds( arc_scores, arc_preds[:sent_len, :sent_len], rel_preds, tree[:sent_len], root_rel_idx) graphs.append(graph[1:]) # Remove root # if not predicting: # # Write back to torch Tensor # for d, hr in zip(graph): # pass output_2nd = None, graphs return tuple(zip(output_1st, output_2nd))