示例#1
0
    def custom_rules_test_base(self,
                               device,
                               dtype,
                               op,
                               allow_eager_fail=False):
        try:
            samples = op.sample_inputs(device, dtype, requires_grad=False)
            sample_input = first_sample(self, samples)
            input_args = [sample_input.input, *sample_input.args]
            expected_res = op(*input_args, **sample_input.kwargs)

        except Exception as e:
            if allow_eager_fail:
                return
            else:
                raise e

        func = op.get_op()
        traced_fn = create_traced_fn(self, func)

        # Have to run the traced function to actually generate the trace
        traced_fn(sample_input.input, *sample_input.args,
                  **sample_input.kwargs)

        # Run the Dtype Analysis
        graph = traced_fn.graph  # Note this is a cached graph
        input_tensors = [t for t in input_args if isinstance(t, torch.Tensor)]
        input_tensors += [
            v for v in sample_input.kwargs.values()
            if isinstance(v, torch.Tensor)
        ]
        self.prop_dtype_on_graph(graph, input_tensors)
        self.assert_output_dtype_equal(expected_res, graph)
示例#2
0
    def test_aliases(self):
        # tests that op aliases are correctly being normalized
        # does not check for other properties such as correctness because
        # the common method registry gets tested for those in test_jit.py

        op_registry = {}
        for op in method_tests():
            op_registry[op[0]] = op

        for alias, mapping in op_alias_mappings.items():
            assert alias in op_registry, "Test not found for {} alias".format(alias)

            name, self_size, args, kwargs, output_process_fn = get_defaults(*op_registry[alias])

            def fn(*inputs, **kwargs):
                attr = getattr(inputs[0], name)
                output = attr(*inputs[1:], **kwargs)
                return output_process_fn(output)

            self_variable = create_input((self_size,))[0][0]
            args_variable, kwargs_variable = create_input(args, requires_grad=False, call_kwargs=kwargs)

            traced_fn = create_traced_fn(self, fn)
            inputs = (self_variable,) + args_variable
            traced_fn(*inputs, **kwargs)
            last_graph = traced_fn.last_graph
            FileCheck().check(mapping).check_not(alias).run(last_graph)

            script_fn = create_script_fn(self, name, 'method', output_process_fn)
            script_fn(*inputs, **kwargs)
            last_graph = script_fn.last_graph
            FileCheck().check(mapping).check_not(alias).run(last_graph)
示例#3
0
    def test_variant_consistency_jit(self, device, dtype, op):
        _requires_grad = op.supports_autograd and (
            dtype.is_floating_point or op.supports_complex_autograd)
        samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)

        for sample in samples:
            # Acquires variants to test
            func = op.get_op()
            method = op.get_method()
            variants = {
                # TODO: inplace tests currently fail, fix and add inplace variant
                'function': func,
                'method': method,
            }

            # Test traced and scripted consistency
            for func_type, variant in variants.items():
                if variant is None:
                    continue

                # Create accessor for script function variant
                name = op.name + '_' if func_type == 'inplace' else op.name

                # run with disable_autodiff_subgraph_inlining(True) to test
                #   autodiff support. Context manager forces the graph to contain
                #   DifferentiableGraph nodes if they are present
                with disable_autodiff_subgraph_inlining():
                    # Check scripted forward, grad, and grad grad
                    script_fn = create_script_fn(self, name, func_type)

                    def out_fn(output):
                        # Processes the output for autograd
                        if sample.output_process_fn_grad is not None:
                            return sample.output_process_fn_grad(output)
                        return output

                    check_against_reference(self,
                                            script_fn,
                                            func,
                                            out_fn,
                                            (sample.input, ) + sample.args,
                                            sample.kwargs,
                                            no_grad=not _requires_grad)

                    # Check traced forward, grad, and grad grad
                    traced_fn = create_traced_fn(self, variant)
                    check_against_reference(self,
                                            traced_fn,
                                            func,
                                            out_fn,
                                            (sample.input, ) + sample.args,
                                            sample.kwargs,
                                            no_grad=not _requires_grad)

                    # Check alias annotation schema for correctness (make
                    #   sure inputs that aren't supposed to be modified aren't)
                    # Note: only runs in float32 and int64 because schema isn't affected by dtype,
                    #   so running it on all dtypes is would be excessive
                    if dtype in [torch.float32, torch.int32]:
                        check_alias_annotation(name,
                                               (sample.input, ) + sample.args,
                                               sample.kwargs,
                                               func_type=func_type,
                                               aten_name=op.aten_name)

                    # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
                    if dtype is torch.float32:
                        # Sandcastle doesn't fuse nodes
                        if IS_SANDCASTLE:
                            # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
                            nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
                            fusible_nodes = []
                        else:
                            nonfusible_nodes = op.autodiff_nonfusible_nodes
                            fusible_nodes = op.autodiff_fusible_nodes

                        self.assertAutodiffNode(traced_fn.last_graph,
                                                op.assert_autodiffed,
                                                nonfusible_nodes,
                                                fusible_nodes)
                        self.assertAutodiffNode(script_fn.last_graph,
                                                op.assert_autodiffed,
                                                nonfusible_nodes,
                                                fusible_nodes)
示例#4
0
    def test_variant_consistency_jit(self, device, dtype, op):
        samples = op.sample_inputs(device, dtype, requires_grad=True)
        if len(samples) == 0:
            self.skipTest("Skipped! No sample inputs!")

        for sample in samples:

            # Acquires variants to test
            method = op.get_method()
            inplace = op.get_inplace()
            variants = (v for v in (method, inplace) if v is not None)

            # Adds function variant to variant list
            # TODO: inplace tests currently fail
            # variants = (v for v in (op, method, inplace) if v is not None)
            variants = (v for v in (op, method) if v is not None)

            # Test traced and scripted consistency
            for variant in variants:
                # Create accessor for script function variant
                if variant is op:
                    name = op.name
                    func_type = 'function'
                elif variant is method:
                    name = op.name
                    func_type = 'method'
                else:  # variant is inplace
                    assert variant is inplace
                    name = op.name + "_"
                    func_type = 'inplace'

                # run with disable_autodiff_subgraph_inlining(True) to test
                #   autodiff support. Context manager forces the graph to contain
                #   DifferentiableGraph nodes if they are present
                with disable_autodiff_subgraph_inlining():

                    def fn(*inputs, **kwargs):
                        attr = getattr(inputs[0], name)
                        output = attr(*inputs[1:], **kwargs)
                        return op.output_func(output)

                    # bfloat16 grad doesn't work for some operators
                    dtypes_to_grad_check = floating_and_complex_types_and(torch.half) \
                        if op.skip_bfloat16_grad else floating_and_complex_types_and(torch.half, torch.bfloat16)

                    # Check scripted forward, grad, and grad grad
                    script_fn = create_script_fn(self, name, func_type,
                                                 op.output_func)

                    check_against_reference(
                        self,
                        script_fn,
                        fn, (*sample.input, ) + sample.args,
                        sample.kwargs,
                        no_grad=(dtype not in dtypes_to_grad_check))

                    # Check traced forward, grad, and grad grad
                    traced_fn = create_traced_fn(self, variant)
                    check_against_reference(
                        self,
                        traced_fn,
                        fn, (*sample.input, ) + sample.args,
                        sample.kwargs,
                        no_grad=(dtype not in dtypes_to_grad_check))

                    # Check alias annotation schema for correctness (make
                    #   sure inputs that aren't supposed to be modified aren't)
                    # Note: only runs in float32 and int64 because schema isn't affected by dtype,
                    #   so running it on all dtypes is would be excessive
                    if dtype in [torch.float32, torch.int32]:
                        check_alias_annotation(name,
                                               (*sample.input, ) + sample.args,
                                               sample.kwargs)

                    # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
                    if dtype is torch.float32:
                        # Sandcastle doesn't fuse nodes
                        if IS_SANDCASTLE:
                            # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
                            nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
                            fusible_nodes = []
                        else:
                            nonfusible_nodes = op.autodiff_nonfusible_nodes
                            fusible_nodes = op.autodiff_fusible_nodes

                        self.assertAutodiffNode(traced_fn.last_graph,
                                                op.assert_autodiffed,
                                                nonfusible_nodes,
                                                fusible_nodes)
                        self.assertAutodiffNode(script_fn.last_graph,
                                                op.assert_autodiffed,
                                                nonfusible_nodes,
                                                fusible_nodes)
示例#5
0
    def test_variant_consistency_jit(self, device, dtype, op):
        _requires_grad = op.supports_autograd and (
            dtype.is_floating_point
            or op.supports_complex_autograd(torch.device(device).type))

        include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
        samples = op.sample_inputs(
            device,
            dtype,
            requires_grad=_requires_grad,
            include_conjugated_inputs=include_conjugated_inputs)

        # Acquires variants to test
        func = op.get_op()
        method = op.get_method()
        variants = {
            # TODO: inplace tests currently fail, fix and add inplace variant
            'function': func,
            'method': method,
        }

        # TODO: find better way to standardize on op registration itself..
        has_fake_function = op.name in ["resize_", 'resize_as_']

        if has_fake_function:
            variants = {'method': getattr(torch.Tensor, op.name)}
            samples = op.sample_inputs(device, dtype, requires_grad=False)

        support_script = op.supports_scripting

        tested = False
        for sample in samples:
            # Test traced and scripted consistency
            for func_type, variant in variants.items():
                if variant is None:
                    continue

                # scripting and check_alias_analysis do not work with lambdas
                # lambdas are typically used as a way to simulate methods without
                # functional variants, so rely on the other variant for testing
                # for now
                if is_lambda(variant):
                    continue

                tested = True

                # Create accessor for script function variant
                name = op.name + '_' if func_type == 'inplace' else op.name

                # run with disable_autodiff_subgraph_inlining(True) to test
                #   autodiff support. Context manager forces the graph to contain
                #   DifferentiableGraph nodes if they are present
                with disable_autodiff_subgraph_inlining():
                    # Check scripted forward, grad, and grad grad
                    if support_script:
                        script_fn = create_script_fn(self, name, func_type)

                    def out_fn(output):
                        # Processes the output for autograd
                        if sample.output_process_fn_grad is not None:
                            return sample.output_process_fn_grad(output)
                        return output

                    def get_sample():
                        return clone_input_helper(
                            sample.input
                        ) if op.name[-1] == '_' else sample.input

                    if support_script:
                        check_against_reference(
                            self,
                            script_fn,
                            func,
                            out_fn, (get_sample(), ) + sample.args,
                            sample.kwargs,
                            no_grad=not _requires_grad,
                            no_gradgrad=not op.supports_gradgrad)

                    # Check traced forward, grad, and grad grad
                    # TODO: fix tracing here
                    supports_tracing = not has_fake_function
                    if op.assert_jit_shape_analysis:
                        self.assertTrue(supports_tracing)

                    if supports_tracing:
                        traced_fn = create_traced_fn(self, variant)
                        check_against_reference(
                            self,
                            traced_fn,
                            func,
                            out_fn, (get_sample(), ) + sample.args,
                            sample.kwargs,
                            no_grad=not _requires_grad,
                            no_gradgrad=not op.supports_gradgrad)

                    # Check alias annotation schema for correctness (make
                    #   sure inputs that aren't supposed to be modified aren't)
                    # Note: only runs in float32 because schema isn't affected by dtype,
                    #   so running it on all dtypes is would be excessive
                    if dtype == torch.float32:
                        # TODO: no reason why we cant run this with tracing graph
                        if support_script and op.name != "rsub":
                            check_alias_annotation(name, (get_sample(), ) +
                                                   sample.args,
                                                   sample.kwargs,
                                                   func_type=func_type,
                                                   aten_name=op.aten_name)

                        # TODO: use script graph as well
                        checked_shape_analysis = False
                        if supports_tracing:
                            out = variant(get_sample(), *sample.args,
                                          **sample.kwargs)

                            # right now, tuple of outputs and tensor output supported
                            # TODO: list of tensor outputs
                            tuple_of_tensors = isinstance(out, tuple) and all([
                                isinstance(elem, torch.Tensor) for elem in out
                            ])

                            if isinstance(out,
                                          torch.Tensor) or tuple_of_tensors:
                                if tuple_of_tensors:
                                    sizes = [elem.size() for elem in out]
                                else:
                                    sizes = out.size()
                                self.checkShapeAnalysis(
                                    sizes, traced_fn.graph,
                                    op.assert_jit_shape_analysis)
                                checked_shape_analysis = True
                        if op.assert_jit_shape_analysis:
                            self.assertTrue(checked_shape_analysis)

                    # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
                    if dtype is torch.float32:
                        # Sandcastle doesn't fuse nodes
                        if IS_SANDCASTLE:
                            # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
                            nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
                            fusible_nodes = []
                        else:
                            nonfusible_nodes = op.autodiff_nonfusible_nodes
                            fusible_nodes = op.autodiff_fusible_nodes

                        if supports_tracing:
                            self.assertAutodiffNode(traced_fn.last_graph,
                                                    op.assert_autodiffed,
                                                    nonfusible_nodes,
                                                    fusible_nodes)
                        if support_script:
                            self.assertAutodiffNode(script_fn.last_graph,
                                                    op.assert_autodiffed,
                                                    nonfusible_nodes,
                                                    fusible_nodes)
        assert tested, "JIT Test does not execute any logic"
示例#6
0
    def indiv_variant_test_jit(self, device, dtype, op, sample, func_type,
                               variant, has_fake_function):
        _requires_grad = (dtype in op.supported_backward_dtypes(
            torch.device(device).type))
        support_script = op.supports_scripting
        # Create accessor for script function variant
        name = op.name + '_' if func_type == 'inplace' else op.name

        # run with disable_autodiff_subgraph_inlining(True) to test
        #   autodiff support. Context manager forces the graph to contain
        #   DifferentiableGraph nodes if they are present
        with disable_autodiff_subgraph_inlining():
            # Check scripted forward, grad, and grad grad
            if support_script:
                script_fn = create_script_fn(self, name, func_type)

            def out_fn(output):
                # Processes the output for autograd
                if sample.output_process_fn_grad is not None:
                    return sample.output_process_fn_grad(output)
                return output

            def get_sample():
                return clone_input_helper(
                    sample.input) if op.name[-1] == '_' else sample.input

            if support_script:
                check_against_reference(self,
                                        script_fn,
                                        op.get_op(),
                                        out_fn, (get_sample(), ) + sample.args,
                                        sample.kwargs,
                                        no_grad=not _requires_grad,
                                        no_gradgrad=not op.supports_gradgrad)

            # Check traced forward, grad, and grad grad
            # TODO: fix tracing here
            supports_tracing = op.supports_tracing and not has_fake_function
            if op.assert_jit_shape_analysis:
                self.assertTrue(supports_tracing)

            if supports_tracing:
                traced_fn = create_traced_fn(self, variant)
                check_against_reference(self,
                                        traced_fn,
                                        op.get_op(),
                                        out_fn, (get_sample(), ) + sample.args,
                                        sample.kwargs,
                                        no_grad=not _requires_grad,
                                        no_gradgrad=not op.supports_gradgrad)

            # Check alias annotation schema for correctness (make
            #   sure inputs that aren't supposed to be modified aren't)
            # Note: only runs in float32 because schema isn't affected by dtype,
            #   so running it on all dtypes is would be excessive
            if dtype == torch.float32:
                # TODO: no reason why we cant run this with tracing graph
                if support_script and op.name != "rsub":
                    check_alias_annotation(name,
                                           (get_sample(), ) + sample.args,
                                           sample.kwargs,
                                           func_type=func_type,
                                           aten_name=op.aten_name)

                # TODO: use script graph as well
                checked_shape_analysis = False
                if supports_tracing:
                    out = variant(get_sample(), *sample.args, **sample.kwargs)

                    # right now, tuple of outputs and tensor output supported
                    # TODO: list of tensor outputs
                    tuple_of_tensors = isinstance(out, tuple) and all(
                        [isinstance(elem, torch.Tensor) for elem in out])

                    if isinstance(out, torch.Tensor) or tuple_of_tensors:
                        if tuple_of_tensors:
                            sizes = [elem.size() for elem in out]
                        else:
                            sizes = out.size()
                        self.checkShapeAnalysis(sizes, traced_fn.graph,
                                                op.assert_jit_shape_analysis)
                        checked_shape_analysis = True
                if op.assert_jit_shape_analysis:
                    self.assertTrue(checked_shape_analysis)

            # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample
            if dtype is torch.float32:
                # Sandcastle doesn't fuse nodes
                if IS_SANDCASTLE:
                    # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
                    nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
                    fusible_nodes = []
                else:
                    nonfusible_nodes = op.autodiff_nonfusible_nodes
                    fusible_nodes = op.autodiff_fusible_nodes

                if supports_tracing:
                    self.assertAutodiffNode(traced_fn.last_graph,
                                            op.assert_autodiffed,
                                            nonfusible_nodes, fusible_nodes)
                if support_script:
                    self.assertAutodiffNode(script_fn.last_graph,
                                            op.assert_autodiffed,
                                            nonfusible_nodes, fusible_nodes)