コード例 #1
0
    def test_clamp(self):
        def func2(a, b):
            return torch.clamp(a + b, min=0, max=2)

        def funcInf(a, b):
            return torch.clamp(a + b, min=0, max=float('inf'))

        def funcOptMin(a, b):
            return torch.clamp(a + b, max=2)

        def funcOptMax(a, b):
            return torch.clamp(a + b, min=0)

        a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
        b = torch.randn(4, 4, dtype=torch.float, device='cuda')
        nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda')

        funcs = (func2, funcInf, funcOptMin, funcOptMax)
        for f, inputs in product(funcs, [[a, b], [a, nan]]):
            inp1, inp2 = inputs
            s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING)
            self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'})
            c = s(inp1, inp2)
            with enable_profiling_mode():
                warmup_backward(c.sum())
            graph = backward_graph(s)
            self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'})
コード例 #2
0
ファイル: jit_utils.py プロジェクト: leohunt/pytorch
    def checkScriptRaisesRegex(self,
                               script,
                               inputs,
                               exception,
                               regex,
                               outputs=None,
                               capture_output=False,
                               profiling=ProfilingMode.PROFILING):
        """
        Checks that a given function will throw the correct exception,
        when executed with normal python, the string frontend, and the AST frontend
        """

        with enable_profiling_mode():
            # normal python
            with self.assertRaisesRegex(exception, regex):
                script(*inputs)
            # string frontend
            with self.assertRaisesRegex(exception, regex):
                source = textwrap.dedent(inspect.getsource(script))
                cu = torch.jit.CompilationUnit(source)
                ge = getattr(cu, script.__name__)
                # profiling run
                with self.assertRaisesRegex(exception, regex):
                    ge(*inputs)
                # optimized run
                ge(*inputs)
            # python AST frontend
            with self.assertRaisesRegex(exception, regex):
                ge = torch.jit.script(script)
                # profiling run
                with self.assertRaisesRegex(exception, regex):
                    ge(*inputs)
                # optimized run
                ge(*inputs)
コード例 #3
0
 def _perform_ad_subgraph_slicing(self, fn, *input_sizes):
     with disable_autodiff_subgraph_inlining():
         with enable_profiling_mode():
             ge = torch.jit.script(fn)
             inputs = [
                 torch.randn(size, requires_grad=True)
                 for size in input_sizes
             ]
             ge(*inputs, profile_and_replay=True)
             return ge.graph_for(*inputs)
コード例 #4
0
    def test_chunk_constant_script_ad(self):
        @torch.jit.script
        def func(x):
            x1, x2 = torch.chunk(x, 2)
            return (x1, x2)

        input = torch.rand(6, 10).requires_grad_()
        with disable_autodiff_subgraph_inlining():
            with enable_profiling_mode():
                output = func(input, profile_and_replay=True)
                self.assertAutodiffNode(func.graph_for(input), True,
                                        ['prim::ConstantChunk'], [])
コード例 #5
0
    def test_fuser_iou(self):
        # This checks if most of Intersection over Union is fused.
        # In particular, the backward contains many _grad_sum_to_size.
        def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2):
            ltx = torch.max(b1x1, b2x1)  # [N,M]
            lty = torch.max(b1y1, b2y1)
            rbx = torch.min(b1x2, b2x2)
            rby = torch.min(b1y2, b2y2)

            w = (rbx - ltx).clamp(min=0, max=float('inf'))  # [N,M]
            h = (rby - lty).clamp(min=0, max=float('inf'))  # [N,M]
            inter = w * h  # [N,M]

            area1 = (b1x2 - b1x1) * (b1y2 - b1y2)  # [N,1]
            area2 = (b2x2 - b2x1) * (b2y2 - b2y2)  # [1,M]
            iou = inter / (area1 + area2 - inter)
            return iou

        box1 = torch.randn(5, 4, requires_grad=True)
        box2 = torch.randn(5, 4, requires_grad=True)
        # unsqueezing can currently not be fused
        b1x1 = box1[:, 0].unsqueeze(1)  # [N,1]
        b1y1 = box1[:, 1].unsqueeze(1)
        b1x2 = box1[:, 2].unsqueeze(1)
        b1y2 = box1[:, 3].unsqueeze(1)
        b2x1 = box2[:, 0].unsqueeze(0)  # [1,N]
        b2y1 = box2[:, 1].unsqueeze(0)
        b2x2 = box2[:, 2].unsqueeze(0)
        b2y2 = box2[:, 3].unsqueeze(0)

        s = self.checkScript(iou,
                             (b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2))
        self.assertAllFused(s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1,
                                        b2x2, b2y2),
                            except_for={
                                'aten::size', 'prim::BroadcastSizes',
                                'aten::_size_if_not_equal'
                            })

        with enable_profiling_mode(True):
            c = s(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2)
            warmup_backward(c.sum(),
                            [b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2])
            graph = backward_graph(s)
            self.assertAllFused(graph,
                                except_for={
                                    'aten::size', 'prim::BroadcastSizes',
                                    'aten::_size_if_not_equal'
                                })
コード例 #6
0
    def _test_reinforcement_learning(self, device, test_export_import=True):
        class Policy(nn.Module):
            def __init__(self):
                super(Policy, self).__init__()
                self.affine1 = nn.Linear(4, 128)
                self.affine2 = nn.Linear(128, 2)

            def forward(self, x):
                x = F.relu(self.affine1(x))
                action_scores = self.affine2(x)
                return F.softmax(action_scores, dim=1)

        with enable_profiling_mode():
            self.checkTrace(Policy().to(device), (torch.rand(1, 4, device=device),),
                            export_import=test_export_import)
コード例 #7
0
    def _test_vae(self, device, check_export_import=True, quantized=False):
        class VAE(nn.Module):
            def __init__(self):
                super(VAE, self).__init__()

                self.fc1 = nn.Linear(784, 400)
                self.fc21 = nn.Linear(400, 20)
                self.fc22 = nn.Linear(400, 20)
                self.fc3 = nn.Linear(20, 400)
                self.fc4 = nn.Linear(400, 784)

            def encode(self, x):
                h1 = F.relu(self.fc1(x))
                return self.fc21(h1), self.fc22(h1)

            def reparameterize(self, mu, logvar):
                if self.training:
                    std = torch.exp(0.5 * logvar)
                    eps = torch.randn_like(std)
                    return eps.mul(std).add_(mu)
                else:
                    return mu

            def decode(self, z):
                h3 = F.relu(self.fc3(z))
                return torch.sigmoid(self.fc4(h3))

            def forward(self, x):
                mu, logvar = self.encode(x.view(-1, 784))
                z = self.reparameterize(mu, logvar)
                return self.decode(z), mu, logvar

        if quantized:
            vae = VAE().to(device).eval()
            torch.jit.quantized.quantize_linear_modules(vae)
            # We don't do export/import checks because we would need to call
            # _unpack and _pack
            self.checkTrace(vae, (torch.rand(128, 1, 28, 28, device=device), ),
                            export_import=False,
                            allow_unused=True,
                            inputs_require_grads=False)
        else:
            with enable_profiling_mode():
                # eval() is present because randn_like makes this nondeterministic
                self.checkTrace(VAE().to(device).eval(),
                                (torch.rand(128, 1, 28, 28, device=device), ),
                                export_import=check_export_import)
コード例 #8
0
    def test_lstm_cuda(self):
        inputs = get_lstm_inputs('cuda', training=True)
        module = self.checkScript(LSTMCellS, inputs)
        return
        forward_graph = module.graph_for(*inputs)
        self.assertGraphContainsExactly(
            forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
        self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2)
        # Everything is differentiable but TupleConstruct return
        FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
            .check_next("return").run(str(forward_graph))

        with enable_profiling_mode(True):
            hy, cy = module(*inputs)
            warmup_backward((hy + cy).sum())
            backward = backward_graph(module)
        self.assertAllFused(backward, except_for=("aten::t", "aten::mm",
                                                  "aten::_grad_sum_to_size"))
コード例 #9
0
 def _test_mnist(self, device, check_export_import=True):
     # eval() is present because dropout makes this nondeterministic
     with enable_profiling_mode():
         self.checkTrace(MnistNet().to(device).eval(), (torch.rand(5, 1, 28, 28, device=device),),
                         export_import=check_export_import)
コード例 #10
0
ファイル: jit_utils.py プロジェクト: leohunt/pytorch
    def checkScript(self,
                    script,
                    inputs,
                    name='func',
                    optimize=True,
                    inputs_requires_grad=False,
                    capture_output=False,
                    frames_up=1,
                    profiling=ProfilingMode.PROFILING):
        with torch.jit.optimized_execution(optimize):
            with enable_profiling_mode():
                if isinstance(script, str):
                    # Compile the string to a Script function
                    # with enable_profiling_mode():
                    cu = torch.jit.CompilationUnit(script,
                                                   _frames_up=frames_up)

                    # Execute the Python function so we can run it later and get its
                    # outputs

                    frame = self.get_frame_vars(frames_up)
                    the_locals = {}
                    execWrapper(script, glob=frame, loc=the_locals)
                    frame.update(the_locals)

                    python_fn = frame[name]
                    scripted_fn = getattr(cu, name)
                else:

                    # Check the string frontend first
                    source = textwrap.dedent(inspect.getsource(script))
                    self.checkScript(source,
                                     inputs,
                                     script.__name__,
                                     capture_output,
                                     profiling=profiling,
                                     frames_up=2)

                    # Continue checking the Python frontend
                    scripted_fn = torch.jit.script(script, _frames_up=1)
                    python_fn = script

                if inputs_requires_grad:
                    recording_inputs = do_input_map(
                        lambda t: t.detach().requires_grad_(), inputs)
                else:
                    recording_inputs = inputs

                if capture_output:
                    with self.capture_stdout() as script_stdout:
                        script_outputs = scripted_fn(*recording_inputs)
                    with self.capture_stdout() as opt_script_stdout:
                        opt_script_outputs = scripted_fn(*recording_inputs)
                    with self.capture_stdout() as _python_stdout:
                        python_outputs = python_fn(*inputs)
                    if not IS_WINDOWS:
                        self.assertExpected(script_stdout[0], subname='stdout')
                    self.assertEqual(python_outputs, opt_script_outputs)
                else:
                    # profiling run
                    script_outputs = scripted_fn(*recording_inputs)
                    # optimized run
                    opt_script_outputs = scripted_fn(*recording_inputs)
                    python_outputs = python_fn(*inputs)
                self.assertEqual(python_outputs, script_outputs)
                self.assertEqual(script_outputs, opt_script_outputs)
                return scripted_fn