Beispiel #1
0
 def test_keras_with_tf2onnx(self):
     import tensorflow.keras as keras
     model = keras.Sequential()
     model.add(keras.layers.Dense(units=4, input_shape=(10,), activation='relu'))
     model.compile(loss='binary_crossentropy', optimizer='Adam', metrics=['binary_accuracy'])
     onnx_model = onnxmltools.convert_tensorflow(model)
     self.assertTrue(len(onnx_model.graph.node) > 0)
Beispiel #2
0
def get_onnx_model(model_format,
                   model,
                   initial_types: list = None,
                   final_types: list = None):
    if model_format == ModelFormat.KERAS:
        return onnxmltools.convert_keras(model)
    if model_format == ModelFormat.SK_LEARN:
        return onnxmltools.convert_sklearn(model, initial_types=initial_types)
    if model_format == ModelFormat.TENSORFLOW:
        return onnxmltools.convert_tensorflow(model)
Beispiel #3
0
 def test_keras_with_tf2onnx(self):
     try:
         import keras2onnx
     except (ImportError, AssertionError):
         warnings.warn("keras2onnx or one of its dependencies is missing.")
         return
     from keras2onnx.proto import keras
     from keras2onnx.proto.tfcompat import is_tf2
     if not is_tf2:  # tf2onnx is not available for tensorflow 2.0 yet.
         model = keras.Sequential()
         model.add(keras.layers.Dense(units=4, input_shape=(10,), activation='relu'))
         model.compile(loss='binary_crossentropy', optimizer='Adam', metrics=['binary_accuracy'])
         graph_def = keras2onnx.export_tf_frozen_graph(model)
         onnx_model = onnxmltools.convert_tensorflow(graph_def, **keras2onnx.build_io_names_tf2onnx(model))
         self.assertTrue(len(onnx_model.graph.node) > 0)
Beispiel #4
0
def convert_tensorflow_file(filename, opset, input_names, output_names):
    import tensorflow
    from tensorflow.core.framework import graph_pb2
    from tensorflow.python.tools import freeze_graph
    import onnx
    import tensorflow as tf

    graph_def = graph_pb2.GraphDef()
    with open(filename, 'rb') as file:
        graph_def.ParseFromString(file.read())
    converted_model = onnxmltools.convert_tensorflow(graph_def,
                                                     target_opset=opset,
                                                     input_names=[],
                                                     output_names=output_names)
    onnx.checker.check_model(converted_model)
    return converted_model