コード例 #1
0
ファイル: lite_test.py プロジェクト: aeverall/tensorflow
  def testFreezeGraph(self):
    in_tensor = array_ops.placeholder(
        shape=[1, 16, 16, 3], dtype=dtypes.float32)
    var = variable_scope.get_variable(
        'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
    out_tensor = in_tensor + var
    sess = session.Session()
    sess.run(_global_variables_initializer())

    # Convert model and ensure model is not None.
    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
                                                  [out_tensor])
    tflite_model = converter.convert()
    self.assertTrue(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertEqual(1, len(input_details))
    self.assertEqual('Placeholder', input_details[0]['name'])
    self.assertEqual(np.float32, input_details[0]['dtype'])
    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
    self.assertEqual((0., 0.), input_details[0]['quantization'])

    output_details = interpreter.get_output_details()
    self.assertEqual(1, len(output_details))
    self.assertEqual('add', output_details[0]['name'])
    self.assertEqual(np.float32, output_details[0]['dtype'])
    self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
    self.assertEqual((0., 0.), output_details[0]['quantization'])
コード例 #2
0
    def testFreezeGraph(self):
        in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3],
                                          dtype=dtypes.float32)
        var = variable_scope.get_variable('weights',
                                          shape=[1, 16, 16, 3],
                                          dtype=dtypes.float32)
        out_tensor = in_tensor + var
        sess = session.Session()
        sess.run(_global_variables_initializer())

        # Convert model and ensure model is not None.
        converter = lite.TocoConverter.from_session(sess, [in_tensor],
                                                    [out_tensor])
        tflite_model = converter.convert()
        self.assertTrue(tflite_model)

        # Check values from converted model.
        interpreter = Interpreter(model_content=tflite_model)
        interpreter.allocate_tensors()

        input_details = interpreter.get_input_details()
        self.assertEqual(1, len(input_details))
        self.assertEqual('Placeholder', input_details[0]['name'])
        self.assertEqual(np.float32, input_details[0]['dtype'])
        self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
        self.assertEqual((0., 0.), input_details[0]['quantization'])

        output_details = interpreter.get_output_details()
        self.assertEqual(1, len(output_details))
        self.assertEqual('add', output_details[0]['name'])
        self.assertEqual(np.float32, output_details[0]['dtype'])
        self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
        self.assertEqual((0., 0.), output_details[0]['quantization'])
コード例 #3
0
ファイル: lite.py プロジェクト: zyy2020/tensorflow
def _freeze_graph(sess, output_tensors):
  """Returns a frozen GraphDef.

  Freezes a graph with Variables in it. Otherwise the existing GraphDef is
  returned.

  Args:
    sess: TensorFlow Session.
    output_tensors: List of output tensors (only .name is used from this).

  Returns:
    Frozen GraphDef.
  """
  if not _is_frozen_graph(sess):
    sess.run(_global_variables_initializer())
    output_arrays = [_tensor_name(tensor) for tensor in output_tensors]
    return _tf_graph_util.convert_variables_to_constants(
        sess, sess.graph_def, output_arrays)
  else:
    return sess.graph_def
コード例 #4
0
ファイル: lite.py プロジェクト: zyy2020/tensorflow
  def from_frozen_graph(cls,
                        graph_def_file,
                        input_arrays,
                        output_arrays,
                        input_shapes=None):
    """Creates a TocoConverter class from a file containing a frozen GraphDef.

    Args:
      graph_def_file: Full filepath of file containing frozen GraphDef.
      input_arrays: List of input tensors to freeze graph with.
      output_arrays: List of output tensors to freeze graph with.
      input_shapes: Dict of strings representing input tensor names to list of
        integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
        Automatically determined when input shapes is None (e.g., {"foo" :
        None}). (default None)

    Returns:
      TocoConverter class.

    Raises:
      ValueError:
        Unable to parse input file.
        The graph is not frozen.
        input_arrays or output_arrays contains an invalid tensor name.
    """
    with _session.Session() as sess:
      sess.run(_global_variables_initializer())

      # Read GraphDef from file.
      graph_def = _graph_pb2.GraphDef()
      with open(graph_def_file, "rb") as f:
        file_content = f.read()
      try:
        graph_def.ParseFromString(file_content)
      except (_text_format.ParseError, DecodeError):
        try:
          print("Ignore 'tcmalloc: large alloc' warnings.")

          if not isinstance(file_content, str):
            if PY3:
              file_content = file_content.decode('utf-8')
            else:
              file_content = file_content.encode('utf-8')
          _text_format.Merge(file_content, graph_def)
        except (_text_format.ParseError, DecodeError):
          raise ValueError(
              "Unable to parse input file '{}'.".format(graph_def_file))
      sess.graph.as_default()
      _import_graph_def(graph_def, name="")

      # Get input and output tensors.
      input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
      output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays)
      _set_tensor_shapes(input_tensors, input_shapes)

      # Check if graph is frozen.
      if not _is_frozen_graph(sess):
        raise ValueError("Please freeze the graph using freeze_graph.py.")

      # Create TocoConverter class.
      return cls(sess.graph_def, input_tensors, output_tensors)