def test_kwarg_only(self) -> None: with capture_logs() as logs: x = LoggingTensor(torch.ones(1)) y = LoggingTensor(torch.ones(1, 1)) z = LoggingTensor(torch.ones(1)) log_input("x", x) log_input("y", y) log_input("z", z) torch.addmv(x, y, z) torch.addmv(x, y, z, beta=1) torch.addmv(x, y, z, beta=2) torch.addmv(x, y, z, alpha=2) torch.addmv(x, y, z, beta=2, alpha=2) # The expectation is that beta/alpha don't show up when they're # defaulted. This is even if the user explicitly specified it. self.assertExpectedInline('\n'.join(logs), '''\ $0 = input('x') $1 = input('y') $2 = input('z') $3 = torch._ops.aten.addmv($0, $1, $2) $4 = torch._ops.aten.addmv($0, $1, $2) $5 = torch._ops.aten.addmv($0, $1, $2, beta=2) $6 = torch._ops.aten.addmv($0, $1, $2, alpha=2) $7 = torch._ops.aten.addmv($0, $1, $2, beta=2, alpha=2)''')
def test_basic(self) -> None: with capture_logs() as logs: x = LoggingTensor(torch.tensor([3.0], requires_grad=True)) log_input("x", x) y = x * x saved_x = y.grad_fn._saved_self grad_y = LoggingTensor(torch.tensor([1.0])) log_input("grad_y", grad_y) g, = torch.autograd.grad((y,), (x,), (grad_y,)) self.assertEqual(g.elem, torch.tensor([6.0])) with torch.no_grad(): self.assertEqual(saved_x, x) self.assertEqual(saved_x._version, x._version) x.add_(2) self.assertEqual(saved_x, x) # TODO: figure out why broken # self.assertEqual(saved_x._version, x._version) self.assertExpectedInline('\n'.join(logs), '''\ $0 = input('x') $1 = torch._ops.aten.mul($0, $0) $2 = input('grad_y') $3 = torch._ops.aten.mul($2, $0) $4 = torch._ops.aten.mul($2, $0) $5 = torch._ops.aten.add($4, $3)''')
def test_version(self) -> None: x = LoggingTensor(torch.ones(1)) prev_vc = x._version x.detach().add_(2) cur_vc = x._version self.assertNotEqual(prev_vc, cur_vc) x.data.add_(2) self.assertEqual(cur_vc, x._version)
def test_detach_appears_twice_when_called_once(self) -> None: with capture_logs() as logs: x = LoggingTensor(torch.tensor([3.0]), requires_grad=True) log_input("x", x) x.detach() # FIXME: We actually want this to emit a single detach. However, # it currently emits two, for reasons unclear to us. Leaving # this test here to make sure we don't regress even further (it # would be bad if calling .detach() once emits 3+ detaches). self.assertExpectedInline('\n'.join(logs), '''\ $0 = input('x') $1 = torch._ops.aten.detach.default($0) $2 = torch._ops.aten.detach.default($1)''')
def test_out(self) -> None: with capture_logs() as logs: x = LoggingTensor(torch.ones(1)) y = LoggingTensor(torch.zeros(1)) log_input("x", x) log_input("y", y) torch.abs(x, out=y) self.assertEqual(y.elem, torch.ones(1)) # TODO: arguably this shouldn't pass and we should complain # that out isn't a kwarg self.assertExpectedInline('\n'.join(logs), '''\ $0 = input('x') $1 = input('y') $2 = torch._ops.aten.abs($0, out=$1)''')
def test_format(self) -> None: x = LoggingTensor(torch.ones(1)) s1 = str(x) s2 = repr(x) s3 = f"{x}" self.assertExpectedInline(s1, """LoggingTensor(tensor([1.]))""") self.assertEqual(s1, s2) self.assertEqual(s1, s3)
def test_subclass_creation(self): # Make sure these statements runs without error # In particular checking that when internal detach returns # subclasses, these are cleanly overwritten. class Foo(torch.Tensor): pass err_msg = "subclass Foo but.*already associated to a python object of type LoggingTensor" with self.assertRaisesRegex(RuntimeError, err_msg): a = torch.Tensor._make_subclass(Foo, LoggingTensor(torch.rand(2))) with self.assertRaisesRegex(RuntimeError, err_msg): b = LoggingTensor(torch.rand(2)).as_subclass(Foo) with self.assertRaisesRegex(RuntimeError, err_msg): Foo(LoggingTensor(torch.rand(2))) with self.assertRaisesRegex(TypeError, "Foo must define __torch_dispatch__"): torch.Tensor._make_wrapper_subclass(Foo, (2, 2))
def test_enable_python_mode_error(self) -> None: with self.assertRaisesRegex(ValueError, "__torch_dispatch__"): with enable_python_mode(torch.Tensor): pass z = LoggingTensor(torch.empty([])) with self.assertRaisesRegex(ValueError, "must be the type"): with enable_python_mode(z): pass
def test_mixed_wrappers_valid(self): def f(x, y): z = x + y z.add_(1) return z x1_not_functional = LoggingTensor(torch.ones(4)) x2_functional = torch._to_functional_tensor(LoggingTensor(torch.ones(4))) with capture_logs() as logs: y = f(x1_not_functional, x2_functional) # Make sure that functionalization ran the "+" kernel # with a functional + non-functional tensor, and wrapped the output appropriately. self.assertExpectedInline('\n'.join(logs), """\ $2 = torch._ops.aten.add.Tensor($0, $1) $3 = torch._ops.aten.add.Tensor($2, 1)""")
def test_enable_torch_dispatch_mode_error(self) -> None: z = LoggingTensor(torch.empty([])) with self.assertRaisesRegex( ValueError, "expected to get TorchDispatchMode, Tensor-like class, or None" ): with enable_torch_dispatch_mode(z): pass
def test_multiple_levels_of_wrapping(self): def f(x): # call an inplace op and have it get logged twice (by the outer + inner wrapper) x.add_(1) # Test 1: both the inner and outer wrapper are "functionalized" x_inner_and_outer_functional = torch._to_functional_tensor( InplaceLoggingTensor( torch._to_functional_tensor(LoggingTensor(torch.ones(4))))) with capture_logs() as logs: f(x_inner_and_outer_functional) # Since both wrappers were unctionalized, they both log "add" self.assertExpectedInline( '\n'.join(logs), """\ $1 = torch._ops.aten.add.Tensor($0, 1) $3 = torch._ops.aten.add.Tensor($2, 1)""") # Test 2: only the inner wrapper is "functionalized" x_only_inner_functional = InplaceLoggingTensor( torch._to_functional_tensor(LoggingTensor(torch.ones(4)))) with capture_logs() as logs: f(x_only_inner_functional) # Since only the inner wrapper is functionalized, then the inner (first) log is functionalized self.assertExpectedInline( '\n'.join(logs), """\ $1 = torch._ops.aten.add.Tensor($0, 1) $3 = torch._ops.aten.add_.Tensor($2, 1)""") # Test 3: only the inner wrapper is "functionalized" x_only_outer_functional = torch._to_functional_tensor( InplaceLoggingTensor(LoggingTensor(torch.ones(4)))) with capture_logs() as logs: f(x_only_outer_functional) # Only the outer add_ is functionalized # Since only the outer wrapper is functionalized, then the outer (second) log is functionalized self.assertExpectedInline( '\n'.join(logs), """\ $1 = torch._ops.aten.add_.Tensor($0, 1) $3 = torch._ops.aten.add.Tensor($2, 1)""")
def test_wrapper_subclass_serializes(self) -> None: with tempfile.TemporaryFile() as f: x = LoggingTensor(torch.randn(3)) torch.save(x, f) f.seek(0) x_loaded = torch.load(f) self.assertTrue(type(x_loaded) is type(x)) self.assertEqual(x.elem, x_loaded.elem) self.assertFalse(x is x_loaded)
def test_torch_ops(self): r = make_tensor((2,), device='cpu', dtype=torch.float) self.assertEqual(torch.ops.prims.sin(r), torch.sin(r)) r = LoggingTensor(r) with capture_logs() as logs: log_input("input", r) prims.sin(r) self.assertExpectedInline('\n'.join(logs), """\ $0 = input('input') $1 = torch._ops.prims.sin.default($0)""")
def test_kwarg_only_and_positional_default(self) -> None: with capture_logs() as logs: x = LoggingTensor(torch.ones(1)) y = LoggingTensor(torch.ones(1)) log_input("x", x) log_input("y", y) torch.ops.aten.kl_div(x, y) torch.ops.aten.kl_div(x, y, 2) torch.ops.aten.kl_div(x, y, log_target=True) torch.ops.aten.kl_div(x, y, 2, log_target=True) # What we are testing here is that we omit reduction # if it is defaulted, even if a kwarg is set self.assertExpectedInline('\n'.join(logs), '''\ $0 = input('x') $1 = input('y') $2 = torch._ops.aten.kl_div($0, $1) $3 = torch._ops.aten.kl_div($0, $1, 2) $4 = torch._ops.aten.kl_div($0, $1, log_target=True) $5 = torch._ops.aten.kl_div($0, $1, 2, log_target=True)''')
def get_logs(self, func, inpt): input_clone_logging = LoggingTensor(inpt.clone()) input_functional_logging = torch._to_functional_tensor(input_clone_logging) with capture_logs() as logs: log_input("input", input_clone_logging) torch._enable_functionalization() try: func(input_functional_logging) finally: torch._disable_functionalization() return logs
def test_mixed_wrappers_valid(self): def f(x, y): z = x + y z.add_(1) return z x1_not_functional = LoggingTensor(torch.ones(4)) x2_functional = torch._to_functional_tensor( LoggingTensor(torch.ones(4))) with capture_logs() as logs: y = f(x1_not_functional, x2_functional) # I think the alias trace is coming from the fact that x2 is technically *not* # a LoggingTensor (instead it *contains* a LoggingTensor), but x1 *is* a LoggingTensor. # The important thing here though is that functionalization ran the "+" kernel # with a functional + non-functional tensor, and wrapped the output appropriately. self.assertExpectedInline( '\n'.join(logs), """\ $2 = torch._ops.aten.add.Tensor($0, $1) $3 = torch._ops.aten.alias.default($2) $4 = torch._ops.aten.add.Tensor($3, tensor(1))""")
def test_tolist_numpy_with_python_mode(self) -> None: x = LoggingTensor(torch.tensor([2.0, 3.0])) with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."): x.tolist() with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."): x.numpy() with self.assertRaises(AssertionError): self.assertEqual(x, None)
def test_nesting_with_same_enable_torch_dispatch_mode(self) -> None: # "nested" enable_torch_dispatch_modes are allowed if they're the same mode. It's the equivalent of # a noop, so it will only write once to the log with capture_logs() as logs: x = LoggingTensor(torch.tensor([3.])) log_input("x", x) with enable_torch_dispatch_mode(LoggingTensor): with enable_torch_dispatch_mode(LoggingTensor): x + x self.assertExpectedInline( '\n'.join(logs), '''\ $0 = input('x') $1 = torch._ops.aten.add.Tensor($0, $0)''')
def test_custom_autograd(self) -> None: escape = [None] class Square(torch.autograd.Function): @staticmethod def forward(ctx, x): y = x**2 ctx.save_for_backward(x) return y @staticmethod def backward(ctx, grad_output): assert isinstance(grad_output, LoggingTensor) x, = ctx.saved_tensors assert isinstance(x, LoggingTensor) escape[0] = x return grad_output * 2 * x with capture_logs() as logs: x = LoggingTensor(torch.ones(1), requires_grad=True) log_input("x", x) x.grad = LoggingTensor(torch.zeros(1)) log_input("x.grad", x.grad) y = Square.apply(x) grad_output = LoggingTensor(torch.ones(1)) log_input("grad_output", grad_output) y.backward(grad_output) with torch.no_grad(): self.assertEqual(escape[0], x) self.assertEqual(escape[0]._version, x._version) # TODO: figure out why x.requires_grad = False doesn't # trigger an error for LoggingTensor x.add_(2) self.assertEqual(escape[0], x) # TODO: figure out why this is broken # self.assertEqual(escape[0]._version, x._version) self.assertExpectedInline( '\n'.join(logs), '''\ $0 = input('x') $1 = input('x.grad') $2 = torch._ops.aten.pow.Tensor_Scalar($0, 2) $3 = input('grad_output') $4 = torch._ops.aten.mul.Tensor($3, 2) $5 = torch._ops.aten.mul.Tensor($4, $0) $6 = torch._ops.aten.add_.Tensor($1, $5)''')
def test_autograd_in_attr(self): # We want the wrapped Tensor to require gradients! true_t = torch.rand(2, requires_grad=True) t = LoggingTensor(true_t) out = t + 2 self.assertFalse(out.requires_grad) self.assertIsNone(out.grad_fn) self.assertTrue(out.elem.requires_grad) self.assertIsNotNone(out.elem.grad_fn) with self.assertRaisesRegex(RuntimeError, "does not require grad"): out.sum().backward() out.elem.sum().backward() self.assertIsNone(t.grad) self.assertIsNotNone(t.elem.grad)
def test_storage_can_be_converted_to_python_object(self): with enable_python_mode(LoggingTensor): s = torch.Storage() z = LoggingTensor(torch.empty([])) z.set_(s)
class SubclassInfo: __slots__ = ['name', 'create_fn', 'closed_under_ops'] def __init__(self, name, create_fn, closed_under_ops=True): self.name = name self.create_fn = create_fn # create_fn(shape) -> tensor instance self.closed_under_ops = closed_under_ops subclass_db = { torch.Tensor: SubclassInfo('base_tensor', create_fn=lambda shape: torch.randn(shape)), NonWrapperTensor: SubclassInfo('non_wrapper_tensor', create_fn=lambda shape: NonWrapperTensor(torch.randn(shape))), LoggingTensor: SubclassInfo('logging_tensor', create_fn=lambda shape: LoggingTensor(torch.randn(shape))), SparseTensor: SubclassInfo('sparse_tensor', create_fn=lambda shape: SparseTensor.from_dense( torch.randn(shape).relu())), DiagTensorBelow: SubclassInfo( 'diag_tensor_below', create_fn=lambda shape: DiagTensorBelow(torch.randn(shape)), closed_under_ops=False # sparse semantics ), }
def test_deepcopy_wrapper_subclass(self) -> None: x = LoggingTensor(torch.randn(3)) x_copy = deepcopy(x) self.assertTrue(type(x_copy) is type(x)) self.assertEqual(x.elem, x_copy.elem) self.assertFalse(x is x_copy)
def test_storage(self) -> None: # For now, just make sure it doesn't crash. Ideally, we should # return some virtual storage that is safe to work with x = LoggingTensor(torch.ones(1)) self.assertRaises(RuntimeError, lambda: x.storage())
def test_metadata_change_not_allowed(self) -> None: x = LoggingTensor(torch.ones(1)) y = x.data self.assertIsInstance(y, LoggingTensor) self.assertRaises(RuntimeError, lambda: y.resize_(4))
def test_save_for_backwards_segfault(self): inp = torch._to_functional_tensor(LoggingTensor(torch.randn(2, 2))).requires_grad_(True) inp.exp()