Пример #1
0
# SPDX-License-Identifier: Apache-2.0
"""
tf2onnx.graph_helper - class to help building graph, such as helping to make complex node
"""

import numpy as np
from tf2onnxnightly import utils, logging

# pylint: disable=missing-docstring

logger = logging.getLogger(__name__)


class GraphBuilder(object):
    """help to build graph"""
    def __init__(self, graph):
        self._g = graph

    @property
    def graph(self):
        return self._g

    def make_slice(self,
                   kwargs,
                   name=None,
                   shapes=None,
                   dtypes=None,
                   return_node=False):
        """
        slice changes its schema at opset 10: it treats some attributes as dynamic input
        so this function has to process inputs according to graph's opset version
Пример #2
0
def main():
    args = get_args()
    logging.basicConfig(level=logging.get_verbosity_level(args.verbose))
    if args.debug:
        utils.set_debug_mode(True)

    logger = logging.getLogger(constants.TF2ONNX_PACKAGE_NAME)

    extra_opset = args.extra_opset or []
    tflite_path = None
    custom_ops = {}
    initialized_tables = None
    tensors_to_rename = {}
    if args.custom_ops:
        using_tf_opset = False
        for op in args.custom_ops.split(","):
            if ":" in op:
                op, domain = op.split(":")
            else:
                # default custom ops for tensorflow-onnx are in the "tf" namespace
                using_tf_opset = True
                domain = constants.TENSORFLOW_OPSET.domain
            custom_ops[op] = (make_default_custom_op_handler(domain), [])
        if using_tf_opset:
            extra_opset.append(constants.TENSORFLOW_OPSET)

    if any(opset.domain == constants.CONTRIB_OPS_DOMAIN for opset in extra_opset):
        try:
            import tensorflow_text   # pylint: disable=import-outside-toplevel
        except ModuleNotFoundError:
            logger.warning("tensorflow_text not installed. Model will fail to load if tensorflow_text ops are used.")

    # get the frozen tensorflow model from graphdef, checkpoint or saved_model.
    graph_def = None
    inputs = None
    outputs = None
    model_path = None

    if args.graphdef:
        graph_def, inputs, outputs = tf_loader.from_graphdef(args.graphdef, args.inputs, args.outputs)
        model_path = args.graphdef
    if args.checkpoint:
        graph_def, inputs, outputs = tf_loader.from_checkpoint(args.checkpoint, args.inputs, args.outputs)
        model_path = args.checkpoint
    if args.saved_model:
        graph_def, inputs, outputs, initialized_tables, tensors_to_rename = tf_loader.from_saved_model(
            args.saved_model, args.inputs, args.outputs, args.tag, args.signature_def, args.concrete_function,
            args.large_model, return_initialized_tables=True, return_tensors_to_rename=True)
        model_path = args.saved_model
    if args.keras:
        graph_def, inputs, outputs = tf_loader.from_keras(
            args.keras, args.inputs, args.outputs)
        model_path = args.keras
    if args.tflite:
        tflite_path = args.tflite
        model_path = tflite_path

    if args.verbose:
        logger.info("inputs: %s", inputs)
        logger.info("outputs: %s", outputs)

    if args.rename_inputs:
        tensors_to_rename.update(zip(inputs, args.rename_inputs))
    if args.rename_outputs:
        tensors_to_rename.update(zip(outputs, args.rename_outputs))

    with tf.device("/cpu:0"):
        model_proto, _ = _convert_common(
            graph_def,
            name=model_path,
            continue_on_error=args.continue_on_error,
            target=args.target,
            opset=args.opset,
            custom_op_handlers=custom_ops,
            extra_opset=extra_opset,
            shape_override=args.shape_override,
            input_names=inputs,
            output_names=outputs,
            inputs_as_nchw=args.inputs_as_nchw,
            large_model=args.large_model,
            tensors_to_rename=tensors_to_rename,
            ignore_default=args.ignore_default,
            use_default=args.use_default,
            tflite_path=tflite_path,
            dequantize=args.dequantize,
            initialized_tables=initialized_tables,
            output_frozen_graph=args.output_frozen_graph,
            output_path=args.output)


    # write onnx graph
    logger.info("")
    logger.info("Successfully converted TensorFlow model %s to ONNX", model_path)

    logger.info("Model inputs: %s", [n.name for n in model_proto.graph.input])
    logger.info("Model outputs: %s", [n.name for n in model_proto.graph.output])
    if args.output:
        if args.large_model:
            logger.info("Zipped ONNX model is saved at %s. Unzip before opening in onnxruntime.", args.output)
        else:
            logger.info("ONNX model is saved at %s", args.output)
    else:
        logger.info("To export ONNX model to file, please run with `--output` option")
    import tensorflow.contrib.rnn  # pylint: disable=unused-import
except:  # pylint: disable=bare-except
    # not needed for tf-2.0
    pass

try:
    import tensorflow_text  # pylint: disable=unused-import
except ModuleNotFoundError:
    pass

from tf2onnxnightly import tf_loader, logging, optimizer, utils, tf_utils, constants
from tf2onnxnightly.tfonnx import process_tf_graph
from tf2onnxnightly.tf_loader import tf_session, tf_reset_default_graph
from tf2onnxnightly.graph import ExternalTensorStorage

logger = logging.getLogger("run_pretrained")

TEMP_DIR = os.path.join(utils.get_temp_directory(), "run_pretrained")
PERFITER = 1000


def get_img(shape, path, dtype, should_scale=True):
    """Get image as input."""
    resize_to = shape[1:3]
    path = os.path.join(os.path.dirname(os.path.abspath(__file__)), path)
    img = PIL.Image.open(path)
    img = img.resize(resize_to, PIL.Image.ANTIALIAS)
    img_np = np.array(img).astype(dtype)
    img_np = np.stack([img_np] * shape[0], axis=0).reshape(shape)
    if should_scale:
        img_np = img_np / 255