コード例 #1
0
ファイル: trt_convert.py プロジェクト: ZhangXinNan/tensorflow
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
  """Convert an existing calibration graph to inference graph.

  Args:
    calibration_graph_def: the calibration GraphDef object with calibration data
    is_dynamic_op: whether to create dynamic static engines from calibration
  Returns:
    New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
  Raises:
    RuntimeError: if the returned status message is malformed.
  """

  def py2string(inp):
    return inp

  def py3string(inp):
    return inp.decode("utf-8")

  if _six.PY2:
    to_string = py2string
  else:
    to_string = py3string
  is_calib_graph = False
  for n in calibration_graph_def.node:
    if n.op == "TRTEngineOp":
      is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s
  if not is_calib_graph:
    tf_logging.error(
        "Not a calib graph. Doesn't seem to contain any calibration nodes.")
    return None
  graph_str = calibration_graph_def.SerializeToString()
  out = calib_convert(graph_str, is_dynamic_op)
  status = to_string(out[0])
  output_graph_def_string = out[1]
  del graph_str  # Save some memory
  if len(status) < 2:
    raise _impl.UnknownError(None, None, status)
  if status[:2] != "OK":
    msg = status.split(";")
    if len(msg) == 1:
      raise RuntimeError("Status message is malformed {}".format(status))
    # pylint: disable=protected-access
    raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
                                         int(msg[0]))
    # pylint: enable=protected-access
  output_graph_def = graph_pb2.GraphDef()
  output_graph_def.ParseFromString(output_graph_def_string)
  del output_graph_def_string  # Save some memory
  return output_graph_def
コード例 #2
0
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
  """Convert an existing calibration graph to inference graph.

  Args:
    calibration_graph_def: the calibration GraphDef object with calibration data
    is_dynamic_op: whether to create dynamic static engines from calibration
  Returns:
    New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
  Raises:
    RuntimeError: if the returned status message is malformed.
  """

  def py2string(inp):
    return inp

  def py3string(inp):
    return inp.decode("utf-8")

  if _six.PY2:
    to_string = py2string
  else:
    to_string = py3string
  is_calib_graph = False
  for n in calibration_graph_def.node:
    if n.op == "TRTEngineOp":
      is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s
  if not is_calib_graph:
    tf_logging.error(
        "Not a calib graph. Doesn't seem to contain any calibration nodes.")
    return None
  graph_str = calibration_graph_def.SerializeToString()
  out = calib_convert(graph_str, is_dynamic_op)
  status = to_string(out[0])
  output_graph_def_string = out[1]
  del graph_str  # Save some memory
  if len(status) < 2:
    raise _impl.UnknownError(None, None, status)
  if status[:2] != "OK":
    msg = status.split(";")
    if len(msg) == 1:
      raise RuntimeError("Status message is malformed {}".format(status))
    # pylint: disable=protected-access
    raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
                                         int(msg[0]))
    # pylint: enable=protected-access
  output_graph_def = graph_pb2.GraphDef()
  output_graph_def.ParseFromString(output_graph_def_string)
  del output_graph_def_string  # Save some memory
  return output_graph_def
コード例 #3
0
ファイル: trt_convert.py プロジェクト: sgcm520/tensorflow2
def calib_graph_to_infer_graph(calibration_graph_def):
  """Convert an existing calibration graph to inference graph.

  Args:
    calibration_graph_def: the calibration GraphDef object with calibration data
  Returns:
    New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
  Raises:
    RuntimeError: if the returned status message is malformed.
  """

  def py2string(inp):
    return inp

  def py3string(inp):
    return inp.decode("utf-8")

  if _six.PY2:
    to_string = py2string
  else:
    to_string = py3string

  graph_str = calibration_graph_def.SerializeToString()
  out = calib_convert(graph_str)
  status = to_string(out[0])
  output_graph_def_string = out[1]
  del graph_str  # Save some memory
  if len(status) < 2:
    raise _impl.UnknownError(None, None, status)
  if status[:2] != "OK":
    msg = status.split(";")
    if len(msg) == 1:
      raise RuntimeError("Status message is malformed {}".format(status))
    # pylint: disable=protected-access
    raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
                                         int(msg[0]))
    # pylint: enable=protected-access
  output_graph_def = graph_pb2.GraphDef()
  output_graph_def.ParseFromString(output_graph_def_string)
  del output_graph_def_string  # Save some memory
  return output_graph_def