def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[float], Optional[torch.Tensor],
               Optional[torch.Tensor]]:
        has_labels = all(inputs.get(k) is not None for k in self.label_names)
        inputs = self._prepare_inputs(inputs)
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config,
                                      "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []
        with torch.no_grad():
            if has_labels:
                labels = inputs.pop("labels")
            if self.use_amp:
                with autocast():
                    outputs = model(**inputs)
            else:
                outputs = model(**inputs)
            logits = outputs[0]
            if has_labels:
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
            else:
                loss = None
            # TODO: this needs to be fixed and made cleaner later.
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index if has_labels else
                                     self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None)
        logits = nested_detach(logits)
        if has_labels:
            labels = nested_detach(labels)
        else:
            labels = None
        return (loss, logits, labels)
示例#2
0
 def prediction_step(self,
                     model,
                     inputs,
                     prediction_loss_only,
                     ignore_keys=None):
     """Modified to use LXMERT QA score for logits."""
     has_labels = all(inputs.get(k) is not None for k in self.label_names)
     inputs = {k: v for (k, v) in inputs.items() if k != "question_id"}
     inputs = self._prepare_inputs(inputs)
     with torch.no_grad():
         if self.args.fp16 and _use_native_amp:
             with autocast():
                 outputs = model(**inputs)
         else:
             outputs = model(**inputs)
         if has_labels:
             loss = outputs[0].mean().detach()
             logits = outputs[1:2]
             # Limit slice to question_answering_score.
         else:
             loss = None
             # Limit slice to question_answering_score.
             logits = outputs[0:1]
         if self.args.past_index >= 0:
             self._past = (
                 outputs[self.args.past_index if has_labels else self.args.
                         past_index - 1])
             logits = logits[:self.args.past_index -
                             1] + logits[self.args.past_index:]
     if prediction_loss_only:
         return (loss, None, None)
     logits = nested_detach(logits)
     if len(logits) == 1:
         logits = logits[0]
     if has_labels:
         labels = nested_detach(
             tuple(inputs.get(name) for name in self.label_names))
         if len(labels) == 1:
             labels = labels[0]
     else:
         labels = None
     return (loss, logits, labels)
示例#3
0
    def prediction_step(self, inputs, label_names=["labels"]):
        has_labels = all(inputs.get(k) is not None for k in label_names)
        inputs = self.prepare_inputs(inputs)
        if hasattr(self.model, "config"):
            ignore_keys = getattr(self.model.config,
                                  "keys_to_ignore_at_inference", [])
        else:
            ignore_keys = []

        if has_labels:
            labels = trainer_pt_utils.nested_detach(
                tuple(inputs.get(name) for name in label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

        with torch.no_grad():
            if has_labels:
                loss, outputs = self.compute_loss(inputs, True)
                loss = loss.mean().detach()
                if isinstance(outputs, dict):
                    logits = tuple(v for k, v in outputs.items()
                                   if k not in ignore_keys + ["loss"])
                else:
                    logits = outputs[1:]
            else:
                loss, outputs = None, self.model(**inputs)
                if isinstance(outputs, dict):
                    logits = tuple(v for k, v in outputs.items()
                                   if k not in ignore_keys + ["loss"])
                else:
                    logits = outputs

        logits = trainer_pt_utils.nested_detach(logits)
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)
    def compute_loss(self, model, inputs):
        """
        Override loss computation to calculate and log metrics
        during training
        """
        outputs = model(**inputs)

        # Custom logging steps (to log training metrics)
        if (self.state.global_step == 1 and self.args.logging_first_step) or (
                self.args.logging_steps > 0 and self.state.global_step > 0
                and self.state.global_step % self.args.logging_steps == 0):
            labels = None
            has_labels = all(
                inputs.get(k) is not None for k in self.label_names)
            if has_labels:
                labels = nested_detach(
                    tuple(inputs.get(name) for name in self.label_names))
                if len(labels) == 1:
                    labels = labels[0]

            # Compute and log metrics only if labels are available
            if labels is not None:
                metrics = self.compute_scores(
                    EvalPrediction(
                        predictions=(outputs["word_outputs"],
                                     outputs["indexes"]),
                        label_ids=labels,
                    ))
                if self.wandb_callback is not None:
                    self.wandb_callback.update_metrics(metrics)

        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
        # We don't use .loss here since the model may return tuples instead of ModelOutput.
        return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
示例#5
0
    def prediction_step(
        self,
        model,
        inputs,
        prediction_loss_only,
        ignore_keys=None,
    ):
        """
        Slightly changed ~transformers.Trainer.prediction_step using
        confi_prediction method to find prediction of Cotrain and TriTrain Models.
        Used during evaluation step.
        """
        has_labels = all(inputs.get(k) is not None for k in self.label_names)
        inputs = self._prepare_inputs(inputs)
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config,
                                      "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []

        with torch.no_grad():
            if self.use_amp:
                with autocast():
                    outputs = model(**inputs)
            else:
                outputs = model(**inputs)
            if has_labels:
                if self.label_smoother is not None and "labels" in inputs:
                    loss = self.label_smoother(
                        outputs, inputs["labels"]).mean().detach()
                else:
                    loss = (outputs["loss"] if isinstance(outputs, dict) else
                            outputs[0]).mean().detach()
                if isinstance(outputs, dict):
                    logits = {
                        k: v
                        for k, v in outputs.items()
                        if k not in ignore_keys + ["loss"]
                    }
                    logits = self.confi_prediction(**logits)
                else:

                    logits = outputs[1:]
            else:
                loss = None
                if isinstance(outputs, dict):
                    logits = tuple(v for k, v in outputs.items()
                                   if k not in ignore_keys)
                else:
                    logits = outputs
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index if has_labels else
                                     self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None)

        logits = nested_detach(logits)
        if len(logits) == 1:
            logits = logits[0]

        if has_labels:
            labels = nested_detach(
                tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

        return (loss, logits, labels)
示例#6
0
    def prediction_step(
        self,
        model,
        inputs,
        prediction_loss_only,
        ignore_keys=None,
    ):
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.
            ignore_keys (:obj:`Lst[str]`, `optional`):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
            labels (each being optional).
        """
        has_labels = all(inputs.get(k) is not None for k in self.label_names)
        inputs = self._prepare_inputs(inputs)
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config,
                                      "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []

        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
        if has_labels:
            labels = nested_detach(
                tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

        with torch.no_grad():
            if has_labels:
                loss, outputs = self.compute_loss(model,
                                                  inputs,
                                                  return_outputs=True)
                loss = loss.mean().detach()
                if isinstance(outputs, dict):
                    logits = tuple(v for k, v in outputs.items()
                                   if k not in ignore_keys + ["loss"])
                else:
                    logits = outputs[1:]
            else:
                loss = None
                outputs = model(**inputs)
                if isinstance(outputs, dict):
                    logits = tuple(v for k, v in outputs.items()
                                   if k not in ignore_keys)
                else:
                    logits = outputs
                # TODO: this needs to be fixed and made cleaner later.

        if prediction_loss_only:
            return (loss, None, None)

        logits = nested_detach(logits)
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)
示例#7
0
    def evaluate(self,
                 aikidoka: Aikidoka,
                 kata: Kata,
                 metrics: Metrics = []) -> [DojoEvaluation]:
        aikidoka.eval()
        aikidoka = Ref(aikidoka)

        # invoke training_started listener event
        self.tell(lambda x: x.evaluation_started(
            OnEvaluationStarted(aikidoka, kata, self.dojo_kun)))
        aikidoka = aikidoka.wrapped

        # load the kata data
        data = kata.load(self.dojo_kun.batch_size)
        # aikidoka.pre_init(data)

        # invoke kata_load_finished listener event
        self.tell(lambda x: x.kata_load_finished(
            OnKataLoaded(aikidoka, kata, data, self.dojo_kun)))

        loss_collector = TensorCollector(self.dojo_kun.local_rank, len(data),
                                         self.dojo_kun.batch_size)
        pred_collector = TensorCollector(self.dojo_kun.local_rank, len(data))
        label_collector = TensorCollector(self.dojo_kun.local_rank, len(data))

        with torch.no_grad():
            for batch_idx, batch in enumerate(data.data_loader):
                run = Run(0, 0, batch_idx, len(data.data_loader))
                batch_ref = Ref(batch)

                # invoke batch_started listener event
                self.tell(lambda x: x.batch_started(
                    OnBatchStarted(aikidoka, self.dojo_kun, batch_ref, run)))
                batch = batch_ref.get_wrapped()

                result = aikidoka(**batch)
                loss = aggregate(result)

                loss_collector.concat(loss, repeat=True)
                pred_collector.concat_all(nested_detach(result[0][1]))
                # label_collector.concat_all(nested_detach(result[0][-1]))
                # pred_collector.concat_all(nested_detach(result[1]))
                label_collector.concat_all(nested_detach(
                    batch["labels"]))  # FIXME

                # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
                if self.dojo_kun.grad_acc_steps is not None and (
                        batch_idx + 1) % self.dojo_kun.grad_acc_steps == 0:
                    loss_collector.gather()
                    pred_collector.gather()
                    label_collector.gather()

                # invoke batch_finished listener event
                self.tell(lambda x: x.batch_finished(
                    OnBatchFinished(aikidoka, batch, run, loss)))

        # Gather all remaining tensors and put them back on the CPU
        loss = loss_collector.finalize()
        pred = pred_collector.finalize()
        label = label_collector.finalize()

        # invoke training_finished listener event
        self.tell(lambda x: x.evaluation_finished(
            OnEvaluationFinished(aikidoka, kata, self.dojo_kun)))

        return DojoEvaluation(pred, loss,
                              compute_metrics(metrics, pred, label))