Ejemplo n.º 1
0
    def __exit__(self, exc_type: Any, exc_val: Any,
                 exc_tb: Any):  # type: ignore[override]
        if torch._jit_internal.is_scripting():
            return

        # Drop the cache when we exit to a nesting level that's outside any instance of autocast.
        if self.device == 'cpu':
            if torch.autocast_decrement_nesting() == 0:
                torch.clear_autocast_cache()
            torch.set_autocast_cpu_enabled(self.prev)
            torch.set_autocast_cpu_dtype(self.prev_fastdtype)
        elif self.device == 'xpu':
            if torch.autocast_decrement_nesting() == 0:
                torch.clear_autocast_cache()
            torch.xpu.set_autocast_xpu_enabled(
                self.prev)  # type: ignore[attr-defined]
            torch.xpu.set_autocast_xpu_dtype(
                self.prev_fastdtype)  # type: ignore[attr-defined]
        else:
            if torch.autocast_decrement_nesting() == 0:
                torch.clear_autocast_cache()
            torch.set_autocast_enabled(self.prev)
            torch.set_autocast_gpu_dtype(self.prev_fastdtype)
        torch.set_autocast_cache_enabled(self.prev_cache_enabled)
        return False
Ejemplo n.º 2
0
    def on_batch_end(self, runner: IRunner) -> None:
        """On batch end event

        Args:
            runner: current runner
        """
        # Drop the cache when we exit to a nesting level
        # that's outside any instance of autocast.
        if torch.autocast_decrement_nesting() == 0:
            torch.clear_autocast_cache()
        torch.set_autocast_enabled(self.prev_autocast_state)

        if not runner.is_train_loader:
            return

        loss = runner.batch_metrics[self.metric_key]

        self._accumulation_counter += 1
        need_gradient_step = (self._accumulation_counter %
                              self.accumulation_steps == 0)

        self.scaler.scale(loss).backward()

        if need_gradient_step:
            self.grad_step(
                optimizer=self._optimizer,
                grad_clip_fn=self.grad_clip_fn,
            )

            utils.maybe_recursive_call(self._optimizer, "zero_grad")
            self._accumulation_counter = 0
Ejemplo n.º 3
0
    def __enter__(self):
        if torch._jit_internal.is_scripting():
            assert self.fast_dtype is not None
            return self

        self.prev_cache_enabled = torch.is_autocast_cache_enabled()
        if self.device == 'cpu':
            self.prev = torch.is_autocast_cpu_enabled()
            self.prev_fastdtype = torch.get_autocast_cpu_dtype()
            torch.set_autocast_cpu_enabled(self._enabled)
            torch.set_autocast_cpu_dtype(
                self.fast_dtype)  # type: ignore[arg-type]
            torch.autocast_increment_nesting()
        elif self.device == 'xpu':
            self.prev = torch.xpu.is_autocast_xpu_enabled(
            )  # type: ignore[attr-defined]
            self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype(
            )  # type: ignore[attr-defined]
            torch.xpu.set_autocast_xpu_enabled(
                self._enabled)  # type: ignore[attr-defined]
            torch.xpu.set_autocast_xpu_dtype(
                self.fast_dtype)  # type: ignore[attr-defined]
            torch.autocast_increment_nesting()
        else:
            self.prev = torch.is_autocast_enabled()
            self.prev_fastdtype = torch.get_autocast_gpu_dtype()
            torch.set_autocast_gpu_dtype(
                self.fast_dtype)  # type: ignore[arg-type]
            torch.set_autocast_enabled(self._enabled)
            torch.autocast_increment_nesting()
        torch.set_autocast_cache_enabled(self._cache_enabled)
Ejemplo n.º 4
0
 def __exit__(self, *args):
     if self._enabled:
         # Drop the cache when we exit to a nesting level that's outside any instance of autocast
         if torch.autocast_decrement_nesting() == 0:
             torch.clear_autocast_cache()
         torch.set_autocast_enabled(self.prev)
     return False
Ejemplo n.º 5
0
    def on_batch_start(self, runner: IRunner) -> None:
        """On batch start event

        Args:
            runner: current runner
        """
        self.prev_autocast_state = torch.is_autocast_enabled()
        torch.set_autocast_enabled(True)
        torch.autocast_increment_nesting()
Ejemplo n.º 6
0
 def __exit__(self, *args):
     # Drop the cache when we exit to a nesting level that's outside any instance of autocast.
     if self.device == 'cpu':
         if torch.autocast_decrement_nesting() == 0:
             torch.clear_autocast_cache()
         torch.set_autocast_cpu_enabled(self.prev)
         torch.set_autocast_cpu_dtype(self.prev_fastdtype)
     else:
         if torch.autocast_decrement_nesting() == 0:
             torch.clear_autocast_cache()
         torch.set_autocast_enabled(self.prev)
         torch.set_autocast_gpu_dtype(self.prev_fastdtype)
     return False
Ejemplo n.º 7
0
 def __enter__(self):
     if self.device == 'cpu':
         self.prev = torch.is_autocast_cpu_enabled()
         self.prev_fastdtype = torch.get_autocast_cpu_dtype()
         torch.set_autocast_cpu_enabled(self._enabled)
         torch.set_autocast_cpu_dtype(self.fast_dtype)
         torch.autocast_increment_nesting()
     else:
         self.prev = torch.is_autocast_enabled()
         self.prev_fastdtype = torch.get_autocast_gpu_dtype()
         torch.set_autocast_gpu_dtype(self.fast_dtype)
         torch.set_autocast_enabled(self._enabled)
         torch.autocast_increment_nesting()
Ejemplo n.º 8
0
    def on_batch_end(self, runner: "IRunner") -> None:
        """On batch end event

        Args:
            runner: current runner
        """
        if self.use_amp:
            # Drop the cache when we exit to a nesting level
            # that's outside any instance of autocast.
            if torch.autocast_decrement_nesting() == 0:
                torch.clear_autocast_cache()
            torch.set_autocast_enabled(self.prev_autocast_state)

        if not runner.is_train_loader:
            return

        loss = runner.batch_metrics[self.metric_key]

        self._accumulation_counter += 1
        need_gradient_step = (self._accumulation_counter %
                              self.accumulation_steps == 0)

        # @TODO: speedup with re-definition ``on_stage_start``
        if self.use_apex:
            from apex import amp

            # Need to set ``delay_unscale``
            # according to
            # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
            delay_unscale = not need_gradient_step
            with amp.scale_loss(loss,
                                self._optimizer,
                                delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
        elif self.use_amp:
            self.scaler.scale(loss).backward()
        else:
            loss.backward()

        if need_gradient_step:
            self.grad_step(
                optimizer=self._optimizer,
                grad_clip_fn=self.grad_clip_fn,
            )
            if not self.use_fast_zero_grad:
                maybe_recursive_call(self._optimizer, "zero_grad")
            else:
                maybe_recursive_call(self._optimizer, zero_grad)
            self._accumulation_counter = 0
 def __enter__(self):
     self.prev = torch.is_autocast_enabled()
     torch.set_autocast_enabled(self._enabled)
     torch.autocast_increment_nesting()