def test_cmdarg_parse(self): arg = "input/V-1_2:0,input/X:0[1,2,3],Y:1[4,5],Z:3,A:1,B" expected_inputs = ['input/V-1_2:0', 'input/X:0', 'Y:1', 'Z:3', 'A:1', 'B'] expected_shape = {'Y:1': [4, 5], 'input/X:0': [1, 2, 3]} inputs, shape_override = utils.split_nodename_and_shape(arg) self.assertEqual(expected_inputs, inputs) self.assertEqual(expected_shape, shape_override)
def get_args(): """Parse commandline.""" parser = argparse.ArgumentParser(description="Convert tensorflow graphs to ONNX.", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=_HELP_TEXT) parser.add_argument("--input", help="input from graphdef") parser.add_argument("--graphdef", help="input from graphdef") parser.add_argument("--saved-model", help="input from saved model") parser.add_argument("--tag", help="tag to use for saved_model") parser.add_argument("--signature_def", help="signature_def from saved_model to use") parser.add_argument("--concrete_function", type=int, default=None, help="For TF2.x saved_model, index of func signature in __call__ (--signature_def is ignored)") parser.add_argument("--checkpoint", help="input from checkpoint") parser.add_argument("--keras", help="input from keras model") parser.add_argument("--large_model", help="use the large model format (for models > 2GB)", action="store_true") parser.add_argument("--output", help="output model file") parser.add_argument("--inputs", help="model input_names") parser.add_argument("--outputs", help="model output_names") parser.add_argument("--opset", type=int, default=None, help="opset version to use for onnx domain") parser.add_argument("--custom-ops", help="comma-separated map of custom ops to domains in format OpName:domain") parser.add_argument("--extra_opset", default=None, help="extra opset with format like domain:version, e.g. com.microsoft:1") parser.add_argument("--target", default=",".join(constants.DEFAULT_TARGET), choices=constants.POSSIBLE_TARGETS, help="target platform") parser.add_argument("--continue_on_error", help="continue_on_error", action="store_true") parser.add_argument("--verbose", "-v", help="verbose output, option is additive", action="count") parser.add_argument("--debug", help="debug mode", action="store_true") parser.add_argument("--output_frozen_graph", help="output frozen tf graph to file") parser.add_argument("--fold_const", help="Deprecated. Constant folding is always enabled.", action="store_true") # experimental parser.add_argument("--inputs-as-nchw", help="transpose inputs as from nhwc to nchw") args = parser.parse_args() args.shape_override = None if args.input: # for backward compativility args.graphdef = args.input if args.graphdef or args.checkpoint: if not args.input and not args.outputs: parser.error("graphdef and checkpoint models need to provide inputs and outputs") if not any([args.graphdef, args.checkpoint, args.saved_model, args.keras]): parser.print_help() sys.exit(1) if args.inputs: args.inputs, args.shape_override = utils.split_nodename_and_shape(args.inputs) if args.outputs: args.outputs = args.outputs.split(",") if args.inputs_as_nchw: args.inputs_as_nchw = args.inputs_as_nchw.split(",") if args.target: args.target = args.target.split(",") if args.signature_def: args.signature_def = [args.signature_def] if args.extra_opset: tokens = args.extra_opset.split(':') if len(tokens) != 2: parser.error("invalid extra_opset argument") args.extra_opset = [utils.make_opsetid(tokens[0], int(tokens[1]))] return args
def convert_onnx(sess, graph_def, input_path, inputs_op, outputs_op): graphdef = input_path if inputs_op: inputs_op, shape_override = utils.split_nodename_and_shape(inputs_op) if outputs_op: outputs_op = outputs_op.split(",") logging.basicConfig(level=logging.get_verbosity_level(True)) utils.set_debug_mode(True) logger = logging.getLogger(constants.TF2ONNX_PACKAGE_NAME) graph_def, inputs_op, outputs_op = from_graphdef(sess, graph_def, graphdef, inputs_op, outputs_op) model_path = graphdef graph_def = tf_optimize(inputs_op, outputs_op, graph_def, True) with tf.Graph().as_default() as tf_graph: tf.import_graph_def(graph_def, name='') with tf.Session(graph=tf_graph): g = process_tf_graph(tf_graph, continue_on_error=False, target=",".join(constants.DEFAULT_TARGET), opset=10, custom_op_handlers=None, extra_opset=None, shape_override=None, input_names=inputs_op, output_names=outputs_op, inputs_as_nchw=None) onnx_graph = optimizer.optimize_graph(g) model_proto = onnx_graph.make_model("converted from {}".format(model_path)) # write onnx graph logger.info("") logger.info("Successfully converted TensorFlow model %s to ONNX", model_path) # if args.output: output_path = input_path.replace(".pb", ".onnx") utils.save_protobuf(output_path, model_proto) logger.info("ONNX model is saved at %s", output_path)
def get_args(): """Parse commandline.""" parser = argparse.ArgumentParser() parser.add_argument("--input", help="input from graphdef") parser.add_argument("--graphdef", help="input from graphdef") parser.add_argument("--saved-model", help="input from saved model") parser.add_argument("--checkpoint", help="input from checkpoint") parser.add_argument("--output", help="output model file") parser.add_argument("--inputs", help="model input_names") parser.add_argument("--outputs", help="model output_names") parser.add_argument("--opset", type=int, default=None, help="onnx opset to use") parser.add_argument("--custom-ops", help="list of custom ops") parser.add_argument("--target", default=",".join(DEFAULT_TARGET), choices=POSSIBLE_TARGETS, help="target platform") parser.add_argument("--continue_on_error", help="continue_on_error", action="store_true") parser.add_argument("--verbose", help="verbose output", action="store_true") parser.add_argument( "--fold_const", help="enable tf constant_folding transformation before conversion", action="store_true") # experimental parser.add_argument("--inputs-as-nchw", help="transpose inputs as from nhwc to nchw") # depreciated, going to be removed some time in the future parser.add_argument("--unknown-dim", type=int, default=-1, help="default for unknown dimensions") args = parser.parse_args() args.shape_override = None if args.input: # for backward compativility args.graphdef = args.input if args.graphdef or args.checkpoint: if not args.input and not args.outputs: raise ValueError( "graphdef and checkpoint models need to provide inputs and outputs" ) if not any([args.graphdef, args.checkpoint, args.saved_model]): raise ValueError("need input as graphdef, checkpoint or saved_model") if args.inputs: args.inputs, args.shape_override = utils.split_nodename_and_shape( args.inputs) if args.outputs: args.outputs = args.outputs.split(",") if args.inputs_as_nchw: args.inputs_as_nchw = args.inputs_as_nchw.split(",") if args.target: args.target = args.target.split(",") return args
def get_args(): """Parse commandline.""" parser = argparse.ArgumentParser(description="Convert tensorflow graphs to ONNX.", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=_HELP_TEXT) parser.add_argument("--input", help="input from graphdef") parser.add_argument("--graphdef", help="input from graphdef") parser.add_argument("--saved-model", help="input from saved model") parser.add_argument("--tag", help="tag to use for saved_model") parser.add_argument("--signature_def", help="signature_def from saved_model to use") parser.add_argument("--concrete_function", type=int, default=None, help="For TF2.x saved_model, index of func signature in __call__ (--signature_def is ignored)") parser.add_argument("--checkpoint", help="input from checkpoint") parser.add_argument("--keras", help="input from keras model") parser.add_argument("--tflite", help="input from tflite model") parser.add_argument("--tfjs", help="input from tfjs model") parser.add_argument("--large_model", help="use the large model format (for models > 2GB)", action="store_true") parser.add_argument("--output", help="output model file") parser.add_argument("--inputs", help="model input_names (optional for saved_model, keras, and tflite)") parser.add_argument("--outputs", help="model output_names (optional for saved_model, keras, and tflite)") parser.add_argument("--ignore_default", help="comma-separated list of names of PlaceholderWithDefault " "ops to change into Placeholder ops") parser.add_argument("--use_default", help="comma-separated list of names of PlaceholderWithDefault ops to " "change into Identity ops using their default value") parser.add_argument("--rename-inputs", help="input names to use in final model (optional)") parser.add_argument("--rename-outputs", help="output names to use in final model (optional)") parser.add_argument("--use-graph-names", help="(saved model only) skip renaming io using signature names", action="store_true") parser.add_argument("--opset", type=int, default=None, help="opset version to use for onnx domain") parser.add_argument("--dequantize", help="Remove quantization from model. Only supported for tflite currently.", action="store_true") parser.add_argument("--custom-ops", help="Comma-separated map of custom ops to domains in format OpName:domain. " "Domain 'ai.onnx.converters.tensorflow' is used by default.") parser.add_argument("--extra_opset", default=None, help="extra opset with format like domain:version, e.g. com.microsoft:1") parser.add_argument("--load_op_libraries", help="comma-separated list of tf op library paths to register before loading model") parser.add_argument("--target", default=",".join(constants.DEFAULT_TARGET), choices=constants.POSSIBLE_TARGETS, help="target platform") parser.add_argument("--continue_on_error", help="continue_on_error", action="store_true") parser.add_argument("--verbose", "-v", help="verbose output, option is additive", action="count") parser.add_argument("--debug", help="debug mode", action="store_true") parser.add_argument("--output_frozen_graph", help="output frozen tf graph to file") parser.add_argument("--fold_const", help="Deprecated. Constant folding is always enabled.", action="store_true") # experimental parser.add_argument("--inputs-as-nchw", help="transpose inputs as from nhwc to nchw") args = parser.parse_args() args.shape_override = None if args.input: # for backward compativility args.graphdef = args.input if args.graphdef or args.checkpoint: if not args.inputs or not args.outputs: parser.error("graphdef and checkpoint models need to provide inputs and outputs") if not any([args.graphdef, args.checkpoint, args.saved_model, args.keras, args.tflite, args.tfjs]): parser.print_help() sys.exit(1) if args.inputs: args.inputs, args.shape_override = utils.split_nodename_and_shape(args.inputs) if args.outputs: args.outputs = args.outputs.split(",") if args.ignore_default: args.ignore_default = args.ignore_default.split(",") if args.use_default: args.use_default = args.use_default.split(",") if args.rename_outputs: args.rename_outputs = args.rename_outputs.split(",") if args.rename_inputs: args.rename_inputs = args.rename_inputs.split(",") if args.inputs_as_nchw: args.inputs_as_nchw = args.inputs_as_nchw.split(",") if args.target: args.target = args.target.split(",") if args.signature_def: args.signature_def = [args.signature_def] if args.dequantize: if not args.tflite: parser.error("dequantize flag is currently only supported for tflite") if args.extra_opset: tokens = args.extra_opset.split(':') if len(tokens) != 2: parser.error("invalid extra_opset argument") args.extra_opset = [utils.make_opsetid(tokens[0], int(tokens[1]))] if args.load_op_libraries: args.load_op_libraries = args.load_op_libraries.split(",") return args