def predict_task(self, task: Task, output_key, batch, results, output_dict=None, run_transform=True, cls_is_bos=True, sep_is_eos=True): output_dict, batch = self.feed_batch(batch, output_key, output_dict, run_transform, cls_is_bos, sep_is_eos, results) self.decode_output(output_dict, batch, output_key) results[output_key].extend(task.prediction_to_result(output_dict[output_key]['prediction'], batch)) return output_dict
def build_transform(self, task: Task) -> Tuple[TransformerSequenceTokenizer, TransformList]: encoder: ContextualWordEmbedding = self.config.encoder encoder_transform: TransformerSequenceTokenizer = task.build_tokenizer(encoder.transform()) length_transform = FieldLength('token', 'token_length') transform = TransformList(encoder_transform, length_transform) extra_transform = self.config.get('transform', None) if extra_transform: transform.insert(0, extra_transform) return encoder_transform, transform
def compute_loss(self, batch: Dict[str, Any], output: Union[torch.Tensor, Dict[str, torch.Tensor], Iterable[torch.Tensor], Any], criterion: Callable, task: Task) -> torch.FloatTensor: return task.compute_loss(batch, output, criterion)