Exemple #1
0
    def _load(self) -> None:
        if not self.load_path:
            return

        # Backwards compat with older checkpoint formats. List is newest to
        # oldest known state_dict locations.
        potential_paths = [
            ["state_dict.pth"],
            ["determined", "state_dict.pth"],
            ["pedl", "state_dict.pth"],
            ["checkpoint.pt"],
        ]

        checkpoint: Optional[Dict[str, Any]] = None
        for ckpt_path in potential_paths:
            maybe_ckpt = self.load_path.joinpath(*ckpt_path)
            if maybe_ckpt.exists():
                checkpoint = torch.load(str(maybe_ckpt),
                                        map_location="cpu")  # type: ignore
                break
        if checkpoint is None or not isinstance(checkpoint, dict):
            return

        for callback in self.callbacks.values():
            callback.on_checkpoint_load_start(checkpoint)

        if "model_state_dict" in checkpoint:
            # Backward compatible with older checkpoint format.
            check.not_in("models_state_dict", checkpoint)
            check.eq(len(self.context.models), 1)
            self.context.models[0].load_state_dict(
                checkpoint["model_state_dict"])
        else:
            for idx, model in enumerate(self.context.models):
                model.load_state_dict(checkpoint["models_state_dict"][idx])

        if "optimizer_state_dict" in checkpoint:
            # Backward compatible with older checkpoint format.
            check.not_in("optimizers_state_dict", checkpoint)
            check.eq(len(self.context.optimizers), 1)
            self.context.optimizers[0].load_state_dict(
                checkpoint["optimizer_state_dict"])
        else:
            for idx, optimizer in enumerate(self.context.optimizers):
                optimizer.load_state_dict(
                    checkpoint["optimizers_state_dict"][idx])

        if "lr_scheduler" in checkpoint:
            # Backward compatible with older checkpoint format.
            check.not_in("lr_schedulers_state_dict", checkpoint)
            check.eq(len(self.context.lr_schedulers), 1)
            self.context.lr_schedulers[0].load_state_dict(
                checkpoint["lr_scheduler"])
        else:
            for idx, lr_scheduler in enumerate(self.context.lr_schedulers):
                lr_scheduler.load_state_dict(
                    checkpoint["lr_schedulers_state_dict"][idx])

        if "scaler_state_dict":
            if self.context._scaler:
                self.context._scaler.load_state_dict(
                    checkpoint["scaler_state_dict"])
            else:
                logging.warning(
                    "There exists scaler_state_dict in checkpoint but the experiment is not using "
                    "AMP.")
        else:
            if self.context._scaler:
                logging.warning(
                    "The experiment is using AMP but scaler_state_dict does not exist in the "
                    "checkpoint.")

        if "amp_state" in checkpoint:
            if self.context._use_apex:
                apex.amp.load_state_dict(checkpoint["amp_state"])
            else:
                logging.warning(
                    "There exists amp_state in checkpoint but the experiment is not using Apex."
                )
        else:
            if self.context._use_apex:
                logging.warning(
                    "The experiment is using Apex but amp_state does not exist in the checkpoint."
                )

        if "rng_state" in checkpoint:
            rng_state = checkpoint["rng_state"]
            np.random.set_state(rng_state["np_rng_state"])
            random.setstate(rng_state["random_rng_state"])
            torch.random.set_rng_state(rng_state["cpu_rng_state"])

            if torch.cuda.device_count():
                if "gpu_rng_state" in rng_state:
                    torch.cuda.set_rng_state(
                        rng_state["gpu_rng_state"],
                        device=self.context.distributed.get_local_rank())
                else:
                    logging.warning(
                        "The system has a gpu but no gpu_rng_state exists in the checkpoint."
                    )
            else:
                if "gpu_rng_state" in rng_state:
                    logging.warning(
                        "There exists gpu_rng_state in checkpoint but the system has no gpu."
                    )
        else:
            logging.warning("The checkpoint has no random state to restore.")

        callback_state = checkpoint.get("callbacks", {})
        for name in self.callbacks:
            if name in callback_state:
                self.callbacks[name].load_state_dict(callback_state[name])
            elif util.is_overridden(self.callbacks[name].load_state_dict,
                                    pytorch.PyTorchCallback):
                logging.warning(
                    "Callback '{}' implements load_state_dict(), but no callback state "
                    "was found for that name when restoring from checkpoint. This "
                    "callback will be initialized from scratch")
Exemple #2
0
    def from_configs(
        experiment_config: ExperimentConfig,
        rendezvous_info: RendezvousInfo,
        hparams: Dict[str, Any],
    ) -> "HorovodContext":
        """
        Create the HorovodContext according to experiment config and rendezvous info for this trial.
        """

        # Horovod is always used for multi-machine distributed training. For
        # single-machine multi-GPU training, Horovod is used when native_parallel is
        # disabled.
        multi_machine_trial = rendezvous_info.get_size() > 1
        multi_slot_trial = experiment_config["resources"]["slots_per_trial"] > 1
        use_horovod = multi_machine_trial or (
            multi_slot_trial
            and not experiment_config.native_parallel_enabled())

        check.is_in("optimizations", experiment_config)
        optimizations_config = cast(Dict[str, Any],
                                    experiment_config.get("optimizations"))

        check.is_in("aggregation_frequency", optimizations_config)
        check.is_in("gradient_compression", optimizations_config)
        check.is_in("average_training_metrics", optimizations_config)

        # Help users migrate from the old locations for these settings, in hparams.
        def error_message_removed_from_hparams(removed_hparam: str) -> str:
            return (
                f"Please move `{removed_hparam}` in the experiment config to "
                f"`Optimizations` from `hyperparameters`.")

        check.not_in(
            "aggregation_frequency",
            hparams,
            error_message_removed_from_hparams("aggregation_frequency"),
        )
        check.not_in(
            "gradient_compression",
            hparams,
            error_message_removed_from_hparams("gradient_compression"),
        )
        check.not_in(
            "grad_updates_size_file",
            hparams,
            error_message_removed_from_hparams("grad_updates_size_file"),
        )

        hvd_config = HorovodContext(
            use=use_horovod,
            aggregation_frequency=cast(
                int, optimizations_config.get("aggregation_frequency")),
            fp16_compression=cast(
                bool, optimizations_config.get("gradient_compression")),
            grad_updates_size_file=optimizations_config.get(
                "grad_updates_size_file", None),
            average_aggregated_gradients=cast(
                bool,
                optimizations_config.get("average_aggregated_gradients")),
            average_training_metrics=cast(
                bool, optimizations_config.get("average_training_metrics")),
        )

        if hvd_config.use and hvd_config.aggregation_frequency > 1:
            logging.info(
                f"Setting `aggregation_frequency` to {hvd_config.aggregation_frequency} "
                "to optimize training.")

        if hvd_config.use and hvd_config.fp16_compression:
            logging.info(
                "Enabling `gradient_compression` to optimize training.")

        return hvd_config
    def _load(self, load_path: pathlib.Path) -> None:
        # Backwards compat with older checkpoint formats. List is newest to
        # oldest known state_dict locations.
        potential_paths = [
            ["state_dict.pth"],
            ["determined", "state_dict.pth"],
            ["pedl", "state_dict.pth"],
            ["checkpoint.pt"],
        ]

        checkpoint: Optional[Dict[str, Any]] = None
        for ckpt_path in potential_paths:
            maybe_ckpt = load_path.joinpath(*ckpt_path)
            if maybe_ckpt.exists():
                checkpoint = torch.load(str(maybe_ckpt),
                                        map_location="cpu")  # type: ignore
                break
        if checkpoint is None or not isinstance(checkpoint, dict):
            return

        for callback in self.callbacks.values():
            callback.on_checkpoint_load_start(checkpoint)

        if "model_state_dict" in checkpoint:
            # Backward compatible with older checkpoint format.
            check.not_in("models_state_dict", checkpoint)
            check.eq(len(self.context.models), 1)
            self.context.models[0].load_state_dict(
                checkpoint["model_state_dict"])
        else:
            for idx, model in enumerate(self.context.models):
                model_state_dict = checkpoint["models_state_dict"][idx]
                try:
                    model.load_state_dict(model_state_dict)
                except Exception:
                    # If the checkpointed model is non-DDP and the current model is DDP, append
                    # module prefix to the checkpointed data
                    if isinstance(model,
                                  torch.nn.parallel.DistributedDataParallel):
                        logging.debug(
                            "Loading non-DDP checkpoint into a DDP model")
                        self._add_prefix_in_state_dict_if_not_present(
                            model_state_dict, "module.")
                    else:
                        # If the checkpointed model is DDP and we are currently running in
                        # single-slot mode, remove the module prefix from checkpointed data
                        logging.debug(
                            "Loading DDP checkpoint into a non-DDP model")
                        torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
                            model_state_dict, "module.")
                    model.load_state_dict(model_state_dict)

        if "optimizer_state_dict" in checkpoint:
            # Backward compatible with older checkpoint format.
            check.not_in("optimizers_state_dict", checkpoint)
            check.eq(len(self.context.optimizers), 1)
            self.context.optimizers[0].load_state_dict(
                checkpoint["optimizer_state_dict"])
        else:
            for idx, optimizer in enumerate(self.context.optimizers):
                optimizer.load_state_dict(
                    checkpoint["optimizers_state_dict"][idx])

        if "lr_scheduler" in checkpoint:
            # Backward compatible with older checkpoint format.
            check.not_in("lr_schedulers_state_dict", checkpoint)
            check.eq(len(self.context.lr_schedulers), 1)
            self.context.lr_schedulers[0].load_state_dict(
                checkpoint["lr_scheduler"])
        else:
            for idx, lr_scheduler in enumerate(self.context.lr_schedulers):
                lr_scheduler.load_state_dict(
                    checkpoint["lr_schedulers_state_dict"][idx])

        if "scaler_state_dict" in checkpoint:
            if self.context._scaler:
                self.context._scaler.load_state_dict(
                    checkpoint["scaler_state_dict"])
            else:
                logging.warning(
                    "There exists scaler_state_dict in checkpoint but the experiment is not using "
                    "AMP.")
        else:
            if self.context._scaler:
                logging.warning(
                    "The experiment is using AMP but scaler_state_dict does not exist in the "
                    "checkpoint.")

        if "amp_state" in checkpoint:
            if self.context._use_apex:
                apex.amp.load_state_dict(checkpoint["amp_state"])
            else:
                logging.warning(
                    "There exists amp_state in checkpoint but the experiment is not using Apex."
                )
        else:
            if self.context._use_apex:
                logging.warning(
                    "The experiment is using Apex but amp_state does not exist in the checkpoint."
                )

        if "rng_state" in checkpoint:
            rng_state = checkpoint["rng_state"]
            np.random.set_state(rng_state["np_rng_state"])
            random.setstate(rng_state["random_rng_state"])
            torch.random.set_rng_state(rng_state["cpu_rng_state"])

            if torch.cuda.device_count():
                if "gpu_rng_state" in rng_state:
                    torch.cuda.set_rng_state(
                        rng_state["gpu_rng_state"],
                        device=self.context.distributed.local_rank)
                else:
                    logging.warning(
                        "The system has a gpu but no gpu_rng_state exists in the checkpoint."
                    )
            else:
                if "gpu_rng_state" in rng_state:
                    logging.warning(
                        "There exists gpu_rng_state in checkpoint but the system has no gpu."
                    )
        else:
            logging.warning("The checkpoint has no random state to restore.")

        callback_state = checkpoint.get("callbacks", {})
        for name in self.callbacks:
            if name in callback_state:
                self.callbacks[name].load_state_dict(callback_state[name])
            elif util.is_overridden(self.callbacks[name].load_state_dict,
                                    pytorch.PyTorchCallback):
                logging.warning(
                    "Callback '{}' implements load_state_dict(), but no callback state "
                    "was found for that name when restoring from checkpoint. This "
                    "callback will be initialized from scratch")

        # Load workload sequencer state.
        wlsq_path = load_path.joinpath("workload_sequencer.pkl")
        if self.wlsq is not None and wlsq_path.exists():
            with wlsq_path.open("rb") as f:
                self.wlsq.load_state(pickle.load(f))