示例#1
0
def from_graph_def(graph_def, name=None, input_names=None, output_names=None, opset=None, custom_ops=None,
                   custom_op_handlers=None, custom_rewriter=None, inputs_as_nchw=None, extra_opset=None,
                   shape_override=None, target=None, large_model=False, tensors_to_rename=None, output_path=None):
    """Returns a ONNX model_proto for a tensorflow graphdef.

    Args:
        graph_def: the graphdef we want to convert
        input_names: list of input names
        output_names: list of output names
        name: A name for the graph
        opset: the opset to be used for the ONNX model, default is the latest
        custom_ops: if a model contains ops not recognized by onnx runtime,
            you can tag these ops with a custom op domain so that the
            runtime can still open the model. Type is a dictionary `{op name: domain}`.
        target: list of workarounds applied to help certain platforms
        custom_op_handlers: dictionary of custom ops handlers
        custom_rewriter: list of custom graph rewriters
        extra_opset: list of extra opset's, for example the opset's used by custom ops
        shape_override: dict with inputs that override the shapes given by tensorflow
        inputs_as_nchw: transpose inputs in list from nchw to nhwc
        large_model: use the ONNX external tensor storage format
        output_path: save model to output_path

    Returns:
        An ONNX model_proto and an external_tensor_storage dict.
    """
    if not input_names:
        raise ValueError("input_names needs to be provided")
    if not output_names:
        raise ValueError("output_names needs to be provided")
    if not name:
        name = "unknown"
    initialized_tables = None

    with tf.device("/cpu:0"):
        with tf.Graph().as_default() as tf_graph:
            with tf_loader.tf_session(graph=tf_graph) as sess:
                tf.import_graph_def(graph_def, name='')
                frozen_graph = tf_loader.freeze_session(sess, input_names=input_names, output_names=output_names)
                input_names = tf_loader.inputs_without_resource(sess, input_names)
                frozen_graph = tf_loader.tf_optimize(input_names, output_names, graph_def)

    model_proto, external_tensor_storage = _convert_common(
        frozen_graph,
        name=name,
        continue_on_error=True,
        target=target,
        opset=opset,
        custom_op_handlers=custom_ops,
        extra_opset=extra_opset,
        shape_override=shape_override,
        input_names=input_names,
        output_names=output_names,
        inputs_as_nchw=inputs_as_nchw,
        large_model=large_model,
        tensors_to_rename=tensors_to_rename,
        initialized_tables=initialized_tables,
        output_path=output_path)

    return model_proto, external_tensor_storage
示例#2
0
    def load(self, model_path: Union[str, Path], **_) -> Model:
        if isinstance(model_path, Path):
            model_path = model_path.as_posix()

        get_model = load_from_file(model_path, "model", GET_MODEL_FN_NAME)
        get_serving_input_receiver_fn = load_from_file(model_path, "model", GET_SERVING_INPUT_RECEIVER_FN)

        if get_model is None:
            raise RuntimeError(f"Could not find {GET_MODEL_FN_NAME} in {model_path}")
        if get_serving_input_receiver_fn is None:
            raise RuntimeError(f"Could not find {GET_SERVING_INPUT_RECEIVER_FN} in {model_path}")

        model_args = filter_fn_args(self._model_args, fn=get_model)
        serving_input_receiver_args = filter_fn_args(self._model_args, fn=get_serving_input_receiver_fn)

        session_config = create_session_config(allow_growth=True)
        tf.compat.v1.reset_default_graph()
        with tf.compat.v1.Session(config=session_config) as sess:
            estimator = get_model(**model_args)
            serving_input_receiver_fn = get_serving_input_receiver_fn(**serving_input_receiver_args)

            input_receiver = serving_input_receiver_fn()
            estimator_spec = estimator.model_fn(
                features=input_receiver.features,
                labels=None,
                mode=tf.estimator.ModeKeys.PREDICT,
                config=estimator.config,
            )

            input_tensors_dict = input_receiver.receiver_tensors
            output_tensors_dict = estimator_spec.predictions
            inputs_dict = {k: tensor2tensor_spec(tensor) for k, tensor in input_tensors_dict.items()}
            outputs_dict = {k: tensor2tensor_spec(tensor) for k, tensor in output_tensors_dict.items()}

            input_tensor_names = [t.name for t in inputs_dict.values()]
            output_tensor_names = [t.name for t in outputs_dict.values()]

            graph_saver = estimator_spec.scaffold.saver or tf.compat.v1.train.Saver(sharded=True)
            graph_saver.restore(sess, estimator.latest_checkpoint())

            input_tensor_names = inputs_without_resource(sess, input_tensor_names)
            frozen_graph = freeze_session(sess, input_names=input_tensor_names, output_names=output_tensor_names)
            input_tensor_names = remove_redundant_inputs(frozen_graph, input_tensor_names)

        tf.compat.v1.reset_default_graph()
        with tf.compat.v1.Session(config=estimator.config.session_config):
            frozen_graph = tf_optimize(input_tensor_names, output_tensor_names, frozen_graph)
        tf.compat.v1.reset_default_graph()

        precision = _infer_model_precision(frozen_graph, inputs_dict, outputs_dict)

        return Model(frozen_graph, precision, inputs_dict, outputs_dict)