示例#1
0
def get_args():
    """Parse commandline."""
    parser = argparse.ArgumentParser()
    parser.add_argument("--cache", default="/tmp/pre-trained", help="pre-trained models cache dir")
    parser.add_argument("--config", default="tests/run_pretrained_models.yaml", help="yaml config to use")
    parser.add_argument("--tests", help="tests to run")
    parser.add_argument("--target", default="", help="target platform")
    parser.add_argument("--backend", default="onnxruntime",
                        choices=["caffe2", "onnxmsrtnext", "onnxruntime"], help="backend to use")
    parser.add_argument("--opset", type=int, default=None, help="opset to use")
    parser.add_argument("--extra_opset", default=None,
                        help="extra opset with format like domain:version, e.g. com.microsoft:1")
    parser.add_argument("--verbose", "-v", help="verbose output, option is additive", action="count")
    parser.add_argument("--debug", help="debug mode", action="store_true")
    parser.add_argument("--list", help="list tests", action="store_true")
    parser.add_argument("--onnx-file", help="create onnx file in directory")
    parser.add_argument("--perf", help="capture performance numbers")
    parser.add_argument("--fold_const", help="enable tf constant_folding transformation before conversion",
                        action="store_true")
    parser.add_argument("--include-disabled", help="include disabled tests", action="store_true")
    args = parser.parse_args()

    args.target = args.target.split(",")
    if args.extra_opset:
        tokens = args.extra_opset.split(':')
        if len(tokens) != 2:
            raise ValueError("invalid extra_opset argument")
        args.extra_opset = [utils.make_opsetid(tokens[0], int(tokens[1]))]
    return args
 def to_onnx(self,
             tf_graph,
             opset=None,
             extra_opset=None,
             shape_override=None,
             input_names=None,
             const_node_values=None,
             initialized_tables=None,
             tflite_path=None):
     """Convert graph to tensorflow."""
     if extra_opset is None:
         extra_opset = []
     if self.use_custom_ops:
         extra_opset.append(
             utils.make_opsetid(constants.CONTRIB_OPS_DOMAIN, 1))
     return process_tf_graph(tf_graph,
                             continue_on_error=False,
                             opset=opset,
                             extra_opset=extra_opset,
                             target=Test.target,
                             shape_override=shape_override,
                             input_names=input_names,
                             output_names=self.output_names,
                             const_node_values=const_node_values,
                             initialized_tables=initialized_tables,
                             tflite_path=tflite_path,
                             dequantize=self.dequantize)
示例#3
0
def get_args():
    """Parse commandline."""
    parser = argparse.ArgumentParser(description="Convert tensorflow graphs to ONNX.",
                                     formatter_class=argparse.RawDescriptionHelpFormatter, epilog=_HELP_TEXT)
    parser.add_argument("--input", help="input from graphdef")
    parser.add_argument("--graphdef", help="input from graphdef")
    parser.add_argument("--saved-model", help="input from saved model")
    parser.add_argument("--tag", help="tag to use for saved_model")
    parser.add_argument("--signature_def", help="signature_def from saved_model to use")
    parser.add_argument("--concrete_function", type=int, default=None,
                        help="For TF2.x saved_model, index of func signature in __call__ (--signature_def is ignored)")
    parser.add_argument("--checkpoint", help="input from checkpoint")
    parser.add_argument("--keras", help="input from keras model")
    parser.add_argument("--large_model", help="use the large model format (for models > 2GB)", action="store_true")
    parser.add_argument("--output", help="output model file")
    parser.add_argument("--inputs", help="model input_names")
    parser.add_argument("--outputs", help="model output_names")
    parser.add_argument("--opset", type=int, default=None, help="opset version to use for onnx domain")
    parser.add_argument("--custom-ops", help="comma-separated map of custom ops to domains in format OpName:domain")
    parser.add_argument("--extra_opset", default=None,
                        help="extra opset with format like domain:version, e.g. com.microsoft:1")
    parser.add_argument("--target", default=",".join(constants.DEFAULT_TARGET), choices=constants.POSSIBLE_TARGETS,
                        help="target platform")
    parser.add_argument("--continue_on_error", help="continue_on_error", action="store_true")
    parser.add_argument("--verbose", "-v", help="verbose output, option is additive", action="count")
    parser.add_argument("--debug", help="debug mode", action="store_true")
    parser.add_argument("--output_frozen_graph", help="output frozen tf graph to file")
    parser.add_argument("--fold_const", help="Deprecated. Constant folding is always enabled.",
                        action="store_true")
    # experimental
    parser.add_argument("--inputs-as-nchw", help="transpose inputs as from nhwc to nchw")
    args = parser.parse_args()

    args.shape_override = None
    if args.input:
        # for backward compativility
        args.graphdef = args.input
    if args.graphdef or args.checkpoint:
        if not args.input and not args.outputs:
            parser.error("graphdef and checkpoint models need to provide inputs and outputs")
    if not any([args.graphdef, args.checkpoint, args.saved_model, args.keras]):
        parser.print_help()
        sys.exit(1)
    if args.inputs:
        args.inputs, args.shape_override = utils.split_nodename_and_shape(args.inputs)
    if args.outputs:
        args.outputs = args.outputs.split(",")
    if args.inputs_as_nchw:
        args.inputs_as_nchw = args.inputs_as_nchw.split(",")
    if args.target:
        args.target = args.target.split(",")
    if args.signature_def:
        args.signature_def = [args.signature_def]
    if args.extra_opset:
        tokens = args.extra_opset.split(':')
        if len(tokens) != 2:
            parser.error("invalid extra_opset argument")
        args.extra_opset = [utils.make_opsetid(tokens[0], int(tokens[1]))]

    return args
 def _run_test_case(self, func, output_names_with_port, feed_dict,
                    **kwargs):
     extra_opset = [utils.make_opsetid(constants.CONTRIB_OPS_DOMAIN, 1)]
     process_args = {"extra_opset": extra_opset}
     return self.run_test_case(func,
                               feed_dict, [],
                               output_names_with_port,
                               process_args=process_args,
                               **kwargs)
def get_args():
    """Parse commandline."""
    parser = argparse.ArgumentParser()
    parser.add_argument("--cache",
                        default=os.path.join(tempfile.gettempdir(),
                                             'pre-trained'),
                        help="pre-trained models cache dir")
    parser.add_argument("--config",
                        default="tests/run_pretrained_models.yaml",
                        help="yaml config to use")
    parser.add_argument("--tests", help="tests to run")
    parser.add_argument("--target", default="", help="target platform")
    parser.add_argument("--backend",
                        default="onnxruntime",
                        choices=["onnxruntime"],
                        help="backend to use")
    parser.add_argument("--opset", type=int, default=None, help="opset to use")
    parser.add_argument(
        "--extra_opset",
        default=None,
        help="extra opset with format like domain:version, e.g. com.microsoft:1"
    )
    parser.add_argument("--skip_tf_tests",
                        help="skip non-tflite tests",
                        default="False")
    parser.add_argument("--skip_tflite_tests",
                        help="skip tflite tests",
                        default="False")
    parser.add_argument("--skip_tfjs_tests",
                        help="skip tfjs tests",
                        default="False")
    parser.add_argument("--verbose",
                        "-v",
                        help="verbose output, option is additive",
                        action="count")
    parser.add_argument("--debug", help="debug mode", action="store_true")
    parser.add_argument("--list", help="list tests", action="store_true")
    parser.add_argument("--onnx-file", help="create onnx file in directory")
    parser.add_argument("--perf", help="capture performance numbers")
    parser.add_argument("--include-disabled",
                        help="include disabled tests",
                        action="store_true")
    args = parser.parse_args()

    args.target = args.target.split(",")
    args.skip_tf_tests = args.skip_tf_tests.upper() == "TRUE"
    args.skip_tflite_tests = args.skip_tflite_tests.upper() == "TRUE"
    args.skip_tfjs_tests = args.skip_tfjs_tests.upper() == "TRUE"
    if args.extra_opset:
        tokens = args.extra_opset.split(':')
        if len(tokens) != 2:
            raise ValueError("invalid extra_opset argument")
        args.extra_opset = [utils.make_opsetid(tokens[0], int(tokens[1]))]
    return args
示例#6
0
    def test_extra_opset(self):
        extra_opset = [
            utils.make_opsetid(constants.MICROSOFT_DOMAIN, 1),
            utils.make_opsetid("my.domain", 1024),
        ]
        with tf.Session() as sess:
            x = tf.placeholder(tf.float32, [2, 3], name="input1")
            x_ = tf.add(x, x)
            _ = tf.identity(x_, name="output")
            g = process_tf_graph(sess.graph,
                                 opset=self.config.opset,
                                 extra_opset=extra_opset)
            self.assertEqual(g.opset, self.config.opset)
            self.assertEqual(g.extra_opset, extra_opset)

            # convert between graph and model proto, make sure extra opset is preserved
            model_proto = g.make_model("test")
            model_proto = GraphUtil.optimize_model_proto(model_proto)
            g = GraphUtil.create_graph_from_onnx_model(model_proto)
            self.assertEqual(g.opset, self.config.opset)
            self.assertEqual(g.extra_opset, extra_opset)
示例#7
0
def test_ms_domain(versions=None):
    """ Parameterize test case to apply ms opset(s) as extra_opset. """
    def _custom_name_func(testcase_func, param_num, param):
        del param_num
        arg = param.args[0]
        return "%s_%s" % (testcase_func.__name__, arg.version)

    # Test all opset versions in ms domain if versions is not specified
    if versions is None:
        versions = list(range(1, _MAX_MS_OPSET_VERSION + 1))

    opsets = []
    for version in versions:
        opsets.append(
            [utils.make_opsetid(constants.MICROSOFT_DOMAIN, version)])
    return parameterized.expand(opsets, testcase_func_name=_custom_name_func)
示例#8
0
def get_args():
    """Parse commandline."""
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", help="input from graphdef")
    parser.add_argument("--graphdef", help="input from graphdef")
    parser.add_argument("--saved-model", help="input from saved model")
    parser.add_argument("--checkpoint", help="input from checkpoint")
    parser.add_argument("--output", help="output model file")
    parser.add_argument("--inputs", help="model input_names")
    parser.add_argument("--outputs", help="model output_names")
    parser.add_argument("--opset",
                        type=int,
                        default=None,
                        help="opset version to use for onnx domain")
    parser.add_argument("--custom-ops", help="list of custom ops")
    parser.add_argument(
        "--extra_opset",
        default=None,
        help="extra opset with format like domain:version, e.g. com.microsoft:1"
    )
    parser.add_argument("--target",
                        default=",".join(constants.DEFAULT_TARGET),
                        choices=constants.POSSIBLE_TARGETS,
                        help="target platform")
    parser.add_argument("--continue_on_error",
                        help="continue_on_error",
                        action="store_true")
    parser.add_argument("--verbose",
                        help="verbose output",
                        action="store_true")
    parser.add_argument(
        "--fold_const",
        help="enable tf constant_folding transformation before conversion",
        action="store_true")
    # experimental
    parser.add_argument("--inputs-as-nchw",
                        help="transpose inputs as from nhwc to nchw")
    # depreciated, going to be removed some time in the future
    parser.add_argument("--unknown-dim",
                        type=int,
                        default=-1,
                        help="default for unknown dimensions")
    args = parser.parse_args()

    args.shape_override = None
    if args.input:
        # for backward compativility
        args.graphdef = args.input
    if args.graphdef or args.checkpoint:
        if not args.input and not args.outputs:
            raise ValueError(
                "graphdef and checkpoint models need to provide inputs and outputs"
            )
    if not any([args.graphdef, args.checkpoint, args.saved_model]):
        raise ValueError("need input as graphdef, checkpoint or saved_model")
    if args.inputs:
        args.inputs, args.shape_override = utils.split_nodename_and_shape(
            args.inputs)
    if args.outputs:
        args.outputs = args.outputs.split(",")
    if args.inputs_as_nchw:
        args.inputs_as_nchw = args.inputs_as_nchw.split(",")
    if args.target:
        args.target = args.target.split(",")

    if args.extra_opset:
        tokens = args.extra_opset.split(':')
        if len(tokens) != 2:
            raise ValueError("invalid extra_opset argument")
        args.extra_opset = [utils.make_opsetid(tokens[0], int(tokens[1]))]
    return args
示例#9
0
def get_args():
    """Parse commandline."""
    parser = argparse.ArgumentParser(description="Convert tensorflow graphs to ONNX.",
                                     formatter_class=argparse.RawDescriptionHelpFormatter, epilog=_HELP_TEXT)
    parser.add_argument("--input", help="input from graphdef")
    parser.add_argument("--graphdef", help="input from graphdef")
    parser.add_argument("--saved-model", help="input from saved model")
    parser.add_argument("--tag", help="tag to use for saved_model")
    parser.add_argument("--signature_def", help="signature_def from saved_model to use")
    parser.add_argument("--concrete_function", type=int, default=None,
                        help="For TF2.x saved_model, index of func signature in __call__ (--signature_def is ignored)")
    parser.add_argument("--checkpoint", help="input from checkpoint")
    parser.add_argument("--keras", help="input from keras model")
    parser.add_argument("--tflite", help="input from tflite model")
    parser.add_argument("--tfjs", help="input from tfjs model")
    parser.add_argument("--large_model", help="use the large model format (for models > 2GB)", action="store_true")
    parser.add_argument("--output", help="output model file")
    parser.add_argument("--inputs", help="model input_names (optional for saved_model, keras, and tflite)")
    parser.add_argument("--outputs", help="model output_names (optional for saved_model, keras, and tflite)")
    parser.add_argument("--ignore_default", help="comma-separated list of names of PlaceholderWithDefault "
                                                 "ops to change into Placeholder ops")
    parser.add_argument("--use_default", help="comma-separated list of names of PlaceholderWithDefault ops to "
                                              "change into Identity ops using their default value")
    parser.add_argument("--rename-inputs", help="input names to use in final model (optional)")
    parser.add_argument("--rename-outputs", help="output names to use in final model (optional)")
    parser.add_argument("--use-graph-names", help="(saved model only) skip renaming io using signature names",
                        action="store_true")
    parser.add_argument("--opset", type=int, default=None, help="opset version to use for onnx domain")
    parser.add_argument("--dequantize", help="Remove quantization from model. Only supported for tflite currently.",
                        action="store_true")
    parser.add_argument("--custom-ops", help="Comma-separated map of custom ops to domains in format OpName:domain. "
                                             "Domain 'ai.onnx.converters.tensorflow' is used by default.")
    parser.add_argument("--extra_opset", default=None,
                        help="extra opset with format like domain:version, e.g. com.microsoft:1")
    parser.add_argument("--load_op_libraries",
                        help="comma-separated list of tf op library paths to register before loading model")
    parser.add_argument("--target", default=",".join(constants.DEFAULT_TARGET), choices=constants.POSSIBLE_TARGETS,
                        help="target platform")
    parser.add_argument("--continue_on_error", help="continue_on_error", action="store_true")
    parser.add_argument("--verbose", "-v", help="verbose output, option is additive", action="count")
    parser.add_argument("--debug", help="debug mode", action="store_true")
    parser.add_argument("--output_frozen_graph", help="output frozen tf graph to file")
    parser.add_argument("--fold_const", help="Deprecated. Constant folding is always enabled.",
                        action="store_true")
    # experimental
    parser.add_argument("--inputs-as-nchw", help="transpose inputs as from nhwc to nchw")
    args = parser.parse_args()

    args.shape_override = None
    if args.input:
        # for backward compativility
        args.graphdef = args.input
    if args.graphdef or args.checkpoint:
        if not args.inputs or not args.outputs:
            parser.error("graphdef and checkpoint models need to provide inputs and outputs")
    if not any([args.graphdef, args.checkpoint, args.saved_model, args.keras, args.tflite, args.tfjs]):
        parser.print_help()
        sys.exit(1)
    if args.inputs:
        args.inputs, args.shape_override = utils.split_nodename_and_shape(args.inputs)
    if args.outputs:
        args.outputs = args.outputs.split(",")
    if args.ignore_default:
        args.ignore_default = args.ignore_default.split(",")
    if args.use_default:
        args.use_default = args.use_default.split(",")
    if args.rename_outputs:
        args.rename_outputs = args.rename_outputs.split(",")
    if args.rename_inputs:
        args.rename_inputs = args.rename_inputs.split(",")
    if args.inputs_as_nchw:
        args.inputs_as_nchw = args.inputs_as_nchw.split(",")
    if args.target:
        args.target = args.target.split(",")
    if args.signature_def:
        args.signature_def = [args.signature_def]
    if args.dequantize:
        if not args.tflite:
            parser.error("dequantize flag is currently only supported for tflite")
    if args.extra_opset:
        tokens = args.extra_opset.split(':')
        if len(tokens) != 2:
            parser.error("invalid extra_opset argument")
        args.extra_opset = [utils.make_opsetid(tokens[0], int(tokens[1]))]
    if args.load_op_libraries:
        args.load_op_libraries = args.load_op_libraries.split(",")
    return args