Example #1
0
    def __exit__(self, exc_type: Any, exc_val: Any,
                 exc_tb: Any):  # type: ignore[override]
        if torch._jit_internal.is_scripting():
            return

        # Drop the cache when we exit to a nesting level that's outside any instance of autocast.
        if self.device == 'cpu':
            if torch.autocast_decrement_nesting() == 0:
                torch.clear_autocast_cache()
            torch.set_autocast_cpu_enabled(self.prev)
            torch.set_autocast_cpu_dtype(self.prev_fastdtype)
        elif self.device == 'xpu':
            if torch.autocast_decrement_nesting() == 0:
                torch.clear_autocast_cache()
            torch.xpu.set_autocast_xpu_enabled(
                self.prev)  # type: ignore[attr-defined]
            torch.xpu.set_autocast_xpu_dtype(
                self.prev_fastdtype)  # type: ignore[attr-defined]
        else:
            if torch.autocast_decrement_nesting() == 0:
                torch.clear_autocast_cache()
            torch.set_autocast_enabled(self.prev)
            torch.set_autocast_gpu_dtype(self.prev_fastdtype)
        torch.set_autocast_cache_enabled(self.prev_cache_enabled)
        return False
Example #2
0
    def __enter__(self):
        if torch._jit_internal.is_scripting():
            assert self.fast_dtype is not None
            return self

        self.prev_cache_enabled = torch.is_autocast_cache_enabled()
        if self.device == 'cpu':
            self.prev = torch.is_autocast_cpu_enabled()
            self.prev_fastdtype = torch.get_autocast_cpu_dtype()
            torch.set_autocast_cpu_enabled(self._enabled)
            torch.set_autocast_cpu_dtype(
                self.fast_dtype)  # type: ignore[arg-type]
            torch.autocast_increment_nesting()
        elif self.device == 'xpu':
            self.prev = torch.xpu.is_autocast_xpu_enabled(
            )  # type: ignore[attr-defined]
            self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype(
            )  # type: ignore[attr-defined]
            torch.xpu.set_autocast_xpu_enabled(
                self._enabled)  # type: ignore[attr-defined]
            torch.xpu.set_autocast_xpu_dtype(
                self.fast_dtype)  # type: ignore[attr-defined]
            torch.autocast_increment_nesting()
        else:
            self.prev = torch.is_autocast_enabled()
            self.prev_fastdtype = torch.get_autocast_gpu_dtype()
            torch.set_autocast_gpu_dtype(
                self.fast_dtype)  # type: ignore[arg-type]
            torch.set_autocast_enabled(self._enabled)
            torch.autocast_increment_nesting()
        torch.set_autocast_cache_enabled(self._cache_enabled)
Example #3
0
 def __exit__(self, *args):
     # Drop the cache when we exit to a nesting level that's outside any instance of autocast.
     if torch.autocast_decrement_nesting() == 0:
         torch.clear_autocast_cache()
     torch.set_autocast_cpu_enabled(self.prev)
     torch.set_autocast_cpu_dtype(self.prev_dtype)
     return False
Example #4
0
 def __enter__(self):
     if self.device == 'cpu':
         self.prev = torch.is_autocast_cpu_enabled()
         self.prev_fastdtype = torch.get_autocast_cpu_dtype()
         torch.set_autocast_cpu_enabled(self._enabled)
         torch.set_autocast_cpu_dtype(self.fast_dtype)
         torch.autocast_increment_nesting()
     else:
         self.prev = torch.is_autocast_enabled()
         self.prev_fastdtype = torch.get_autocast_gpu_dtype()
         torch.set_autocast_gpu_dtype(self.fast_dtype)
         torch.set_autocast_enabled(self._enabled)
         torch.autocast_increment_nesting()
Example #5
0
 def __enter__(self):
     self.prev = torch.is_autocast_cpu_enabled()
     self.prev_dtype = torch.get_autocast_cpu_dtype()
     torch.set_autocast_cpu_enabled(self._enabled)
     torch.set_autocast_cpu_dtype(self._dtype)
     torch.autocast_increment_nesting()