Ejemplo n.º 1
0
    def test_bias_as_arg(self):

        with enable_profiling_mode_for_profiling_tests():

            def method1(x, weight, bias: Optional[torch.Tensor]):
                return torch.nn.functional.linear(x, weight, bias).relu() + 2

            N = 10
            x = torch.rand(N, N, requires_grad=True)
            weight = torch.rand(N, N, requires_grad=True)
            bias = None
            scripted = self.checkScript(method1, (x, weight, bias))
            # check_types requires last_graph on scripted to be set, so we just skip it
            check_against_reference(self,
                                    scripted,
                                    method1,
                                    lambda x: x, (x, weight, bias),
                                    check_types=False)
            bias = torch.rand(N, N, requires_grad=True)
            scripted = self.checkScript(method1, (x, weight, bias))
            # check_types requires last_graph on scripted to be set, so we just skip it
            check_against_reference(self,
                                    scripted,
                                    method1,
                                    lambda x: x, (x, weight, bias),
                                    check_types=False)
Ejemplo n.º 2
0
    def test_constructed_bias(self):

        with enable_profiling_mode_for_profiling_tests():
            def method1(x, weight, b1, b2):
                bias = b1 * b2
                return torch.nn.functional.linear(x, weight, bias)
            N = 10
            x = torch.rand(N, N, requires_grad=True)
            weight = torch.rand(N, N, requires_grad=True)
            b1 = torch.rand(N, N, requires_grad=True)
            b2 = torch.rand(N, N, requires_grad=True)
            scripted = self.checkScript(method1, (x, weight, b1, b2))
            # check_types requires last_graph on scripted to be set, so we just skip it
            check_against_reference(self, scripted, method1, lambda x: x, (x, weight, b1, b2), check_types=False)
Ejemplo n.º 3
0
    def test_bias_as_module_attr(self):

        with enable_profiling_mode_for_profiling_tests():

            class M(torch.nn.Module):
                def __init__(self, has_bias):
                    super(M, self).__init__()
                    self.ll = torch.nn.Linear(10, 10, has_bias)

                def forward(self, x, y):
                    return self.ll(x + y) * x + y

            x = torch.rand(10, 10, requires_grad=True)
            no_bias = M(False)
            scripted_no_bias = torch.jit.script(no_bias)
            scripted_no_bias(x, x)
            scripted_no_bias(x, x)
            scripted_no_bias(x, x)
            has_bias = M(True)
            check_against_reference(self,
                                    scripted_no_bias,
                                    no_bias,
                                    lambda x: x, (
                                        x,
                                        x,
                                    ),
                                    check_types=False)
            scripted_has_bias = torch.jit.script(has_bias)
            scripted_has_bias(x, x)
            scripted_has_bias(x, x)
            scripted_has_bias(x, x)
            check_against_reference(self,
                                    scripted_has_bias,
                                    has_bias,
                                    lambda x: x, (
                                        x,
                                        x,
                                    ),
                                    check_types=False)
Ejemplo n.º 4
0
    def test_variant_consistency_jit(self, device, dtype, op):
        _requires_grad = op.supports_autograd and (
            dtype.is_floating_point or op.supports_complex_autograd)
        samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)

        for sample in samples:
            # Acquires variants to test
            func = op.get_op()
            method = op.get_method()
            variants = {
                # TODO: inplace tests currently fail, fix and add inplace variant
                'function': func,
                'method': method,
            }

            # Test traced and scripted consistency
            for func_type, variant in variants.items():
                if variant is None:
                    continue

                # Create accessor for script function variant
                name = op.name + '_' if func_type == 'inplace' else op.name

                # run with disable_autodiff_subgraph_inlining(True) to test
                #   autodiff support. Context manager forces the graph to contain
                #   DifferentiableGraph nodes if they are present
                with disable_autodiff_subgraph_inlining():
                    # Check scripted forward, grad, and grad grad
                    script_fn = create_script_fn(self, name, func_type)

                    def out_fn(output):
                        # Processes the output for autograd
                        if sample.output_process_fn_grad is not None:
                            return sample.output_process_fn_grad(output)
                        return output

                    check_against_reference(self,
                                            script_fn,
                                            func,
                                            out_fn,
                                            (sample.input, ) + sample.args,
                                            sample.kwargs,
                                            no_grad=not _requires_grad)

                    # Check traced forward, grad, and grad grad
                    traced_fn = create_traced_fn(self, variant)
                    check_against_reference(self,
                                            traced_fn,
                                            func,
                                            out_fn,
                                            (sample.input, ) + sample.args,
                                            sample.kwargs,
                                            no_grad=not _requires_grad)

                    # Check alias annotation schema for correctness (make
                    #   sure inputs that aren't supposed to be modified aren't)
                    # Note: only runs in float32 and int64 because schema isn't affected by dtype,
                    #   so running it on all dtypes is would be excessive
                    if dtype in [torch.float32, torch.int32]:
                        check_alias_annotation(name,
                                               (sample.input, ) + sample.args,
                                               sample.kwargs,
                                               func_type=func_type,
                                               aten_name=op.aten_name)

                    # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
                    if dtype is torch.float32:
                        # Sandcastle doesn't fuse nodes
                        if IS_SANDCASTLE:
                            # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
                            nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
                            fusible_nodes = []
                        else:
                            nonfusible_nodes = op.autodiff_nonfusible_nodes
                            fusible_nodes = op.autodiff_fusible_nodes

                        self.assertAutodiffNode(traced_fn.last_graph,
                                                op.assert_autodiffed,
                                                nonfusible_nodes,
                                                fusible_nodes)
                        self.assertAutodiffNode(script_fn.last_graph,
                                                op.assert_autodiffed,
                                                nonfusible_nodes,
                                                fusible_nodes)
Ejemplo n.º 5
0
    def test_variant_consistency_jit(self, device, dtype, op):
        samples = op.sample_inputs(device, dtype, requires_grad=True)
        if len(samples) == 0:
            self.skipTest("Skipped! No sample inputs!")

        for sample in samples:

            # Acquires variants to test
            method = op.get_method()
            inplace = op.get_inplace()
            variants = (v for v in (method, inplace) if v is not None)

            # Adds function variant to variant list
            # TODO: inplace tests currently fail
            # variants = (v for v in (op, method, inplace) if v is not None)
            variants = (v for v in (op, method) if v is not None)

            # Test traced and scripted consistency
            for variant in variants:
                # Create accessor for script function variant
                if variant is op:
                    name = op.name
                    func_type = 'function'
                elif variant is method:
                    name = op.name
                    func_type = 'method'
                else:  # variant is inplace
                    assert variant is inplace
                    name = op.name + "_"
                    func_type = 'inplace'

                # run with disable_autodiff_subgraph_inlining(True) to test
                #   autodiff support. Context manager forces the graph to contain
                #   DifferentiableGraph nodes if they are present
                with disable_autodiff_subgraph_inlining():

                    def fn(*inputs, **kwargs):
                        attr = getattr(inputs[0], name)
                        output = attr(*inputs[1:], **kwargs)
                        return op.output_func(output)

                    # bfloat16 grad doesn't work for some operators
                    dtypes_to_grad_check = floating_and_complex_types_and(torch.half) \
                        if op.skip_bfloat16_grad else floating_and_complex_types_and(torch.half, torch.bfloat16)

                    # Check scripted forward, grad, and grad grad
                    script_fn = create_script_fn(self, name, func_type,
                                                 op.output_func)

                    check_against_reference(
                        self,
                        script_fn,
                        fn, (*sample.input, ) + sample.args,
                        sample.kwargs,
                        no_grad=(dtype not in dtypes_to_grad_check))

                    # Check traced forward, grad, and grad grad
                    traced_fn = create_traced_fn(self, variant)
                    check_against_reference(
                        self,
                        traced_fn,
                        fn, (*sample.input, ) + sample.args,
                        sample.kwargs,
                        no_grad=(dtype not in dtypes_to_grad_check))

                    # Check alias annotation schema for correctness (make
                    #   sure inputs that aren't supposed to be modified aren't)
                    # Note: only runs in float32 and int64 because schema isn't affected by dtype,
                    #   so running it on all dtypes is would be excessive
                    if dtype in [torch.float32, torch.int32]:
                        check_alias_annotation(name,
                                               (*sample.input, ) + sample.args,
                                               sample.kwargs)

                    # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
                    if dtype is torch.float32:
                        # Sandcastle doesn't fuse nodes
                        if IS_SANDCASTLE:
                            # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
                            nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
                            fusible_nodes = []
                        else:
                            nonfusible_nodes = op.autodiff_nonfusible_nodes
                            fusible_nodes = op.autodiff_fusible_nodes

                        self.assertAutodiffNode(traced_fn.last_graph,
                                                op.assert_autodiffed,
                                                nonfusible_nodes,
                                                fusible_nodes)
                        self.assertAutodiffNode(script_fn.last_graph,
                                                op.assert_autodiffed,
                                                nonfusible_nodes,
                                                fusible_nodes)
Ejemplo n.º 6
0
    def test_variant_consistency_jit(self, device, dtype, op):
        _requires_grad = op.supports_autograd and (
            dtype.is_floating_point
            or op.supports_complex_autograd(torch.device(device).type))

        include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
        samples = op.sample_inputs(
            device,
            dtype,
            requires_grad=_requires_grad,
            include_conjugated_inputs=include_conjugated_inputs)

        # Acquires variants to test
        func = op.get_op()
        method = op.get_method()
        variants = {
            # TODO: inplace tests currently fail, fix and add inplace variant
            'function': func,
            'method': method,
        }

        # TODO: find better way to standardize on op registration itself..
        has_fake_function = op.name in ["resize_", 'resize_as_']

        if has_fake_function:
            variants = {'method': getattr(torch.Tensor, op.name)}
            samples = op.sample_inputs(device, dtype, requires_grad=False)

        support_script = op.supports_scripting

        tested = False
        for sample in samples:
            # Test traced and scripted consistency
            for func_type, variant in variants.items():
                if variant is None:
                    continue

                # scripting and check_alias_analysis do not work with lambdas
                # lambdas are typically used as a way to simulate methods without
                # functional variants, so rely on the other variant for testing
                # for now
                if is_lambda(variant):
                    continue

                tested = True

                # Create accessor for script function variant
                name = op.name + '_' if func_type == 'inplace' else op.name

                # run with disable_autodiff_subgraph_inlining(True) to test
                #   autodiff support. Context manager forces the graph to contain
                #   DifferentiableGraph nodes if they are present
                with disable_autodiff_subgraph_inlining():
                    # Check scripted forward, grad, and grad grad
                    if support_script:
                        script_fn = create_script_fn(self, name, func_type)

                    def out_fn(output):
                        # Processes the output for autograd
                        if sample.output_process_fn_grad is not None:
                            return sample.output_process_fn_grad(output)
                        return output

                    def get_sample():
                        return clone_input_helper(
                            sample.input
                        ) if op.name[-1] == '_' else sample.input

                    if support_script:
                        check_against_reference(
                            self,
                            script_fn,
                            func,
                            out_fn, (get_sample(), ) + sample.args,
                            sample.kwargs,
                            no_grad=not _requires_grad,
                            no_gradgrad=not op.supports_gradgrad)

                    # Check traced forward, grad, and grad grad
                    # TODO: fix tracing here
                    supports_tracing = not has_fake_function
                    if op.assert_jit_shape_analysis:
                        self.assertTrue(supports_tracing)

                    if supports_tracing:
                        traced_fn = create_traced_fn(self, variant)
                        check_against_reference(
                            self,
                            traced_fn,
                            func,
                            out_fn, (get_sample(), ) + sample.args,
                            sample.kwargs,
                            no_grad=not _requires_grad,
                            no_gradgrad=not op.supports_gradgrad)

                    # Check alias annotation schema for correctness (make
                    #   sure inputs that aren't supposed to be modified aren't)
                    # Note: only runs in float32 because schema isn't affected by dtype,
                    #   so running it on all dtypes is would be excessive
                    if dtype == torch.float32:
                        # TODO: no reason why we cant run this with tracing graph
                        if support_script and op.name != "rsub":
                            check_alias_annotation(name, (get_sample(), ) +
                                                   sample.args,
                                                   sample.kwargs,
                                                   func_type=func_type,
                                                   aten_name=op.aten_name)

                        # TODO: use script graph as well
                        checked_shape_analysis = False
                        if supports_tracing:
                            out = variant(get_sample(), *sample.args,
                                          **sample.kwargs)

                            # right now, tuple of outputs and tensor output supported
                            # TODO: list of tensor outputs
                            tuple_of_tensors = isinstance(out, tuple) and all([
                                isinstance(elem, torch.Tensor) for elem in out
                            ])

                            if isinstance(out,
                                          torch.Tensor) or tuple_of_tensors:
                                if tuple_of_tensors:
                                    sizes = [elem.size() for elem in out]
                                else:
                                    sizes = out.size()
                                self.checkShapeAnalysis(
                                    sizes, traced_fn.graph,
                                    op.assert_jit_shape_analysis)
                                checked_shape_analysis = True
                        if op.assert_jit_shape_analysis:
                            self.assertTrue(checked_shape_analysis)

                    # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
                    if dtype is torch.float32:
                        # Sandcastle doesn't fuse nodes
                        if IS_SANDCASTLE:
                            # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
                            nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
                            fusible_nodes = []
                        else:
                            nonfusible_nodes = op.autodiff_nonfusible_nodes
                            fusible_nodes = op.autodiff_fusible_nodes

                        if supports_tracing:
                            self.assertAutodiffNode(traced_fn.last_graph,
                                                    op.assert_autodiffed,
                                                    nonfusible_nodes,
                                                    fusible_nodes)
                        if support_script:
                            self.assertAutodiffNode(script_fn.last_graph,
                                                    op.assert_autodiffed,
                                                    nonfusible_nodes,
                                                    fusible_nodes)
        assert tested, "JIT Test does not execute any logic"
Ejemplo n.º 7
0
    def indiv_variant_test_jit(self, device, dtype, op, sample, func_type,
                               variant, has_fake_function):
        _requires_grad = (dtype in op.supported_backward_dtypes(
            torch.device(device).type))
        support_script = op.supports_scripting
        # Create accessor for script function variant
        name = op.name + '_' if func_type == 'inplace' else op.name

        # run with disable_autodiff_subgraph_inlining(True) to test
        #   autodiff support. Context manager forces the graph to contain
        #   DifferentiableGraph nodes if they are present
        with disable_autodiff_subgraph_inlining():
            # Check scripted forward, grad, and grad grad
            if support_script:
                script_fn = create_script_fn(self, name, func_type)

            def out_fn(output):
                # Processes the output for autograd
                if sample.output_process_fn_grad is not None:
                    return sample.output_process_fn_grad(output)
                return output

            def get_sample():
                return clone_input_helper(
                    sample.input) if op.name[-1] == '_' else sample.input

            if support_script:
                check_against_reference(self,
                                        script_fn,
                                        op.get_op(),
                                        out_fn, (get_sample(), ) + sample.args,
                                        sample.kwargs,
                                        no_grad=not _requires_grad,
                                        no_gradgrad=not op.supports_gradgrad)

            # Check traced forward, grad, and grad grad
            # TODO: fix tracing here
            supports_tracing = op.supports_tracing and not has_fake_function
            if op.assert_jit_shape_analysis:
                self.assertTrue(supports_tracing)

            if supports_tracing:
                traced_fn = create_traced_fn(self, variant)
                check_against_reference(self,
                                        traced_fn,
                                        op.get_op(),
                                        out_fn, (get_sample(), ) + sample.args,
                                        sample.kwargs,
                                        no_grad=not _requires_grad,
                                        no_gradgrad=not op.supports_gradgrad)

            # Check alias annotation schema for correctness (make
            #   sure inputs that aren't supposed to be modified aren't)
            # Note: only runs in float32 because schema isn't affected by dtype,
            #   so running it on all dtypes is would be excessive
            if dtype == torch.float32:
                # TODO: no reason why we cant run this with tracing graph
                if support_script and op.name != "rsub":
                    check_alias_annotation(name,
                                           (get_sample(), ) + sample.args,
                                           sample.kwargs,
                                           func_type=func_type,
                                           aten_name=op.aten_name)

                # TODO: use script graph as well
                checked_shape_analysis = False
                if supports_tracing:
                    out = variant(get_sample(), *sample.args, **sample.kwargs)

                    # right now, tuple of outputs and tensor output supported
                    # TODO: list of tensor outputs
                    tuple_of_tensors = isinstance(out, tuple) and all(
                        [isinstance(elem, torch.Tensor) for elem in out])

                    if isinstance(out, torch.Tensor) or tuple_of_tensors:
                        if tuple_of_tensors:
                            sizes = [elem.size() for elem in out]
                        else:
                            sizes = out.size()
                        self.checkShapeAnalysis(sizes, traced_fn.graph,
                                                op.assert_jit_shape_analysis)
                        checked_shape_analysis = True
                if op.assert_jit_shape_analysis:
                    self.assertTrue(checked_shape_analysis)

            # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
            if dtype is torch.float32:
                # Sandcastle doesn't fuse nodes
                if IS_SANDCASTLE:
                    # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
                    nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
                    fusible_nodes = []
                else:
                    nonfusible_nodes = op.autodiff_nonfusible_nodes
                    fusible_nodes = op.autodiff_fusible_nodes

                if supports_tracing:
                    self.assertAutodiffNode(traced_fn.last_graph,
                                            op.assert_autodiffed,
                                            nonfusible_nodes, fusible_nodes)
                if support_script:
                    self.assertAutodiffNode(script_fn.last_graph,
                                            op.assert_autodiffed,
                                            nonfusible_nodes, fusible_nodes)