Exemple #1
0
def onnx2onnx_flow(m: onnx.ModelProto,
                   disable_fuse_bn=False,
                   bn_on_skip=False,
                   bn_before_add=False,
                   bgr=False,
                   norm=False,
                   rgba2yynn=False,
                   eliminate_tail=False) -> onnx.ModelProto:
    """Optimize the onnx.

    Args:
        m (ModelProto): the input onnx ModelProto
        disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False.
        bn_on_skip (bool, optional): add BN operator on skip branches. Defaults to False.
        bn_before_add (bool, optional): add BN before Add node on every branches. Defaults to False.
        bgr (bool, optional): add an Conv layer to convert rgb input to bgr. Defaults to False.
        norm (bool, optional): add an Conv layer to add 0.5 tp the input. Defaults to False.
        rgba2yynn (bool, optional): add an Conv layer to convert rgb input to yynn . Defaults to False.
        eliminate_tail (bool, optional): remove the trailing NPU unsupported nodes. Defaults to False.

    Returns:
        ModelProto: the optimized onnx model object.
    """
    # temp.weight_broadcast(m.graph)
    m = combo.preprocess(m, disable_fuse_bn)
    # temp.fuse_bias_in_consecutive_1x1_conv(m.graph)

    # Add BN on skip branch
    if bn_on_skip:
        other.add_bn_on_skip_branch(m.graph)
    elif bn_before_add:
        other.add_bn_before_add(m.graph)
        other.add_bn_before_activation(m.graph)

    # My optimization
    m = combo.common_optimization(m)
    # Special options
    if bgr:
        special.change_input_from_bgr_to_rgb(m)
    if norm:
        special.add_0_5_to_normalized_input(m)
    if rgba2yynn:
        special.add_rgb2yynn_node(m)

    # Remove useless last node
    if eliminate_tail:
        eliminating.remove_useless_last_nodes(m.graph)

    # Postprocessing
    m = combo.postprocess(m)
    return m
def torch_exported_onnx_flow(m: onnx.ModelProto, disable_fuse_bn=False) -> onnx.ModelProto:
    """Optimize the Pytorch exported onnx.

    Args:
        m (ModelProto): the input onnx model
        disable_fuse_bn (bool, optional): do not fuse BN into Conv. Defaults to False.

    Returns:
        ModelProto: the optimized onnx model
    """
    m = combo.preprocess(m, disable_fuse_bn)
    m = combo.pytorch_constant_folding(m)

    m = combo.common_optimization(m)

    m = combo.postprocess(m)

    return m
Exemple #3
0
m = onnx.load(args.in_file)
# temp.weight_broadcast(m.graph)
m = combo.preprocess(m, args.disable_fuse_bn)
# temp.fuse_bias_in_consecutive_1x1_conv(m.graph)

# Add BN on skip branch
if args.bn_on_skip:
    other.add_bn_on_skip_branch(m.graph)
elif args.bn_before_add:
    other.add_bn_before_add(m.graph)
    other.add_bn_before_activation(m.graph)
# Split deconv
if args.split_convtranspose:
    other.split_ConvTranspose(m)

# My optimization
m = combo.common_optimization(m)
# Special options
if args.bgr:
    special.change_input_from_bgr_to_rgb(m)
if args.norm:
    special.add_0_5_to_normalized_input(m)

# Remove useless last node
if args.eliminate_tail:
    eliminating.remove_useless_last_nodes(m.graph)

# Postprocessing
m = combo.postprocess(m)
onnx.save(m, outfile)
def tf2onnx_flow(pb_path: str, test_mode =False) -> onnx.ModelProto:
    """Convert frozen graph pb file into onnx

    Args:
        pb_path (str): input pb file path
        test_mode (bool, optional): test mode. Defaults to False.

    Raises:
        Exception: invalid input file

    Returns:
        onnx.ModelProto: converted onnx
    """
    TF2ONNX_VERSION = int(tf2onnx.version.version.replace('.', ''))

    if 160 <= TF2ONNX_VERSION:
        from tf2onnx import tf_loader
    else:
        from tf2onnx import loader as tf_loader
 
    if pb_path[-3:] == '.pb':
        model_name = pb_path.split('/')[-1][:-3]

        # always reset tensorflow session at begin 
        tf.reset_default_graph()
        
        with tf.Session() as sess:
            with gfile.FastGFile(pb_path, 'rb') as f:
                graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            sess.graph.as_default()
            tf.import_graph_def(graph_def, name='')

            if 160 <= int(tf2onnx.version.version.replace('.', '')):
                onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, functions = tf2onnx.tf_utils.tflist_to_onnx(
                    sess.graph,
                    {})
            else:
                onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes = tf2onnx.tfonnx.tflist_to_onnx(
                    sess.graph.get_operations(),
                    {})

            for n in onnx_nodes:
                if len(n.output) == 0:
                    onnx_nodes.remove(n)

            # find inputs and outputs of graph
            nodes_inputs = set()
            nodes_outputs = set()

            for n in onnx_nodes:
                if n.op_type == 'Placeholder':
                    continue
                for input in n.input:
                    nodes_inputs.add(input)
                for output in n.output:
                   nodes_outputs.add(output)

            graph_input_names = set()
            for input_name in nodes_inputs:
                if input_name not in nodes_outputs:
                    graph_input_names.add(input_name)

            graph_output_names = set()
            for n in onnx_nodes:
                if n.input and n.input[0] not in nodes_outputs:
                    continue
                if len(n.output) == 0:
                    n.output.append(n.name + ':0')
                    graph_output_names.add(n.output[0])
                else:
                    output_name = n.output[0]
                    if (output_name not in nodes_inputs) and (0 < len(n.input)):
                        graph_output_names.add(output_name)

        logging.info('Model Inputs: %s', str(list(graph_input_names)))
        logging.info('Model Outputs: %s', str(list(graph_output_names)))

        graph_def, inputs, outputs = tf_loader.from_graphdef(model_path=pb_path,
                                                             input_names=list(graph_input_names),
                                                             output_names=list(graph_output_names))

        with tf.Graph().as_default() as tf_graph:
            tf.import_graph_def(graph_def, name='')

        if 160 <= TF2ONNX_VERSION:
            with tf_loader.tf_session(graph=tf_graph):
                onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph,
                                                             input_names=inputs,
                                                             output_names=outputs,
                                                             opset=11)
        else:
            with tf.Session(graph=tf_graph):
                onnx_graph = tf2onnx.tfonnx.process_tf_graph(tf_graph=tf_graph,
                                                             input_names=inputs,
                                                             output_names=outputs,
                                                             opset=11)

        # Optimize with tf2onnx.optimizer
        onnx_graph = tf2onnx.optimizer.optimize_graph(onnx_graph)
        model_proto = onnx_graph.make_model(model_name)

        # Make tf2onnx output compatible with the spec. of onnx.utils.polish_model
        replacing.replace_initializer_with_Constant(model_proto.graph)
        model_proto = onnx.utils.polish_model(model_proto)

    else:
        raise Exception('expect .pb file as input, but got "' + str(pb_path) + '"')

    # rename
    m = model_proto

    m = combo.preprocess(m)
    m = combo.common_optimization(m)
    m = combo.tensorflow_optimization(m)
    m = combo.postprocess(m)

    if not test_mode:
        g = m.graph
        eliminating.eliminate_shape_changing_after_input(g)

    m = onnx.utils.polish_model(m)
    return m