Exemple #1
0
    def __init__(self, trial_inst: det.Trial, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

        check.is_instance(trial_inst, PyTorchTrial, "PyTorchTrialController needs an PyTorchTrial")
        self.trial = cast(PyTorchTrial, trial_inst)
        self._check_evaluate_implementation()

        self._init_model_and_optimizer()

        # Validation loader will be undefined on process ranks > 0
        # when the user defines `validate_full_dataset()`.
        self.validation_loader = None  # type: Optional[torch.utils.data.DataLoader]
        self._set_data_loaders()

        # Track whether a warning logging category has already been issued to the user.
        self.warning_logged = {_WarningLogs.FAILED_MOVING_TO_DEVICE: False}

        self.context.lr_scheduler = self.trial.create_lr_scheduler(self.context.optimizer)

        self.callbacks = self.trial.build_callbacks()

        # If a load path is provided load weights and restore the data location.
        self._load()
        self._configure_amp()

        if self.hvd_config.use:
            hvd.broadcast_parameters(self.context.model.state_dict(), root_rank=0)
            hvd.broadcast_optimizer_state(self.context.optimizer, root_rank=0)

        self.training_iterator = iter(self.training_loader)
Exemple #2
0
    def _init_model(self) -> None:
        self.optimizer = self.trial.optimizer(self.model)
        # TODO: Check that optimizer is not an amp optimizer.

        self._init_device()
        self.model = self.model.to(self.device)

        if self.hvd_config.use:
            use_compression = self.hvd_config.fp16_compression
            self.optimizer = hvd.DistributedOptimizer(
                self.optimizer,
                named_parameters=self.model.named_parameters(),
                backward_passes_per_step=self.hvd_config.aggregation_frequency,
                compression=hvd.Compression.fp16
                if use_compression else hvd.Compression.none,
            )
            logging.debug(
                "Initialized optimizer for distributed and optimized parallel training."
            )
        elif self.n_gpus > 1:
            check.eq(
                self.hvd_config.aggregation_frequency,
                1,
                "Please enable `optimized_parallel` to use aggregation "
                "frequency greater than 1 for single machine multi-GPU "
                "training.",
            )
            self.model = nn.DataParallel(self.model)
            logging.debug("Initialized mode for native parallel training.")

        self.lr_helper = _LRHelper(
            self.trial.create_lr_scheduler(self.optimizer))

        # If a load path is provided load weights and restore the data location.
        self._load()

        self._configure_amp()

        if self.hvd_config.use:
            hvd.broadcast_parameters(self.model.state_dict(), root_rank=0)
            hvd.broadcast_optimizer_state(self.optimizer, root_rank=0)

        # Initialize training and validation iterators.
        self.training_iterator = iter(self.training_loader)
Exemple #3
0
    def __init__(self, trial_inst: det.Trial, *args: Any,
                 **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

        check.is_instance(trial_inst, PyTorchTrial,
                          "PyTorchTrialController needs an PyTorchTrial")
        self.trial = cast(PyTorchTrial, trial_inst)
        self.context = cast(pytorch.PyTorchTrialContext, self.context)
        self.context.experimental._set_allgather_fn(self.allgather_metrics)
        self.callbacks = self.trial.build_callbacks()

        self._apply_backwards_compatibility()

        check.gt_eq(
            len(self.context.models),
            1,
            "Must have at least one model. "
            "This might be caused by not wrapping your model with wrap_model()",
        )
        check.gt_eq(
            len(self.context.optimizers),
            1,
            "Must have at least one optimizer. "
            "This might be caused by not wrapping your optimizer with wrap_optimizer()",
        )
        self._check_evaluate_implementation()

        # Validation loader will be undefined on process ranks > 0
        # when the user defines `validate_full_dataset()`.
        self.validation_loader = None  # type: Optional[torch.utils.data.DataLoader]
        self._set_data_loaders()

        # We don't want the training_iterator shuffling values after we load state
        self.training_iterator = iter(self.training_loader)

        # If a load path is provided load weights and restore the data location.
        self._load()

        if self.hvd_config.use:
            hvd.broadcast_parameters(self.context._main_model.state_dict(),
                                     root_rank=0)
            for optimizer in self.context.optimizers:
                hvd.broadcast_optimizer_state(optimizer, root_rank=0)
    def run(self) -> None:
        # We create the training_iterator here rather than in __init__ because we have to be careful
        # to trigger its shutdown explicitly, to avoid hangs in when the user is using
        # multiprocessing-based parallelism for their dataloader.
        #
        # We create it before loading state because we don't want the training_iterator shuffling
        # values after we load state.
        self.training_iterator = iter(self.training_loader)
        try:
            self._load()

            if self.hvd_config.use:
                hvd.broadcast_parameters(self.context._main_model.state_dict(), root_rank=0)
                for optimizer in self.context.optimizers:
                    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

            with self.prof:
                self._run()

        finally:
            # Explicitly trigger the training iterator's shutdown (which happens in __del__).
            # See the rather long note in pytorch/torch/utils/data/dataloader.py.
            del self.training_iterator
Exemple #5
0
    def __init__(self, trial_inst: det.Trial, *args: Any,
                 **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

        check.is_instance(trial_inst, PyTorchTrial,
                          "PyTorchTrialController needs an PyTorchTrial")
        self.trial = cast(PyTorchTrial, trial_inst)
        self.context = cast(PyTorchTrialContext, self.context)
        self.callbacks = self.trial.build_callbacks()

        # TODO(DET-3262): remove this backward compatibility of old interface.
        if (util.is_overridden(self.trial.build_model, PyTorchTrial)
                or util.is_overridden(self.trial.optimizer, PyTorchTrial)
                or util.is_overridden(self.trial.create_lr_scheduler,
                                      PyTorchTrial)):
            check.true(
                util.is_overridden(self.trial.build_model, PyTorchTrial)
                and util.is_overridden(self.trial.optimizer, PyTorchTrial),
                "Both build_model() and optimizer() must be defined "
                "if any of build_model(), optimizer(), and create_lr_scheduler() are defined. "
                "If you want to use the new interface, you should instead instantiate your models, "
                "optimizers, and LR schedulers in __init__ and call context.backward(loss) "
                "and context.step_optimizer(optimizer) in train_batch.",
            )

            model = self.context._Model(self.trial.build_model())
            optim = self.context._Optimizer(self.trial.optimizer(model))

            lr_scheduler = self.trial.create_lr_scheduler(optim)
            if lr_scheduler is not None:
                self.context.lr_schedulers.append(lr_scheduler)

            if det.ExperimentConfig(self.context.get_experiment_config()
                                    ).mixed_precision_enabled():
                self.context._configure_apex_amp(
                    models=model,
                    optimizers=optim,
                    opt_level=self.context.get_experiment_config().get(
                        "optimizations", {}).get("mixed_precision", "O0"),
                )

            train_batch = self.trial.train_batch

            def new_train_batch(
                    batch: TorchData, model: nn.Module, epoch_idx: int,
                    batch_idx: int) -> Union[torch.Tensor, Dict[str, Any]]:
                tr_metrics = train_batch(batch, model, epoch_idx, batch_idx)
                if isinstance(tr_metrics, torch.Tensor):
                    tr_metrics = {"loss": tr_metrics}
                check.is_instance(
                    tr_metrics,
                    dict,
                    "train_batch() must return a dictionary "
                    f"mapping string names to Tensor metrics, got {type(tr_metrics)}",
                )
                check.is_in("loss", tr_metrics.keys(),
                            'Please include "loss" in you training metrics.')

                def clip_grads(parameters: Iterator) -> None:
                    for callback in self.callbacks.values():
                        callback.on_before_optimizer_step(parameters)

                self.context._backward(tr_metrics["loss"])
                self.context._step_optimizer(self.context.optimizers[0],
                                             clip_grads=clip_grads)
                return tr_metrics

            self.trial.__setattr__("train_batch", new_train_batch)

        check.gt_eq(
            len(self.context.models),
            1,
            "Must have at least one model. "
            "This might be caused by not wrapping your model with Model()",
        )
        check.gt_eq(
            len(self.context.optimizers),
            1,
            "Must have at least one optimizer. "
            "This might be caused by not wrapping your model with Optimizer()",
        )
        self._check_evaluate_implementation()

        # Validation loader will be undefined on process ranks > 0
        # when the user defines `validate_full_dataset()`.
        self.validation_loader = None  # type: Optional[torch.utils.data.DataLoader]
        self._set_data_loaders()

        # If a load path is provided load weights and restore the data location.
        self._load()

        if self.hvd_config.use:
            hvd.broadcast_parameters(self.context._main_model.state_dict(),
                                     root_rank=0)
            for optimizer in self.context.optimizers:
                hvd.broadcast_optimizer_state(optimizer, root_rank=0)

        self.training_iterator = iter(self.training_loader)
    def run(self) -> None:
        @contextlib.contextmanager
        def defer(fn: Callable, *args: Any) -> Iterator[None]:
            try:
                yield
            finally:
                fn(*args)

        # We define on_shutdown here instead of inside the `for callback in...` loop to ensure we
        # don't bind a the loop iteration variable `callback`, which would likely cause us to call
        # on_trial_shutdown() multiple times for the final callback, and not at all for the others.
        def on_shutdown(callback_name: str,
                        on_trial_shutdown: Callable) -> None:
            with self.prof.record_timing(
                    f"callbacks.{callback_name}.on_trial_shutdown"):
                on_trial_shutdown()

        with contextlib.ExitStack() as exit_stack:
            for callback in self.callbacks.values():
                with self.prof.record_timing(
                        f"callbacks.{callback.__class__.__name__}.on_trial_startup"
                ):
                    callback.on_trial_startup(self.steps_completed,
                                              self.env.latest_checkpoint)
                exit_stack.enter_context(
                    defer(on_shutdown, callback.__class__.__name__,
                          callback.on_trial_shutdown))

            self._set_data_loaders()

            # We create the training_iterator here rather than in __init__ because we have to be
            # careful to trigger its shutdown explicitly, to avoid hangs in when the user is using
            # multiprocessing-based parallelism for their dataloader.
            #
            # We create it before loading state because we don't want the training_iterator
            # shuffling values after we load state.
            self.training_iterator = iter(self.training_loader)

            def cleanup_iterator() -> None:
                # Explicitly trigger the training iterator's shutdown (which happens in __del__).
                # See the rather long note in pytorch/torch/utils/data/dataloader.py.
                del self.training_iterator

            exit_stack.enter_context(defer(cleanup_iterator))

            # If a load path is provided load weights and restore the data location.
            if self.env.latest_checkpoint is not None:
                logging.info(
                    f"Restoring trial from checkpoint {self.env.latest_checkpoint}"
                )
                with self.context._core.checkpoint.restore_path(
                        self.env.latest_checkpoint) as load_path:
                    self._load(load_path)

            if self.context.distributed.size > 1 and self.use_horovod:
                hvd.broadcast_parameters(self.context._main_model.state_dict(),
                                         root_rank=0)
                for optimizer in self.context.optimizers:
                    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

            with self.prof:
                for callback in self.callbacks.values():
                    with self.prof.record_timing(
                            f"callbacks.{callback.__class__.__name__}.on_training_start"
                    ):
                        callback.on_training_start()
                self._run()