Ejemplo n.º 1
0
    def test_kwarg_only(self) -> None:
        with capture_logs() as logs:
            x = LoggingTensor(torch.ones(1))
            y = LoggingTensor(torch.ones(1, 1))
            z = LoggingTensor(torch.ones(1))
            log_input("x", x)
            log_input("y", y)
            log_input("z", z)
            torch.addmv(x, y, z)
            torch.addmv(x, y, z, beta=1)
            torch.addmv(x, y, z, beta=2)
            torch.addmv(x, y, z, alpha=2)
            torch.addmv(x, y, z, beta=2, alpha=2)

        # The expectation is that beta/alpha don't show up when they're
        # defaulted.  This is even if the user explicitly specified it.
        self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = input('y')
$2 = input('z')
$3 = torch._ops.aten.addmv($0, $1, $2)
$4 = torch._ops.aten.addmv($0, $1, $2)
$5 = torch._ops.aten.addmv($0, $1, $2, beta=2)
$6 = torch._ops.aten.addmv($0, $1, $2, alpha=2)
$7 = torch._ops.aten.addmv($0, $1, $2, beta=2, alpha=2)''')
Ejemplo n.º 2
0
    def test_basic(self) -> None:
        with capture_logs() as logs:
            x = LoggingTensor(torch.tensor([3.0], requires_grad=True))
            log_input("x", x)
            y = x * x
            saved_x = y.grad_fn._saved_self
            grad_y = LoggingTensor(torch.tensor([1.0]))
            log_input("grad_y", grad_y)
            g, = torch.autograd.grad((y,), (x,), (grad_y,))

        self.assertEqual(g.elem, torch.tensor([6.0]))
        with torch.no_grad():
            self.assertEqual(saved_x, x)
            self.assertEqual(saved_x._version, x._version)
            x.add_(2)
            self.assertEqual(saved_x, x)
            # TODO: figure out why broken
            # self.assertEqual(saved_x._version, x._version)
        self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = torch._ops.aten.mul($0, $0)
$2 = input('grad_y')
$3 = torch._ops.aten.mul($2, $0)
$4 = torch._ops.aten.mul($2, $0)
$5 = torch._ops.aten.add($4, $3)''')
Ejemplo n.º 3
0
 def test_version(self) -> None:
     x = LoggingTensor(torch.ones(1))
     prev_vc = x._version
     x.detach().add_(2)
     cur_vc = x._version
     self.assertNotEqual(prev_vc, cur_vc)
     x.data.add_(2)
     self.assertEqual(cur_vc, x._version)
Ejemplo n.º 4
0
    def test_detach_appears_twice_when_called_once(self) -> None:
        with capture_logs() as logs:
            x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
            log_input("x", x)
            x.detach()
        # FIXME: We actually want this to emit a single detach. However,
        # it currently emits two, for reasons unclear to us. Leaving
        # this test here to make sure we don't regress even further (it
        # would be bad if calling .detach() once emits 3+ detaches).
        self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = torch._ops.aten.detach.default($0)
$2 = torch._ops.aten.detach.default($1)''')
Ejemplo n.º 5
0
    def test_out(self) -> None:
        with capture_logs() as logs:
            x = LoggingTensor(torch.ones(1))
            y = LoggingTensor(torch.zeros(1))
            log_input("x", x)
            log_input("y", y)
            torch.abs(x, out=y)

        self.assertEqual(y.elem, torch.ones(1))
        # TODO: arguably this shouldn't pass and we should complain
        # that out isn't a kwarg
        self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = input('y')
$2 = torch._ops.aten.abs($0, out=$1)''')
Ejemplo n.º 6
0
 def test_format(self) -> None:
     x = LoggingTensor(torch.ones(1))
     s1 = str(x)
     s2 = repr(x)
     s3 = f"{x}"
     self.assertExpectedInline(s1, """LoggingTensor(tensor([1.]))""")
     self.assertEqual(s1, s2)
     self.assertEqual(s1, s3)
Ejemplo n.º 7
0
    def test_subclass_creation(self):
        # Make sure these statements runs without error
        # In particular checking that when internal detach returns
        # subclasses, these are cleanly overwritten.
        class Foo(torch.Tensor):
            pass

        err_msg = "subclass Foo but.*already associated to a python object of type LoggingTensor"
        with self.assertRaisesRegex(RuntimeError, err_msg):
            a = torch.Tensor._make_subclass(Foo, LoggingTensor(torch.rand(2)))
        with self.assertRaisesRegex(RuntimeError, err_msg):
            b = LoggingTensor(torch.rand(2)).as_subclass(Foo)
        with self.assertRaisesRegex(RuntimeError, err_msg):
            Foo(LoggingTensor(torch.rand(2)))

        with self.assertRaisesRegex(TypeError, "Foo must define __torch_dispatch__"):
            torch.Tensor._make_wrapper_subclass(Foo, (2, 2))
Ejemplo n.º 8
0
 def test_enable_python_mode_error(self) -> None:
     with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
         with enable_python_mode(torch.Tensor):
             pass
     z = LoggingTensor(torch.empty([]))
     with self.assertRaisesRegex(ValueError, "must be the type"):
         with enable_python_mode(z):
             pass
Ejemplo n.º 9
0
    def test_mixed_wrappers_valid(self):
        def f(x, y):
            z = x + y
            z.add_(1)
            return z

        x1_not_functional = LoggingTensor(torch.ones(4))
        x2_functional = torch._to_functional_tensor(LoggingTensor(torch.ones(4)))

        with capture_logs() as logs:
            y = f(x1_not_functional, x2_functional)

        # Make sure that functionalization ran the "+" kernel
        # with a functional + non-functional tensor, and wrapped the output appropriately.
        self.assertExpectedInline('\n'.join(logs), """\
$2 = torch._ops.aten.add.Tensor($0, $1)
$3 = torch._ops.aten.add.Tensor($2, 1)""")
Ejemplo n.º 10
0
 def test_enable_torch_dispatch_mode_error(self) -> None:
     z = LoggingTensor(torch.empty([]))
     with self.assertRaisesRegex(
             ValueError,
             "expected to get TorchDispatchMode, Tensor-like class, or None"
     ):
         with enable_torch_dispatch_mode(z):
             pass
Ejemplo n.º 11
0
    def test_multiple_levels_of_wrapping(self):
        def f(x):
            # call an inplace op and have it get logged twice (by the outer + inner wrapper)
            x.add_(1)

        # Test 1: both the inner and outer wrapper are "functionalized"
        x_inner_and_outer_functional = torch._to_functional_tensor(
            InplaceLoggingTensor(
                torch._to_functional_tensor(LoggingTensor(torch.ones(4)))))

        with capture_logs() as logs:
            f(x_inner_and_outer_functional)

        # Since both wrappers were unctionalized, they both log "add"
        self.assertExpectedInline(
            '\n'.join(logs), """\
$1 = torch._ops.aten.add.Tensor($0, 1)
$3 = torch._ops.aten.add.Tensor($2, 1)""")

        # Test 2: only the inner wrapper is "functionalized"
        x_only_inner_functional = InplaceLoggingTensor(
            torch._to_functional_tensor(LoggingTensor(torch.ones(4))))

        with capture_logs() as logs:
            f(x_only_inner_functional)

        # Since only the inner wrapper is functionalized, then the inner (first) log is functionalized
        self.assertExpectedInline(
            '\n'.join(logs), """\
$1 = torch._ops.aten.add.Tensor($0, 1)
$3 = torch._ops.aten.add_.Tensor($2, 1)""")

        # Test 3: only the inner wrapper is "functionalized"
        x_only_outer_functional = torch._to_functional_tensor(
            InplaceLoggingTensor(LoggingTensor(torch.ones(4))))

        with capture_logs() as logs:
            f(x_only_outer_functional)

        # Only the outer add_ is functionalized
        # Since only the outer wrapper is functionalized, then the outer (second) log is functionalized
        self.assertExpectedInline(
            '\n'.join(logs), """\
$1 = torch._ops.aten.add_.Tensor($0, 1)
$3 = torch._ops.aten.add.Tensor($2, 1)""")
Ejemplo n.º 12
0
 def test_wrapper_subclass_serializes(self) -> None:
     with tempfile.TemporaryFile() as f:
         x = LoggingTensor(torch.randn(3))
         torch.save(x, f)
         f.seek(0)
         x_loaded = torch.load(f)
         self.assertTrue(type(x_loaded) is type(x))
         self.assertEqual(x.elem, x_loaded.elem)
         self.assertFalse(x is x_loaded)
Ejemplo n.º 13
0
    def test_torch_ops(self):
        r = make_tensor((2,), device='cpu', dtype=torch.float)
        self.assertEqual(torch.ops.prims.sin(r), torch.sin(r))

        r = LoggingTensor(r)
        with capture_logs() as logs:
            log_input("input", r)
            prims.sin(r)
        self.assertExpectedInline('\n'.join(logs), """\
$0 = input('input')
$1 = torch._ops.prims.sin.default($0)""")
Ejemplo n.º 14
0
    def test_kwarg_only_and_positional_default(self) -> None:
        with capture_logs() as logs:
            x = LoggingTensor(torch.ones(1))
            y = LoggingTensor(torch.ones(1))
            log_input("x", x)
            log_input("y", y)
            torch.ops.aten.kl_div(x, y)
            torch.ops.aten.kl_div(x, y, 2)
            torch.ops.aten.kl_div(x, y, log_target=True)
            torch.ops.aten.kl_div(x, y, 2, log_target=True)

        # What we are testing here is that we omit reduction
        # if it is defaulted, even if a kwarg is set
        self.assertExpectedInline('\n'.join(logs), '''\
$0 = input('x')
$1 = input('y')
$2 = torch._ops.aten.kl_div($0, $1)
$3 = torch._ops.aten.kl_div($0, $1, 2)
$4 = torch._ops.aten.kl_div($0, $1, log_target=True)
$5 = torch._ops.aten.kl_div($0, $1, 2, log_target=True)''')
Ejemplo n.º 15
0
    def get_logs(self, func, inpt):
        input_clone_logging = LoggingTensor(inpt.clone())
        input_functional_logging = torch._to_functional_tensor(input_clone_logging)

        with capture_logs() as logs:
            log_input("input", input_clone_logging)
            torch._enable_functionalization()
            try:
                func(input_functional_logging)
            finally:
                torch._disable_functionalization()
        return logs
Ejemplo n.º 16
0
    def test_mixed_wrappers_valid(self):
        def f(x, y):
            z = x + y
            z.add_(1)
            return z

        x1_not_functional = LoggingTensor(torch.ones(4))
        x2_functional = torch._to_functional_tensor(
            LoggingTensor(torch.ones(4)))

        with capture_logs() as logs:
            y = f(x1_not_functional, x2_functional)

        # I think the alias trace is coming from the fact that x2 is technically *not*
        # a LoggingTensor (instead it *contains* a LoggingTensor), but x1 *is* a LoggingTensor.
        # The important thing here though is that functionalization ran the "+" kernel
        # with a functional + non-functional tensor, and wrapped the output appropriately.
        self.assertExpectedInline(
            '\n'.join(logs), """\
$2 = torch._ops.aten.add.Tensor($0, $1)
$3 = torch._ops.aten.alias.default($2)
$4 = torch._ops.aten.add.Tensor($3, tensor(1))""")
Ejemplo n.º 17
0
 def test_tolist_numpy_with_python_mode(self) -> None:
     x = LoggingTensor(torch.tensor([2.0, 3.0]))
     with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
         x.tolist()
     with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
         x.numpy()
     with self.assertRaises(AssertionError):
         self.assertEqual(x, None)
Ejemplo n.º 18
0
    def test_nesting_with_same_enable_torch_dispatch_mode(self) -> None:
        # "nested" enable_torch_dispatch_modes are allowed if they're the same mode. It's the equivalent of
        # a noop, so it will only write once to the log
        with capture_logs() as logs:
            x = LoggingTensor(torch.tensor([3.]))
            log_input("x", x)
            with enable_torch_dispatch_mode(LoggingTensor):
                with enable_torch_dispatch_mode(LoggingTensor):
                    x + x

        self.assertExpectedInline(
            '\n'.join(logs), '''\
$0 = input('x')
$1 = torch._ops.aten.add.Tensor($0, $0)''')
Ejemplo n.º 19
0
    def test_custom_autograd(self) -> None:
        escape = [None]

        class Square(torch.autograd.Function):
            @staticmethod
            def forward(ctx, x):
                y = x**2
                ctx.save_for_backward(x)
                return y

            @staticmethod
            def backward(ctx, grad_output):
                assert isinstance(grad_output, LoggingTensor)
                x, = ctx.saved_tensors
                assert isinstance(x, LoggingTensor)
                escape[0] = x
                return grad_output * 2 * x

        with capture_logs() as logs:
            x = LoggingTensor(torch.ones(1), requires_grad=True)
            log_input("x", x)
            x.grad = LoggingTensor(torch.zeros(1))
            log_input("x.grad", x.grad)
            y = Square.apply(x)
            grad_output = LoggingTensor(torch.ones(1))
            log_input("grad_output", grad_output)
            y.backward(grad_output)

        with torch.no_grad():
            self.assertEqual(escape[0], x)
            self.assertEqual(escape[0]._version, x._version)
            # TODO: figure out why x.requires_grad = False doesn't
            # trigger an error for LoggingTensor
            x.add_(2)
            self.assertEqual(escape[0], x)
            # TODO: figure out why this is broken
            # self.assertEqual(escape[0]._version, x._version)

        self.assertExpectedInline(
            '\n'.join(logs), '''\
$0 = input('x')
$1 = input('x.grad')
$2 = torch._ops.aten.pow.Tensor_Scalar($0, 2)
$3 = input('grad_output')
$4 = torch._ops.aten.mul.Tensor($3, 2)
$5 = torch._ops.aten.mul.Tensor($4, $0)
$6 = torch._ops.aten.add_.Tensor($1, $5)''')
Ejemplo n.º 20
0
    def test_autograd_in_attr(self):
        # We want the wrapped Tensor to require gradients!
        true_t = torch.rand(2, requires_grad=True)
        t = LoggingTensor(true_t)

        out = t + 2

        self.assertFalse(out.requires_grad)
        self.assertIsNone(out.grad_fn)

        self.assertTrue(out.elem.requires_grad)
        self.assertIsNotNone(out.elem.grad_fn)

        with self.assertRaisesRegex(RuntimeError, "does not require grad"):
            out.sum().backward()

        out.elem.sum().backward()

        self.assertIsNone(t.grad)
        self.assertIsNotNone(t.elem.grad)
Ejemplo n.º 21
0
 def test_storage_can_be_converted_to_python_object(self):
     with enable_python_mode(LoggingTensor):
         s = torch.Storage()
         z = LoggingTensor(torch.empty([]))
         z.set_(s)
Ejemplo n.º 22
0
class SubclassInfo:

    __slots__ = ['name', 'create_fn', 'closed_under_ops']

    def __init__(self, name, create_fn, closed_under_ops=True):
        self.name = name
        self.create_fn = create_fn  # create_fn(shape) -> tensor instance
        self.closed_under_ops = closed_under_ops


subclass_db = {
    torch.Tensor:
    SubclassInfo('base_tensor', create_fn=lambda shape: torch.randn(shape)),
    NonWrapperTensor:
    SubclassInfo('non_wrapper_tensor',
                 create_fn=lambda shape: NonWrapperTensor(torch.randn(shape))),
    LoggingTensor:
    SubclassInfo('logging_tensor',
                 create_fn=lambda shape: LoggingTensor(torch.randn(shape))),
    SparseTensor:
    SubclassInfo('sparse_tensor',
                 create_fn=lambda shape: SparseTensor.from_dense(
                     torch.randn(shape).relu())),
    DiagTensorBelow:
    SubclassInfo(
        'diag_tensor_below',
        create_fn=lambda shape: DiagTensorBelow(torch.randn(shape)),
        closed_under_ops=False  # sparse semantics
    ),
}
Ejemplo n.º 23
0
 def test_deepcopy_wrapper_subclass(self) -> None:
     x = LoggingTensor(torch.randn(3))
     x_copy = deepcopy(x)
     self.assertTrue(type(x_copy) is type(x))
     self.assertEqual(x.elem, x_copy.elem)
     self.assertFalse(x is x_copy)
Ejemplo n.º 24
0
 def test_storage(self) -> None:
     # For now, just make sure it doesn't crash.  Ideally, we should
     # return some virtual storage that is safe to work with
     x = LoggingTensor(torch.ones(1))
     self.assertRaises(RuntimeError, lambda: x.storage())
Ejemplo n.º 25
0
 def test_metadata_change_not_allowed(self) -> None:
     x = LoggingTensor(torch.ones(1))
     y = x.data
     self.assertIsInstance(y, LoggingTensor)
     self.assertRaises(RuntimeError, lambda: y.resize_(4))
Ejemplo n.º 26
0
 def test_save_for_backwards_segfault(self):
     inp = torch._to_functional_tensor(LoggingTensor(torch.randn(2, 2))).requires_grad_(True)
     inp.exp()