def _expand_and_optimize_ir(torchscript): """ Given a torch.jit.ScriptModule, convert it to a optimized torch._C.Graph and dict of model parameter's names to tensors. """ graph = torchscript.forward.graph # From PyTorch code: Inline function and method calls. _torch._C._jit_pass_inline(graph) # From PyTorch code: This inlines the forked section in the fork() # callsite and replaces uses of the result of wait() calls with the # values produced from the (now-inlined) forked section. _torch._C._jit_pass_inline_fork_wait(graph) # Starting from the return node, marks all nodes that feed into the # output, as well as nodes with side effects. Any nodes not marked are # eliminated. _torch._C._jit_pass_dce(graph) # From PyTorch code: checks well-formedness and invariants of graph. _torch._C._jit_pass_lint(graph) # Replaces a couple specific ops patterns (add, sub, mul, div, chunk). if version_lt(_torch, '1.6.0'): _torch._C._jit_pass_canonicalize_ops(graph) _torch._C._jit_pass_lint(graph) # From PyTorch code: This pass catches all of the small, easy to catch # peephole optimizations you might be interested in doing. # Eliminate no-op 'expand' nodes # Simplify x.t().t() to x # pass disabled for v1.6.0 and onwards, wrongly captures the shape of dummy inputs during tracing. _torch._C._jit_pass_peephole(graph, addmm_fusion_enabled=False) else: # v1.6.0 pass renamed _torch._C._jit_pass_canonicalize_graph_fuser_ops(graph) _torch._C._jit_pass_lint(graph) # From PyTorch docs: Renumber the graph so that all structurally # equivalent graphs have same numbers. graph = _torch._C._jit_pass_canonicalize(graph) _torch._C._jit_pass_lint(graph) if version_lt(_torch, '1.6.0'): # v1.6.0 JIT changes disallows pulling list values out of # prim::Constant. We can only pull scalar values. constant # propagation removes `listConstruct` and results in list values. # We disallow constant prop pass to keep them as scalars, and rely # on our own constant prop to interpret `listConstruct`. _torch._C._jit_pass_constant_propagation(graph) # NOTE: Don't need another DCE, it's included in constant propagation. _torch._C._jit_pass_lint(graph) # Get the params_dict and rename the getattr nodes in the graph graph, params_dict = TorchConverter._jit_pass_lower_graph( graph, torchscript) return graph, params_dict
class TestImageResample(TensorFlow2BaseTest): @pytest.mark.skipif( condition=version_lt(tf, "2.4") or version_ge(tf, "2.5"), reason= "tfa.image.resample requires TF 2.4+. On TF 2.5+, TF package has this symbol missing" ) @pytest.mark.parametrize( "use_cpu_only, backend, data_warp_shapes", itertools.product( [True, False], backends, [ # Data shape format: (Batch, Hin, Win, C) # Warp shape format: (Batch, Hout, Wout, 2) [(1, 3, 3, 1), (1, 3, 3, 2)], # no size change [(2, 5, 5, 3), (2, 3, 3, 2)], # down-sampling [(3, 6, 6, 1), (3, 8, 8, 2)], # up-sampling ], ), ) def test_resample( self, use_cpu_only, backend, data_warp_shapes, ): if backend[0] == "neuralnetwork": pytest.xfail("nn backend not supported") tfa = pytest.importorskip("tensorflow_addons") data_shape, warp_shape = data_warp_shapes @make_tf_graph([data_shape, warp_shape]) def build_model(x, warp): return tfa.image.resampler(data=x, warp=warp) model, inputs, outputs = build_model # warp exceeding input sizes in order to test more padding modes input_values = [ random_gen(data_shape, -100, 100), random_gen(warp_shape, -15, 15), ] input_dict = dict(zip(inputs, input_values)) self.run_compare_tf2( model, input_dict, outputs, use_cpu_only=use_cpu_only, backend=backend, )
def _expand_and_optimize_ir(torchscript): """Given a torch.jit.ScriptModule, convert it to a optimized torch._C.Graph and dict of model parameter's names to tensors. """ # Recursively replaces all attribute accesses with the sub-graphs of # those modules. The resulting graph will be self-contained and will # not reference into other modules. Params will contain the "trainable" # inputs to the graph. graph, params = _torch._C._jit_pass_lower_graph( torchscript.forward.graph, torchscript._c) # From PyTorch code: Inline function and method calls. _torch._C._jit_pass_inline(graph) # From PyTorch code: This inlines the forked section in the fork() # callsite and replaces uses of the result of wait() calls with the # values produced from the (now-inlined) forked section. _torch._C._jit_pass_inline_fork_wait(graph) # Starting from the return node, marks all nodes that feed into the # output, as well as nodes with side effects. Any nodes not marked are # eliminated. _torch._C._jit_pass_dce(graph) # From PyTorch code: checks well-formedness and invariants of graph. _torch._C._jit_pass_lint(graph) # From PyTorch code: remove all in-place ops and replace them with # out-of-place equivalents. # e.g. # %foo = aten::add_(%foo, %n) # becomes # %foo.2 = aten::add(%foo, %n) _torch._C._jit_pass_remove_inplace_ops(graph) _torch._C._jit_pass_dce(graph) _torch._C._jit_pass_lint(graph) # Replaces a couple specific ops patterns (add, sub, mul, div, chunk). if version_lt(_torch, '1.6.0'): _torch._C._jit_pass_canonicalize_ops(graph) _torch._C._jit_pass_lint(graph) # From PyTorch code: This pass catches all of the small, easy to catch # peephole optimizations you might be interested in doing. # Eliminate no-op 'expand' nodes # Simplify x.t().t() to x # pass disabled for v1.6.0 and onwards, wrongly captures the shape of dummy inputs during tracing. _torch._C._jit_pass_peephole(graph, addmm_fusion_enabled=False) else: # v1.6.0 pass renamed _torch._C._jit_pass_canonicalize_graph_fuser_ops(graph) _torch._C._jit_pass_lint(graph) # From PyTorch docs: Renumber the graph so that all structurally # equivalent graphs have same numbers. graph = _torch._C._jit_pass_canonicalize(graph) _torch._C._jit_pass_lint(graph) if version_lt(_torch, '1.6.0'): # v1.6.0 JIT changes disallows pulling list values out of # prim::Constant. We can only pull scalar values. constant # propagation removes `listConstruct` and results in list values. # We disallow constant prop pass to keep them as scalars, and rely # on our own constant prop to interpret `listConstruct`. _torch._C._jit_pass_constant_propagation(graph) # NOTE: Don't need another DCE, it's included in constant propagation. _torch._C._jit_pass_lint(graph) input_and_param_names = [val.debugName() for val in graph.inputs()] param_names = input_and_param_names[len(input_and_param_names) - len(params):] params_dict = dict(zip(param_names, params)) return graph, params_dict