Ejemplo n.º 1
0
 def predict_task(self, task: Task, output_key, batch, results, output_dict=None, run_transform=True,
                  cls_is_bos=True, sep_is_eos=True):
     output_dict, batch = self.feed_batch(batch, output_key, output_dict, run_transform, cls_is_bos, sep_is_eos,
                                          results)
     self.decode_output(output_dict, batch, output_key)
     results[output_key].extend(task.prediction_to_result(output_dict[output_key]['prediction'], batch))
     return output_dict
Ejemplo n.º 2
0
 def build_transform(self, task: Task) -> Tuple[TransformerSequenceTokenizer, TransformList]:
     encoder: ContextualWordEmbedding = self.config.encoder
     encoder_transform: TransformerSequenceTokenizer = task.build_tokenizer(encoder.transform())
     length_transform = FieldLength('token', 'token_length')
     transform = TransformList(encoder_transform, length_transform)
     extra_transform = self.config.get('transform', None)
     if extra_transform:
         transform.insert(0, extra_transform)
     return encoder_transform, transform
Ejemplo n.º 3
0
 def compute_loss(self, batch: Dict[str, Any],
                  output: Union[torch.Tensor, Dict[str, torch.Tensor],
                                Iterable[torch.Tensor], Any],
                  criterion: Callable, task: Task) -> torch.FloatTensor:
     return task.compute_loss(batch, output, criterion)