def fn( x, y ) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]: b1 = torch.is_autocast_cpu_enabled() v1 = torch.mm(x, y) with torch.cpu.amp.autocast(enabled=True): b2 = torch.is_autocast_cpu_enabled() v2 = torch.mm(x, y) with torch.cpu.amp.autocast(enabled=False): b3 = torch.is_autocast_cpu_enabled() v3 = torch.mm(x, y) return (v1, b1, v2, b2, v3, b3)
def __enter__(self): if torch._jit_internal.is_scripting(): assert self.fast_dtype is not None return self self.prev_cache_enabled = torch.is_autocast_cache_enabled() if self.device == 'cpu': self.prev = torch.is_autocast_cpu_enabled() self.prev_fastdtype = torch.get_autocast_cpu_dtype() torch.set_autocast_cpu_enabled(self._enabled) torch.set_autocast_cpu_dtype( self.fast_dtype) # type: ignore[arg-type] torch.autocast_increment_nesting() elif self.device == 'xpu': self.prev = torch.xpu.is_autocast_xpu_enabled( ) # type: ignore[attr-defined] self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype( ) # type: ignore[attr-defined] torch.xpu.set_autocast_xpu_enabled( self._enabled) # type: ignore[attr-defined] torch.xpu.set_autocast_xpu_dtype( self.fast_dtype) # type: ignore[attr-defined] torch.autocast_increment_nesting() else: self.prev = torch.is_autocast_enabled() self.prev_fastdtype = torch.get_autocast_gpu_dtype() torch.set_autocast_gpu_dtype( self.fast_dtype) # type: ignore[arg-type] torch.set_autocast_enabled(self._enabled) torch.autocast_increment_nesting() torch.set_autocast_cache_enabled(self._cache_enabled)
def _get_autocast_kwargs(): gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), "dtype": torch.get_autocast_gpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled()} cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(), "dtype": torch.get_autocast_cpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled()} return gpu_autocast_kwargs, cpu_autocast_kwargs
def __enter__(self): if self.device == 'cpu': self.prev = torch.is_autocast_cpu_enabled() self.prev_fastdtype = torch.get_autocast_cpu_dtype() torch.set_autocast_cpu_enabled(self._enabled) torch.set_autocast_cpu_dtype(self.fast_dtype) torch.autocast_increment_nesting() else: self.prev = torch.is_autocast_enabled() self.prev_fastdtype = torch.get_autocast_gpu_dtype() torch.set_autocast_gpu_dtype(self.fast_dtype) torch.set_autocast_enabled(self._enabled) torch.autocast_increment_nesting()
def forward(ctx, run_function, preserve_rng_state, *args): check_backward_validity(args) ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. ctx.gpu_autocast_kwargs = { "enabled": torch.is_autocast_enabled(), "dtype": torch.get_autocast_gpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled() } ctx.cpu_autocast_kwargs = { "enabled": torch.is_autocast_cpu_enabled(), "dtype": torch.get_autocast_cpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled() } if preserve_rng_state: ctx.fwd_cpu_state = torch.get_rng_state() # Don't eagerly initialize the cuda context by accident. # (If the user intends that the context is initialized later, within their # run_function, we SHOULD actually stash the cuda state here. Unfortunately, # we have no way to anticipate this will happen before we run the function.) ctx.had_cuda_in_fwd = False if torch.cuda._initialized: ctx.had_cuda_in_fwd = True ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) # Save non-tensor inputs in ctx, keep a placeholder None for tensors # to be filled out during the backward. ctx.inputs = [] ctx.tensor_indices = [] tensor_inputs = [] tensor_outputs = [] for i, arg in enumerate(args): if torch.is_tensor(arg): tensor_inputs.append(arg) ctx.tensor_indices.append(i) ctx.inputs.append(None) else: ctx.inputs.append(arg) ctx.save_for_backward(*tensor_inputs) with torch.no_grad(): outputs = run_function(*args) return outputs
def _run_autocast_outofplace(self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None): # helper to cast args def cast(val, to_type): if isinstance(val, torch.Tensor): return val.to(to_type) if val.is_floating_point() else val elif isinstance(val, collections.abc.Iterable): return type(val)(cast(v, to_type) for v in val) else: return val if add_kwargs is None: add_kwargs = {} self.assertFalse(torch.is_autocast_cpu_enabled()) with torch.cpu.amp.autocast(): self.assertTrue(torch.is_autocast_cpu_enabled()) out_type = out_type if out_type is not None else run_as_type output = output_method = None # Try module.* variant, if requested: if module is not None and hasattr(module, op): output = getattr(module, op)(*args, **add_kwargs) if isinstance(output, torch.Tensor): self.assertTrue(out_type == output.dtype, "autocast for torch.{} produced {}, should produce {}" .format(op, output.dtype, out_type)) # Try Tensor.* variant: if hasattr(torch.Tensor, op): output_method = getattr(args[0], op)(*args[1:], **add_kwargs) if isinstance(output_method, torch.Tensor): self.assertTrue(out_type == output_method.dtype, "autocast for torch.{} produced {}, should produce torch.{}" .format(op, output_method.dtype, out_type)) self.assertTrue((output is not None) or (output_method is not None), "{} not found as an attribute on either Tensor or the requested module {}".format( op, module)) # Accounts for ops that return Tensors, iterables, and other non-Tensors. # For example, lstm_cell returns a tuple and equal returns bool. def compare(first, second): if isinstance(first, torch.Tensor): return torch.equal(first, second) elif isinstance(first, collections.abc.Iterable): return all(compare(f, s) for f, s in zip(first, second)) else: return first == second # If both torch.* and Tensor.* variants were found, check outputs are identical if (output is not None) and (output_method is not None): self.assertTrue(type(output) == type(output_method)) comparison = compare(output, output_method) self.assertTrue(comparison, "torch.{0} result did not match Tensor.{0} result".format(op)) # Compare numerics to Python-side "autocasting" that (we expect) does the same thing # as the C++-side autocasting, and should be bitwise accurate. output_to_compare = output if output is not None else output_method with torch.cpu.amp.autocast(enabled=False): self.assertFalse(torch.is_autocast_cpu_enabled()) if module is not None and hasattr(module, op): control = getattr(module, op)(*cast(args, run_as_type), **add_kwargs) else: control = getattr(args[0].to(run_as_type), op)(*cast(args[1:], run_as_type), **add_kwargs) self.assertTrue(type(output_to_compare) == type(control)) comparison = compare(output_to_compare, control) self.assertTrue(comparison, "torch.{} result did not match control".format(op)) self.assertTrue(torch.is_autocast_cpu_enabled()) self.assertFalse(torch.is_autocast_cpu_enabled())
def _assert_autocast_enabled(self): if self.trainer.precision_plugin.device == "cpu": assert torch.is_autocast_cpu_enabled() else: assert torch.is_autocast_enabled()
def __enter__(self): self.prev = torch.is_autocast_cpu_enabled() self.prev_dtype = torch.get_autocast_cpu_dtype() torch.set_autocast_cpu_enabled(self._enabled) torch.set_autocast_cpu_dtype(self._dtype) torch.autocast_increment_nesting()
def fn(x): if torch.is_autocast_cpu_enabled(): return x.relu() else: return x.sin()