Example #1
0
    def test_clamp(self):
        def func2(a, b):
            return torch.clamp(a + b, min=0, max=2)

        def funcInf(a, b):
            return torch.clamp(a + b, min=0, max=float('inf'))

        def funcOptMin(a, b):
            return torch.clamp(a + b, max=2)

        def funcOptMax(a, b):
            return torch.clamp(a + b, min=0)

        a = torch.randn(4,
                        4,
                        dtype=torch.float,
                        device='cuda',
                        requires_grad=True)
        b = torch.randn(4, 4, dtype=torch.float, device='cuda')
        nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda')

        funcs = (func2, funcInf, funcOptMin, funcOptMax)
        for f, inputs in product(funcs, [[a, b], [a, nan]]):
            inp1, inp2 = inputs
            s = self.checkScript(f, (inp1, inp2),
                                 profiling=ProfilingMode.PROFILING)
            self.assertAllFused(
                s.graph_for(inp1, inp2),
                except_for={'aten::size', 'aten::_size_if_not_equal'})
            c = s(inp1, inp2)
            with enable_profiling_mode():
                warmup_backward(c.sum())
            graph = backward_graph(s)
            self.assertAllFused(
                graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'})
Example #2
0
    def checkScriptRaisesRegex(self,
                               script,
                               inputs,
                               exception,
                               regex,
                               outputs=None,
                               capture_output=False,
                               profiling=ProfilingMode.PROFILING):
        """
        Checks that a given function will throw the correct exception,
        when executed with normal python, the string frontend, and the AST frontend
        """

        with enable_profiling_mode():
            # normal python
            with self.assertRaisesRegex(exception, regex):
                script(*inputs)
            # string frontend
            with self.assertRaisesRegex(exception, regex):
                source = textwrap.dedent(inspect.getsource(script))
                cu = torch.jit.CompilationUnit(source)
                ge = getattr(cu, script.__name__)
                # profiling run
                with self.assertRaisesRegex(exception, regex):
                    ge(*inputs)
                # optimized run
                ge(*inputs)
            # python AST frontend
            with self.assertRaisesRegex(exception, regex):
                ge = torch.jit.script(script)
                # profiling run
                with self.assertRaisesRegex(exception, regex):
                    ge(*inputs)
                # optimized run
                ge(*inputs)
 def _perform_ad_subgraph_slicing(self, fn, *input_sizes):
     with disable_autodiff_subgraph_inlining():
         with enable_profiling_mode():
             ge = torch.jit.script(fn)
             inputs = [torch.randn(size, requires_grad=True) for size in input_sizes]
             ge(*inputs, profile_and_replay=True)
             return ge.graph_for(*inputs)
def do_test(model, bailout, print_diff):
    logging.basicConfig(filename='jit_' + model.replace('/', '-') +
                        str(int(time.time())) + '.log',
                        filemode='w',
                        level=logging.DEBUG)
    with enable_profiling_mode():
        logging.info("loading profiled %s", model)
        jm = torch.jit.load(model)
        #jm.eval()
        logging.info("running profiled %s", model)
        with freeze_rng_state():
            po = jm()
        logging.info("running profiled2 %s", model)
        with freeze_rng_state():
            po2 = jm()
        if not test_allclose(po, po2):
            logging.error("profiled and profiled2 outputs aren't equal")
            if (print_diff):
                logging.error("po : %s", str(po))
                logging.error("po2 : %s", str(po2))
        logging.info("running optimized %s", model)
        with freeze_rng_state():
            jo = jm()
        if not test_allclose(po, jo):
            logging.error("profiled and optimized outputs aren't equal")
            if (print_diff):
                logging.error("po : %s", str(po))
                logging.error("jo : %s", str(jo))
        plan = get_plan(jm)
        num_bailouts = plan.code.num_bailouts()
        logging.info("number of bailouts: %d", num_bailouts)
        if bailout:
            logging.info("triggering bailout %d ", bailout)
            plan.code.request_bailout(bailout)
            with freeze_rng_state():
                bo = jm()
            if not test_allclose(bo, jo):
                logging.error("bailout %d and optimized outputs aren't equal",
                              bailout)
                if (print_diff):
                    logging.error("bo : %s", str(bo))
                    logging.error("jo : %s", str(jo))
        else:
            for i in range(0, num_bailouts):
                logging.info("triggering bailout %d ", i)
                plan.code.request_bailout(i)
                with freeze_rng_state():
                    bo = jm()
                if not test_allclose(bo, jo):
                    logging.error(
                        "bailout %d and optimized outputs aren't equal", i)
                    if (print_diff):
                        logging.error("bo : %s", str(bo))
                        logging.error("jo : %s", str(jo))
    def test_chunk_constant_script_ad(self):
        @torch.jit.script
        def func(x):
            x1, x2 = torch.chunk(x, 2)
            return (x1, x2)

        input = torch.rand(6, 10).requires_grad_()
        with disable_autodiff_subgraph_inlining():
            with enable_profiling_mode():
                output = func(input, profile_and_replay=True)
                self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], [])
Example #6
0
    def test_fuser_iou(self):
        # This checks if most of Intersection over Union is fused.
        # In particular, the backward contains many _grad_sum_to_size.
        def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2):
            ltx = torch.max(b1x1, b2x1)  # [N,M]
            lty = torch.max(b1y1, b2y1)
            rbx = torch.min(b1x2, b2x2)
            rby = torch.min(b1y2, b2y2)

            w = (rbx - ltx).clamp(min=0, max=float('inf'))  # [N,M]
            h = (rby - lty).clamp(min=0, max=float('inf'))  # [N,M]
            inter = w * h  # [N,M]

            area1 = (b1x2 - b1x1) * (b1y2 - b1y2)  # [N,1]
            area2 = (b2x2 - b2x1) * (b2y2 - b2y2)  # [1,M]
            iou = inter / (area1 + area2 - inter)
            return iou

        box1 = torch.randn(5, 4, requires_grad=True)
        box2 = torch.randn(5, 4, requires_grad=True)
        # unsqueezing can currently not be fused
        b1x1 = box1[:, 0].unsqueeze(1)  # [N,1]
        b1y1 = box1[:, 1].unsqueeze(1)
        b1x2 = box1[:, 2].unsqueeze(1)
        b1y2 = box1[:, 3].unsqueeze(1)
        b2x1 = box2[:, 0].unsqueeze(0)  # [1,N]
        b2y1 = box2[:, 1].unsqueeze(0)
        b2x2 = box2[:, 2].unsqueeze(0)
        b2y2 = box2[:, 3].unsqueeze(0)

        s = self.checkScript(iou,
                             (b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2))
        self.assertAllFused(s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1,
                                        b2x2, b2y2),
                            except_for={
                                'aten::size', 'prim::BroadcastSizes',
                                'aten::_size_if_not_equal'
                            })

        with enable_profiling_mode(True):
            c = s(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2)
            warmup_backward(c.sum(),
                            [b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2])
            graph = backward_graph(s)
            self.assertAllFused(graph,
                                except_for={
                                    'aten::size', 'prim::BroadcastSizes',
                                    'aten::_size_if_not_equal'
                                })
Example #7
0
    def _test_reinforcement_learning(self, device, test_export_import=True):
        class Policy(nn.Module):
            def __init__(self):
                super(Policy, self).__init__()
                self.affine1 = nn.Linear(4, 128)
                self.affine2 = nn.Linear(128, 2)

            def forward(self, x):
                x = F.relu(self.affine1(x))
                action_scores = self.affine2(x)
                return F.softmax(action_scores, dim=1)

        with enable_profiling_mode():
            self.checkTrace(Policy().to(device),
                            (torch.rand(1, 4, device=device), ),
                            export_import=test_export_import)
Example #8
0
    def _test_vae(self, device, check_export_import=True, quantized=False):
        class VAE(nn.Module):
            def __init__(self):
                super(VAE, self).__init__()

                self.fc1 = nn.Linear(784, 400)
                self.fc21 = nn.Linear(400, 20)
                self.fc22 = nn.Linear(400, 20)
                self.fc3 = nn.Linear(20, 400)
                self.fc4 = nn.Linear(400, 784)

            def encode(self, x):
                h1 = F.relu(self.fc1(x))
                return self.fc21(h1), self.fc22(h1)

            def reparameterize(self, mu, logvar):
                if self.training:
                    std = torch.exp(0.5 * logvar)
                    eps = torch.randn_like(std)
                    return eps.mul(std).add_(mu)
                else:
                    return mu

            def decode(self, z):
                h3 = F.relu(self.fc3(z))
                return torch.sigmoid(self.fc4(h3))

            def forward(self, x):
                mu, logvar = self.encode(x.view(-1, 784))
                z = self.reparameterize(mu, logvar)
                return self.decode(z), mu, logvar

        if quantized:
            vae = VAE().to(device).eval()
            torch.jit.quantized.quantize_linear_modules(vae)
            # We don't do export/import checks because we would need to call
            # _unpack and _pack
            self.checkTrace(vae, (torch.rand(128, 1, 28, 28, device=device), ),
                            export_import=False,
                            allow_unused=True,
                            inputs_require_grads=False)
        else:
            with enable_profiling_mode():
                # eval() is present because randn_like makes this nondeterministic
                self.checkTrace(VAE().to(device).eval(),
                                (torch.rand(128, 1, 28, 28, device=device), ),
                                export_import=check_export_import)
Example #9
0
    def test_lstm_cuda(self):
        inputs = get_lstm_inputs('cuda', training=True)
        module = self.checkScript(LSTMCellS, inputs)
        return
        forward_graph = module.graph_for(*inputs)
        self.assertGraphContainsExactly(forward_graph,
                                        'prim::FusionGroup',
                                        1,
                                        consider_subgraphs=True)
        self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2)
        # Everything is differentiable but TupleConstruct return
        FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
            .check_next("return").run(str(forward_graph))

        with enable_profiling_mode(True):
            hy, cy = module(*inputs)
            warmup_backward((hy + cy).sum())
            backward = backward_graph(module)
        self.assertAllFused(backward,
                            except_for=("aten::t", "aten::mm",
                                        "aten::_grad_sum_to_size"))
Example #10
0
def do_detection_test(name):
    with enable_profiling_mode():
        torch.manual_seed(0)
        print("testing ", name)
        model = models.detection.__dict__[name](num_classes=50,
                                                pretrained_backbone=False)
        model.eval()
        input_shape = (3, 300, 300)
        x = torch.rand(input_shape)
        model_input = [x]
        out = model(model_input)
        scripted_model = torch.jit.script(model)
        scripted_model(model_input)
        scripted_model(model_input)
        plan = get_plan(scripted_model)
        num_bailouts = plan.code.num_bailouts()
        print(num_bailouts)
        for i in range(0, num_bailouts):
            plan.code.request_bailout(i)
            bailout_output = scripted_model(model_input)
        open("detection_{}".format(name), 'a').close()
Example #11
0
def do_classification_test(model_name):
    with enable_profiling_mode():
        input_shape = (1, 3, 224, 224)
        if model_name in ['inception_v3']:
            input_shape = (1, 3, 299, 299)
        print("testing ", model_name)
        open("classification_{}".format(model_name), 'a').close()
        torch.manual_seed(0)
        model = models.__dict__[model_name](num_classes=50)
        scripted_model = torch.jit.script(model)
        scripted_model.eval()
        x = torch.rand(input_shape)
        py_output = model(x)
        scripted_model(x)
        opt_output = scripted_model(x)
        #assert torch.allclose(py_output, opt_output)
        plan = get_plan(scripted_model)
        num_bailouts = plan.code.num_bailouts()
        print(num_bailouts)
        for i in range(0, num_bailouts):
            plan.code.request_bailout(i)
            bailout_output = scripted_model(x)
Example #12
0
def do_segmentation_test(model_name):
    with enable_profiling_mode():
        # passing num_class equal to a number other than 1000 helps in making the test
        # more enforcing in nature
        print("testing ", model_name)
        torch.manual_seed(0)
        open("segmentation_{}".format(model_name), 'a').close()
        model = models.segmentation.__dict__[model_name](
            num_classes=50, pretrained_backbone=False)
        model.eval()
        scripted_model = torch.jit.script(model)
        input_shape = (1, 3, 300, 300)
        x = torch.rand(input_shape)
        # out = model(x)
        scripted_model(x)
        scripted_model(x)
        plan = get_plan(scripted_model)
        num_bailouts = plan.code.num_bailouts()
        print(num_bailouts)
        for i in range(0, num_bailouts):
            plan.code.request_bailout(i)
            bailout_output = scripted_model(x)
Example #13
0
 def _test_mnist(self, device, check_export_import=True):
     # eval() is present because dropout makes this nondeterministic
     with enable_profiling_mode():
         self.checkTrace(MnistNet().to(device).eval(),
                         (torch.rand(5, 1, 28, 28, device=device), ),
                         export_import=check_export_import)
Example #14
0
    def checkScript(self,
                    script,
                    inputs,
                    name='func',
                    optimize=True,
                    inputs_requires_grad=False,
                    capture_output=False,
                    frames_up=1,
                    profiling=ProfilingMode.PROFILING):
        with torch.jit.optimized_execution(optimize):
            with enable_profiling_mode():
                if isinstance(script, str):
                    # Compile the string to a Script function
                    # with enable_profiling_mode():
                    cu = torch.jit.CompilationUnit(script,
                                                   _frames_up=frames_up)

                    # Execute the Python function so we can run it later and get its
                    # outputs

                    frame = self.get_frame_vars(frames_up)
                    the_locals = {}
                    execWrapper(script, glob=frame, loc=the_locals)
                    frame.update(the_locals)

                    python_fn = frame[name]
                    scripted_fn = getattr(cu, name)
                else:

                    # Check the string frontend first
                    source = textwrap.dedent(inspect.getsource(script))
                    self.checkScript(source,
                                     inputs,
                                     script.__name__,
                                     capture_output,
                                     profiling=profiling,
                                     frames_up=2)

                    # Continue checking the Python frontend
                    scripted_fn = torch.jit.script(script, _frames_up=1)
                    python_fn = script

                if inputs_requires_grad:
                    recording_inputs = do_input_map(
                        lambda t: t.detach().requires_grad_(), inputs)
                else:
                    recording_inputs = inputs

                if capture_output:
                    with self.capture_stdout() as script_stdout:
                        script_outputs = scripted_fn(*recording_inputs)
                    with self.capture_stdout() as opt_script_stdout:
                        opt_script_outputs = scripted_fn(*recording_inputs)
                    with self.capture_stdout() as _python_stdout:
                        python_outputs = python_fn(*inputs)
                    if not IS_WINDOWS:
                        self.assertExpected(script_stdout[0], subname='stdout')
                    self.assertEqual(python_outputs, opt_script_outputs)
                else:
                    # profiling run
                    script_outputs = scripted_fn(*recording_inputs)
                    # optimized run
                    opt_script_outputs = scripted_fn(*recording_inputs)
                    python_outputs = python_fn(*inputs)
                self.assertEqual(python_outputs, script_outputs)
                self.assertEqual(script_outputs, opt_script_outputs)
                return scripted_fn