Beispiel #1
0
    def on_batch_end(self, runner: IRunner) -> None:
        """On batch end event

        Args:
            runner: current runner
        """
        # Drop the cache when we exit to a nesting level
        # that's outside any instance of autocast.
        if torch.autocast_decrement_nesting() == 0:
            torch.clear_autocast_cache()
        torch.set_autocast_enabled(self.prev_autocast_state)

        if not runner.is_train_loader:
            return

        loss = runner.batch_metrics[self.metric_key]

        self._accumulation_counter += 1
        need_gradient_step = (self._accumulation_counter %
                              self.accumulation_steps == 0)

        self.scaler.scale(loss).backward()

        if need_gradient_step:
            self.grad_step(
                optimizer=self._optimizer,
                grad_clip_fn=self.grad_clip_fn,
            )

            utils.maybe_recursive_call(self._optimizer, "zero_grad")
            self._accumulation_counter = 0
Beispiel #2
0
    def _run_epoch(self, stage: str, epoch: int) -> None:
        """
        Inner method to run epoch on Runner,
        with epoch callbacks events.

        Args:
            stage (str): stage name of interest,
                like "pretrain" / "train" / "finetune" / etc
            epoch (int): epoch index
        """
        self._prepare_for_epoch(stage=stage, epoch=epoch)
        assert self.loaders is not None

        for loader_name, loader in self.loaders.items():
            if len(loader) == 0:
                raise RunnerException(
                    f"DataLoader with name {loader_name} is empty.")

        # @TODO: better solution with train/inference handling ?
        self.is_infer_stage = self.stage_name.startswith("infer")
        if not self.is_infer_stage:
            assert self.valid_loader in self.loaders.keys(), (
                f"'{self.valid_loader}' "
                f"should be in provided loaders: {list(self.loaders.keys())}")
        else:
            # @TODO: add check for non distributed run for inference
            assert not any(
                x.startswith(settings.loader_train_prefix)
                for x in self.loaders.keys()
            ), "for inference no train loader should be passed"

        for loader_name, loader in self.loaders.items():
            self.loader_name = loader_name
            self.loader_len = len(loader)
            self.is_train_loader = loader_name.startswith(
                settings.loader_train_prefix)
            self.is_valid_loader = loader_name.startswith(
                settings.loader_valid_prefix)
            self.is_infer_loader = loader_name.startswith(
                settings.loader_infer_prefix)
            utils.maybe_recursive_call(
                self.model,
                "train",
                mode=self.is_train_loader,
            )

            if (isinstance(loader.sampler, DistributedSampler)
                    and not self.is_infer_stage):
                loader.sampler.set_epoch(self.epoch)

            utils.set_global_seed(self.experiment.initial_seed +
                                  self.global_epoch + 1)
            self._run_event("on_loader_start")
            with torch.set_grad_enabled(self.is_train_loader):
                self._run_loader(loader)
            self._run_event("on_loader_end")
Beispiel #3
0
    def on_batch_end(self, state: State) -> None:
        """On batch end event

        Args:
            state (State): current state
        """
        if not state.is_train_loader:
            return

        loss = state.batch_metrics[self.loss_key]

        self._accumulation_counter += 1
        need_gradient_step = (self._accumulation_counter +
                              1) % self.accumulation_steps == 0

        # This is very hacky check whether we have AMP optimizer and this may
        # change in future.
        # But alternative solution is to have AmpOptimizerCallback.
        # or expose another c'tor argument.
        if hasattr(self._optimizer, "_amp_stash"):
            from apex import amp

            # Need to set ``delay_unscale``
            # according to
            # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
            delay_unscale = not need_gradient_step
            with amp.scale_loss(loss,
                                self._optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if need_gradient_step:
            self.grad_step(
                optimizer=self._optimizer,
                optimizer_wds=self._optimizer_wd,
                grad_clip_fn=self.grad_clip_fn,
            )

            # if self.save_model_grads:
            #     for tag, value in model.named_parameters():
            #         tag = tag.replace(".", "/")
            #         state.model_grads[tag] = value.grad.cpu().numpy()

            utils.maybe_recursive_call(self._optimizer, "zero_grad")

            self._accumulation_counter = 0
Beispiel #4
0
    def model(self, value: Union[Model, Dict[str, Model]]):
        """Setter for the runner's model, useful for experiment tracing.

        Args:
            value (Union[Model, Dict[str, Model]]): new model.
        """
        if isinstance(value, nn.Module):
            model = value
        elif isinstance(value, dict):
            values_are_models = all(
                isinstance(v, nn.Module) for v in value.values())
            if not values_are_models:
                raise TypeError(
                    "Invalid dict value type, must be `torch.nn.Module`")

            model = value

        else:
            raise TypeError(
                f"Invalid value type "
                f"must be `torch.nn.Module` or `Dict[str, torch.nn.Module]` "
                f"got '{type(value)}'")

        if self._device is not None:
            model: Model = utils.maybe_recursive_call(model,
                                                      "to",
                                                      device=self._device)

        self._model = model
Beispiel #5
0
    def device(self, value: Device):
        """
        Setter for the runner's device.

        Args:
            value (Device): new torch device.

        Raises:
            TypeError: if `value` is out of `torch.device`, `str` or `None`
        """
        if isinstance(value, torch.device):
            self._device = value
        elif isinstance(value, str):
            self._device = torch.device(value)
        elif isinstance(value, type(None)):
            self._device = None
        else:
            raise TypeError(f"Invalid value type "
                            f"must be `str` or `torch.device` "
                            f"got '{type(value)}'")

        if self._model is not None:
            self._model = utils.maybe_recursive_call(self._model,
                                                     "to",
                                                     device=self._device
                                                     or "cpu")
Beispiel #6
0
    def on_batch_end(self, runner: IRunner) -> None:
        """On batch end event

        Args:
            runner: current runner
        """
        if not runner.is_train_loader:
            return

        loss = runner.batch_metrics[self.metric_key]

        self._accumulation_counter += 1
        need_gradient_step = (self._accumulation_counter %
                              self.accumulation_steps == 0)

        # This is very hacky check whether we have AMP optimizer and this may
        # change in future.
        # But alternative solution is to have AmpOptimizerCallback.
        # or expose another c'tor argument.
        # @TODO: speedup with re-definition ``on_stage_start``
        if hasattr(self._optimizer, "_amp_stash"):
            from apex import amp

            # Need to set ``delay_unscale``
            # according to
            # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
            delay_unscale = not need_gradient_step
            with amp.scale_loss(loss,
                                self._optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if need_gradient_step:
            self.grad_step(
                optimizer=self._optimizer,
                optimizer_wds=self._optimizer_wd,
                grad_clip_fn=self.grad_clip_fn,
            )
            if not self.use_fast_zero_grad:
                utils.maybe_recursive_call(self._optimizer, "zero_grad")
            else:
                utils.maybe_recursive_call(self._optimizer, zero_grad)
            self._accumulation_counter = 0
Beispiel #7
0
    def device(self, value: Device):
        """Setter for the runner's device.

        Args:
            value (Device): new torch device.
        """
        if isinstance(value, (str, torch.device)):
            self._device = value
        else:
            raise TypeError(f"Invalid value type "
                            f"must be `str` or `torch.device` "
                            f"got '{type(value)}'")

        if self._model is not None:
            self._model = utils.maybe_recursive_call(self._model,
                                                     "to",
                                                     device=self._device)