def _get_autocast_kwargs(): gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), "dtype": torch.get_autocast_gpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled()} cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(), "dtype": torch.get_autocast_cpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled()} return gpu_autocast_kwargs, cpu_autocast_kwargs
def __enter__(self): if torch._jit_internal.is_scripting(): assert self.fast_dtype is not None return self self.prev_cache_enabled = torch.is_autocast_cache_enabled() if self.device == 'cpu': self.prev = torch.is_autocast_cpu_enabled() self.prev_fastdtype = torch.get_autocast_cpu_dtype() torch.set_autocast_cpu_enabled(self._enabled) torch.set_autocast_cpu_dtype( self.fast_dtype) # type: ignore[arg-type] torch.autocast_increment_nesting() elif self.device == 'xpu': self.prev = torch.xpu.is_autocast_xpu_enabled( ) # type: ignore[attr-defined] self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype( ) # type: ignore[attr-defined] torch.xpu.set_autocast_xpu_enabled( self._enabled) # type: ignore[attr-defined] torch.xpu.set_autocast_xpu_dtype( self.fast_dtype) # type: ignore[attr-defined] torch.autocast_increment_nesting() else: self.prev = torch.is_autocast_enabled() self.prev_fastdtype = torch.get_autocast_gpu_dtype() torch.set_autocast_gpu_dtype( self.fast_dtype) # type: ignore[arg-type] torch.set_autocast_enabled(self._enabled) torch.autocast_increment_nesting() torch.set_autocast_cache_enabled(self._cache_enabled)
def forward(ctx, run_function, preserve_rng_state, *args): check_backward_validity(args) ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. ctx.gpu_autocast_kwargs = { "enabled": torch.is_autocast_enabled(), "dtype": torch.get_autocast_gpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled() } ctx.cpu_autocast_kwargs = { "enabled": torch.is_autocast_cpu_enabled(), "dtype": torch.get_autocast_cpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled() } if preserve_rng_state: ctx.fwd_cpu_state = torch.get_rng_state() # Don't eagerly initialize the cuda context by accident. # (If the user intends that the context is initialized later, within their # run_function, we SHOULD actually stash the cuda state here. Unfortunately, # we have no way to anticipate this will happen before we run the function.) ctx.had_cuda_in_fwd = False if torch.cuda._initialized: ctx.had_cuda_in_fwd = True ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) # Save non-tensor inputs in ctx, keep a placeholder None for tensors # to be filled out during the backward. ctx.inputs = [] ctx.tensor_indices = [] tensor_inputs = [] tensor_outputs = [] for i, arg in enumerate(args): if torch.is_tensor(arg): tensor_inputs.append(arg) ctx.tensor_indices.append(i) ctx.inputs.append(None) else: ctx.inputs.append(arg) ctx.save_for_backward(*tensor_inputs) with torch.no_grad(): outputs = run_function(*args) return outputs
def __init__(self, device_type: str, dtype: Optional[_dtype] = None, enabled: bool = True, cache_enabled: Optional[bool] = None): if torch._jit_internal.is_scripting(): self._enabled = enabled self.device = device_type self.fast_dtype = dtype # TODO: support get_autocast_gpu/cpu_dtype assert dtype is not None return self.device = device_type if self.device == 'cuda': self.fast_dtype = torch.get_autocast_gpu_dtype() elif self.device == 'cpu': self.fast_dtype = torch.get_autocast_cpu_dtype() elif self.device == 'xpu': self.fast_dtype = torch.xpu.get_autocast_xpu_dtype( ) # type: ignore[attr-defined] else: raise RuntimeError( 'User specified autocast device_type must be \'cuda\' or \'cpu\'' ) self._cache_enabled = torch.is_autocast_cache_enabled() if torch.cuda.amp.common.amp_definitely_not_available( ) and self.device == 'cuda': warnings.warn( 'User provided device_type of \'cuda\', but CUDA is not available. Disabling' ) enabled = False if dtype is not None: self.fast_dtype = dtype if cache_enabled is not None: self._cache_enabled = cache_enabled if self.device == 'cpu': supported_dtype = [torch.bfloat16] if self.fast_dtype not in supported_dtype: error_message = 'In CPU autocast, but the target dtype is not supported. Disabling autocast.\n' error_message += 'CPU Autocast only supports dtype of torch.bfloat16 currently.' warnings.warn(error_message) enabled = False if self.device == 'xpu': supported_dtype = [torch.bfloat16, torch.float16] if self.fast_dtype not in supported_dtype: error_message = 'In XPU autocast, but the target dtype is not supported. Disabling autocast.\n' error_message += 'XPU Autocast only supports dtype of torch.bfloat16 currently.' warnings.warn(error_message) enabled = False if self.device == 'cuda': if self.fast_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported( ): raise RuntimeError( 'Current CUDA Device does not support bfloat16. Please switch dtype to float16.' ) self._enabled = enabled
def __enter__(self): self.prev_cache_enabled = torch.is_autocast_cache_enabled() if self.device == 'cpu': self.prev = torch.is_autocast_cpu_enabled() self.prev_fastdtype = torch.get_autocast_cpu_dtype() torch.set_autocast_cpu_enabled(self._enabled) torch.set_autocast_cpu_dtype(self.fast_dtype) torch.autocast_increment_nesting() else: self.prev = torch.is_autocast_enabled() self.prev_fastdtype = torch.get_autocast_gpu_dtype() torch.set_autocast_gpu_dtype(self.fast_dtype) torch.set_autocast_enabled(self._enabled) torch.autocast_increment_nesting() torch.set_autocast_cache_enabled(self._cache_enabled)
def __init__(self, device_type, enabled=True, **kwargs): self.device = device_type if self.device == 'cuda': self.fast_dtype = torch.get_autocast_gpu_dtype() elif self.device == 'cpu': self.fast_dtype = torch.get_autocast_cpu_dtype() else: raise RuntimeError( 'User specified autocast device_type must be \'cuda\' or \'cpu\'' ) self._cache_enabled = torch.is_autocast_cache_enabled() if torch.cuda.amp.common.amp_definitely_not_available( ) and self.device == 'cuda': warnings.warn( 'User provided device_type of \'cuda\', but CUDA is not available. Disabling' ) enabled = False for key, value in kwargs.items(): if key == 'dtype': self.fast_dtype = value if key == 'cache_enabled': self._cache_enabled = value if not ((key == 'dtype') or (key == 'cache_enabled')): raise RuntimeError( 'Unrecognized optional argument supplied to autocast context manager: ' + str(key)) if self.device == 'cpu': supported_dtype = [torch.bfloat16] if self.fast_dtype not in supported_dtype: error_message = 'In CPU autocast, but the target dtype is not supported. Disabling autocast.\n' error_message += 'CPU Autocast only supports dtype of torch.bfloat16 currently.' warnings.warn(error_message) enabled = False if self.device == 'cuda': if self.fast_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported( ): raise RuntimeError( 'Current CUDA Device does not support bfloat16. Please switch dtype to float16.' ) self._enabled = enabled