def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] if torch._jit_internal.is_scripting(): return # Drop the cache when we exit to a nesting level that's outside any instance of autocast. if self.device == 'cpu': if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.set_autocast_cpu_enabled(self.prev) torch.set_autocast_cpu_dtype(self.prev_fastdtype) elif self.device == 'xpu': if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.xpu.set_autocast_xpu_enabled( self.prev) # type: ignore[attr-defined] torch.xpu.set_autocast_xpu_dtype( self.prev_fastdtype) # type: ignore[attr-defined] else: if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.set_autocast_enabled(self.prev) torch.set_autocast_gpu_dtype(self.prev_fastdtype) torch.set_autocast_cache_enabled(self.prev_cache_enabled) return False
def on_batch_end(self, runner: IRunner) -> None: """On batch end event Args: runner: current runner """ # Drop the cache when we exit to a nesting level # that's outside any instance of autocast. if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.set_autocast_enabled(self.prev_autocast_state) if not runner.is_train_loader: return loss = runner.batch_metrics[self.metric_key] self._accumulation_counter += 1 need_gradient_step = (self._accumulation_counter % self.accumulation_steps == 0) self.scaler.scale(loss).backward() if need_gradient_step: self.grad_step( optimizer=self._optimizer, grad_clip_fn=self.grad_clip_fn, ) utils.maybe_recursive_call(self._optimizer, "zero_grad") self._accumulation_counter = 0
def __enter__(self): if torch._jit_internal.is_scripting(): assert self.fast_dtype is not None return self self.prev_cache_enabled = torch.is_autocast_cache_enabled() if self.device == 'cpu': self.prev = torch.is_autocast_cpu_enabled() self.prev_fastdtype = torch.get_autocast_cpu_dtype() torch.set_autocast_cpu_enabled(self._enabled) torch.set_autocast_cpu_dtype( self.fast_dtype) # type: ignore[arg-type] torch.autocast_increment_nesting() elif self.device == 'xpu': self.prev = torch.xpu.is_autocast_xpu_enabled( ) # type: ignore[attr-defined] self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype( ) # type: ignore[attr-defined] torch.xpu.set_autocast_xpu_enabled( self._enabled) # type: ignore[attr-defined] torch.xpu.set_autocast_xpu_dtype( self.fast_dtype) # type: ignore[attr-defined] torch.autocast_increment_nesting() else: self.prev = torch.is_autocast_enabled() self.prev_fastdtype = torch.get_autocast_gpu_dtype() torch.set_autocast_gpu_dtype( self.fast_dtype) # type: ignore[arg-type] torch.set_autocast_enabled(self._enabled) torch.autocast_increment_nesting() torch.set_autocast_cache_enabled(self._cache_enabled)
def __exit__(self, *args): if self._enabled: # Drop the cache when we exit to a nesting level that's outside any instance of autocast if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.set_autocast_enabled(self.prev) return False
def on_batch_start(self, runner: IRunner) -> None: """On batch start event Args: runner: current runner """ self.prev_autocast_state = torch.is_autocast_enabled() torch.set_autocast_enabled(True) torch.autocast_increment_nesting()
def __exit__(self, *args): # Drop the cache when we exit to a nesting level that's outside any instance of autocast. if self.device == 'cpu': if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.set_autocast_cpu_enabled(self.prev) torch.set_autocast_cpu_dtype(self.prev_fastdtype) else: if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.set_autocast_enabled(self.prev) torch.set_autocast_gpu_dtype(self.prev_fastdtype) return False
def __enter__(self): if self.device == 'cpu': self.prev = torch.is_autocast_cpu_enabled() self.prev_fastdtype = torch.get_autocast_cpu_dtype() torch.set_autocast_cpu_enabled(self._enabled) torch.set_autocast_cpu_dtype(self.fast_dtype) torch.autocast_increment_nesting() else: self.prev = torch.is_autocast_enabled() self.prev_fastdtype = torch.get_autocast_gpu_dtype() torch.set_autocast_gpu_dtype(self.fast_dtype) torch.set_autocast_enabled(self._enabled) torch.autocast_increment_nesting()
def on_batch_end(self, runner: "IRunner") -> None: """On batch end event Args: runner: current runner """ if self.use_amp: # Drop the cache when we exit to a nesting level # that's outside any instance of autocast. if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.set_autocast_enabled(self.prev_autocast_state) if not runner.is_train_loader: return loss = runner.batch_metrics[self.metric_key] self._accumulation_counter += 1 need_gradient_step = (self._accumulation_counter % self.accumulation_steps == 0) # @TODO: speedup with re-definition ``on_stage_start`` if self.use_apex: from apex import amp # Need to set ``delay_unscale`` # according to # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations delay_unscale = not need_gradient_step with amp.scale_loss(loss, self._optimizer, delay_unscale=delay_unscale) as scaled_loss: scaled_loss.backward() elif self.use_amp: self.scaler.scale(loss).backward() else: loss.backward() if need_gradient_step: self.grad_step( optimizer=self._optimizer, grad_clip_fn=self.grad_clip_fn, ) if not self.use_fast_zero_grad: maybe_recursive_call(self._optimizer, "zero_grad") else: maybe_recursive_call(self._optimizer, zero_grad) self._accumulation_counter = 0
def __enter__(self): self.prev = torch.is_autocast_enabled() torch.set_autocast_enabled(self._enabled) torch.autocast_increment_nesting()