示例#1
0
文件: dep.py 项目: zouyanjian/HanLP
 def update_metrics(self, batch: Dict[str, Any],
                    output: Union[torch.Tensor, Dict[str, torch.Tensor],
                                  Iterable[torch.Tensor], Any],
                    prediction: Dict[str, Any], metric: Union[MetricDict,
                                                              Metric]):
     BiaffineDependencyParser.update_metric(self, *prediction, batch['arc'],
                                            batch['rel_id'], output[1],
                                            batch.get('punct_mask',
                                                      None), metric, batch)
示例#2
0
 def update_metrics(self, metrics, batch, outputs, mask):
     arc_preds, rel_preds, puncts = outputs['arc_preds'], outputs[
         'rel_preds'], batch.get('punct_mask', None)
     BiaffineDependencyParser.update_metric(self, arc_preds, rel_preds,
                                            batch['arc'], batch['rel_id'],
                                            mask, puncts, metrics['deps'],
                                            batch)
     for task, key in zip(['lemmas', 'upos', 'feats'],
                          ['lemma_id', 'pos_id', 'feat_id']):
         metric: Metric = metrics[task]
         pred = outputs['class_probabilities'][task]
         gold = batch[key]
         metric(pred.detach(), gold, mask=mask)
     return metrics