示例#1
0
    def test_lower_linear(self):
        # linear is one of main use cases of removing mutation so add test so it doesnt regress
        @torch.jit.script
        def foo(x):
            return F.linear(x, torch.randn(20, 20), torch.randn(20))

        self.run_pass('inline', foo.graph)
        self.run_pass('peephole', foo.graph)
        self.run_pass('constant_propagation', foo.graph)
        FileCheck().check("aten::add_").run(foo.graph)
        input = torch.randn(20, 20)
        with freeze_rng_state():
            out1 = foo(input)

        self.run_pass('remove_mutation', foo.graph)
        FileCheck().check_not("aten::add_").run(foo.graph)
        with freeze_rng_state():
            out2 = foo(input)
        self.assertEqual(out1, out2)
示例#2
0
    def test_special_mapped_op(self):
        def test_successful():
            x = torch.tensor([2, 2])
            y = torch.tensor([2, 4])
            x.zero_()
            y.fill_(3)
            return x, y

        fn = torch.jit.script(test_successful)
        graph = fn.graph
        self.run_pass('remove_mutation', graph)
        FileCheck().check_not("aten::zero_").check_not("aten::fill_").run(
            graph)
        self.assertEqual(test_successful(), fn())

        # full_like is not implemented for a tensor fill value

        def test_unsuccessful():
            x = torch.tensor([2, 2])
            y = torch.tensor([2, 4])
            x.fill_(y)
            return x + x

        fn = torch.jit.script(test_unsuccessful)
        graph = fn.graph
        self.run_pass('remove_mutation', graph)
        FileCheck().check('aten::fill_').run(graph)

        def normal():
            return torch.rand(2, 1, 3, 4).normal_()

        fn = torch.jit.script(normal)
        graph = fn.graph
        self.run_pass('remove_mutation', graph)
        FileCheck().check_not("normal_").run(graph)
        with freeze_rng_state():
            out_eager = normal()
        with freeze_rng_state():
            out_script = fn()
        self.assertEqual(out_eager, out_script)