def checkScriptRaisesRegex(self, script, inputs, exception, regex, outputs=None, capture_output=False, profiling=ProfilingMode.PROFILING): """ Checks that a given function will throw the correct exception, when executed with normal python, the string frontend, and the AST frontend """ with enable_profiling_mode_for_profiling_tests(): # normal python with self.assertRaisesRegex(exception, regex): script(*inputs) # string frontend with self.assertRaisesRegex(exception, regex): source = textwrap.dedent(inspect.getsource(script)) cu = torch.jit.CompilationUnit(source) ge = getattr(cu, script.__name__) # profiling run with self.assertRaisesRegex(exception, regex): ge(*inputs) # optimized run ge(*inputs) # python AST frontend with self.assertRaisesRegex(exception, regex): ge = torch.jit.script(script) # profiling run with self.assertRaisesRegex(exception, regex): ge(*inputs) # optimized run ge(*inputs)
def test_clamp(self): def func2(a, b): return torch.clamp(a + b, min=0, max=2) def funcInf(a, b): return torch.clamp(a + b, min=0, max=float('inf')) def funcNegInf(a, b): return torch.clamp(a + b, min=float('-inf'), max=0) def funcOptMin(a, b): return torch.clamp(a + b, max=2) def funcOptMax(a, b): return torch.clamp(a + b, min=0) a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True) b = torch.randn(4, 4, dtype=torch.float, device='cuda') nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda') funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax) for f, inputs in product(funcs, [[a, b], [a, nan]]): inp1, inp2 = inputs s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING) self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'}) c = s(inp1, inp2) with enable_profiling_mode_for_profiling_tests(): warmup_backward(c.sum()) graph = backward_graph(s) self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'})
def test_bias_as_arg(self): with enable_profiling_mode_for_profiling_tests(): def method1(x, weight, bias: Optional[torch.Tensor]): return torch.nn.functional.linear(x, weight, bias).relu() + 2 N = 10 x = torch.rand(N, N, requires_grad=True) weight = torch.rand(N, N, requires_grad=True) bias = None scripted = self.checkScript(method1, (x, weight, bias)) # check_types requires last_graph on scripted to be set, so we just skip it check_against_reference(self, scripted, method1, lambda x: x, (x, weight, bias), check_types=False) bias = torch.rand(N, N, requires_grad=True) scripted = self.checkScript(method1, (x, weight, bias)) # check_types requires last_graph on scripted to be set, so we just skip it check_against_reference(self, scripted, method1, lambda x: x, (x, weight, bias), check_types=False)
def _perform_ad_subgraph_slicing(self, fn, *input_sizes): with disable_autodiff_subgraph_inlining(): with enable_profiling_mode_for_profiling_tests(): ge = torch.jit.script(fn) inputs = [torch.randn(size, requires_grad=True) for size in input_sizes] ge(*inputs, profile_and_replay=True) return ge.graph_for(*inputs)
def test_diff_graph_inline_threshold(self): with enable_profiling_mode_for_profiling_tests(): NUM_RUNS = 1 with num_profiled_runs(NUM_RUNS): @torch.jit.script def foo(x): # two nodes should be fused # see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/runtime/graph_executor_impl.h#L49 return torch.sigmoid(torch.sigmoid(x)) @torch.jit.script def bar(x): # two nodes should NOT be fused return torch.sigmoid(x) input = torch.rand([4, 4], requires_grad=True) foo(input) foo(input) bar(input) bar(input) print(foo.graph_for(input)) self.assertGraphContainsExactly(foo.graph_for(input), 'prim::DifferentiableGraph', 1) self.assertGraphContainsExactly(bar.graph_for(input), 'prim::DifferentiableGraph', 0)
def test_differentiable_graph_ops_requires_grad(self): x = torch.randn(8, 2, dtype=torch.float).requires_grad_() y = torch.randn(8, 2, dtype=torch.float) def t(x: torch.Tensor, y: torch.Tensor): o = x + 1.0 o1 = torch.relu(o) o = y + 1.5 o2 = torch.relu(o) o3 = o1 + o2 return o1, o2, o3 with enable_profiling_mode_for_profiling_tests(): t_jit = torch.jit.script(t) jit_o = t_jit(x, y) jit_o = t_jit(x, y) o = t(x, y) FileCheck().check("prim::DifferentiableGraph").run( t_jit.graph_for(x, y)) # validate the differentiableGraphOps are marking proper requires_grad for oo, jit_oo in zip(o, jit_o): self.assertEqual(oo.requires_grad, jit_oo.requires_grad) self.assertEqual(oo, jit_oo) # one more runs to trigger fusion jit_o = t_jit(x, y) for oo, jit_oo in zip(o, jit_o): self.assertEqual(oo.dtype, jit_oo.dtype) self.assertEqual(oo.requires_grad, jit_oo.requires_grad) self.assertEqual(oo, jit_oo)
def checkScriptRaisesRegex(self, script, inputs, exception, regex, name=None, outputs=None, capture_output=False, frames_up=1, profiling=ProfilingMode.PROFILING): """ Checks that a given function will throw the correct exception, when executed with normal python, the string frontend, and the AST frontend. Logic taken from `checkScript` (see comments there for details) """ with enable_profiling_mode_for_profiling_tests(): # Normal Python with self.assertRaisesRegex(exception, regex): if isinstance(script, str): frame = self.get_frame_vars(frames_up) the_locals: Dict[str, Any] = {} execWrapper(script, glob=frame, loc=the_locals) frame.update(the_locals) python_fn = frame[name] else: python_fn = script python_fn(*inputs) # String frontend with self.assertRaisesRegex(exception, regex): if isinstance(script, str): cu = torch.jit.CompilationUnit(script, _frames_up=frames_up) string_frontend = getattr(cu, name) else: source = textwrap.dedent(inspect.getsource(script)) cu = torch.jit.CompilationUnit(source, _frames_up=frames_up) string_frontend = getattr(cu, script.__name__) with self.assertRaisesRegex(exception, regex): string_frontend(*inputs) # optimized run string_frontend(*inputs) # Python AST frontend if not isinstance(script, str): with self.assertRaisesRegex(exception, regex): ge = torch.jit.script(python_fn) # profiling run with self.assertRaisesRegex(exception, regex): ge(*inputs) # optimized run ge(*inputs)
def test_chunk_constant_script_ad(self): @torch.jit.script def func(x): x1, x2 = torch.chunk(x, 2) return (x1, x2) input = torch.rand(6, 10).requires_grad_() with disable_autodiff_subgraph_inlining(): with enable_profiling_mode_for_profiling_tests(): output = func(input, profile_and_replay=True) FileCheck().check_not("prim::DifferentiableGraph").run(func.graph_for(input))
def test_chunk_constant_script_ad(self): @torch.jit.script def func(x): x1, x2 = torch.chunk(x, 2) return (x1, x2) input = torch.rand(6, 10).requires_grad_() with disable_autodiff_subgraph_inlining(): with enable_profiling_mode_for_profiling_tests(): output = func(input, profile_and_replay=True) self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], [])
def test_constructed_bias(self): with enable_profiling_mode_for_profiling_tests(): def method1(x, weight, b1, b2): bias = b1 * b2 return torch.nn.functional.linear(x, weight, bias) N = 10 x = torch.rand(N, N, requires_grad=True) weight = torch.rand(N, N, requires_grad=True) b1 = torch.rand(N, N, requires_grad=True) b2 = torch.rand(N, N, requires_grad=True) scripted = self.checkScript(method1, (x, weight, b1, b2)) # check_types requires last_graph on scripted to be set, so we just skip it check_against_reference(self, scripted, method1, lambda x: x, (x, weight, b1, b2), check_types=False)
def test_fuser_iou(self): # This checks if most of Intersection over Union is fused. # In particular, the backward contains many _grad_sum_to_size. def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2): ltx = torch.max(b1x1, b2x1) # [N,M] lty = torch.max(b1y1, b2y1) rbx = torch.min(b1x2, b2x2) rby = torch.min(b1y2, b2y2) w = (rbx - ltx).clamp(min=0, max=float('inf')) # [N,M] h = (rby - lty).clamp(min=0, max=float('inf')) # [N,M] inter = w * h # [N,M] area1 = (b1x2 - b1x1) * (b1y2 - b1y2) # [N,1] area2 = (b2x2 - b2x1) * (b2y2 - b2y2) # [1,M] iou = inter / (area1 + area2 - inter) return iou box1 = torch.randn(5, 4, requires_grad=True) box2 = torch.randn(5, 4, requires_grad=True) # unsqueezing can currently not be fused b1x1 = box1[:, 0].unsqueeze(1) # [N,1] b1y1 = box1[:, 1].unsqueeze(1) b1x2 = box1[:, 2].unsqueeze(1) b1y2 = box1[:, 3].unsqueeze(1) b2x1 = box2[:, 0].unsqueeze(0) # [1,N] b2y1 = box2[:, 1].unsqueeze(0) b2x2 = box2[:, 2].unsqueeze(0) b2y2 = box2[:, 3].unsqueeze(0) s = self.checkScript(iou, (b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2)) self.assertAllFused(s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2), except_for={ 'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal' }) with enable_profiling_mode_for_profiling_tests(True): c = s(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2) warmup_backward(c.sum(), [b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2]) graph = backward_graph(s) self.assertAllFused(graph, except_for={ 'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal' })
def _test_reinforcement_learning(self, device, test_export_import=True): class Policy(nn.Module): def __init__(self): super(Policy, self).__init__() self.affine1 = nn.Linear(4, 128) self.affine2 = nn.Linear(128, 2) def forward(self, x): x = F.relu(self.affine1(x)) action_scores = self.affine2(x) return F.softmax(action_scores, dim=1) with enable_profiling_mode_for_profiling_tests(): self.checkTrace(Policy().to(device), (torch.rand(1, 4, device=device),), export_import=test_export_import)
def _test_vae(self, device, check_export_import=True, quantized=False): class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() self.fc1 = nn.Linear(784, 400) self.fc21 = nn.Linear(400, 20) self.fc22 = nn.Linear(400, 20) self.fc3 = nn.Linear(20, 400) self.fc4 = nn.Linear(400, 784) def encode(self, x): h1 = F.relu(self.fc1(x)) return self.fc21(h1), self.fc22(h1) def reparameterize(self, mu, logvar): if self.training: std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps.mul(std).add_(mu) else: return mu def decode(self, z): h3 = F.relu(self.fc3(z)) return torch.sigmoid(self.fc4(h3)) def forward(self, x): mu, logvar = self.encode(x.view(-1, 784)) z = self.reparameterize(mu, logvar) return self.decode(z), mu, logvar if quantized: vae = VAE().to(device).eval() torch.jit.quantized.quantize_linear_modules(vae) # We don't do export/import checks because we would need to call # _unpack and _pack self.checkTrace(vae, (torch.rand(128, 1, 28, 28, device=device), ), export_import=False, allow_unused=True, inputs_require_grads=False) else: with enable_profiling_mode_for_profiling_tests(): # eval() is present because randn_like makes this nondeterministic self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device), ), export_import=check_export_import)
def test_lstm_cuda(self): inputs = get_lstm_inputs('cuda', training=True) module = self.checkScript(LSTMCellS, inputs) return forward_graph = module.graph_for(*inputs) self.assertGraphContainsExactly( forward_graph, FUSION_GROUP, 1, consider_subgraphs=True) self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2) # Everything is differentiable but TupleConstruct return FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ .check_next("return").run(str(forward_graph)) with enable_profiling_mode_for_profiling_tests(True): hy, cy = module(*inputs) warmup_backward((hy + cy).sum()) backward = backward_graph(module) self.assertAllFused(backward, except_for=("aten::t", "aten::mm", "aten::_grad_sum_to_size"))
def test_requires_grad_for_tensor_list(self): with enable_profiling_mode_for_profiling_tests(): # output & var_list[0] should have requires_grad set to True def func(input0: torch.Tensor, input1: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: var_list = [input0, input1] var = torch.cat(var_list) output = var + 1.0 return output, var_list jit_f = torch.jit.script(func) input0 = torch.randn((2,), requires_grad=True) input1 = torch.randn((2,)) output_ref = func(input0, input1) for i in range(2): output = jit_f(input0, input1) assert(output_ref[0].requires_grad == output[0].requires_grad) assert(output_ref[1][0].requires_grad == output[1][0].requires_grad) assert(output_ref[1][1].requires_grad == output[1][1].requires_grad)
def test_bias_as_module_attr(self): with enable_profiling_mode_for_profiling_tests(): class M(torch.nn.Module): def __init__(self, has_bias): super(M, self).__init__() self.ll = torch.nn.Linear(10, 10, has_bias) def forward(self, x, y): return self.ll(x + y) * x + y x = torch.rand(10, 10, requires_grad=True) no_bias = M(False) scripted_no_bias = torch.jit.script(no_bias) scripted_no_bias(x, x) scripted_no_bias(x, x) scripted_no_bias(x, x) has_bias = M(True) check_against_reference(self, scripted_no_bias, no_bias, lambda x: x, ( x, x, ), check_types=False) scripted_has_bias = torch.jit.script(has_bias) scripted_has_bias(x, x) scripted_has_bias(x, x) scripted_has_bias(x, x) check_against_reference(self, scripted_has_bias, has_bias, lambda x: x, ( x, x, ), check_types=False)
def check_against_reference(self, func, reference_func, output_func, args, kwargs=None, allow_unused=True, check_types=True, no_grad=False): kwargs = kwargs if kwargs else {} def allSum(vs): if isinstance(vs, torch.Tensor): vs = (vs,) return sum((i + 1) * v.sum() for i, v in enumerate(vs) if v is not None and v.dtype in floating_and_complex_types_and(torch.half, torch.bfloat16)) def clone_inputs(requires_grad): inputs = [ arg.detach().clone().requires_grad_(requires_grad and arg.requires_grad) if isinstance(arg, torch.Tensor) else arg for arg in args ] return inputs, [input for input in inputs if isinstance(input, torch.Tensor) and input.requires_grad] nograd_inputs, nograd_tensors = clone_inputs(False) recording_inputs, recording_tensors = clone_inputs(True) # test no gradients case outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs) with enable_profiling_mode_for_profiling_tests(): outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs) self.assertEqual(outputs, outputs_test) if check_types: check_output_types(self, func, outputs_test, nograd_inputs, kwargs) if no_grad: # skip grad tests return with enable_profiling_mode_for_profiling_tests(): # test single grad case outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs)) grads = torch.autograd.grad(allSum(outputs), recording_tensors, allow_unused=allow_unused) outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs)) grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors, allow_unused=allow_unused) self.assertEqual(outputs, outputs_test) self.assertEqual(grads, grads_test) # test the grad grad case if self._testMethodName in nn_functional_single_grad: return outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs)) l1 = allSum(outputs) grads = torch.autograd.grad(l1, recording_tensors, create_graph=True, allow_unused=allow_unused) l2 = (allSum(grads) * l1) grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused) recording_inputs, recording_tensors = clone_inputs(True) outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs)) l1_test = allSum(outputs_test) grads_test = torch.autograd.grad( l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused) l2_test = (allSum(grads_test) * l1_test) grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused) self.assertEqual(outputs, outputs_test) self.assertEqual(grads, grads_test) for g2, g2_test in zip(grads2, grads2_test): if g2 is None and g2_test is None: continue self.assertTrue(torch.allclose(g2, g2_test, atol=5e-4, rtol=1e-4))
def check_against_reference(self, func, reference_func, output_func, args, kwargs=None, allow_unused=True, check_types=True, no_grad=False, no_gradgrad=False): """Verifies a function performs identically to some reference implementation. Commonly, this is used to verify that a JIT implementation (output_func) matches the behavior of the eager implementation (reference_func). """ kwargs = kwargs if kwargs else {} def allSum(vs): if isinstance(vs, torch.Tensor): vs = (vs, ) return sum( (i + 1) * v.sum() for i, v in enumerate(vs) if v is not None and v.dtype in floating_and_complex_types_and( torch.half, torch.bfloat16)) def clone_tensor(t, preserve_requires_grad): require_grad = preserve_requires_grad and t.requires_grad return t.detach().clone().requires_grad_(require_grad) def clone_inputs(preserve_requires_grad: bool): inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = [] for arg in args: if isinstance(arg, torch.Tensor): inputs.append(clone_tensor(arg, preserve_requires_grad)) elif is_iterable_of_tensors(arg): inputs.append( [clone_tensor(t, preserve_requires_grad) for t in arg]) else: inputs.append(arg) return inputs # Returns tensors in args that requires_grad, including tensors in TensorList args def get_recording_tensors(args): recording_tensors: List[torch.Tensor] = [] for arg in args: if isinstance(arg, torch.Tensor) and arg.requires_grad: recording_tensors.append(arg) elif is_iterable_of_tensors(arg): recording_tensors.extend(filter(lambda t: t.requires_grad, arg)) return recording_tensors # test no gradients case nograd_inputs = clone_inputs(preserve_requires_grad=False) outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs) with enable_profiling_mode_for_profiling_tests(): outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs) self.assertEqual(outputs, outputs_test) if check_types: check_output_types(self, func, outputs_test, nograd_inputs, kwargs) if no_grad: # skip grad tests return with enable_profiling_mode_for_profiling_tests(): # test single grad case recording_inputs = clone_inputs(preserve_requires_grad=True) recording_tensors = get_recording_tensors(recording_inputs) outputs = output_func( self.runAndSaveRNG(reference_func, recording_inputs, kwargs)) grads = torch.autograd.grad(allSum(outputs), recording_tensors, allow_unused=allow_unused) outputs_test = output_func( self.runAndSaveRNG(func, recording_inputs, kwargs)) grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors, allow_unused=allow_unused) self.assertEqual(outputs, outputs_test) self.assertEqual(grads, grads_test) # test the grad grad case if self._testMethodName in nn_functional_single_grad or no_gradgrad: return outputs = output_func( self.runAndSaveRNG(reference_func, recording_inputs, kwargs)) l1 = allSum(outputs) grads = torch.autograd.grad(l1, recording_tensors, create_graph=True, allow_unused=allow_unused) l2 = (allSum(grads) * l1) grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused) recording_inputs = clone_inputs(preserve_requires_grad=True) recording_tensors = get_recording_tensors(recording_inputs) outputs_test = output_func( self.runAndSaveRNG(func, recording_inputs, kwargs)) l1_test = allSum(outputs_test) grads_test = torch.autograd.grad(l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused) l2_test = (allSum(grads_test) * l1_test) grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused) self.assertEqual(outputs, outputs_test) self.assertEqual(grads, grads_test) for g2, g2_test in zip(grads2, grads2_test): if g2 is None and g2_test is None: continue self.assertEqual(g2, g2_test, atol=5e-4, rtol=1e-4)
def test_aliased_outputs(self): with enable_profiling_mode_for_profiling_tests(): # Case 1: aliasing between relu and t # is within a DifferentiableGraph. It should be valid # to merge both split_with_sizes in relu in one graph input_str = """ graph(%a : Tensor): %b : Tensor = aten::relu(%a) %2 : Tensor = aten::t(%b) return (%2) """ graph = torch._C.parse_ir(input_str) torch._C._jit_pass_create_autodiff_subgraphs(graph, 1) FileCheck().check("with prim::DifferentiableGraph") \ .check("aten::relu").check("aten::t") \ .run(graph) # Case 2: aliasing between relu and split_with_sizes # are both outputs of a Diff graph. It should be invalid # to merge both split_with_sizes in relu in one graph # i.e. relu and split_with_sizes should be in different # differentiable graphs input_str = """ graph(%a : Tensor): %b : Tensor = aten::relu(%a) %0 : int[] = prim::Constant[value=[2, 2, 1]]() %1 : int = prim::Constant[value=0]() %2 : Tensor[] = aten::split_with_sizes(%b, %0, %1) %3 : (Tensor[], Tensor[]) = prim::TupleConstruct(%b, %2) return (%3) """ graph = torch._C.parse_ir(input_str) torch._C._jit_pass_create_autodiff_subgraphs(graph, 1) FileCheck().check("Tensor = prim::DifferentiableGraph") \ .check("with prim::DifferentiableGraph") \ .check("Tensor = aten::relu") \ .check_not("aten::split_with_sizes") \ .run(graph) # Case 3: two aliased nodes in a graph. # Both `split_with_sizes` should be unfused input_str = """ graph(%a : Tensor): %b : Tensor = aten::relu(%a) %s1 : int[] = prim::Constant[value=[2, 2, 1]]() %s2 : int[] = prim::Constant[value=[3, 1]]() %1 : int = prim::Constant[value=0]() %2 : Tensor[] = aten::split_with_sizes(%b, %s1, %1) %3 : Tensor[] = aten::split_with_sizes(%b, %s2, %1) %4 : (Tensor, Tensor[]) = prim::TupleConstruct(%b, %2, %3) return (%4) """ graph = torch._C.parse_ir(input_str) torch._C._jit_pass_create_autodiff_subgraphs(graph, 1) FileCheck().check("Tensor = prim::DifferentiableGraph") \ .check("with prim::DifferentiableGraph") \ .check("Tensor = aten::relu") \ .check_not("aten::split_with_sizes") \ .run(graph) # Case 4: the aliased output has a descendant # Both should be unfused. Note, %3 comes before %2 # to test that we unfuse in the reverse topo order input_str = """ graph(%a : Tensor): %b : Tensor = aten::relu(%a) %0 : int[] = prim::Constant[value=[2, 2, 1]]() %1 : int = prim::Constant[value=0]() %2 : Tensor = aten::t(%b) %3 : Tensor = aten::gelu(%2) %4 : (Tensor, Tensor, Tensor[]) = prim::TupleConstruct(%b, %3, %2) return (%4) """ graph = torch._C.parse_ir(input_str) torch._C._jit_pass_create_autodiff_subgraphs(graph, 1) FileCheck().check("Tensor = prim::DifferentiableGraph") \ .check("with prim::DifferentiableGraph") \ .check("Tensor = aten::relu") \ .check_not("aten::t") \ .run(graph) # Case 5: multiple aliased groups # Both should be unfused. Note, %3 comes before %2 # to test that we unfuse in the reverse topo order input_str = """ graph(%a : Tensor): %b : Tensor = aten::relu(%a) %c : Tensor = aten::abs(%a) %0 : int[] = prim::Constant[value=[2, 2, 1]]() %1 : int = prim::Constant[value=0]() %d : Tensor = aten::t(%c) %2 : Tensor = aten::t(%b) %3 : Tensor = aten::gelu(%2) %4 : (Tensor, Tensor, Tensor[]) = prim::TupleConstruct(%3, %2, %d, %b, %c, %b) return (%4) """ graph = torch._C.parse_ir(input_str) torch._C._jit_pass_create_autodiff_subgraphs(graph, 1) FileCheck().check("Tensor = prim::DifferentiableGraph") \ .check("with prim::DifferentiableGraph") \ .check("Tensor = aten::relu") \ .check_not("aten::t") \ .run(graph)
def checkScript(self, script, inputs, name='func', optimize=True, inputs_requires_grad=False, capture_output=False, frames_up=1, profiling=ProfilingMode.PROFILING): with torch.jit.optimized_execution(optimize): with enable_profiling_mode_for_profiling_tests(): if isinstance(script, str): # Compile the string to a Script function # with enable_profiling_mode(): cu = torch.jit.CompilationUnit(script, _frames_up=frames_up) # Execute the Python function so we can run it later and get its # outputs frame = self.get_frame_vars(frames_up) the_locals = {} execWrapper(script, glob=frame, loc=the_locals) frame.update(the_locals) python_fn = frame[name] scripted_fn = getattr(cu, name) else: # Check the string frontend first source = textwrap.dedent(inspect.getsource(script)) self.checkScript(source, inputs, script.__name__, optimize=optimize, inputs_requires_grad=inputs_requires_grad, capture_output=capture_output, profiling=profiling, frames_up=2) # Continue checking the Python frontend scripted_fn = torch.jit.script(script, _frames_up=1) python_fn = script if inputs_requires_grad: recording_inputs = do_input_map( lambda t: t.detach().requires_grad_(), inputs) else: recording_inputs = inputs if capture_output: with self.capture_stdout() as script_stdout: script_outputs = scripted_fn(*recording_inputs) with self.capture_stdout() as opt_script_stdout: opt_script_outputs = scripted_fn(*recording_inputs) with self.capture_stdout() as _python_stdout: python_outputs = python_fn(*inputs) if not IS_WINDOWS: self.assertExpected(script_stdout[0], subname='stdout') self.assertEqual(python_outputs, opt_script_outputs) else: # profiling run script_outputs = scripted_fn(*recording_inputs) # optimized run opt_script_outputs = scripted_fn(*recording_inputs) if TEST_BAILOUTS: self.checkBailouts(scripted_fn, inputs, opt_script_outputs) python_outputs = python_fn(*inputs) self.assertEqual(python_outputs, script_outputs) self.assertEqual(script_outputs, opt_script_outputs) return scripted_fn
def _test_mnist(self, device, check_export_import=True): # eval() is present because dropout makes this nondeterministic with enable_profiling_mode_for_profiling_tests(): self.checkTrace(MnistNet().to(device).eval(), (torch.rand(5, 1, 28, 28, device=device), ), export_import=check_export_import)