コード例 #1
0
    def from_keras_model_file(cls,
                              model_file,
                              input_arrays=None,
                              input_shapes=None,
                              output_arrays=None,
                              custom_objects=None):
        """Creates a TFLiteConverter class from a tf.keras model file.

    Args:
      model_file: Full filepath of HDF5 file containing the tf.keras model.
      input_arrays: List of input tensors to freeze graph with. Uses input
        arrays from SignatureDef when none are provided. (default None)
      input_shapes: Dict of strings representing input tensor names to list of
        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
        Automatically determined when input shapes is None (e.g., {"foo" :
          None}). (default None)
      output_arrays: List of output tensors to freeze graph with. Uses output
        arrays from SignatureDef when none are provided. (default None)
      custom_objects: Dict mapping names (strings) to custom classes or
        functions to be considered during model deserialization. (default None)

    Returns:
      TFLiteConverter class.
    """
        _keras.backend.clear_session()
        _keras.backend.set_learning_phase(False)
        keras_model = _keras.models.load_model(model_file, custom_objects)
        sess = _keras.backend.get_session()

        # Get input and output tensors.
        if input_arrays:
            input_tensors = _get_tensors_from_tensor_names(
                sess.graph, input_arrays)
        else:
            input_tensors = keras_model.inputs

        if output_arrays:
            output_tensors = _get_tensors_from_tensor_names(
                sess.graph, output_arrays)
        else:
            output_tensors = keras_model.outputs
        _set_tensor_shapes(input_tensors, input_shapes)

        graph_def = _freeze_graph(sess, input_tensors, output_tensors)
        return cls(graph_def, input_tensors, output_tensors)
コード例 #2
0
ファイル: lite.py プロジェクト: Albert-Z-Guo/tensorflow
  def from_keras_model_file(cls,
                            model_file,
                            input_arrays=None,
                            input_shapes=None,
                            output_arrays=None,
                            custom_objects=None):
    """Creates a TFLiteConverter class from a tf.keras model file.

    Args:
      model_file: Full filepath of HDF5 file containing the tf.keras model.
      input_arrays: List of input tensors to freeze graph with. Uses input
        arrays from SignatureDef when none are provided. (default None)
      input_shapes: Dict of strings representing input tensor names to list of
        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
        Automatically determined when input shapes is None (e.g., {"foo" :
          None}). (default None)
      output_arrays: List of output tensors to freeze graph with. Uses output
        arrays from SignatureDef when none are provided. (default None)
      custom_objects: Dict mapping names (strings) to custom classes or
        functions to be considered during model deserialization. (default None)

    Returns:
      TFLiteConverter class.
    """
    _keras.backend.clear_session()
    _keras.backend.set_learning_phase(False)
    keras_model = _keras.models.load_model(model_file, custom_objects)
    sess = _keras.backend.get_session()

    # Get input and output tensors.
    if input_arrays:
      input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
    else:
      input_tensors = keras_model.inputs

    if output_arrays:
      output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays)
    else:
      output_tensors = keras_model.outputs
    _set_tensor_shapes(input_tensors, input_shapes)

    graph_def = _freeze_graph(sess, input_tensors, output_tensors)
    return cls(graph_def, input_tensors, output_tensors)
コード例 #3
0
    def from_keras_model_file(cls,
                              model_file,
                              input_arrays=None,
                              input_shapes=None,
                              output_arrays=None,
                              custom_objects=None):
        """Creates a TFLiteConverter class from a tf.keras model file.

    Args:
      model_file: Full filepath of HDF5 file containing the tf.keras model.
      input_arrays: List of input tensors to freeze graph with. Uses input
        arrays from SignatureDef when none are provided. (default None)
      input_shapes: Dict of strings representing input tensor names to list of
        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
        Automatically determined when input shapes is None (e.g., {"foo" :
          None}). (default None)
      output_arrays: List of output tensors to freeze graph with. Uses output
        arrays from SignatureDef when none are provided. (default None)
      custom_objects: Dict mapping names (strings) to custom classes or
        functions to be considered during model deserialization. (default None)

    Returns:
      TFLiteConverter class.
    """
        # Handles Keras when Eager mode is enabled.
        if context.executing_eagerly():
            if input_arrays or output_arrays:
                raise ValueError(
                    "`input_arrays` and `output_arrays` are unsupported "
                    "with Eager mode. If your model requires any of these "
                    "parameters, please use disable_eager_execution().")

            _keras.backend.set_learning_phase(False)
            keras_model = _keras.models.load_model(model_file, custom_objects)

            function = _saving_utils.trace_model_call(keras_model)
            concrete_func = function.get_concrete_function()

            frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
                concrete_func, lower_control_flow=False)
            _set_tensor_shapes(frozen_func.inputs, input_shapes)
            return cls(frozen_func.graph.as_graph_def(),
                       frozen_func.inputs,
                       frozen_func.outputs,
                       experimental_debug_info_func=_build_debug_info_func(
                           frozen_func.graph))

        # Handles Keras when Eager mode is disabled.
        _keras.backend.clear_session()
        _keras.backend.set_learning_phase(False)
        keras_model = _keras.models.load_model(model_file, custom_objects)
        sess = _keras.backend.get_session()

        # Get input and output tensors.
        if input_arrays:
            input_tensors = _get_tensors_from_tensor_names(
                sess.graph, input_arrays)
        else:
            input_tensors = keras_model.inputs

        if output_arrays:
            output_tensors = _get_tensors_from_tensor_names(
                sess.graph, output_arrays)
        else:
            output_tensors = keras_model.outputs
        _set_tensor_shapes(input_tensors, input_shapes)

        graph_def = _freeze_graph(sess, input_tensors, output_tensors)
        return cls(graph_def,
                   input_tensors,
                   output_tensors,
                   experimental_debug_info_func=_build_debug_info_func(
                       sess.graph))
コード例 #4
0
    def from_frozen_graph(cls,
                          graph_def_file,
                          input_arrays,
                          output_arrays,
                          input_shapes=None):
        """Creates a TFLiteConverter class from a file containing a frozen GraphDef.

    Args:
      graph_def_file: Full filepath of file containing frozen GraphDef.
      input_arrays: List of input tensors to freeze graph with.
      output_arrays: List of output tensors to freeze graph with.
      input_shapes: Dict of strings representing input tensor names to list of
        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
        Automatically determined when input shapes is None (e.g., {"foo" :
          None}). (default None)

    Returns:
      TFLiteConverter class.

    Raises:
      IOError:
        File not found.
        Unable to parse input file.
      ValueError:
        The graph is not frozen.
        input_arrays or output_arrays contains an invalid tensor name.
        input_shapes is not correctly defined when required
    """
        with _ops.Graph().as_default():
            with _session.Session() as sess:
                # Read GraphDef from file.
                if not _file_io.file_exists(graph_def_file):
                    raise IOError(
                        "File '{0}' does not exist.".format(graph_def_file))
                with _file_io.FileIO(graph_def_file, "rb") as f:
                    file_content = f.read()

                try:
                    graph_def = _graph_pb2.GraphDef()
                    graph_def.ParseFromString(file_content)
                except (_text_format.ParseError, DecodeError):
                    try:
                        print("Ignore 'tcmalloc: large alloc' warnings.")

                        if not isinstance(file_content, str):
                            if PY3:
                                file_content = six.ensure_text(
                                    file_content, "utf-8")
                            else:
                                file_content = six.ensure_binary(
                                    file_content, "utf-8")
                        graph_def = _graph_pb2.GraphDef()
                        _text_format.Merge(file_content, graph_def)
                    except (_text_format.ParseError, DecodeError):
                        raise IOError(
                            "Unable to parse input file '{}'.".format(
                                graph_def_file))

                # Handles models with custom TFLite ops that cannot be resolved in
                # TensorFlow.
                load_model_in_session = True
                try:
                    _import_graph_def(graph_def, name="")
                except _NotFoundError:
                    load_model_in_session = False

                if load_model_in_session:
                    # Check if graph is frozen.
                    if not _is_frozen_graph(sess):
                        raise ValueError(
                            "Please freeze the graph using freeze_graph.py.")

                    # Get input and output tensors.
                    input_tensors = _get_tensors_from_tensor_names(
                        sess.graph, input_arrays)
                    output_tensors = _get_tensors_from_tensor_names(
                        sess.graph, output_arrays)
                    _set_tensor_shapes(input_tensors, input_shapes)

                    return cls(sess.graph_def, input_tensors, output_tensors)
                else:
                    if not input_shapes:
                        raise ValueError(
                            "input_shapes must be defined for this model.")
                    if set(input_arrays) != set(input_shapes.keys()):
                        raise ValueError(
                            "input_shapes must contain a value for each item "
                            "in input_array.")

                    input_arrays_with_shape = [(name, input_shapes[name])
                                               for name in input_arrays]
                    return cls(graph_def,
                               input_tensors=None,
                               output_tensors=None,
                               input_arrays_with_shape=input_arrays_with_shape,
                               output_arrays=output_arrays)
コード例 #5
0
ファイル: lite.py プロジェクト: aritratony/tensorflow
  def from_keras_model_file(cls,
                            model_file,
                            input_arrays=None,
                            input_shapes=None,
                            output_arrays=None,
                            custom_objects=None):
    """Creates a TFLiteConverter class from a tf.keras model file.

    Args:
      model_file: Full filepath of HDF5 file containing the tf.keras model.
      input_arrays: List of input tensors to freeze graph with. Uses input
        arrays from SignatureDef when none are provided. (default None)
      input_shapes: Dict of strings representing input tensor names to list of
        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
        Automatically determined when input shapes is None (e.g., {"foo" :
          None}). (default None)
      output_arrays: List of output tensors to freeze graph with. Uses output
        arrays from SignatureDef when none are provided. (default None)
      custom_objects: Dict mapping names (strings) to custom classes or
        functions to be considered during model deserialization. (default None)

    Returns:
      TFLiteConverter class.
    """
    # Handles Keras when Eager mode is enabled.
    if context.executing_eagerly():
      if input_arrays or output_arrays:
        raise ValueError("`input_arrays` and `output_arrays` are unsupported "
                         "with Eager mode. If your model requires any of these "
                         "parameters, please use disable_eager_execution().")

      _keras.backend.set_learning_phase(False)
      keras_model = _keras.models.load_model(model_file, custom_objects)

      function = _saving_utils.trace_model_call(keras_model)
      concrete_func = function.get_concrete_function()

      frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
          concrete_func)
      _set_tensor_shapes(frozen_func.inputs, input_shapes)
      return cls(frozen_func.graph.as_graph_def(), frozen_func.inputs,
                 frozen_func.outputs)

    # Handles Keras when Eager mode is disabled.
    _keras.backend.clear_session()
    _keras.backend.set_learning_phase(False)
    keras_model = _keras.models.load_model(model_file, custom_objects)
    sess = _keras.backend.get_session()

    # Get input and output tensors.
    if input_arrays:
      input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
    else:
      input_tensors = keras_model.inputs

    if output_arrays:
      output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays)
    else:
      output_tensors = keras_model.outputs
    _set_tensor_shapes(input_tensors, input_shapes)

    graph_def = _freeze_graph(sess, input_tensors, output_tensors)
    return cls(graph_def, input_tensors, output_tensors)
コード例 #6
0
ファイル: lite.py プロジェクト: aritratony/tensorflow
  def from_frozen_graph(cls,
                        graph_def_file,
                        input_arrays,
                        output_arrays,
                        input_shapes=None):
    """Creates a TFLiteConverter class from a file containing a frozen GraphDef.

    Args:
      graph_def_file: Full filepath of file containing frozen GraphDef.
      input_arrays: List of input tensors to freeze graph with.
      output_arrays: List of output tensors to freeze graph with.
      input_shapes: Dict of strings representing input tensor names to list of
        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
        Automatically determined when input shapes is None (e.g., {"foo" :
          None}). (default None)

    Returns:
      TFLiteConverter class.

    Raises:
      IOError:
        File not found.
        Unable to parse input file.
      ValueError:
        The graph is not frozen.
        input_arrays or output_arrays contains an invalid tensor name.
        input_shapes is not correctly defined when required
    """
    with _ops.Graph().as_default():
      with _session.Session() as sess:
        # Read GraphDef from file.
        if not _file_io.file_exists(graph_def_file):
          raise IOError("File '{0}' does not exist.".format(graph_def_file))
        with _file_io.FileIO(graph_def_file, "rb") as f:
          file_content = f.read()

        try:
          graph_def = _graph_pb2.GraphDef()
          graph_def.ParseFromString(file_content)
        except (_text_format.ParseError, DecodeError):
          try:
            print("Ignore 'tcmalloc: large alloc' warnings.")

            if not isinstance(file_content, str):
              if PY3:
                file_content = file_content.decode("utf-8")
              else:
                file_content = file_content.encode("utf-8")
            graph_def = _graph_pb2.GraphDef()
            _text_format.Merge(file_content, graph_def)
          except (_text_format.ParseError, DecodeError):
            raise IOError(
                "Unable to parse input file '{}'.".format(graph_def_file))

        # Handles models with custom TFLite ops that cannot be resolved in
        # TensorFlow.
        load_model_in_session = True
        try:
          _import_graph_def(graph_def, name="")
        except _NotFoundError:
          load_model_in_session = False

        if load_model_in_session:
          # Check if graph is frozen.
          if not _is_frozen_graph(sess):
            raise ValueError("Please freeze the graph using freeze_graph.py.")

          # Get input and output tensors.
          input_tensors = _get_tensors_from_tensor_names(
              sess.graph, input_arrays)
          output_tensors = _get_tensors_from_tensor_names(
              sess.graph, output_arrays)
          _set_tensor_shapes(input_tensors, input_shapes)

          return cls(sess.graph_def, input_tensors, output_tensors)
        else:
          if not input_shapes:
            raise ValueError("input_shapes must be defined for this model.")
          if set(input_arrays) != set(input_shapes.keys()):
            raise ValueError("input_shapes must contain a value for each item "
                             "in input_array.")

          input_arrays_with_shape = [
              (name, input_shapes[name]) for name in input_arrays
          ]
          return cls(
              graph_def,
              input_tensors=None,
              output_tensors=None,
              input_arrays_with_shape=input_arrays_with_shape,
              output_arrays=output_arrays)