def test_clamp(self): def func2(a, b): return torch.clamp(a + b, min=0, max=2) def funcInf(a, b): return torch.clamp(a + b, min=0, max=float('inf')) def funcOptMin(a, b): return torch.clamp(a + b, max=2) def funcOptMax(a, b): return torch.clamp(a + b, min=0) a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True) b = torch.randn(4, 4, dtype=torch.float, device='cuda') nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda') funcs = (func2, funcInf, funcOptMin, funcOptMax) for f, inputs in product(funcs, [[a, b], [a, nan]]): inp1, inp2 = inputs s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING) self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'}) c = s(inp1, inp2) with enable_profiling_mode(): warmup_backward(c.sum()) graph = backward_graph(s) self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'})
def checkScriptRaisesRegex(self, script, inputs, exception, regex, outputs=None, capture_output=False, profiling=ProfilingMode.PROFILING): """ Checks that a given function will throw the correct exception, when executed with normal python, the string frontend, and the AST frontend """ with enable_profiling_mode(): # normal python with self.assertRaisesRegex(exception, regex): script(*inputs) # string frontend with self.assertRaisesRegex(exception, regex): source = textwrap.dedent(inspect.getsource(script)) cu = torch.jit.CompilationUnit(source) ge = getattr(cu, script.__name__) # profiling run with self.assertRaisesRegex(exception, regex): ge(*inputs) # optimized run ge(*inputs) # python AST frontend with self.assertRaisesRegex(exception, regex): ge = torch.jit.script(script) # profiling run with self.assertRaisesRegex(exception, regex): ge(*inputs) # optimized run ge(*inputs)
def _perform_ad_subgraph_slicing(self, fn, *input_sizes): with disable_autodiff_subgraph_inlining(): with enable_profiling_mode(): ge = torch.jit.script(fn) inputs = [ torch.randn(size, requires_grad=True) for size in input_sizes ] ge(*inputs, profile_and_replay=True) return ge.graph_for(*inputs)
def test_chunk_constant_script_ad(self): @torch.jit.script def func(x): x1, x2 = torch.chunk(x, 2) return (x1, x2) input = torch.rand(6, 10).requires_grad_() with disable_autodiff_subgraph_inlining(): with enable_profiling_mode(): output = func(input, profile_and_replay=True) self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], [])
def test_fuser_iou(self): # This checks if most of Intersection over Union is fused. # In particular, the backward contains many _grad_sum_to_size. def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2): ltx = torch.max(b1x1, b2x1) # [N,M] lty = torch.max(b1y1, b2y1) rbx = torch.min(b1x2, b2x2) rby = torch.min(b1y2, b2y2) w = (rbx - ltx).clamp(min=0, max=float('inf')) # [N,M] h = (rby - lty).clamp(min=0, max=float('inf')) # [N,M] inter = w * h # [N,M] area1 = (b1x2 - b1x1) * (b1y2 - b1y2) # [N,1] area2 = (b2x2 - b2x1) * (b2y2 - b2y2) # [1,M] iou = inter / (area1 + area2 - inter) return iou box1 = torch.randn(5, 4, requires_grad=True) box2 = torch.randn(5, 4, requires_grad=True) # unsqueezing can currently not be fused b1x1 = box1[:, 0].unsqueeze(1) # [N,1] b1y1 = box1[:, 1].unsqueeze(1) b1x2 = box1[:, 2].unsqueeze(1) b1y2 = box1[:, 3].unsqueeze(1) b2x1 = box2[:, 0].unsqueeze(0) # [1,N] b2y1 = box2[:, 1].unsqueeze(0) b2x2 = box2[:, 2].unsqueeze(0) b2y2 = box2[:, 3].unsqueeze(0) s = self.checkScript(iou, (b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2)) self.assertAllFused(s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2), except_for={ 'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal' }) with enable_profiling_mode(True): c = s(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2) warmup_backward(c.sum(), [b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2]) graph = backward_graph(s) self.assertAllFused(graph, except_for={ 'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal' })
def _test_reinforcement_learning(self, device, test_export_import=True): class Policy(nn.Module): def __init__(self): super(Policy, self).__init__() self.affine1 = nn.Linear(4, 128) self.affine2 = nn.Linear(128, 2) def forward(self, x): x = F.relu(self.affine1(x)) action_scores = self.affine2(x) return F.softmax(action_scores, dim=1) with enable_profiling_mode(): self.checkTrace(Policy().to(device), (torch.rand(1, 4, device=device),), export_import=test_export_import)
def _test_vae(self, device, check_export_import=True, quantized=False): class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() self.fc1 = nn.Linear(784, 400) self.fc21 = nn.Linear(400, 20) self.fc22 = nn.Linear(400, 20) self.fc3 = nn.Linear(20, 400) self.fc4 = nn.Linear(400, 784) def encode(self, x): h1 = F.relu(self.fc1(x)) return self.fc21(h1), self.fc22(h1) def reparameterize(self, mu, logvar): if self.training: std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps.mul(std).add_(mu) else: return mu def decode(self, z): h3 = F.relu(self.fc3(z)) return torch.sigmoid(self.fc4(h3)) def forward(self, x): mu, logvar = self.encode(x.view(-1, 784)) z = self.reparameterize(mu, logvar) return self.decode(z), mu, logvar if quantized: vae = VAE().to(device).eval() torch.jit.quantized.quantize_linear_modules(vae) # We don't do export/import checks because we would need to call # _unpack and _pack self.checkTrace(vae, (torch.rand(128, 1, 28, 28, device=device), ), export_import=False, allow_unused=True, inputs_require_grads=False) else: with enable_profiling_mode(): # eval() is present because randn_like makes this nondeterministic self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device), ), export_import=check_export_import)
def test_lstm_cuda(self): inputs = get_lstm_inputs('cuda', training=True) module = self.checkScript(LSTMCellS, inputs) return forward_graph = module.graph_for(*inputs) self.assertGraphContainsExactly( forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True) self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2) # Everything is differentiable but TupleConstruct return FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ .check_next("return").run(str(forward_graph)) with enable_profiling_mode(True): hy, cy = module(*inputs) warmup_backward((hy + cy).sum()) backward = backward_graph(module) self.assertAllFused(backward, except_for=("aten::t", "aten::mm", "aten::_grad_sum_to_size"))
def _test_mnist(self, device, check_export_import=True): # eval() is present because dropout makes this nondeterministic with enable_profiling_mode(): self.checkTrace(MnistNet().to(device).eval(), (torch.rand(5, 1, 28, 28, device=device),), export_import=check_export_import)
def checkScript(self, script, inputs, name='func', optimize=True, inputs_requires_grad=False, capture_output=False, frames_up=1, profiling=ProfilingMode.PROFILING): with torch.jit.optimized_execution(optimize): with enable_profiling_mode(): if isinstance(script, str): # Compile the string to a Script function # with enable_profiling_mode(): cu = torch.jit.CompilationUnit(script, _frames_up=frames_up) # Execute the Python function so we can run it later and get its # outputs frame = self.get_frame_vars(frames_up) the_locals = {} execWrapper(script, glob=frame, loc=the_locals) frame.update(the_locals) python_fn = frame[name] scripted_fn = getattr(cu, name) else: # Check the string frontend first source = textwrap.dedent(inspect.getsource(script)) self.checkScript(source, inputs, script.__name__, capture_output, profiling=profiling, frames_up=2) # Continue checking the Python frontend scripted_fn = torch.jit.script(script, _frames_up=1) python_fn = script if inputs_requires_grad: recording_inputs = do_input_map( lambda t: t.detach().requires_grad_(), inputs) else: recording_inputs = inputs if capture_output: with self.capture_stdout() as script_stdout: script_outputs = scripted_fn(*recording_inputs) with self.capture_stdout() as opt_script_stdout: opt_script_outputs = scripted_fn(*recording_inputs) with self.capture_stdout() as _python_stdout: python_outputs = python_fn(*inputs) if not IS_WINDOWS: self.assertExpected(script_stdout[0], subname='stdout') self.assertEqual(python_outputs, opt_script_outputs) else: # profiling run script_outputs = scripted_fn(*recording_inputs) # optimized run opt_script_outputs = scripted_fn(*recording_inputs) python_outputs = python_fn(*inputs) self.assertEqual(python_outputs, script_outputs) self.assertEqual(script_outputs, opt_script_outputs) return scripted_fn