Пример #1
0
def evaluate_frozen_graph(filename, input_arrays, output_arrays):
  """Returns a function that evaluates the frozen graph on input data.

  Args:
    filename: Full filepath of file containing frozen GraphDef.
    input_arrays: List of input tensors to freeze graph with.
    output_arrays: List of output tensors to freeze graph with.

  Returns:
    Lambda function ([np.ndarray data] : [np.ndarray result]).
  """
  with _session.Session().as_default() as sess:
    with _file_io.FileIO(filename, "rb") as f:
      file_content = f.read()

    graph_def = _graph_pb2.GraphDef()
    try:
      graph_def.ParseFromString(file_content)
    except (_text_format.ParseError, DecodeError):
      if not isinstance(file_content, str):
        if PY3:
          file_content = file_content.decode("utf-8")
        else:
          file_content = file_content.encode("utf-8")
      _text_format.Merge(file_content, graph_def)
    _import_graph_def(graph_def, name="")

    inputs = _util.get_tensors_from_tensor_names(sess.graph, input_arrays)
    outputs = _util.get_tensors_from_tensor_names(sess.graph, output_arrays)

    return lambda input_data: sess.run(outputs, dict(zip(inputs, input_data)))
Пример #2
0
def evaluate_frozen_graph(filename, input_arrays, output_arrays):
    """Returns a function that evaluates the frozen graph on input data.

  Args:
    filename: Full filepath of file containing frozen GraphDef.
    input_arrays: List of input tensors to freeze graph with.
    output_arrays: List of output tensors to freeze graph with.

  Returns:
    Lambda function ([np.ndarray data] : [np.ndarray result]).
  """
    with _session.Session().as_default() as sess:
        with _file_io.FileIO(filename, "rb") as f:
            file_content = f.read()

        graph_def = _graph_pb2.GraphDef()
        graph_def.ParseFromString(file_content)
        _import_graph_def(graph_def, name="")

        inputs = _util.get_tensors_from_tensor_names(sess.graph, input_arrays)
        outputs = _util.get_tensors_from_tensor_names(sess.graph,
                                                      output_arrays)

        return lambda input_data: sess.run(outputs,
                                           dict(zip(inputs, input_data)))
Пример #3
0
    def testGetTensorsInvalid(self):
        in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3],
                                          dtype=dtypes.float32)
        _ = in_tensor + in_tensor
        sess = session.Session()

        with self.assertRaises(ValueError) as error:
            util.get_tensors_from_tensor_names(sess.graph, ["invalid-input"])
        self.assertEqual("Invalid tensors 'invalid-input' were found.",
                         str(error.exception))
Пример #4
0
  def testGetTensorsInvalid(self):
    in_tensor = array_ops.placeholder(
        shape=[1, 16, 16, 3], dtype=dtypes.float32)
    _ = in_tensor + in_tensor
    sess = session.Session()

    with self.assertRaises(ValueError) as error:
      util.get_tensors_from_tensor_names(sess.graph, ["invalid-input"])
    self.assertEqual("Invalid tensors 'invalid-input' were found.",
                     str(error.exception))
Пример #5
0
  def testGetTensorsValid(self):
    in_tensor = array_ops.placeholder(
        shape=[1, 16, 16, 3], dtype=dtypes.float32)
    _ = in_tensor + in_tensor
    sess = session.Session()

    tensors = util.get_tensors_from_tensor_names(sess.graph, ["Placeholder"])
    self.assertEqual("Placeholder:0", tensors[0].name)
Пример #6
0
def evaluate_frozen_graph(filename, input_arrays, output_arrays):
  """Returns a function that evaluates the frozen graph on input data.

  Args:
    filename: Full filepath of file containing frozen GraphDef.
    input_arrays: List of input tensors to freeze graph with.
    output_arrays: List of output tensors to freeze graph with.

  Returns:
    Lambda function ([np.ndarray data] : [np.ndarray result]).
  """
  with _session.Session().as_default() as sess:
    with _file_io.FileIO(filename, "rb") as f:
      file_content = f.read()

    graph_def = _graph_pb2.GraphDef()
    graph_def.ParseFromString(file_content)
    _import_graph_def(graph_def, name="")

    inputs = _util.get_tensors_from_tensor_names(sess.graph, input_arrays)
    outputs = _util.get_tensors_from_tensor_names(sess.graph, output_arrays)

    return lambda input_data: sess.run(outputs, dict(zip(inputs, input_data)))
Пример #7
0
def _get_tensors(graph,
                 signature_def_tensor_names=None,
                 user_tensor_names=None):
    """Gets the tensors associated with the tensor names.

  Either signature_def_tensor_names or user_tensor_names should be provided. If
  the user provides tensors, the tensors associated with the user provided
  tensor names are provided. Otherwise, the tensors associated with the names in
  the SignatureDef are provided.

  Args:
    graph: GraphDef representing graph.
    signature_def_tensor_names: Tensor names stored in either the inputs or
      outputs of a SignatureDef. (default None)
    user_tensor_names: Tensor names provided by the user. (default None)

  Returns:
    List of tensors.

  Raises:
    ValueError:
      signature_def_tensors and user_tensor_names are undefined or empty.
      user_tensor_names are not valid.
  """
    tensors = []
    if user_tensor_names:
        # Sort the tensor names.
        user_tensor_names = sorted(user_tensor_names)

        tensors = util.get_tensors_from_tensor_names(graph, user_tensor_names)
    elif signature_def_tensor_names:
        tensors = [
            graph.get_tensor_by_name(name)
            for name in sorted(signature_def_tensor_names)
        ]
    else:
        # Throw ValueError if signature_def_tensors and user_tensor_names are both
        # either undefined or empty.
        raise ValueError(
            "Specify either signature_def_tensor_names or user_tensor_names")

    return tensors
def _get_tensors(graph, signature_def_tensor_names=None,
                 user_tensor_names=None):
  """Gets the tensors associated with the tensor names.

  Either signature_def_tensor_names or user_tensor_names should be provided. If
  the user provides tensors, the tensors associated with the user provided
  tensor names are provided. Otherwise, the tensors associated with the names in
  the SignatureDef are provided.

  Args:
    graph: GraphDef representing graph.
    signature_def_tensor_names: Tensor names stored in either the inputs or
      outputs of a SignatureDef. (default None)
    user_tensor_names: Tensor names provided by the user. (default None)

  Returns:
    List of tensors.

  Raises:
    ValueError:
      signature_def_tensors and user_tensor_names are undefined or empty.
      user_tensor_names are not valid.
  """
  tensors = []
  if user_tensor_names:
    # Sort the tensor names.
    user_tensor_names = sorted(user_tensor_names)

    tensors = util.get_tensors_from_tensor_names(graph, user_tensor_names)
  elif signature_def_tensor_names:
    tensors = [
        graph.get_tensor_by_name(name)
        for name in sorted(signature_def_tensor_names)
    ]
  else:
    # Throw ValueError if signature_def_tensors and user_tensor_names are both
    # either undefined or empty.
    raise ValueError(
        "Specify either signature_def_tensor_names or user_tensor_names")

  return tensors