Example #1
0
 def compute_loss(self, batch: Dict[str, Any],
                  output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], criterion) -> \
         Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
     (arc_scores, rel_scores), mask, punct_mask = output
     return BiaffineSemanticDependencyParser.compute_loss(
         self, arc_scores, rel_scores, batch['arc'], batch['rel_id'], mask,
         criterion, batch)
Example #2
0
 def compute_loss(self,
                  arc_scores,
                  rel_scores,
                  arcs,
                  rels,
                  mask,
                  criterion,
                  batch=None):
     arc_scores_1st, arc_scores_2nd, rel_scores_1st, rel_scores_2nd = self.unpack_scores(
         arc_scores, rel_scores)
     loss_1st = super().compute_loss(arc_scores_1st, rel_scores_1st, arcs,
                                     rels, mask, criterion[0], batch)
     mask = self.compute_mask(arc_scores_2nd, batch, mask)
     # noinspection PyCallByClass
     loss_2st = BiaffineSemanticDependencyParser.compute_loss(
         self, arc_scores_2nd, rel_scores_2nd, batch['arc_2nd'],
         batch['rel_2nd_id'], mask, criterion[1], batch)
     return loss_1st + loss_2st