def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] if torch._jit_internal.is_scripting(): return # Drop the cache when we exit to a nesting level that's outside any instance of autocast. if self.device == 'cpu': if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.set_autocast_cpu_enabled(self.prev) torch.set_autocast_cpu_dtype(self.prev_fastdtype) elif self.device == 'xpu': if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.xpu.set_autocast_xpu_enabled( self.prev) # type: ignore[attr-defined] torch.xpu.set_autocast_xpu_dtype( self.prev_fastdtype) # type: ignore[attr-defined] else: if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.set_autocast_enabled(self.prev) torch.set_autocast_gpu_dtype(self.prev_fastdtype) torch.set_autocast_cache_enabled(self.prev_cache_enabled) return False
def __enter__(self): if torch._jit_internal.is_scripting(): assert self.fast_dtype is not None return self self.prev_cache_enabled = torch.is_autocast_cache_enabled() if self.device == 'cpu': self.prev = torch.is_autocast_cpu_enabled() self.prev_fastdtype = torch.get_autocast_cpu_dtype() torch.set_autocast_cpu_enabled(self._enabled) torch.set_autocast_cpu_dtype( self.fast_dtype) # type: ignore[arg-type] torch.autocast_increment_nesting() elif self.device == 'xpu': self.prev = torch.xpu.is_autocast_xpu_enabled( ) # type: ignore[attr-defined] self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype( ) # type: ignore[attr-defined] torch.xpu.set_autocast_xpu_enabled( self._enabled) # type: ignore[attr-defined] torch.xpu.set_autocast_xpu_dtype( self.fast_dtype) # type: ignore[attr-defined] torch.autocast_increment_nesting() else: self.prev = torch.is_autocast_enabled() self.prev_fastdtype = torch.get_autocast_gpu_dtype() torch.set_autocast_gpu_dtype( self.fast_dtype) # type: ignore[arg-type] torch.set_autocast_enabled(self._enabled) torch.autocast_increment_nesting() torch.set_autocast_cache_enabled(self._cache_enabled)
def __exit__(self, *args): # Drop the cache when we exit to a nesting level that's outside any instance of autocast. if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.set_autocast_cpu_enabled(self.prev) torch.set_autocast_cpu_dtype(self.prev_dtype) return False
def __enter__(self): if self.device == 'cpu': self.prev = torch.is_autocast_cpu_enabled() self.prev_fastdtype = torch.get_autocast_cpu_dtype() torch.set_autocast_cpu_enabled(self._enabled) torch.set_autocast_cpu_dtype(self.fast_dtype) torch.autocast_increment_nesting() else: self.prev = torch.is_autocast_enabled() self.prev_fastdtype = torch.get_autocast_gpu_dtype() torch.set_autocast_gpu_dtype(self.fast_dtype) torch.set_autocast_enabled(self._enabled) torch.autocast_increment_nesting()
def __enter__(self): self.prev = torch.is_autocast_cpu_enabled() self.prev_dtype = torch.get_autocast_cpu_dtype() torch.set_autocast_cpu_enabled(self._enabled) torch.set_autocast_cpu_dtype(self._dtype) torch.autocast_increment_nesting()