示例#1
0
 def build_dataloader(self,
                      data,
                      transform: TransformList = None,
                      training=False,
                      device=None,
                      logger: logging.Logger = None,
                      gradient_accumulation=1,
                      **kwargs) -> DataLoader:
     transform.insert(0, append_bos)
     dataset = BiaffineDependencyParser.build_dataset(self, data, transform)
     if isinstance(data, str):
         dataset.purge_cache()
     if self.vocabs.mutable:
         BiaffineDependencyParser.build_vocabs(self,
                                               dataset,
                                               logger,
                                               transformer=True)
     if dataset.cache:
         timer = CountdownTimer(len(dataset))
         BiaffineDependencyParser.cache_dataset(self, dataset, timer,
                                                training, logger)
     max_seq_len = self.config.get('max_seq_len', None)
     if max_seq_len and isinstance(data, str):
         dataset.prune(lambda x: len(x['token_input_ids']) > 510, logger)
     return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build(
         self.compute_lens(data, dataset, length_field='FORM'),
         shuffle=training,
         gradient_accumulation=gradient_accumulation),
                                  device=device,
                                  dataset=dataset,
                                  pad=self.get_pad_dict())
示例#2
0
 def build_dataloader(self,
                      data,
                      transform: Callable = None,
                      training=False,
                      device=None,
                      logger: logging.Logger = None,
                      cache=False,
                      gradient_accumulation=1,
                      **kwargs) -> DataLoader:
     dataset = CRFConstituencyParsing.build_dataset(self, data, transform)
     if isinstance(data, str):
         dataset.purge_cache()
     if self.vocabs.mutable:
         CRFConstituencyParsing.build_vocabs(self, dataset, logger)
     if dataset.cache:
         timer = CountdownTimer(len(dataset))
         # noinspection PyCallByClass
         BiaffineDependencyParser.cache_dataset(self, dataset, timer,
                                                training, logger)
     return PadSequenceDataLoader(batch_sampler=self.sampler_builder.build(
         self.compute_lens(data, dataset),
         shuffle=training,
         gradient_accumulation=gradient_accumulation),
                                  device=device,
                                  dataset=dataset)
示例#3
0
 def update_metrics(self, batch: Dict[str, Any],
                    output: Union[torch.Tensor, Dict[str, torch.Tensor],
                                  Iterable[torch.Tensor], Any],
                    prediction: Dict[str, Any], metric: Union[MetricDict,
                                                              Metric]):
     BiaffineDependencyParser.update_metric(self, *prediction, batch['arc'],
                                            batch['rel_id'], output[1],
                                            batch.get('punct_mask',
                                                      None), metric, batch)
示例#4
0
 def compute_loss(self, batch: Dict[str, Any],
                  output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], criterion) -> \
         Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
     (arc_scores, rel_scores), mask = output
     return BiaffineDependencyParser.compute_loss(self, arc_scores,
                                                  rel_scores, batch['arc'],
                                                  batch['rel_id'], mask,
                                                  criterion, batch)
示例#5
0
 def decode_output(self, output: Union[torch.Tensor, Dict[str,
                                                          torch.Tensor],
                                       Iterable[torch.Tensor], Any],
                   mask: torch.BoolTensor, batch: Dict[str, Any], decoder,
                   **kwargs) -> Union[Dict[str, Any], Any]:
     (arc_scores, rel_scores), mask = output
     return BiaffineDependencyParser.decode(self, arc_scores, rel_scores,
                                            mask, batch)
示例#6
0
 def compute_loss(self,
                  arc_scores,
                  rel_scores,
                  arcs,
                  rels,
                  mask,
                  criterion,
                  batch=None):
     parse_loss = BiaffineDependencyParser.compute_loss(
         self, arc_scores, rel_scores, arcs, rels, mask, criterion, batch)
     if self.model.training:
         gold_input_ids = batch['gold_input_ids']
         pred_input_ids = batch['pred_input_ids']
         input_ids_mask = batch['input_ids_mask']
         token_span = batch['token_span']
         gold_input_ids = batch['gold_input_ids'] = gold_input_ids.gather(
             1, token_span[:, :, 0])
         input_ids_mask = batch['input_ids_mask'] = input_ids_mask.gather(
             1, token_span[:, :, 0])
         mlm_loss = F.cross_entropy(pred_input_ids[input_ids_mask],
                                    gold_input_ids[input_ids_mask])
         loss = parse_loss + mlm_loss
         return loss
     return parse_loss
示例#7
0
 def input_is_flat(self, data) -> bool:
     return BiaffineDependencyParser.input_is_flat(self, data,
                                                   self.config.use_pos)
示例#8
0
 def build_metric(self, **kwargs):
     return BiaffineDependencyParser.build_metric(self, **kwargs)