예제 #1
0
    def testSetTensorShapeEmpty(self):
        tensor = array_ops.placeholder(shape=[None, 3, 5],
                                       dtype=dtypes.float32)
        self.assertEqual([None, 3, 5], tensor.shape.as_list())

        util.set_tensor_shapes([tensor], {})
        self.assertEqual([None, 3, 5], tensor.shape.as_list())
예제 #2
0
파일: util_test.py 프로젝트: Harryi0/tinyML
    def testSetTensorShapeNoneValid(self):
        with ops.Graph().as_default():
            tensor = array_ops.placeholder(dtype=dtypes.float32)
        self.assertEqual(None, tensor.shape)

        util.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]})
        self.assertEqual([1, 3, 5], tensor.shape.as_list())
예제 #3
0
  def testSetTensorShapeEmpty(self):
    with ops.Graph().as_default():
      tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 3, 5])
    self.assertAllEqual([None, 3, 5], tensor.shape)

    util.set_tensor_shapes([tensor], {})
    self.assertAllEqual([None, 3, 5], tensor.shape)
예제 #4
0
def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
                       output_arrays, tag_set, signature_key):
    """Converts a SavedModel to a frozen graph.

  Args:
    saved_model_dir: SavedModel directory to convert.
    input_arrays: List of input tensors to freeze graph with. Uses input arrays
      from SignatureDef when none are provided.
    input_shapes: Dict of strings representing input tensor names to list of
      integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
      Automatically determined when input shapes is None (e.g., {"foo" : None}).
    output_arrays: List of output tensors to freeze graph with. Uses output
      arrays from SignatureDef when none are provided.
    tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
      analyze. All tags in the tag set must be present.
    signature_key: Key identifying SignatureDef containing inputs and outputs.

  Returns:
    frozen_graph_def: Frozen GraphDef.
    in_tensors: List of input tensors for the graph.
    out_tensors: List of output tensors for the graph.

  Raises:
    ValueError:
      SavedModel doesn't contain a MetaGraphDef identified by tag_set.
      signature_key is not in the MetaGraphDef.
      assets/ directory is in the MetaGraphDef.
      input_shapes does not match the length of input_arrays.
      input_arrays or output_arrays are not valid.
  """
    # Read SignatureDef.
    meta_graph = get_meta_graph_def(saved_model_dir, tag_set)
    signature_def = get_signature_def(meta_graph, signature_key)
    inputs, outputs = get_inputs_outputs(signature_def)

    # Check SavedModel for assets directory.
    collection_def = meta_graph.collection_def
    if constants.ASSETS_KEY in collection_def:
        raise ValueError(
            "SavedModels with assets/ directory are not supported.")

    graph = ops.Graph()
    with session.Session(graph=graph) as sess:
        loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir)

        # Gets input and output tensors.
        # TODO(zhixianyan): Use TFLite supported Op list to filter outputs.
        in_tensors = _get_tensors(graph, inputs, input_arrays)
        out_tensors = _get_tensors(graph, outputs, output_arrays)
        util.set_tensor_shapes(in_tensors, input_shapes)

        output_names = [node.split(":")[0] for node in outputs]
        frozen_graph_def = tf_graph_util.convert_variables_to_constants(
            sess, graph.as_graph_def(), output_names)

        return frozen_graph_def, in_tensors, out_tensors
예제 #5
0
  def testSetTensorShapeDimensionInvalid(self):
    # Tests set_tensor_shape where the shape passed in is incompatiable.
    tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
    self.assertEqual([None, 3, 5], tensor.shape.as_list())

    with self.assertRaises(ValueError) as error:
      util.set_tensor_shapes([tensor], {"Placeholder": [1, 5, 5]})
    self.assertIn("The shape of tensor 'Placeholder' cannot be changed",
                  str(error.exception))
    self.assertEqual([None, 3, 5], tensor.shape.as_list())
예제 #6
0
  def testSetTensorShapeArrayInvalid(self):
    # Tests set_tensor_shape where the tensor name passed in doesn't exist.
    tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
    self.assertEqual([None, 3, 5], tensor.shape.as_list())

    with self.assertRaises(ValueError) as error:
      util.set_tensor_shapes([tensor], {"invalid-input": [5, 3, 5]})
    self.assertEqual(
        "Invalid tensor 'invalid-input' found in tensor shapes map.",
        str(error.exception))
    self.assertEqual([None, 3, 5], tensor.shape.as_list())
def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
                       output_arrays, tag_set, signature_key):
  """Converts a SavedModel to a frozen graph.

  Args:
    saved_model_dir: SavedModel directory to convert.
    input_arrays: List of input tensors to freeze graph with. Uses input arrays
      from SignatureDef when none are provided.
    input_shapes: Dict of strings representing input tensor names to list of
      integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
      Automatically determined when input shapes is None (e.g., {"foo" : None}).
    output_arrays: List of output tensors to freeze graph with. Uses output
      arrays from SignatureDef when none are provided.
    tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
      analyze. All tags in the tag set must be present.
    signature_key: Key identifying SignatureDef containing inputs and outputs.

  Returns:
    frozen_graph_def: Frozen GraphDef.
    in_tensors: List of input tensors for the graph.
    out_tensors: List of output tensors for the graph.

  Raises:
    ValueError:
      SavedModel doesn't contain a MetaGraphDef identified by tag_set.
      signature_key is not in the MetaGraphDef.
      assets/ directory is in the MetaGraphDef.
      input_shapes does not match the length of input_arrays.
      input_arrays or output_arrays are not valid.
  """
  # Read SignatureDef.
  meta_graph = get_meta_graph_def(saved_model_dir, tag_set)
  signature_def = get_signature_def(meta_graph, signature_key)
  inputs, outputs = get_inputs_outputs(signature_def)

  # Check SavedModel for assets directory.
  collection_def = meta_graph.collection_def
  if constants.ASSETS_KEY in collection_def:
    raise ValueError("SavedModels with assets/ directory are not supported.")

  graph = ops.Graph()
  with session.Session(graph=graph) as sess:
    loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir)

    # Gets input and output tensors.
    # TODO(zhixianyan): Use TFLite supported Op list to filter outputs.
    in_tensors = _get_tensors(graph, inputs, input_arrays)
    out_tensors = _get_tensors(graph, outputs, output_arrays)
    util.set_tensor_shapes(in_tensors, input_shapes)

    frozen_graph_def = util.freeze_graph(sess, in_tensors, out_tensors)
    return frozen_graph_def, in_tensors, out_tensors
예제 #8
0
  def testSetTensorShapeNoneValid(self):
    tensor = array_ops.placeholder(dtype=dtypes.float32)
    self.assertEqual(None, tensor.shape)

    util.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]})
    self.assertEqual([1, 3, 5], tensor.shape.as_list())
예제 #9
0
  def testSetTensorShapeEmpty(self):
    tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
    self.assertEqual([None, 3, 5], tensor.shape.as_list())

    util.set_tensor_shapes([tensor], {})
    self.assertEqual([None, 3, 5], tensor.shape.as_list())