Пример #1
0
 def test_parameter_instantiation(self):
     with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
         x = torch.rand([4])
         y = torch.nn.parameter.Parameter(x)
         self.assertTrue(isinstance(y, torch.nn.Parameter))
Пример #2
0
 def test_memoized_conversion_to_meta(self):
     x = torch.rand(2, 2, 2)
     mode = FakeTensorMode(inner=None)
     self.assertTrue(mode.from_tensor(x) is mode.from_tensor(x))
Пример #3
0
 def test_normalize_device(self):
     with FakeTensorMode():
         x = torch.empty(1, device="cuda")
         y = torch.empty(1, device=f"cuda:{torch.cuda.current_device()}")
         out = x + y
     self.checkType(out, "cuda", [1])
Пример #4
0
    def test_cudnn_rnn(self):
        def fn(
            a0,
            b0,
            b1,
            b2,
            b3,
            b4,
            b5,
            b6,
            b7,
            b8,
            b9,
            b10,
            b11,
            b12,
            b13,
            b14,
            b15,
            a3,
            a4,
            a5,
        ):
            a1 = [
                b0,
                b1,
                b2,
                b3,
                b4,
                b5,
                b6,
                b7,
                b8,
                b9,
                b10,
                b11,
                b12,
                b13,
                b14,
                b15,
            ]
            return torch.ops.aten._cudnn_rnn(
                a0,
                a1,
                4,
                a3,
                a4,
                a5,
                2,
                2048,
                0,
                2,
                False,
                0.0,
                False,
                True,
                [],
                None,
            )

        mode = FakeTensorMode(inner=None)
        for i, context in enumerate([contextlib.nullcontext, lambda: enable_torch_dispatch_mode(mode)]):
            with context():
                inps = (
                    torch.randn([92, 8, 2048]).cuda(),
                    torch.randn([8192, 2048]).cuda(),
                    torch.randn([8192, 2048]).cuda(),
                    torch.randn([8192]).cuda(),
                    torch.randn([8192]).cuda(),
                    torch.randn([8192, 2048]).cuda(),
                    torch.randn([8192, 2048]).cuda(),
                    torch.randn([8192]).cuda(),
                    torch.randn([8192]).cuda(),
                    torch.randn([8192, 4096]).cuda(),
                    torch.randn([8192, 2048]).cuda(),
                    torch.randn([8192]).cuda(),
                    torch.randn([8192]).cuda(),
                    torch.randn([8192, 4096]).cuda(),
                    torch.randn([8192, 2048]).cuda(),
                    torch.randn([8192]).cuda(),
                    torch.randn([8192]).cuda(),
                    torch.randn([167837696]).cuda(),
                    torch.randn([4, 8, 2048]).cuda(),
                    torch.randn([4, 8, 2048]).cuda(),
                )
                out = fn(*inps)
                self.assertIs(out[4], inps[-3])
                for ten in out:
                    if i == 1:
                        self.assertTrue(isinstance(ten, FakeTensor))
                    self.assertEqual(ten.device.type, 'cuda')
Пример #5
0
    def test_fake_mode_error(self):
        x = torch.rand([4, 4])

        with self.assertRaisesRegex(Exception, "non-Fake Tensor inputs"):
            with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
                y = x[0]
Пример #6
0
 def test_from_numpy(self):
     with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
         x = torch.tensor(np.zeros([4, 4]))
         self.checkType(x, "cpu", [4, 4])
Пример #7
0
    def test_constructor(self):
        with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
            x = torch.rand([4, 4], device="cpu")

        self.assertTrue(isinstance(x, FakeTensor))
        self.assertTrue(x.device.type == "cpu")
Пример #8
0
    def test_mode(self):
        with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
            y = torch.rand([4], device="cpu")
            out = y + y

        self.assertTrue(isinstance(out, FakeTensor))
Пример #9
0
 def test_separate_mode_error(self):
     with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
         x = torch.empty(2, 2, device="cpu")
     with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
         y = torch.empty(2, 2, device="cpu")
     self.assertRaises(Exception, lambda: x, y)
Пример #10
0
 def test_setitem(self):
     for device in ["cpu", "cuda"]:
         with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
             x = torch.rand([16, 1], device=device)
             x[..., 0] = 0
Пример #11
0
 def test_memoized_conversion_from_meta(self):
     x = torch.rand(2, 2).to(device="meta")
     mode = FakeTensorMode(inner=None)
     converter = mode.fake_tensor_converter
     self.assertTrue(converter(mode, x, "cpu") is converter(mode, x, "cpu"))
Пример #12
0
 def test_data_dependent_operator(self):
     with enable_torch_dispatch_mode(
             FakeTensorMode(inner=None, allow_cpu_fallback=False)):
         x = torch.rand([10, 10])
         self.assertRaises(DynamicOutputShapeException,
                           lambda: torch.nonzero(x))
Пример #13
0
 def test_index_cuda_with_cpu(self):
     with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
         x = torch.rand([2048], device='cuda')
         out = x[torch.zeros([36], dtype=torch.int64)]
         self.checkType(out, "cuda", [36])
Пример #14
0
 def test_new(self):
     with enable_torch_dispatch_mode(FakeTensorMode(inner=None)):
         a = torch.rand([16, 1])
         self.checkType(a.new(10, 10), "cpu", [10, 10])
         self.checkType(a.new([1, 2, 3, 4]), "cpu", [4])
         self.checkType(a.new(device='cuda'), "cuda", [0])
Пример #15
0
    def wrapped(*args):
        phs = pytree.tree_map(lambda _: fx.PH,
                              args)  # type: ignore[attr-defined]
        fx_tracer = PythonKeyTracer()
        fake_tensor_mode: Any = nullcontext()
        if tracing_mode == "real":
            fake_tensor_mode = nullcontext()
        elif tracing_mode == "fake":
            fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True)
        elif tracing_mode == "symbolic":
            fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=False)
        else:
            raise AssertionError(f"Unexpected tracing type: {tracing_mode}")

        proxy_mode = ProxyTorchDispatchMode(fx_tracer)

        def wrap_fake_concrete(x):
            if isinstance(x, torch.Tensor):
                return fake_tensor_mode.from_tensor(
                    x)  # type: ignore[attr-defined]

            return x

        shape_env = ShapeEnv()
        sym_mode = proxy_mode.sym_mode

        # todo: Figure out a more informative name for symints
        def wrap_fake_symbolic(x, sym_shape):
            if isinstance(x, torch.Tensor):
                val = FakeTensor(
                    fake_tensor_mode,
                    torch.empty(sym_shape,
                                device="meta",
                                requires_grad=x.requires_grad), x.device)
                return val
            return x

        wrap_fn_map = {
            "real": lambda x: x,
            "fake": wrap_fake_concrete,
        }
        if tracing_mode == "symbolic":
            flat_shapes = shape_env.create_shapes_for_args(args)
            flat_args, spec = pytree.tree_flatten(args)
            args = pytree.tree_unflatten(
                list(
                    map(lambda a: wrap_fake_symbolic(a[0], a[1]),
                        zip(flat_args, flat_shapes))), spec)
        else:
            args = pytree.tree_map(wrap_fn_map[tracing_mode], args)

        if not hasattr(f, '__code__') or inspect.unwrap(
                f).__code__.co_flags & inspect.CO_VARARGS:
            # FX doesn't support varargs, so we gotta fake up a wrapper
            # TODO: Would be nice to fix this at the source...
            func = fake_signature(f, len(phs))
        else:
            func = f

        with decompose(
                decomposition_table
        ), fake_tensor_mode, sym_mode, proxy_mode:  # type: ignore[attr-defined]
            t = dispatch_trace(wrap_key(func, args, fx_tracer),
                               tracer=fx_tracer,
                               concrete_args=tuple(phs))

        # TODO: kind of a bad way to do it, should maybe figure out a better way
        t.shape_env = shape_env  # type: ignore[assignment]
        return t