def test_lower_linear(self): # linear is one of main use cases of removing mutation so add test so it doesnt regress @torch.jit.script def foo(x): return F.linear(x, torch.randn(20, 20), torch.randn(20)) self.run_pass('inline', foo.graph) self.run_pass('peephole', foo.graph) self.run_pass('constant_propagation', foo.graph) FileCheck().check("aten::add_").run(foo.graph) input = torch.randn(20, 20) with freeze_rng_state(): out1 = foo(input) self.run_pass('remove_mutation', foo.graph) FileCheck().check_not("aten::add_").run(foo.graph) with freeze_rng_state(): out2 = foo(input) self.assertEqual(out1, out2)
def test_special_mapped_op(self): def test_successful(): x = torch.tensor([2, 2]) y = torch.tensor([2, 4]) x.zero_() y.fill_(3) return x, y fn = torch.jit.script(test_successful) graph = fn.graph self.run_pass('remove_mutation', graph) FileCheck().check_not("aten::zero_").check_not("aten::fill_").run( graph) self.assertEqual(test_successful(), fn()) # full_like is not implemented for a tensor fill value def test_unsuccessful(): x = torch.tensor([2, 2]) y = torch.tensor([2, 4]) x.fill_(y) return x + x fn = torch.jit.script(test_unsuccessful) graph = fn.graph self.run_pass('remove_mutation', graph) FileCheck().check('aten::fill_').run(graph) def normal(): return torch.rand(2, 1, 3, 4).normal_() fn = torch.jit.script(normal) graph = fn.graph self.run_pass('remove_mutation', graph) FileCheck().check_not("normal_").run(graph) with freeze_rng_state(): out_eager = normal() with freeze_rng_state(): out_script = fn() self.assertEqual(out_eager, out_script)