def prediction_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: has_labels = all(inputs.get(k) is not None for k in self.label_names) inputs = self._prepare_inputs(inputs) if ignore_keys is None: if hasattr(self.model, "config"): ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) else: ignore_keys = [] with torch.no_grad(): if has_labels: labels = inputs.pop("labels") if self.use_amp: with autocast(): outputs = model(**inputs) else: outputs = model(**inputs) logits = outputs[0] if has_labels: loss_fct = BCEWithLogitsLoss() loss = loss_fct(logits, labels) else: loss = None # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1] if prediction_loss_only: return (loss, None, None) logits = nested_detach(logits) if has_labels: labels = nested_detach(labels) else: labels = None return (loss, logits, labels)
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None): """Modified to use LXMERT QA score for logits.""" has_labels = all(inputs.get(k) is not None for k in self.label_names) inputs = {k: v for (k, v) in inputs.items() if k != "question_id"} inputs = self._prepare_inputs(inputs) with torch.no_grad(): if self.args.fp16 and _use_native_amp: with autocast(): outputs = model(**inputs) else: outputs = model(**inputs) if has_labels: loss = outputs[0].mean().detach() logits = outputs[1:2] # Limit slice to question_answering_score. else: loss = None # Limit slice to question_answering_score. logits = outputs[0:1] if self.args.past_index >= 0: self._past = ( outputs[self.args.past_index if has_labels else self.args. past_index - 1]) logits = logits[:self.args.past_index - 1] + logits[self.args.past_index:] if prediction_loss_only: return (loss, None, None) logits = nested_detach(logits) if len(logits) == 1: logits = logits[0] if has_labels: labels = nested_detach( tuple(inputs.get(name) for name in self.label_names)) if len(labels) == 1: labels = labels[0] else: labels = None return (loss, logits, labels)
def prediction_step(self, inputs, label_names=["labels"]): has_labels = all(inputs.get(k) is not None for k in label_names) inputs = self.prepare_inputs(inputs) if hasattr(self.model, "config"): ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) else: ignore_keys = [] if has_labels: labels = trainer_pt_utils.nested_detach( tuple(inputs.get(name) for name in label_names)) if len(labels) == 1: labels = labels[0] else: labels = None with torch.no_grad(): if has_labels: loss, outputs = self.compute_loss(inputs, True) loss = loss.mean().detach() if isinstance(outputs, dict): logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) else: logits = outputs[1:] else: loss, outputs = None, self.model(**inputs) if isinstance(outputs, dict): logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) else: logits = outputs logits = trainer_pt_utils.nested_detach(logits) if len(logits) == 1: logits = logits[0] return (loss, logits, labels)
def compute_loss(self, model, inputs): """ Override loss computation to calculate and log metrics during training """ outputs = model(**inputs) # Custom logging steps (to log training metrics) if (self.state.global_step == 1 and self.args.logging_first_step) or ( self.args.logging_steps > 0 and self.state.global_step > 0 and self.state.global_step % self.args.logging_steps == 0): labels = None has_labels = all( inputs.get(k) is not None for k in self.label_names) if has_labels: labels = nested_detach( tuple(inputs.get(name) for name in self.label_names)) if len(labels) == 1: labels = labels[0] # Compute and log metrics only if labels are available if labels is not None: metrics = self.compute_scores( EvalPrediction( predictions=(outputs["word_outputs"], outputs["indexes"]), label_ids=labels, )) if self.wandb_callback is not None: self.wandb_callback.update_metrics(metrics) # Save past state if it exists # TODO: this needs to be fixed and made cleaner later. if self.args.past_index >= 0: self._past = outputs[self.args.past_index] # We don't use .loss here since the model may return tuples instead of ModelOutput. return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
def prediction_step( self, model, inputs, prediction_loss_only, ignore_keys=None, ): """ Slightly changed ~transformers.Trainer.prediction_step using confi_prediction method to find prediction of Cotrain and TriTrain Models. Used during evaluation step. """ has_labels = all(inputs.get(k) is not None for k in self.label_names) inputs = self._prepare_inputs(inputs) if ignore_keys is None: if hasattr(self.model, "config"): ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) else: ignore_keys = [] with torch.no_grad(): if self.use_amp: with autocast(): outputs = model(**inputs) else: outputs = model(**inputs) if has_labels: if self.label_smoother is not None and "labels" in inputs: loss = self.label_smoother( outputs, inputs["labels"]).mean().detach() else: loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() if isinstance(outputs, dict): logits = { k: v for k, v in outputs.items() if k not in ignore_keys + ["loss"] } logits = self.confi_prediction(**logits) else: logits = outputs[1:] else: loss = None if isinstance(outputs, dict): logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) else: logits = outputs if self.args.past_index >= 0: self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1] if prediction_loss_only: return (loss, None, None) logits = nested_detach(logits) if len(logits) == 1: logits = logits[0] if has_labels: labels = nested_detach( tuple(inputs.get(name) for name in self.label_names)) if len(labels) == 1: labels = labels[0] else: labels = None return (loss, logits, labels)
def prediction_step( self, model, inputs, prediction_loss_only, ignore_keys=None, ): """ Perform an evaluation step on :obj:`model` using obj:`inputs`. Subclass and override to inject custom behavior. Args: model (:obj:`nn.Module`): The model to evaluate. inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument :obj:`labels`. Check your model's documentation for all accepted arguments. prediction_loss_only (:obj:`bool`): Whether or not to return the loss only. ignore_keys (:obj:`Lst[str]`, `optional`): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. Return: Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and labels (each being optional). """ has_labels = all(inputs.get(k) is not None for k in self.label_names) inputs = self._prepare_inputs(inputs) if ignore_keys is None: if hasattr(self.model, "config"): ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) else: ignore_keys = [] # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. if has_labels: labels = nested_detach( tuple(inputs.get(name) for name in self.label_names)) if len(labels) == 1: labels = labels[0] else: labels = None with torch.no_grad(): if has_labels: loss, outputs = self.compute_loss(model, inputs, return_outputs=True) loss = loss.mean().detach() if isinstance(outputs, dict): logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) else: logits = outputs[1:] else: loss = None outputs = model(**inputs) if isinstance(outputs, dict): logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) else: logits = outputs # TODO: this needs to be fixed and made cleaner later. if prediction_loss_only: return (loss, None, None) logits = nested_detach(logits) if len(logits) == 1: logits = logits[0] return (loss, logits, labels)
def evaluate(self, aikidoka: Aikidoka, kata: Kata, metrics: Metrics = []) -> [DojoEvaluation]: aikidoka.eval() aikidoka = Ref(aikidoka) # invoke training_started listener event self.tell(lambda x: x.evaluation_started( OnEvaluationStarted(aikidoka, kata, self.dojo_kun))) aikidoka = aikidoka.wrapped # load the kata data data = kata.load(self.dojo_kun.batch_size) # aikidoka.pre_init(data) # invoke kata_load_finished listener event self.tell(lambda x: x.kata_load_finished( OnKataLoaded(aikidoka, kata, data, self.dojo_kun))) loss_collector = TensorCollector(self.dojo_kun.local_rank, len(data), self.dojo_kun.batch_size) pred_collector = TensorCollector(self.dojo_kun.local_rank, len(data)) label_collector = TensorCollector(self.dojo_kun.local_rank, len(data)) with torch.no_grad(): for batch_idx, batch in enumerate(data.data_loader): run = Run(0, 0, batch_idx, len(data.data_loader)) batch_ref = Ref(batch) # invoke batch_started listener event self.tell(lambda x: x.batch_started( OnBatchStarted(aikidoka, self.dojo_kun, batch_ref, run))) batch = batch_ref.get_wrapped() result = aikidoka(**batch) loss = aggregate(result) loss_collector.concat(loss, repeat=True) pred_collector.concat_all(nested_detach(result[0][1])) # label_collector.concat_all(nested_detach(result[0][-1])) # pred_collector.concat_all(nested_detach(result[1])) label_collector.concat_all(nested_detach( batch["labels"])) # FIXME # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. if self.dojo_kun.grad_acc_steps is not None and ( batch_idx + 1) % self.dojo_kun.grad_acc_steps == 0: loss_collector.gather() pred_collector.gather() label_collector.gather() # invoke batch_finished listener event self.tell(lambda x: x.batch_finished( OnBatchFinished(aikidoka, batch, run, loss))) # Gather all remaining tensors and put them back on the CPU loss = loss_collector.finalize() pred = pred_collector.finalize() label = label_collector.finalize() # invoke training_finished listener event self.tell(lambda x: x.evaluation_finished( OnEvaluationFinished(aikidoka, kata, self.dojo_kun))) return DojoEvaluation(pred, loss, compute_metrics(metrics, pred, label))