示例#1
0
def make_onnx_node(g,
                   op_type,
                   inputs,
                   attr=None,
                   output_count=1,
                   skip_conversion=True):
    if attr is None:
        attr = {}
    node_name = utils.make_name(op_type)
    outputs = [node_name + ":" + str(i) for i in np.arange(output_count)]
    node = Node(helper.make_node(op_type,
                                 inputs,
                                 outputs,
                                 name=node_name,
                                 **attr),
                g,
                skip_conversion=skip_conversion)

    return node
示例#2
0
def rewrite_dropout(g, ops):
    pattern = \
        OpTypePattern('Mul', name='outputs', inputs=[
            OpTypePattern('RealDiv', name="input2"),
            OpTypePattern('Floor', inputs=[
                OpTypePattern('Add', inputs=[
                    OpTypePattern(None, name="input3"),
                    OpTypePattern('RandomUniform'),
                ])
            ]),
        ])
    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        inputs2 = match.get_op('input2')
        outputs = match.get_op('outputs')
        op_name = utils.make_name("Dropout")
        out_name = op_name + ":0"
        new_node = Node(helper.make_node("Dropout", [inputs2.input[0]], [out_name], name=op_name, ratio=1.0), g)
        ops = g.replace_subgraph(ops, match, [inputs2], [outputs], [new_node], [new_node])

    return ops
示例#3
0
def rewrite_random_normal(g, ops):
    pattern = \
        OpTypePattern('Add', name='output', inputs=[
            OpTypePattern('Mul', name='input2', inputs=[
                OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]), "*"
            ]), "*"
        ])

    matcher = GraphMatcher(pattern)
    match_results = list(matcher.match_ops(ops))
    for match in match_results:
        output = match.get_op('output')
        mean = output.inputs[1].get_tensor_value()[0]
        shape = g.get_shape(output.output[0])
        dtype = output.dtype
        op_name = utils.make_name("RandomNormal")
        out_name = op_name + ":0"
        new_node = Node(helper.make_node("RandomNormal", [], [out_name],
                                         name=op_name, shape=shape, mean=mean, scale=1.0,
                                         dtype=dtype), g)
        ops = g.replace_subgraph(ops, match, [], [output], [], [new_node])

    return ops
示例#4
0
 def test_rewrite_subgraph(self):
     model_proto = self.sample_net()
     nodes = model_proto.node
     g = tf2onnx.graph.Graph(nodes, output_shapes={}, dtypes={})
     pattern = \
         OpTypePattern('Abs', name='output', inputs=[
             OpTypePattern('Add', name='input')
         ])
     ops = g.get_nodes()
     matcher = GraphMatcher(pattern)
     match_results = list(matcher.match_ops(ops))
     for match in match_results:
         input_node = match.get_op('input')
         output_node = match.get_op('output')
         op_name = tf2onnx.utils.make_name("ReplacedOp")
         out_name = tf2onnx.utils.port_name(op_name)
         new_node = Node(helper.make_node("Sub", input_node.input, [out_name], name=op_name), g)
         ops = g.replace_subgraph(ops, match, [], [output_node], [], [new_node])
     g.topological_sort(ops)
     result = onnx_to_graphviz(g)
     expected = 'digraph { n1 [op_type=Abs] n3 [op_type=Abs] n2 [op_type=Abs] ReplacedOp__2 [op_type=Sub] ' \
                'n6 [op_type=Identity] input -> n1 n1:0 -> n3 n1:0 -> n2 n2:0 -> ReplacedOp__2 ' \
                'n3:0 -> ReplacedOp__2 ReplacedOp__2:0 -> n6 }'
     self.assertEqual(expected, result)
    def rewrite(self):
        log.debug("enter custom rnn late rewriter")
        nodes = self.g.get_nodes()
        nodes_to_remove = []
        for scan_node in nodes:
            if scan_node.type != "Scan":
                continue
            log.debug("late write for scan node %s", scan_node.name)
            num_scan_inputs = scan_node.get_attr("num_scan_inputs").i
            if not BodyGraphDict.has_body_graph_info(scan_node.name):
                continue

            body_graph_meta = BodyGraphDict.pop_body_graph_info(scan_node.name)
            onnx_nodes, _ = LoopRewriterBase.find_subgraph(
                body_graph_meta, self.g)
            nodes_to_remove.extend(onnx_nodes)

            log.debug("start creating body graph for scan node %s ",
                      scan_node.name)
            body_graph_initializers = {}
            const_nodes = [
                n for n in onnx_nodes if n.type in ("Const", "ConstV2")
            ]
            for n in const_nodes:
                # when set nodes, Const should be removed, they need be replaced as initializers.
                body_graph_initializers[n.output[0]] = self.g.initializers[
                    n.output[0]]
                onnx_nodes.remove(n)

            onnx_nodes = set(onnx_nodes)

            ops = []
            for op in onnx_nodes:
                onnx_op = op.op
                ops.append(onnx_op)

            body_g = Graph(ops,
                           output_shapes=self.g._output_shapes,
                           dtypes=self.g._dtypes)
            body_g._initializers = body_graph_initializers

            log.debug("start preparing body graph inputs nodes")
            temp_nodes = body_g.get_nodes()
            i = 0
            input_count = len(body_graph_meta.input_ids)
            for input_name, init_input_id in zip(
                    body_graph_meta.input_ids,
                    body_graph_meta.initial_input_ids):
                shape = body_g.get_shape(input_name)
                dtype = body_g.get_dtype(input_name)
                if shape is None:
                    shape = self.g.get_shape(init_input_id)
                    if i >= input_count - num_scan_inputs:
                        loop_input_shape = list(shape)[2:]  # delete [1, time,]
                    else:
                        loop_input_shape = list(shape)
                else:
                    loop_input_shape = list(shape)

                onnx_input_shape = utils.make_onnx_shape(loop_input_shape)
                val = helper.make_tensor_value_info(input_name, dtype,
                                                    onnx_input_shape)
                body_g.add_model_input(input_name, val)
                i += 1

            log.debug("start preparing body graph outputs nodes")
            new_output_names = []
            for o in body_graph_meta.output_ids:
                # insert identity node, since sometimes we need output same output_id as state_output
                # and scan_out, but ONNX don't allow the same output_id appeared more than once as
                # output node.
                identity_name = utils.make_name("Identity")
                identity_output = utils.port_name(identity_name)
                node = Node(
                    helper.make_node("Identity", [o], [identity_output],
                                     name=identity_name), body_g)
                body_g.set_dtype(identity_output, body_g.get_dtype(o))
                body_g.copy_shape(o, identity_output)
                new_output_names.append(identity_output)
                temp_nodes.append(node)

            body_g.set_nodes(temp_nodes)
            body_g.topological_sort(body_g.get_nodes())

            log.debug("start make graph based on body graph nodes")
            body_g.output_names = new_output_names
            graph = body_g.make_graph("scan body graph")
            scan_node.set_attr("body", graph)

        # remove nodes in body graph from g
        for n in set(nodes_to_remove):
            if n in nodes:
                nodes.remove(n)
            elif self.g.is_initializer(n.output[0]):
                del self.g.initializers[n.output[0]]
            else:
                raise ValueError("error when removing nodes")

        return nodes
示例#6
0
def select_op8(ctx, node, name, args):
    # T output = Select(bool condition, T x, T y)
    # V v_final_and_scan_outputs = Loop(int64 M, B cond, V v_initial)
    utils.make_sure(
        len(node.input) > 1, "Select with only condition is not supported.")

    nodes = []
    true_data_type = ctx.get_dtype(node.input[1])
    false_data_type = ctx.get_dtype(node.input[2])
    true_data_shape = ctx.get_shape(node.input[1])
    false_data_shape = ctx.get_shape(node.input[2])
    make_sure(true_data_type == false_data_type,
              "select true val and false val have different data types.")
    make_sure(np.array_equal(true_data_shape, false_data_shape),
              "select true val and false val have different output shapes.")

    condition_shape = ctx.get_shape(node.input[0])
    utils.make_sure(condition_shape is not None, "condition shape is None")
    rank = len(condition_shape)

    utils.make_sure(rank >= 0, "rank should be >= 0")
    val_output_id = None
    if rank > 0:
        # create nodes getting shape of condition
        shape_node_output_shape = [rank]
        shape_node = ctx.make_node("Shape", [node.input[0]],
                                   op_name_scope=node.name,
                                   shapes=[shape_node_output_shape],
                                   dtypes=[onnx_pb.TensorProto.INT64])
        nodes.append(shape_node)

        # todo(pengwa), move those leveraging rewrite_incomplete_type_support_onnxruntime after shape inferencing
        # bug is fixed.
        # workaround: onnxruntime does not support Split-2, add cases before and after.
        target_dtype = onnx_pb.TensorProto.FLOAT
        shape_f_node = ctx.make_node("Cast", [shape_node.output[0]],
                                     attr={"to": target_dtype},
                                     shapes=[shape_node_output_shape],
                                     dtypes=[target_dtype],
                                     op_name_scope=node.name)
        nodes.append(shape_f_node)

        split_attr = [1 for i in range(rank)]
        output_shapes = [[1] for i in range(rank)]
        output_dtypes = [target_dtype for i in range(rank)]
        split_node = ctx.make_node("Split", [shape_f_node.output[0]],
                                   output_count=rank,
                                   attr={"split": split_attr},
                                   shapes=output_shapes,
                                   dtypes=output_dtypes,
                                   op_name_scope=node.name)
        nodes.append(split_node)

        trip_cnts = []
        for i in range(rank):
            output_id = split_node.output[i]
            output_shape = ctx.get_shape(output_id)
            target_dtype = onnx_pb.TensorProto.INT64
            shape_i_node = ctx.make_node("Cast", [output_id],
                                         attr={"to": target_dtype},
                                         shapes=[output_shape],
                                         dtypes=[target_dtype],
                                         op_name_scope=node.name)
            trip_cnts.append(shape_i_node.output[0])
            nodes.append(shape_i_node)
        # workaround ends

        onnx_nodes = create_loop_op(node.input, true_data_type,
                                    true_data_shape, trip_cnts, rank)
        new_nodes = [Node(n, ctx, skip_conversion=True) for n in onnx_nodes]
        nodes.extend(new_nodes)
        loop_node = new_nodes[-1]
        val_output_id = loop_node.output[1]
    elif rank == 0:
        if_onnx_node, val_output_id = create_if_op(node.input, true_data_type,
                                                   true_data_shape)
        if_node = Node(if_onnx_node, ctx, skip_conversion=True)
        nodes.append(if_node)

    ctx.copy_shape(node.output[0], val_output_id)
    ctx.set_dtype(node.output[0], true_data_type)

    output_node = ctx.make_node("Identity", [val_output_id],
                                name=node.name,
                                shapes=[ctx.get_shape(val_output_id)],
                                dtypes=[true_data_type])
    nodes.append(output_node)

    return nodes