Esempio n. 1
0
    def require_horovod_type(self, horovod_type: str, reason: str) -> None:
        """
        Declare the required type of horovod and give a unique reason as to why it is required.

        The reason makes for clear error reporting if require_horovod_type() is called a second
        time but with a different type.
        """

        known_types = {"tensorflow", "tensorflow.keras", "torch"}
        check.is_in(horovod_type, known_types, "Unknown horovod type requested.")

        if self._poly_hvd_type is not None:
            check.eq(
                horovod_type,
                self._poly_hvd_type,
                f"require_horovod_type() called with with type {horovod_type} after a previous "
                f"call with type {self._poly_hvd_type} in the same process. The reason for the "
                f"first call was '{self._poly_hvd_first_reason}'; the reason for this call is "
                f"'{reason}'.",
            )
        else:
            self._poly_hvd_type = horovod_type
            self._poly_hvd_first_reason = reason
            # If horovod has not been imported yet, do it now.
            try:
                self._poly_hvd_module = importlib.import_module(f"horovod.{horovod_type}")
            except ImportError:
                pass
Esempio n. 2
0
    def on_train_batch_end(self, _: int, logs: Any = None) -> None:
        check.is_in("loss", logs)

        # Remove default keras metrics we aren't interested in like "batch" and
        # "size".
        self.metrics.append(
            {k: v
             for k, v in logs.items() if k not in {"batch", "size"}})
        self.batches_processed += 1
        if self.batches_processed != self.tf_keras_trial_controller.batches_per_step:
            return

        check.is_not_none(
            self.tf_keras_trial_controller.train_response_func,
            "Callback should avoid calling model.predict() or change model.stop_training "
            "as this will affect Determined training behavior",
        )
        response_func = cast(
            workload.ResponseFunc,
            self.tf_keras_trial_controller.train_response_func)

        # TODO(DET-1278): Average training metrics across GPUs when using Horovod.
        num_inputs = (self.tf_keras_trial_controller.batches_per_step *
                      self.tf_keras_trial_controller.batch_size)

        if self.tf_keras_trial_controller.is_chief:
            response_func(det.util.make_metrics(num_inputs, self.metrics))
        else:
            response_func(workload.Skipped())

        self.tf_keras_trial_controller.train_response_func = None
        self.metrics = []
        self.batches_processed = 0

        self.tf_keras_trial_controller.run()
Esempio n. 3
0
            def new_train_batch(batch: pytorch.TorchData, epoch_idx: int,
                                batch_idx: int) -> Any:
                tr_metrics = train_batch(
                    batch=batch,
                    model=model,
                    epoch_idx=epoch_idx,
                    batch_idx=batch_idx,
                )
                if isinstance(tr_metrics, torch.Tensor):
                    tr_metrics = {"loss": tr_metrics}
                check.is_instance(
                    tr_metrics,
                    dict,
                    "train_batch() must return a dictionary "
                    f"mapping string names to Tensor metrics, got {type(tr_metrics)}",
                )
                check.is_in("loss", tr_metrics.keys(),
                            'Please include "loss" in you training metrics.')

                def clip_grads(parameters: Iterator) -> None:
                    for callback in self.callbacks.values():
                        callback.on_before_optimizer_step(parameters)

                self.context.backward(tr_metrics["loss"])
                self.context.step_optimizer(self.context.optimizers[0],
                                            clip_grads=clip_grads)

                return tr_metrics
Esempio n. 4
0
 def from_config(cls, config: Dict[str, Any],
                 container_path: Optional[str]) -> "StorageManager":
     allowed_keys = {
         "host_path", "storage_path", "container_path", "propagation"
     }
     for key in config.keys():
         check.is_in(key, allowed_keys, "extra key in shared_fs config")
     check.is_in("host_path", config,
                 "shared_fs config is missing host_path")
     # Ignore legacy configuration values propagation and container_path.
     base_path = _full_storage_path(config["host_path"],
                                    config.get("storage_path"),
                                    container_path)
     return cls(base_path)
Esempio n. 5
0
def binary_error_rate(predictions: torch.Tensor,
                      labels: torch.Tensor) -> float:
    """Return the classification error rate for binary classification."""
    check.eq(predictions.shape[0], labels.shape[0])
    check.is_in(len(predictions.shape), [1, 2])
    if len(predictions.shape) == 2:
        check.eq(predictions.shape[1], 1)
    check.len_eq(labels.shape, 1, "Labels must be a column vector")

    if len(predictions.shape) > 1:
        predictions = torch.squeeze(predictions)

    errors = torch.sum(
        labels.to(torch.long) != torch.round(predictions).to(torch.long))
    result = float(errors) / predictions.shape[0]  # type: float
    return result
Esempio n. 6
0
    def on_train_batch_end(self, _: int, logs: Any = None) -> None:
        check.is_in("loss", logs)

        # Remove default keras metrics we aren't interested in like "batch" and
        # "size".
        self.metrics.append(
            {k: v
             for k, v in logs.items() if k not in {"batch", "size"}})
        self.batches_processed += 1
        if self.batches_processed != self.tf_keras_trial_controller.num_batches:
            return

        check.is_not_none(
            self.tf_keras_trial_controller.train_response_func,
            "Callback should avoid calling model.predict(), "
            "as this will affect Determined training behavior",
        )
        response_func = cast(
            workload.ResponseFunc,
            self.tf_keras_trial_controller.train_response_func)

        # TODO(DET-1278): Average training metrics across GPUs when using Horovod.
        num_inputs = (self.tf_keras_trial_controller.num_batches *
                      self.tf_keras_trial_controller.batch_size)

        if self.tf_keras_trial_controller.is_chief:
            response = {
                "metrics":
                det.util.make_metrics(num_inputs, self.metrics),
                "stop_requested":
                self.tf_keras_trial_controller.context.get_stop_requested(),
            }
            response_func(response)
        else:
            response_func(workload.Skipped())

        self.tf_keras_trial_controller.train_response_func = None
        self.metrics = []
        self.batches_processed = 0

        self.tf_keras_trial_controller.run()

        if self.model.stop_training and version.parse(
                tf.__version__) >= version.parse("2.2.0"):
            # Starting with TF 2.2, `model.stop_training` is only checked at the end of epochs.
            raise det.errors.WorkerFinishedGracefully
Esempio n. 7
0
    def wrap_lr_scheduler(
        self,
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
        step_mode: pytorch.LRScheduler.StepMode,
    ) -> torch.optim.lr_scheduler._LRScheduler:
        """
        Returns a wrapped LR scheduler.

        The LR scheduler must use an optimizer wrapped by :meth:`wrap_optimizer`.  If ``apex.amp``
        is in use, the optimizer must also have been configured with :meth:`configure_apex_amp`.
        """
        if isinstance(lr_scheduler,
                      torch.optim.lr_scheduler.ReduceLROnPlateau):
            if step_mode != pytorch.LRScheduler.StepMode.MANUAL_STEP:
                raise det.errors.InvalidExperimentException(
                    "detected that context.wrap_lr_scheduler() was called with an instance of "
                    "torch.optim.lr_scheduer.ReduceLROnPlateau as the lr_scheduler.  This lr "
                    "scheduler class does not have the usual step() parameters, and so it can "
                    "only be used with step_mode=MANUAL_STEP.\n"
                    "\n"
                    "For example, if you wanted to step it on every validation step, you might "
                    "wrap your lr_scheduler and pass it to a callback like this:\n"
                    "\n"
                    "class MyLRStepper(PyTorchCallback):\n"
                    "    def __init__(self, wrapped_lr_scheduler):\n"
                    "        self.wrapped_lr_scheduler = wrapped_lr_scheduler\n"
                    "\n"
                    "    def on_validation_end(self, metrics):\n"
                    '        self.wrapped_lr_scheduler.step(metrics["validation_error"])\n'
                )

        opt = getattr(lr_scheduler, "optimizer", None)
        if opt is not None:
            check.is_in(
                opt,
                self.optimizers,
                "Must use an optimizer that is returned by wrap_optimizer()",
            )
        wrapped = pytorch.LRScheduler(lr_scheduler, step_mode)
        self.lr_schedulers.append(wrapped)

        # Return the original LR scheduler to the user in case they have customizations that we
        # don't care about.
        return lr_scheduler
Esempio n. 8
0
    def _LRScheduler(
        self,
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
        step_mode: pytorch.LRScheduler.StepMode,
    ) -> pytorch.LRScheduler:
        """Wraps a LR scheduler. It returns a wrapped LR scheduler.

        The LR scheduler must use an optimizer wrapped by :meth:`Optimizer` and configured with
        :meth:`configure_apex_amp`.
        """

        check.is_in(
            lr_scheduler.optimizer,  # type: ignore
            self.optimizers,
            "Must use an optimizer that is returned by Optimizer()",
        )
        wrapped = pytorch.LRScheduler(lr_scheduler, step_mode)
        self.lr_schedulers.append(wrapped)
        return wrapped
Esempio n. 9
0
    def wrap_lr_scheduler(
        self,
        lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
        step_mode: pytorch.LRScheduler.StepMode,
    ) -> torch.optim.lr_scheduler._LRScheduler:
        """Returns a wrapped LR scheduler.

        The LR scheduler must use an optimizer wrapped by :meth:`wrap_optimizer`.  If ``apex.amp``
        is in use, the optimizer must also have been configured with :meth:`configure_apex_amp`.
        """

        check.is_in(
            lr_scheduler.optimizer,  # type: ignore
            self.optimizers,
            "Must use an optimizer that is returned by wrap_optimizer()",
        )
        wrapped = pytorch.LRScheduler(lr_scheduler, step_mode)
        self.lr_schedulers.append(wrapped)

        # Return the original LR scheduler to the user in case they have customizations that we
        # don't care about.
        return lr_scheduler
Esempio n. 10
0
    def from_configs(
        experiment_config: ExperimentConfig,
        rendezvous_info: RendezvousInfo,
        hparams: Dict[str, Any],
    ) -> "HorovodContext":
        """
        Create the HorovodContext according to experiment config and rendezvous info for this trial.
        """

        # Horovod is always used for multi-machine distributed training. For
        # single-machine multi-GPU training, Horovod is used when native_parallel is
        # disabled.
        multi_machine_trial = rendezvous_info.get_size() > 1
        multi_slot_trial = experiment_config["resources"]["slots_per_trial"] > 1
        use_horovod = multi_machine_trial or (
            multi_slot_trial and not experiment_config.native_parallel_enabled()
        )

        check.is_in("optimizations", experiment_config)
        optimizations_config = cast(Dict[str, Any], experiment_config.get("optimizations"))

        check.is_in("aggregation_frequency", optimizations_config)
        check.is_in("gradient_compression", optimizations_config)
        check.is_in("average_training_metrics", optimizations_config)

        # Help users migrate from the old locations for these settings, in hparams.
        def error_message_removed_from_hparams(removed_hparam: str) -> str:
            return (
                f"Please move `{removed_hparam}` in the experiment config to "
                f"`Optimizations` from `hyperparameters`."
            )

        check.not_in(
            "aggregation_frequency",
            hparams,
            error_message_removed_from_hparams("aggregation_frequency"),
        )
        check.not_in(
            "gradient_compression",
            hparams,
            error_message_removed_from_hparams("gradient_compression"),
        )
        check.not_in(
            "grad_updates_size_file",
            hparams,
            error_message_removed_from_hparams("grad_updates_size_file"),
        )

        hvd_config = HorovodContext(
            use=use_horovod,
            aggregation_frequency=cast(int, optimizations_config.get("aggregation_frequency")),
            fp16_compression=cast(bool, optimizations_config.get("gradient_compression")),
            grad_updates_size_file=optimizations_config.get("grad_updates_size_file", None),
            average_aggregated_gradients=cast(
                bool, optimizations_config.get("average_aggregated_gradients")
            ),
            average_training_metrics=cast(
                bool, optimizations_config.get("average_training_metrics")
            ),
        )

        if hvd_config.use and hvd_config.aggregation_frequency > 1:
            logging.info(
                f"Setting `aggregation_frequency` to {hvd_config.aggregation_frequency} "
                "to optimize training."
            )

        if hvd_config.use and hvd_config.fp16_compression:
            logging.info("Enabling `gradient_compression` to optimize training.")

        return hvd_config
Esempio n. 11
0
    def _train_for_step(self, step_id: int, num_batches: int,
                        total_batches_processed: int) -> workload.Response:
        check.gt(step_id, 0)

        # Set the behavior of certain layers (e.g., dropout) that are different
        # between training and inference.
        for model in self.context.models:
            model.train()

        start = total_batches_processed
        end = start + num_batches

        per_batch_metrics = []  # type: List[Dict]
        num_inputs = 0

        for batch_idx in range(start, end):
            batch = next(self.training_iterator)
            num_inputs += data_length(batch)
            batch = self.context._to_device(batch)

            self.context._current_batch_idx = batch_idx
            self.context._loss_ids = {}
            tr_metrics = self.trial.train_batch(
                batch=batch,
                model=self.context.models[0],
                epoch_idx=self.get_epoch_idx(batch_idx),
                batch_idx=batch_idx,
            )
            if isinstance(tr_metrics, torch.Tensor):
                tr_metrics = {"loss": tr_metrics}
            check.is_instance(
                tr_metrics,
                dict,
                "train_batch() must return a dictionary "
                f"mapping string names to Tensor metrics, got {type(tr_metrics)}",
            )
            check.is_in("loss", tr_metrics.keys(),
                        'Please include "loss" in you training metrics.')

            # Step learning rate of a LRScheduler.
            for lr_scheduler in self.context.lr_schedulers:
                self._auto_step_lr_scheduler_per_batch(batch_idx, lr_scheduler)

            for name, metric in tr_metrics.items():
                # Convert PyTorch metric values to NumPy, so that
                # `det.util.encode_json` handles them properly without
                # needing a dependency on PyTorch.
                if isinstance(metric, torch.Tensor):
                    metric = metric.cpu().detach().numpy()
                tr_metrics[name] = metric

            check.is_in("loss", tr_metrics,
                        'Please include "loss" in your training metrics.')
            per_batch_metrics.append(tr_metrics)

        # Aggregate and reduce training metrics from all the training processes.
        if self.hvd_config.use and self.hvd_config.average_training_metrics:
            per_batch_metrics = self._average_training_metrics(
                per_batch_metrics)
        if self.hvd_config.use:
            num_inputs *= hvd.size()
        metrics = det.util.make_metrics(num_inputs, per_batch_metrics)

        if not self.is_chief:
            # The training metrics are reported only in the chief process.
            return workload.Skipped()

        logging.debug(
            f"Done training step: {num_inputs} records in {num_batches} batches."
        )

        return metrics
Esempio n. 12
0
    def _train_for_step(self, step_id: int,
                        batches_per_step: int) -> workload.Response:
        check.gt(step_id, 0)

        step_idx = step_id - 1
        start = step_idx * batches_per_step
        end = start + batches_per_step

        # Set the behavior of certain layers (e.g., dropout) that are different
        # between training and inference.
        self.model.train()

        per_batch_metrics = []  # type: List[Dict]
        num_inputs = 0

        for batch_idx in range(start, end):
            batch = next(self.training_iterator)
            num_inputs += data_length(batch)

            batch = self._to_device(batch)
            # Forward pass.
            tr_metrics = self.trial.train_batch(
                batch=batch,
                model=self.model,
                epoch_idx=self.get_epoch_idx(batch_idx),
                batch_idx=batch_idx,
            )

            if isinstance(tr_metrics, torch.Tensor):
                tr_metrics = {"loss": tr_metrics}

            check.is_instance(
                tr_metrics,
                dict,
                "train_batch() must return a dictionary "
                "mapping string names to Tensor metrics, got {type(tr_metrics)}",
            )
            check.is_in("loss", tr_metrics.keys(),
                        'Please include "loss" in you training metrics.')

            # Backwards pass.
            loss = tr_metrics["loss"]
            communicate_and_update = (
                batch_idx + 1) % self.hvd_config.aggregation_frequency == 0
            if self.use_amp():
                with apex.amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
                    if self.hvd_config.use and communicate_and_update:
                        self.optimizer.synchronize()
            else:
                loss.backward()

            if communicate_and_update:
                parameters = (self.model.parameters() if not self.use_amp()
                              else apex.amp.master_params(self.optimizer))

                if self.hvd_config.average_aggregated_gradients:
                    self._average_gradients(
                        parameters=parameters,
                        divisor=self.hvd_config.aggregation_frequency)

                self._clip_grads(parameters)

                if self.hvd_config.use and self.use_amp():
                    with self.optimizer.skip_synchronize():
                        self.optimizer.step()
                else:
                    self.optimizer.step()
                self.optimizer.zero_grad()

                if self.lr_helper.should_step_lr(
                        batches_completed=batch_idx + 1,
                        epoch_length=len(self.training_loader),
                        aggregation_frequency=self.hvd_config.
                        aggregation_frequency,
                ):
                    self.lr_helper.step()

            for name, metric in tr_metrics.items():
                # Convert PyTorch metric values to NumPy, so that
                # `det.util.encode_json` handles them properly without
                # needing a dependency on PyTorch.
                if isinstance(metric, torch.Tensor):
                    metric = metric.cpu().detach().numpy()
                tr_metrics[name] = metric

            check.is_in("loss", tr_metrics,
                        'Please include "loss" in your training metrics.')
            per_batch_metrics.append(tr_metrics)

        if self.hvd_config.use and self.hvd_config.average_training_metrics:
            per_batch_metrics = self._average_training_metrics(
                per_batch_metrics)

        if not self.is_chief:
            return workload.Skipped()

        if self.hvd_config.use:
            num_inputs *= hvd.size()

        logging.debug(
            f"Done training step: {num_inputs} records in {batches_per_step} batches."
        )
        return det.util.make_metrics(num_inputs, per_batch_metrics)
Esempio n. 13
0
    def _apply_backwards_compatibility(self) -> None:
        # TODO(DET-3262): remove this backward compatibility of old interface.
        if (util.is_overridden(self.trial.build_model, PyTorchTrial)
                or util.is_overridden(self.trial.optimizer, PyTorchTrial)
                or util.is_overridden(self.trial.create_lr_scheduler,
                                      PyTorchTrial)):
            logging.warning(
                "build_model(), optimizer(), and create_lr_scheduler(), which belong to "
                "the old interface, are deprecated. Please see the following documentation "
                "of PyTorchTrial for the new interface \n"
                f"{PyTorchTrial.__doc__}")
            logging.warning(
                "The callback on_before_optimizer_step is deprecated."
                "Please use context.step_optimizer to clip gradients.")
            check.true(
                util.is_overridden(self.trial.build_model, PyTorchTrial)
                and util.is_overridden(self.trial.optimizer, PyTorchTrial),
                "Both build_model() and optimizer() must be defined "
                "if any of build_model(), optimizer(), and create_lr_scheduler() are defined. "
                "If you want to use the new interface, you should instead instantiate your models, "
                "optimizers, and LR schedulers in __init__ and call context.backward(loss) "
                "and context.step_optimizer(optimizer) in train_batch.",
            )

            model = self.context.wrap_model(self.trial.build_model())
            optim = self.context.wrap_optimizer(self.trial.optimizer(model))

            lr_scheduler = self.trial.create_lr_scheduler(optim)
            if lr_scheduler is not None:
                opt = getattr(lr_scheduler._scheduler, "optimizer", None)
                if opt is not None:
                    check.is_in(
                        opt,
                        self.context.optimizers,
                        "Must use a wrapped optimizer that is passed in by the optimizer "
                        "argument of create_lr_scheduler",
                    )
                self.context.lr_schedulers.append(lr_scheduler)

            if det.ExperimentConfig(self.context.get_experiment_config()
                                    ).mixed_precision_enabled():
                logging.warning(
                    "The experiment configuration field optimization.mixed_precision is deprecated."
                    "Please use configure_apex_amp in __init__ to configrue apex amp. "
                    "See the following documentation of PyTorchTrial for the new interface \n"
                    f"{PyTorchTrial.__doc__}")
                self.context.configure_apex_amp(
                    models=model,
                    optimizers=optim,
                    opt_level=self.context.get_experiment_config().get(
                        "optimizations", {}).get("mixed_precision", "O0"),
                )

            # Backward compatibility: train_batch
            train_batch = cast(Callable, self.trial.train_batch)

            def new_train_batch(batch: pytorch.TorchData, epoch_idx: int,
                                batch_idx: int) -> Any:
                tr_metrics = train_batch(
                    batch=batch,
                    model=model,
                    epoch_idx=epoch_idx,
                    batch_idx=batch_idx,
                )
                if isinstance(tr_metrics, torch.Tensor):
                    tr_metrics = {"loss": tr_metrics}
                check.is_instance(
                    tr_metrics,
                    dict,
                    "train_batch() must return a dictionary "
                    f"mapping string names to Tensor metrics, got {type(tr_metrics)}",
                )
                check.is_in("loss", tr_metrics.keys(),
                            'Please include "loss" in you training metrics.')

                def clip_grads(parameters: Iterator) -> None:
                    for callback in self.callbacks.values():
                        callback.on_before_optimizer_step(parameters)

                self.context.backward(tr_metrics["loss"])
                self.context.step_optimizer(self.context.optimizers[0],
                                            clip_grads=clip_grads)

                return tr_metrics

            self.trial.__setattr__("train_batch", new_train_batch)

            # Backward compatibility: evaluate_batch
            if self._evaluate_batch_defined():
                evaluate_batch = cast(Callable, self.trial.evaluate_batch)

                def new_evaluate_batch(batch: pytorch.TorchData) -> Any:
                    return evaluate_batch(model=model, batch=batch)

                self.trial.__setattr__("evaluate_batch", new_evaluate_batch)

            # Backward compatibility: evaluate_full_dataset
            if self._evaluate_full_dataset_defined():
                evaluate_full_dataset = cast(Callable,
                                             self.trial.evaluate_full_dataset)

                def new_evaluate_full_dataset(
                        data_loader: torch.utils.data.DataLoader) -> Any:
                    return evaluate_full_dataset(model=model,
                                                 data_loader=data_loader)

                self.trial.__setattr__("evaluate_full_dataset",
                                       new_evaluate_full_dataset)
Esempio n. 14
0
    def _train_for_step(self, step_id: int, batches_per_step: int) -> workload.Response:
        check.gt(step_id, 0)

        # Set the behavior of certain layers (e.g., dropout) that are different
        # between training and inference.
        self.context.model.train()

        for callback in self.callbacks.values():
            callback.on_train_step_start(step_id)

        step_idx = step_id - 1
        start = step_idx * batches_per_step
        end = start + batches_per_step

        per_batch_metrics = []  # type: List[Dict]
        num_inputs = 0

        for batch_idx in range(start, end):
            batch = next(self.training_iterator)
            num_inputs += data_length(batch)

            batch = self._to_device(batch)
            # Forward pass.
            tr_metrics = self.trial.train_batch(
                batch=batch,
                model=self.context.model,
                epoch_idx=self.get_epoch_idx(batch_idx),
                batch_idx=batch_idx,
            )

            if isinstance(tr_metrics, torch.Tensor):
                tr_metrics = {"loss": tr_metrics}

            check.is_instance(
                tr_metrics,
                dict,
                "train_batch() must return a dictionary "
                "mapping string names to Tensor metrics, got {type(tr_metrics)}",
            )
            check.is_in("loss", tr_metrics.keys(), 'Please include "loss" in you training metrics.')

            # Backwards pass.
            loss = tr_metrics["loss"]
            communicate_and_update = (batch_idx + 1) % self.hvd_config.aggregation_frequency == 0
            if self.use_amp():
                with apex.amp.scale_loss(loss, self.context.optimizer) as scaled_loss:
                    scaled_loss.backward()
                    if self.hvd_config.use and communicate_and_update:
                        # When using horovod, we need to finish communicating gradient
                        # updates before they are unscaled which happens when we exit
                        # of this context manager.
                        self.context.optimizer.synchronize()
            else:
                loss.backward()

                # Communication needs to be synchronized so that is completed
                # before we apply gradient clipping and `step()`.
                if communicate_and_update and self.hvd_config.use:
                    self.context.optimizer.synchronize()

            if communicate_and_update:
                parameters = (
                    self.context.model.parameters()
                    if not self.use_amp()
                    else apex.amp.master_params(self.context.optimizer)
                )

                if self.hvd_config.average_aggregated_gradients:
                    self._average_gradients(
                        parameters=parameters, divisor=self.hvd_config.aggregation_frequency
                    )

                # TODO: Remove this check in v0.12.8.
                check.false(
                    self.env.hparams.get("clip_grad_l2_norm", None)
                    or self.env.hparams.get("clip_grad_val", None),
                    "Please specify gradient clipping via callbacks.",
                )

                for callback in self.callbacks.values():
                    callback.on_before_optimizer_step(parameters)

                if self.hvd_config.use:
                    with self.context.optimizer.skip_synchronize():
                        self.context.optimizer.step()
                else:
                    self.context.optimizer.step()
                self.context.optimizer.zero_grad()

                # Step learning rate of a LRScheduler.
                if self.context.lr_scheduler is not None:
                    self._auto_step_lr_scheduler_per_batch(batch_idx, self.context.lr_scheduler)

            for name, metric in tr_metrics.items():
                # Convert PyTorch metric values to NumPy, so that
                # `det.util.encode_json` handles them properly without
                # needing a dependency on PyTorch.
                if isinstance(metric, torch.Tensor):
                    metric = metric.cpu().detach().numpy()
                tr_metrics[name] = metric

            check.is_in("loss", tr_metrics, 'Please include "loss" in your training metrics.')
            per_batch_metrics.append(tr_metrics)

        if self.hvd_config.use and self.hvd_config.average_training_metrics:
            per_batch_metrics = self._average_training_metrics(per_batch_metrics)

        if self.hvd_config.use:
            num_inputs *= hvd.size()

        metrics = det.util.make_metrics(num_inputs, per_batch_metrics)

        for callback in self.callbacks.values():
            callback.on_train_step_end(step_id, metrics)

        if not self.is_chief:
            return workload.Skipped()

        logging.debug(f"Done training step: {num_inputs} records in {batches_per_step} batches.")

        return metrics