Exemple #1
0
    def test_cpu_fallback(self):
        with enable_torch_dispatch_mode(
                FakeTensorMode(inner=None, allow_fallback_kernels=False)):
            filters = torch.randn(8, 4, 3, 3).cuda()
            inputs = torch.randn(1, 4, 5, 5).cuda()
            out = torch.nn.functional.conv2d(inputs, filters, padding=1)
            self.assertEqual(out.device.type, "cuda")
            self.assertEqual(list(out.size()), [1, 8, 5, 5])

        with enable_torch_dispatch_mode(
                FakeTensorMode(inner=None, allow_fallback_kernels=True)):
            # intentionally bad inputs
            filters = torch.randn(8, 20, 3, 3).cuda()
            inputs = torch.randn(1, 7, 10, 5).cuda()
            with self.assertRaises(RuntimeError):
                torch.nn.functional.conv2d(inputs, filters, padding=1)

        with enable_torch_dispatch_mode(
                FakeTensorMode(inner=None, allow_fallback_kernels=True)):
            filters = torch.randn(8, 4, 3, 3).cuda()
            inputs = torch.randn(1, 4, 5, 5).cuda()

            out = torch.nn.functional.conv2d(inputs, filters, padding=1)
            self.assertEqual(out.device.type, "cuda")
            self.assertEqual(list(out.size()), [1, 8, 5, 5])
    def test_nested_enable_torch_dispatch_mode(self) -> None:
        class A(LoggingTensorMode):
            pass

        with self.assertRaisesRegex(ValueError,
                                    "there is already an active mode"):
            with enable_torch_dispatch_mode(LoggingTensorMode):
                with enable_torch_dispatch_mode(A):
                    pass
    def test_nesting_with_same_enable_torch_dispatch_mode(self) -> None:
        # "nested" enable_torch_dispatch_modes are allowed if they're the same mode. It's the equivalent of
        # a noop, so it will only write once to the log
        with capture_logs() as logs:
            x = LoggingTensor(torch.tensor([3.]))
            log_input("x", x)
            with enable_torch_dispatch_mode(LoggingTensor):
                with enable_torch_dispatch_mode(LoggingTensor):
                    x + x

        self.assertExpectedInline(
            '\n'.join(logs), '''\
$0 = input('x')
$1 = torch._ops.aten.add.Tensor($0, $0)''')
Exemple #4
0
    def test_schema_info_bind_basic(self):
        class SchemaInfoBindTestMode(TorchDispatchMode):
            def __init__(self, test_self):
                self.test_self = test_self

            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                named_arg_list = normalize_function(
                    func,
                    args,
                    kwargs,
                    normalize_to_only_use_kwargs=True
                ).kwargs
                schema_info_value_test = torch._C._SchemaInfo(func._schema)
                schema_info_values_test = torch._C._SchemaInfo(func._schema)
                self.test_self.assertFalse(schema_info_value_test.may_alias(
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
                self.test_self.assertFalse(schema_info_values_test.may_alias(
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
                for i in named_arg_list:
                    schema_info_value_test.add_argument_value(i, named_arg_list[i])
                schema_info_values_test.add_argument_values(named_arg_list)
                self.test_self.assertTrue(schema_info_value_test.may_alias(
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
                self.test_self.assertTrue(schema_info_values_test.may_alias(
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))

                return func(*args, **kwargs)
        x = torch.rand((3, 3))
        schemaInfoCheck = SchemaInfoBindTestMode(self)
        with enable_torch_dispatch_mode(schemaInfoCheck):
            x.add(x)
Exemple #5
0
 def test_schema_check_mode_functionality_training_op(self):
     x = torch.rand((3, 3), requires_grad=True)
     batch = torch.nn.BatchNorm1d(3, track_running_stats=True)
     expected = batch(x)
     with enable_torch_dispatch_mode(SchemaCheckMode()):
         actual = batch(x)
     self.assertEqual(expected, actual)
Exemple #6
0
 def test_schema_check_mode_mutated_aliasing_resize_(self):
     actual = torch.rand((3, 3), requires_grad=False)
     schema_check = SchemaCheckMode()
     with enable_torch_dispatch_mode(schema_check):
         actual.resize_(9)
     self.assertEqual([('aten::resize_', 'input')], schema_check.mutated)
     self.assertEqual([('aten::resize_', 'input', 'output_0')], schema_check.aliasing)
Exemple #7
0
 def test_schema_check_mode_operator_order_without_grad(self):
     schema_check = SchemaCheckMode()
     with enable_torch_dispatch_mode(schema_check):
         x = torch.rand((3, 3), requires_grad=False)
         x.relu().sin()
     self.assertEqual(["aten::rand", "aten::relu", "aten::sin"],
                      schema_check.ops)
Exemple #8
0
 def test_alias_check_fail_outputs_unexpectedly_aliasing(self):
     with self.assertRaisesRegex(RuntimeError,
                                 "Outputs 0 and 1 alias unexpectedly"):
         x = torch.rand((3, 3))
         s = SchemaCheckMode()
         with enable_torch_dispatch_mode(s):
             IncorrectAliasTensor(x).aminmax(dim=0)
Exemple #9
0
 def test_schema_check_mode_functionality_with_multiple_outputs_aliasing(
         self):
     x = torch.rand((3, 3))
     actual = torch.zeros(3)
     with enable_torch_dispatch_mode(SchemaCheckMode()):
         torch.aminmax(x, dim=0, out=[actual, actual])
     self.assertEqual(torch.amax(x, dim=0), actual)
Exemple #10
0
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        # need to handle here to avoid infinite recursion
        # see [in_kernel_invocation]
        if func == torch.ops.prim.device.default:
            assert len(args) == 1 and isinstance(args[0], FakeTensor)
            if args[0].fake_mode.in_kernel_invocation:
                return torch.device("meta")
            else:
                return args[0].fake_device

        # Because fake mode can return NotImplemented (if it sees a subclass
        # it doesn't know how to deal with), this test here is important
        # because the next dispatch after a fake mode will attempt to use
        # subclasses of tensors to dispatch, and any FakeTensor arguments
        # will be considered eligible.
        if any(not issubclass(t, FakeTensor) and t is not torch.Tensor
               for t in types):
            return NotImplemented

        fake_mode = None
        for arg in itertools.chain(
                tree_flatten(args)[0],
                tree_flatten(kwargs)[0]):
            if isinstance(arg, FakeTensor):
                if fake_mode is None:
                    fake_mode = arg.fake_mode
                else:
                    assert fake_mode is arg.fake_mode, "Mixing modes NYI"

        with enable_torch_dispatch_mode(fake_mode):
            return func(*args, **kwargs)
Exemple #11
0
 def test_schema_check_mode_mutated_aliasing_none(self):
     x = torch.rand((3, 3), requires_grad=True)
     schema_check = SchemaCheckMode()
     with enable_torch_dispatch_mode(schema_check):
         actual = x.relu().sin()
     self.assertEqual([], schema_check.mutated)
     self.assertEqual([], schema_check.aliasing)
 def test_enable_torch_dispatch_mode_respects_no_dispatch(self) -> None:
     with enable_torch_dispatch_mode(LoggingTensorMode):
         z = torch.ones([2, 3])
         self.assertTrue(isinstance(z, LoggingTensorMode))
         with no_dispatch():
             expected = torch.ones([2, 3])
             self.assertEqual(z.elem, expected)
Exemple #13
0
 def test_type_as(self):
     with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
         x = torch.rand([16, 1], device="cpu")
         y = torch.rand([4, 4], device="cuda")
         out = x.type_as(y)
         self.assertEqual(out.device.type, "cuda")
         self.assertTrue(isinstance(out, FakeTensor))
Exemple #14
0
 def test_schema_check_tensor_functionality_mutable_inputs(self):
     expected = torch.rand((3, 3), requires_grad=False)
     actual = torch.clone(expected)
     expected.sinh_()
     with enable_torch_dispatch_mode(SchemaCheckMode()):
         actual.sinh_()
     self.assertEqual(expected, actual)
Exemple #15
0
    def test_mode(self):
        x = FakeTensor.from_tensor(torch.rand([1]))
        with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
            y = torch.rand([4], device="cpu")
            out = x + y

        self.assertTrue(isinstance(y, FakeTensor))
Exemple #16
0
 def test_new(self):
     with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
         a = torch.rand([16, 1])
         self.checkType(a.new(10, 10), "cpu", [10, 10])
         self.checkType(a.new([1, 2, 3, 4]), "cpu", [4])
         b = torch.rand([4, 4], device='cuda')
         self.checkType(b.new(device='cuda'), "cuda", [0])
Exemple #17
0
 def test_non_kwarg_device(self):
     with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
         x = torch.rand([16, 1], device="cpu")
         y = x.to(torch.device("cpu"))
         self.assertIs(x, y)
         z = x.to(torch.device("cuda"))
         self.assertEqual(z.device.type, "cuda")
Exemple #18
0
    def test_data_dependent_operator(self):
        with enable_torch_dispatch_mode(
            FakeTensorMode(inner=None, allow_fallback_kernels=False)
        ):
            x = torch.rand([10, 10])

            self.assertRaises(DynamicOutputShapeException, lambda: torch.nonzero(x))
Exemple #19
0
 def test_binary_op_type_promotion(self):
     with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
         x = torch.empty([2, 2], dtype=torch.float)
         y = torch.empty([2, 2], dtype=torch.int64)
         out = x / y
         self.assertEqual(out.dtype, torch.float)
         self.assertEqual(out.device.type, "cpu")
Exemple #20
0
 def test_schema_check_tensor_functionality_kwarg_tensor(self):
     x = torch.rand((3, 5))
     w = torch.rand((4))
     expected = torch.stft(x, 4, win_length=4, window=w, return_complex=True)
     with enable_torch_dispatch_mode(SchemaCheckMode()):
         actual = torch.stft(x, 4, win_length=4, window=w, return_complex=True)
     self.assertEqual(expected, actual)
 def test_enable_torch_dispatch_mode_error(self) -> None:
     z = LoggingTensor(torch.empty([]))
     with self.assertRaisesRegex(
             ValueError,
             "expected to get TorchDispatchMode, Tensor-like class, or None"
     ):
         with enable_torch_dispatch_mode(z):
             pass
Exemple #22
0
 def test_throw(self):
     mode = FakeTensorMode(inner=None)
     x = torch.tensor(0.)  # TODO: tensor() errors
     with enable_torch_dispatch_mode(mode):
         x_conv = mode.from_tensor(x)
         y = torch.rand([4, 4], device="cuda")
         z = torch.rand([4, 4], device="cpu")
         self.assertRaises(Exception, lambda: torch.lerp(x_conv, y, z))
Exemple #23
0
 def test_shape_take_not_device(self):
     with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
         x = torch.empty(1, device="cpu")
         y = torch.empty(8, 8, device="cuda")
         out = x.resize_as_(y)
         self.assertEqual(out.shape, (8, 8))
         self.assertEqual(out.device.type, "cpu")
         self.assertTrue(isinstance(out, FakeTensor))
Exemple #24
0
 def test_randperm(self):
     x = torch.randperm(10)
     y = torch.randperm(5, device="cpu")
     with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
         x1 = torch.randperm(10)
         prims.utils.compare_tensor_meta(x, x1)
         y1 = torch.randperm(5, device="cpu")
         prims.utils.compare_tensor_meta(y, y1)
Exemple #25
0
 def test_schema_check_tensor_functionality_list_input(self):
     a = torch.rand((3, 3))
     b = torch.rand((3, 3))
     c = torch.rand((3, 3))
     expected = torch.linalg.multi_dot([a, b, c])
     with enable_torch_dispatch_mode(SchemaCheckMode()):
         actual = torch.linalg.multi_dot([a, b, c])
     self.assertEqual(expected, actual)
    def test_enable_torch_dispatch_mode_replace(self):
        class A(torch.Tensor):
            @staticmethod
            def __new__(cls, elem):
                return torch.Tensor._make_subclass(cls, elem,
                                                   elem.requires_grad)

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                return cls(torch.zeros(()))

        class B(A):
            pass

        with enable_torch_dispatch_mode(A):
            with enable_torch_dispatch_mode(B, replace=A):
                self.assertTrue(isinstance(torch.zeros(()), B))
Exemple #27
0
def maybe_disable_fake_tensor_mode():
    # TODO: figure out if this API generally makes sense and bake it into the
    # library
    mb_fake_mode = torch._C._get_torch_dispatch_mode()
    if isinstance(mb_fake_mode, FakeTensorMode):
        return enable_torch_dispatch_mode(mb_fake_mode.inner, replace=mb_fake_mode)
    else:
        return nullcontext()
Exemple #28
0
 def test_mutation_check_fail(self):
     with self.assertRaisesRegex(
             RuntimeError,
             "Argument input is not defined as mutable but was mutated"):
         x = torch.rand((3, 3))
         y = torch.rand((3, 3))
         with enable_torch_dispatch_mode(SchemaCheckMode()):
             IncorrectAliasTensor(x).sub(IncorrectAliasTensor(y))
 def test_schema_correctness(self, device, dtype, op):
     # Currently torch.equal isn't supported with torch.complex32
     # There's also errors with complex64 and complex128
     if (dtype == torch.complex32):
         return
     for sample in op.sample_inputs(device, dtype, requires_grad=False):
         with enable_torch_dispatch_mode(SchemaCheckMode()):
             op(sample.input, *sample.args, **sample.kwargs)
    def test_enable_torch_dispatch_mode_subclass_priority(self) -> None:
        class ErrorA(RuntimeError):
            pass

        class ErrorB(RuntimeError):
            pass

        class A(torch.Tensor):
            @staticmethod
            def __new__(cls, elem):
                return torch.Tensor._make_subclass(cls, elem,
                                                   elem.requires_grad)

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                raise ErrorA

        class B(A):
            @staticmethod
            def __new__(cls, elem):
                return torch.Tensor._make_subclass(cls, elem,
                                                   elem.requires_grad)

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                raise ErrorB

        a = A(torch.empty(1))
        b = B(torch.empty(1))
        with self.assertRaises(ErrorA):
            a + a
        with self.assertRaises(ErrorB):
            a + b

        # B has precedence over A due to the subclass relationship yet
        # modes take precedence over arguments
        with self.assertRaises(ErrorA):
            with enable_torch_dispatch_mode(A):
                b + b
        with self.assertRaises(ErrorB):
            with enable_torch_dispatch_mode(B):
                a + a
        with self.assertRaises(ErrorB):
            with enable_torch_dispatch_mode(B):
                a + b