Beispiel #1
0
 def _perform_ad_subgraph_slicing(self, fn, *input_sizes):
     with disable_autodiff_subgraph_inlining():
         with enable_profiling_mode_for_profiling_tests():
             ge = torch.jit.script(fn)
             inputs = [torch.randn(size, requires_grad=True) for size in input_sizes]
             ge(*inputs, profile_and_replay=True)
             return ge.graph_for(*inputs)
Beispiel #2
0
    def test_chunk_constant_script_ad(self):
        @torch.jit.script
        def func(x):
            x1, x2 = torch.chunk(x, 2)
            return (x1, x2)

        input = torch.rand(6, 10).requires_grad_()
        with disable_autodiff_subgraph_inlining():
            with enable_profiling_mode_for_profiling_tests():
                output = func(input, profile_and_replay=True)
                FileCheck().check_not("prim::DifferentiableGraph").run(func.graph_for(input))
    def test_chunk_constant_script_ad(self):
        @torch.jit.script
        def func(x):
            x1, x2 = torch.chunk(x, 2)
            return (x1, x2)

        input = torch.rand(6, 10).requires_grad_()
        with disable_autodiff_subgraph_inlining():
            with enable_profiling_mode_for_profiling_tests():
                output = func(input, profile_and_replay=True)
                self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], [])
Beispiel #4
0
    def test_variant_consistency_jit(self, device, dtype, op):
        _requires_grad = op.supports_autograd and (
            dtype.is_floating_point or op.supports_complex_autograd)
        samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)

        for sample in samples:
            # Acquires variants to test
            func = op.get_op()
            method = op.get_method()
            variants = {
                # TODO: inplace tests currently fail, fix and add inplace variant
                'function': func,
                'method': method,
            }

            # Test traced and scripted consistency
            for func_type, variant in variants.items():
                if variant is None:
                    continue

                # Create accessor for script function variant
                name = op.name + '_' if func_type == 'inplace' else op.name

                # run with disable_autodiff_subgraph_inlining(True) to test
                #   autodiff support. Context manager forces the graph to contain
                #   DifferentiableGraph nodes if they are present
                with disable_autodiff_subgraph_inlining():
                    # Check scripted forward, grad, and grad grad
                    script_fn = create_script_fn(self, name, func_type)

                    def out_fn(output):
                        # Processes the output for autograd
                        if sample.output_process_fn_grad is not None:
                            return sample.output_process_fn_grad(output)
                        return output

                    check_against_reference(self,
                                            script_fn,
                                            func,
                                            out_fn,
                                            (sample.input, ) + sample.args,
                                            sample.kwargs,
                                            no_grad=not _requires_grad)

                    # Check traced forward, grad, and grad grad
                    traced_fn = create_traced_fn(self, variant)
                    check_against_reference(self,
                                            traced_fn,
                                            func,
                                            out_fn,
                                            (sample.input, ) + sample.args,
                                            sample.kwargs,
                                            no_grad=not _requires_grad)

                    # Check alias annotation schema for correctness (make
                    #   sure inputs that aren't supposed to be modified aren't)
                    # Note: only runs in float32 and int64 because schema isn't affected by dtype,
                    #   so running it on all dtypes is would be excessive
                    if dtype in [torch.float32, torch.int32]:
                        check_alias_annotation(name,
                                               (sample.input, ) + sample.args,
                                               sample.kwargs,
                                               func_type=func_type,
                                               aten_name=op.aten_name)

                    # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
                    if dtype is torch.float32:
                        # Sandcastle doesn't fuse nodes
                        if IS_SANDCASTLE:
                            # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
                            nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
                            fusible_nodes = []
                        else:
                            nonfusible_nodes = op.autodiff_nonfusible_nodes
                            fusible_nodes = op.autodiff_fusible_nodes

                        self.assertAutodiffNode(traced_fn.last_graph,
                                                op.assert_autodiffed,
                                                nonfusible_nodes,
                                                fusible_nodes)
                        self.assertAutodiffNode(script_fn.last_graph,
                                                op.assert_autodiffed,
                                                nonfusible_nodes,
                                                fusible_nodes)
Beispiel #5
0
    def test_variant_consistency_jit(self, device, dtype, op):
        samples = op.sample_inputs(device, dtype, requires_grad=True)
        if len(samples) == 0:
            self.skipTest("Skipped! No sample inputs!")

        for sample in samples:

            # Acquires variants to test
            method = op.get_method()
            inplace = op.get_inplace()
            variants = (v for v in (method, inplace) if v is not None)

            # Adds function variant to variant list
            # TODO: inplace tests currently fail
            # variants = (v for v in (op, method, inplace) if v is not None)
            variants = (v for v in (op, method) if v is not None)

            # Test traced and scripted consistency
            for variant in variants:
                # Create accessor for script function variant
                if variant is op:
                    name = op.name
                    func_type = 'function'
                elif variant is method:
                    name = op.name
                    func_type = 'method'
                else:  # variant is inplace
                    assert variant is inplace
                    name = op.name + "_"
                    func_type = 'inplace'

                # run with disable_autodiff_subgraph_inlining(True) to test
                #   autodiff support. Context manager forces the graph to contain
                #   DifferentiableGraph nodes if they are present
                with disable_autodiff_subgraph_inlining():

                    def fn(*inputs, **kwargs):
                        attr = getattr(inputs[0], name)
                        output = attr(*inputs[1:], **kwargs)
                        return op.output_func(output)

                    # bfloat16 grad doesn't work for some operators
                    dtypes_to_grad_check = floating_and_complex_types_and(torch.half) \
                        if op.skip_bfloat16_grad else floating_and_complex_types_and(torch.half, torch.bfloat16)

                    # Check scripted forward, grad, and grad grad
                    script_fn = create_script_fn(self, name, func_type,
                                                 op.output_func)

                    check_against_reference(
                        self,
                        script_fn,
                        fn, (*sample.input, ) + sample.args,
                        sample.kwargs,
                        no_grad=(dtype not in dtypes_to_grad_check))

                    # Check traced forward, grad, and grad grad
                    traced_fn = create_traced_fn(self, variant)
                    check_against_reference(
                        self,
                        traced_fn,
                        fn, (*sample.input, ) + sample.args,
                        sample.kwargs,
                        no_grad=(dtype not in dtypes_to_grad_check))

                    # Check alias annotation schema for correctness (make
                    #   sure inputs that aren't supposed to be modified aren't)
                    # Note: only runs in float32 and int64 because schema isn't affected by dtype,
                    #   so running it on all dtypes is would be excessive
                    if dtype in [torch.float32, torch.int32]:
                        check_alias_annotation(name,
                                               (*sample.input, ) + sample.args,
                                               sample.kwargs)

                    # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
                    if dtype is torch.float32:
                        # Sandcastle doesn't fuse nodes
                        if IS_SANDCASTLE:
                            # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
                            nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
                            fusible_nodes = []
                        else:
                            nonfusible_nodes = op.autodiff_nonfusible_nodes
                            fusible_nodes = op.autodiff_fusible_nodes

                        self.assertAutodiffNode(traced_fn.last_graph,
                                                op.assert_autodiffed,
                                                nonfusible_nodes,
                                                fusible_nodes)
                        self.assertAutodiffNode(script_fn.last_graph,
                                                op.assert_autodiffed,
                                                nonfusible_nodes,
                                                fusible_nodes)
Beispiel #6
0
    def test_variant_consistency_jit(self, device, dtype, op):
        _requires_grad = op.supports_autograd and (
            dtype.is_floating_point
            or op.supports_complex_autograd(torch.device(device).type))

        include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
        samples = op.sample_inputs(
            device,
            dtype,
            requires_grad=_requires_grad,
            include_conjugated_inputs=include_conjugated_inputs)

        # Acquires variants to test
        func = op.get_op()
        method = op.get_method()
        variants = {
            # TODO: inplace tests currently fail, fix and add inplace variant
            'function': func,
            'method': method,
        }

        # TODO: find better way to standardize on op registration itself..
        has_fake_function = op.name in ["resize_", 'resize_as_']

        if has_fake_function:
            variants = {'method': getattr(torch.Tensor, op.name)}
            samples = op.sample_inputs(device, dtype, requires_grad=False)

        support_script = op.supports_scripting

        tested = False
        for sample in samples:
            # Test traced and scripted consistency
            for func_type, variant in variants.items():
                if variant is None:
                    continue

                # scripting and check_alias_analysis do not work with lambdas
                # lambdas are typically used as a way to simulate methods without
                # functional variants, so rely on the other variant for testing
                # for now
                if is_lambda(variant):
                    continue

                tested = True

                # Create accessor for script function variant
                name = op.name + '_' if func_type == 'inplace' else op.name

                # run with disable_autodiff_subgraph_inlining(True) to test
                #   autodiff support. Context manager forces the graph to contain
                #   DifferentiableGraph nodes if they are present
                with disable_autodiff_subgraph_inlining():
                    # Check scripted forward, grad, and grad grad
                    if support_script:
                        script_fn = create_script_fn(self, name, func_type)

                    def out_fn(output):
                        # Processes the output for autograd
                        if sample.output_process_fn_grad is not None:
                            return sample.output_process_fn_grad(output)
                        return output

                    def get_sample():
                        return clone_input_helper(
                            sample.input
                        ) if op.name[-1] == '_' else sample.input

                    if support_script:
                        check_against_reference(
                            self,
                            script_fn,
                            func,
                            out_fn, (get_sample(), ) + sample.args,
                            sample.kwargs,
                            no_grad=not _requires_grad,
                            no_gradgrad=not op.supports_gradgrad)

                    # Check traced forward, grad, and grad grad
                    # TODO: fix tracing here
                    supports_tracing = not has_fake_function
                    if op.assert_jit_shape_analysis:
                        self.assertTrue(supports_tracing)

                    if supports_tracing:
                        traced_fn = create_traced_fn(self, variant)
                        check_against_reference(
                            self,
                            traced_fn,
                            func,
                            out_fn, (get_sample(), ) + sample.args,
                            sample.kwargs,
                            no_grad=not _requires_grad,
                            no_gradgrad=not op.supports_gradgrad)

                    # Check alias annotation schema for correctness (make
                    #   sure inputs that aren't supposed to be modified aren't)
                    # Note: only runs in float32 because schema isn't affected by dtype,
                    #   so running it on all dtypes is would be excessive
                    if dtype == torch.float32:
                        # TODO: no reason why we cant run this with tracing graph
                        if support_script and op.name != "rsub":
                            check_alias_annotation(name, (get_sample(), ) +
                                                   sample.args,
                                                   sample.kwargs,
                                                   func_type=func_type,
                                                   aten_name=op.aten_name)

                        # TODO: use script graph as well
                        checked_shape_analysis = False
                        if supports_tracing:
                            out = variant(get_sample(), *sample.args,
                                          **sample.kwargs)

                            # right now, tuple of outputs and tensor output supported
                            # TODO: list of tensor outputs
                            tuple_of_tensors = isinstance(out, tuple) and all([
                                isinstance(elem, torch.Tensor) for elem in out
                            ])

                            if isinstance(out,
                                          torch.Tensor) or tuple_of_tensors:
                                if tuple_of_tensors:
                                    sizes = [elem.size() for elem in out]
                                else:
                                    sizes = out.size()
                                self.checkShapeAnalysis(
                                    sizes, traced_fn.graph,
                                    op.assert_jit_shape_analysis)
                                checked_shape_analysis = True
                        if op.assert_jit_shape_analysis:
                            self.assertTrue(checked_shape_analysis)

                    # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
                    if dtype is torch.float32:
                        # Sandcastle doesn't fuse nodes
                        if IS_SANDCASTLE:
                            # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
                            nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
                            fusible_nodes = []
                        else:
                            nonfusible_nodes = op.autodiff_nonfusible_nodes
                            fusible_nodes = op.autodiff_fusible_nodes

                        if supports_tracing:
                            self.assertAutodiffNode(traced_fn.last_graph,
                                                    op.assert_autodiffed,
                                                    nonfusible_nodes,
                                                    fusible_nodes)
                        if support_script:
                            self.assertAutodiffNode(script_fn.last_graph,
                                                    op.assert_autodiffed,
                                                    nonfusible_nodes,
                                                    fusible_nodes)
        assert tested, "JIT Test does not execute any logic"
Beispiel #7
0
    def indiv_variant_test_jit(self, device, dtype, op, sample, func_type,
                               variant, has_fake_function):
        _requires_grad = (dtype in op.supported_backward_dtypes(
            torch.device(device).type))
        support_script = op.supports_scripting
        # Create accessor for script function variant
        name = op.name + '_' if func_type == 'inplace' else op.name

        # run with disable_autodiff_subgraph_inlining(True) to test
        #   autodiff support. Context manager forces the graph to contain
        #   DifferentiableGraph nodes if they are present
        with disable_autodiff_subgraph_inlining():
            # Check scripted forward, grad, and grad grad
            if support_script:
                script_fn = create_script_fn(self, name, func_type)

            def out_fn(output):
                # Processes the output for autograd
                if sample.output_process_fn_grad is not None:
                    return sample.output_process_fn_grad(output)
                return output

            def get_sample():
                return clone_input_helper(
                    sample.input) if op.name[-1] == '_' else sample.input

            if support_script:
                check_against_reference(self,
                                        script_fn,
                                        op.get_op(),
                                        out_fn, (get_sample(), ) + sample.args,
                                        sample.kwargs,
                                        no_grad=not _requires_grad,
                                        no_gradgrad=not op.supports_gradgrad)

            # Check traced forward, grad, and grad grad
            # TODO: fix tracing here
            supports_tracing = op.supports_tracing and not has_fake_function
            if op.assert_jit_shape_analysis:
                self.assertTrue(supports_tracing)

            if supports_tracing:
                traced_fn = create_traced_fn(self, variant)
                check_against_reference(self,
                                        traced_fn,
                                        op.get_op(),
                                        out_fn, (get_sample(), ) + sample.args,
                                        sample.kwargs,
                                        no_grad=not _requires_grad,
                                        no_gradgrad=not op.supports_gradgrad)

            # Check alias annotation schema for correctness (make
            #   sure inputs that aren't supposed to be modified aren't)
            # Note: only runs in float32 because schema isn't affected by dtype,
            #   so running it on all dtypes is would be excessive
            if dtype == torch.float32:
                # TODO: no reason why we cant run this with tracing graph
                if support_script and op.name != "rsub":
                    check_alias_annotation(name,
                                           (get_sample(), ) + sample.args,
                                           sample.kwargs,
                                           func_type=func_type,
                                           aten_name=op.aten_name)

                # TODO: use script graph as well
                checked_shape_analysis = False
                if supports_tracing:
                    out = variant(get_sample(), *sample.args, **sample.kwargs)

                    # right now, tuple of outputs and tensor output supported
                    # TODO: list of tensor outputs
                    tuple_of_tensors = isinstance(out, tuple) and all(
                        [isinstance(elem, torch.Tensor) for elem in out])

                    if isinstance(out, torch.Tensor) or tuple_of_tensors:
                        if tuple_of_tensors:
                            sizes = [elem.size() for elem in out]
                        else:
                            sizes = out.size()
                        self.checkShapeAnalysis(sizes, traced_fn.graph,
                                                op.assert_jit_shape_analysis)
                        checked_shape_analysis = True
                if op.assert_jit_shape_analysis:
                    self.assertTrue(checked_shape_analysis)

            # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
            if dtype is torch.float32:
                # Sandcastle doesn't fuse nodes
                if IS_SANDCASTLE:
                    # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
                    nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
                    fusible_nodes = []
                else:
                    nonfusible_nodes = op.autodiff_nonfusible_nodes
                    fusible_nodes = op.autodiff_fusible_nodes

                if supports_tracing:
                    self.assertAutodiffNode(traced_fn.last_graph,
                                            op.assert_autodiffed,
                                            nonfusible_nodes, fusible_nodes)
                if support_script:
                    self.assertAutodiffNode(script_fn.last_graph,
                                            op.assert_autodiffed,
                                            nonfusible_nodes, fusible_nodes)