Exemple #1
0
    def version_7(cls, ctx, node, **kwargs):
        tfl_while_inputs = node.input
        output_shapes = node.output_shapes
        output_dtypes = node.output_dtypes
        output_names = node.output

        cond_name = node.get_attr_str("cond_subgraph_index")
        cond_graph = find_function(cond_name)
        cond_graph.parent_graph = ctx

        body_name = node.get_attr_str("body_subgraph_index")
        body = find_function(body_name)
        body.parent_graph = ctx

        ctx.remove_node(node.name)

        cond_binding = parameter_binding(cond_graph, tfl_while_inputs)
        cond_outputs = inline_subgraph(ctx, cond_graph, cond_name, cond_binding)

        # Potential scan output candidates are identified in the body subgraph using tfl_scan_output_rewriter.
        # They can then be optimized in this tfl loop handler provided they are not used in the cond subgraph.
        scan_outputs = sorted(body.scan_outputs, reverse=True)
        def input_is_unused(g, index):
            return len(g.find_output_consumers(g.inputs[index])) == 0
        scan_outputs = [(i, out) for i, out in scan_outputs if input_is_unused(cond_graph, i)]

        for idx, _ in scan_outputs:
            del tfl_while_inputs[idx]
            output_shapes.append(output_shapes.pop(idx))
            output_dtypes.append(output_dtypes.pop(idx))
            output_names.append(output_names.pop(idx))

        max_iterations = ctx.make_const(utils.make_name("max_iterations"), np.array(np.iinfo(np.int64).max))

        loop_node = ctx.make_node("Loop", [max_iterations.output[0], cond_outputs[0]] + tfl_while_inputs,
                                  output_count=len(output_shapes), name=node.name + "_loop",
                                  shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True)

        output_map = dict(zip(output_names, loop_node.output))

        # shift output consumers
        for k, v in output_map.items():
            ctx.replace_all_inputs(k, v)  # ops=ctx.get_nodes()

        body = wire_tfl_while_body(body, loop_node.inputs, output_shapes, output_dtypes, cond_graph, scan_outputs)

        for i in range(len(scan_outputs)):
            squeeze_node = GraphBuilder(body).make_squeeze(
                {'data': body.outputs[-1-i], "axes": [0]}, return_node=True)
            body.outputs[-1-i] = squeeze_node.output[0]

        loop_node.set_body_graph_as_attr("body", body)
    def version_1(cls, ctx, node, **kwargs):
        """V2 control flow - If"""
        inputs = node.input[1:]

        if node.type == "If" and len(inputs) == 0:
            # this comes from the re-writers
            return

        output_shapes = node.output_shapes
        output_dtypes = node.output_dtypes
        ctx.remove_node(node.name)

        # replace the original node
        if_node = ctx.make_node("If",
                                node.input[:1],
                                name=node.name,
                                output_count=len(output_shapes),
                                shapes=output_shapes,
                                dtypes=output_dtypes,
                                skip_conversion=True)

        for branch in ["then_branch", "else_branch"]:
            func_name = node.get_attr_str(branch)
            g = find_function(func_name)
            g.parent_graph = ctx
            wire_if_branch(ctx, g, inputs, output_shapes, output_dtypes,
                           func_name, node.name)
            if_node.set_body_graph_as_attr(branch, g)
    def version_1(cls, ctx, node, **kwargs):
        """V2 control flow - If"""
        inputs = node.input[1:]

        output_shapes = node.output_shapes
        output_dtypes = node.output_dtypes
        ctx.remove_node(node.name)

        # replace the original node
        branches = {}
        for branch in ["then_branch", "else_branch"]:
            func_name = node.get_attr_str(branch)
            g = find_function(func_name)
            g.parent_graph = ctx
            wire_if_branch(ctx, g, inputs, output_shapes, output_dtypes,
                           func_name, node.name)
            branches[branch] = g

        _ = ctx.make_node("If",
                          node.input[:1],
                          name=node.name,
                          output_count=len(output_shapes),
                          shapes=output_shapes,
                          dtypes=output_dtypes,
                          skip_conversion=True,
                          branches=branches)
Exemple #4
0
    def version_7(cls, ctx, node, **kwargs):
        tfl_while_inputs = node.input
        output_shapes = node.output_shapes
        output_dtypes = node.output_dtypes
        output_names = node.output

        cond_name = node.get_attr_str("cond_subgraph_index")
        cond_graph = find_function(cond_name)
        cond_graph.parent_graph = ctx

        body_name = node.get_attr_str("body_subgraph_index")
        body = find_function(body_name)
        body.parent_graph = ctx

        ctx.remove_node(node.name)

        cond_binding = parameter_binding(cond_graph, tfl_while_inputs)
        cond_outputs = inline_subgraph(ctx, cond_graph, cond_name,
                                       cond_binding)

        max_iterations = ctx.make_const(utils.make_name("max_iterations"),
                                        np.array(np.iinfo(np.int64).max))

        loop_node = ctx.make_node("Loop",
                                  [max_iterations.output[0], cond_outputs[0]] +
                                  tfl_while_inputs,
                                  output_count=len(output_shapes),
                                  name=node.name + "_loop",
                                  shapes=output_shapes,
                                  dtypes=output_dtypes,
                                  skip_conversion=True)

        output_map = dict(zip(output_names, loop_node.output))

        # shift output consumers
        for k, v in output_map.items():
            ctx.replace_all_inputs(k, v)  # ops=ctx.get_nodes()

        body = wire_tfl_while_body(body, loop_node.inputs, output_shapes,
                                   output_dtypes, cond_graph)

        loop_node.set_body_graph_as_attr("body", body)
    def version_7(cls, ctx, node, **kwargs):
        # the tensorflow while input is:
        #   loop_counter, max_iterations, [loop_vars]
        # cond and body use the same inputs
        # outputs are identical to inputs
        tf_while_inputs = node.input

        # the onnx loop input is:
        #   max_iterations, cond, [loop_vars]
        # body uses the inputs:
        #   iteration, cond, [loop_vars]
        # the onnx loop output is:
        #   cond [v_final_and_scan_outputs]

        output_shapes = node.output_shapes
        output_dtypes = node.output_dtypes
        # node.output must be copied as some element
        # may be removed from output_names below
        output_names = node.output.copy()

        # Make maximum_iterations int64 and replace -1(tf) with maxsize(onnx). If the const node has no other consumers,
        # modify it in place. Otherwise, make a new const node and leave the original unchanged.
        maximum_iterations_name = node.input[1]
        maximum_iterations = node.inputs[1].get_tensor_value()
        if maximum_iterations == -1:
            maximum_iterations = np.iinfo(np.int64).max
        consumers = ctx.find_output_consumers(maximum_iterations_name)
        external_consumers = [
            c for c in consumers if c != node and c.type != 'TensorListReserve'
        ]
        if len(external_consumers) == 0:
            ctx.remove_node(node.inputs[1].name)
        else:
            maximum_iterations_name = utils.make_name(node.inputs[1].name)
        ctx.make_const(maximum_iterations_name,
                       np.array(maximum_iterations, dtype=np.int64))
        node.input[1] = maximum_iterations_name

        cond_name = node.get_attr_str("cond")
        cond_graph = find_function(cond_name)
        cond_graph.parent_graph = ctx

        body_name = node.get_attr_str("body")
        body = find_function(body_name)
        body.parent_graph = ctx

        loop_vars = []  # passed into the loop
        body_input_to_state_var = {
        }  # Map from body input name to state var name
        cond_input_to_state_var = {}
        to_remove = []
        input_idx_to_remove = []
        # remove TensorListReserve
        for idx, name in enumerate(tf_while_inputs):
            if idx == 1:
                # onnx does not know maximum_iterations in the body so move this to a state var
                body_input_to_state_var[
                    body.func_inputs[idx]] = maximum_iterations_name
                cond_input_to_state_var[
                    cond_graph.func_inputs[idx]] = maximum_iterations_name
                continue
            if idx < 2:
                # skip  [0,1] loop_counter, max_iterations
                continue
            n = node.inputs[idx]
            if n.type in ["TensorListReserve", "TensorListResize"]:
                # there is no equivalent step in onnx and we should remove it.
                to_remove.append((idx, n))
                continue

            # tensor arrays we read from can't be loop_vars and we fetch them from the outer context instead
            if body.func_inputs[idx] in body.ta_reads:
                body_input_to_state_var[body.func_inputs[idx]] = name
                cond_input_to_state_var[cond_graph.func_inputs[idx]] = name
                input_idx_to_remove.append(idx)
            else:
                loop_vars.append(name)

        # loop_vars that become state_vars need to be removed from output as well
        for idx in reversed(input_idx_to_remove):
            del output_shapes[idx]
            del output_dtypes[idx]
            del output_names[idx]
            del body.outputs[idx]

        removed_scan_outputs = {}
        # remove tensor array that are passed in to the loop
        for idx, n in reversed(to_remove):
            ctx.remove_node(n.name)
            # make the node output bad
            ctx.replace_all_inputs(ctx.get_nodes(), n.output[0], "@@ALLOC")
            del body.func_inputs[idx]
            del cond_graph.func_inputs[idx]
            del tf_while_inputs[idx]
            # save the index of the scan output
            removed_scan_outputs[body.outputs[idx]] = idx
            del body.outputs[idx]
            # FIXME: Output shapes may be in wrong order if there are multiple scan outputs
            output_shapes.append(output_shapes[idx])
            output_dtypes.append(output_dtypes[idx])
            output_names.append(output_names[idx])
            del output_shapes[idx]
            del output_dtypes[idx]
            del output_names[idx]

        utils.make_sure(
            len(removed_scan_outputs) <= 1,
            "converter only supports while loops with a single scan output")

        ctx.remove_node(node.name)

        # In onnx 'cond' is a variable, not a function. We need to inject the subgraph into the main graph
        # before the loop and into the body.
        cond_binding = parameter_binding(cond_graph, tf_while_inputs)
        cond_outputs = inline_subgraph(ctx, cond_graph, cond_name,
                                       cond_binding)
        # onnx Loop op outputs only loop_vars so we need shift output dtypes/shapes and consumers
        output_shapes = output_shapes[2:]
        output_dtypes = output_dtypes[2:]
        output_names = output_names[2:]

        loop_node = ctx.make_node(
            "Loop", [maximum_iterations_name, cond_outputs[0]] + loop_vars,
            output_count=len(output_shapes),
            name=node.name + "_loop",
            shapes=output_shapes,
            dtypes=output_dtypes,
            skip_conversion=True)

        output_map = dict(zip(output_names, loop_node.output))

        # shift output consumers
        for k, v in output_map.items():
            ctx.replace_all_inputs(ctx.get_nodes(), k, v)

        wire_while_body(ctx, body, loop_node.inputs, body_input_to_state_var,
                        cond_input_to_state_var, output_shapes, output_dtypes,
                        body_name, node.name, cond_graph, tf_while_inputs,
                        removed_scan_outputs)

        # if there was a tensorflow variant type, bind in a real type here
        # FIXME: I don't think this is needed anymore
        for i, n in enumerate(body.inputs):
            if body.get_dtype(n.output[0]) == onnx_pb.TensorProto.UNDEFINED:
                body.set_dtype(n.output[0], ctx.get_dtype(loop_node.input[i]))
        loop_node.set_body_graph_as_attr("body", body)
Exemple #6
0
    def version_7(cls, ctx, node, **kwargs):
        # the tensorflow while input is:
        #   loop_counter, max_iterations, [loop_vars]
        # cond and body use the same inputs
        # outputs are identical to inputs
        tf_while_inputs = node.input

        # the onnx loop input is:
        #   max_iterations, cond, [loop_vars]
        # body uses the inputs:
        #   iteration, cond, [loop_vars]
        # the onnx loop output is:
        #   cond [v_final_and_scan_outputs]

        output_shapes = node.output_shapes
        output_dtypes = node.output_dtypes

        # make maximum_iterations int64 and replace -1(tf) with maxsize(onnx)
        maximum_iterations_name = node.input[1]
        maximum_iterations = node.inputs[1].get_tensor_value()
        ctx.remove_node(node.inputs[1].name)
        if maximum_iterations == -1:
            maximum_iterations = sys.maxsize
        ctx.make_const(maximum_iterations_name,
                       np.array(maximum_iterations, dtype=np.int64))

        cond_name = node.get_attr_str("cond")
        cond_graph = find_function(cond_name)
        cond_graph.parent_graph = ctx

        body_name = node.get_attr_str("body")
        body = find_function(body_name)
        body.parent_graph = ctx

        loop_vars = []  # passed into the loop
        state_vars = {}  # comes from outer context
        to_remove = []
        input_idx_to_remove = []
        # remove TensorListReserve
        for idx, name in enumerate(tf_while_inputs):
            if idx == 1:
                # onnx does not know maximum_iterations in the body so move this to a state var
                state_vars[body.func_inputs[idx]] = maximum_iterations_name
                continue
            if idx < 2:
                # skip  [0,1] loop_counter, max_iterations
                continue
            n = node.inputs[idx]
            if n.type in ["TensorListReserve", "TensorListResize"]:
                # there is no equivalent step in onnx and we should remove it.
                # But we make this an identity to keep the loop_vars the same on input and output
                # of the body but there should be no access to this argument in the body.
                to_remove.append((idx, n))
                continue

            # tensor arrays we read from can't be loop_vars and we fetch them from the outer context instead
            if body.func_inputs[idx] in body.ta_reads:
                state_vars[body.func_inputs[idx]] = name
                input_idx_to_remove.append(idx)
            else:
                loop_vars.append(name)

        # loop_vars that become state_vars need to be removed from output as well
        for idx in reversed(input_idx_to_remove):
            del output_shapes[idx]
            del output_dtypes[idx]
            del body.outputs[idx]

        # remove tensor array that are passed in to the loop
        for idx, n in reversed(to_remove):
            ctx.remove_node(n.name)
            # make the node output bad
            ctx.replace_all_inputs(ctx.get_nodes(), n.output[0], "@@ALLOC")
            del body.func_inputs[idx]
            del cond_graph.func_inputs[idx]
            del tf_while_inputs[idx]

        ctx.remove_node(node.name)

        # In onnx 'cond' is a variable, not a function. We need to inject the subgraph into the main graph
        # before the loop and into the body.
        cond_binding = parameter_binding(cond_graph, tf_while_inputs)
        cond_outputs = inline_subgraph(ctx, cond_graph, cond_name,
                                       cond_binding)
        # onnx Loop op outputs only loop_vars so we need shift output dtypes/shapes and consumers
        output_map = {
            node.output[i + 2]: node.output[i]
            for i in range(len(node.output) - 2)
        }
        output_shapes = output_shapes[2:]
        output_dtypes = output_dtypes[2:]

        loop_node = ctx.make_node(
            "Loop", [maximum_iterations_name, cond_outputs[0]] + loop_vars,
            output_count=len(output_shapes),
            name=node.name,
            shapes=output_shapes,
            dtypes=output_dtypes,
            skip_conversion=True)
        # shift output consumers
        for k, v in output_map.items():
            ctx.replace_all_inputs(ctx.get_nodes(), k, v)

        wire_while_body(ctx, body, loop_node.inputs, state_vars, output_shapes,
                        output_dtypes, body_name, node.name, cond_graph,
                        tf_while_inputs)

        # if there was a tensorflow variant type, bind in a real type here
        for i, n in enumerate(body.inputs):
            if body.get_dtype(n.output[0]) == onnx_pb.TensorProto.UNDEFINED:
                body.set_dtype(n.output[0], ctx.get_dtype(loop_node.input[i]))
        loop_node.set_body_graph_as_attr("body", body)
    def version_7(cls, ctx, node, **kwargs):
        # the tensorflow while input is:
        #   loop_counter, max_iterations, [loop_vars]
        # cond and body use the same inputs
        # outputs are identical to inputs
        tf_while_inputs = node.input

        # the onnx loop input is:
        #   max_iterations, cond, [loop_vars]
        # body uses the inputs:
        #   iteration, cond, [loop_vars]
        # the onnx loop output is:
        #   cond [v_final_and_scan_outputs]

        output_shapes = node.output_shapes
        output_dtypes = node.output_dtypes
        # node.output must be copied as some element
        # may be removed from output_names below
        output_names = node.output.copy()

        # Make maximum_iterations int64 and replace -1(tf) with maxsize(onnx). If the const node has no other
        # consumers, modify it in place. Otherwise, make a new const node and leave the original unchanged.
        # if maximum_iterations is not const,should add an cast node(cast to int64)
        maximum_iterations_name = node.input[1]
        if node.inputs[1].is_const():
            maximum_iterations = node.inputs[1].get_tensor_value()
            if maximum_iterations == -1:
                maximum_iterations = np.iinfo(np.int64).max
            consumers = ctx.find_output_consumers(maximum_iterations_name)
            external_consumers = [
                c for c in consumers
                if c != node and c.type != 'TensorListReserve'
            ]
            if len(external_consumers) == 0:
                ctx.remove_node(node.inputs[1].name)
            else:
                maximum_iterations_name = utils.make_name(node.inputs[1].name)
            ctx.make_const(maximum_iterations_name,
                           np.array(maximum_iterations, dtype=np.int64))
            ctx.replace_input(node, node.input[1], maximum_iterations_name, 1)
            maximum_iterations_int64 = maximum_iterations_name
        else:
            cast_inputs = [maximum_iterations_name]
            attr = {"to": onnx_pb.TensorProto.INT64}
            cast_name = node.name + "_cast"
            cast_node = ctx.make_node("Cast",
                                      cast_inputs,
                                      attr,
                                      name=cast_name)
            maximum_iterations_int64 = cast_node.output[0]

        cond_name = node.get_attr_str("cond")
        cond_graph = find_function(cond_name)
        cond_graph.parent_graph = ctx

        body_name = node.get_attr_str("body")
        body = find_function(body_name)
        body.parent_graph = ctx

        loop_vars = []  # passed into the loop
        body_input_to_state_var = {
        }  # Map from body input name to state var name
        cond_input_to_state_var = {}
        scan_outputs = []
        input_idx_to_remove = []
        idx_to_ragged_writes = dict(body.ragged_variant_list_writes)
        # remove TensorListReserve
        for idx, name in enumerate(tf_while_inputs):
            if idx == 1:
                # onnx does not know maximum_iterations in the body so move this to a state var
                body_input_to_state_var[
                    body.input_names[idx]] = maximum_iterations_name
                cond_input_to_state_var[
                    cond_graph.input_names[idx]] = maximum_iterations_name
                continue
            if idx < 2:
                # skip  [0,1] loop_counter, max_iterations
                continue
            n = node.inputs[idx]
            if n.type in ["TensorListReserve", "TensorListResize"]:
                # there is no equivalent step in onnx and we should remove it.
                output_shape = None
                output_dtype = n.get_attr_value("element_dtype")
                is_ragged = False
                if n.type == "TensorListReserve" and n.inputs[0].is_const(
                ) and not n.inputs[0].is_scalar():
                    output_shape = [-1] + n.inputs[0].get_tensor_value(
                        as_list=True)
                if idx in idx_to_ragged_writes:
                    output_shape = None
                    output_dtype = body.get_dtype(
                        idx_to_ragged_writes[idx].input[0])
                    is_ragged = True
                    loop_vars.append(name)
                scan_outputs.append(
                    (idx, n, output_shape, output_dtype, is_ragged))
                continue

            # tensor arrays we read from can't be loop_vars and we fetch them from the outer context instead
            if body.input_names[idx] in body.ta_reads:
                body_input_to_state_var[body.input_names[idx]] = name
                cond_input_to_state_var[cond_graph.input_names[idx]] = name
                input_idx_to_remove.append(idx)
            else:
                loop_vars.append(name)

        # loop_vars that become state_vars need to be removed from output as well
        for idx in reversed(input_idx_to_remove):
            del output_shapes[idx]
            del output_dtypes[idx]
            del output_names[idx]
            del body.outputs[idx]

        scan_output_names = []
        ragged_scan_output_names = []
        ragged_scan_output_to_len = {}

        # remove tensor arrays that are passed in to the loop
        for idx, n, output_shape, output_dtype, is_ragged in reversed(
                scan_outputs):
            if is_ragged:
                out = n.output[0]
                ctx.remove_node(n.name)
                seq_empty = ctx.make_node("SequenceEmpty", [],
                                          attr={'dtype': output_dtype},
                                          name=n.name,
                                          outputs=[out],
                                          shapes=[None],
                                          dtypes=[utils.SeqType(output_dtype)])
                ctx.replace_all_inputs(n.output[0], seq_empty.output[0])
                # Ragged tensors also must track the length of each row
                output_shapes.append([-1])
                output_dtypes.append(TensorProto.INT64)
                output_shapes[idx] = None
                output_dtypes[idx] = utils.SeqType(output_dtype)
                body_ragged_name = utils.make_name("ragged_scan_output")
                external_ragged_name = utils.make_name("ragged_output")
                scan_output_names.append(body_ragged_name)
                output_names.append(external_ragged_name)
                ragged_scan_output_names.append(body_ragged_name)
                ragged_scan_output_to_len[
                    output_names[idx]] = external_ragged_name
                continue
            ctx.remove_node(n.name)
            # make the node output bad
            ctx.replace_all_inputs(n.output[0],
                                   "@@ALLOC")  # ops=ctx.get_nodes()
            del body.inputs[idx]
            del cond_graph.inputs[idx]
            del tf_while_inputs[idx]
            scan_output_names.append(body.outputs[idx])
            del body.outputs[idx]
            output_shapes.append(output_shape)
            output_dtypes.append(output_dtype)
            output_names.append(output_names[idx])
            del output_shapes[idx]
            del output_dtypes[idx]
            del output_names[idx]

        ctx.remove_node(node.name)

        # In onnx 'cond' is a variable, not a function. We need to inject the subgraph into the main graph
        # before the loop and into the body.
        cond_binding = parameter_binding(cond_graph, tf_while_inputs)
        cond_outputs = inline_subgraph(ctx, cond_graph, cond_name,
                                       cond_binding)
        # onnx Loop op outputs only loop_vars so we need shift output dtypes/shapes and consumers
        output_shapes = output_shapes[2:]
        output_dtypes = output_dtypes[2:]
        output_names = output_names[2:]

        branches = {"body": body}
        loop_node = ctx.make_node(
            "Loop", [maximum_iterations_int64, cond_outputs[0]] + loop_vars,
            output_count=len(output_shapes),
            name=node.name + "_loop",
            shapes=output_shapes,
            dtypes=output_dtypes,
            skip_conversion=True,
            branches=branches)

        output_map = dict(zip(output_names, loop_node.output))

        # shift output consumers
        for k, v in output_map.items():
            if k not in ragged_scan_output_to_len.values():
                ctx.replace_all_inputs(k, v)  # ops=ctx.get_nodes()

        ragged_scan_output_to_len = {
            output_map[k]: output_map[v]
            for k, v in ragged_scan_output_to_len.items()
        }

        wire_while_body(ctx, body, loop_node, body_input_to_state_var,
                        cond_input_to_state_var, output_shapes, output_dtypes,
                        body_name, node.name, cond_graph, tf_while_inputs,
                        scan_output_names, ragged_scan_output_names)

        loop_node.ragged_scan_output_to_len = ragged_scan_output_to_len
        # if there was a tensorflow variant type, bind in a real type here
        # FIXME: I don't think this is needed anymore
        for i, n in enumerate(body.inputs):
            if body.get_dtype(n.output[0]) == onnx_pb.TensorProto.UNDEFINED:
                body.set_dtype(n.output[0], ctx.get_dtype(loop_node.input[i]))
def rewrite_gru_tf2(g, ops):
    pattern1 = make_grucell_pattern("Identity")
    pattern2 = keras_gru_pattern

    for pattern in [pattern1, pattern2]:
        matcher = GraphMatcher(pattern, allow_reorder=True)
        match_results = list(matcher.match_ops(ops))
        for match_result in match_results:
            activation_op = match_result.get_op("optional_activation")
            activations = ["Sigmoid", activation_op.type]
            if activation_op.type not in ["Relu", "Tanh", "Sigmoid"]:
                continue

            if pattern is pattern1:
                concat = match_result.get_op("cell_inputs")
                if len(concat.inputs) != 3:
                    continue
                get_item = concat.inputs[0]
                init_state = concat.inputs[1]
            else:
                get_item = match_result.get_op("gru_input")
                init_state = match_result.get_op("state")
            if not get_item.type == "TensorListGetItem":
                continue
            x_e = get_item.inputs[0]
            if not x_e.is_graph_input():
                continue
            x_idx = g.input_names.index(x_e.output[0])
            if not init_state.is_graph_input():
                continue
            init_state_idx = g.input_names.index(init_state.output[0])

            cell_output = match_result.get_op("cell_output")
            final_consumers = g.find_output_consumers(cell_output.output[0])
            select_ops = [n for n in final_consumers if n.type == "Select"]

            def has_tensor_list_consumer(n):
                return any(c.type == "TensorListSetItem"
                           for c in g.find_output_consumers(n.output[0]))

            select_ops = [n for n in select_ops if has_tensor_list_consumer(n)]
            if len(select_ops) == 1:
                greater_eq = select_ops[0].inputs[0]
                if greater_eq.type != "GreaterEqual":
                    continue
                seq_len = greater_eq.inputs[1]
                if not seq_len.is_graph_input():
                    continue
                seq_len_idx = g.input_names.index(seq_len.output[0])
                final_consumers = g.find_output_consumers(
                    select_ops[0].output[0])
            else:
                seq_len_idx = None

            tensor_set_items = [
                n for n in final_consumers if n.type == "TensorListSetItem"
            ]
            if len(tensor_set_items) != 1:
                continue

            if not tensor_set_items[0].inputs[0].is_graph_input():
                continue
            out_idx = g.input_names.index(tensor_set_items[0].input[0])

            hk = match_result.get_op("hidden_kernel")
            while hk.type == "Identity":
                hk = hk.inputs[0]
            if not hk.is_graph_input():
                continue
            hk_idx = g.input_names.index(hk.output[0])

            hb = match_result.get_op("hidden_bias")
            if not hb.is_graph_input():
                continue
            hb_idx = g.input_names.index(hb.output[0])

            gk = match_result.get_op("gate_kernel")
            while gk.type == "Identity":
                gk = gk.inputs[0]
            if not gk.is_graph_input():
                continue
            gk_idx = g.input_names.index(gk.output[0])

            gb = match_result.get_op("gate_bias")
            if not gb.is_graph_input():
                continue
            gb_idx = g.input_names.index(gb.output[0])

            bias_add = match_result.get_op("bias_add")
            if bias_add is not None and bias_add.data_format != "NHWC":
                continue

            g.gru_rewriter_context = {
                "x_idx": x_idx,
                "out_idx": out_idx,
                "initial_state_idx": init_state_idx,
                "hidden_kernel_idx": hk_idx,
                "hidden_bias_idx": hb_idx,
                "gate_kernel_idx": gk_idx,
                "gate_bias_idx": gb_idx,
                "seq_len_idx": seq_len_idx,
                "activations": activations,
                "from_keras": pattern is pattern2,
                "linear_before_reset": 1 if pattern is pattern2 else 0,
            }

    for op in ops:
        if op.is_while():
            body_graph = find_function(op.get_attr_str("body"))
            if body_graph.gru_rewriter_context is None:
                continue
            body_context = body_graph.gru_rewriter_context
            hk = op.input[body_context["hidden_kernel_idx"]]
            hb = op.input[body_context["hidden_bias_idx"]]
            gk = op.input[body_context["gate_kernel_idx"]]
            gb = op.input[body_context["gate_bias_idx"]]
            if not all(g.is_const(w) for w in [hk, hb, gk, gb]):
                continue
            hk_const = g.get_tensor_value(hk, as_list=False)
            hb_const = g.get_tensor_value(hb, as_list=False)
            gk_const = g.get_tensor_value(gk, as_list=False)
            gb_const = g.get_tensor_value(gb, as_list=False)

            initial_state_sq = op.input[body_context["initial_state_idx"]]
            initial_state = GraphBuilder(g).make_unsqueeze({
                "data": initial_state_sq,
                "axes": [0]
            })

            context = UnitRnnContext()
            context.from_keras = body_context["from_keras"]
            context.weights.update({
                "hidden_kernel": hk_const,
                "hidden_bias": hb_const,
                "gate_kernel": gk_const,
                "gate_bias": gb_const
            })
            context.attributes["activations"] = body_context["activations"]
            context.attributes["linear_before_reset"] = body_context[
                "linear_before_reset"]
            tensor_array_inp = op.inputs[body_context["x_idx"]]
            if not tensor_array_inp.type == "TensorListFromTensor":
                continue

            final_consumers = g.find_output_consumers(
                op.output[body_context["out_idx"]])
            output_ys = [
                n.output[0] for n in final_consumers
                if n.type == "TensorListStack"
            ]

            context.onnx_input_ids["X"] = tensor_array_inp.input[0]
            if body_context["seq_len_idx"] is None:
                context.onnx_input_ids["sequence_lens"] = ""
            else:
                context.onnx_input_ids["sequence_lens"] = op.input[
                    body_context["seq_len_idx"]]
            context.onnx_input_ids["initial_state"] = initial_state

            gru_rewriter = GRUUnitRewriter(g)
            gru_rewriter.process_weights_and_bias(context)
            gru_node = gru_rewriter.create_rnn_node(context)
            squeeze_output = GraphBuilder(g).make_squeeze({
                "data":
                gru_node.output[0],
                "axes": [1]
            })
            for output in output_ys:
                g.replace_all_inputs(output, squeeze_output)

            f_state_squeeze = GraphBuilder(g).make_squeeze({
                "data":
                gru_node.output[1],
                "axes": [0]
            })
            g.replace_all_inputs(op.output[body_context["initial_state_idx"]],
                                 f_state_squeeze)

    return g.get_nodes()
Exemple #9
0
def rewriter_lstm_tf2(g, ops):

    pattern1 = make_lstm_pattern(enter_or_id="Identity")  # TF LSTM
    pattern2 = make_lstm_pattern(from_keras=True, use_bias=False)  # keras LSTM
    pattern3 = make_lstm_pattern(from_keras=True,
                                 use_bias=True)  # keras LSTM with bias

    for pattern in [pattern1, pattern2, pattern3]:
        matcher = GraphMatcher(pattern, allow_reorder=False)
        match_results = list(matcher.match_ops(ops))

        for match_result in match_results:
            from_keras = pattern != pattern1
            activations_fgh = [
                match_result.get_op("ft").type,
                match_result.get_op("gt").type,
                match_result.get_op("ct'").type
            ]
            supported_activations = ['Relu', 'Sigmoid', 'Tanh']
            if any(f not in supported_activations for f in activations_fgh):
                continue

            # extract input x_t
            if from_keras:
                get_item = match_result.get_op("xt")
            else:
                concat = match_result.get_op("xh")
                if len(concat.inputs) != 3:
                    continue
                get_item = concat.inputs[0]
            if not get_item.type == "TensorListGetItem":
                continue
            x_e = get_item.inputs[0]
            if not x_e.is_graph_input():
                continue
            x_idx = g.input_names.index(x_e.output[0])

            # extract output h_t
            ht_mul = match_result.get_op("ht")
            final_consumers = g.find_output_consumers(ht_mul.output[0])
            select_ops = [n for n in final_consumers if n.type == "Select"]

            def has_tensor_list_consumer(n):
                return any(c.type == "TensorListSetItem"
                           for c in g.find_output_consumers(n.output[0]))

            select_ops = [n for n in select_ops if has_tensor_list_consumer(n)]
            if len(select_ops) == 1:
                greater_eq = select_ops[0].inputs[0]
                if greater_eq.type != "GreaterEqual":
                    continue
                seq_len = greater_eq.inputs[1]
                if not seq_len.is_graph_input():
                    continue
                seq_len_idx = g.input_names.index(seq_len.output[0])
                final_consumers = g.find_output_consumers(
                    select_ops[0].output[0])
            else:
                seq_len_idx = None

            tensor_set_items = [
                n for n in final_consumers if n.type == "TensorListSetItem"
            ]
            if len(tensor_set_items) != 1:
                continue

            if not tensor_set_items[0].inputs[0].is_graph_input():
                continue
            out_idx = g.input_names.index(tensor_set_items[0].input[0])

            # extract input h_(t-1) and c_(t-1)
            init_state = match_result.get_op(
                "ht-1") if from_keras else concat.inputs[1]
            if init_state.is_graph_input():
                # c and h are separate
                h_idx = g.input_names.index(init_state.output[0])
                c_e = match_result.get_op("c")
                if not c_e.is_graph_input():
                    continue
                c_idx = g.input_names.index(c_e.output[0])
                ch_info = {
                    "state_is_tuple": True,
                    "c_idx": c_idx,
                    "h_idx": h_idx,
                }
            else:
                # c and h are concatenated
                if not init_state.type == "Slice":
                    continue
                ch_e = init_state.inputs[0]
                if not ch_e.is_graph_input():
                    continue
                ch_idx = g.input_names.index(ch_e.output[0])

                c_e = match_result.get_op("c")
                if not c_e.type == "Slice" or c_e.input[0] != ch_e.output[0]:
                    continue
                ch_info = {
                    "state_is_tuple": False,
                    "ch_idx": ch_idx,
                }

            # extract weights and bias
            w_idx = hk_idx = gk_idx = 0
            ft_bias = None

            if from_keras:
                # hidden kernel
                hk = match_result.get_op("R")
                while hk.type == "Identity":
                    hk = hk.inputs[0]
                if not hk.is_graph_input():
                    continue
                hk_idx = g.input_names.index(hk.output[0])

                # gate kernel
                gk = match_result.get_op("W")
                while gk.type == "Identity":
                    gk = gk.inputs[0]
                if not gk.is_graph_input():
                    continue
                gk_idx = g.input_names.index(gk.output[0])

                # Wb and Rb are concatenated
                b_idx = None
                if pattern is pattern3:
                    bias_add = match_result.get_op("bias_add")
                    if bias_add is not None and bias_add.data_format != "NHWC":
                        continue

                    b_e = match_result.get_op("cell_bias")
                    while b_e.type == "Identity":
                        b_e = b_e.inputs[0]
                    if not b_e.is_graph_input():
                        continue
                    b_idx = g.input_names.index(b_e.output[0])

            else:
                # W and R are concatenated
                w_e = match_result.get_op("cell_kernel")
                if not w_e.is_graph_input():
                    continue
                w_idx = g.input_names.index(w_e.output[0])

                bias_add = match_result.get_op("bias_add")
                if bias_add is not None and bias_add.data_format != "NHWC":
                    continue

                b_e = match_result.get_op("cell_bias")
                if not b_e.is_graph_input():
                    continue
                b_idx = g.input_names.index(b_e.output[0])

                ft_bias_node = match_result.get_op("ft_bias")
                if not ft_bias_node.is_const():
                    continue
                if g.get_dtype(ft_bias_node.output[0]) != g.get_dtype(
                        b_e.output[0]):
                    continue
                ft_bias = ft_bias_node.get_tensor_value(as_list=False)

            g.lstm_rewriter_context = {
                # common
                "x_idx":
                x_idx,
                "out_idx":
                out_idx,
                "seq_len_idx":
                seq_len_idx,
                "bias_idx":
                b_idx,
                "from_keras":
                from_keras,
                "activations_fgh":
                activations_fgh,
                **ch_info,  # {state_is_tuple, h_idx, c_idx} or {state_is_tuple, ch_idx}

                # TF
                "weight_idx":
                w_idx,
                "ft_bias":
                ft_bias,

                # Keras
                "w_idx":
                gk_idx,
                "r_idx":
                hk_idx,
            }

    for op in ops:
        if op.is_while():
            body_graph = find_function(op.get_attr_str("body"))
            if body_graph.lstm_rewriter_context is None:
                continue
            body_context = body_graph.lstm_rewriter_context

            # parse weights
            consts = []
            if body_context["from_keras"]:
                wx = op.input[body_context["w_idx"]]
                wh = op.input[body_context["r_idx"]]
                wx_const = g.get_tensor_value(wx, as_list=False)
                wh_const = g.get_tensor_value(wh, as_list=False)
                consts.extend([wx, wh])
            else:
                w = op.input[body_context["weight_idx"]]
                w_const = g.get_tensor_value(w, as_list=False)
                consts.append(w)

            # parse bias
            if body_context["bias_idx"] is not None:
                b = op.input[body_context["bias_idx"]]
                b_const = g.get_tensor_value(b, as_list=False)
                consts.append(b)
            else:
                b_const = None

            if not all(g.is_const(c) for c in consts):
                continue

            # parse states
            if body_context["state_is_tuple"]:
                initial_c_sq = op.input[body_context["c_idx"]]
                initial_h_sq = op.input[body_context["h_idx"]]
                initial_c = GraphBuilder(g).make_unsqueeze({
                    "data": initial_c_sq,
                    "axes": [0]
                })
                initial_h = GraphBuilder(g).make_unsqueeze({
                    "data": initial_h_sq,
                    "axes": [0]
                })
            else:
                initial_ch = op.input[body_context["ch_idx"]]
                if not g.is_const(initial_ch):
                    continue
                initial_ch_const = g.get_tensor_value(initial_ch,
                                                      as_list=False)
                if not len(initial_ch_const.shape) == 2:
                    continue
                initial_ch_const = np.expand_dims(initial_ch_const, axis=0)
                initial_c_const, initial_h_const = np.split(initial_ch_const,
                                                            2,
                                                            axis=2)
                initial_c = g.make_const(utils.make_name("initial_c"),
                                         initial_c_const).output[0]
                initial_h = g.make_const(utils.make_name("initial_h"),
                                         initial_h_const).output[0]

            # build LSTMContext
            context = LSTMContext()
            context.from_keras = body_context["from_keras"]

            if context.from_keras:
                context.weights.append({
                    "w": wx_const,
                    "r": wh_const,
                    "bias": b_const
                })
            else:
                context.weights.append({
                    "weight": w_const,
                    "bias": b_const,
                    "ft_bias": body_context["ft_bias"]
                })

            context.onnx_input_ids.append({})
            context.input_size.append(None)
            context.hidden_size.append(None)
            context.attributes.append(
                {"activations": body_context['activations_fgh']})
            tensor_array_inp = op.inputs[body_context["x_idx"]]
            if not tensor_array_inp.type == "TensorListFromTensor":
                continue

            final_consumers = g.find_output_consumers(
                op.output[body_context["out_idx"]])
            output_ys = [
                n.output[0] for n in final_consumers
                if n.type == "TensorListStack"
            ]

            context.onnx_input_ids[0]["X"] = tensor_array_inp.input[0]
            if body_context["seq_len_idx"] is None:
                context.onnx_input_ids[0]["sequence_lens"] = ""
            else:
                context.onnx_input_ids[0]["sequence_lens"] = op.input[
                    body_context["seq_len_idx"]]
            context.onnx_input_ids[0]["initial_c"] = initial_c
            context.onnx_input_ids[0]["initial_h"] = initial_h

            lstm_rewriter = LSTMRewriter(g)
            lstm_rewriter.num_lstm_layers = 1

            lstm_rewriter.process_weights_and_bias(context)
            lstm_node = lstm_rewriter.create_rnn_node(context)[0]

            squeeze_output = GraphBuilder(g).make_squeeze({
                "data":
                lstm_node.output[0],
                "axes": [1]
            })
            for output in output_ys:
                g.replace_all_inputs(output, squeeze_output)

            if body_context["state_is_tuple"]:
                c_squeeze = GraphBuilder(g).make_squeeze({
                    "data":
                    lstm_node.output[2],
                    "axes": [0]
                })
                h_squeeze = GraphBuilder(g).make_squeeze({
                    "data":
                    lstm_node.output[1],
                    "axes": [0]
                })
                g.replace_all_inputs(op.output[body_context["c_idx"]],
                                     c_squeeze)
                g.replace_all_inputs(op.output[body_context["h_idx"]],
                                     h_squeeze)
            else:
                concat_ch = g.make_node(
                    "Concat", [lstm_node.output[2], lstm_node.output[1]],
                    attr={
                        "axis": 2
                    }).output[0]
                ch_squeeze = GraphBuilder(g).make_squeeze({
                    "data": concat_ch,
                    "axes": [0]
                })
                ch_output = op.output[body_context["ch_idx"]]
                g.replace_all_inputs(ch_output, ch_squeeze)

    return g.get_nodes()
def rewriter_lstm_tf2(g, ops):
    pattern1 = make_lstmcell_pattern("Identity")

    for pattern in [pattern1]:
        matcher = GraphMatcher(pattern, allow_reorder=False)
        match_results = list(matcher.match_ops(ops))
        for match_result in match_results:
            concat = match_result.get_op("xh")
            if len(concat.inputs) != 3:
                continue
            get_item = concat.inputs[0]
            if not get_item.type == "TensorListGetItem":
                continue
            x_e = get_item.inputs[0]
            if not x_e.is_graph_input():
                continue
            x_idx = g.input_names.index(x_e.output[0])

            ht_mul = match_result.get_op("ht")
            final_consumers = g.find_output_consumers(ht_mul.output[0])
            select_ops = [n for n in final_consumers if n.type == "Select"]
            def has_tensor_list_consumer(n):
                return any(c.type == "TensorListSetItem" for c in g.find_output_consumers(n.output[0]))
            select_ops = [n for n in select_ops if has_tensor_list_consumer(n)]
            if len(select_ops) == 1:
                greater_eq = select_ops[0].inputs[0]
                if greater_eq.type != "GreaterEqual":
                    continue
                seq_len = greater_eq.inputs[1]
                if not seq_len.is_graph_input():
                    continue
                seq_len_idx = g.input_names.index(seq_len.output[0])
                final_consumers = g.find_output_consumers(select_ops[0].output[0])
            else:
                seq_len_idx = None

            tensor_set_items = [n for n in final_consumers if n.type == "TensorListSetItem"]
            if len(tensor_set_items) != 1:
                continue

            if not tensor_set_items[0].inputs[0].is_graph_input():
                continue
            out_idx = g.input_names.index(tensor_set_items[0].input[0])

            if concat.inputs[1].is_graph_input():
                # c and h are separate
                h_idx = g.input_names.index(concat.input[1])
                c_e = match_result.get_op("c")
                if not c_e.is_graph_input():
                    continue
                c_idx = g.input_names.index(c_e.output[0])
                ch_info = {
                    "state_is_tuple": True,
                    "c_idx": c_idx,
                    "h_idx": h_idx,
                }
            else:
                # c and h are concatenated
                if not concat.inputs[1].type == "Slice":
                    continue
                ch_e = concat.inputs[1].inputs[0]
                if not ch_e.is_graph_input():
                    continue
                ch_idx = g.input_names.index(ch_e.output[0])

                c_e = match_result.get_op("c")
                if not c_e.type == "Slice" or c_e.input[0] != ch_e.output[0]:
                    continue
                ch_info = {
                    "state_is_tuple": False,
                    "ch_idx": ch_idx,
                }

            w_e = match_result.get_op("cell_kernel")
            if not w_e.is_graph_input():
                continue
            w_idx = g.input_names.index(w_e.output[0])

            bias_add = match_result.get_op("bias_add")
            if bias_add is not None and bias_add.data_format != "NHWC":
                continue

            b_e = match_result.get_op("cell_bias")
            if not b_e.is_graph_input():
                continue
            b_idx = g.input_names.index(b_e.output[0])

            ft_bias_node = match_result.get_op("ft_bias")
            if not ft_bias_node.is_const():
                continue
            if g.get_dtype(ft_bias_node.output[0]) != g.get_dtype(b_e.output[0]):
                continue
            ft_bias = ft_bias_node.get_tensor_value(as_list=False)

            g.lstm_rewriter_context = {
                "x_idx": x_idx,
                "out_idx": out_idx,
                "weight_idx": w_idx,
                "bias_idx": b_idx,
                "ft_bias": ft_bias,
                "seq_len_idx": seq_len_idx,
                **ch_info
            }

    for op in ops:
        if op.is_while():
            body_graph = find_function(op.get_attr_str("body"))
            if body_graph.lstm_rewriter_context is None:
                continue
            body_context = body_graph.lstm_rewriter_context
            w = op.input[body_context["weight_idx"]]
            b = op.input[body_context["bias_idx"]]
            if not g.is_const(w) or not g.is_const(b):
                continue
            w_const = g.get_tensor_value(w, as_list=False)
            b_const = g.get_tensor_value(b, as_list=False)

            if body_context["state_is_tuple"]:
                initial_c_sq = op.input[body_context["c_idx"]]
                initial_h_sq = op.input[body_context["h_idx"]]
                initial_c = GraphBuilder(g).make_unsqueeze({"data": initial_c_sq, "axes": [0]})
                initial_h = GraphBuilder(g).make_unsqueeze({"data": initial_h_sq, "axes": [0]})
            else:
                initial_ch = op.input[body_context["ch_idx"]]
                if not g.is_const(initial_ch):
                    continue
                initial_ch_const = g.get_tensor_value(initial_ch, as_list=False)
                if not len(initial_ch_const.shape) == 2:
                    continue
                initial_ch_const = np.expand_dims(initial_ch_const, axis=0)
                initial_c_const, initial_h_const = np.split(initial_ch_const, 2, axis=2)
                initial_c = g.make_const(utils.make_name("initial_c"), initial_c_const).output[0]
                initial_h = g.make_const(utils.make_name("initial_h"), initial_h_const).output[0]

            context = LSTMContext()
            context.weights.append({"weight": w_const, "bias": b_const, "ft_bias": body_context["ft_bias"]})
            context.onnx_input_ids.append({})
            context.input_size.append(None)
            context.hidden_size.append(None)
            context.attributes.append({})
            tensor_array_inp = op.inputs[body_context["x_idx"]]
            if not tensor_array_inp.type == "TensorListFromTensor":
                continue

            final_consumers = g.find_output_consumers(op.output[body_context["out_idx"]])
            output_ys = [n.output[0] for n in final_consumers if n.type == "TensorListStack"]

            context.onnx_input_ids[0]["X"] = tensor_array_inp.input[0]
            if body_context["seq_len_idx"] is None:
                context.onnx_input_ids[0]["sequence_lens"] = ""
            else:
                context.onnx_input_ids[0]["sequence_lens"] = op.input[body_context["seq_len_idx"]]
            context.onnx_input_ids[0]["initial_c"] = initial_c
            context.onnx_input_ids[0]["initial_h"] = initial_h

            lstm_rewriter = LSTMRewriter(g)
            lstm_rewriter.num_lstm_layers = 1
            lstm_rewriter.process_weights_and_bias(context)
            lstm_node = lstm_rewriter.create_rnn_node(context)[0]
            squeeze_output = GraphBuilder(g).make_squeeze({"data": lstm_node.output[0], "axes": [1]})
            for output in output_ys:
                g.replace_all_inputs(output, squeeze_output)

            if body_context["state_is_tuple"]:
                c_squeeze = GraphBuilder(g).make_squeeze({"data": lstm_node.output[2], "axes": [0]})
                h_squeeze = GraphBuilder(g).make_squeeze({"data": lstm_node.output[1], "axes": [0]})
                g.replace_all_inputs(op.output[body_context["c_idx"]], c_squeeze)
                g.replace_all_inputs(op.output[body_context["h_idx"]], h_squeeze)
            else:
                concat_ch = g.make_node("Concat", [lstm_node.output[2], lstm_node.output[1]],
                                        attr={"axis": 2}).output[0]
                ch_squeeze = GraphBuilder(g).make_squeeze({"data": concat_ch, "axes": [0]})
                ch_output = op.output[body_context["ch_idx"]]
                g.replace_all_inputs(ch_output, ch_squeeze)

    return g.get_nodes()