Ejemplo n.º 1
0
 def build_dataloader(self,
                      data,
                      transform: Callable = None,
                      training=False,
                      device=None,
                      logger: logging.Logger = None,
                      gradient_accumulation=1,
                      **kwargs) -> DataLoader:
     dataset = BiaffineSecondaryParser.build_dataset(self, data, transform)
     dataset.purge_cache()
     if self.vocabs.mutable:
         BiaffineSecondaryParser.build_vocabs(self,
                                              dataset,
                                              logger,
                                              transformer=True)
     return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build(
         self.compute_lens(data, dataset),
         shuffle=training,
         gradient_accumulation=gradient_accumulation),
                                  device=device,
                                  dataset=dataset,
                                  pad={
                                      'arc': 0,
                                      'arc_2nd': False
                                  })
Ejemplo n.º 2
0
    def update_metrics(self, batch: Dict[str, Any],
                       output: Union[torch.Tensor, Dict[str, torch.Tensor],
                                     Iterable[torch.Tensor], Any],
                       prediction: Dict[str, Any], metric: Union[MetricDict,
                                                                 Metric]):

        BiaffineSecondaryParser.update_metric(self, *prediction, batch['arc'],
                                              batch['rel_id'], output[1],
                                              batch['punct_mask'], metric,
                                              batch)
Ejemplo n.º 3
0
 def compute_loss(self, batch: Dict[str, Any],
                  output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], criterion) -> \
         Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
     return BiaffineSecondaryParser.compute_loss(self, *output[0],
                                                 batch['arc'],
                                                 batch['rel_id'], output[1],
                                                 criterion, batch)
Ejemplo n.º 4
0
 def prediction_to_result(self, prediction: Dict[str, Any],
                          batch: Dict[str, Any]) -> List:
     outputs = []
     return BiaffineSecondaryParser.predictions_to_human(self,
                                                         prediction,
                                                         outputs,
                                                         batch['token'],
                                                         use_pos=False)
Ejemplo n.º 5
0
 def input_is_flat(self, data) -> bool:
     return BiaffineSecondaryParser.input_is_flat(self, data)
Ejemplo n.º 6
0
 def build_criterion(self, **kwargs):
     return BiaffineSecondaryParser.build_criterion(self, **kwargs)
Ejemplo n.º 7
0
 def build_metric(self, **kwargs):
     return BiaffineSecondaryParser.build_metric(self, **kwargs)
Ejemplo n.º 8
0
 def decode_output(self, output: Dict[str, Any], batch: Dict[str, Any], decoder, **kwargs) \
         -> Union[Dict[str, Any], Any]:
     return BiaffineSecondaryParser.decode(self, *output[0], output[1],
                                           batch)