Beispiel #1
0
    def _vmap_test(self,
                   op,
                   inputs,
                   in_dims=0,
                   out_dims=0,
                   check_view=False,
                   check_propagates_grad=True):
        result = vmap(op, in_dims, out_dims)(*inputs)
        reference_result = reference_vmap(op, inputs, in_dims, out_dims)
        self.assertEqual(result, reference_result)
        op_has_single_return = not isinstance(result, tuple)

        if check_view:
            result_as_tuple = (result, ) if op_has_single_return else result
            for output in result_as_tuple:
                self.assertEqual(
                    output.data_ptr() -
                    output.storage_offset() * output.element_size(),
                    inputs[0].data_ptr(),
                    msg="result was not a view of the first input!")

        if not check_propagates_grad:
            return
        # Assuming input[0] is a floating-point tensor. Check if the vmap
        # operation propagates the requires_grad flag to the zeroth output.
        # Some vmap operators are implemented in a way that assumes that
        # they are composite with respect to autograd. If the operator ever is
        # changed to not be composite with respect to autograd, then the
        # following check should fail.
        inputs_clone = list(inputs)
        inputs_clone[0] = inputs[0].clone().requires_grad_()
        result = vmap(op, in_dims, out_dims)(*inputs_clone)
        result_as_tuple = (result, ) if op_has_single_return else result
        self.assertTrue(result[0].requires_grad)
Beispiel #2
0
 def test_slice(self):
     test = self._vmap_view_test
     B0, B1, B2 = 7, 11, 13
     test(lambda t: t[0:1], (torch.rand(B0, 3, 5),))
     test(lambda t: t[:, 1:3], (torch.rand(3, 5, B0),), in_dims=2)
     test(vmap(lambda t: t[:, 0:1], in_dims=2), (torch.rand(3, 5, B0, B1),), in_dims=2)
     test(vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2),
          (torch.rand(3, 5, B0, B1, B2),), in_dims=2)
Beispiel #3
0
    def test_unsupported_op_err_msg(self):
        def foo(x):
            return torch.cos(x)

        x = torch.randn(3)
        with self.assertRaisesRegex(RuntimeError,
                                    'NYI: Calling aten::cos inside of vmap'):
            vmap(foo)(x)
Beispiel #4
0
    def test_nested_with_same_map_dim(self):
        x = torch.randn(2, 3, 5)
        y = torch.randn(2, 3, 5)
        output = vmap(vmap(torch.mul))(x, y)
        self.assertEqual(output, x * y)

        output = vmap(vmap(vmap(torch.mul)))(x, y)
        self.assertEqual(output, x * y)
Beispiel #5
0
 def test_t(self):
     op = torch.t
     test = self._vmap_view_test
     B0, B1, B2 = 7, 11, 13
     test(op, (torch.rand(B0, 2, 5),))
     test(op, (torch.rand(2, B0, 5),), in_dims=1)
     test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
     test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)
Beispiel #6
0
    def test_unsupported_inplace_op_err_msg(self):
        def foo(x):
            return x.cos_()

        x = torch.randn(3)
        with self.assertRaisesRegex(RuntimeError,
                                    'Batching rule not implemented'):
            vmap(foo)(x)
Beispiel #7
0
 def test_select(self):
     op = torch.select
     test = self._vmap_view_test
     B0, B1, B2 = 7, 11, 13
     test(op, (torch.rand(B0, 2, 5), 0, 0), in_dims=(0, None, None))
     test(op, (torch.rand(2, B0, 5), 1, 1), in_dims=(1, None, None))
     test(vmap(lambda t: op(t, 1, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
     test(vmap(vmap(lambda t: op(t, 1, 1), in_dims=1)), (torch.rand(B1, 2, B0, B2, 5),), in_dims=2)
Beispiel #8
0
    def test_out_dims_edge_case(self):
        def foo(x):
            return x

        # Test that we accept out_dims=(1,) for a function with one output.
        tensor = torch.randn(2, 3)
        expected = vmap(foo, out_dims=1)(tensor)
        result = vmap(foo, out_dims=(1, ))(tensor)
        self.assertEqual(result, expected)
Beispiel #9
0
 def test_reshape(self):
     test = self._vmap_test
     B0, B1, B2 = 7, 11, 13
     op = torch.reshape
     test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None), check_view=True)
     test(op, (torch.rand(2, B0, 5), [1, 1, 10]), in_dims=(1, None), check_view=False)
     test(vmap(lambda t: t.reshape([-1])), (torch.rand(B0, B1, 2, 5),), check_view=True)
     test(vmap(vmap(lambda t: t.reshape([-1]), in_dims=2), in_dims=1),
          (torch.rand(3, B1, 2, B2, 5, B0),), in_dims=5, check_view=False)
Beispiel #10
0
    def test_non_tensor_output_raises(self):
        with self.assertRaisesRegex(ValueError, "got type <class 'float'> as the return"):
            output = vmap(lambda x: 3.14)(torch.ones(3))

        def multiple_outputs(x):
            return x, 3

        with self.assertRaisesRegex(ValueError, "got type <class 'int'> for return 1"):
            vmap(multiple_outputs)(torch.ones(3))
Beispiel #11
0
 def test_out_dim_out_of_bounds_err_msg(self):
     # TODO(rzou): This error message isn't that great. It comes straight
     # from maybe_wrap_dim. Consider doing a try-catch-(add some context) to
     # the error message in the future in C++
     msg = 'Dimension out of range'
     x = torch.randn(2, 3, 5)
     with self.assertRaisesRegex(IndexError, msg):
         vmap(lambda x: x, out_dims=3)(x)
     with self.assertRaisesRegex(IndexError, msg):
         vmap(lambda x: x, out_dims=-4)(x)
Beispiel #12
0
    def test_unsupported_inplace_op_err_msg(self):
        def foo(x):
            return x.cos_()

        x = torch.randn(3)
        # TODO(rzou): Yeah, this error message is pretty bad because the
        # dispatcher's fallback mechanism doesn't work for ops that don't support
        # boxing. Fix the error message at some point.
        with self.assertRaisesRegex(RuntimeError,
                                    'Tried to call KernelFunction::call'):
            vmap(foo)(x)
Beispiel #13
0
    def test_nested_with_different_map_dim(self):
        x = torch.randn(2, 3)
        y = torch.randn(5, 3)
        output = vmap(lambda x: vmap(lambda y: x * y)(y))(x)
        self.assertEqual(output.shape, (2, 5, 3))
        self.assertEqual(output, x.view(2, 1, 3) * y)

        z = torch.randn(7, 3)
        output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x)
        self.assertEqual(output.shape, (2, 5, 7, 3))
        self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z)
Beispiel #14
0
    def test_unfold(self):
        op = torch.Tensor.unfold
        test = self._vmap_view_test
        B0, B1, B2 = 3, 2, 5

        test(op, (torch.rand(B0, 7, 11), 0, 2, 1), in_dims=(0, None, None, None))
        test(op, (torch.rand(7, B0, 11), 1, 4, 2), in_dims=(1, None, None, None))
        test(vmap(op, in_dims=(0, None, None, None)),
             (torch.rand(B1, 7, B0, 11), 1, 5, 1), in_dims=(2, None, None, None))
        test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)),
             (torch.rand(B1, 7, B0, 11, B2), -1, 2, 4), in_dims=(2, None, None, None))
Beispiel #15
0
    def test_narrow(self):
        op = torch.narrow
        test = self._vmap_view_test
        B0, B1, B2 = 7, 11, 13

        test(op, (torch.rand(B0, 2, 5), -1, 1, 3), in_dims=(0, None, None, None))
        test(op, (torch.rand(2, B0, 5), 1, 1, 3), in_dims=(1, None, None, None))
        test(vmap(op, in_dims=(0, None, None, None)),
             (torch.rand(B1, 2, B0, 5), 1, 0, 0), in_dims=(2, None, None, None))
        test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)),
             (torch.rand(B1, 2, B0, 5, B2), -1, 2, 3), in_dims=(2, None, None, None))
Beispiel #16
0
 def test_expand_as(self):
     op = torch.Tensor.expand_as
     test = self._vmap_view_test
     B0, B1, B2 = 7, 11, 13
     test(op, (torch.rand(B0, 1, 5), torch.rand(B0, 2, 3, 5)))
     test(op, (torch.rand(B0, 1, 5), torch.rand(2, 3, 5)), in_dims=(0, None))
     test(op, (torch.rand(1, 5), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
     test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B0, B1, 2, 3, 5)))
     test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B1, B0, 2, 3, 5)), in_dims=(0, 1))
     test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None))
     test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5)))
Beispiel #17
0
 def test_diagonal(self):
     tensor = torch.randn(3, 5, 7, 11, 13)
     test = self._vmap_view_test
     op = torch.diagonal
     test(op, (tensor, 1, 0, 1), in_dims=(0, None, None, None))
     test(op, (tensor, 0, 2, -1), in_dims=(0, None, None, None))
     test(op, (tensor, 2, 1, 2), in_dims=(1, None, None, None))
     test(op, (tensor, 0, -2, -1), in_dims=(1, None, None, None), out_dims=1)
     test(vmap(lambda t: op(t, 0, 0, -1)), (tensor,), in_dims=1, out_dims=1)
     test(vmap(vmap(lambda t: op(t, 0, 0, 1), in_dims=1), in_dims=3),
          (tensor,), in_dims=1, out_dims=1)
Beispiel #18
0
    def test_T_numpy(self):
        def op(t):
            return t.T

        test = self._vmap_view_test
        B0, B1, B2 = 7, 11, 13
        test(op, (torch.rand(B0, 2, 3, 5),))
        test(op, (torch.rand(B0),))
        test(op, (torch.rand(2, B0, 3, 5),), in_dims=1)
        test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
        test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2)
        test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 3, B2, 5),), in_dims=2)
Beispiel #19
0
    def test_chunk(self):
        test = self._vmap_view_test
        op = torch.chunk
        B0, B1, B2 = 7, 11, 13

        # tests for torch.split(self, split_size: int, dim)
        test(op, (torch.rand(B0, 2, 1024), 15, -1), in_dims=(0, None, None))
        test(op, (torch.rand(2, B0, 1024), 9, 1), in_dims=(1, None, None))
        test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 4, 0),
             in_dims=(2, None, None))
        test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
             (torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
Beispiel #20
0
    def test_unbind(self):
        test = self._vmap_view_test
        op = torch.unbind
        B0, B1, B2 = 7, 11, 13

        test(op, (torch.rand(B0, 2, 1024), -1), in_dims=(0, None))
        test(op, (torch.rand(B0, 2, 0),))
        test(op, (torch.rand(2, B0, 7), 0), in_dims=(1, None))
        test(vmap(op, in_dims=(0, None)), (torch.rand(B1, 1023, B0, 5), 1),
             in_dims=(2, None))
        test(vmap(vmap(lambda t: op(t, dim=1), in_dims=2)),
             (torch.rand(B1, 2, B0, 32, B2),), in_dims=2)
Beispiel #21
0
    def test_none_in_dims(self):
        x = torch.randn(2, 3)
        y = torch.randn(2, 3)

        # None in_dim for a Tensor means we don't map over it
        output = vmap(torch.mul, (0, None))(x, y)
        self.assertEqual(output.shape, (2, 2, 3))
        self.assertEqual(output, x.view(2, 1, 3) * y)

        # None in_dim for non-tensor arguments
        output = vmap(torch.mul, (0, None))(x, 2)
        self.assertEqual(output, x * 2)
Beispiel #22
0
    def test_view(self):
        test = self._vmap_view_test
        B0, B1, B2 = 7, 11, 13
        op = torch.Tensor.view

        # We should error out if the view would produce an incorrect result
        with self.assertRaises(RuntimeError):
            vmap(op, in_dims=(1, None))(torch.rand(2, B0, 5), [10])

        test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None))
        test(op, (torch.rand(B0, 4, 5), [1, 2, 1, 10]), in_dims=(0, None))
        test(vmap(lambda t: t.view([-1])), (torch.rand(B0, B1, 2, 5, 3),))
        test(vmap(vmap(lambda t: t.reshape([-1])), in_dims=1),
             (torch.rand(B2, B0, B1, 3, 2, 5),), in_dims=1)
Beispiel #23
0
    def test_func_with_no_inputs(self):
        expected_msg = 'got no inputs'

        def foo():
            return torch.randn(3)

        def bar(x):
            return torch.randn(3)

        with self.assertRaisesRegex(ValueError, expected_msg):
            vmap(foo)()

        with self.assertRaisesRegex(ValueError, expected_msg):
            vmap(bar)()
Beispiel #24
0
    def test_non_zero_in_dims(self):
        tensor = torch.randn(2, 3, 5)

        # Implicit out_dims = 0; vmap will move the batch dim to the front.
        output = vmap(lambda x: x, (1, ))(tensor)
        self.assertEqual(output, tensor.permute(1, 0, 2))
        self.assertEqual(output.data_ptr(), tensor.data_ptr())

        x = torch.randn(2, 3)
        y = torch.randn(3, 2)
        output = vmap(torch.mul, (0, 1))(x, y)
        self.assertEqual(output, x * y.t())
        output = vmap(torch.mul, (1, 0))(x, y)
        self.assertEqual(output, x.t() * y)
Beispiel #25
0
    def test_reshape_as(self):
        test = self._vmap_test
        B0, B1, B2 = 7, 11, 13
        op = torch.Tensor.reshape_as
        test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)), check_view=True)
        test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0), check_view=True)
        test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None), check_view=True)

        test(op, (torch.rand(2, B0, 5), torch.rand(1, 1, 10)), in_dims=(1, None), check_view=False)

        test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)), check_view=True)
        test(vmap(vmap(op, in_dims=(2, None)), in_dims=(1, None)),
             (torch.rand(3, B1, 2, B2, 5, B0), torch.rand(B0, 3 * 2 * 5)),
             in_dims=(5, 0), check_view=False)
Beispiel #26
0
    def test_nested_out_dims(self):
        y = torch.randn(2, 3, 5, 7)

        # Inner vmap has non-zero out_dim
        result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y)
        self.assertEqual(result.shape, (2, 5, 3, 7))
        self.assertEqual(result, y.permute(0, 2, 1, 3))

        # all vmaps have non-zero out_dim
        result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y),
                      out_dims=1)(y)
        self.assertEqual(result.shape, (5, 2, 3, 7))
        self.assertEqual(result, y.permute(2, 0, 1, 3))

        # throwing in some negative out_dims
        result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y),
                      out_dims=-1)(y)
        self.assertEqual(result.shape, (5, 7, 3, 2))
        self.assertEqual(result, y.permute(2, 3, 1, 0))

        # testing fn that isn't the identity
        x = torch.randn(2, 3)
        y = torch.randn(5, 3)
        result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x),
                      out_dims=-1)(y)
        self.assertEqual(result.shape, (3, 2, 5))
        self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0))
Beispiel #27
0
    def _test_unary(self, op, getter, device):
        test = self._vmap_test
        B0, B1 = 7, 11

        self._assert_doesnt_use_vmap_fallback([op], [getter([B0], device)])

        # Single vmap, various in_dims / out_dims
        test(op, [getter([B0, 3], device)])
        test(op, [getter([2, 5, B0, 3], device)], in_dims=2)
        test(op, [getter([2, 5, B0, 3], device)], in_dims=2, out_dims=2)

        # Doubly nested vmap
        test(vmap(op), [getter([B0, B1], device)])
        test(vmap(op), [getter([B1, 2, 5, B0, 3], device)], in_dims=2)
        test(vmap(op, in_dims=2), [getter([2, 5, B0, B1, 3], device)],
             in_dims=2, out_dims=2)
Beispiel #28
0
def naive_vmap():
    I_N = torch.eye(num_classes)
    # torch._C._debug_only_display_vmap_fallback_warnings(True)
    L = []

    def get_jacobian(v):
        j = torch.autograd.grad(output[i, :],
                                model.parameters(),
                                v,
                                retain_graph=True)
        jac_persample = []
        for j_ in j:
            jac_persample.append(j_.view(-1))
        for name, param in model.named_parameters():
            param.grad = None
        return torch.cat(jac_persample, 0)

    for i in range(BATCH_SIZE):
        jacobian = torch.vmap(get_jacobian)(I_N)
        L.append(jacobian)

    jac = torch.cat(L, 0)
    jac = jac.reshape(BATCH_SIZE, num_classes, -1)
    jac = jac.permute(1, 0, 2)
    jac = jac.reshape(BATCH_SIZE * num_classes, -1)
    JJT = torch.matmul(jac, jac.permute(1, 0)) / BATCH_SIZE
    return JJT
Beispiel #29
0
 def _assert_uses_vmap_fallback(self, vmap_args, inputs):
     with warnings.catch_warnings(record=True) as wa:
         result = vmap(*vmap_args)(*inputs)
         self.assertEqual(len(wa), 2)
         self.assertRegex(
             str(wa[-1].message),
             r'falling back to slow \(for loop and stack\) implementation')
Beispiel #30
0
    def test_multiple_outputs(self):
        def foo(x):
            return x * x, x * x * x

        x = torch.randn(3)
        outputs = vmap(foo)(x)
        self.assertEqual(outputs[0], x * x)
        self.assertEqual(outputs[1], x * x * x)