def assert_functionalization(self, func, inpt, *, reapply_views=False): input_clone = inpt.clone() input_clone2 = inpt.clone() input_functional = torch._to_functional_tensor(input_clone2) # Compare outputs (and mutated inputs), with and without functionalization. out_ref = func(inpt) torch._enable_functionalization(reapply_views=reapply_views) try: out_functional = func(input_functional) finally: torch._disable_functionalization() # We need to sync the input tensors first, in case there are any queued mutations left. torch._sync(input_functional) self.assertEqual(inpt, torch._from_functional_tensor( input_functional)) # input mutations should still occur # Handle tests with multi-tensor outputs if isinstance(out_ref, tuple) and isinstance(out_functional, tuple): out_refs, out_functionals = list(out_ref), list(out_functional) else: out_refs, out_functionals = [out_ref], [out_functional] for out_ref_, out_functional_ in zip(out_refs, out_functionals): self.assertEqual(out_ref_.size(), out_functional_.size()) torch._sync(out_functional_) out_functional_unwrapped = torch._from_functional_tensor( out_functional_) self.assertEqual(out_ref_, out_functional_unwrapped)
def assert_functionalization(self, func, inpt): input_clone = inpt.clone() input_clone2 = inpt.clone() input_functional = torch._to_functional_tensor(input_clone2) # Compare outputs (and mutated inputs), with and without functionalization. out_ref = func(inpt) torch._enable_functionalization() try: out_functional = func(input_functional) finally: torch._disable_functionalization() # We need to sync the input tensors first, in case there are any queued mutations left. torch._sync(input_functional) torch._sync(out_functional) self.assertEqual(out_ref, torch._from_functional_tensor(out_functional)) self.assertEqual(inpt, torch._from_functional_tensor(input_functional)) # input mutations should still occur
def test_aliases_maintained_after_pass(self): def f(x): tmp = torch.ones(4, 2) y = x.view(4, 2) z = x.view(4, 2) y.add_(tmp) return y, z input_functional = torch._to_functional_tensor(torch.ones(4, 2)) torch._enable_functionalization() try: y, z = f(input_functional) torch._sync(y) torch._sync(z) finally: torch._disable_functionalization() # y and z are aliases inside of the function, and that aliasing relationship should be maintained. _y = torch._from_functional_tensor(y) _z = torch._from_functional_tensor(z) self.assertTrue(are_aliased(_y, _z))
def _str_intern(inp, *, tensor_contents=None): is_plain_tensor = type(inp) is torch.Tensor or type( inp) is torch.nn.Parameter if inp.is_nested: prefix = "nested_tensor(" elif is_plain_tensor: prefix = "tensor(" else: prefix = f"{type(inp).__name__}(" indent = len(prefix) suffixes = [] custom_contents_provided = tensor_contents is not None if custom_contents_provided: tensor_str = tensor_contents # This is used to extract the primal value and thus disable the forward AD # within this function. # TODO(albanD) This needs to be updated when more than one level is supported self, tangent = torch.autograd.forward_ad.unpack_dual(inp) # Note [Print tensor device]: # A general logic here is we only print device when it doesn't match # the device specified in default tensor type. # Currently torch.set_default_tensor_type() only supports CPU/CUDA, thus # torch._C._get_default_device() only returns either cpu or cuda. # In other cases, we don't have a way to set them as default yet, # and we should always print out device for them. if (self.device.type != torch._C._get_default_device() or (self.device.type == "cuda" and torch.cuda.current_device() != self.device.index) or (self.device.type == "mps")): suffixes.append("device='" + str(self.device) + "'") # Tensor printing performs tensor operations like slice, indexing, etc to make it in a # representable format. These operations on ipu/xla/lazy tensor results in compilations. Hence, # to avoid compilations, copying the tensor to cpu before printing. if self.device.type in ["xla", "lazy", "ipu"]: self = self.to("cpu") # TODO: add an API to map real -> complex dtypes _default_complex_dtype = (torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat) has_default_dtype = self.dtype in ( torch.get_default_dtype(), _default_complex_dtype, torch.int64, torch.bool, ) if self.is_sparse: suffixes.append("size=" + str(tuple(self.shape))) from torch._subclasses.fake_tensor import FakeTensor if not self.is_meta and not isinstance(self, FakeTensor): suffixes.append("nnz=" + str(self._nnz())) if not has_default_dtype: suffixes.append("dtype=" + str(self.dtype)) if not custom_contents_provided: indices_prefix = "indices=tensor(" indices = self._indices().detach() indices_str = _tensor_str(indices, indent + len(indices_prefix)) if indices.numel() == 0: indices_str += ", size=" + str(tuple(indices.shape)) values_prefix = "values=tensor(" values = self._values().detach() values_str = _tensor_str(values, indent + len(values_prefix)) if values.numel() == 0: values_str += ", size=" + str(tuple(values.shape)) tensor_str = (indices_prefix + indices_str + "),\n" + " " * indent + values_prefix + values_str + ")") elif self.layout in { torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc, }: suffixes.append("size=" + str(tuple(self.shape))) suffixes.append("nnz=" + str(self._nnz())) if not has_default_dtype: suffixes.append("dtype=" + str(self.dtype)) if not custom_contents_provided: compressed_indices_method, plain_indices_method = { torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices), torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices), torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices), torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices), }[self.layout] if self.layout in {torch.sparse_csr, torch.sparse_bsr}: cdimname, pdimname = "row", "column" else: cdimname, pdimname = "column", "row" compressed_indices_prefix = f"c{cdimname[:3]}_indices=tensor(" compressed_indices = compressed_indices_method(self).detach() compressed_indices_str = _tensor_str( compressed_indices, indent + len(compressed_indices_prefix)) if compressed_indices.numel() == 0: compressed_indices_str += ", size=" + str( tuple(compressed_indices.shape)) plain_indices_prefix = f"{pdimname[:3]}_indices=tensor(" plain_indices = plain_indices_method(self).detach() plain_indices_str = _tensor_str(plain_indices, indent + len(plain_indices_prefix)) if plain_indices.numel() == 0: plain_indices_str += ", size=" + str(tuple( plain_indices.shape)) values_prefix = "values=tensor(" values = self.values().detach() values_str = _tensor_str(values, indent + len(values_prefix)) if values.numel() == 0: values_str += ", size=" + str(tuple(values.shape)) tensor_str = (compressed_indices_prefix + compressed_indices_str + "),\n" + " " * indent + plain_indices_prefix + plain_indices_str + "),\n" + " " * indent + values_prefix + values_str + ")") elif self.is_quantized: suffixes.append("size=" + str(tuple(self.shape))) if not has_default_dtype: suffixes.append("dtype=" + str(self.dtype)) suffixes.append("quantization_scheme=" + str(self.qscheme())) if (self.qscheme() == torch.per_tensor_affine or self.qscheme() == torch.per_tensor_symmetric): suffixes.append("scale=" + str(self.q_scale())) suffixes.append("zero_point=" + str(self.q_zero_point())) elif (self.qscheme() == torch.per_channel_affine or self.qscheme() == torch.per_channel_symmetric or self.qscheme() == torch.per_channel_affine_float_qparams): suffixes.append("scale=" + str(self.q_per_channel_scales())) suffixes.append("zero_point=" + str(self.q_per_channel_zero_points())) suffixes.append("axis=" + str(self.q_per_channel_axis())) if not custom_contents_provided: tensor_str = _tensor_str(self.dequantize(), indent) elif self.is_nested: if not custom_contents_provided: def indented_str(s, indent): return "\n".join(f" {line}" for line in s.split("\n")) strs = ",\n".join( indented_str(str(t), indent + 1) for t in torch.ops.aten.unbind.int(self, 0)) tensor_str = f"[\n{strs}\n]" elif torch._is_functional_tensor(self): prefix = "_to_functional_tensor(" tensor_str = repr(torch._from_functional_tensor(self)) else: if self.is_meta: suffixes.append("size=" + str(tuple(self.shape))) if self.dtype != torch.get_default_dtype(): suffixes.append("dtype=" + str(self.dtype)) # TODO: This implies that ellipses is valid syntax for allocating # a meta tensor, which it could be, but it isn't right now if not custom_contents_provided: tensor_str = "..." else: if self.numel() == 0 and not self.is_sparse: # Explicitly print the shape if it is not (0,), to match NumPy behavior if self.dim() != 1: suffixes.append("size=" + str(tuple(self.shape))) # In an empty tensor, there are no elements to infer if the dtype # should be int64, so it must be shown explicitly. if self.dtype != torch.get_default_dtype(): suffixes.append("dtype=" + str(self.dtype)) if not custom_contents_provided: tensor_str = "[]" else: if not has_default_dtype: suffixes.append("dtype=" + str(self.dtype)) if not custom_contents_provided: if self.layout != torch.strided: tensor_str = _tensor_str(self.to_dense(), indent) else: tensor_str = _tensor_str(self, indent) if self.layout != torch.strided: suffixes.append("layout=" + str(self.layout)) # Use inp here to get the original grad_fn and not the one generated by the forward grad # unpacking. if inp.grad_fn is not None: name = type(inp.grad_fn).__name__ if name == "CppFunction": name = inp.grad_fn.name().rsplit("::", 1)[-1] suffixes.append("grad_fn=<{}>".format(name)) elif inp.requires_grad: suffixes.append("requires_grad=True") if self.has_names(): suffixes.append("names={}".format(self.names)) if tangent is not None: suffixes.append("tangent={}".format(tangent)) string_repr = _add_suffixes(prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse) # Check if this instance is flagged as a parameter and change the repr accordingly. # Unfortunately, this function has to be aware of this detail. # NB: This is currently skipped for plain tensor parameters to maintain BC. In the future, # this should be done for those as well to produce a valid repr. if isinstance(self, torch.nn.Parameter) and not is_plain_tensor: string_repr = f"Parameter({string_repr})" return string_repr