def test_cpu_fallback(self): with enable_torch_dispatch_mode( FakeTensorMode(inner=None, allow_fallback_kernels=False)): filters = torch.randn(8, 4, 3, 3).cuda() inputs = torch.randn(1, 4, 5, 5).cuda() out = torch.nn.functional.conv2d(inputs, filters, padding=1) self.assertEqual(out.device.type, "cuda") self.assertEqual(list(out.size()), [1, 8, 5, 5]) with enable_torch_dispatch_mode( FakeTensorMode(inner=None, allow_fallback_kernels=True)): # intentionally bad inputs filters = torch.randn(8, 20, 3, 3).cuda() inputs = torch.randn(1, 7, 10, 5).cuda() with self.assertRaises(RuntimeError): torch.nn.functional.conv2d(inputs, filters, padding=1) with enable_torch_dispatch_mode( FakeTensorMode(inner=None, allow_fallback_kernels=True)): filters = torch.randn(8, 4, 3, 3).cuda() inputs = torch.randn(1, 4, 5, 5).cuda() out = torch.nn.functional.conv2d(inputs, filters, padding=1) self.assertEqual(out.device.type, "cuda") self.assertEqual(list(out.size()), [1, 8, 5, 5])
def test_nested_enable_torch_dispatch_mode(self) -> None: class A(LoggingTensorMode): pass with self.assertRaisesRegex(ValueError, "there is already an active mode"): with enable_torch_dispatch_mode(LoggingTensorMode): with enable_torch_dispatch_mode(A): pass
def test_nesting_with_same_enable_torch_dispatch_mode(self) -> None: # "nested" enable_torch_dispatch_modes are allowed if they're the same mode. It's the equivalent of # a noop, so it will only write once to the log with capture_logs() as logs: x = LoggingTensor(torch.tensor([3.])) log_input("x", x) with enable_torch_dispatch_mode(LoggingTensor): with enable_torch_dispatch_mode(LoggingTensor): x + x self.assertExpectedInline( '\n'.join(logs), '''\ $0 = input('x') $1 = torch._ops.aten.add.Tensor($0, $0)''')
def test_schema_info_bind_basic(self): class SchemaInfoBindTestMode(TorchDispatchMode): def __init__(self, test_self): self.test_self = test_self def __torch_dispatch__(self, func, types, args=(), kwargs=None): named_arg_list = normalize_function( func, args, kwargs, normalize_to_only_use_kwargs=True ).kwargs schema_info_value_test = torch._C._SchemaInfo(func._schema) schema_info_values_test = torch._C._SchemaInfo(func._schema) self.test_self.assertFalse(schema_info_value_test.may_alias( torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) self.test_self.assertFalse(schema_info_values_test.may_alias( torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) for i in named_arg_list: schema_info_value_test.add_argument_value(i, named_arg_list[i]) schema_info_values_test.add_argument_values(named_arg_list) self.test_self.assertTrue(schema_info_value_test.may_alias( torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) self.test_self.assertTrue(schema_info_values_test.may_alias( torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) return func(*args, **kwargs) x = torch.rand((3, 3)) schemaInfoCheck = SchemaInfoBindTestMode(self) with enable_torch_dispatch_mode(schemaInfoCheck): x.add(x)
def test_schema_check_mode_functionality_training_op(self): x = torch.rand((3, 3), requires_grad=True) batch = torch.nn.BatchNorm1d(3, track_running_stats=True) expected = batch(x) with enable_torch_dispatch_mode(SchemaCheckMode()): actual = batch(x) self.assertEqual(expected, actual)
def test_schema_check_mode_mutated_aliasing_resize_(self): actual = torch.rand((3, 3), requires_grad=False) schema_check = SchemaCheckMode() with enable_torch_dispatch_mode(schema_check): actual.resize_(9) self.assertEqual([('aten::resize_', 'input')], schema_check.mutated) self.assertEqual([('aten::resize_', 'input', 'output_0')], schema_check.aliasing)
def test_schema_check_mode_operator_order_without_grad(self): schema_check = SchemaCheckMode() with enable_torch_dispatch_mode(schema_check): x = torch.rand((3, 3), requires_grad=False) x.relu().sin() self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops)
def test_alias_check_fail_outputs_unexpectedly_aliasing(self): with self.assertRaisesRegex(RuntimeError, "Outputs 0 and 1 alias unexpectedly"): x = torch.rand((3, 3)) s = SchemaCheckMode() with enable_torch_dispatch_mode(s): IncorrectAliasTensor(x).aminmax(dim=0)
def test_schema_check_mode_functionality_with_multiple_outputs_aliasing( self): x = torch.rand((3, 3)) actual = torch.zeros(3) with enable_torch_dispatch_mode(SchemaCheckMode()): torch.aminmax(x, dim=0, out=[actual, actual]) self.assertEqual(torch.amax(x, dim=0), actual)
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # need to handle here to avoid infinite recursion # see [in_kernel_invocation] if func == torch.ops.prim.device.default: assert len(args) == 1 and isinstance(args[0], FakeTensor) if args[0].fake_mode.in_kernel_invocation: return torch.device("meta") else: return args[0].fake_device # Because fake mode can return NotImplemented (if it sees a subclass # it doesn't know how to deal with), this test here is important # because the next dispatch after a fake mode will attempt to use # subclasses of tensors to dispatch, and any FakeTensor arguments # will be considered eligible. if any(not issubclass(t, FakeTensor) and t is not torch.Tensor for t in types): return NotImplemented fake_mode = None for arg in itertools.chain( tree_flatten(args)[0], tree_flatten(kwargs)[0]): if isinstance(arg, FakeTensor): if fake_mode is None: fake_mode = arg.fake_mode else: assert fake_mode is arg.fake_mode, "Mixing modes NYI" with enable_torch_dispatch_mode(fake_mode): return func(*args, **kwargs)
def test_schema_check_mode_mutated_aliasing_none(self): x = torch.rand((3, 3), requires_grad=True) schema_check = SchemaCheckMode() with enable_torch_dispatch_mode(schema_check): actual = x.relu().sin() self.assertEqual([], schema_check.mutated) self.assertEqual([], schema_check.aliasing)
def test_enable_torch_dispatch_mode_respects_no_dispatch(self) -> None: with enable_torch_dispatch_mode(LoggingTensorMode): z = torch.ones([2, 3]) self.assertTrue(isinstance(z, LoggingTensorMode)) with no_dispatch(): expected = torch.ones([2, 3]) self.assertEqual(z.elem, expected)
def test_type_as(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): x = torch.rand([16, 1], device="cpu") y = torch.rand([4, 4], device="cuda") out = x.type_as(y) self.assertEqual(out.device.type, "cuda") self.assertTrue(isinstance(out, FakeTensor))
def test_schema_check_tensor_functionality_mutable_inputs(self): expected = torch.rand((3, 3), requires_grad=False) actual = torch.clone(expected) expected.sinh_() with enable_torch_dispatch_mode(SchemaCheckMode()): actual.sinh_() self.assertEqual(expected, actual)
def test_mode(self): x = FakeTensor.from_tensor(torch.rand([1])) with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): y = torch.rand([4], device="cpu") out = x + y self.assertTrue(isinstance(y, FakeTensor))
def test_new(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): a = torch.rand([16, 1]) self.checkType(a.new(10, 10), "cpu", [10, 10]) self.checkType(a.new([1, 2, 3, 4]), "cpu", [4]) b = torch.rand([4, 4], device='cuda') self.checkType(b.new(device='cuda'), "cuda", [0])
def test_non_kwarg_device(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): x = torch.rand([16, 1], device="cpu") y = x.to(torch.device("cpu")) self.assertIs(x, y) z = x.to(torch.device("cuda")) self.assertEqual(z.device.type, "cuda")
def test_data_dependent_operator(self): with enable_torch_dispatch_mode( FakeTensorMode(inner=None, allow_fallback_kernels=False) ): x = torch.rand([10, 10]) self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x))
def test_binary_op_type_promotion(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): x = torch.empty([2, 2], dtype=torch.float) y = torch.empty([2, 2], dtype=torch.int64) out = x / y self.assertEqual(out.dtype, torch.float) self.assertEqual(out.device.type, "cpu")
def test_schema_check_tensor_functionality_kwarg_tensor(self): x = torch.rand((3, 5)) w = torch.rand((4)) expected = torch.stft(x, 4, win_length=4, window=w, return_complex=True) with enable_torch_dispatch_mode(SchemaCheckMode()): actual = torch.stft(x, 4, win_length=4, window=w, return_complex=True) self.assertEqual(expected, actual)
def test_enable_torch_dispatch_mode_error(self) -> None: z = LoggingTensor(torch.empty([])) with self.assertRaisesRegex( ValueError, "expected to get TorchDispatchMode, Tensor-like class, or None" ): with enable_torch_dispatch_mode(z): pass
def test_throw(self): mode = FakeTensorMode(inner=None) x = torch.tensor(0.) # TODO: tensor() errors with enable_torch_dispatch_mode(mode): x_conv = mode.from_tensor(x) y = torch.rand([4, 4], device="cuda") z = torch.rand([4, 4], device="cpu") self.assertRaises(Exception, lambda: torch.lerp(x_conv, y, z))
def test_shape_take_not_device(self): with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): x = torch.empty(1, device="cpu") y = torch.empty(8, 8, device="cuda") out = x.resize_as_(y) self.assertEqual(out.shape, (8, 8)) self.assertEqual(out.device.type, "cpu") self.assertTrue(isinstance(out, FakeTensor))
def test_randperm(self): x = torch.randperm(10) y = torch.randperm(5, device="cpu") with enable_torch_dispatch_mode(FakeTensorMode(inner=None)): x1 = torch.randperm(10) prims.utils.compare_tensor_meta(x, x1) y1 = torch.randperm(5, device="cpu") prims.utils.compare_tensor_meta(y, y1)
def test_schema_check_tensor_functionality_list_input(self): a = torch.rand((3, 3)) b = torch.rand((3, 3)) c = torch.rand((3, 3)) expected = torch.linalg.multi_dot([a, b, c]) with enable_torch_dispatch_mode(SchemaCheckMode()): actual = torch.linalg.multi_dot([a, b, c]) self.assertEqual(expected, actual)
def test_enable_torch_dispatch_mode_replace(self): class A(torch.Tensor): @staticmethod def __new__(cls, elem): return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return cls(torch.zeros(())) class B(A): pass with enable_torch_dispatch_mode(A): with enable_torch_dispatch_mode(B, replace=A): self.assertTrue(isinstance(torch.zeros(()), B))
def maybe_disable_fake_tensor_mode(): # TODO: figure out if this API generally makes sense and bake it into the # library mb_fake_mode = torch._C._get_torch_dispatch_mode() if isinstance(mb_fake_mode, FakeTensorMode): return enable_torch_dispatch_mode(mb_fake_mode.inner, replace=mb_fake_mode) else: return nullcontext()
def test_mutation_check_fail(self): with self.assertRaisesRegex( RuntimeError, "Argument input is not defined as mutable but was mutated"): x = torch.rand((3, 3)) y = torch.rand((3, 3)) with enable_torch_dispatch_mode(SchemaCheckMode()): IncorrectAliasTensor(x).sub(IncorrectAliasTensor(y))
def test_schema_correctness(self, device, dtype, op): # Currently torch.equal isn't supported with torch.complex32 # There's also errors with complex64 and complex128 if (dtype == torch.complex32): return for sample in op.sample_inputs(device, dtype, requires_grad=False): with enable_torch_dispatch_mode(SchemaCheckMode()): op(sample.input, *sample.args, **sample.kwargs)
def test_enable_torch_dispatch_mode_subclass_priority(self) -> None: class ErrorA(RuntimeError): pass class ErrorB(RuntimeError): pass class A(torch.Tensor): @staticmethod def __new__(cls, elem): return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): raise ErrorA class B(A): @staticmethod def __new__(cls, elem): return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): raise ErrorB a = A(torch.empty(1)) b = B(torch.empty(1)) with self.assertRaises(ErrorA): a + a with self.assertRaises(ErrorB): a + b # B has precedence over A due to the subclass relationship yet # modes take precedence over arguments with self.assertRaises(ErrorA): with enable_torch_dispatch_mode(A): b + b with self.assertRaises(ErrorB): with enable_torch_dispatch_mode(B): a + a with self.assertRaises(ErrorB): with enable_torch_dispatch_mode(B): a + b