예제 #1
0
    def checkScriptRaisesRegex(self,
                               script,
                               inputs,
                               exception,
                               regex,
                               outputs=None,
                               capture_output=False,
                               profiling=ProfilingMode.PROFILING):
        """
        Checks that a given function will throw the correct exception,
        when executed with normal python, the string frontend, and the AST frontend
        """

        with enable_profiling_mode_for_profiling_tests():
            # normal python
            with self.assertRaisesRegex(exception, regex):
                script(*inputs)
            # string frontend
            with self.assertRaisesRegex(exception, regex):
                source = textwrap.dedent(inspect.getsource(script))
                cu = torch.jit.CompilationUnit(source)
                ge = getattr(cu, script.__name__)
                # profiling run
                with self.assertRaisesRegex(exception, regex):
                    ge(*inputs)
                # optimized run
                ge(*inputs)
            # python AST frontend
            with self.assertRaisesRegex(exception, regex):
                ge = torch.jit.script(script)
                # profiling run
                with self.assertRaisesRegex(exception, regex):
                    ge(*inputs)
                # optimized run
                ge(*inputs)
예제 #2
0
    def test_clamp(self):
        def func2(a, b):
            return torch.clamp(a + b, min=0, max=2)

        def funcInf(a, b):
            return torch.clamp(a + b, min=0, max=float('inf'))

        def funcNegInf(a, b):
            return torch.clamp(a + b, min=float('-inf'), max=0)

        def funcOptMin(a, b):
            return torch.clamp(a + b, max=2)

        def funcOptMax(a, b):
            return torch.clamp(a + b, min=0)

        a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
        b = torch.randn(4, 4, dtype=torch.float, device='cuda')
        nan = torch.tensor(float('nan'), dtype=torch.float, device='cuda')

        funcs = (func2, funcInf, funcNegInf, funcOptMin, funcOptMax)
        for f, inputs in product(funcs, [[a, b], [a, nan]]):
            inp1, inp2 = inputs
            s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING)
            self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'})
            c = s(inp1, inp2)
            with enable_profiling_mode_for_profiling_tests():
                warmup_backward(c.sum())
            graph = backward_graph(s)
            self.assertAllFused(graph, except_for={'aten::Float', 'aten::_grad_sum_to_size'})
예제 #3
0
    def test_bias_as_arg(self):

        with enable_profiling_mode_for_profiling_tests():

            def method1(x, weight, bias: Optional[torch.Tensor]):
                return torch.nn.functional.linear(x, weight, bias).relu() + 2

            N = 10
            x = torch.rand(N, N, requires_grad=True)
            weight = torch.rand(N, N, requires_grad=True)
            bias = None
            scripted = self.checkScript(method1, (x, weight, bias))
            # check_types requires last_graph on scripted to be set, so we just skip it
            check_against_reference(self,
                                    scripted,
                                    method1,
                                    lambda x: x, (x, weight, bias),
                                    check_types=False)
            bias = torch.rand(N, N, requires_grad=True)
            scripted = self.checkScript(method1, (x, weight, bias))
            # check_types requires last_graph on scripted to be set, so we just skip it
            check_against_reference(self,
                                    scripted,
                                    method1,
                                    lambda x: x, (x, weight, bias),
                                    check_types=False)
예제 #4
0
 def _perform_ad_subgraph_slicing(self, fn, *input_sizes):
     with disable_autodiff_subgraph_inlining():
         with enable_profiling_mode_for_profiling_tests():
             ge = torch.jit.script(fn)
             inputs = [torch.randn(size, requires_grad=True) for size in input_sizes]
             ge(*inputs, profile_and_replay=True)
             return ge.graph_for(*inputs)
예제 #5
0
    def test_diff_graph_inline_threshold(self):
        with enable_profiling_mode_for_profiling_tests():
            NUM_RUNS = 1
            with num_profiled_runs(NUM_RUNS):

                @torch.jit.script
                def foo(x):

                    #  two nodes should be fused
                    #  see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/runtime/graph_executor_impl.h#L49
                    return torch.sigmoid(torch.sigmoid(x))

                @torch.jit.script
                def bar(x):
                    #  two nodes should NOT be fused
                    return torch.sigmoid(x)

                input = torch.rand([4, 4], requires_grad=True)
                foo(input)
                foo(input)

                bar(input)
                bar(input)

                print(foo.graph_for(input))
                self.assertGraphContainsExactly(foo.graph_for(input),
                                                'prim::DifferentiableGraph', 1)
                self.assertGraphContainsExactly(bar.graph_for(input),
                                                'prim::DifferentiableGraph', 0)
예제 #6
0
    def test_differentiable_graph_ops_requires_grad(self):
        x = torch.randn(8, 2, dtype=torch.float).requires_grad_()
        y = torch.randn(8, 2, dtype=torch.float)

        def t(x: torch.Tensor, y: torch.Tensor):
            o = x + 1.0
            o1 = torch.relu(o)
            o = y + 1.5
            o2 = torch.relu(o)
            o3 = o1 + o2
            return o1, o2, o3

        with enable_profiling_mode_for_profiling_tests():

            t_jit = torch.jit.script(t)
            jit_o = t_jit(x, y)
            jit_o = t_jit(x, y)
            o = t(x, y)

            FileCheck().check("prim::DifferentiableGraph").run(
                t_jit.graph_for(x, y))
            # validate the differentiableGraphOps are marking proper requires_grad
            for oo, jit_oo in zip(o, jit_o):
                self.assertEqual(oo.requires_grad, jit_oo.requires_grad)
                self.assertEqual(oo, jit_oo)
            # one more runs to trigger fusion
            jit_o = t_jit(x, y)
            for oo, jit_oo in zip(o, jit_o):
                self.assertEqual(oo.dtype, jit_oo.dtype)
                self.assertEqual(oo.requires_grad, jit_oo.requires_grad)
                self.assertEqual(oo, jit_oo)
예제 #7
0
    def checkScriptRaisesRegex(self,
                               script,
                               inputs,
                               exception,
                               regex,
                               name=None,
                               outputs=None,
                               capture_output=False,
                               frames_up=1,
                               profiling=ProfilingMode.PROFILING):
        """
        Checks that a given function will throw the correct exception,
        when executed with normal python, the string frontend, and the
        AST frontend. Logic taken from `checkScript` (see comments there
        for details)
        """
        with enable_profiling_mode_for_profiling_tests():
            # Normal Python
            with self.assertRaisesRegex(exception, regex):
                if isinstance(script, str):
                    frame = self.get_frame_vars(frames_up)
                    the_locals: Dict[str, Any] = {}
                    execWrapper(script, glob=frame, loc=the_locals)
                    frame.update(the_locals)

                    python_fn = frame[name]
                else:
                    python_fn = script

                python_fn(*inputs)

            # String frontend
            with self.assertRaisesRegex(exception, regex):
                if isinstance(script, str):
                    cu = torch.jit.CompilationUnit(script,
                                                   _frames_up=frames_up)
                    string_frontend = getattr(cu, name)
                else:
                    source = textwrap.dedent(inspect.getsource(script))
                    cu = torch.jit.CompilationUnit(source,
                                                   _frames_up=frames_up)
                    string_frontend = getattr(cu, script.__name__)

                with self.assertRaisesRegex(exception, regex):
                    string_frontend(*inputs)
                # optimized run
                string_frontend(*inputs)

            # Python AST frontend
            if not isinstance(script, str):
                with self.assertRaisesRegex(exception, regex):
                    ge = torch.jit.script(python_fn)
                    # profiling run
                    with self.assertRaisesRegex(exception, regex):
                        ge(*inputs)
                    # optimized run
                    ge(*inputs)
예제 #8
0
    def test_chunk_constant_script_ad(self):
        @torch.jit.script
        def func(x):
            x1, x2 = torch.chunk(x, 2)
            return (x1, x2)

        input = torch.rand(6, 10).requires_grad_()
        with disable_autodiff_subgraph_inlining():
            with enable_profiling_mode_for_profiling_tests():
                output = func(input, profile_and_replay=True)
                FileCheck().check_not("prim::DifferentiableGraph").run(func.graph_for(input))
    def test_chunk_constant_script_ad(self):
        @torch.jit.script
        def func(x):
            x1, x2 = torch.chunk(x, 2)
            return (x1, x2)

        input = torch.rand(6, 10).requires_grad_()
        with disable_autodiff_subgraph_inlining():
            with enable_profiling_mode_for_profiling_tests():
                output = func(input, profile_and_replay=True)
                self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], [])
예제 #10
0
    def test_constructed_bias(self):

        with enable_profiling_mode_for_profiling_tests():
            def method1(x, weight, b1, b2):
                bias = b1 * b2
                return torch.nn.functional.linear(x, weight, bias)
            N = 10
            x = torch.rand(N, N, requires_grad=True)
            weight = torch.rand(N, N, requires_grad=True)
            b1 = torch.rand(N, N, requires_grad=True)
            b2 = torch.rand(N, N, requires_grad=True)
            scripted = self.checkScript(method1, (x, weight, b1, b2))
            # check_types requires last_graph on scripted to be set, so we just skip it
            check_against_reference(self, scripted, method1, lambda x: x, (x, weight, b1, b2), check_types=False)
예제 #11
0
    def test_fuser_iou(self):
        # This checks if most of Intersection over Union is fused.
        # In particular, the backward contains many _grad_sum_to_size.
        def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2):
            ltx = torch.max(b1x1, b2x1)  # [N,M]
            lty = torch.max(b1y1, b2y1)
            rbx = torch.min(b1x2, b2x2)
            rby = torch.min(b1y2, b2y2)

            w = (rbx - ltx).clamp(min=0, max=float('inf'))  # [N,M]
            h = (rby - lty).clamp(min=0, max=float('inf'))  # [N,M]
            inter = w * h  # [N,M]

            area1 = (b1x2 - b1x1) * (b1y2 - b1y2)  # [N,1]
            area2 = (b2x2 - b2x1) * (b2y2 - b2y2)  # [1,M]
            iou = inter / (area1 + area2 - inter)
            return iou

        box1 = torch.randn(5, 4, requires_grad=True)
        box2 = torch.randn(5, 4, requires_grad=True)
        # unsqueezing can currently not be fused
        b1x1 = box1[:, 0].unsqueeze(1)  # [N,1]
        b1y1 = box1[:, 1].unsqueeze(1)
        b1x2 = box1[:, 2].unsqueeze(1)
        b1y2 = box1[:, 3].unsqueeze(1)
        b2x1 = box2[:, 0].unsqueeze(0)  # [1,N]
        b2y1 = box2[:, 1].unsqueeze(0)
        b2x2 = box2[:, 2].unsqueeze(0)
        b2y2 = box2[:, 3].unsqueeze(0)

        s = self.checkScript(iou,
                             (b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2))
        self.assertAllFused(s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1,
                                        b2x2, b2y2),
                            except_for={
                                'aten::size', 'prim::BroadcastSizes',
                                'aten::_size_if_not_equal'
                            })

        with enable_profiling_mode_for_profiling_tests(True):
            c = s(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2)
            warmup_backward(c.sum(),
                            [b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2])
            graph = backward_graph(s)
            self.assertAllFused(graph,
                                except_for={
                                    'aten::size', 'prim::BroadcastSizes',
                                    'aten::_size_if_not_equal'
                                })
예제 #12
0
    def _test_reinforcement_learning(self, device, test_export_import=True):
        class Policy(nn.Module):
            def __init__(self):
                super(Policy, self).__init__()
                self.affine1 = nn.Linear(4, 128)
                self.affine2 = nn.Linear(128, 2)

            def forward(self, x):
                x = F.relu(self.affine1(x))
                action_scores = self.affine2(x)
                return F.softmax(action_scores, dim=1)

        with enable_profiling_mode_for_profiling_tests():
            self.checkTrace(Policy().to(device), (torch.rand(1, 4, device=device),),
                            export_import=test_export_import)
예제 #13
0
    def _test_vae(self, device, check_export_import=True, quantized=False):
        class VAE(nn.Module):
            def __init__(self):
                super(VAE, self).__init__()

                self.fc1 = nn.Linear(784, 400)
                self.fc21 = nn.Linear(400, 20)
                self.fc22 = nn.Linear(400, 20)
                self.fc3 = nn.Linear(20, 400)
                self.fc4 = nn.Linear(400, 784)

            def encode(self, x):
                h1 = F.relu(self.fc1(x))
                return self.fc21(h1), self.fc22(h1)

            def reparameterize(self, mu, logvar):
                if self.training:
                    std = torch.exp(0.5 * logvar)
                    eps = torch.randn_like(std)
                    return eps.mul(std).add_(mu)
                else:
                    return mu

            def decode(self, z):
                h3 = F.relu(self.fc3(z))
                return torch.sigmoid(self.fc4(h3))

            def forward(self, x):
                mu, logvar = self.encode(x.view(-1, 784))
                z = self.reparameterize(mu, logvar)
                return self.decode(z), mu, logvar

        if quantized:
            vae = VAE().to(device).eval()
            torch.jit.quantized.quantize_linear_modules(vae)
            # We don't do export/import checks because we would need to call
            # _unpack and _pack
            self.checkTrace(vae, (torch.rand(128, 1, 28, 28, device=device), ),
                            export_import=False,
                            allow_unused=True,
                            inputs_require_grads=False)
        else:
            with enable_profiling_mode_for_profiling_tests():
                # eval() is present because randn_like makes this nondeterministic
                self.checkTrace(VAE().to(device).eval(),
                                (torch.rand(128, 1, 28, 28, device=device), ),
                                export_import=check_export_import)
예제 #14
0
    def test_lstm_cuda(self):
        inputs = get_lstm_inputs('cuda', training=True)
        module = self.checkScript(LSTMCellS, inputs)
        return
        forward_graph = module.graph_for(*inputs)
        self.assertGraphContainsExactly(
            forward_graph, FUSION_GROUP, 1, consider_subgraphs=True)
        self.assertTrue(len(strip_profiling_nodes(forward_graph.nodes())) == 2)
        # Everything is differentiable but TupleConstruct return
        FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
            .check_next("return").run(str(forward_graph))

        with enable_profiling_mode_for_profiling_tests(True):
            hy, cy = module(*inputs)
            warmup_backward((hy + cy).sum())
            backward = backward_graph(module)
        self.assertAllFused(backward, except_for=("aten::t", "aten::mm",
                                                  "aten::_grad_sum_to_size"))
예제 #15
0
    def test_requires_grad_for_tensor_list(self):

        with enable_profiling_mode_for_profiling_tests():

            # output & var_list[0] should have requires_grad set to True
            def func(input0: torch.Tensor, input1: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
                var_list = [input0, input1]
                var = torch.cat(var_list)
                output = var + 1.0
                return output, var_list
            jit_f = torch.jit.script(func)
            input0 = torch.randn((2,), requires_grad=True)
            input1 = torch.randn((2,))
            output_ref = func(input0, input1)
            for i in range(2):
                output = jit_f(input0, input1)
                assert(output_ref[0].requires_grad == output[0].requires_grad)
                assert(output_ref[1][0].requires_grad == output[1][0].requires_grad)
                assert(output_ref[1][1].requires_grad == output[1][1].requires_grad)
예제 #16
0
    def test_bias_as_module_attr(self):

        with enable_profiling_mode_for_profiling_tests():

            class M(torch.nn.Module):
                def __init__(self, has_bias):
                    super(M, self).__init__()
                    self.ll = torch.nn.Linear(10, 10, has_bias)

                def forward(self, x, y):
                    return self.ll(x + y) * x + y

            x = torch.rand(10, 10, requires_grad=True)
            no_bias = M(False)
            scripted_no_bias = torch.jit.script(no_bias)
            scripted_no_bias(x, x)
            scripted_no_bias(x, x)
            scripted_no_bias(x, x)
            has_bias = M(True)
            check_against_reference(self,
                                    scripted_no_bias,
                                    no_bias,
                                    lambda x: x, (
                                        x,
                                        x,
                                    ),
                                    check_types=False)
            scripted_has_bias = torch.jit.script(has_bias)
            scripted_has_bias(x, x)
            scripted_has_bias(x, x)
            scripted_has_bias(x, x)
            check_against_reference(self,
                                    scripted_has_bias,
                                    has_bias,
                                    lambda x: x, (
                                        x,
                                        x,
                                    ),
                                    check_types=False)
예제 #17
0
def check_against_reference(self, func, reference_func, output_func, args, kwargs=None,
                            allow_unused=True, check_types=True, no_grad=False):
    kwargs = kwargs if kwargs else {}

    def allSum(vs):
        if isinstance(vs, torch.Tensor):
            vs = (vs,)
        return sum((i + 1) * v.sum()
                   for i, v in enumerate(vs)
                   if v is not None and v.dtype in floating_and_complex_types_and(torch.half, torch.bfloat16))

    def clone_inputs(requires_grad):
        inputs = [
            arg.detach().clone().requires_grad_(requires_grad and arg.requires_grad)
            if isinstance(arg, torch.Tensor) else arg for arg in args
        ]
        return inputs, [input for input in inputs if isinstance(input, torch.Tensor) and input.requires_grad]

    nograd_inputs, nograd_tensors = clone_inputs(False)
    recording_inputs, recording_tensors = clone_inputs(True)

    # test no gradients case
    outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
    with enable_profiling_mode_for_profiling_tests():
        outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
    self.assertEqual(outputs, outputs_test)

    if check_types:
        check_output_types(self, func, outputs_test, nograd_inputs, kwargs)

    if no_grad:
        # skip grad tests
        return

    with enable_profiling_mode_for_profiling_tests():
        # test single grad case
        outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
        grads = torch.autograd.grad(allSum(outputs), recording_tensors,
                                    allow_unused=allow_unused)
        outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs))
        grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors,
                                         allow_unused=allow_unused)
        self.assertEqual(outputs, outputs_test)
        self.assertEqual(grads, grads_test)
        # test the grad grad case
        if self._testMethodName in nn_functional_single_grad:
            return

        outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
        l1 = allSum(outputs)
        grads = torch.autograd.grad(l1, recording_tensors, create_graph=True,
                                    allow_unused=allow_unused)

        l2 = (allSum(grads) * l1)
        grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused)
        recording_inputs, recording_tensors = clone_inputs(True)
        outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs))
        l1_test = allSum(outputs_test)
        grads_test = torch.autograd.grad(
            l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused)

        l2_test = (allSum(grads_test) * l1_test)
        grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused)

        self.assertEqual(outputs, outputs_test)
        self.assertEqual(grads, grads_test)
        for g2, g2_test in zip(grads2, grads2_test):
            if g2 is None and g2_test is None:
                continue
            self.assertTrue(torch.allclose(g2, g2_test, atol=5e-4, rtol=1e-4))
예제 #18
0
def check_against_reference(self,
                            func,
                            reference_func,
                            output_func,
                            args,
                            kwargs=None,
                            allow_unused=True,
                            check_types=True,
                            no_grad=False,
                            no_gradgrad=False):
    """Verifies a function performs identically to some reference implementation.

    Commonly, this is used to verify that a JIT implementation
    (output_func) matches the behavior of the eager implementation
    (reference_func).
    """
    kwargs = kwargs if kwargs else {}

    def allSum(vs):
        if isinstance(vs, torch.Tensor):
            vs = (vs, )
        return sum(
            (i + 1) * v.sum() for i, v in enumerate(vs)
            if v is not None and v.dtype in floating_and_complex_types_and(
                torch.half, torch.bfloat16))

    def clone_tensor(t, preserve_requires_grad):
        require_grad = preserve_requires_grad and t.requires_grad
        return t.detach().clone().requires_grad_(require_grad)

    def clone_inputs(preserve_requires_grad: bool):
        inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = []

        for arg in args:
            if isinstance(arg, torch.Tensor):
                inputs.append(clone_tensor(arg, preserve_requires_grad))
            elif is_iterable_of_tensors(arg):
                inputs.append(
                    [clone_tensor(t, preserve_requires_grad) for t in arg])
            else:
                inputs.append(arg)

        return inputs

    # Returns tensors in args that requires_grad, including tensors in TensorList args
    def get_recording_tensors(args):
        recording_tensors: List[torch.Tensor] = []

        for arg in args:
            if isinstance(arg, torch.Tensor) and arg.requires_grad:
                recording_tensors.append(arg)
            elif is_iterable_of_tensors(arg):
                recording_tensors.extend(filter(lambda t: t.requires_grad,
                                                arg))

        return recording_tensors

    # test no gradients case
    nograd_inputs = clone_inputs(preserve_requires_grad=False)
    outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
    with enable_profiling_mode_for_profiling_tests():
        outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
    self.assertEqual(outputs, outputs_test)

    if check_types:
        check_output_types(self, func, outputs_test, nograd_inputs, kwargs)

    if no_grad:
        # skip grad tests
        return

    with enable_profiling_mode_for_profiling_tests():
        # test single grad case
        recording_inputs = clone_inputs(preserve_requires_grad=True)
        recording_tensors = get_recording_tensors(recording_inputs)
        outputs = output_func(
            self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
        grads = torch.autograd.grad(allSum(outputs),
                                    recording_tensors,
                                    allow_unused=allow_unused)
        outputs_test = output_func(
            self.runAndSaveRNG(func, recording_inputs, kwargs))
        grads_test = torch.autograd.grad(allSum(outputs_test),
                                         recording_tensors,
                                         allow_unused=allow_unused)
        self.assertEqual(outputs, outputs_test)
        self.assertEqual(grads, grads_test)
        # test the grad grad case
        if self._testMethodName in nn_functional_single_grad or no_gradgrad:
            return

        outputs = output_func(
            self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
        l1 = allSum(outputs)
        grads = torch.autograd.grad(l1,
                                    recording_tensors,
                                    create_graph=True,
                                    allow_unused=allow_unused)

        l2 = (allSum(grads) * l1)
        grads2 = torch.autograd.grad(l2,
                                     recording_tensors,
                                     allow_unused=allow_unused)
        recording_inputs = clone_inputs(preserve_requires_grad=True)
        recording_tensors = get_recording_tensors(recording_inputs)
        outputs_test = output_func(
            self.runAndSaveRNG(func, recording_inputs, kwargs))
        l1_test = allSum(outputs_test)
        grads_test = torch.autograd.grad(l1_test,
                                         recording_tensors,
                                         create_graph=True,
                                         allow_unused=allow_unused)

        l2_test = (allSum(grads_test) * l1_test)
        grads2_test = torch.autograd.grad(l2_test,
                                          recording_tensors,
                                          allow_unused=allow_unused)

        self.assertEqual(outputs, outputs_test)
        self.assertEqual(grads, grads_test)
        for g2, g2_test in zip(grads2, grads2_test):
            if g2 is None and g2_test is None:
                continue
            self.assertEqual(g2, g2_test, atol=5e-4, rtol=1e-4)
예제 #19
0
    def test_aliased_outputs(self):

        with enable_profiling_mode_for_profiling_tests():


            # Case 1: aliasing between relu and t
            # is within a DifferentiableGraph. It should be valid
            # to merge both split_with_sizes in relu in one graph
            input_str = """
    graph(%a : Tensor):
        %b : Tensor = aten::relu(%a)
        %2 : Tensor = aten::t(%b)
        return (%2)
    """

            graph = torch._C.parse_ir(input_str)
            torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
            FileCheck().check("with prim::DifferentiableGraph") \
                .check("aten::relu").check("aten::t") \
                .run(graph)

            # Case 2: aliasing between relu and split_with_sizes
            # are both outputs of a Diff graph. It should be invalid
            # to merge both split_with_sizes in relu in one graph
            # i.e. relu and split_with_sizes should be in different
            # differentiable graphs
            input_str = """
    graph(%a : Tensor):
        %b : Tensor = aten::relu(%a)
        %0 : int[] = prim::Constant[value=[2, 2, 1]]()
        %1 : int = prim::Constant[value=0]()
        %2 : Tensor[] = aten::split_with_sizes(%b, %0, %1)
        %3 : (Tensor[], Tensor[]) = prim::TupleConstruct(%b, %2)
        return (%3)
"""

            graph = torch._C.parse_ir(input_str)
            torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
            FileCheck().check("Tensor = prim::DifferentiableGraph") \
                .check("with prim::DifferentiableGraph") \
                .check("Tensor = aten::relu") \
                .check_not("aten::split_with_sizes") \
                .run(graph)

            # Case 3: two aliased nodes in a graph.
            # Both `split_with_sizes` should be unfused
            input_str = """
    graph(%a : Tensor):
        %b : Tensor = aten::relu(%a)
        %s1 : int[] = prim::Constant[value=[2, 2, 1]]()
        %s2 : int[] = prim::Constant[value=[3, 1]]()
        %1 : int = prim::Constant[value=0]()
        %2 : Tensor[] = aten::split_with_sizes(%b, %s1, %1)
        %3 : Tensor[] = aten::split_with_sizes(%b, %s2, %1)
        %4 : (Tensor, Tensor[]) = prim::TupleConstruct(%b, %2, %3)
        return (%4)
"""

            graph = torch._C.parse_ir(input_str)
            torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
            FileCheck().check("Tensor = prim::DifferentiableGraph") \
                .check("with prim::DifferentiableGraph") \
                .check("Tensor = aten::relu") \
                .check_not("aten::split_with_sizes") \
                .run(graph)

            # Case 4: the aliased output has a descendant
            # Both should be unfused. Note, %3 comes before %2
            # to test that we unfuse in the reverse topo order
            input_str = """
    graph(%a : Tensor):
        %b : Tensor = aten::relu(%a)
        %0 : int[] = prim::Constant[value=[2, 2, 1]]()
        %1 : int = prim::Constant[value=0]()
        %2 : Tensor = aten::t(%b)
        %3 : Tensor = aten::gelu(%2)
        %4 : (Tensor, Tensor, Tensor[]) = prim::TupleConstruct(%b, %3, %2)
        return (%4)
"""

            graph = torch._C.parse_ir(input_str)
            torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
            FileCheck().check("Tensor = prim::DifferentiableGraph") \
                .check("with prim::DifferentiableGraph") \
                .check("Tensor = aten::relu") \
                .check_not("aten::t") \
                .run(graph)

            # Case 5: multiple aliased groups
            # Both should be unfused. Note, %3 comes before %2
            # to test that we unfuse in the reverse topo order
            input_str = """
    graph(%a : Tensor):
        %b : Tensor = aten::relu(%a)
        %c : Tensor = aten::abs(%a)
        %0 : int[] = prim::Constant[value=[2, 2, 1]]()
        %1 : int = prim::Constant[value=0]()
        %d : Tensor = aten::t(%c)
        %2 : Tensor = aten::t(%b)
        %3 : Tensor = aten::gelu(%2)
        %4 : (Tensor, Tensor, Tensor[]) = prim::TupleConstruct(%3, %2, %d, %b, %c, %b)
        return (%4)
"""

            graph = torch._C.parse_ir(input_str)
            torch._C._jit_pass_create_autodiff_subgraphs(graph, 1)
            FileCheck().check("Tensor = prim::DifferentiableGraph") \
                .check("with prim::DifferentiableGraph") \
                .check("Tensor = aten::relu") \
                .check_not("aten::t") \
                .run(graph)
예제 #20
0
    def checkScript(self,
                    script,
                    inputs,
                    name='func',
                    optimize=True,
                    inputs_requires_grad=False,
                    capture_output=False,
                    frames_up=1,
                    profiling=ProfilingMode.PROFILING):
        with torch.jit.optimized_execution(optimize):
            with enable_profiling_mode_for_profiling_tests():
                if isinstance(script, str):
                    # Compile the string to a Script function
                    # with enable_profiling_mode():
                    cu = torch.jit.CompilationUnit(script,
                                                   _frames_up=frames_up)

                    # Execute the Python function so we can run it later and get its
                    # outputs

                    frame = self.get_frame_vars(frames_up)
                    the_locals = {}
                    execWrapper(script, glob=frame, loc=the_locals)
                    frame.update(the_locals)

                    python_fn = frame[name]
                    scripted_fn = getattr(cu, name)
                else:

                    # Check the string frontend first
                    source = textwrap.dedent(inspect.getsource(script))
                    self.checkScript(source,
                                     inputs,
                                     script.__name__,
                                     optimize=optimize,
                                     inputs_requires_grad=inputs_requires_grad,
                                     capture_output=capture_output,
                                     profiling=profiling,
                                     frames_up=2)

                    # Continue checking the Python frontend
                    scripted_fn = torch.jit.script(script, _frames_up=1)
                    python_fn = script

                if inputs_requires_grad:
                    recording_inputs = do_input_map(
                        lambda t: t.detach().requires_grad_(), inputs)
                else:
                    recording_inputs = inputs

                if capture_output:
                    with self.capture_stdout() as script_stdout:
                        script_outputs = scripted_fn(*recording_inputs)
                    with self.capture_stdout() as opt_script_stdout:
                        opt_script_outputs = scripted_fn(*recording_inputs)
                    with self.capture_stdout() as _python_stdout:
                        python_outputs = python_fn(*inputs)
                    if not IS_WINDOWS:
                        self.assertExpected(script_stdout[0], subname='stdout')
                    self.assertEqual(python_outputs, opt_script_outputs)
                else:
                    # profiling run
                    script_outputs = scripted_fn(*recording_inputs)
                    # optimized run
                    opt_script_outputs = scripted_fn(*recording_inputs)
                    if TEST_BAILOUTS:
                        self.checkBailouts(scripted_fn, inputs,
                                           opt_script_outputs)
                    python_outputs = python_fn(*inputs)
                self.assertEqual(python_outputs, script_outputs)
                self.assertEqual(script_outputs, opt_script_outputs)
                return scripted_fn
예제 #21
0
 def _test_mnist(self, device, check_export_import=True):
     # eval() is present because dropout makes this nondeterministic
     with enable_profiling_mode_for_profiling_tests():
         self.checkTrace(MnistNet().to(device).eval(),
                         (torch.rand(5, 1, 28, 28, device=device), ),
                         export_import=check_export_import)