def _vmap_test(self, op, inputs, in_dims=0, out_dims=0, check_view=False, check_propagates_grad=True): result = vmap(op, in_dims, out_dims)(*inputs) reference_result = reference_vmap(op, inputs, in_dims, out_dims) self.assertEqual(result, reference_result) op_has_single_return = not isinstance(result, tuple) if check_view: result_as_tuple = (result, ) if op_has_single_return else result for output in result_as_tuple: self.assertEqual( output.data_ptr() - output.storage_offset() * output.element_size(), inputs[0].data_ptr(), msg="result was not a view of the first input!") if not check_propagates_grad: return # Assuming input[0] is a floating-point tensor. Check if the vmap # operation propagates the requires_grad flag to the zeroth output. # Some vmap operators are implemented in a way that assumes that # they are composite with respect to autograd. If the operator ever is # changed to not be composite with respect to autograd, then the # following check should fail. inputs_clone = list(inputs) inputs_clone[0] = inputs[0].clone().requires_grad_() result = vmap(op, in_dims, out_dims)(*inputs_clone) result_as_tuple = (result, ) if op_has_single_return else result self.assertTrue(result[0].requires_grad)
def test_slice(self): test = self._vmap_view_test B0, B1, B2 = 7, 11, 13 test(lambda t: t[0:1], (torch.rand(B0, 3, 5),)) test(lambda t: t[:, 1:3], (torch.rand(3, 5, B0),), in_dims=2) test(vmap(lambda t: t[:, 0:1], in_dims=2), (torch.rand(3, 5, B0, B1),), in_dims=2) test(vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2), (torch.rand(3, 5, B0, B1, B2),), in_dims=2)
def test_unsupported_op_err_msg(self): def foo(x): return torch.cos(x) x = torch.randn(3) with self.assertRaisesRegex(RuntimeError, 'NYI: Calling aten::cos inside of vmap'): vmap(foo)(x)
def test_nested_with_same_map_dim(self): x = torch.randn(2, 3, 5) y = torch.randn(2, 3, 5) output = vmap(vmap(torch.mul))(x, y) self.assertEqual(output, x * y) output = vmap(vmap(vmap(torch.mul)))(x, y) self.assertEqual(output, x * y)
def test_t(self): op = torch.t test = self._vmap_view_test B0, B1, B2 = 7, 11, 13 test(op, (torch.rand(B0, 2, 5),)) test(op, (torch.rand(2, B0, 5),), in_dims=1) test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2) test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)
def test_unsupported_inplace_op_err_msg(self): def foo(x): return x.cos_() x = torch.randn(3) with self.assertRaisesRegex(RuntimeError, 'Batching rule not implemented'): vmap(foo)(x)
def test_select(self): op = torch.select test = self._vmap_view_test B0, B1, B2 = 7, 11, 13 test(op, (torch.rand(B0, 2, 5), 0, 0), in_dims=(0, None, None)) test(op, (torch.rand(2, B0, 5), 1, 1), in_dims=(1, None, None)) test(vmap(lambda t: op(t, 1, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2) test(vmap(vmap(lambda t: op(t, 1, 1), in_dims=1)), (torch.rand(B1, 2, B0, B2, 5),), in_dims=2)
def test_out_dims_edge_case(self): def foo(x): return x # Test that we accept out_dims=(1,) for a function with one output. tensor = torch.randn(2, 3) expected = vmap(foo, out_dims=1)(tensor) result = vmap(foo, out_dims=(1, ))(tensor) self.assertEqual(result, expected)
def test_reshape(self): test = self._vmap_test B0, B1, B2 = 7, 11, 13 op = torch.reshape test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None), check_view=True) test(op, (torch.rand(2, B0, 5), [1, 1, 10]), in_dims=(1, None), check_view=False) test(vmap(lambda t: t.reshape([-1])), (torch.rand(B0, B1, 2, 5),), check_view=True) test(vmap(vmap(lambda t: t.reshape([-1]), in_dims=2), in_dims=1), (torch.rand(3, B1, 2, B2, 5, B0),), in_dims=5, check_view=False)
def test_non_tensor_output_raises(self): with self.assertRaisesRegex(ValueError, "got type <class 'float'> as the return"): output = vmap(lambda x: 3.14)(torch.ones(3)) def multiple_outputs(x): return x, 3 with self.assertRaisesRegex(ValueError, "got type <class 'int'> for return 1"): vmap(multiple_outputs)(torch.ones(3))
def test_out_dim_out_of_bounds_err_msg(self): # TODO(rzou): This error message isn't that great. It comes straight # from maybe_wrap_dim. Consider doing a try-catch-(add some context) to # the error message in the future in C++ msg = 'Dimension out of range' x = torch.randn(2, 3, 5) with self.assertRaisesRegex(IndexError, msg): vmap(lambda x: x, out_dims=3)(x) with self.assertRaisesRegex(IndexError, msg): vmap(lambda x: x, out_dims=-4)(x)
def test_unsupported_inplace_op_err_msg(self): def foo(x): return x.cos_() x = torch.randn(3) # TODO(rzou): Yeah, this error message is pretty bad because the # dispatcher's fallback mechanism doesn't work for ops that don't support # boxing. Fix the error message at some point. with self.assertRaisesRegex(RuntimeError, 'Tried to call KernelFunction::call'): vmap(foo)(x)
def test_nested_with_different_map_dim(self): x = torch.randn(2, 3) y = torch.randn(5, 3) output = vmap(lambda x: vmap(lambda y: x * y)(y))(x) self.assertEqual(output.shape, (2, 5, 3)) self.assertEqual(output, x.view(2, 1, 3) * y) z = torch.randn(7, 3) output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x) self.assertEqual(output.shape, (2, 5, 7, 3)) self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z)
def test_unfold(self): op = torch.Tensor.unfold test = self._vmap_view_test B0, B1, B2 = 3, 2, 5 test(op, (torch.rand(B0, 7, 11), 0, 2, 1), in_dims=(0, None, None, None)) test(op, (torch.rand(7, B0, 11), 1, 4, 2), in_dims=(1, None, None, None)) test(vmap(op, in_dims=(0, None, None, None)), (torch.rand(B1, 7, B0, 11), 1, 5, 1), in_dims=(2, None, None, None)) test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)), (torch.rand(B1, 7, B0, 11, B2), -1, 2, 4), in_dims=(2, None, None, None))
def test_narrow(self): op = torch.narrow test = self._vmap_view_test B0, B1, B2 = 7, 11, 13 test(op, (torch.rand(B0, 2, 5), -1, 1, 3), in_dims=(0, None, None, None)) test(op, (torch.rand(2, B0, 5), 1, 1, 3), in_dims=(1, None, None, None)) test(vmap(op, in_dims=(0, None, None, None)), (torch.rand(B1, 2, B0, 5), 1, 0, 0), in_dims=(2, None, None, None)) test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)), (torch.rand(B1, 2, B0, 5, B2), -1, 2, 3), in_dims=(2, None, None, None))
def test_expand_as(self): op = torch.Tensor.expand_as test = self._vmap_view_test B0, B1, B2 = 7, 11, 13 test(op, (torch.rand(B0, 1, 5), torch.rand(B0, 2, 3, 5))) test(op, (torch.rand(B0, 1, 5), torch.rand(2, 3, 5)), in_dims=(0, None)) test(op, (torch.rand(1, 5), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0)) test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B0, B1, 2, 3, 5))) test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B1, B0, 2, 3, 5)), in_dims=(0, 1)) test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None)) test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5)))
def test_diagonal(self): tensor = torch.randn(3, 5, 7, 11, 13) test = self._vmap_view_test op = torch.diagonal test(op, (tensor, 1, 0, 1), in_dims=(0, None, None, None)) test(op, (tensor, 0, 2, -1), in_dims=(0, None, None, None)) test(op, (tensor, 2, 1, 2), in_dims=(1, None, None, None)) test(op, (tensor, 0, -2, -1), in_dims=(1, None, None, None), out_dims=1) test(vmap(lambda t: op(t, 0, 0, -1)), (tensor,), in_dims=1, out_dims=1) test(vmap(vmap(lambda t: op(t, 0, 0, 1), in_dims=1), in_dims=3), (tensor,), in_dims=1, out_dims=1)
def test_T_numpy(self): def op(t): return t.T test = self._vmap_view_test B0, B1, B2 = 7, 11, 13 test(op, (torch.rand(B0, 2, 3, 5),)) test(op, (torch.rand(B0),)) test(op, (torch.rand(2, B0, 3, 5),), in_dims=1) test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2) test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2) test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 3, B2, 5),), in_dims=2)
def test_chunk(self): test = self._vmap_view_test op = torch.chunk B0, B1, B2 = 7, 11, 13 # tests for torch.split(self, split_size: int, dim) test(op, (torch.rand(B0, 2, 1024), 15, -1), in_dims=(0, None, None)) test(op, (torch.rand(2, B0, 1024), 9, 1), in_dims=(1, None, None)) test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 4, 0), in_dims=(2, None, None)) test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), (torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
def test_unbind(self): test = self._vmap_view_test op = torch.unbind B0, B1, B2 = 7, 11, 13 test(op, (torch.rand(B0, 2, 1024), -1), in_dims=(0, None)) test(op, (torch.rand(B0, 2, 0),)) test(op, (torch.rand(2, B0, 7), 0), in_dims=(1, None)) test(vmap(op, in_dims=(0, None)), (torch.rand(B1, 1023, B0, 5), 1), in_dims=(2, None)) test(vmap(vmap(lambda t: op(t, dim=1), in_dims=2)), (torch.rand(B1, 2, B0, 32, B2),), in_dims=2)
def test_none_in_dims(self): x = torch.randn(2, 3) y = torch.randn(2, 3) # None in_dim for a Tensor means we don't map over it output = vmap(torch.mul, (0, None))(x, y) self.assertEqual(output.shape, (2, 2, 3)) self.assertEqual(output, x.view(2, 1, 3) * y) # None in_dim for non-tensor arguments output = vmap(torch.mul, (0, None))(x, 2) self.assertEqual(output, x * 2)
def test_view(self): test = self._vmap_view_test B0, B1, B2 = 7, 11, 13 op = torch.Tensor.view # We should error out if the view would produce an incorrect result with self.assertRaises(RuntimeError): vmap(op, in_dims=(1, None))(torch.rand(2, B0, 5), [10]) test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None)) test(op, (torch.rand(B0, 4, 5), [1, 2, 1, 10]), in_dims=(0, None)) test(vmap(lambda t: t.view([-1])), (torch.rand(B0, B1, 2, 5, 3),)) test(vmap(vmap(lambda t: t.reshape([-1])), in_dims=1), (torch.rand(B2, B0, B1, 3, 2, 5),), in_dims=1)
def test_func_with_no_inputs(self): expected_msg = 'got no inputs' def foo(): return torch.randn(3) def bar(x): return torch.randn(3) with self.assertRaisesRegex(ValueError, expected_msg): vmap(foo)() with self.assertRaisesRegex(ValueError, expected_msg): vmap(bar)()
def test_non_zero_in_dims(self): tensor = torch.randn(2, 3, 5) # Implicit out_dims = 0; vmap will move the batch dim to the front. output = vmap(lambda x: x, (1, ))(tensor) self.assertEqual(output, tensor.permute(1, 0, 2)) self.assertEqual(output.data_ptr(), tensor.data_ptr()) x = torch.randn(2, 3) y = torch.randn(3, 2) output = vmap(torch.mul, (0, 1))(x, y) self.assertEqual(output, x * y.t()) output = vmap(torch.mul, (1, 0))(x, y) self.assertEqual(output, x.t() * y)
def test_reshape_as(self): test = self._vmap_test B0, B1, B2 = 7, 11, 13 op = torch.Tensor.reshape_as test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)), check_view=True) test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0), check_view=True) test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None), check_view=True) test(op, (torch.rand(2, B0, 5), torch.rand(1, 1, 10)), in_dims=(1, None), check_view=False) test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)), check_view=True) test(vmap(vmap(op, in_dims=(2, None)), in_dims=(1, None)), (torch.rand(3, B1, 2, B2, 5, B0), torch.rand(B0, 3 * 2 * 5)), in_dims=(5, 0), check_view=False)
def test_nested_out_dims(self): y = torch.randn(2, 3, 5, 7) # Inner vmap has non-zero out_dim result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y) self.assertEqual(result.shape, (2, 5, 3, 7)) self.assertEqual(result, y.permute(0, 2, 1, 3)) # all vmaps have non-zero out_dim result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y) self.assertEqual(result.shape, (5, 2, 3, 7)) self.assertEqual(result, y.permute(2, 0, 1, 3)) # throwing in some negative out_dims result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y) self.assertEqual(result.shape, (5, 7, 3, 2)) self.assertEqual(result, y.permute(2, 3, 1, 0)) # testing fn that isn't the identity x = torch.randn(2, 3) y = torch.randn(5, 3) result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y) self.assertEqual(result.shape, (3, 2, 5)) self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0))
def _test_unary(self, op, getter, device): test = self._vmap_test B0, B1 = 7, 11 self._assert_doesnt_use_vmap_fallback([op], [getter([B0], device)]) # Single vmap, various in_dims / out_dims test(op, [getter([B0, 3], device)]) test(op, [getter([2, 5, B0, 3], device)], in_dims=2) test(op, [getter([2, 5, B0, 3], device)], in_dims=2, out_dims=2) # Doubly nested vmap test(vmap(op), [getter([B0, B1], device)]) test(vmap(op), [getter([B1, 2, 5, B0, 3], device)], in_dims=2) test(vmap(op, in_dims=2), [getter([2, 5, B0, B1, 3], device)], in_dims=2, out_dims=2)
def naive_vmap(): I_N = torch.eye(num_classes) # torch._C._debug_only_display_vmap_fallback_warnings(True) L = [] def get_jacobian(v): j = torch.autograd.grad(output[i, :], model.parameters(), v, retain_graph=True) jac_persample = [] for j_ in j: jac_persample.append(j_.view(-1)) for name, param in model.named_parameters(): param.grad = None return torch.cat(jac_persample, 0) for i in range(BATCH_SIZE): jacobian = torch.vmap(get_jacobian)(I_N) L.append(jacobian) jac = torch.cat(L, 0) jac = jac.reshape(BATCH_SIZE, num_classes, -1) jac = jac.permute(1, 0, 2) jac = jac.reshape(BATCH_SIZE * num_classes, -1) JJT = torch.matmul(jac, jac.permute(1, 0)) / BATCH_SIZE return JJT
def _assert_uses_vmap_fallback(self, vmap_args, inputs): with warnings.catch_warnings(record=True) as wa: result = vmap(*vmap_args)(*inputs) self.assertEqual(len(wa), 2) self.assertRegex( str(wa[-1].message), r'falling back to slow \(for loop and stack\) implementation')
def test_multiple_outputs(self): def foo(x): return x * x, x * x * x x = torch.randn(3) outputs = vmap(foo)(x) self.assertEqual(outputs[0], x * x) self.assertEqual(outputs[1], x * x * x)