예제 #1
0
def wire_tfl_while_body(g, loop_node_inputs, output_shapes, output_dtypes,
                        cond_graph):
    """Wire subgraph graph into main."""

    g = copy.deepcopy(g)

    # onnx will pass in cond as argument
    iter_node = g.make_node("Placeholder", [],
                            name=utils.make_name("iteration_num"),
                            output_count=1,
                            dtypes=[TensorProto.INT64],
                            shapes=[[]])
    cond_node = g.make_node("Placeholder", [],
                            name=utils.make_name("cond"),
                            output_count=1,
                            dtypes=[TensorProto.BOOL],
                            shapes=[[]])
    cond_binding = parameter_binding(cond_graph, g.outputs)

    # in onnx the body inputs are: index, cond, [loop_vars]
    g.func_inputs = [iter_node.output[0], cond_node.output[0]] + g.func_inputs
    # tell graph lib to keep inputs in order
    g._order_sensitive_inputs = \
        [g.get_node_by_output(name) for name in g.func_inputs]  # pylint: disable=protected-access

    for p, c in zip(loop_node_inputs, g.func_inputs):
        shape = p.output_shapes[0]
        g.set_shape(c, shape)

    cond_outputs = inline_subgraph(g, cond_graph, "cond__", cond_binding)

    g.outputs = [cond_outputs[0]] + g.outputs
    return g
예제 #2
0
def wire_tfl_while_body(g, loop_node_inputs, output_shapes, output_dtypes,
                        cond_graph, scan_outputs):
    """Wire subgraph graph into main."""

    g = copy.deepcopy(g)
    graph_inputs = g.func_inputs.copy()

    # onnx will pass in cond as argument
    iter_node = g.make_node("Placeholder", [],
                            name=utils.make_name("iteration_num"),
                            output_count=1,
                            dtypes=[TensorProto.INT64],
                            shapes=[[]])
    cond_node = g.make_node("Placeholder", [],
                            name=utils.make_name("cond"),
                            output_count=1,
                            dtypes=[TensorProto.BOOL],
                            shapes=[[]])
    cond_binding = parameter_binding(cond_graph, g.outputs)

    to_remove = set()
    for idx, scan_output in scan_outputs:
        inp = g.get_node_by_output(graph_inputs[idx])

        # Remove consumers of scan input
        stack = [inp]
        while stack:
            node = stack.pop()
            if node not in to_remove:
                to_remove.add(node)
                for out in node.output:
                    stack += g.find_output_consumers(out)

        # Remove scan input from cond graph
        cond_binding = {
            k: "@@ALLOC" if v == g.outputs[idx] else v
            for k, v in cond_binding.items()
        }
        del g.func_inputs[idx]
        del g.outputs[idx]
        g.outputs.append(scan_output)

    for node in to_remove:
        g.remove_node(node.name)

    # in onnx the body inputs are: index, cond, [loop_vars]
    g.func_inputs = [iter_node.output[0], cond_node.output[0]] + g.func_inputs
    # tell graph lib to keep inputs in order
    g._order_sensitive_inputs = \
        [g.get_node_by_output(name) for name in g.func_inputs]  # pylint: disable=protected-access

    for p, c in zip(loop_node_inputs, g.func_inputs):
        shape = p.output_shapes[0]
        g.set_shape(c, shape)

    cond_outputs = inline_subgraph(g, cond_graph, "cond__", cond_binding)

    g.outputs = [cond_outputs[0]] + g.outputs
    return g
예제 #3
0
    def version_7(cls, ctx, node, **kwargs):
        tfl_while_inputs = node.input
        output_shapes = node.output_shapes
        output_dtypes = node.output_dtypes
        output_names = node.output

        cond_name = node.get_attr_str("cond_subgraph_index")
        cond_graph = find_function(cond_name)
        cond_graph.parent_graph = ctx

        body_name = node.get_attr_str("body_subgraph_index")
        body = find_function(body_name)
        body.parent_graph = ctx

        ctx.remove_node(node.name)

        cond_binding = parameter_binding(cond_graph, tfl_while_inputs)
        cond_outputs = inline_subgraph(ctx, cond_graph, cond_name, cond_binding)

        # Potential scan output candidates are identified in the body subgraph using tfl_scan_output_rewriter.
        # They can then be optimized in this tfl loop handler provided they are not used in the cond subgraph.
        scan_outputs = sorted(body.scan_outputs, reverse=True)
        def input_is_unused(g, index):
            return len(g.find_output_consumers(g.inputs[index])) == 0
        scan_outputs = [(i, out) for i, out in scan_outputs if input_is_unused(cond_graph, i)]

        for idx, _ in scan_outputs:
            del tfl_while_inputs[idx]
            output_shapes.append(output_shapes.pop(idx))
            output_dtypes.append(output_dtypes.pop(idx))
            output_names.append(output_names.pop(idx))

        max_iterations = ctx.make_const(utils.make_name("max_iterations"), np.array(np.iinfo(np.int64).max))

        loop_node = ctx.make_node("Loop", [max_iterations.output[0], cond_outputs[0]] + tfl_while_inputs,
                                  output_count=len(output_shapes), name=node.name + "_loop",
                                  shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True)

        output_map = dict(zip(output_names, loop_node.output))

        # shift output consumers
        for k, v in output_map.items():
            ctx.replace_all_inputs(k, v)  # ops=ctx.get_nodes()

        body = wire_tfl_while_body(body, loop_node.inputs, output_shapes, output_dtypes, cond_graph, scan_outputs)

        for i in range(len(scan_outputs)):
            squeeze_node = GraphBuilder(body).make_squeeze(
                {'data': body.outputs[-1-i], "axes": [0]}, return_node=True)
            body.outputs[-1-i] = squeeze_node.output[0]

        loop_node.set_body_graph_as_attr("body", body)
예제 #4
0
    def version_7(cls, ctx, node, **kwargs):
        tfl_while_inputs = node.input
        output_shapes = node.output_shapes
        output_dtypes = node.output_dtypes
        output_names = node.output

        cond_name = node.get_attr_str("cond_subgraph_index")
        cond_graph = find_function(cond_name)
        cond_graph.parent_graph = ctx

        body_name = node.get_attr_str("body_subgraph_index")
        body = find_function(body_name)
        body.parent_graph = ctx

        ctx.remove_node(node.name)

        cond_binding = parameter_binding(cond_graph, tfl_while_inputs)
        cond_outputs = inline_subgraph(ctx, cond_graph, cond_name,
                                       cond_binding)

        max_iterations = ctx.make_const(utils.make_name("max_iterations"),
                                        np.array(np.iinfo(np.int64).max))

        loop_node = ctx.make_node("Loop",
                                  [max_iterations.output[0], cond_outputs[0]] +
                                  tfl_while_inputs,
                                  output_count=len(output_shapes),
                                  name=node.name + "_loop",
                                  shapes=output_shapes,
                                  dtypes=output_dtypes,
                                  skip_conversion=True)

        output_map = dict(zip(output_names, loop_node.output))

        # shift output consumers
        for k, v in output_map.items():
            ctx.replace_all_inputs(k, v)  # ops=ctx.get_nodes()

        body = wire_tfl_while_body(body, loop_node.inputs, output_shapes,
                                   output_dtypes, cond_graph)

        loop_node.set_body_graph_as_attr("body", body)