Exemplo n.º 1
0
def select_op8(ctx, node, name, args):
    # T output = Select(bool condition, T x, T y)
    # V v_final_and_scan_outputs = Loop(int64 M, B cond, V v_initial)
    if len(node.input) == 1:
        raise ValueError("Select with only condition is not supported.")
    nodes = []
    data_type = ctx.get_dtype(node.input[1])

    op_name = utils.make_name("Size")
    out_name = port_name(op_name)
    batch_size_node = Node(
        helper.make_node("Size", [node.input[0]], [out_name], name=op_name),
        ctx)
    nodes.append(batch_size_node)

    nodes_to_append = create_loop_op(ctx, node, batch_size_node.output[0],
                                     data_type)
    nodes.extend(nodes_to_append)

    loop_scan_output_id = nodes[-1].output[1]
    ctx.copy_shape(node.output[0], loop_scan_output_id)
    ctx.set_dtype(node.output[0], data_type)

    op_name = node.name
    out_name = port_name(op_name)
    output_node = Node(
        helper.make_node("Identity", [loop_scan_output_id], [out_name],
                         name=op_name), ctx)
    nodes.append(output_node)

    return nodes
    def _create_transpose_pairs_after_node(self, node):
        assert len(node.output) == 1  # just support node who has 1 output
        non_nchw_trans_consumers = self._get_non_nchw_transpose_output_nodes(
            node)
        added_node = []
        # add Transpose(0, 3, 1, 2) and Transpose(0, 2, 3, 1) before each non_nchw_trans_consumers
        for consumer in non_nchw_trans_consumers:
            nchw_op_name = utils.make_name("Transpose")
            nchw_out_name = nchw_op_name + ":0"

            kwargs = {"perm": [0, 3, 1, 2]}
            nchw = helper.make_node("Transpose", [node.output[0]],
                                    [nchw_out_name],
                                    name=nchw_op_name,
                                    **kwargs)

            nhwc_op_name = utils.make_name("Transpose")
            nhwc_out_name = nhwc_op_name + ":0"

            kwargs = {"perm": [0, 2, 3, 1]}
            nhwc = helper.make_node("Transpose", [nchw_out_name],
                                    [nhwc_out_name],
                                    name=nhwc_op_name,
                                    **kwargs)
            nchw_node = Node(nchw, self._g)
            nhwc_node = Node(nhwc, self._g)
            self._g.replace_input(consumer, node.output[0], nhwc_out_name)
            added_node.extend([nchw_node, nhwc_node])

        if added_node:
            self._update_graph_nodes(added_node, None, True)
        return added_node
Exemplo n.º 3
0
def create_loop_op(ctx, node, batch_val_input_id, data_type):
    nodes = []

    cond_var_name = "condition"
    true = helper.make_tensor(cond_var_name, TensorProto.BOOL, (), [True])
    init_cond = Node(
        helper.make_node("Constant", [], [cond_var_name], value=true), ctx)
    nodes.append(init_cond)

    # Loop requires at least a variable, add a useless fake variable.
    fake_val_name = "fake_var"
    fake_var_init_val = helper.make_tensor(fake_val_name, TensorProto.FLOAT,
                                           (), [0.0])
    fake_var_init_node = Node(
        helper.make_node("Constant", [], ["fake_var"],
                         value=fake_var_init_val), ctx)
    nodes.append(fake_var_init_node)

    op_name = utils.make_name("Loop")
    out_name = port_name(op_name)
    loop_inputs = [
        batch_val_input_id,  # trip count
        cond_var_name,  # termination condition
        fake_val_name  # initial value of loop-carried dependencies
    ]
    loop_scan_output_id = port_name(op_name, 1)
    loop_node = Node(
        helper.make_node("Loop",
                         loop_inputs, [out_name, loop_scan_output_id],
                         name=op_name), ctx)
    loop_body = create_loop_body_graph(ctx, node, node.input[0], data_type)
    loop_node.set_attr("body", loop_body)
    ctx.add_body_graph(out_name, loop_body)
    nodes.append(loop_node)
    return nodes
Exemplo n.º 4
0
def minmax_op(ctx, node, name, args):
    # tensorflow minimum/maximum support broadcast. Onnx <= opset 7 does not.
    # inject a add(0) as 'broadcast' operator if needed.
    shapeo = ctx.get_shape(node.output[0])
    needs_broadcast_op = []
    for i, name in enumerate(node.input):
        if ctx.get_shape(name) != shapeo:
            needs_broadcast_op.append(i)
    if needs_broadcast_op:
        new_nodes = []
        for i in needs_broadcast_op:
            input_node = node.inputs[i]
            dtype = ctx.dtypes[node.input[i]]
            zero_name = utils.make_name(input_node.name)
            zero_node = ctx.make_const(
                zero_name, "Const",
                np.zeros(shapeo, dtype=utils.ONNX_TO_NUMPY_DTYPE[dtype]))
            op_name = utils.make_name(input_node.name)
            output_name = op_name + ":0"
            add_node = Node(
                helper.make_node("Add", [input_node.output[0], zero_name],
                                 [output_name],
                                 name=op_name), ctx)
            node.input[i] = output_name
            new_nodes.append(add_node)
        new_nodes.append(node)
        return new_nodes
    return node
Exemplo n.º 5
0
def rewrite_dropout(g, ops):
    pattern = \
        OpTypePattern('Mul', name='outputs', inputs=[
            OpTypePattern('RealDiv', name="input2"),
            OpTypePattern('Floor', inputs=[
                OpTypePattern('Add', inputs=[
                    OpTypePattern(None, name="input3"),
                    OpTypePattern('RandomUniform'),
                ])
            ]),
        ])
    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        inputs2 = match.get_op('input2')
        outputs = match.get_op('outputs')
        op_name = utils.make_name("Dropout")
        out_name = op_name + ":0"
        new_node = Node(
            helper.make_node("Dropout", [inputs2.input[0]], [out_name],
                             name=op_name,
                             ratio=1.0), g)
        ops = g.replace_subgraph(ops, match, [inputs2], [outputs], [new_node],
                                 [new_node])

    return ops
Exemplo n.º 6
0
def rewrite_random_normal(g, ops):
    pattern = \
        OpTypePattern('Add', name='output', inputs=[
            OpTypePattern('Mul', name='input2', inputs=[
                OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]), "*"
            ]), "*"
        ])

    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        output = match.get_op('output')
        mean = output.inputs[1].get_tensor_value()[0]
        shape = g.get_shape(output.output[0])
        dtype = output.dtype
        op_name = utils.make_name("RandomNormal")
        out_name = op_name + ":0"
        new_node = Node(
            helper.make_node("RandomNormal", [], [out_name],
                             name=op_name,
                             shape=shape,
                             mean=mean,
                             scale=1.0,
                             dtype=dtype), g)
        ops = g.replace_subgraph(ops, match, [], [output], [], [new_node])

    return ops
Exemplo n.º 7
0
def make_onnx_node(g,
                   op_type,
                   inputs,
                   attr=None,
                   output_count=1,
                   skip_conversion=True,
                   op_name_scope=None):
    if attr is None:
        attr = {}

    op_name_basis = op_type
    if op_name_scope:
        op_name_basis = "_".join([op_name_scope, op_type])

    node_name = utils.make_name(op_name_basis)
    outputs = [node_name + ":" + str(i) for i in np.arange(output_count)]
    node = Node(helper.make_node(op_type,
                                 inputs,
                                 outputs,
                                 name=node_name,
                                 **attr),
                g,
                skip_conversion=skip_conversion)

    return node
Exemplo n.º 8
0
 def test_rewrite_subgraph(self):
     model_proto = self.sample_net()
     nodes = model_proto.node
     g = tf2onnx.graph.Graph(nodes, output_shapes={}, dtypes={})
     pattern = \
         OpTypePattern('Abs', name='output', inputs=[
             OpTypePattern('Add', name='input')
         ])
     ops = g.get_nodes()
     matcher = GraphMatcher(pattern)
     match_results = list(matcher.match_ops(ops))
     for match in match_results:
         input_node = match.get_op('input')
         output_node = match.get_op('output')
         op_name = tf2onnx.utils.make_name("ReplacedOp")
         out_name = op_name + ":0"
         new_node = Node(
             helper.make_node("Sub",
                              input_node.input, [out_name],
                              name=op_name), g)
         ops = g.replace_subgraph(ops, match, [], [output_node], [],
                                  [new_node])
     g.topological_sort(ops)
     result = onnx_to_graphviz(g)
     expected = 'digraph { n1 [op_type=Abs] n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__2 [op_type=Sub] ' \
                'n6 [op_type=Identity] input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__2 ' \
                'n3:0 -> ReplacedOp__2 ReplacedOp__2:0 -> n6 }'
     self.assertEqual(expected, result)
Exemplo n.º 9
0
def rewrite_random_uniform(g, ops):
    pattern = \
        OpTypePattern('Add', name='output', inputs=[
            OpTypePattern('Mul', inputs=[
                OpTypePattern('RandomUniform', name='input1'),
                OpTypePattern('Sub', name='input2', inputs=["*", "*"]),
            ]), None
        ])

    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        input2 = match.get_op('input2')
        output = match.get_op('output')
        # max is on input 0
        tmax = input2.inputs[0].get_tensor_value()[0]
        tmin = input2.inputs[1].get_tensor_value()[0]
        shape = g.get_shape(output.output[0])
        dtype = output.dtype
        op_name = utils.make_name("RandomUniform")
        out_name = op_name + ":0"
        new_node = Node(helper.make_node("RandomUniform", [], [out_name],
                                         name=op_name, low=tmin, high=tmax,
                                         dtype=dtype, shape=shape), g)
        ops = g.replace_subgraph(ops, match, [], [output], [], [new_node])

    return ops
Exemplo n.º 10
0
def rewrite_flatten(g, ops):
    pattern = \
        OpTypePattern('Reshape', name='outputs', inputs=[
            OpTypePattern("*", name="input2"),
            OpTypePattern('Pack', inputs=[
                OpTypePattern('StridedSlice', inputs=[
                    "*", "*", "*", "*",
                ]),
                "*",
            ]),
        ])
    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        inputs2 = match.get_op('input2')
        outputs = match.get_op('outputs')
        op_name = utils.make_name("Flatten")
        out_name = op_name + ":0"
        new_node = Node(helper.make_node("Flatten", [inputs2.output[0]], [out_name], name=op_name), g)
        g.replace_all_inputs(ops, outputs.output[0], out_name)
        to_be_removed = [node for node in match.get_nodes() if node != inputs2]
        for i in range(len(ops) - 1, -1, -1):
            if ops[i] in to_be_removed:
                del ops[i]
        ops.append(new_node)
    return ops
Exemplo n.º 11
0
def make_onnx_node(g, op_type, inputs, attr = {}, output_count = 1, skip_conversion = True):
    node_name = utils.make_name(op_type)
    outputs = [node_name + ":" + str(i) for i in np.arange(output_count)]
    node = Node(
        helper.make_node(op_type, inputs, outputs, name = node_name, **attr),
        g, skip_conversion = skip_conversion)

    return node
Exemplo n.º 12
0
def create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output):
    dtype = output.dtype
    op_name = utils.make_name("RandomUniform")
    out_name = port_name(op_name)
    if ru_op.inputs[0].type == "Shape":
        shape_node = ru_op.inputs[0]
        new_node = Node(
            helper.make_node("RandomUniformLike", [shape_node.input[0]],
                             [out_name],
                             name=op_name,
                             low=tmin,
                             high=tmax,
                             dtype=dtype), g)
    else:
        shape = g.get_shape(output.output[0])
        new_node = Node(
            helper.make_node("RandomUniform", [], [out_name],
                             name=op_name,
                             low=tmin,
                             high=tmax,
                             dtype=dtype,
                             shape=shape), g)
    return new_node
Exemplo n.º 13
0
def _make_node(g,
               op_type,
               inputs,
               name=None,
               attr=None,
               output_count=1,
               skip_conversion=True,
               op_name_scope=None,
               shapes=None,
               dtypes=None):
    if attr is None:
        attr = {}
    if shapes is None:
        shapes = []
    if dtypes is None:
        dtypes = []

    op_name_basis = op_type
    if op_name_scope:
        op_name_basis = "_".join([op_name_scope, op_type])

    if name is None:
        node_name = utils.make_name(op_name_basis)
    else:
        node_name = name

    outputs = [node_name + ":" + str(i) for i in range(output_count)]
    onnx_node = helper.make_node(op_type,
                                 inputs,
                                 outputs,
                                 name=node_name,
                                 **attr)
    node = Node(onnx_node, g, skip_conversion=skip_conversion)

    if shapes:
        make_sure(
            len(shapes) == output_count,
            "output shape count not equal to output count")
        for i in range(output_count):
            g.set_shape(node.output[i], shapes[i])

    if dtypes:
        make_sure(
            len(dtypes) == output_count,
            "output dtypes count not equal to output count")
        for i in range(output_count):
            g.set_dtype(node.output[i], dtypes[i])

    return node
    def _create_transpose_pairs_before_node(self, node):
        non_nhwc_trans_inputs = []
        for input_id, n in zip(node.input, node.inputs):
            if not is_nhwc_transpose(n):
                # check in case node has two inputs coming from a same node output.
                if [input_id, n] not in non_nhwc_trans_inputs:
                    non_nhwc_trans_inputs.append([input_id, n])

        added_node = []
        # add Transpose(0, 3, 1, 2) and Transpose(0, 2, 3, 1) before each non_nhwc_trans_consumers
        for input_id, n in non_nhwc_trans_inputs:
            nchw_op_name = utils.make_name("Transpose")
            nchw_out_name = utils.port_name(nchw_op_name)

            kwargs = {"perm": [0, 3, 1, 2]}
            nchw = helper.make_node("Transpose", [input_id], [nchw_out_name],
                                    name=nchw_op_name,
                                    **kwargs)

            nhwc_op_name = utils.make_name("Transpose")
            nhwc_out_name = utils.port_name(nhwc_op_name)

            kwargs = {"perm": [0, 2, 3, 1]}
            nhwc = helper.make_node("Transpose", [nchw_out_name],
                                    [nhwc_out_name],
                                    name=nhwc_op_name,
                                    **kwargs)

            nchw_node = Node(nchw, self._g)
            nhwc_node = Node(nhwc, self._g)
            self._g.replace_input(node, input_id, nhwc_out_name)
            added_node.extend([nchw_node, nhwc_node])

        if added_node:
            self._update_graph_nodes(added_node, None, True)
        return added_node
    def post_optimize_action(self):
        nodes = self.nodes
        # if channel==1 or height==width==1, replace transpose with reshape
        for op in nodes:
            if op.type == "Transpose":
                input_shape = self._g.get_shape(op.input[0])
                if not input_shape:
                    continue

                new_shape = []
                # when transpose is NHWC_TO_NCHW
                if is_nchw_transpose(op) and (input_shape[3] == 1 or
                                              (input_shape[1] == 1
                                               and input_shape[2] == 1)):
                    new_shape = [
                        input_shape[0], input_shape[3], input_shape[1],
                        input_shape[2]
                    ]
                # when transpose is NCHW_TO_NHWC
                if is_nhwc_transpose(op) and (input_shape[1] == 1 or
                                              (input_shape[2] == 1
                                               and input_shape[3] == 1)):
                    new_shape = [
                        input_shape[0], input_shape[2], input_shape[3],
                        input_shape[1]
                    ]
                if new_shape:
                    out_nodes = self._g.find_output_consumers(op.output[0])
                    need_insert_reshape = False
                    for out_node in out_nodes:
                        if out_node.type != "Reshape":
                            need_insert_reshape = True
                    if need_insert_reshape:
                        op_name = utils.make_name("reshape")
                        shape_name = utils.make_name(op_name)
                        self._g.make_const(shape_name,
                                           np.array(new_shape, dtype=np.int64))
                        reshape = helper.make_node("Reshape",
                                                   [op.input[0], shape_name],
                                                   op.output,
                                                   name=op_name)
                        reshape_node = Node(reshape, self._g)
                        self._update_graph_nodes([reshape_node], [op], True)
                    else:
                        self._remove_useless_tranpose(op)
        self._g.update_proto()
        self._g.topological_sort(self._g.get_nodes())
Exemplo n.º 16
0
    def _make_onnx_node(self,
                        operation_type,
                        input_names_with_output_id,
                        attribute=None,
                        output_num=1):
        op_name = utils.make_name(operation_type)
        out_names = []
        for i in range(output_num):
            out_names.append(op_name + ":" + str(i))

        n = helper.make_node(operation_type,
                             input_names_with_output_id,
                             out_names,
                             name=op_name)
        if attribute:
            n.attribute.extend(attribute)

        return Node(n, self._g)
    def rewrite(self):
        log.debug("enter custom rnn late rewriter")
        nodes = self.g.get_nodes()
        nodes_to_remove = []
        for scan_node in nodes:
            if scan_node.type != "Scan":
                continue
            log.debug("late write for scan node %s", scan_node.name)
            num_scan_inputs = scan_node.get_attr("num_scan_inputs").i
            if not BodyGraphDict.has_body_graph_info(scan_node.name):
                continue

            body_graph_meta = BodyGraphDict.pop_body_graph_info(scan_node.name)
            onnx_nodes, _ = LoopRewriterBase.find_subgraph(
                body_graph_meta, self.g)
            nodes_to_remove.extend(onnx_nodes)

            log.debug("start creating body graph for scan node %s ",
                      scan_node.name)
            body_graph_initializers = {}
            const_nodes = [
                n for n in onnx_nodes if n.type in ("Const", "ConstV2")
            ]
            for n in const_nodes:
                # when set nodes, Const should be removed, they need be replaced as initializers.
                body_graph_initializers[n.output[0]] = self.g.initializers[
                    n.output[0]]
                onnx_nodes.remove(n)

            onnx_nodes = set(onnx_nodes)

            ops = []
            for op in onnx_nodes:
                onnx_op = op.op
                ops.append(onnx_op)

            body_g = Graph(ops,
                           output_shapes=self.g._output_shapes,
                           dtypes=self.g._dtypes)
            body_g._initializers = body_graph_initializers

            log.debug("start preparing body graph inputs nodes")
            temp_nodes = body_g.get_nodes()
            i = 0
            input_count = len(body_graph_meta.input_ids)
            for input_name, init_input_id in zip(
                    body_graph_meta.input_ids,
                    body_graph_meta.initial_input_ids):
                shape = body_g.get_shape(input_name)
                dtype = body_g.get_dtype(input_name)
                if shape is None:
                    shape = self.g.get_shape(init_input_id)
                    if i >= input_count - num_scan_inputs:
                        loop_input_shape = list(shape)[2:]  # delete [1, time,]
                    else:
                        loop_input_shape = list(shape)
                else:
                    loop_input_shape = list(shape)

                onnx_input_shape = utils.make_onnx_shape(loop_input_shape)
                val = helper.make_tensor_value_info(input_name, dtype,
                                                    onnx_input_shape)
                body_g.add_model_input(input_name, val)
                i += 1

            log.debug("start preparing body graph outputs nodes")
            new_output_names = []
            for o in body_graph_meta.output_ids:
                # insert identity node, since sometimes we need output same output_id as state_output
                # and scan_out, but ONNX don't allow the same output_id appeared more than once as
                # output node.
                identity_name = utils.make_name("Identity")
                identity_output = utils.port_name(identity_name)
                node = Node(
                    helper.make_node("Identity", [o], [identity_output],
                                     name=identity_name), body_g)
                body_g.set_dtype(identity_output, body_g.get_dtype(o))
                body_g.copy_shape(o, identity_output)
                new_output_names.append(identity_output)
                temp_nodes.append(node)

            body_g.set_nodes(temp_nodes)
            body_g.topological_sort(body_g.get_nodes())

            log.debug("start make graph based on body graph nodes")
            body_g.output_names = new_output_names
            graph = body_g.make_graph("scan body graph")
            scan_node.set_attr("body", graph)

        # remove nodes in body graph from g
        for n in set(nodes_to_remove):
            if n in nodes:
                nodes.remove(n)
            elif self.g.is_initializer(n.output[0]):
                del self.g.initializers[n.output[0]]
            else:
                raise ValueError("error when removing nodes")

        return nodes
Exemplo n.º 18
0
def select_op8(ctx, node, name, args):
    # T output = Select(bool condition, T x, T y)
    # V v_final_and_scan_outputs = Loop(int64 M, B cond, V v_initial)
    utils.make_sure(
        len(node.input) > 1, "Select with only condition is not supported.")

    nodes = []
    true_data_type = ctx.get_dtype(node.input[1])
    false_data_type = ctx.get_dtype(node.input[2])
    true_data_shape = ctx.get_shape(node.input[1])
    false_data_shape = ctx.get_shape(node.input[2])
    make_sure(true_data_type == false_data_type,
              "select true val and false val have different data types.")
    make_sure(np.array_equal(true_data_shape, false_data_shape),
              "select true val and false val have different output shapes.")

    condition_shape = ctx.get_shape(node.input[0])
    utils.make_sure(condition_shape is not None, "condition shape is None")
    rank = len(condition_shape)

    utils.make_sure(rank >= 0, "rank should be >= 0")
    val_output_id = None
    if rank > 0:
        # create nodes getting shape of condition
        shape_node_output_shape = [rank]
        shape_node = ctx.make_node("Shape", [node.input[0]],
                                   op_name_scope=node.name,
                                   shapes=[shape_node_output_shape],
                                   dtypes=[onnx_pb.TensorProto.INT64])
        nodes.append(shape_node)

        # todo(pengwa), move those leveraging rewrite_incomplete_type_support_onnxruntime after shape inferencing
        # bug is fixed.
        # workaround: onnxruntime does not support Split-2, add cases before and after.
        target_dtype = onnx_pb.TensorProto.FLOAT
        shape_f_node = ctx.make_node("Cast", [shape_node.output[0]],
                                     attr={"to": target_dtype},
                                     shapes=[shape_node_output_shape],
                                     dtypes=[target_dtype],
                                     op_name_scope=node.name)
        nodes.append(shape_f_node)

        split_attr = [1 for i in range(rank)]
        output_shapes = [[1] for i in range(rank)]
        output_dtypes = [target_dtype for i in range(rank)]
        split_node = ctx.make_node("Split", [shape_f_node.output[0]],
                                   output_count=rank,
                                   attr={"split": split_attr},
                                   shapes=output_shapes,
                                   dtypes=output_dtypes,
                                   op_name_scope=node.name)
        nodes.append(split_node)

        trip_cnts = []
        for i in range(rank):
            output_id = split_node.output[i]
            output_shape = ctx.get_shape(output_id)
            target_dtype = onnx_pb.TensorProto.INT64
            shape_i_node = ctx.make_node("Cast", [output_id],
                                         attr={"to": target_dtype},
                                         shapes=[output_shape],
                                         dtypes=[target_dtype],
                                         op_name_scope=node.name)
            trip_cnts.append(shape_i_node.output[0])
            nodes.append(shape_i_node)
        # workaround ends

        onnx_nodes = create_loop_op(node.input, true_data_type,
                                    true_data_shape, trip_cnts, rank)
        new_nodes = [Node(n, ctx, skip_conversion=True) for n in onnx_nodes]
        nodes.extend(new_nodes)
        loop_node = new_nodes[-1]
        val_output_id = loop_node.output[1]
    elif rank == 0:
        if_onnx_node, val_output_id = create_if_op(node.input, true_data_type,
                                                   true_data_shape)
        if_node = Node(if_onnx_node, ctx, skip_conversion=True)
        nodes.append(if_node)

    ctx.copy_shape(node.output[0], val_output_id)
    ctx.set_dtype(node.output[0], true_data_type)

    output_node = ctx.make_node("Identity", [val_output_id],
                                name=node.name,
                                shapes=[ctx.get_shape(val_output_id)],
                                dtypes=[true_data_type])
    nodes.append(output_node)

    return nodes