示例#1
0
    def __enter__(self):
        if torch._jit_internal.is_scripting():
            assert self.fast_dtype is not None
            return self

        self.prev_cache_enabled = torch.is_autocast_cache_enabled()
        if self.device == 'cpu':
            self.prev = torch.is_autocast_cpu_enabled()
            self.prev_fastdtype = torch.get_autocast_cpu_dtype()
            torch.set_autocast_cpu_enabled(self._enabled)
            torch.set_autocast_cpu_dtype(
                self.fast_dtype)  # type: ignore[arg-type]
            torch.autocast_increment_nesting()
        elif self.device == 'xpu':
            self.prev = torch.xpu.is_autocast_xpu_enabled(
            )  # type: ignore[attr-defined]
            self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype(
            )  # type: ignore[attr-defined]
            torch.xpu.set_autocast_xpu_enabled(
                self._enabled)  # type: ignore[attr-defined]
            torch.xpu.set_autocast_xpu_dtype(
                self.fast_dtype)  # type: ignore[attr-defined]
            torch.autocast_increment_nesting()
        else:
            self.prev = torch.is_autocast_enabled()
            self.prev_fastdtype = torch.get_autocast_gpu_dtype()
            torch.set_autocast_gpu_dtype(
                self.fast_dtype)  # type: ignore[arg-type]
            torch.set_autocast_enabled(self._enabled)
            torch.autocast_increment_nesting()
        torch.set_autocast_cache_enabled(self._cache_enabled)
示例#2
0
    def __init__(self, device_type, enabled=True, **kwargs):
        self.device = device_type
        if self.device == 'cuda':
            self.fast_dtype = torch.get_autocast_gpu_dtype()
        elif self.device == 'cpu':
            self.fast_dtype = torch.get_autocast_cpu_dtype()
        else:
            raise RuntimeError('User specified autocast device_type must be \'cuda\' or \'cpu\'')
        if torch.cuda.amp.common.amp_definitely_not_available() and self.device == 'cuda':
            warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling')
            enabled = False
        for key, value in kwargs.items():
            if key == 'dtype':
                self.fast_dtype = value
            if not (key == 'dtype'):
                raise RuntimeError('Unrecognized optional argument supplied to autocast context manager: ' + str(key))

        if self.device == 'cpu':
            supported_dtype = [torch.bfloat16]
            if self.fast_dtype not in supported_dtype:
                error_message = 'In CPU autocast, but the target dtype is not supported. Disabling autocast.\n'
                error_message += 'CPU Autocast only supports dtype of torch.bfloat16 currently.'
                warnings.warn(error_message)
                enabled = False
        if self.device == 'cuda':
            if self.fast_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
                raise RuntimeError('Current CUDA Device does not support bfloat16. Please switch dtype to float16.')
        self._enabled = enabled
示例#3
0
    def __init__(self,
                 device_type: str,
                 dtype: Optional[_dtype] = None,
                 enabled: bool = True,
                 cache_enabled: Optional[bool] = None):
        if torch._jit_internal.is_scripting():
            self._enabled = enabled
            self.device = device_type
            self.fast_dtype = dtype
            # TODO: support get_autocast_gpu/cpu_dtype
            assert dtype is not None
            return
        self.device = device_type
        if self.device == 'cuda':
            self.fast_dtype = torch.get_autocast_gpu_dtype()
        elif self.device == 'cpu':
            self.fast_dtype = torch.get_autocast_cpu_dtype()
        elif self.device == 'xpu':
            self.fast_dtype = torch.xpu.get_autocast_xpu_dtype(
            )  # type: ignore[attr-defined]
        else:
            raise RuntimeError(
                'User specified autocast device_type must be \'cuda\' or \'cpu\''
            )
        self._cache_enabled = torch.is_autocast_cache_enabled()
        if torch.cuda.amp.common.amp_definitely_not_available(
        ) and self.device == 'cuda':
            warnings.warn(
                'User provided device_type of \'cuda\', but CUDA is not available. Disabling'
            )
            enabled = False
        if dtype is not None:
            self.fast_dtype = dtype
        if cache_enabled is not None:
            self._cache_enabled = cache_enabled

        if self.device == 'cpu':
            supported_dtype = [torch.bfloat16]
            if self.fast_dtype not in supported_dtype:
                error_message = 'In CPU autocast, but the target dtype is not supported. Disabling autocast.\n'
                error_message += 'CPU Autocast only supports dtype of torch.bfloat16 currently.'
                warnings.warn(error_message)
                enabled = False
        if self.device == 'xpu':
            supported_dtype = [torch.bfloat16, torch.float16]
            if self.fast_dtype not in supported_dtype:
                error_message = 'In XPU autocast, but the target dtype is not supported. Disabling autocast.\n'
                error_message += 'XPU Autocast only supports dtype of torch.bfloat16 currently.'
                warnings.warn(error_message)
                enabled = False
        if self.device == 'cuda':
            if self.fast_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(
            ):
                raise RuntimeError(
                    'Current CUDA Device does not support bfloat16. Please switch dtype to float16.'
                )
        self._enabled = enabled
示例#4
0
def _get_autocast_kwargs():
    gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
                           "dtype": torch.get_autocast_gpu_dtype(),
                           "cache_enabled": torch.is_autocast_cache_enabled()}

    cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(),
                           "dtype": torch.get_autocast_cpu_dtype(),
                           "cache_enabled": torch.is_autocast_cache_enabled()}

    return gpu_autocast_kwargs, cpu_autocast_kwargs
示例#5
0
 def __enter__(self):
     if self.device == 'cpu':
         self.prev = torch.is_autocast_cpu_enabled()
         self.prev_fastdtype = torch.get_autocast_cpu_dtype()
         torch.set_autocast_cpu_enabled(self._enabled)
         torch.set_autocast_cpu_dtype(self.fast_dtype)
         torch.autocast_increment_nesting()
     else:
         self.prev = torch.is_autocast_enabled()
         self.prev_fastdtype = torch.get_autocast_gpu_dtype()
         torch.set_autocast_gpu_dtype(self.fast_dtype)
         torch.set_autocast_enabled(self._enabled)
         torch.autocast_increment_nesting()
示例#6
0
文件: checkpoint.py 项目: pytorch/xla
  def forward(ctx, run_function, preserve_rng_state, *args):
    check_backward_validity(args)
    ctx.run_function = run_function
    ctx.preserve_rng_state = preserve_rng_state
    # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
    ctx.gpu_autocast_kwargs = {
        "enabled": torch.is_autocast_enabled(),
        "dtype": torch.get_autocast_gpu_dtype(),
        "cache_enabled": torch.is_autocast_cache_enabled()
    }
    ctx.cpu_autocast_kwargs = {
        "enabled": torch.is_autocast_cpu_enabled(),
        "dtype": torch.get_autocast_cpu_dtype(),
        "cache_enabled": torch.is_autocast_cache_enabled()
    }
    if preserve_rng_state:
      ctx.fwd_cpu_state = torch.get_rng_state()
      # Don't eagerly initialize the cuda context by accident.
      # (If the user intends that the context is initialized later, within their
      # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
      # we have no way to anticipate this will happen before we run the function.)
      ctx.had_cuda_in_fwd = False
      if torch.cuda._initialized:
        ctx.had_cuda_in_fwd = True
        ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)

    # Save non-tensor inputs in ctx, keep a placeholder None for tensors
    # to be filled out during the backward.
    ctx.inputs = []
    ctx.tensor_indices = []
    tensor_inputs = []
    tensor_outputs = []
    for i, arg in enumerate(args):
      if torch.is_tensor(arg):
        tensor_inputs.append(arg)
        ctx.tensor_indices.append(i)
        ctx.inputs.append(None)
      else:
        ctx.inputs.append(arg)

    ctx.save_for_backward(*tensor_inputs)

    with torch.no_grad():
      outputs = run_function(*args)

    return outputs
示例#7
0
 def test_autocast_fast_dtype(self):
     gpu_fast_dtype = torch.get_autocast_gpu_dtype()
     cpu_fast_dtype = torch.get_autocast_cpu_dtype()
     self.assertEqual(gpu_fast_dtype, torch.half)
     self.assertEqual(cpu_fast_dtype, torch.bfloat16)
示例#8
0
def _cast_if_autocast_enabled(*args):
    if not torch.is_autocast_enabled():
        return args
    else:
        return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype())
示例#9
0
def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype:
    if not torch.is_autocast_enabled():
        return torch.float or dtype
    else:
        return torch.get_autocast_gpu_dtype()