def _compute_validation_metrics(self) -> workload.Response:
        self.context.experimental.reset_reducers()
        # Set the behavior of certain layers (e.g., dropout) that are
        # different between training and inference.
        for model in self.context.models:
            model.eval()

        for callback in self.callbacks.values():
            logging.warning(
                "on_validation_step_start is now deprecated, please use on_validation_start instead"
            )
            callback.on_validation_step_start()

        for callback in self.callbacks.values():
            callback.on_validation_start()

        num_inputs = 0
        metrics = {}  # type: Dict[str, Any]

        if self._evaluate_batch_defined():
            keys = None
            batch_metrics = []

            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            check.gt(len(self.validation_loader), 0)
            for batch in self.validation_loader:
                batch = self.context.to_device(batch)
                num_inputs += pytorch.data_length(batch)

                vld_metrics = self.trial.evaluate_batch(batch=batch)
                # Verify validation metric names are the same across batches.
                if keys is None:
                    keys = vld_metrics.keys()
                else:
                    check.eq(
                        keys,
                        vld_metrics.keys(),
                        "Validation metric names must match across all batches of data.",
                    )
                check.is_instance(
                    vld_metrics,
                    dict,
                    "validation_metrics() must return a "
                    "dictionary of string names to Tensor "
                    "metrics",
                )
                # TODO: For performance perform -> cpu() only at the end of validation.
                batch_metrics.append(
                    self._convert_metrics_to_numpy(vld_metrics))
                if self.env.test_mode:
                    break

            metrics = self._reduce_metrics(
                batch_metrics=batch_metrics,
                keys=keys,
                metrics_reducers=self._prepare_metrics_reducers(keys=keys),
            )

            if self.hvd_config.use:
                num_inputs *= hvd.size()

        else:
            check.true(self._evaluate_full_dataset_defined())
            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            if self.is_chief:
                metrics = self.trial.evaluate_full_dataset(
                    data_loader=self.validation_loader)

                check.is_instance(
                    metrics, dict,
                    f"eval() must return a dictionary, got {type(metrics)}.")

                metrics = self._convert_metrics_to_numpy(metrics)
                num_inputs = self.context.get_per_slot_batch_size() * len(
                    self.validation_loader)

        metrics.update(
            self._convert_metrics_to_numpy(
                self.context.experimental.reduce_metrics(for_training=False)))

        if self.hvd_config.use and any(
                map(
                    lambda c: util.is_overridden(
                        c.on_validation_end, pytorch.
                        PyTorchCallback) or util.is_overridden(
                            c.on_validation_step_end, pytorch.PyTorchCallback),
                    self.callbacks.values(),
                )):
            logging.debug(
                "Broadcasting metrics to all worker processes to execute a "
                "validation step end callback")
            metrics = hvd.broadcast_object(metrics, root_rank=0)

        for callback in self.callbacks.values():
            logging.warning(
                "on_validation_step_end is now deprecated, please use on_validation_end instead"
            )
            callback.on_validation_step_end(metrics)

        for callback in self.callbacks.values():
            callback.on_validation_end(metrics)

        if not self.is_chief:
            return workload.Skipped()

        return {"num_inputs": num_inputs, "validation_metrics": metrics}
    def _compute_validation_metrics(self) -> workload.Response:
        self.context.reset_reducers()
        # Set the behavior of certain layers (e.g., dropout) that are
        # different between training and inference.
        for model in self.context.models:
            model.eval()

        step_start_time = time.time()

        for callback in self.callbacks.values():
            if util.is_overridden(callback.on_validation_step_start,
                                  pytorch.PyTorchCallback):
                logging.warning("on_validation_step_start is now deprecated, "
                                "please use on_validation_start instead")
                callback.on_validation_step_start()

        for callback in self.callbacks.values():
            callback.on_validation_start()

        num_inputs = 0
        metrics = {}  # type: Dict[str, Any]

        if self._evaluate_batch_defined():
            keys = None
            batch_metrics = []

            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            check.gt(len(self.validation_loader), 0)
            for callback in self.callbacks.values():
                callback.on_validation_epoch_start()
            for idx, batch in enumerate(self.validation_loader):
                if self.context.experimental._auto_to_device:
                    batch = self.context.to_device(batch)
                num_inputs += self.trial.get_batch_length(batch)

                if has_param(self.trial.evaluate_batch, "batch_idx", 2):
                    vld_metrics = self.trial.evaluate_batch(batch=batch,
                                                            batch_idx=idx)
                else:
                    vld_metrics = self.trial.evaluate_batch(
                        batch=batch)  # type: ignore
                # Verify validation metric names are the same across batches.
                if keys is None:
                    keys = vld_metrics.keys()
                else:
                    check.eq(
                        keys,
                        vld_metrics.keys(),
                        "Validation metric names must match across all batches of data.",
                    )
                check.is_instance(
                    vld_metrics,
                    dict,
                    "validation_metrics() must return a "
                    "dictionary of string names to Tensor "
                    "metrics",
                )
                # TODO: For performance perform -> cpu() only at the end of validation.
                batch_metrics.append(
                    pytorch._convert_metrics_to_numpy(vld_metrics))
                if self.env.test_mode:
                    break

            for callback in self.callbacks.values():
                callback.on_validation_epoch_end(batch_metrics)

            metrics = pytorch._reduce_metrics(
                self.context.distributed,
                batch_metrics=batch_metrics,
                keys=keys,
                metrics_reducers=pytorch._prepare_metrics_reducers(
                    self.trial.evaluation_reducer(), keys=keys),
            )

            # Gather a list of per-worker (num_inputs, num_batches) tuples.
            input_counts = self.context.distributed.gather(
                (num_inputs, idx + 1))
            if self.context.distributed.rank == 0:
                assert input_counts is not None
                # Reshape and sum.
                num_inputs, num_batches = [sum(n) for n in zip(*input_counts)]

        else:
            check.true(self._evaluate_full_dataset_defined())
            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            if self.is_chief:
                metrics = self.trial.evaluate_full_dataset(
                    data_loader=self.validation_loader)

                check.is_instance(
                    metrics, dict,
                    f"eval() must return a dictionary, got {type(metrics)}.")

                metrics = pytorch._convert_metrics_to_numpy(metrics)
                num_inputs = self.context.get_per_slot_batch_size() * len(
                    self.validation_loader)

        metrics.update(
            pytorch._convert_metrics_to_numpy(
                self.context.reduce_metrics(for_training=False)))

        if self.context.distributed.size > 1 and any(
                map(
                    lambda c: util.is_overridden(
                        c.on_validation_end, pytorch.
                        PyTorchCallback) or util.is_overridden(
                            c.on_validation_step_end, pytorch.PyTorchCallback),
                    self.callbacks.values(),
                )):
            logging.debug(
                "Broadcasting metrics to all worker processes to execute a "
                "validation step end callback")
            metrics = hvd.broadcast_object(metrics, root_rank=0)

        for callback in self.callbacks.values():
            if util.is_overridden(callback.on_validation_step_end,
                                  pytorch.PyTorchCallback):
                logging.warning(
                    "on_validation_step_end is now deprecated, please use on_validation_end instead"
                )
                callback.on_validation_step_end(metrics)

        for callback in self.callbacks.values():
            callback.on_validation_end(metrics)

        if not self.is_chief:
            return {}

        # Skip reporting timings if evaluate_full_dataset() was defined.  This is far less common
        # than evaluate_batch() and we can't know how the user processed their validation data.
        if self._evaluate_batch_defined():
            step_duration = time.time() - step_start_time
            logging.info(
                det.util.make_timing_log("validated", step_duration,
                                         num_inputs, num_batches))

        return {"num_inputs": num_inputs, "validation_metrics": metrics}