Beispiel #1
0
def load_model(
    ckpt_dir: pathlib.Path, metadata: Dict[str, Any], **kwargs: Any
) -> Union[PyTorchTrial, torch.nn.Module]:
    checkpoint = torch.load(ckpt_dir.joinpath("state_dict.pth"), **kwargs)  # type: ignore

    trial_cls, trial_context = experimental._load_trial_on_local(
        ckpt_dir.joinpath("code"),
        training=False,
        config=metadata["experiment_config"],
        hparams=metadata["hparams"],
    )

    trial_context = cast(PyTorchTrialContext, trial_context)
    trial = cast(PyTorchTrial, trial_cls(trial_context))
    if "model_state_dict" in checkpoint:
        # Backward compatible with older checkpoint format.
        model = trial.build_model()
        model.load_state_dict(checkpoint["model_state_dict"])
        return model
    else:
        # Backward compatible with older interface
        if util.is_overridden(trial.build_model, PyTorchTrial):
            model = trial.build_model()
            model.load_state_dict(checkpoint["models_state_dict"][0])
            return model
        else:
            for idx, model in enumerate(trial_context.models):
                model.load_state_dict(checkpoint["models_state_dict"][idx])
            return trial
Beispiel #2
0
    def _load(self) -> None:
        if not self.load_path:
            return

        # Backwards compat with older checkpoint formats. List is newest to
        # oldest known state_dict locations.
        potential_paths = [
            ["state_dict.pth"],
            ["determined", "state_dict.pth"],
            ["pedl", "state_dict.pth"],
            ["checkpoint.pt"],
        ]

        for ckpt_path in potential_paths:
            maybe_ckpt = self.load_path.joinpath(*ckpt_path)
            if maybe_ckpt.exists():
                checkpoint = torch.load(maybe_ckpt,
                                        map_location="cpu")  # type: ignore
                break

        self.context.model.load_state_dict(checkpoint["model_state_dict"])
        self.context.optimizer.load_state_dict(
            checkpoint["optimizer_state_dict"])
        self.lr_helper.load_state_dict(checkpoint.get("lr_scheduler"))

        callback_state = checkpoint.get("callbacks", {})
        for name in self.callbacks:
            if name in callback_state:
                self.callbacks[name].load_state_dict(callback_state[name])
            elif util.is_overridden(self.callbacks[name].load_state_dict,
                                    _callback.PyTorchCallback):
                logging.warning(
                    "Callback '{}' implements load_state_dict(), but no callback state "
                    "was found for that name when restoring from checkpoint. This "
                    "callback will be initialized from scratch")
def _check_if_aggregation_frequency_will_work(
    model: tf.keras.Model,
    hvd_config: horovod.HorovodContext,
) -> None:
    if not hvd_config.use or hvd_config.aggregation_frequency == 1:
        return

    if model._is_graph_network or isinstance(model, sequential.Sequential):
        return

    if version.parse(tf.__version__) >= version.parse("2.4.0"):
        return

    if util.is_overridden(model.train_step, tf.keras.Model):
        logging.warning(
            "If you subclassing tf.keras.Model in TF 2.2 or TF 2.3 and defining "
            "a custom train_step() method, in order to use aggregation_frequency > 1 "
            "you need to include the following steps in your train_step(): "
            "For each optimizer you must call: `aggregated_gradients = "
            "optimizer._aggregate_gradients(grads, vars)` and then call "
            "`optimizer.apply_gradients(zip(aggregated_gradients, vars), "
            " experimental_aggregate_gradients=False)`.")
    def _load(self) -> None:
        if not self.load_path:
            return

        # Backwards compat with older checkpoint formats. List is newest to
        # oldest known state_dict locations.
        potential_paths = [
            ["state_dict.pth"],
            ["determined", "state_dict.pth"],
            ["pedl", "state_dict.pth"],
            ["checkpoint.pt"],
        ]

        for ckpt_path in potential_paths:
            maybe_ckpt = self.load_path.joinpath(*ckpt_path)
            if maybe_ckpt.exists():
                checkpoint = torch.load(str(maybe_ckpt),
                                        map_location="cpu")  # type: ignore
                break

        if "model_state_dict" in checkpoint:
            # Backward compatible with older checkpoint format.
            check.not_in("models_state_dict", checkpoint)
            check.eq(len(self.context.models), 1)
            self.context.models[0].load_state_dict(
                checkpoint["model_state_dict"])
        else:
            for idx, model in enumerate(self.context.models):
                model.load_state_dict(checkpoint["models_state_dict"][idx])

        if "optimizer_state_dict" in checkpoint:
            # Backward compatible with older checkpoint format.
            check.not_in("optimizers_state_dict", checkpoint)
            check.eq(len(self.context.optimizers), 1)
            self.context.optimizers[0].load_state_dict(
                checkpoint["optimizer_state_dict"])
        else:
            for idx, optimizer in enumerate(self.context.optimizers):
                optimizer.load_state_dict(
                    checkpoint["optimizers_state_dict"][idx])

        if "lr_scheduler" in checkpoint:
            # Backward compatible with older checkpoint format.
            check.not_in("lr_schedulers_state_dict", checkpoint)
            check.eq(len(self.context.lr_schedulers), 1)
            self.context.lr_schedulers[0].load_state_dict(
                checkpoint["lr_scheduler"])
        else:
            for idx, lr_scheduler in enumerate(self.context.lr_schedulers):
                lr_scheduler.load_state_dict(
                    checkpoint["lr_schedulers_state_dict"][idx])

        if "amp_state" in checkpoint:
            if self.context._use_amp:
                apex.amp.load_state_dict(checkpoint["amp_state"])
            else:
                logging.warning(
                    "There exists amp_state in checkpoint but the experiment is not using AMP."
                )
        else:
            if self.context._use_amp:
                logging.warning(
                    "The experiment is using AMP but amp_state does not exist in the checkpoint."
                )

        if "rng_state" in checkpoint:
            rng_state = checkpoint["rng_state"]
            np.random.set_state(rng_state["np_rng_state"])
            random.setstate(rng_state["random_rng_state"])
            torch.random.set_rng_state(
                rng_state["cpu_rng_state"])  # type: ignore

            if torch.cuda.device_count():
                if "gpu_rng_state" in rng_state:
                    torch.cuda.set_rng_state(  # type: ignore
                        rng_state["gpu_rng_state"],
                        device=self.context.distributed.get_local_rank())
                else:
                    logging.warning(
                        "The system has a gpu but no gpu_rng_state exists in the checkpoint."
                    )
            else:
                if "gpu_rng_state" in rng_state:
                    logging.warning(
                        "There exists gpu_rng_state in checkpoint but the system has no gpu."
                    )
        else:
            logging.warning("The checkpoint has no random state to restore.")

        callback_state = checkpoint.get("callbacks", {})
        for name in self.callbacks:
            if name in callback_state:
                self.callbacks[name].load_state_dict(callback_state[name])
            elif util.is_overridden(self.callbacks[name].load_state_dict,
                                    pytorch.PyTorchCallback):
                logging.warning(
                    "Callback '{}' implements load_state_dict(), but no callback state "
                    "was found for that name when restoring from checkpoint. This "
                    "callback will be initialized from scratch")
    def _compute_validation_metrics(self) -> workload.Response:
        self.context.experimental.reset_reducers()
        # Set the behavior of certain layers (e.g., dropout) that are
        # different between training and inference.
        for model in self.context.models:
            model.eval()

        for callback in self.callbacks.values():
            logging.warning(
                "on_validation_step_start is now deprecated, please use on_validation_start instead"
            )
            callback.on_validation_step_start()

        for callback in self.callbacks.values():
            callback.on_validation_start()

        num_inputs = 0
        metrics = {}  # type: Dict[str, Any]

        if self._evaluate_batch_defined():
            keys = None
            batch_metrics = []

            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            check.gt(len(self.validation_loader), 0)
            for batch in self.validation_loader:
                batch = self.context.to_device(batch)
                num_inputs += pytorch.data_length(batch)

                vld_metrics = self.trial.evaluate_batch(batch=batch)
                # Verify validation metric names are the same across batches.
                if keys is None:
                    keys = vld_metrics.keys()
                else:
                    check.eq(
                        keys,
                        vld_metrics.keys(),
                        "Validation metric names must match across all batches of data.",
                    )
                check.is_instance(
                    vld_metrics,
                    dict,
                    "validation_metrics() must return a "
                    "dictionary of string names to Tensor "
                    "metrics",
                )
                # TODO: For performance perform -> cpu() only at the end of validation.
                batch_metrics.append(
                    self._convert_metrics_to_numpy(vld_metrics))
                if self.env.test_mode:
                    break

            metrics = self._reduce_metrics(
                batch_metrics=batch_metrics,
                keys=keys,
                metrics_reducers=self._prepare_metrics_reducers(keys=keys),
            )

            if self.hvd_config.use:
                num_inputs *= hvd.size()

        else:
            check.true(self._evaluate_full_dataset_defined())
            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            if self.is_chief:
                metrics = self.trial.evaluate_full_dataset(
                    data_loader=self.validation_loader)

                check.is_instance(
                    metrics, dict,
                    f"eval() must return a dictionary, got {type(metrics)}.")

                metrics = self._convert_metrics_to_numpy(metrics)
                num_inputs = self.context.get_per_slot_batch_size() * len(
                    self.validation_loader)

        metrics.update(
            self._convert_metrics_to_numpy(
                self.context.experimental.reduce_metrics(for_training=False)))

        if self.hvd_config.use and any(
                map(
                    lambda c: util.is_overridden(
                        c.on_validation_end, pytorch.
                        PyTorchCallback) or util.is_overridden(
                            c.on_validation_step_end, pytorch.PyTorchCallback),
                    self.callbacks.values(),
                )):
            logging.debug(
                "Broadcasting metrics to all worker processes to execute a "
                "validation step end callback")
            metrics = hvd.broadcast_object(metrics, root_rank=0)

        for callback in self.callbacks.values():
            logging.warning(
                "on_validation_step_end is now deprecated, please use on_validation_end instead"
            )
            callback.on_validation_step_end(metrics)

        for callback in self.callbacks.values():
            callback.on_validation_end(metrics)

        if not self.is_chief:
            return workload.Skipped()

        return {"num_inputs": num_inputs, "validation_metrics": metrics}
 def _evaluate_full_dataset_defined(self) -> bool:
     return util.is_overridden(self.trial.evaluate_full_dataset,
                               PyTorchTrial)
 def _evaluate_batch_defined(self) -> bool:
     return util.is_overridden(self.trial.evaluate_batch, PyTorchTrial)
Beispiel #8
0
    def _configure_callbacks(self, user_callbacks: Optional[List]) -> None:
        """
        If we pass a callbacks parameter to model.fit() or model.evaluate() which is a
        pre-constructed CallbackList, Keras will not alter it.  We can use this property to
        configure the exact callback order that we want in our system.

        The implementation is based closely on from the real
        tf.keras.callbacks.configure_callbacks(), with the following differences:

          - We always assume we have the original Callbacks list.
          - We prepend and append additional Determined and Horovod callbacks
          - We create a det.keras.CallbackList instead of the normal tf.keras one.
        """

        callbacks = user_callbacks or []
        check.is_instance(
            callbacks,
            list,
            "the callbacks parameter of model.fit() or model.eval() must be a list of Callbacks",
        )

        if self.env.experiment_config.get_records_per_epoch() is None:
            for cb in callbacks:
                if util.is_overridden(
                        cb.on_epoch_end,
                        tf.keras.callbacks.Callback) and not getattr(
                            cb, "_skip_epoch_end_check", False):
                    if isinstance(cb, keras.callbacks.Callback):
                        # New callbacks must obey the rules.
                        raise AssertionError(
                            "it is unsupported to use a Callback that defines on_epoch_end "
                            f"({type(cb).__name__}) without setting the records_per_epoch value "
                            "in the experiment config")
                    else:
                        # Pre-existing callbacks only get a warning.
                        logging.warning(
                            "It is unsupported to use a Callback that defines on_epoch_end "
                            f"({type(cb).__name__})without setting the records_per_epoch value in "
                            "the experiment config. Training will continue but on_epoch_end will "
                            "never be called.")

        # Standard post-callback from the real configure_callbacks().
        # Note that we are not including BaseLogger since it is only for averaging metrics over an
        # entire epoch, and we don't report any metrics in on_epoch_end at all.
        self.model.history = keras.callbacks._DeterminedHistory()
        callbacks = callbacks + [self.model.history]

        if self.context._fit_verbose:
            # Our implementation of verbose=True.
            callbacks = [keras.callbacks._DeterminedProgress()] + callbacks

        # Calculate batches per epoch.  We can only handle batches per epoch, not records per epoch,
        # because we would have to communicate after every batch to know how many records were in
        # each batch on each worker in order to trigger on_epoch_end callbacks correctly.
        batches_per_epoch = None
        records_per_epoch = self.env.experiment_config.get_records_per_epoch()
        if records_per_epoch is not None:
            batches_per_epoch = records_per_epoch // self.context.get_global_batch_size(
            )

        # We wrap all of the callbacks in a single Multiplexer.
        self.multiplexer = TrialControllerMultiplexer(
            self,
            callbacks,
            self.is_chief,
            self.batch_size,
            batches_per_epoch,
            self.multiplexer_load_state,
        )
        callbacks = [self.multiplexer]

        if self.hvd_config.use:
            # Horovod synchronization of initial variables should happen even before we enter our
            # control loop, in case we have an initial validation requested.
            callbacks = [hvd.callbacks.BroadcastGlobalVariablesCallback(0)
                         ] + callbacks

        # The remainder of Determined control logic is done with a custom CallbackList
        self.callback_list = CallbackList(callbacks)

        # Disable timing of callbacks in some versions of keras. This can fail in some corner-cases
        # because CallbackList is not designed to allow some callbacks to call other callbacks, and
        # they can interact very poorly.
        if hasattr(self.callback_list, "_timing"):
            self.callback_list._timing["on_train_batch_begin"] = True
            self.callback_list._timing["on_train_batch_end"] = True
            self.callback_list._timing["on_test_batch_begin"] = True
            self.callback_list._timing["on_test_batch_end"] = True
            self.callback_list._timing["on_predict_batch_begin"] = True
            self.callback_list._timing["on_predict_batch_end"] = True

        # callback_model is the model given to callbacks, where we should be checking for
        # stop_training.  In horovod dtrain or non-dtrain, it should always be self.model.
        callback_model = self.model._get_callback_model()
        self.callback_list.set_model(callback_model)

        # Fill in bogus values for most of these... some of them are very complex to calculate.
        set_callback_parameters(
            self.callback_list,
            self.model,
            do_validation=False,
            batch_size=self.batch_size,
            epochs=None,
            steps_per_epoch=None,
            samples=None,
            verbose=False,
            mode=ModeKeys.TRAIN,
        )

        self.callback_list.model.stop_training = False
Beispiel #9
0
    def __init__(self, trial_inst: det.Trial, *args: Any,
                 **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

        check.is_instance(trial_inst, PyTorchTrial,
                          "PyTorchTrialController needs an PyTorchTrial")
        self.trial = cast(PyTorchTrial, trial_inst)
        self.context = cast(PyTorchTrialContext, self.context)
        self.callbacks = self.trial.build_callbacks()

        # TODO(DET-3262): remove this backward compatibility of old interface.
        if (util.is_overridden(self.trial.build_model, PyTorchTrial)
                or util.is_overridden(self.trial.optimizer, PyTorchTrial)
                or util.is_overridden(self.trial.create_lr_scheduler,
                                      PyTorchTrial)):
            check.true(
                util.is_overridden(self.trial.build_model, PyTorchTrial)
                and util.is_overridden(self.trial.optimizer, PyTorchTrial),
                "Both build_model() and optimizer() must be defined "
                "if any of build_model(), optimizer(), and create_lr_scheduler() are defined. "
                "If you want to use the new interface, you should instead instantiate your models, "
                "optimizers, and LR schedulers in __init__ and call context.backward(loss) "
                "and context.step_optimizer(optimizer) in train_batch.",
            )

            model = self.context._Model(self.trial.build_model())
            optim = self.context._Optimizer(self.trial.optimizer(model))

            lr_scheduler = self.trial.create_lr_scheduler(optim)
            if lr_scheduler is not None:
                self.context.lr_schedulers.append(lr_scheduler)

            if det.ExperimentConfig(self.context.get_experiment_config()
                                    ).mixed_precision_enabled():
                self.context._configure_apex_amp(
                    models=model,
                    optimizers=optim,
                    opt_level=self.context.get_experiment_config().get(
                        "optimizations", {}).get("mixed_precision", "O0"),
                )

            train_batch = self.trial.train_batch

            def new_train_batch(
                    batch: TorchData, model: nn.Module, epoch_idx: int,
                    batch_idx: int) -> Union[torch.Tensor, Dict[str, Any]]:
                tr_metrics = train_batch(batch, model, epoch_idx, batch_idx)
                if isinstance(tr_metrics, torch.Tensor):
                    tr_metrics = {"loss": tr_metrics}
                check.is_instance(
                    tr_metrics,
                    dict,
                    "train_batch() must return a dictionary "
                    f"mapping string names to Tensor metrics, got {type(tr_metrics)}",
                )
                check.is_in("loss", tr_metrics.keys(),
                            'Please include "loss" in you training metrics.')

                def clip_grads(parameters: Iterator) -> None:
                    for callback in self.callbacks.values():
                        callback.on_before_optimizer_step(parameters)

                self.context._backward(tr_metrics["loss"])
                self.context._step_optimizer(self.context.optimizers[0],
                                             clip_grads=clip_grads)
                return tr_metrics

            self.trial.__setattr__("train_batch", new_train_batch)

        check.gt_eq(
            len(self.context.models),
            1,
            "Must have at least one model. "
            "This might be caused by not wrapping your model with Model()",
        )
        check.gt_eq(
            len(self.context.optimizers),
            1,
            "Must have at least one optimizer. "
            "This might be caused by not wrapping your model with Optimizer()",
        )
        self._check_evaluate_implementation()

        # Validation loader will be undefined on process ranks > 0
        # when the user defines `validate_full_dataset()`.
        self.validation_loader = None  # type: Optional[torch.utils.data.DataLoader]
        self._set_data_loaders()

        # If a load path is provided load weights and restore the data location.
        self._load()

        if self.hvd_config.use:
            hvd.broadcast_parameters(self.context._main_model.state_dict(),
                                     root_rank=0)
            for optimizer in self.context.optimizers:
                hvd.broadcast_optimizer_state(optimizer, root_rank=0)

        self.training_iterator = iter(self.training_loader)
    def _load(self, load_path: pathlib.Path) -> None:
        # Backwards compat with older checkpoint formats. List is newest to
        # oldest known state_dict locations.
        potential_paths = [
            ["state_dict.pth"],
            ["determined", "state_dict.pth"],
            ["pedl", "state_dict.pth"],
            ["checkpoint.pt"],
        ]

        checkpoint: Optional[Dict[str, Any]] = None
        for ckpt_path in potential_paths:
            maybe_ckpt = load_path.joinpath(*ckpt_path)
            if maybe_ckpt.exists():
                checkpoint = torch.load(str(maybe_ckpt),
                                        map_location="cpu")  # type: ignore
                break
        if checkpoint is None or not isinstance(checkpoint, dict):
            return

        for callback in self.callbacks.values():
            callback.on_checkpoint_load_start(checkpoint)

        if "model_state_dict" in checkpoint:
            # Backward compatible with older checkpoint format.
            check.not_in("models_state_dict", checkpoint)
            check.eq(len(self.context.models), 1)
            self.context.models[0].load_state_dict(
                checkpoint["model_state_dict"])
        else:
            for idx, model in enumerate(self.context.models):
                model_state_dict = checkpoint["models_state_dict"][idx]
                try:
                    model.load_state_dict(model_state_dict)
                except Exception:
                    # If the checkpointed model is non-DDP and the current model is DDP, append
                    # module prefix to the checkpointed data
                    if isinstance(model,
                                  torch.nn.parallel.DistributedDataParallel):
                        logging.debug(
                            "Loading non-DDP checkpoint into a DDP model")
                        self._add_prefix_in_state_dict_if_not_present(
                            model_state_dict, "module.")
                    else:
                        # If the checkpointed model is DDP and we are currently running in
                        # single-slot mode, remove the module prefix from checkpointed data
                        logging.debug(
                            "Loading DDP checkpoint into a non-DDP model")
                        torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
                            model_state_dict, "module.")
                    model.load_state_dict(model_state_dict)

        if "optimizer_state_dict" in checkpoint:
            # Backward compatible with older checkpoint format.
            check.not_in("optimizers_state_dict", checkpoint)
            check.eq(len(self.context.optimizers), 1)
            self.context.optimizers[0].load_state_dict(
                checkpoint["optimizer_state_dict"])
        else:
            for idx, optimizer in enumerate(self.context.optimizers):
                optimizer.load_state_dict(
                    checkpoint["optimizers_state_dict"][idx])

        if "lr_scheduler" in checkpoint:
            # Backward compatible with older checkpoint format.
            check.not_in("lr_schedulers_state_dict", checkpoint)
            check.eq(len(self.context.lr_schedulers), 1)
            self.context.lr_schedulers[0].load_state_dict(
                checkpoint["lr_scheduler"])
        else:
            for idx, lr_scheduler in enumerate(self.context.lr_schedulers):
                lr_scheduler.load_state_dict(
                    checkpoint["lr_schedulers_state_dict"][idx])

        if "scaler_state_dict" in checkpoint:
            if self.context._scaler:
                self.context._scaler.load_state_dict(
                    checkpoint["scaler_state_dict"])
            else:
                logging.warning(
                    "There exists scaler_state_dict in checkpoint but the experiment is not using "
                    "AMP.")
        else:
            if self.context._scaler:
                logging.warning(
                    "The experiment is using AMP but scaler_state_dict does not exist in the "
                    "checkpoint.")

        if "amp_state" in checkpoint:
            if self.context._use_apex:
                apex.amp.load_state_dict(checkpoint["amp_state"])
            else:
                logging.warning(
                    "There exists amp_state in checkpoint but the experiment is not using Apex."
                )
        else:
            if self.context._use_apex:
                logging.warning(
                    "The experiment is using Apex but amp_state does not exist in the checkpoint."
                )

        if "rng_state" in checkpoint:
            rng_state = checkpoint["rng_state"]
            np.random.set_state(rng_state["np_rng_state"])
            random.setstate(rng_state["random_rng_state"])
            torch.random.set_rng_state(rng_state["cpu_rng_state"])

            if torch.cuda.device_count():
                if "gpu_rng_state" in rng_state:
                    torch.cuda.set_rng_state(
                        rng_state["gpu_rng_state"],
                        device=self.context.distributed.local_rank)
                else:
                    logging.warning(
                        "The system has a gpu but no gpu_rng_state exists in the checkpoint."
                    )
            else:
                if "gpu_rng_state" in rng_state:
                    logging.warning(
                        "There exists gpu_rng_state in checkpoint but the system has no gpu."
                    )
        else:
            logging.warning("The checkpoint has no random state to restore.")

        callback_state = checkpoint.get("callbacks", {})
        for name in self.callbacks:
            if name in callback_state:
                self.callbacks[name].load_state_dict(callback_state[name])
            elif util.is_overridden(self.callbacks[name].load_state_dict,
                                    pytorch.PyTorchCallback):
                logging.warning(
                    "Callback '{}' implements load_state_dict(), but no callback state "
                    "was found for that name when restoring from checkpoint. This "
                    "callback will be initialized from scratch")

        # Load workload sequencer state.
        wlsq_path = load_path.joinpath("workload_sequencer.pkl")
        if self.wlsq is not None and wlsq_path.exists():
            with wlsq_path.open("rb") as f:
                self.wlsq.load_state(pickle.load(f))
    def _compute_validation_metrics(self) -> workload.Response:
        self.context.reset_reducers()
        # Set the behavior of certain layers (e.g., dropout) that are
        # different between training and inference.
        for model in self.context.models:
            model.eval()

        step_start_time = time.time()

        for callback in self.callbacks.values():
            if util.is_overridden(callback.on_validation_step_start,
                                  pytorch.PyTorchCallback):
                logging.warning("on_validation_step_start is now deprecated, "
                                "please use on_validation_start instead")
                callback.on_validation_step_start()

        for callback in self.callbacks.values():
            callback.on_validation_start()

        num_inputs = 0
        metrics = {}  # type: Dict[str, Any]

        if self._evaluate_batch_defined():
            keys = None
            batch_metrics = []

            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            check.gt(len(self.validation_loader), 0)
            for callback in self.callbacks.values():
                callback.on_validation_epoch_start()
            for idx, batch in enumerate(self.validation_loader):
                if self.context.experimental._auto_to_device:
                    batch = self.context.to_device(batch)
                num_inputs += self.trial.get_batch_length(batch)

                if has_param(self.trial.evaluate_batch, "batch_idx", 2):
                    vld_metrics = self.trial.evaluate_batch(batch=batch,
                                                            batch_idx=idx)
                else:
                    vld_metrics = self.trial.evaluate_batch(
                        batch=batch)  # type: ignore
                # Verify validation metric names are the same across batches.
                if keys is None:
                    keys = vld_metrics.keys()
                else:
                    check.eq(
                        keys,
                        vld_metrics.keys(),
                        "Validation metric names must match across all batches of data.",
                    )
                check.is_instance(
                    vld_metrics,
                    dict,
                    "validation_metrics() must return a "
                    "dictionary of string names to Tensor "
                    "metrics",
                )
                # TODO: For performance perform -> cpu() only at the end of validation.
                batch_metrics.append(
                    pytorch._convert_metrics_to_numpy(vld_metrics))
                if self.env.test_mode:
                    break

            for callback in self.callbacks.values():
                callback.on_validation_epoch_end(batch_metrics)

            metrics = pytorch._reduce_metrics(
                self.context.distributed,
                batch_metrics=batch_metrics,
                keys=keys,
                metrics_reducers=pytorch._prepare_metrics_reducers(
                    self.trial.evaluation_reducer(), keys=keys),
            )

            # Gather a list of per-worker (num_inputs, num_batches) tuples.
            input_counts = self.context.distributed.gather(
                (num_inputs, idx + 1))
            if self.context.distributed.rank == 0:
                assert input_counts is not None
                # Reshape and sum.
                num_inputs, num_batches = [sum(n) for n in zip(*input_counts)]

        else:
            check.true(self._evaluate_full_dataset_defined())
            self.validation_loader = cast(torch.utils.data.DataLoader,
                                          self.validation_loader)
            if self.is_chief:
                metrics = self.trial.evaluate_full_dataset(
                    data_loader=self.validation_loader)

                check.is_instance(
                    metrics, dict,
                    f"eval() must return a dictionary, got {type(metrics)}.")

                metrics = pytorch._convert_metrics_to_numpy(metrics)
                num_inputs = self.context.get_per_slot_batch_size() * len(
                    self.validation_loader)

        metrics.update(
            pytorch._convert_metrics_to_numpy(
                self.context.reduce_metrics(for_training=False)))

        if self.context.distributed.size > 1 and any(
                map(
                    lambda c: util.is_overridden(
                        c.on_validation_end, pytorch.
                        PyTorchCallback) or util.is_overridden(
                            c.on_validation_step_end, pytorch.PyTorchCallback),
                    self.callbacks.values(),
                )):
            logging.debug(
                "Broadcasting metrics to all worker processes to execute a "
                "validation step end callback")
            metrics = hvd.broadcast_object(metrics, root_rank=0)

        for callback in self.callbacks.values():
            if util.is_overridden(callback.on_validation_step_end,
                                  pytorch.PyTorchCallback):
                logging.warning(
                    "on_validation_step_end is now deprecated, please use on_validation_end instead"
                )
                callback.on_validation_step_end(metrics)

        for callback in self.callbacks.values():
            callback.on_validation_end(metrics)

        if not self.is_chief:
            return {}

        # Skip reporting timings if evaluate_full_dataset() was defined.  This is far less common
        # than evaluate_batch() and we can't know how the user processed their validation data.
        if self._evaluate_batch_defined():
            step_duration = time.time() - step_start_time
            logging.info(
                det.util.make_timing_log("validated", step_duration,
                                         num_inputs, num_batches))

        return {"num_inputs": num_inputs, "validation_metrics": metrics}
Beispiel #12
0
    def _apply_backwards_compatibility(self) -> None:
        # TODO(DET-3262): remove this backward compatibility of old interface.
        if (util.is_overridden(self.trial.build_model, PyTorchTrial)
                or util.is_overridden(self.trial.optimizer, PyTorchTrial)
                or util.is_overridden(self.trial.create_lr_scheduler,
                                      PyTorchTrial)):
            logging.warning(
                "build_model(), optimizer(), and create_lr_scheduler(), which belong to "
                "the old interface, are deprecated. Please see the following documentation "
                "of PyTorchTrial for the new interface \n"
                f"{PyTorchTrial.__doc__}")
            logging.warning(
                "The callback on_before_optimizer_step is deprecated."
                "Please use context.step_optimizer to clip gradients.")
            check.true(
                util.is_overridden(self.trial.build_model, PyTorchTrial)
                and util.is_overridden(self.trial.optimizer, PyTorchTrial),
                "Both build_model() and optimizer() must be defined "
                "if any of build_model(), optimizer(), and create_lr_scheduler() are defined. "
                "If you want to use the new interface, you should instead instantiate your models, "
                "optimizers, and LR schedulers in __init__ and call context.backward(loss) "
                "and context.step_optimizer(optimizer) in train_batch.",
            )

            model = self.context.wrap_model(self.trial.build_model())
            optim = self.context.wrap_optimizer(self.trial.optimizer(model))

            lr_scheduler = self.trial.create_lr_scheduler(optim)
            if lr_scheduler is not None:
                opt = getattr(lr_scheduler._scheduler, "optimizer", None)
                if opt is not None:
                    check.is_in(
                        opt,
                        self.context.optimizers,
                        "Must use a wrapped optimizer that is passed in by the optimizer "
                        "argument of create_lr_scheduler",
                    )
                self.context.lr_schedulers.append(lr_scheduler)

            if det.ExperimentConfig(self.context.get_experiment_config()
                                    ).mixed_precision_enabled():
                logging.warning(
                    "The experiment configuration field optimization.mixed_precision is deprecated."
                    "Please use configure_apex_amp in __init__ to configrue apex amp. "
                    "See the following documentation of PyTorchTrial for the new interface \n"
                    f"{PyTorchTrial.__doc__}")
                self.context.configure_apex_amp(
                    models=model,
                    optimizers=optim,
                    opt_level=self.context.get_experiment_config().get(
                        "optimizations", {}).get("mixed_precision", "O0"),
                )

            # Backward compatibility: train_batch
            train_batch = cast(Callable, self.trial.train_batch)

            def new_train_batch(batch: pytorch.TorchData, epoch_idx: int,
                                batch_idx: int) -> Any:
                tr_metrics = train_batch(
                    batch=batch,
                    model=model,
                    epoch_idx=epoch_idx,
                    batch_idx=batch_idx,
                )
                if isinstance(tr_metrics, torch.Tensor):
                    tr_metrics = {"loss": tr_metrics}
                check.is_instance(
                    tr_metrics,
                    dict,
                    "train_batch() must return a dictionary "
                    f"mapping string names to Tensor metrics, got {type(tr_metrics)}",
                )
                check.is_in("loss", tr_metrics.keys(),
                            'Please include "loss" in you training metrics.')

                def clip_grads(parameters: Iterator) -> None:
                    for callback in self.callbacks.values():
                        callback.on_before_optimizer_step(parameters)

                self.context.backward(tr_metrics["loss"])
                self.context.step_optimizer(self.context.optimizers[0],
                                            clip_grads=clip_grads)

                return tr_metrics

            self.trial.__setattr__("train_batch", new_train_batch)

            # Backward compatibility: evaluate_batch
            if self._evaluate_batch_defined():
                evaluate_batch = cast(Callable, self.trial.evaluate_batch)

                def new_evaluate_batch(batch: pytorch.TorchData) -> Any:
                    return evaluate_batch(model=model, batch=batch)

                self.trial.__setattr__("evaluate_batch", new_evaluate_batch)

            # Backward compatibility: evaluate_full_dataset
            if self._evaluate_full_dataset_defined():
                evaluate_full_dataset = cast(Callable,
                                             self.trial.evaluate_full_dataset)

                def new_evaluate_full_dataset(
                        data_loader: torch.utils.data.DataLoader) -> Any:
                    return evaluate_full_dataset(model=model,
                                                 data_loader=data_loader)

                self.trial.__setattr__("evaluate_full_dataset",
                                       new_evaluate_full_dataset)
Beispiel #13
0
    def _load(self, load_path: pathlib.Path) -> None:
        # Right now we will load all checkpoint shards on each node regardless of which
        # checkpoints are needed.
        # TODO (Liam): revisit later to optimize sharded checkpoint loading.

        # Load stateful things tracked by Determined on all slots.
        ckpt_path = f"det_state_dict_rank{self.context.distributed.rank}.pth"
        maybe_ckpt = load_path.joinpath(ckpt_path)

        if not maybe_ckpt.exists():
            return

        checkpoint = torch.load(str(maybe_ckpt),
                                map_location="cpu")  # type: ignore
        if not isinstance(checkpoint, dict):
            raise det.errors.InvalidExperimentException(
                f"Expected checkpoint at {maybe_ckpt} to be a dict "
                f"but got {type(checkpoint).__name__}.")

        for callback in self.callbacks.values():
            callback.on_checkpoint_load_start(checkpoint)

        # We allow users to override load behavior if needed but we default to using
        # the load method provided by DeepSpeed.
        self.trial.load(self.context, load_path)

        if "rng_state" in checkpoint:
            rng_state = checkpoint["rng_state"]
            np.random.set_state(rng_state["np_rng_state"])
            random.setstate(rng_state["random_rng_state"])
            torch.random.set_rng_state(rng_state["cpu_rng_state"])

            if torch.cuda.device_count():
                if "gpu_rng_state" in rng_state:
                    torch.cuda.set_rng_state(
                        rng_state["gpu_rng_state"],
                        device=self.context.distributed.get_local_rank())
                else:
                    logging.warning(
                        "The system has a gpu but no gpu_rng_state exists in the checkpoint."
                    )
            else:
                if "gpu_rng_state" in rng_state:
                    logging.warning(
                        "There exists gpu_rng_state in checkpoint but the system has no gpu."
                    )
        else:
            logging.warning("The checkpoint has no random state to restore.")

        callback_state = checkpoint.get("callbacks", {})
        for name in self.callbacks:
            if name in callback_state:
                self.callbacks[name].load_state_dict(callback_state[name])
            elif util.is_overridden(self.callbacks[name].load_state_dict,
                                    pytorch.PyTorchCallback):
                logging.warning(
                    "Callback '{}' implements load_state_dict(), but no callback state "
                    "was found for that name when restoring from checkpoint. This "
                    "callback will be initialized from scratch")

        # Load workload sequencer state.
        wlsq_path = load_path.joinpath("workload_sequencer.pkl")
        if self.wlsq is not None and wlsq_path.exists():
            with wlsq_path.open("rb") as f:
                self.wlsq.load_state(pickle.load(f))
Beispiel #14
0
    def _compute_validation_metrics(self) -> workload.Response:
        self.context.reset_reducers()
        # Set the behavior of certain layers (e.g., dropout) that are
        # different between training and inference.
        for model in self.context.models:
            model.eval()

        step_start_time = time.time()

        for callback in self.callbacks.values():
            if util.is_overridden(callback.on_validation_step_start,
                                  pytorch.PyTorchCallback):
                logging.warning("on_validation_step_start is now deprecated, "
                                "please use on_validation_start instead")
                callback.on_validation_step_start()

        for callback in self.callbacks.values():
            callback.on_validation_start()

        num_inputs = 0
        keys = None
        batch_metrics = []

        for callback in self.callbacks.values():
            callback.on_validation_epoch_start()

        validation_iterator = iter(
            self.validation_loader) if self.validation_loader else None
        for idx in range(cast(int, self.num_validation_batches)):
            num_inputs += cast(int, self.validation_batch_size)
            # Note that when using pipeline parallelism, each call to evaluate_batch will request
            # self.context.num_micro_batches_per_slot batches from the validation iterator.
            # This is why we set self.num_validation_batches differently for pipeline parallel
            # and no pipeline parallel when building the data laoders.
            vld_metrics = self.trial.evaluate_batch(validation_iterator, idx)
            if self.context._mpu.should_report_metrics:
                if not isinstance(vld_metrics, dict):
                    raise det.errors.InvalidExperimentException(
                        "evaluate_batch must return a dictionary of string names "
                        "to Tensor metrics", )
                # Verify validation metric names are the same across batches.
                if keys is None:
                    keys = vld_metrics.keys()
                else:
                    if keys != vld_metrics.keys():
                        raise det.errors.InvalidExperimentException(
                            "Validation metric names must match across all batches of data.",
                        )
                # TODO: For performance perform -> cpu() only at the end of validation.
                batch_metrics.append(
                    pytorch._convert_metrics_to_numpy(vld_metrics))
            if self.env.test_mode:
                break

        # keys and list(keys) does not satisfy all cases because it will return dict_keys type if
        # keys is an empty dict. this will then break when passed to zmq_broadcast since it does
        # not know how to serialize dict_keys type.
        all_keys = self.context.distributed.gather(
            keys if keys is None else list(keys))
        if self.is_chief:
            all_keys = [k for k in all_keys if k is not None]
            keys = all_keys[0]
        keys = self.context.distributed.broadcast(keys)

        for callback in self.callbacks.values():
            callback.on_validation_epoch_end(batch_metrics)

        metrics = pytorch._reduce_metrics(
            self.context.distributed,
            batch_metrics=batch_metrics,
            keys=keys,
            metrics_reducers=pytorch._prepare_metrics_reducers(
                pytorch.Reducer.AVG, keys=keys),
        )
        metrics.update(
            pytorch._convert_metrics_to_numpy(
                self.context.reduce_metrics(for_training=False)))

        if self.context.distributed.size > 1 and any(
                util.is_overridden(c.on_validation_end,
                                   pytorch.PyTorchCallback)
                or util.is_overridden(c.on_validation_step_end,
                                      pytorch.PyTorchCallback)
                for c in self.callbacks.values()):
            logging.debug(
                "Broadcasting metrics to all worker processes to execute a "
                "validation step end callback")
            metrics = self.context.distributed.broadcast(metrics)

        for callback in self.callbacks.values():
            if util.is_overridden(callback.on_validation_step_end,
                                  pytorch.PyTorchCallback):
                logging.warning(
                    "on_validation_step_end is now deprecated, please use on_validation_end instead"
                )
                callback.on_validation_step_end(metrics)

        for callback in self.callbacks.values():
            callback.on_validation_end(metrics)

        if not self.is_chief:
            return {}

        num_inputs *= self.context._mpu.data_parallel_world_size
        step_duration = time.time() - step_start_time
        logging.info(
            det.util.make_timing_log("validated", step_duration, num_inputs,
                                     cast(int, self.num_validation_batches)))

        self.metric_writer.on_validation_step_end(self.steps_completed,
                                                  metrics)
        return {"num_inputs": num_inputs, "validation_metrics": metrics}