def test_clamp(self): def func2(a, b): return torch.clamp(a + b, min=0, max=2) def funcInf(a, b): return torch.clamp(a + b, min=0, max=float('inf')) def funcOptMin(a, b): return torch.clamp(a + b, max=2) def funcOptMax(a, b): return torch.clamp(a + b, min=0) a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True) b = torch.randn(4, 4, dtype=torch.float, device='cuda') nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda') funcs = (func2, funcInf, funcOptMin, funcOptMax) for f, inputs in product(funcs, [[a, b], [a, nan]]): inp1, inp2 = inputs s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING) self.assertAllFused( s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'}) c = s(inp1, inp2) with enable_profiling_mode(): warmup_backward(c.sum()) graph = backward_graph(s) self.assertAllFused( graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'})
def checkScriptRaisesRegex(self, script, inputs, exception, regex, outputs=None, capture_output=False, profiling=ProfilingMode.PROFILING): """ Checks that a given function will throw the correct exception, when executed with normal python, the string frontend, and the AST frontend """ with enable_profiling_mode(): # normal python with self.assertRaisesRegex(exception, regex): script(*inputs) # string frontend with self.assertRaisesRegex(exception, regex): source = textwrap.dedent(inspect.getsource(script)) cu = torch.jit.CompilationUnit(source) ge = getattr(cu, script.__name__) # profiling run with self.assertRaisesRegex(exception, regex): ge(*inputs) # optimized run ge(*inputs) # python AST frontend with self.assertRaisesRegex(exception, regex): ge = torch.jit.script(script) # profiling run with self.assertRaisesRegex(exception, regex): ge(*inputs) # optimized run ge(*inputs)
def _perform_ad_subgraph_slicing(self, fn, *input_sizes): with disable_autodiff_subgraph_inlining(): with enable_profiling_mode(): ge = torch.jit.script(fn) inputs = [torch.randn(size, requires_grad=True) for size in input_sizes] ge(*inputs, profile_and_replay=True) return ge.graph_for(*inputs)
def do_test(model, bailout, print_diff): logging.basicConfig(filename='jit_' + model.replace('/', '-') + str(int(time.time())) + '.log', filemode='w', level=logging.DEBUG) with enable_profiling_mode(): logging.info("loading profiled %s", model) jm = torch.jit.load(model) #jm.eval() logging.info("running profiled %s", model) with freeze_rng_state(): po = jm() logging.info("running profiled2 %s", model) with freeze_rng_state(): po2 = jm() if not test_allclose(po, po2): logging.error("profiled and profiled2 outputs aren't equal") if (print_diff): logging.error("po : %s", str(po)) logging.error("po2 : %s", str(po2)) logging.info("running optimized %s", model) with freeze_rng_state(): jo = jm() if not test_allclose(po, jo): logging.error("profiled and optimized outputs aren't equal") if (print_diff): logging.error("po : %s", str(po)) logging.error("jo : %s", str(jo)) plan = get_plan(jm) num_bailouts = plan.code.num_bailouts() logging.info("number of bailouts: %d", num_bailouts) if bailout: logging.info("triggering bailout %d ", bailout) plan.code.request_bailout(bailout) with freeze_rng_state(): bo = jm() if not test_allclose(bo, jo): logging.error("bailout %d and optimized outputs aren't equal", bailout) if (print_diff): logging.error("bo : %s", str(bo)) logging.error("jo : %s", str(jo)) else: for i in range(0, num_bailouts): logging.info("triggering bailout %d ", i) plan.code.request_bailout(i) with freeze_rng_state(): bo = jm() if not test_allclose(bo, jo): logging.error( "bailout %d and optimized outputs aren't equal", i) if (print_diff): logging.error("bo : %s", str(bo)) logging.error("jo : %s", str(jo))
def test_chunk_constant_script_ad(self): @torch.jit.script def func(x): x1, x2 = torch.chunk(x, 2) return (x1, x2) input = torch.rand(6, 10).requires_grad_() with disable_autodiff_subgraph_inlining(): with enable_profiling_mode(): output = func(input, profile_and_replay=True) self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], [])
def test_fuser_iou(self): # This checks if most of Intersection over Union is fused. # In particular, the backward contains many _grad_sum_to_size. def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2): ltx = torch.max(b1x1, b2x1) # [N,M] lty = torch.max(b1y1, b2y1) rbx = torch.min(b1x2, b2x2) rby = torch.min(b1y2, b2y2) w = (rbx - ltx).clamp(min=0, max=float('inf')) # [N,M] h = (rby - lty).clamp(min=0, max=float('inf')) # [N,M] inter = w * h # [N,M] area1 = (b1x2 - b1x1) * (b1y2 - b1y2) # [N,1] area2 = (b2x2 - b2x1) * (b2y2 - b2y2) # [1,M] iou = inter / (area1 + area2 - inter) return iou box1 = torch.randn(5, 4, requires_grad=True) box2 = torch.randn(5, 4, requires_grad=True) # unsqueezing can currently not be fused b1x1 = box1[:, 0].unsqueeze(1) # [N,1] b1y1 = box1[:, 1].unsqueeze(1) b1x2 = box1[:, 2].unsqueeze(1) b1y2 = box1[:, 3].unsqueeze(1) b2x1 = box2[:, 0].unsqueeze(0) # [1,N] b2y1 = box2[:, 1].unsqueeze(0) b2x2 = box2[:, 2].unsqueeze(0) b2y2 = box2[:, 3].unsqueeze(0) s = self.checkScript(iou, (b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2)) self.assertAllFused(s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2), except_for={ 'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal' }) with enable_profiling_mode(True): c = s(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2) warmup_backward(c.sum(), [b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2]) graph = backward_graph(s) self.assertAllFused(graph, except_for={ 'aten::size', 'prim::BroadcastSizes', 'aten::_size_if_not_equal' })
def _test_reinforcement_learning(self, device, test_export_import=True): class Policy(nn.Module): def __init__(self): super(Policy, self).__init__() self.affine1 = nn.Linear(4, 128) self.affine2 = nn.Linear(128, 2) def forward(self, x): x = F.relu(self.affine1(x)) action_scores = self.affine2(x) return F.softmax(action_scores, dim=1) with enable_profiling_mode(): self.checkTrace(Policy().to(device), (torch.rand(1, 4, device=device), ), export_import=test_export_import)
def _test_vae(self, device, check_export_import=True, quantized=False): class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() self.fc1 = nn.Linear(784, 400) self.fc21 = nn.Linear(400, 20) self.fc22 = nn.Linear(400, 20) self.fc3 = nn.Linear(20, 400) self.fc4 = nn.Linear(400, 784) def encode(self, x): h1 = F.relu(self.fc1(x)) return self.fc21(h1), self.fc22(h1) def reparameterize(self, mu, logvar): if self.training: std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps.mul(std).add_(mu) else: return mu def decode(self, z): h3 = F.relu(self.fc3(z)) return torch.sigmoid(self.fc4(h3)) def forward(self, x): mu, logvar = self.encode(x.view(-1, 784)) z = self.reparameterize(mu, logvar) return self.decode(z), mu, logvar if quantized: vae = VAE().to(device).eval() torch.jit.quantized.quantize_linear_modules(vae) # We don't do export/import checks because we would need to call # _unpack and _pack self.checkTrace(vae, (torch.rand(128, 1, 28, 28, device=device), ), export_import=False, allow_unused=True, inputs_require_grads=False) else: with enable_profiling_mode(): # eval() is present because randn_like makes this nondeterministic self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device), ), export_import=check_export_import)
def test_lstm_cuda(self): inputs = get_lstm_inputs('cuda', training=True) module = self.checkScript(LSTMCellS, inputs) return forward_graph = module.graph_for(*inputs) self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True) self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2) # Everything is differentiable but TupleConstruct return FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \ .check_next("return").run(str(forward_graph)) with enable_profiling_mode(True): hy, cy = module(*inputs) warmup_backward((hy + cy).sum()) backward = backward_graph(module) self.assertAllFused(backward, except_for=("aten::t", "aten::mm", "aten::_grad_sum_to_size"))
def do_detection_test(name): with enable_profiling_mode(): torch.manual_seed(0) print("testing ", name) model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False) model.eval() input_shape = (3, 300, 300) x = torch.rand(input_shape) model_input = [x] out = model(model_input) scripted_model = torch.jit.script(model) scripted_model(model_input) scripted_model(model_input) plan = get_plan(scripted_model) num_bailouts = plan.code.num_bailouts() print(num_bailouts) for i in range(0, num_bailouts): plan.code.request_bailout(i) bailout_output = scripted_model(model_input) open("detection_{}".format(name), 'a').close()
def do_classification_test(model_name): with enable_profiling_mode(): input_shape = (1, 3, 224, 224) if model_name in ['inception_v3']: input_shape = (1, 3, 299, 299) print("testing ", model_name) open("classification_{}".format(model_name), 'a').close() torch.manual_seed(0) model = models.__dict__[model_name](num_classes=50) scripted_model = torch.jit.script(model) scripted_model.eval() x = torch.rand(input_shape) py_output = model(x) scripted_model(x) opt_output = scripted_model(x) #assert torch.allclose(py_output, opt_output) plan = get_plan(scripted_model) num_bailouts = plan.code.num_bailouts() print(num_bailouts) for i in range(0, num_bailouts): plan.code.request_bailout(i) bailout_output = scripted_model(x)
def do_segmentation_test(model_name): with enable_profiling_mode(): # passing num_class equal to a number other than 1000 helps in making the test # more enforcing in nature print("testing ", model_name) torch.manual_seed(0) open("segmentation_{}".format(model_name), 'a').close() model = models.segmentation.__dict__[model_name]( num_classes=50, pretrained_backbone=False) model.eval() scripted_model = torch.jit.script(model) input_shape = (1, 3, 300, 300) x = torch.rand(input_shape) # out = model(x) scripted_model(x) scripted_model(x) plan = get_plan(scripted_model) num_bailouts = plan.code.num_bailouts() print(num_bailouts) for i in range(0, num_bailouts): plan.code.request_bailout(i) bailout_output = scripted_model(x)
def _test_mnist(self, device, check_export_import=True): # eval() is present because dropout makes this nondeterministic with enable_profiling_mode(): self.checkTrace(MnistNet().to(device).eval(), (torch.rand(5, 1, 28, 28, device=device), ), export_import=check_export_import)
def checkScript(self, script, inputs, name='func', optimize=True, inputs_requires_grad=False, capture_output=False, frames_up=1, profiling=ProfilingMode.PROFILING): with torch.jit.optimized_execution(optimize): with enable_profiling_mode(): if isinstance(script, str): # Compile the string to a Script function # with enable_profiling_mode(): cu = torch.jit.CompilationUnit(script, _frames_up=frames_up) # Execute the Python function so we can run it later and get its # outputs frame = self.get_frame_vars(frames_up) the_locals = {} execWrapper(script, glob=frame, loc=the_locals) frame.update(the_locals) python_fn = frame[name] scripted_fn = getattr(cu, name) else: # Check the string frontend first source = textwrap.dedent(inspect.getsource(script)) self.checkScript(source, inputs, script.__name__, capture_output, profiling=profiling, frames_up=2) # Continue checking the Python frontend scripted_fn = torch.jit.script(script, _frames_up=1) python_fn = script if inputs_requires_grad: recording_inputs = do_input_map( lambda t: t.detach().requires_grad_(), inputs) else: recording_inputs = inputs if capture_output: with self.capture_stdout() as script_stdout: script_outputs = scripted_fn(*recording_inputs) with self.capture_stdout() as opt_script_stdout: opt_script_outputs = scripted_fn(*recording_inputs) with self.capture_stdout() as _python_stdout: python_outputs = python_fn(*inputs) if not IS_WINDOWS: self.assertExpected(script_stdout[0], subname='stdout') self.assertEqual(python_outputs, opt_script_outputs) else: # profiling run script_outputs = scripted_fn(*recording_inputs) # optimized run opt_script_outputs = scripted_fn(*recording_inputs) python_outputs = python_fn(*inputs) self.assertEqual(python_outputs, script_outputs) self.assertEqual(script_outputs, opt_script_outputs) return scripted_fn