Ejemplo n.º 1
0
    def _expand_and_optimize_ir(torchscript):
        """
        Given a torch.jit.ScriptModule, convert it to a optimized
        torch._C.Graph and dict of model parameter's names to tensors.
        """
        graph = torchscript.forward.graph

        # From PyTorch code: Inline function and method calls.
        _torch._C._jit_pass_inline(graph)
        # From PyTorch code: This inlines the forked section in the fork()
        # callsite and replaces uses of the result of wait() calls with the
        # values produced from the (now-inlined) forked section.
        _torch._C._jit_pass_inline_fork_wait(graph)
        # Starting from the return node, marks all nodes that feed into the
        # output, as well as nodes with side effects. Any nodes not marked are
        # eliminated.
        _torch._C._jit_pass_dce(graph)
        # From PyTorch code: checks well-formedness and invariants of graph.
        _torch._C._jit_pass_lint(graph)
        # Replaces a couple specific ops patterns (add, sub, mul, div, chunk).
        if version_lt(_torch, '1.6.0'):
            _torch._C._jit_pass_canonicalize_ops(graph)
            _torch._C._jit_pass_lint(graph)

            # From PyTorch code: This pass catches all of the small, easy to catch
            # peephole optimizations you might be interested in doing.
            #     Eliminate no-op 'expand' nodes
            #     Simplify x.t().t() to x
            # pass disabled for v1.6.0 and onwards, wrongly captures the shape of dummy inputs during tracing.
            _torch._C._jit_pass_peephole(graph, addmm_fusion_enabled=False)
        else:
            # v1.6.0 pass renamed
            _torch._C._jit_pass_canonicalize_graph_fuser_ops(graph)
        _torch._C._jit_pass_lint(graph)

        # From PyTorch docs: Renumber the graph so that all structurally
        # equivalent graphs have same numbers.
        graph = _torch._C._jit_pass_canonicalize(graph)
        _torch._C._jit_pass_lint(graph)
        if version_lt(_torch, '1.6.0'):
            # v1.6.0 JIT changes disallows pulling list values out of
            # prim::Constant. We can only pull scalar values. constant
            # propagation removes `listConstruct` and results in list values.
            # We disallow constant prop pass to keep them as scalars, and rely
            # on our own constant prop to interpret `listConstruct`.
            _torch._C._jit_pass_constant_propagation(graph)
        # NOTE: Don't need another DCE, it's included in constant propagation.
        _torch._C._jit_pass_lint(graph)

        # Get the params_dict and rename the getattr nodes in the graph
        graph, params_dict = TorchConverter._jit_pass_lower_graph(
            graph, torchscript)

        return graph, params_dict
Ejemplo n.º 2
0
class TestImageResample(TensorFlow2BaseTest):
    @pytest.mark.skipif(
        condition=version_lt(tf, "2.4") or version_ge(tf, "2.5"),
        reason=
        "tfa.image.resample requires TF 2.4+. On TF 2.5+, TF package has this symbol missing"
    )
    @pytest.mark.parametrize(
        "use_cpu_only, backend, data_warp_shapes",
        itertools.product(
            [True, False],
            backends,
            [
                # Data shape format: (Batch, Hin, Win, C)
                # Warp shape format: (Batch, Hout, Wout, 2)
                [(1, 3, 3, 1), (1, 3, 3, 2)],  # no size change
                [(2, 5, 5, 3), (2, 3, 3, 2)],  # down-sampling
                [(3, 6, 6, 1), (3, 8, 8, 2)],  # up-sampling
            ],
        ),
    )
    def test_resample(
        self,
        use_cpu_only,
        backend,
        data_warp_shapes,
    ):
        if backend[0] == "neuralnetwork":
            pytest.xfail("nn backend not supported")

        tfa = pytest.importorskip("tensorflow_addons")

        data_shape, warp_shape = data_warp_shapes

        @make_tf_graph([data_shape, warp_shape])
        def build_model(x, warp):
            return tfa.image.resampler(data=x, warp=warp)

        model, inputs, outputs = build_model
        # warp exceeding input sizes in order to test more padding modes
        input_values = [
            random_gen(data_shape, -100, 100),
            random_gen(warp_shape, -15, 15),
        ]
        input_dict = dict(zip(inputs, input_values))
        self.run_compare_tf2(
            model,
            input_dict,
            outputs,
            use_cpu_only=use_cpu_only,
            backend=backend,
        )
Ejemplo n.º 3
0
    def _expand_and_optimize_ir(torchscript):
        """Given a torch.jit.ScriptModule, convert it to a optimized
        torch._C.Graph and dict of model parameter's names to tensors.
        """

        # Recursively replaces all attribute accesses with the sub-graphs of
        # those modules. The resulting graph will be self-contained and will
        # not reference into other modules. Params will contain the "trainable"
        # inputs to the graph.
        graph, params = _torch._C._jit_pass_lower_graph(
            torchscript.forward.graph, torchscript._c)

        # From PyTorch code: Inline function and method calls.
        _torch._C._jit_pass_inline(graph)
        # From PyTorch code: This inlines the forked section in the fork()
        # callsite and replaces uses of the result of wait() calls with the
        # values produced from the (now-inlined) forked section.
        _torch._C._jit_pass_inline_fork_wait(graph)
        # Starting from the return node, marks all nodes that feed into the
        # output, as well as nodes with side effects. Any nodes not marked are
        # eliminated.
        _torch._C._jit_pass_dce(graph)
        # From PyTorch code: checks well-formedness and invariants of graph.
        _torch._C._jit_pass_lint(graph)
        # From PyTorch code: remove all in-place ops and replace them with
        # out-of-place equivalents.
        # e.g.
        #   %foo = aten::add_(%foo, %n)
        # becomes
        #   %foo.2 = aten::add(%foo, %n)
        _torch._C._jit_pass_remove_inplace_ops(graph)
        _torch._C._jit_pass_dce(graph)
        _torch._C._jit_pass_lint(graph)
        # Replaces a couple specific ops patterns (add, sub, mul, div, chunk).
        if version_lt(_torch, '1.6.0'):
            _torch._C._jit_pass_canonicalize_ops(graph)
            _torch._C._jit_pass_lint(graph)

            # From PyTorch code: This pass catches all of the small, easy to catch
            # peephole optimizations you might be interested in doing.
            #     Eliminate no-op 'expand' nodes
            #     Simplify x.t().t() to x
            # pass disabled for v1.6.0 and onwards, wrongly captures the shape of dummy inputs during tracing.
            _torch._C._jit_pass_peephole(graph, addmm_fusion_enabled=False)
        else:
            # v1.6.0 pass renamed
            _torch._C._jit_pass_canonicalize_graph_fuser_ops(graph)
        _torch._C._jit_pass_lint(graph)

        # From PyTorch docs: Renumber the graph so that all structurally
        # equivalent graphs have same numbers.
        graph = _torch._C._jit_pass_canonicalize(graph)
        _torch._C._jit_pass_lint(graph)
        if version_lt(_torch, '1.6.0'):
            # v1.6.0 JIT changes disallows pulling list values out of
            # prim::Constant. We can only pull scalar values. constant
            # propagation removes `listConstruct` and results in list values.
            # We disallow constant prop pass to keep them as scalars, and rely
            # on our own constant prop to interpret `listConstruct`.
            _torch._C._jit_pass_constant_propagation(graph)
        # NOTE: Don't need another DCE, it's included in constant propagation.
        _torch._C._jit_pass_lint(graph)

        input_and_param_names = [val.debugName() for val in graph.inputs()]
        param_names = input_and_param_names[len(input_and_param_names) -
                                            len(params):]
        params_dict = dict(zip(param_names, params))

        return graph, params_dict