def from_function(func, input_names, output_names, large_model=False):
    if large_model:
        return convert_variables_to_constants_large_model(func)

    try:
        if get_tf_version() < LooseVersion("2.2"):
            frozen_func = convert_variables_to_constants_v2(
                func, lower_control_flow=False)
        else:
            frozen_func = convert_variables_to_constants_v2(
                func, lower_control_flow=False, aggressive_inlining=True)
    except ValueError as e:
        if "incompatible with expected resource" in str(e):
            frozen_func = convert_variables_to_constants_large_model(func)
            logger.warning(
                "TF freezing failed. Attempting to fix freezing errors.")
            graph_def = fix_freezing_errors(frozen_func)
        else:
            raise e
    else:
        graph_def = frozen_func.graph.as_graph_def(add_shapes=True)

    # output_names = [i.name for i in frozen_func.outputs]
    with tf.Graph().as_default() as tf_graph:
        with tf_session(graph=tf_graph) as sess:
            tf.import_graph_def(graph_def, name='')
            input_names = inputs_without_resource(sess, input_names)
            graph_def = tf_optimize(input_names, output_names, graph_def)
    return graph_def
def _freeze_saved_model_v2(concrete_func, control_flow_v2=False):
  if tf.__version__ < '2.2.0':
    return convert_to_constants.convert_variables_to_constants_v2(
        concrete_func, lower_control_flow=not control_flow_v2).graph

  return convert_to_constants.convert_variables_to_constants_v2(
      concrete_func, lower_control_flow=not control_flow_v2,
      aggressive_inlining=True).graph
Exemple #3
0
def convert_keras_model(model):
    """Converts a Keras model to TFLite flatbuffer.

    Returns:
      The converted data in serialized format.
    """
    if not tf.executing_eagerly():
        raise RuntimeError(
            "Graph mode is not supported. Please enable eager execution using "
            "tf.enable_eager_execution() when using TensorFlow 1.x")
    func = concrete_function_from_keras_model(model)
    if version.parse(tf.__version__) >= version.parse("1.15"):
        frozen_func = convert_variables_to_constants_v2(
            func, lower_control_flow=False)
    else:
        frozen_func = convert_variables_to_constants_v2(func)
    input_tensors = [
        tensor for tensor in frozen_func.inputs
        if tensor.dtype != tf.dtypes.resource
    ]
    output_tensors = frozen_func.outputs

    graph_def = frozen_func.graph.as_graph_def()
    # Run a constant folding using grappler since we currently don't implement
    # folding for LCE custom ops
    graph_def = run_graph_optimizations(
        graph_def,
        input_tensors,
        output_tensors,
        config=get_grappler_config(["constfold"]),
        graph=frozen_func.graph,
    )

    # Checks dimensions in input tensor.
    for tensor in input_tensors:
        # Note that shape_list might be empty for scalar shapes.
        shape_list = tensor.shape.as_list()
        if None in shape_list[1:]:
            raise ValueError(
                "None is only supported in the 1st dimension. Tensor '{0}' has "
                "invalid shape '{1}'.".format(get_tensor_name(tensor),
                                              shape_list))
        elif shape_list and shape_list[0] is None:
            # Set the batch size to 1 if undefined.
            shape = tensor.shape.as_list()
            shape[0] = 1
            tensor.set_shape(shape)

    return convert_graphdef_to_tflite_flatbuffer(
        graph_def.SerializeToString(),
        [get_tensor_name(tensor) for tensor in input_tensors],
        [
            DataType.Name(tensor.dtype.as_datatype_enum)
            for tensor in input_tensors
        ],
        [tensor.shape.as_list() for tensor in input_tensors],
        [get_tensor_name(tensor) for tensor in output_tensors],
    )
  def testConstSavedModel(self):
    """Test a basic model with functions to make sure functions are inlined."""
    input_data = constant_op.constant(1., shape=[1])
    root = tracking.AutoTrackable()
    root.f = def_function.function(lambda x: 2. * x)
    to_save = root.f.get_concrete_function(input_data)

    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    save(root, save_dir, to_save)
    saved_model = load(save_dir)
    input_func = saved_model.signatures["serving_default"]

    variable_graph_def = input_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(variable_graph_def))
    self.assertTrue(variable_graph_def.library.function)

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    constant_graph_def = output_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(constant_graph_def))
    self.assertFalse(constant_graph_def.library.function)

    # Check value.
    expected_value = root.f(input_data)
    actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
                                          [input_data.numpy()])
    self.assertEqual(expected_value.numpy(), actual_value)
  def testKerasModel(self):
    input_data = constant_op.constant(1., shape=[1, 1])

    # Create a simple Keras model.
    x = [-1, 0, 1, 2, 3, 4]
    y = [-3, -1, 1, 3, 5, 7]

    model = keras.models.Sequential(
        [keras.layers.Dense(units=1, input_shape=[1])])
    model.compile(optimizer="sgd", loss="mean_squared_error")
    model.fit(x, y, epochs=1)

    # Get the concrete function from the Keras model.
    @def_function.function
    def to_save(x):
      return model(x)

    input_func = to_save.get_concrete_function(input_data)

    variable_graph_def = input_func.graph.as_graph_def()
    self.assertEqual(2, self._getNumVariables(variable_graph_def))

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    constant_graph_def = output_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(constant_graph_def))
    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))

    # Check value.
    expected_value = to_save(input_data)
    actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
                                          [input_data.numpy()])
    self.assertEqual(expected_value.numpy(), actual_value)
  def testConstructConcreteFunction(self):
    input_data = constant_op.constant(1., shape=[1])
    root = tracking.AutoTrackable()
    root.v1 = variables.Variable(3.)
    root.v2 = variables.Variable(2.)
    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
    func = root.f.get_concrete_function(input_data)

    input_func = convert_to_constants._construct_concrete_function(
        func, func.graph.as_graph_def())

    # Test if model has enough metadata to be frozen afterwards.
    variable_graph_def = input_func.graph.as_graph_def()
    self.assertEqual(2, self._getNumVariables(variable_graph_def))

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    constant_graph_def = output_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(constant_graph_def))
    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))

    # Check value.
    expected_value = root.f(input_data)
    actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
                                          [input_data.numpy()])
    self.assertEqual(expected_value.numpy(), actual_value)
Exemple #7
0
def save_model(path, name, model, suffix='tf'):
    """
    Save model for h5/json/pb/tf/frozen

    :param path: str.
    :param name: str.
    :param model: tf.keras.models.Model.
    :param suffix: str. For different format.
    """
    if not os.path.exists(path):
        os.mkdir(path)
    if suffix == 'h5':
        model.save(path + "/" + name + ".h5")
        # tf.keras.models.save_model(model, path, save_format='h5')
    elif suffix == 'json':
        model_json = model.to_json()
        with open(path + "/" + name + ".json", "w") as json_file:
            json_file.write(model_json)
        model.save_weights(path + "/" + name + ".h5")
    elif suffix == 'pb' or suffix == 'tf':
        tf.keras.models.save_model(model, path, save_format='tf')
    elif suffix == 'frozen':
        full_model = tf.function(lambda x: model(x))
        full_model = full_model.get_concrete_function(
            x=tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
        frozen_func = convert_variables_to_constants_v2(full_model)
        frozen_func.graph.as_graph_def()
        tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                          logdir=path,
                          name=name + '.pb',
                          as_text=False)
    else:
        raise TypeError('suffix=\'%s\'' % suffix)
Exemple #8
0
    def testMultiFunctionModel(self):
        """Test a basic model with Variables."""
        class BasicModel(tracking.AutoTrackable):
            def __init__(self):
                self.y = None
                self.z = None

            @def_function.function
            def add(self, x):
                if self.y is None:
                    self.y = variables.Variable(2.)
                return x + self.y

            @def_function.function
            def sub(self, x):
                if self.z is None:
                    self.z = variables.Variable(3.)
                return x - self.z

        input_data = {"x": constant_op.constant(1., shape=[1])}
        root = BasicModel()
        input_func = root.add.get_concrete_function(input_data["x"])

        variable_graph_def = input_func.graph.as_graph_def()
        self.assertEqual(1, self._getNumVariables(variable_graph_def))

        output_func = convert_to_constants.convert_variables_to_constants_v2(
            input_func)
        constant_graph_def = output_func.graph.as_graph_def()
        self.assertEqual(0, self._getNumVariables(constant_graph_def))
        self.assertFalse(
            self._hasStatefulPartitionedCallOp(constant_graph_def))

        self._testConvertedFunction(root, root.add, output_func, input_data)
  def testVariableSavedModel(self):
    """Test a basic model with Variables with saving/loading the SavedModel."""
    input_data = constant_op.constant(1., shape=[1])
    root = tracking.AutoTrackable()
    root.v1 = variables.Variable(3.)
    root.v2 = variables.Variable(2.)
    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
    to_save = root.f.get_concrete_function(input_data)

    save_dir = os.path.join(self.get_temp_dir(), "saved_model")
    save(root, save_dir, to_save)
    saved_model = load(save_dir)
    input_func = saved_model.signatures["serving_default"]

    variable_graph_def = input_func.graph.as_graph_def()
    self.assertTrue(self._hasStatefulPartitionedCallOp(variable_graph_def))

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    constant_graph_def = output_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(constant_graph_def))
    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))

    # Check value.
    expected_value = root.f(input_data)
    actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
                                          [input_data.numpy()])
    self.assertEqual(expected_value.numpy(), actual_value)
Exemple #10
0
def freezeModel(model):
    # Convert Keras model to ConcreteFunction
    full_model = tf.function(lambda x: model(x))
    full_model = full_model.get_concrete_function(
        tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()

    layers = [op.name for op in frozen_func.graph.get_operations()]
    print("-" * 50)
    print("Frozen model layers: ")
    for layer in layers:
        print(layer)

    print("-" * 50)
    print("Frozen model inputs: ")
    print(frozen_func.inputs)
    print("Frozen model outputs: ")
    print(frozen_func.outputs)

    # Save frozen graph from frozen ConcreteFunction to hard drive
    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir="./frozen_models",
                      name="frozen_graph.pb",
                      as_text=False)
Exemple #11
0
    def freeze_model(self):
        """ Freezes the model """

        tf.saved_model.save(self.tf_model,
                            self.helpers.confs["model"]["saved_model_dir"])

        fmodel = tf.function(lambda x: self.tf_model(x))
        fmodel = fmodel.get_concrete_function(x=tf.TensorSpec(
            self.tf_model.inputs[0].shape, self.tf_model.inputs[0].dtype))

        freeze = convert_variables_to_constants_v2(fmodel)
        freeze.graph.as_graph_def()

        layers = [op.name for op in freeze.graph.get_operations()]
        self.helpers.logger.info("Frozen model layers")
        for layer in layers:
            self.helpers.logger.info(layer)

        self.helpers.logger.info("Frozen model inputs")
        self.helpers.logger.info(freeze.inputs)
        self.helpers.logger.info("Frozen model outputs")
        self.helpers.logger.info(freeze.outputs)

        tf.io.write_graph(
            graph_or_graph_def=freeze.graph,
            logdir=self.helpers.confs["model"]["freezing_log_dir"],
            name=self.helpers.confs["model"]["frozen"],
            as_text=False)
Exemple #12
0
def get_func_from_saved_model(saved_model_dir):
  saved_model_loaded = tf.saved_model.load(
      saved_model_dir, tags=[tag_constants.SERVING])
  graph_func = saved_model_loaded.signatures[
      signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
  graph_func = convert_to_constants.convert_variables_to_constants_v2(graph_func)
  return graph_func
Exemple #13
0
    def testStatelessIf(self):
        """Test whether StatelessIf op freezes correctly."""
        input_data = {"b": constant_op.constant(True)}

        x = constant_op.constant([1., 2.], shape=[1, 2], name="x")

        def true_fn():
            return x

        def false_fn():
            return x + 2

        @def_function.function(input_signature=[
            tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool)
        ])
        def model(b):
            return cond_v2.cond_v2(b, true_fn, false_fn)

        root = tracking.AutoTrackable()
        root.f = model
        input_func = root.f.get_concrete_function()
        input_func(**input_data)

        output_func = convert_to_constants.convert_variables_to_constants_v2(
            input_func, lower_control_flow=False)
        constant_graph_def = output_func.graph.as_graph_def()
        self.assertEqual(0, self._getNumVariables(constant_graph_def))
        self.assertFalse(
            self._hasStatefulPartitionedCallOp(constant_graph_def))

        self._testConvertedFunction(root, root.f, output_func, input_data)
Exemple #14
0
def _graph_def_from_saved_model_or_keras_model(filename):
    """
    Utility function that returns GraphDef object from the given SavedModel or HDF5 model.
    :param filename: TensorFlow SavedModel directory or Keras HDF5 model (.h5) file.
    :return: TensorFlow GraphDef object.
    """
    try:
        import tensorflow as tf
        from tensorflow.python.keras.saving import saving_utils as _saving_utils
        from tensorflow.python.framework import convert_to_constants as _convert_to_constants
        model = tf.keras.models.load_model(filename)
        tf.keras.backend.set_learning_phase(False)
        func = _saving_utils.trace_model_call(model)
        concrete_func = func.get_concrete_function()
        # concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
        frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
            concrete_func)
        graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
    except ImportError as e:
        raise ImportError(
            'Failed to import TensorFlow utilities. {}.'.format(e))
    except Exception as e:
        raise RuntimeError(
            'Failed to load SavedModel or .h5 model. {}.'.format(e))
    return graph_def
Exemple #15
0
    def testLoop(self):
        input_data = {
            "x": constant_op.constant([1., 2., 3., 4.], shape=[2, 2])
        }

        weights = variables.Variable([[0.1, 0.2], [0.3, 0.4]],
                                     dtype=dtypes.float32)

        def condition(x):
            return math_ops.reduce_sum(x) < 100

        def body(x):
            return math_ops.add(x, weights)

        @def_function.function(input_signature=[
            tensor_spec.TensorSpec(shape=[2, 2], dtype=dtypes.float32)
        ])
        def model(x):
            return control_flow_ops.while_loop(condition, body, [x])

        root = tracking.AutoTrackable()
        root.f = model
        input_func = root.f.get_concrete_function()
        input_func(**input_data)

        output_func = convert_to_constants.convert_variables_to_constants_v2(
            input_func, lower_control_flow=False)
        constant_graph_def = output_func.graph.as_graph_def()
        self.assertEqual(0, self._getNumVariables(constant_graph_def))
        self.assertFalse(
            self._hasStatefulPartitionedCallOp(constant_graph_def))

        self._testConvertedFunction(root, root.f, output_func, input_data)
    def __init__(self, keras_model_path, inputshape, in_nodes, dest_nodes):
        if LooseVersion(tensorflow.__version__) < LooseVersion('1.8.0'):
            raise ImportError(
                'Your TensorFlow version %s is outdated. '
                'MMdnn requires tensorflow>=1.8.0' % tensorflow.__version__)

        super(TensorflowParser2, self).__init__()
        self.weight_loaded = True

        import tensorflow as tf
        from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
        model = tf.keras.models.load_model(keras_model_path, compile=False)
        full_model = tf.function(lambda x: model(x))
        full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
        frozen_func = convert_variables_to_constants_v2(full_model)
        frozen_func.graph.as_graph_def()
        g = frozen_func.graph

        from tensorflow.python.client.session import Session
        from tensorflow.python.training.saver import export_meta_graph
        with Session(graph = g) as sess:
            tempdir = tempfile.mkdtemp()
            meta_graph_def = export_meta_graph(filename=os.path.join(tempdir, 'my-model.meta'))
            model = meta_graph_def.graph_def
            shutil.rmtree((tempdir))

        self.tf_graph = TensorflowGraph(model)
        self.tf_graph.build()
def to_frozen_graph(model: tf.keras.Model):
    """
    Returns a frozen graph def for a tf.keras model.
    :param model: The tf.keras model we want to convert.
    :return: The frozen graph def for the provided model.
    """
    func = saving_utils.trace_model_call(model)
    concrete_func = func.get_concrete_function()

    graph_captures = concrete_func.graph._captures # pylint: disable=protected-access
    captured_inputs = [t_name.name for t_val, t_name in graph_captures.values()]

    input_names = [input_tensor.name for input_tensor in concrete_func.inputs
                   if input_tensor.name not in captured_inputs]

    output_names = [output_tensor.name for output_tensor in concrete_func.outputs
                    if output_tensor.dtype != tf.dtypes.resource]

    with tf.device("/cpu:0"):
        frozen_func = convert_variables_to_constants_v2(concrete_func,
                                                        lower_control_flow=False,
                                                        aggressive_inlining=True)
        graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
        with tf.Graph().as_default(): # pylint: disable=not-context-manager
            tf.import_graph_def(graph_def, name='')
            frozen_graph = tf_optimize_grappler(input_names, output_names, graph_def)

    return frozen_graph
Exemple #18
0
    def _convert_saved_model_v2(self):
        """Convert the input SavedModel in 2.0 format."""
        assert context.executing_eagerly()

        self._saved_model = load.load(self._input_saved_model_dir,
                                      self._input_saved_model_tags)
        func = self._saved_model.signatures[
            self._input_saved_model_signature_key]
        frozen_func = convert_to_constants.convert_variables_to_constants_v2(
            func)
        self._grappler_meta_graph_def = saver.export_meta_graph(
            graph_def=frozen_func.graph.as_graph_def(),
            graph=frozen_func.graph)

        # Add a collection 'train_op' so that Grappler knows the outputs.
        fetch_collection = meta_graph_pb2.CollectionDef()
        for array in frozen_func.inputs + frozen_func.outputs:
            fetch_collection.node_list.value.append(array.name)
        self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
            fetch_collection)

        # Run TRT optimizer in Grappler to convert the graph.
        self._run_conversion()
        self._converted_func = wrap_function.function_from_graph_def(
            self._converted_graph_def,
            [tensor.name for tensor in frozen_func.inputs],
            [tensor.name for tensor in frozen_func.outputs])
Exemple #19
0
def h5_to_pb(h5_save_path):
    model = tf.keras.models.load_model(h5_save_path, compile=False)
    model.summary()
    full_model = tf.function(lambda Input: model(Input))
    full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

    # Get frozen ConcreteFunction
    frozen_func = convert_variables_to_constants_v2(full_model)
    frozen_func.graph.as_graph_def()

    layers = [op.name for op in frozen_func.graph.get_operations()]
    print("-" * 50)
    print("Frozen model layers: ")
    for layer in layers:
        print(layer)

    print("-" * 50)
    print("Frozen model inputs: ")
    print(frozen_func.inputs)
    print("Frozen model outputs: ")
    print(frozen_func.outputs)

    # Save frozen graph from frozen ConcreteFunction to hard drive
    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir="./frozen_models2",
                      name="model.pb",
                      as_text=False)
Exemple #20
0
 def load_graph(self, model_path):
     import tensorflow as tf
     tf.compat.v1.reset_default_graph()
     path_stem = os.path.dirname(model_path)
     if path_stem.endswith('saved_model'):
         imported = tf.saved_model.load(path_stem)
         from tensorflow.python.framework.convert_to_constants\
             import convert_variables_to_constants_v2
         all_sigs = imported.signatures.keys()
         signatures = [s for s in all_sigs if not s.startswith("_")]
         func = imported.signatures[signatures[0]]
         frozen_func = convert_variables_to_constants_v2(
             func, lower_control_flow=False)
         graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
         with tf.compat.v1.Session() as sess:
             tf.import_graph_def(graph_def, name='')
             return sess.graph, self.count_ops(sess.graph)
     else:
         with tf.compat.v1.Session() as sess:
             graph_def = tf.compat.v1.GraphDef()
             # print ("graph_def version:", graph_def.version)
             with tf.io.gfile.GFile(model_path, 'rb') as model_f:
                 graph_def.ParseFromString(model_f.read())
                 tf.import_graph_def(graph_def, name='')
                 return sess.graph, self.count_ops(sess.graph)
def from_function(func, input_names, output_names, large_model=False):
    if large_model:
        return convert_variables_to_constants_large_model(func)

    if get_tf_version() < LooseVersion("2.2"):
        frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False)
    else:
        frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False, aggressive_inlining=True)
    graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
    # output_names = [i.name for i in frozen_func.outputs]
    tf_reset_default_graph()
    with tf_session() as sess:
        tf.import_graph_def(graph_def, name='')
        input_names = inputs_without_resource(sess, input_names)
        graph_def = tf_optimize(input_names, output_names, graph_def)
    return graph_def
Exemple #22
0
    def testDynamicRnn(self):
        input_data = {
            "x":
            constant_op.constant(
                np.array(np.random.random_sample((3, 10, 10)),
                         dtype=np.float32))
        }

        cell = rnn_cell_impl.LSTMCell(10)

        @def_function.function(input_signature=[
            tensor_spec.TensorSpec(shape=[3, 10, 10], dtype=dtypes.float32)
        ])
        def model(x):
            return rnn.dynamic_rnn(cell, x, dtype=dtypes.float32)

        root = tracking.AutoTrackable()
        root.f = model
        input_func = root.f.get_concrete_function()

        output_func = convert_to_constants.convert_variables_to_constants_v2(
            input_func, lower_control_flow=False)
        constant_graph_def = output_func.graph.as_graph_def()
        self.assertEqual(0, self._getNumVariables(constant_graph_def))
        self.assertFalse(
            self._hasStatefulPartitionedCallOp(constant_graph_def))

        self._testConvertedFunction(root, root.f, output_func, input_data)
    def testConstructConcreteFunction(self):
        input_data = constant_op.constant(1., shape=[1])
        root = tracking.AutoTrackable()
        root.v1 = variables.Variable(3.)
        root.v2 = variables.Variable(2.)
        root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
        func = root.f.get_concrete_function(input_data)

        input_func = convert_to_constants._construct_concrete_function(
            func, func.graph.as_graph_def(), {})

        # Test if model has enough metadata to be frozen afterwards.
        variable_graph_def = input_func.graph.as_graph_def()
        self.assertEqual(2, self._getNumVariables(variable_graph_def))

        output_func = convert_to_constants.convert_variables_to_constants_v2(
            input_func)
        constant_graph_def = output_func.graph.as_graph_def()
        self.assertEqual(0, self._getNumVariables(constant_graph_def))
        self.assertFalse(
            self._hasStatefulPartitionedCallOp(constant_graph_def))

        # Check value.
        expected_value = root.f(input_data)
        actual_value = nest.flatten(output_func(input_data))
        self.assertEqual(expected_value.numpy(), actual_value)
Exemple #24
0
def get_tf_model_proto(tf_model):
    # define the directory for .pb model
    pb_model_path = "models"

    # define the name of .pb model
    pb_model_name = "mobilenet.pb"

    # create directory for further converted model
    os.makedirs(pb_model_path, exist_ok=True)

    # get model TF graph
    tf_model_graph = tf.function(lambda x: tf_model(x))

    # get concrete function
    tf_model_graph = tf_model_graph.get_concrete_function(
        tf.TensorSpec(tf_model.inputs[0].shape, tf_model.inputs[0].dtype))

    # obtain frozen concrete function
    frozen_tf_func = convert_variables_to_constants_v2(tf_model_graph)
    # get frozen graph
    frozen_tf_func.graph.as_graph_def()

    # save full tf model
    tf.io.write_graph(graph_or_graph_def=frozen_tf_func.graph,
                      logdir=pb_model_path,
                      name=pb_model_name,
                      as_text=False)

    return os.path.join(pb_model_path, pb_model_name)
    def benchmark_tensorrt_inference(self, model):
        # Loading TensorRT model
        print(f"Loading model {model}")
        saved_model_loaded = tf.saved_model.load(
            f"tensorrt_models/{model}", tags=[tag_constants.SERVING])
        graph_func = saved_model_loaded.signatures[
            signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
        frozen_func = convert_variables_to_constants_v2(graph_func)


        inference_timings = []
        print("Warmup")
        for i in range(50):
            chip, chip_id = self.chip_files[0]
            chip = np.expand_dims(np.array(chip), axis=0)
            frozen_func(tf.constant(chip.astype(np.float32)))

        # Run benchmark
        for chip, chip_id in tqdm(self.chip_files):
            chip = np.expand_dims(np.array(chip), axis=0)
            start = timer()
            predicted_chip = frozen_func(tf.constant(chip.astype(np.float32)))[0].numpy()
            end = timer()
            inference_timings.append(end - start)
            self.predictions.append((np.squeeze(predicted_chip), chip_id))
        print(f"Mean: {np.mean(inference_timings)}, STD: {np.std(inference_timings)}, MEDIAN: {np.median(inference_timings)}")
        benchmark_summary = {"timings": inference_timings,
                       "mean": np.mean(inference_timings),
                       "std": np.std(inference_timings),
                       "median": np.median(inference_timings),
                       "90_perc": np.percentile(inference_timings, 90)}
        with open(f"model_scores/{model}-inference_benchmark.json", "w") as inference_json:
            json.dump(benchmark_summary, inference_json)
Exemple #26
0
 def save_and_reload(_model):
     with tempfile.TemporaryDirectory() as model_path:
         tf.saved_model.save(_model, model_path)
         loaded = tf.saved_model.load(model_path)
         func = loaded.signatures["serving_default"]
         frozen_func = convert_variables_to_constants_v2(func)
     return frozen_func
Exemple #27
0
def load_graph_func(saved_model_dir: str, saved_model_tags: str,
                    saved_model_signature_key: str):
    """Loads a graph function in TF2."""
    imported = saved_model_load.load(export_dir=saved_model_dir,
                                     tags=saved_model_tags)
    graph_func = imported.signatures[saved_model_signature_key]
    return convert_to_constants.convert_variables_to_constants_v2(graph_func)
Exemple #28
0
def model_to_graph(model: Callable) -> tf.Graph:
    """Convert a keras model into a wrapped function graph"""
    input_specs = [tf.TensorSpec(t.shape, t.dtype) for t in model.inputs]
    wrapped_function = tf.function(lambda x: model(x))
    wrapped_function = wrapped_function.get_concrete_function(input_specs)
    frozen_function = cc.convert_variables_to_constants_v2(wrapped_function)
    return frozen_function.graph
Exemple #29
0
    def testKerasModel(self):
        input_data = constant_op.constant(1., shape=[1, 1])

        # Create a simple Keras model.
        x = [-1, 0, 1, 2, 3, 4]
        y = [-3, -1, 1, 3, 5, 7]

        model = keras.models.Sequential(
            [keras.layers.Dense(units=1, input_shape=[1])])
        model.compile(optimizer="sgd", loss="mean_squared_error")
        model.fit(x, y, epochs=1)

        # Get the concrete function from the Keras model.
        @def_function.function
        def to_save(x):
            return model(x)

        input_func = to_save.get_concrete_function(input_data)

        variable_graph_def = input_func.graph.as_graph_def()
        self.assertEqual(2, self._getNumVariables(variable_graph_def))

        output_func = convert_to_constants.convert_variables_to_constants_v2(
            input_func)
        constant_graph_def = output_func.graph.as_graph_def()
        self.assertEqual(0, self._getNumVariables(constant_graph_def))
        self.assertFalse(
            self._hasStatefulPartitionedCallOp(constant_graph_def))

        # Check value.
        expected_value = to_save(input_data)
        actual_value = nest.flatten(output_func(input_data))
        self.assertEqual(expected_value.numpy(), actual_value)
Exemple #30
0
def export_model(ca: CAProtoModel, base_filename: str) -> None:
    """
    Saves the weights of the model to a file beginning with base filename (this can be a path). Included
    will be a JSON file `base_filename.json` with the model format, topology, and weights manifest.
    :param ca: The model to export
    :param base_filename: Base file name/path at which to export it
    :return: None
    """
    ca.save_weights(base_filename)

    # TODO: do any of these need adjusting?
    cf = ca.call.get_concrete_function(x=tf.TensorSpec(
        [None, None, None, ca.channel_n]),
                                       fire_rate=tf.constant(0.5),
                                       angle=tf.constant(0.0),
                                       step_size=tf.constant(1.0))
    cf = convert_to_constants.convert_variables_to_constants_v2(cf)
    graph_def = cf.graph.as_graph_def()
    graph_json = MessageToDict(graph_def)
    graph_json['versions'] = dict(producer='1.14', minConsumer='1.14')
    model_json = {
        'format': 'graph-model',
        'modelTopology': graph_json,
        'weightsManifest': [],
    }
    with open(base_filename + '.json', 'w') as f:
        json.dump(model_json, f)
Exemple #31
0
def export_pb(keras_model, im, file, prefix=colorstr('TensorFlow GraphDef:')):
    # YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
    try:
        import tensorflow as tf
        from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

        LOGGER.info(
            f'\n{prefix} starting export with tensorflow {tf.__version__}...')
        f = file.with_suffix('.pb')

        m = tf.function(lambda x: keras_model(x))  # full model
        m = m.get_concrete_function(
            tf.TensorSpec(keras_model.inputs[0].shape,
                          keras_model.inputs[0].dtype))
        frozen_func = convert_variables_to_constants_v2(m)
        frozen_func.graph.as_graph_def()
        tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                          logdir=str(f.parent),
                          name=f.name,
                          as_text=False)

        LOGGER.info(
            f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
    except Exception as e:
        LOGGER.info(f'\n{prefix} export failure: {e}')
    def save_frozen_graph_tf2(self):
        from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

        # Convert Keras model to ConcreteFunction
        full_model = tf.function(self._model).get_concrete_function(
            tf.TensorSpec(self._model.inputs[0].shape,
                          self._model.inputs[0].dtype))

        # Get frozen ConcreteFunction
        frozen_func = convert_variables_to_constants_v2(full_model)
        frozen_func.graph.as_graph_def()

        layers = [op.name for op in frozen_func.graph.get_operations()]
        print("-" * 50)
        print("Frozen model layers: ")
        for layer in layers:
            print(layer)

        print("-" * 50)
        print("Frozen model inputs: ")
        print(frozen_func.inputs)
        print("Frozen model outputs: ")
        print(frozen_func.outputs)

        # Save frozen graph from frozen ConcreteFunction to hard drive
        saved_model_folder = self.build_saved_model_folder()
        frozen_graph_file_name = self.build_frozen_graph_file_name()
        tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                          logdir=saved_model_folder,
                          name=frozen_graph_file_name,
                          as_text=False)
Exemple #33
0
    def testVariableSavedModel(self):
        """Test a basic model with Variables with saving/loading the SavedModel."""
        input_data = {"x": constant_op.constant(1., shape=[1])}
        root = tracking.AutoTrackable()
        root.v1 = variables.Variable(3.)
        root.v2 = variables.Variable(2.)
        root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
        to_save = root.f.get_concrete_function(input_data["x"])

        save_dir = os.path.join(self.get_temp_dir(), "saved_model")
        save(root, save_dir, to_save)
        saved_model = load(save_dir)
        input_func = saved_model.signatures["serving_default"]

        variable_graph_def = input_func.graph.as_graph_def()
        self.assertTrue(self._hasStatefulPartitionedCallOp(variable_graph_def))

        output_func = convert_to_constants.convert_variables_to_constants_v2(
            input_func)
        constant_graph_def = output_func.graph.as_graph_def()
        self.assertEqual(0, self._getNumVariables(constant_graph_def))
        self.assertFalse(
            self._hasStatefulPartitionedCallOp(constant_graph_def))

        self._testConvertedFunction(root, root.f, output_func, input_data)
    def testConstSavedModel(self):
        """Test a basic model with functions to make sure functions are inlined."""
        input_data = constant_op.constant(1., shape=[1])
        root = tracking.AutoTrackable()
        root.f = def_function.function(lambda x: 2. * x)
        to_save = root.f.get_concrete_function(input_data)

        save_dir = os.path.join(self.get_temp_dir(), "saved_model")
        save(root, save_dir, to_save)
        saved_model = load(save_dir)
        input_func = saved_model.signatures["serving_default"]

        variable_graph_def = input_func.graph.as_graph_def()
        self.assertEqual(0, self._getNumVariables(variable_graph_def))
        self.assertTrue(variable_graph_def.library.function)

        output_func = convert_to_constants.convert_variables_to_constants_v2(
            input_func)
        constant_graph_def = output_func.graph.as_graph_def()
        self.assertEqual(0, self._getNumVariables(constant_graph_def))
        self.assertFalse(constant_graph_def.library.function)

        # Check value.
        expected_value = root.f(input_data)
        actual_value = nest.flatten(output_func(input_data))
        self.assertEqual(expected_value.numpy(), actual_value)
Exemple #35
0
  def _convert_saved_model_v2(self):
    """Convert the input SavedModel in 2.0 format."""
    self._saved_model = load.load(self._input_saved_model_dir,
                                  self._input_saved_model_tags)
    func = self._saved_model.signatures[self._input_saved_model_signature_key]
    frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
    self._grappler_meta_graph_def = saver.export_meta_graph(
        graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)

    # Add a collection 'train_op' so that Grappler knows the outputs.
    fetch_collection = meta_graph_pb2.CollectionDef()
    for array in func.inputs + func.outputs:
      fetch_collection.node_list.value.append(array.name)
    self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
        fetch_collection)

    # Run TRT optimizer in Grappler to convert the graph.
    self._run_conversion()

    def _get_tensor(graph, tensors):
      new_tensors = []
      for tensor in tensors:
        new_tensor = graph.get_tensor_by_name(tensor.name)
        new_tensor.set_shape(tensor.shape)
        new_tensors.append(new_tensor)
      return new_tensors

    # TODO(laigd): do we need to use different name e.g. "trt_func_graph"?
    converted_graph = func_graph.FuncGraph(func.graph.name)
    with converted_graph.as_default():
      importer.import_graph_def(self._converted_graph_def, name="")

    converted_graph.inputs = _get_tensor(converted_graph, func.graph.inputs)
    converted_graph.outputs = _get_tensor(converted_graph, func.graph.outputs)
    converted_graph.structured_outputs = func.graph.structured_outputs
    converted_graph.structured_input_signature = (
        func.graph.structured_input_signature)

    # pylint: disable=protected-access
    # TODO(laigd): should we set up the signature as well?
    self._converted_func = function.ConcreteFunction(
        converted_graph, attrs=None, signature=None)
    self._converted_func.add_to_graph()
    self._converted_func._arg_keywords = func._arg_keywords
    self._converted_func._num_positional_args = func._num_positional_args
    self._converted_func._captured_inputs = func._captured_inputs
    self._converted_func.graph.variables = func.graph.variables
  def testMultiFunctionModel(self):
    """Test a basic model with Variables."""

    class BasicModel(tracking.AutoTrackable):

      def __init__(self):
        self.y = None
        self.z = None

      @def_function.function
      def add(self, x):
        if self.y is None:
          self.y = variables.Variable(2.)
        return x + self.y

      @def_function.function
      def sub(self, x):
        if self.z is None:
          self.z = variables.Variable(3.)
        return x - self.z

    input_data = constant_op.constant(1., shape=[1])
    root = BasicModel()
    input_func = root.add.get_concrete_function(input_data)

    variable_graph_def = input_func.graph.as_graph_def()
    self.assertEqual(1, self._getNumVariables(variable_graph_def))

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    constant_graph_def = output_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(constant_graph_def))
    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))

    # Check value.
    expected_value = root.add(input_data)
    actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
                                          [input_data.numpy()])
    self.assertEqual(expected_value.numpy(), actual_value)
  def testVariableModel(self):
    """Test a basic model with Variables."""
    input_data = constant_op.constant(1., shape=[1])
    root = tracking.AutoTrackable()
    root.v1 = variables.Variable(3.)
    root.v2 = variables.Variable(2.)
    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
    input_func = root.f.get_concrete_function(input_data)

    variable_graph_def = input_func.graph.as_graph_def()
    self.assertEqual(2, self._getNumVariables(variable_graph_def))

    output_func = convert_to_constants.convert_variables_to_constants_v2(
        input_func)
    constant_graph_def = output_func.graph.as_graph_def()
    self.assertEqual(0, self._getNumVariables(constant_graph_def))
    self.assertFalse(self._hasStatefulPartitionedCallOp(constant_graph_def))

    # Check value.
    expected_value = root.f(input_data)
    actual_value = self._evaluateGraphDef(constant_graph_def, input_func,
                                          [input_data.numpy()])
    self.assertEqual(expected_value.numpy(), actual_value)
Exemple #38
0
  def convert(self):
    """Convert the input SavedModel in 2.0 format.

    Returns:
      The TF-TRT converted Function.
    """
    assert not self._converted
    self._saved_model = load.load(self._input_saved_model_dir,
                                  self._input_saved_model_tags)
    func = self._saved_model.signatures[self._input_saved_model_signature_key]
    frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
    grappler_meta_graph_def = saver.export_meta_graph(
        graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)

    # Add a collection 'train_op' so that Grappler knows the outputs.
    fetch_collection = meta_graph_pb2.CollectionDef()
    for array in frozen_func.inputs + frozen_func.outputs:
      fetch_collection.node_list.value.append(array.name)
    grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
        fetch_collection)

    # Run TRT optimizer in Grappler to convert the graph.
    self._converted_graph_def = self._run_conversion(grappler_meta_graph_def)
    self._converted_func = wrap_function.function_from_graph_def(
        self._converted_graph_def,
        [tensor.name for tensor in frozen_func.inputs],
        [tensor.name for tensor in frozen_func.outputs])

    self._converted = True

    # Wrap the converted ConcreteFunction in a Function so it can accept numpy
    # arrays as input.
    @def_function.function
    def wrapper_func(*args, **kwargs):
      return self._converted_func(*args, **kwargs)

    return wrapper_func
  def _convert_saved_model_v2(self):
    """Convert the input SavedModel in 2.0 format."""
    assert context.executing_eagerly()

    self._saved_model = load.load(self._input_saved_model_dir,
                                  self._input_saved_model_tags)
    func = self._saved_model.signatures[self._input_saved_model_signature_key]
    frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
    self._grappler_meta_graph_def = saver.export_meta_graph(
        graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)

    # Add a collection 'train_op' so that Grappler knows the outputs.
    fetch_collection = meta_graph_pb2.CollectionDef()
    for array in frozen_func.inputs + frozen_func.outputs:
      fetch_collection.node_list.value.append(array.name)
    self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
        fetch_collection)

    # Run TRT optimizer in Grappler to convert the graph.
    self._run_conversion()
    self._converted_func = wrap_function.function_from_graph_def(
        self._converted_graph_def,
        [tensor.name for tensor in frozen_func.inputs],
        [tensor.name for tensor in frozen_func.outputs])
Exemple #40
0
  def convert(self):
    """Converts a TensorFlow GraphDef based on instance variables.

    Returns:
      The converted data in serialized format.

    Raises:
      ValueError:
        Input shape is not specified.
        None value for dimension in input_tensor.
    """
    frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
        self._func)
    input_tensors = [
        tensor for tensor in frozen_func.inputs
        if tensor.dtype != _dtypes.resource
    ]
    output_tensors = frozen_func.outputs

    # Run a Grappler pass.
    is_only_flex_enabled = set(
        [OpsSet.SELECT_TF_OPS]) == self.target_spec.supported_ops
    config = _get_grappler_config(enable_layout_optimizer=is_only_flex_enabled)
    graph_def = _run_graph_optimizations(
        frozen_func.graph.as_graph_def(),
        input_tensors,
        output_tensors,
        config,
        graph=frozen_func.graph)

    # Checks dimensions in input tensor.
    for tensor in input_tensors:
      # Note that shape_list might be empty for scalar shapes.
      shape_list = tensor.shape.as_list()
      if None in shape_list[1:]:
        raise ValueError(
            "None is only supported in the 1st dimension. Tensor '{0}' has "
            "invalid shape '{1}'.".format(_get_tensor_name(tensor), shape_list))
      elif shape_list and shape_list[0] is None:
        # Set the batch size to 1 if undefined.
        shape = tensor.shape.as_list()
        shape[0] = 1
        tensor.set_shape(shape)

    if self.representative_dataset:
      if not isinstance(self.representative_dataset, RepresentativeDataset):
        raise TypeError("`representative_dataset` must be an instance of "
                        "`RepresentativeDataset`")
      if self.representative_dataset.input_gen is None:
        raise ValueError(
            "Provide an input generator for `representative_dataset`")

    # TODO(shashishekhar): For now use optimizations order is ignored.
    # Both size and latency optimizations decide whether to apply post
    # training optimizations.
    post_training_optimize = bool(
        len(
            set(self.optimizations)
            & set([Optimize.OPTIMIZE_FOR_LATENCY, Optimize.OPTIMIZE_FOR_SIZE])))
    # Do weights only quantization if there is no dataset for calibration.
    weights_only_quantize_flag = (
        post_training_optimize and (self.representative_dataset is None))

    converter_kwargs = {
        "input_format": constants.TENSORFLOW_GRAPHDEF,
        "allow_custom_ops": self.allow_custom_ops,
        "post_training_quantize": weights_only_quantize_flag,
        "target_ops": self.target_spec.supported_ops,
    }

    # Converts model.
    result = _toco_convert_impl(
        input_data=graph_def,
        input_tensors=input_tensors,
        output_tensors=output_tensors,
        **converter_kwargs)

    if self.representative_dataset and post_training_optimize:
      calibrate_quantize = _calibrator.Calibrator(result)
      result = calibrate_quantize.calibrate_and_quantize(
          self.representative_dataset.input_gen)

    return result
Exemple #41
0
  def convert(self):
    """Converts a TensorFlow GraphDef based on instance variables.

    Returns:
      The converted data in serialized format.

    Raises:
      ValueError:
        Multiple concrete functions are specified.
        Input shape is not specified.
        Invalid quantization parameters.
    """
    # TODO(b/130297984): Add support for converting multiple function.
    self._target_ops = self.target_spec.supported_ops
    if len(self._funcs) != 1:
      raise ValueError("This converter can only convert a single "
                       "ConcreteFunction. Converting multiple functions is "
                       "under development.")

    frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
        self._funcs[0])
    input_tensors = [
        tensor for tensor in frozen_func.inputs
        if tensor.dtype != _dtypes.resource
    ]
    output_tensors = frozen_func.outputs

    # Run a Grappler pass.
    graph_def = frozen_func.graph.as_graph_def()
    graph_def = _run_graph_optimizations(
        graph_def,
        input_tensors,
        output_tensors,
        config=self._grappler_config(),
        graph=frozen_func.graph)

    # Checks dimensions in input tensor.
    for tensor in input_tensors:
      # Note that shape_list might be empty for scalar shapes.
      shape_list = tensor.shape.as_list()
      if None in shape_list[1:]:
        raise ValueError(
            "None is only supported in the 1st dimension. Tensor '{0}' has "
            "invalid shape '{1}'.".format(_get_tensor_name(tensor), shape_list))
      elif shape_list and shape_list[0] is None:
        # Set the batch size to 1 if undefined.
        shape = tensor.shape.as_list()
        shape[0] = 1
        tensor.set_shape(shape)

    self._validate_representative_dataset()

    converter_kwargs = {
        "input_format": constants.TENSORFLOW_GRAPHDEF,
        "allow_custom_ops": self.allow_custom_ops,
        "post_training_quantize": self._is_weight_only_quantize(),
        "target_ops": self.target_spec.supported_ops,
    }

    # Converts model.
    result = _toco_convert_impl(
        input_data=graph_def,
        input_tensors=input_tensors,
        output_tensors=output_tensors,
        **converter_kwargs)

    if self._is_calibration_quantize():
      result = self._calibrate_quantize_model(result, constants.FLOAT,
                                              constants.FLOAT)

    return result
Exemple #42
0
  def from_keras_model_file(cls,
                            model_file,
                            input_arrays=None,
                            input_shapes=None,
                            output_arrays=None,
                            custom_objects=None):
    """Creates a TFLiteConverter class from a tf.keras model file.

    Args:
      model_file: Full filepath of HDF5 file containing the tf.keras model.
      input_arrays: List of input tensors to freeze graph with. Uses input
        arrays from SignatureDef when none are provided. (default None)
      input_shapes: Dict of strings representing input tensor names to list of
        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
        Automatically determined when input shapes is None (e.g., {"foo" :
          None}). (default None)
      output_arrays: List of output tensors to freeze graph with. Uses output
        arrays from SignatureDef when none are provided. (default None)
      custom_objects: Dict mapping names (strings) to custom classes or
        functions to be considered during model deserialization. (default None)

    Returns:
      TFLiteConverter class.
    """
    # Handles Keras when Eager mode is enabled.
    if context.executing_eagerly():
      if input_arrays or output_arrays:
        raise ValueError("`input_arrays` and `output_arrays` are unsupported "
                         "with Eager mode. If your model requires any of these "
                         "parameters, please use disable_eager_execution().")

      _keras.backend.set_learning_phase(False)
      keras_model = _keras.models.load_model(model_file, custom_objects)

      function = _saving_utils.trace_model_call(keras_model)
      concrete_func = function.get_concrete_function()

      frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
          concrete_func)
      _set_tensor_shapes(frozen_func.inputs, input_shapes)
      return cls(frozen_func.graph.as_graph_def(), frozen_func.inputs,
                 frozen_func.outputs)

    # Handles Keras when Eager mode is disabled.
    _keras.backend.clear_session()
    _keras.backend.set_learning_phase(False)
    keras_model = _keras.models.load_model(model_file, custom_objects)
    sess = _keras.backend.get_session()

    # Get input and output tensors.
    if input_arrays:
      input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
    else:
      input_tensors = keras_model.inputs

    if output_arrays:
      output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays)
    else:
      output_tensors = keras_model.outputs
    _set_tensor_shapes(input_tensors, input_shapes)

    graph_def = _freeze_graph(sess, input_tensors, output_tensors)
    return cls(graph_def, input_tensors, output_tensors)