コード例 #1
0
    def on_batch_start(self, runner: IRunner) -> None:
        """On batch start event

        Args:
            runner: current runner
        """
        self.prev_autocast_state = torch.is_autocast_enabled()
        torch.set_autocast_enabled(True)
        torch.autocast_increment_nesting()
コード例 #2
0
 def __enter__(self):
     if self.device == 'cpu':
         self.prev = torch.is_autocast_cpu_enabled()
         self.prev_fastdtype = torch.get_autocast_cpu_dtype()
         torch.set_autocast_cpu_enabled(self._enabled)
         torch.set_autocast_cpu_dtype(self.fast_dtype)
         torch.autocast_increment_nesting()
     else:
         self.prev = torch.is_autocast_enabled()
         self.prev_fastdtype = torch.get_autocast_gpu_dtype()
         torch.set_autocast_gpu_dtype(self.fast_dtype)
         torch.set_autocast_enabled(self._enabled)
         torch.autocast_increment_nesting()
コード例 #3
0
ファイル: autocast_mode.py プロジェクト: xkszltl/pytorch
    def __enter__(self):
        if torch._jit_internal.is_scripting():
            assert self.fast_dtype is not None
            return self

        self.prev_cache_enabled = torch.is_autocast_cache_enabled()
        if self.device == 'cpu':
            self.prev = torch.is_autocast_cpu_enabled()
            self.prev_fastdtype = torch.get_autocast_cpu_dtype()
            torch.set_autocast_cpu_enabled(self._enabled)
            torch.set_autocast_cpu_dtype(
                self.fast_dtype)  # type: ignore[arg-type]
            torch.autocast_increment_nesting()
        else:
            self.prev = torch.is_autocast_enabled()
            self.prev_fastdtype = torch.get_autocast_gpu_dtype()
            torch.set_autocast_gpu_dtype(
                self.fast_dtype)  # type: ignore[arg-type]
            torch.set_autocast_enabled(self._enabled)
            torch.autocast_increment_nesting()
        torch.set_autocast_cache_enabled(self._cache_enabled)
コード例 #4
0
 def __enter__(self):
     self.prev = torch.is_autocast_enabled()
     torch.set_autocast_enabled(self._enabled)
     torch.autocast_increment_nesting()
コード例 #5
0
 def __enter__(self):
     self.prev = torch.is_autocast_cpu_enabled()
     self.prev_dtype = torch.get_autocast_cpu_dtype()
     torch.set_autocast_cpu_enabled(self._enabled)
     torch.set_autocast_cpu_dtype(self._dtype)
     torch.autocast_increment_nesting()