示例#1
0
    def test_distributed_tensor_gatherer(self):
        # Simulate a result with a dataset of size 21, 4 processes and chunks of lengths 2, 3, 1
        world_size = 4
        num_samples = 21
        input_indices = [
            [0, 1, 6, 7, 12, 13, 18, 19],
            [2, 3, 4, 8, 9, 10, 14, 15, 16, 20, 0, 1],
            [5, 11, 17, 2],
        ]

        predictions = np.random.normal(size=(num_samples, 13))
        gatherer = DistributedTensorGatherer(world_size=world_size,
                                             num_samples=num_samples)
        for indices in input_indices:
            gatherer.add_arrays(predictions[indices])
        result = gatherer.finalize()
        self.assertTrue(np.array_equal(result, predictions))

        # With nested tensors
        gatherer = DistributedTensorGatherer(world_size=world_size,
                                             num_samples=num_samples)
        for indices in input_indices:
            gatherer.add_arrays([
                predictions[indices],
                [predictions[indices], predictions[indices]]
            ])
        result = gatherer.finalize()
        self.assertTrue(isinstance(result, list))
        self.assertTrue(len(result), 2)
        self.assertTrue(isinstance(result[1], list))
        self.assertTrue(len(result[1]), 2)
        self.assertTrue(np.array_equal(result[0], predictions))
        self.assertTrue(np.array_equal(result[1][0], predictions))
        self.assertTrue(np.array_equal(result[1][1], predictions))
示例#2
0
class TensorCollector:
    collector: Optional[Union[Tensor, List[Tensor]]] = None

    def __init__(self,
                 local_rank: int,
                 num_examples: int,
                 batch_size: Optional[int] = None):
        self.gatherer = DistributedTensorGatherer(
            self._get_world_size(local_rank), num_examples, batch_size)
        self.local_rank = local_rank
        self.batch_size = batch_size
        self.num_examples = num_examples

    def _get_world_size(self, local_rank: int):
        world_size = 1
        if is_torch_tpu_available():
            raise NotImplementedError()
            # world_size = xm.xrt_world_size()
        elif local_rank != -1:
            # noinspection PyUnresolvedReferences
            world_size = torch.distributed.get_world_size()
        return max(1, world_size)

    def concat(self, tensor: Tensor, repeat: bool = False, dim: int = 0):
        tensors = tensor.repeat(self.batch_size) if repeat else tensor
        self.collector = tensors if self.collector is None else torch.cat(
            (self.collector, tensors), dim=dim)

    def concat_all(self,
                   tensor: Union[Tensor, List[Tensor]],
                   padding_index: int = -100):
        if type(tensor) is list and len(tensor) == 1:
            tensor = tensor[0]
        self.collector = tensor if self.collector is None else nested_concat(
            self.collector, tensor, padding_index)

    def gather(self):
        if self.collector is not None:
            self.gatherer.add_arrays(self._gather_and_numpify(self.collector))
            self.collector = None

    def finalize(self):
        self.gather()
        return self.gatherer.finalize()

    def _gather_and_numpify(self, tensors):
        """
        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
        concatenating them to `gathered`
        """
        if tensors is None:
            return
        if is_torch_tpu_available():
            # tensors = nested_xla_mesh_reduce(tensors, name)
            raise NotImplementedError()
        elif self.local_rank != -1:
            tensors = distributed_concat(tensors)

        return nested_numpify(tensors)
示例#3
0
 def __init__(self,
              local_rank: int,
              num_examples: int,
              batch_size: Optional[int] = None):
     self.gatherer = DistributedTensorGatherer(
         self._get_world_size(local_rank), num_examples, batch_size)
     self.local_rank = local_rank
     self.batch_size = batch_size
     self.num_examples = num_examples
示例#4
0
    def test_distributed_tensor_gatherer_different_shapes(self):
        # Simulate a result with a dataset of size 21, 4 processes and chunks of lengths 2, 3, 1
        world_size = 4
        num_samples = 21
        input_indices = [
            [0, 1, 6, 7, 12, 13, 18, 19],
            [2, 3, 4, 8, 9, 10, 14, 15, 16, 20, 0, 1],
            [5, 11, 17, 2],
        ]
        sequence_lengths = [8, 10, 13]

        predictions = np.random.normal(size=(num_samples, 13))
        gatherer = DistributedTensorGatherer(world_size=world_size,
                                             num_samples=num_samples)
        for indices, seq_length in zip(input_indices, sequence_lengths):
            gatherer.add_arrays(predictions[indices, :seq_length])
        result = gatherer.finalize()

        # Remove the extra samples added at the end for a round multiple of num processes.
        actual_indices = [
            input_indices[0], input_indices[1][:-2], input_indices[2][:-1]
        ]
        for indices, seq_length in zip(actual_indices, sequence_lengths):
            self.assertTrue(
                np.array_equal(result[indices, :seq_length],
                               predictions[indices, :seq_length]))

        # With nested tensors
        predictions = np.random.normal(size=(num_samples, 13))
        gatherer = DistributedTensorGatherer(world_size=world_size,
                                             num_samples=num_samples)
        for indices, seq_length in zip(input_indices, sequence_lengths):
            gatherer.add_arrays(
                [predictions[indices, :seq_length], predictions[indices]])
        result = gatherer.finalize()

        for indices, seq_length in zip(actual_indices, sequence_lengths):
            self.assertTrue(
                np.array_equal(result[0][indices, :seq_length],
                               predictions[indices, :seq_length]))
        self.assertTrue(np.array_equal(result[1], predictions))

        # Check if works if varying seq_length is second
        gatherer = DistributedTensorGatherer(world_size=world_size,
                                             num_samples=num_samples)
        for indices, seq_length in zip(input_indices, sequence_lengths):
            gatherer.add_arrays(
                [predictions[indices], predictions[indices, :seq_length]])
        result = gatherer.finalize()

        self.assertTrue(np.array_equal(result[0], predictions))
        for indices, seq_length in zip(actual_indices, sequence_lengths):
            self.assertTrue(
                np.array_equal(result[1][indices, :seq_length],
                               predictions[indices, :seq_length]))
示例#5
0
    def prediction_loop(self, data_loader, world_size):
        num_examples = len(data_loader.dataset)
        batch_size = data_loader.batch_size
        eval_losses_gatherer = DistributedTensorGatherer(
            world_size, num_examples, make_multiple_of=batch_size)
        preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
        labels_gatherer = DistributedTensorGatherer(world_size, num_examples)
        losses_host, preds_host, labels_host = None, None, None
        self.model.eval()

        for step, inputs in enumerate(data_loader):
            loss, logits, labels = self.prediction_step(inputs)
            losses = loss.repeat(batch_size)
            losses_host = losses if losses_host is None else torch.cat(
                (losses_host, losses), dim=0)
            preds_host = logits if preds_host is None else trainer_pt_utils.nested_concat(
                preds_host, logits, padding_index=-100)
            labels_host = labels if labels_host is None else trainer_pt_utils.nested_concat(
                labels_host, labels, padding_index=-100)
            eval_losses_gatherer.add_arrays(
                trainer_pt_utils.nested_numpify(losses_host))
            preds_gatherer.add_arrays(
                trainer_pt_utils.nested_numpify(preds_host))
            labels_gatherer.add_arrays(
                trainer_pt_utils.nested_numpify(labels_host))
            losses_host, preds_host, labels_host = None, None, None

        eval_loss = eval_losses_gatherer.finalize()
        preds = preds_gatherer.finalize()
        labels_ids = labels_gatherer.finalize()

        if self.type_score == "PER":
            preds_ids = np.argmax(preds, axis=-1)

            predicted_phonemes = self.processor.batch_decode(
                torch.from_numpy(preds_ids))
            true_phonemes = self.processor.batch_decode(
                torch.from_numpy(labels_ids))

            per = generate_per_score(true_phonemes, predicted_phonemes)

            return per

        elif self.type_score == "WER":
            pred = EvalPrediction(predictions=preds, label_ids=labels_ids)
            pred_logits = pred.predictions
            pred_ids = np.argmax(pred_logits, axis=-1)

            pred.label_ids[pred.label_ids ==
                           -100] = self.processor.tokenizer.pad_token_id

            pred_str = self.processor.batch_decode(pred_ids)

            # we do not want to group tokens when computing the metrics
            label_str = self.processor.batch_decode(pred.label_ids,
                                                    group_tokens=False)

            metrics = compute_wer(pred_str, label_str)
            metrics = denumpify_detensorize(metrics)
            metrics["t_loss"] = eval_loss.mean().item()
            wer = PredictionOutput(preds, labels_ids, metrics).metrics["wer"]

            return wer
示例#6
0
    def prediction_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.

        Works both with or without labels.
        """
        if not isinstance(dataloader.dataset, collections.abc.Sized):
            raise ValueError("dataset must implement __len__")
        prediction_loss_only = (
            prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
        )

        model = self.model
        # multi-gpu eval
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.

        batch_size = dataloader.batch_size
        num_examples = self.num_examples(dataloader)
        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", num_examples)
        logger.info("  Batch size = %d", batch_size)
        losses_host: torch.Tensor = None
        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None

        world_size = 1
        if is_torch_tpu_available():
            world_size = xm.xrt_world_size()
        elif self.args.local_rank != -1:
            world_size = torch.distributed.get_world_size()
        world_size = max(1, world_size)

        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
        if not prediction_loss_only:
            preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
            labels_gatherer = DistributedTensorGatherer(world_size, num_examples)

        model.eval()

        if is_torch_tpu_available():
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

        if self.args.past_index >= 0:
            self._past = None

        self.callback_handler.eval_dataloader = dataloader

        for step, inputs in enumerate(dataloader):
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
            if loss is not None:
                losses = loss.repeat(batch_size)
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
            if logits is not None:
                # preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
                logits_reduced = logits.argmax(-1)
                preds_host = logits_reduced if preds_host is None else nested_concat(preds_host, logits_reduced, padding_index=-100)
            if labels is not None:
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
            self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
            if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
                if not prediction_loss_only:
                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))

                # Set back to None to begin a new accumulation
                losses_host, preds_host, labels_host = None, None, None

        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Gather all remaining tensors and put them back on the CPU
        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
        if not prediction_loss_only:
            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))

        eval_loss = eval_losses_gatherer.finalize()
        preds = preds_gatherer.finalize() if not prediction_loss_only else None
        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}

        if eval_loss is not None:
            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
示例#7
0
    def evaluate(self,
                 dataset,
                 data_collator=None,
                 description="",
                 metric_key_prefix="eval",
                 compute_metrics=None):
        # predicition with single device

        eval_sampler = SequentialSampler(dataset)
        eval_dataloader = DataLoader(
            dataset,
            sampler=eval_sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.data_collator
            if data_collator is None else data_collator,
            num_workers=self.args.dataloader_num_workers)

        batch_size = eval_dataloader.batch_size
        num_examples = len(eval_dataloader.dataset)
        logger.info("***** Running {} *****".format(description))
        logger.info("  Num examples = %d", len(dataset))
        logger.info("  Batch size = %d", self.args.eval_batch_size)
        losses_host: torch.Tensor = None
        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None

        world_size = max(1, self.args.world_size)
        compute_metrics = self.compute_metrics if compute_metrics is None else compute_metrics
        prediction_loss_only = True if compute_metrics is None else None

        eval_losses_gatherer = DistributedTensorGatherer(
            world_size, num_examples, make_multiple_of=batch_size)
        if not prediction_loss_only:
            # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
            # a batch size to the sampler)
            make_multiple_of = None
            if hasattr(eval_dataloader, "sampler") and isinstance(
                    eval_dataloader.sampler, SequentialDistributedSampler):
                make_multiple_of = eval_dataloader.sampler.batch_size
            preds_gatherer = DistributedTensorGatherer(
                world_size, num_examples, make_multiple_of=make_multiple_of)
            labels_gatherer = DistributedTensorGatherer(
                world_size, num_examples, make_multiple_of=make_multiple_of)

        model = self._wrap_model(self.model)
        model.eval()

        all_example_ids = []
        start_time = timeit.default_timer()
        for step, inputs in enumerate(tqdm(eval_dataloader)):
            if 'example_ids' in inputs.keys():
                example_ids = inputs.pop('example_ids')
                all_example_ids += example_ids
            loss, logits, labels = self.prediction_step(
                model, inputs, prediction_loss_only)

            if loss is not None:
                losses = loss.repeat(eval_dataloader.batch_size)
                losses_host = losses if losses_host is None else torch.cat(
                    (losses_host, losses), dim=0)
            if logits is not None:
                preds_host = logits if preds_host is None else nested_concat(
                    preds_host, logits, padding_index=-100)
            if labels is not None:
                labels_host = labels if labels_host is None else nested_concat(
                    labels_host, labels, padding_index=-100)

            # Gather all remaining tensors and put them back on the CPU
        eval_losses_gatherer.add_arrays(nested_numpify(losses_host))
        if not prediction_loss_only:
            preds_gatherer.add_arrays(nested_numpify(preds_host))
            labels_gatherer.add_arrays(nested_numpify(labels_host))

        eval_loss = eval_losses_gatherer.finalize()
        preds = preds_gatherer.finalize() if not prediction_loss_only else None
        label_ids = labels_gatherer.finalize(
        ) if not prediction_loss_only else None

        if compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = compute_metrics(EvalPrediction(predictions=preds,
                                                     label_ids=label_ids),
                                      all_example_ids=all_example_ids
                                      if len(all_example_ids) > 0 else None)
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        eval_time = timeit.default_timer() - start_time
        logger.info("  Evaluation done in total %f secs (%f sec per example)",
                    eval_time, eval_time / len(dataset))

        if eval_loss is not None:
            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        return PredictionOutput(
            predictions=preds,
            label_ids=label_ids,
            metrics=metrics,
            example_ids=None if len(all_example_ids) == 0 else all_example_ids)
    def prediction_loop(
            self,
            dataloader: DataLoader,
            description: str,
            prediction_loss_only: Optional[bool] = None,
            ignore_keys: Optional[List[str]] = None,
            metric_key_prefix: str = "eval",
    ) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
        Works both with or without labels.
        """
        if not isinstance(dataloader.dataset, collections.abc.Sized):
            raise ValueError("dataset must implement __len__")
        prediction_loss_only = (
            prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
        )

        if self.args.deepspeed and not self.args.do_train:
            # no harm, but flagging to the user that deepspeed config is ignored for eval
            # flagging only for when --do_train wasn't passed as only then it's redundant
            logger.info("Detected the deepspeed argument but it will not be used for evaluation")

        model = self.model
        # multi-gpu eval
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
            # Note: in torch.distributed mode, there's no point in wrapping the model
            # inside a DistributedDataParallel as we'll be under `no_grad` anyways

        # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
        # ``train`` is running, half it first and then put on device

        batch_size = dataloader.batch_size
        num_examples = self.num_examples(dataloader)
        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", num_examples)
        logger.info("  Batch size = %d", batch_size)
        losses_host: torch.Tensor = None
        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
        gumbel_host: Union[torch.Tensor, List[torch.Tensor]] = None
        sentence_labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
        sentence_indicator_host: Union[torch.Tensor, List[torch.Tensor]] = None

        world_size = 1
        if is_torch_tpu_available():
            world_size = xm.xrt_world_size()
        elif self.args.local_rank != -1:
            world_size = dist.get_world_size()
        world_size = max(1, world_size)

        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
        if not prediction_loss_only:
            preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
            labels_gatherer = DistributedTensorGatherer(world_size, num_examples)
            gumbel_gatherer = DistributedTensorGatherer(world_size, num_examples)
            sentence_labels_gatherer = DistributedTensorGatherer(world_size, num_examples)
            sentence_indicator_gatherer = DistributedTensorGatherer(world_size, num_examples)

        model.eval()

        if self.args.past_index >= 0:
            self._past = None

        self.callback_handler.eval_dataloader = dataloader

        for step, inputs in enumerate(dataloader):
            loss, logits, labels, gumbel_output, sentence_labels, sentence_indicator = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)

            if loss is not None:
                losses = loss.repeat(batch_size)
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
            if logits is not None:
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
            if labels is not None:
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
            if gumbel_output is not None:
                gumbel_host = gumbel_output if gumbel_host is None else nested_concat(gumbel_host, gumbel_output, padding_index=-1)
            if sentence_labels is not None:
                sentence_labels_host = sentence_labels if sentence_labels_host is None else nested_concat(sentence_labels_host, sentence_labels, padding_index=-1)
            if sentence_indicator is not None:
                sentence_indicator_host = sentence_indicator if sentence_indicator_host is None else nested_concat(sentence_indicator_host, sentence_indicator, padding_index=-100)

            self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
            if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
                if not prediction_loss_only:
                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
                    gumbel_gatherer.add_arrays(self._gather_and_numpify(gumbel_host, "eval_gumbel_output"))
                    sentence_labels_gatherer.add_arrays(self._gather_and_numpify(sentence_labels_host, "eval_sentence_idxs"))
                    sentence_indicator_gatherer.add_arrays(self._gather_and_numpify(sentence_indicator_host, "eval_sentence_indicator"))

                # Set back to None to begin a new accumulation
                losses_host, preds_host, labels_host, gumbel_host, sentence_labels_host, sentence_indicator_host = None, None, None, None, None, None

        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Gather all remaining tensors and put them back on the CPU
        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
        if not prediction_loss_only:
            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
            gumbel_gatherer.add_arrays(self._gather_and_numpify(gumbel_host, "eval_gumbel_output"))
            sentence_labels_gatherer.add_arrays(self._gather_and_numpify(sentence_labels_host, "eval_sentence_idxs"))
            sentence_indicator_gatherer.add_arrays(self._gather_and_numpify(sentence_indicator_host, "eval_sentence_indicator"))

        eval_loss = eval_losses_gatherer.finalize()
        preds = preds_gatherer.finalize() if not prediction_loss_only else None
        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
        gumbel_outputs = gumbel_gatherer.finalize() if not prediction_loss_only else None
        sentence_idxs = sentence_labels_gatherer.finalize() if not prediction_loss_only else None
        sentence_indicators = sentence_indicator_gatherer.finalize() if not prediction_loss_only else None
        print(sentence_idxs, 'test')

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(preds, label_ids, gumbel_outputs, sentence_idxs, sentence_indicators)
        else:
            metrics = {}

        if eval_loss is not None:
            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)