コード例 #1
0
ファイル: errors_test.py プロジェクト: AnishShah/tensorflow
 def testPickleable(self):
   for error_code in [
       errors.CANCELLED,
       errors.UNKNOWN,
       errors.INVALID_ARGUMENT,
       errors.DEADLINE_EXCEEDED,
       errors.NOT_FOUND,
       errors.ALREADY_EXISTS,
       errors.PERMISSION_DENIED,
       errors.UNAUTHENTICATED,
       errors.RESOURCE_EXHAUSTED,
       errors.FAILED_PRECONDITION,
       errors.ABORTED,
       errors.OUT_OF_RANGE,
       errors.UNIMPLEMENTED,
       errors.INTERNAL,
       errors.UNAVAILABLE,
       errors.DATA_LOSS,
   ]:
     # pylint: disable=protected-access
     exc = errors_impl._make_specific_exception(None, None, None, error_code)
     # pylint: enable=protected-access
     unpickled = pickle.loads(pickle.dumps(exc))
     self.assertEqual(exc.node_def, unpickled.node_def)
     self.assertEqual(exc.op, unpickled.op)
     self.assertEqual(exc.message, unpickled.message)
     self.assertEqual(exc.error_code, unpickled.error_code)
コード例 #2
0
def load_file_system_library(library_filename):
  """Loads a TensorFlow plugin, containing file system implementation.

  Pass `library_filename` to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    None.

  Raises:
    RuntimeError: when unable to load the library.
  """
  status = py_tf.TF_NewStatus()
  lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
  try:
    error_code = py_tf.TF_GetCode(status)
    if error_code != 0:
      error_msg = compat.as_text(py_tf.TF_Message(status))
      # pylint: disable=protected-access
      raise errors_impl._make_specific_exception(
          None, None, error_msg, error_code)
      # pylint: enable=protected-access
  finally:
    py_tf.TF_DeleteStatus(status)
コード例 #3
0
 def testUniqueClassForEachErrorCode(self):
   for error_code, exc_type in [
       (errors.CANCELLED, errors_impl.CancelledError),
       (errors.UNKNOWN, errors_impl.UnknownError),
       (errors.INVALID_ARGUMENT, errors_impl.InvalidArgumentError),
       (errors.DEADLINE_EXCEEDED, errors_impl.DeadlineExceededError),
       (errors.NOT_FOUND, errors_impl.NotFoundError),
       (errors.ALREADY_EXISTS, errors_impl.AlreadyExistsError),
       (errors.PERMISSION_DENIED, errors_impl.PermissionDeniedError),
       (errors.UNAUTHENTICATED, errors_impl.UnauthenticatedError),
       (errors.RESOURCE_EXHAUSTED, errors_impl.ResourceExhaustedError),
       (errors.FAILED_PRECONDITION, errors_impl.FailedPreconditionError),
       (errors.ABORTED, errors_impl.AbortedError),
       (errors.OUT_OF_RANGE, errors_impl.OutOfRangeError),
       (errors.UNIMPLEMENTED, errors_impl.UnimplementedError),
       (errors.INTERNAL, errors_impl.InternalError),
       (errors.UNAVAILABLE, errors_impl.UnavailableError),
       (errors.DATA_LOSS, errors_impl.DataLossError),
   ]:
     # pylint: disable=protected-access
     self.assertTrue(
         isinstance(
             errors_impl._make_specific_exception(None, None, None,
                                                  error_code), exc_type))
     # error_code_from_exception_type and exception_type_from_error_code should
     # be consistent with operation result.
     self.assertEqual(error_code,
                      errors_impl.error_code_from_exception_type(exc_type))
コード例 #4
0
 def testUnknownErrorCodeCausesWarning(self):
   with warnings.catch_warnings(record=True) as w:
     # pylint: disable=protected-access
     exc = errors_impl._make_specific_exception(None, None, None, 37)
     # pylint: enable=protected-access
   self.assertEqual(1, len(w))
   self.assertTrue("Unknown error code: 37" in str(w[0].message))
   self.assertTrue(isinstance(exc, errors_impl.OpError))
コード例 #5
0
def load_op_library(library_filename):
  """Loads a TensorFlow plugin, containing custom ops and kernels.

  Pass "library_filename" to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here. When the
  library is loaded, ops and kernels registered in the library via the
  `REGISTER_*` macros are made available in the TensorFlow process. Note
  that ops with the same name as an existing op are rejected and not
  registered with the process.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    A python module containing the Python wrappers for Ops defined in
    the plugin.

  Raises:
    RuntimeError: when unable to load the library or get the python wrappers.
  """
  status = py_tf.TF_NewStatus()

  lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
  try:
    error_code = py_tf.TF_GetCode(status)
    if error_code != 0:
      error_msg = compat.as_text(py_tf.TF_Message(status))
      # pylint: disable=protected-access
      raise errors_impl._make_specific_exception(
          None, None, error_msg, error_code)
      # pylint: enable=protected-access
  finally:
    py_tf.TF_DeleteStatus(status)

  op_list_str = py_tf.TF_GetOpList(lib_handle)
  op_list = op_def_pb2.OpList()
  op_list.ParseFromString(compat.as_bytes(op_list_str))
  wrappers = py_tf.GetPythonWrappers(op_list_str)

  # Delete the library handle to release any memory held in C
  # that are no longer needed.
  py_tf.TF_DeleteLibraryHandle(lib_handle)

  # Get a unique name for the module.
  module_name = hashlib.md5(wrappers).hexdigest()
  if module_name in sys.modules:
    return sys.modules[module_name]
  module = imp.new_module(module_name)
  # pylint: disable=exec-used
  exec(wrappers, module.__dict__)
  # Stash away the library handle for making calls into the dynamic library.
  module.LIB_HANDLE = lib_handle
  # OpDefs of the list of ops defined in the library.
  module.OP_LIST = op_list
  sys.modules[module_name] = module
  return module
コード例 #6
0
def load_op_library(library_filename):
    """Loads a TensorFlow plugin, containing custom ops and kernels.

  Pass "library_filename" to a platform-specific mechanism for dynamically
  loading a library. The rules for determining the exact location of the
  library are platform-specific and are not documented here. When the
  library is loaded, ops and kernels registered in the library via the
  `REGISTER_*` macros are made available in the TensorFlow process. Note
  that ops with the same name as an existing op are rejected and not
  registered with the process.

  Args:
    library_filename: Path to the plugin.
      Relative or absolute filesystem path to a dynamic library file.

  Returns:
    A python module containing the Python wrappers for Ops defined in
    the plugin.

  Raises:
    RuntimeError: when unable to load the library or get the python wrappers.
  """
    status = py_tf.TF_NewStatus()

    lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
    try:
        error_code = py_tf.TF_GetCode(status)
        if error_code != 0:
            error_msg = compat.as_text(py_tf.TF_Message(status))
            # pylint: disable=protected-access
            raise errors_impl._make_specific_exception(None, None, error_msg,
                                                       error_code)
            # pylint: enable=protected-access
    finally:
        py_tf.TF_DeleteStatus(status)

    op_list_str = py_tf.TF_GetOpList(lib_handle)
    op_list = op_def_pb2.OpList()
    op_list.ParseFromString(compat.as_bytes(op_list_str))
    wrappers = py_tf.GetPythonWrappers(op_list_str)

    # Get a unique name for the module.
    module_name = hashlib.md5(wrappers).hexdigest()
    if module_name in sys.modules:
        return sys.modules[module_name]
    module = imp.new_module(module_name)
    # pylint: disable=exec-used
    exec(wrappers, module.__dict__)
    # Stash away the library handle for making calls into the dynamic library.
    module.LIB_HANDLE = lib_handle
    # OpDefs of the list of ops defined in the library.
    module.OP_LIST = op_list
    sys.modules[module_name] = module
    return module
コード例 #7
0
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
  """Convert an existing calibration graph to inference graph.

  Args:
    calibration_graph_def: the calibration GraphDef object with calibration data
    is_dynamic_op: whether to create dynamic static engines from calibration
  Returns:
    New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
  Raises:
    RuntimeError: if the returned status message is malformed.
  """

  def py2string(inp):
    return inp

  def py3string(inp):
    return inp.decode("utf-8")

  if _six.PY2:
    to_string = py2string
  else:
    to_string = py3string
  is_calib_graph = False
  for n in calibration_graph_def.node:
    if n.op == "TRTEngineOp":
      is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s
  if not is_calib_graph:
    tf_logging.error(
        "Not a calib graph. Doesn't seem to contain any calibration nodes.")
    return None
  graph_str = calibration_graph_def.SerializeToString()
  out = calib_convert(graph_str, is_dynamic_op)
  status = to_string(out[0])
  output_graph_def_string = out[1]
  del graph_str  # Save some memory
  if len(status) < 2:
    raise _impl.UnknownError(None, None, status)
  if status[:2] != "OK":
    msg = status.split(";")
    if len(msg) == 1:
      raise RuntimeError("Status message is malformed {}".format(status))
    # pylint: disable=protected-access
    raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
                                         int(msg[0]))
    # pylint: enable=protected-access
  output_graph_def = graph_pb2.GraphDef()
  output_graph_def.ParseFromString(output_graph_def_string)
  del output_graph_def_string  # Save some memory
  return output_graph_def
コード例 #8
0
ファイル: trt_convert.py プロジェクト: ZhangXinNan/tensorflow
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
  """Convert an existing calibration graph to inference graph.

  Args:
    calibration_graph_def: the calibration GraphDef object with calibration data
    is_dynamic_op: whether to create dynamic static engines from calibration
  Returns:
    New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
  Raises:
    RuntimeError: if the returned status message is malformed.
  """

  def py2string(inp):
    return inp

  def py3string(inp):
    return inp.decode("utf-8")

  if _six.PY2:
    to_string = py2string
  else:
    to_string = py3string
  is_calib_graph = False
  for n in calibration_graph_def.node:
    if n.op == "TRTEngineOp":
      is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s
  if not is_calib_graph:
    tf_logging.error(
        "Not a calib graph. Doesn't seem to contain any calibration nodes.")
    return None
  graph_str = calibration_graph_def.SerializeToString()
  out = calib_convert(graph_str, is_dynamic_op)
  status = to_string(out[0])
  output_graph_def_string = out[1]
  del graph_str  # Save some memory
  if len(status) < 2:
    raise _impl.UnknownError(None, None, status)
  if status[:2] != "OK":
    msg = status.split(";")
    if len(msg) == 1:
      raise RuntimeError("Status message is malformed {}".format(status))
    # pylint: disable=protected-access
    raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
                                         int(msg[0]))
    # pylint: enable=protected-access
  output_graph_def = graph_pb2.GraphDef()
  output_graph_def.ParseFromString(output_graph_def_string)
  del output_graph_def_string  # Save some memory
  return output_graph_def
コード例 #9
0
  def testUnknownErrorCodeCausesWarning(self):
    with warnings.catch_warnings(record=True) as w:
      # pylint: disable=protected-access
      exc = errors_impl._make_specific_exception(None, None, None, 37)
      # pylint: enable=protected-access
    self.assertEqual(1, len(w))
    self.assertTrue("Unknown error code: 37" in str(w[0].message))
    self.assertTrue(isinstance(exc, errors_impl.OpError))

    with warnings.catch_warnings(record=True) as w:
      # pylint: disable=protected-access
      exc = errors_impl.error_code_from_exception_type("Unknown")
      # pylint: enable=protected-access
    self.assertEqual(1, len(w))
    self.assertTrue("Unknown class exception" in str(w[0].message))
    self.assertTrue(isinstance(exc, errors_impl.OpError))
コード例 #10
0
 def testKnownErrorClassForEachErrorCodeInProto(self):
   for error_code in error_codes_pb2.Code.values():
     # pylint: disable=line-too-long
     if error_code in (
         error_codes_pb2.OK, error_codes_pb2.
         DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_
     ):
       continue
     # pylint: enable=line-too-long
     with warnings.catch_warnings(record=True) as w:
       # pylint: disable=protected-access
       exc = errors_impl._make_specific_exception(None, None, None, error_code)
       # pylint: enable=protected-access
     self.assertEqual(0, len(w))  # No warning is raised.
     self.assertTrue(isinstance(exc, errors_impl.OpError))
     self.assertTrue(errors_impl.OpError in exc.__class__.__bases__)
コード例 #11
0
 def testKnownErrorClassForEachErrorCodeInProto(self):
   for error_code in error_codes_pb2.Code.values():
     # pylint: disable=line-too-long
     if error_code in (
         error_codes_pb2.OK, error_codes_pb2.
         DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_
     ):
       continue
     # pylint: enable=line-too-long
     with warnings.catch_warnings(record=True) as w:
       # pylint: disable=protected-access
       exc = errors_impl._make_specific_exception(None, None, None, error_code)
       # pylint: enable=protected-access
     self.assertEqual(0, len(w))  # No warning is raised.
     self.assertTrue(isinstance(exc, errors_impl.OpError))
     self.assertTrue(errors_impl.OpError in exc.__class__.__bases__)
コード例 #12
0
ファイル: trt_convert.py プロジェクト: wjjiege/tensorflow
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
    """Convert an existing calibration graph to inference graph.

  Args:
    calibration_graph_def: the calibration GraphDef object with calibration data
    is_dynamic_op: whether to create dynamic static engines from calibration

  Returns:
    New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
  Raises:
    RuntimeError: if the returned status message is malformed.
  """
    # Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain
    # even if it cannot find TensorRT library.
    trt_ops.load_trt_ops()
    # pylint: disable=g-import-not-at-top,line-too-long
    from tensorflow.python.compiler.tensorrt.wrap_conversion import calib_convert
    # pylint: enable=g-import-not-at-top,line-too-long

    is_calib_graph = False
    for n in calibration_graph_def.node:
        if n.op == "TRTEngineOp":
            is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s
    if not is_calib_graph:
        tf_logging.error(
            "Not a calib graph. Doesn't seem to contain any calibration nodes."
        )
        return None
    graph_str = calibration_graph_def.SerializeToString()
    out = calib_convert(graph_str, is_dynamic_op)
    status = _to_string(out[0])
    output_graph_def_string = out[1]
    del graph_str  # Save some memory
    if len(status) < 2:
        raise _impl.UnknownError(None, None, status)
    if status[:2] != "OK":
        msg = status.split(";")
        if len(msg) == 1:
            raise RuntimeError("Status message is malformed {}".format(status))
        # pylint: disable=protected-access
        raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
                                             int(msg[0]))
        # pylint: enable=protected-access
    output_graph_def = graph_pb2.GraphDef()
    output_graph_def.ParseFromString(output_graph_def_string)
    del output_graph_def_string  # Save some memory
    return output_graph_def
コード例 #13
0
ファイル: trt_convert.py プロジェクト: sgcm520/tensorflow2
def calib_graph_to_infer_graph(calibration_graph_def):
  """Convert an existing calibration graph to inference graph.

  Args:
    calibration_graph_def: the calibration GraphDef object with calibration data
  Returns:
    New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
  Raises:
    RuntimeError: if the returned status message is malformed.
  """

  def py2string(inp):
    return inp

  def py3string(inp):
    return inp.decode("utf-8")

  if _six.PY2:
    to_string = py2string
  else:
    to_string = py3string

  graph_str = calibration_graph_def.SerializeToString()
  out = calib_convert(graph_str)
  status = to_string(out[0])
  output_graph_def_string = out[1]
  del graph_str  # Save some memory
  if len(status) < 2:
    raise _impl.UnknownError(None, None, status)
  if status[:2] != "OK":
    msg = status.split(";")
    if len(msg) == 1:
      raise RuntimeError("Status message is malformed {}".format(status))
    # pylint: disable=protected-access
    raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
                                         int(msg[0]))
    # pylint: enable=protected-access
  output_graph_def = graph_pb2.GraphDef()
  output_graph_def.ParseFromString(output_graph_def_string)
  del output_graph_def_string  # Save some memory
  return output_graph_def
コード例 #14
0
 def testUniqueClassForEachErrorCode(self):
     for error_code, exc_type in [
         (errors.CANCELLED, errors_impl.CancelledError),
         (errors.UNKNOWN, errors_impl.UnknownError),
         (errors.INVALID_ARGUMENT, errors_impl.InvalidArgumentError),
         (errors.DEADLINE_EXCEEDED, errors_impl.DeadlineExceededError),
         (errors.NOT_FOUND, errors_impl.NotFoundError),
         (errors.ALREADY_EXISTS, errors_impl.AlreadyExistsError),
         (errors.PERMISSION_DENIED, errors_impl.PermissionDeniedError),
         (errors.UNAUTHENTICATED, errors_impl.UnauthenticatedError),
         (errors.RESOURCE_EXHAUSTED, errors_impl.ResourceExhaustedError),
         (errors.FAILED_PRECONDITION, errors_impl.FailedPreconditionError),
         (errors.ABORTED, errors_impl.AbortedError),
         (errors.OUT_OF_RANGE, errors_impl.OutOfRangeError),
         (errors.UNIMPLEMENTED, errors_impl.UnimplementedError),
         (errors.INTERNAL, errors_impl.InternalError),
         (errors.UNAVAILABLE, errors_impl.UnavailableError),
         (errors.DATA_LOSS, errors_impl.DataLossError),
     ]:
         # pylint: disable=protected-access
         self.assertTrue(
             isinstance(
                 errors_impl._make_specific_exception(
                     None, None, None, error_code), exc_type))
コード例 #15
0
ファイル: trt_convert.py プロジェクト: PACELab/tensorflow-1
def create_inference_graph(input_graph_def,
                           outputs,
                           max_batch_size=1,
                           max_workspace_size_bytes=2 << 20):
    """Python wrapper for the TRT transormation.


  Args:
    input_graph_def: GraphDef object containing a model to be transformed.
    outputs: List of tensors or node names for the model outputs.
    max_batch_size: max size for the input batch
    max_workspace_size_bytes: parameter to control memory allocation (in Bytes)

  Returns:
    New GraphDef with TRTEngineOps placed in graph replacing subgraphs.

  Raises:
    RuntimeError: if the returned status message is malformed.
  """
    def py2bytes(inp):
        return inp

    def py3bytes(inp):
        return inp.encode("utf-8", errors="surrogateescape")

    def py2string(inp):
        return inp

    def py3string(inp):
        return inp.decode("utf-8")

    if _six.PY2:
        to_bytes = py2bytes
        to_string = py2string
    else:
        to_bytes = py3bytes
        to_string = py3string

    out_names = []
    for i in outputs:
        if isinstance(i, ops.Tensor):
            out_names.append(to_bytes(i.name))
        else:
            out_names.append(to_bytes(i))

    input_graph_def_str = input_graph_def.SerializeToString()

    # TODO(sami): Fix this when we can return status from C++ library
    # There is a problem with the TF internal library setup that doesn't
    # allow us to return a status object from C++.  Thus we return a
    # pair or strings where first one is encoded status and the second
    # one is the transformed graphs protobuf string.
    out = trt_convert(input_graph_def_str, out_names, max_batch_size,
                      max_workspace_size_bytes)
    status = to_string(out[0])
    output_graph_def_string = out[1]
    del input_graph_def_str  # Save some memory
    if len(status) < 2:
        raise _impl.UnknownError(None, None, status)
    if status[:2] != "OK":
        msg = status.split(";")
        if len(msg) == 1:
            raise RuntimeError("Status message is malformed {}".format(status))
        # pylint: disable=protected-access
        raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
                                             int(msg[0]))
        # pylint: enable=protected-access
    output_graph_def = graph_pb2.GraphDef()
    output_graph_def.ParseFromString(output_graph_def_string)
    del output_graph_def_string  # Save some memory
    return output_graph_def
コード例 #16
0
def create_inference_graph(input_graph_def,
                           outputs,
                           max_batch_size=1,
                           max_workspace_size_bytes=2 << 20,
                           precision_mode="FP32",
                           minimum_segment_size=3):
    """Python wrapper for the TRT transormation.

  Args:
    input_graph_def: GraphDef object containing a model to be transformed.
    outputs: list of tensors or node names for the model outputs.
    max_batch_size: max size for the input batch
    max_workspace_size_bytes: parameter to control memory allocation (in Bytes)
    precision_mode: one of 'FP32', 'FP16' and 'INT8'
    minimum_segment_size: the minimum number of nodes required for a subgraph to
      be replaced by TRTEngineOp.

  Returns:
    New GraphDef with TRTEngineOps placed in graph replacing subgraphs.

  Raises:
    ValueError: if the provided precision mode is invalid.
    RuntimeError: if the returned status message is malformed.
  """
    supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2}
    if precision_mode.upper() not in supported_precision_modes:
        raise ValueError(
            ("precision mode '{}' is not supported."
             "It should be one of {}").format(precision_mode,
                                              "{'FP32', 'FP16', 'INT8'}"))
    mode = supported_precision_modes[precision_mode.upper()]

    def py2bytes(inp):
        return inp

    def py3bytes(inp):
        return inp.encode("utf-8", errors="surrogateescape")

    def py2string(inp):
        return inp

    def py3string(inp):
        return inp.decode("utf-8")

    if _six.PY2:
        to_bytes = py2bytes
        to_string = py2string
    else:
        to_bytes = py3bytes
        to_string = py3string

    out_names = []
    for i in outputs:
        if isinstance(i, ops.Tensor):
            out_names.append(to_bytes(i.name))
        else:
            out_names.append(to_bytes(i))

    input_graph_def_str = input_graph_def.SerializeToString()

    # TODO(sami): Fix this when we can return status from C++ library
    # There is a problem with the TF internal library setup that doesn't
    # allow us to return a status object from C++.  Thus we return a
    # pair or strings where first one is encoded status and the second
    # one is the transformed graphs protobuf string.
    out = trt_convert(input_graph_def_str, out_names, max_batch_size,
                      max_workspace_size_bytes, mode, minimum_segment_size)
    status = to_string(out[0])
    output_graph_def_string = out[1]
    del input_graph_def_str  # Save some memory
    if len(status) < 2:
        raise _impl.UnknownError(None, None, status)
    if status[:2] != "OK":
        msg = status.split(";")
        if len(msg) == 1:
            raise RuntimeError("Status message is malformed {}".format(status))
        # pylint: disable=protected-access
        raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
                                             int(msg[0]))
        # pylint: enable=protected-access
    output_graph_def = graph_pb2.GraphDef()
    output_graph_def.ParseFromString(output_graph_def_string)
    del output_graph_def_string  # Save some memory
    return output_graph_def
コード例 #17
0
def create_inference_graph(input_graph_def,
                           outputs,
                           max_batch_size=1,
                           max_workspace_size_bytes=2 << 20,
                           precision_mode="FP32",
                           minimum_segment_size=3,
                           is_dynamic_op=False,
                           maximum_cached_engines=1,
                           cached_engine_batches=[]):
    """Python wrapper for the TRT transformation.

  Args:
    input_graph_def: GraphDef object containing a model to be transformed.
    outputs: list of tensors or node names for the model outputs.
    max_batch_size: max size for the input batch
    max_workspace_size_bytes: parameter to control memory allocation (in Bytes)
    precision_mode: one of 'FP32', 'FP16' and 'INT8'
    minimum_segment_size: the minimum number of nodes required for a subgraph to
      be replaced by TRTEngineOp.
    is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
      network and engine at run time.
    maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
    cached_engine_batches: batch sizes used to pre-create cached engines.

  Returns:
    New GraphDef with TRTEngineOps placed in graph replacing subgraphs.

  Raises:
    ValueError: if the provided precision mode is invalid.
    RuntimeError: if the returned status message is malformed.
  """
    supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2}
    if precision_mode.upper() not in supported_precision_modes:
        raise ValueError(
            ("precision mode '{}' is not supported."
             "It should be one of {}").format(precision_mode,
                                              "{'FP32', 'FP16', 'INT8'}"))
    mode = supported_precision_modes[precision_mode.upper()]
    compiled_version = get_linked_tensorrt_version()
    loaded_version = get_loaded_tensorrt_version()
    version_mismatch = False
    if loaded_version[0] < compiled_version[0]:
        tf_logging.error(
            "TensorRT version mismatch. Tensorflow was compiled against " +
            "TensorRT %s but library loaded from environment is TensorRT %s" %
            (".".join([str(x) for x in compiled_version]),
             ".".join([str(x) for x in loaded_version])) +
            ". Please make sure that correct version of TensorRT " +
            "is available in the system and added to ldconfig or LD_LIBRARY_PATH"
        )
        raise RuntimeError("Incompatible TensorRT library version")
    for i in zip(loaded_version, compiled_version):
        if i[0] != i[1]:
            tf_logging.warn("TensorRT mismatch. Compiled against version " +
                            "%s, but loaded %s. Things may not work" %
                            (".".join([str(x) for x in compiled_version]),
                             ".".join([str(x) for x in loaded_version])))
            version_mismatch = True
            break
    if not version_mismatch:
        tf_logging.info("Running against TensorRT version %s" %
                        ".".join([str(x) for x in loaded_version]))

    def py2bytes(inp):
        return inp

    def py3bytes(inp):
        return inp.encode("utf-8", errors="surrogateescape")

    def py2string(inp):
        return inp

    def py3string(inp):
        return inp.decode("utf-8")

    if _six.PY2:
        to_bytes = py2bytes
        to_string = py2string
    else:
        to_bytes = py3bytes
        to_string = py3string

    out_names = []
    for i in outputs:
        if isinstance(i, ops.Tensor):
            out_names.append(to_bytes(i.name))
        else:
            out_names.append(to_bytes(i))

    input_graph_def_str = input_graph_def.SerializeToString()

    # TODO(sami): Fix this when we can return status from C++ library
    # There is a problem with the TF internal library setup that doesn't
    # allow us to return a status object from C++.  Thus we return a
    # pair or strings where first one is encoded status and the second
    # one is the transformed graphs protobuf string.
    out = trt_convert(input_graph_def_str, out_names, max_batch_size,
                      max_workspace_size_bytes, mode, minimum_segment_size,
                      is_dynamic_op, maximum_cached_engines,
                      cached_engine_batches)
    status = to_string(out[0])
    output_graph_def_string = out[1]
    del input_graph_def_str  # Save some memory
    if len(status) < 2:
        raise _impl.UnknownError(None, None, status)
    if status[:2] != "OK":
        msg = status.split(";")
        if len(msg) == 1:
            raise RuntimeError("Status message is malformed {}".format(status))
        # pylint: disable=protected-access
        raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
                                             int(msg[0]))
        # pylint: enable=protected-access
    output_graph_def = graph_pb2.GraphDef()
    output_graph_def.ParseFromString(output_graph_def_string)
    del output_graph_def_string  # Save some memory
    return output_graph_def
コード例 #18
0
ファイル: trt_convert.py プロジェクト: StephenOman/tensorflow
def create_inference_graph(input_graph_def,
                           outputs,
                           max_batch_size=1,
                           max_workspace_size_bytes=2 << 20,
                           precision_mode="FP32",
                           minimum_segment_size=3,
                           is_dynamic_op=False,
                           maximum_cached_engines=1,
                           cached_engine_batches=[]):
  """Python wrapper for the TRT transformation.

  Args:
    input_graph_def: GraphDef object containing a model to be transformed.
    outputs: list of tensors or node names for the model outputs.
    max_batch_size: max size for the input batch
    max_workspace_size_bytes: parameter to control memory allocation (in Bytes)
    precision_mode: one of 'FP32', 'FP16' and 'INT8'
    minimum_segment_size: the minimum number of nodes required for a subgraph to
      be replaced by TRTEngineOp.
    is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
      network and engine at run time.
    maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
    cached_engine_batches: batch sizes used to pre-create cached engines.

  Returns:
    New GraphDef with TRTEngineOps placed in graph replacing subgraphs.

  Raises:
    ValueError: if the provided precision mode is invalid.
    RuntimeError: if the returned status message is malformed.
  """
  supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2}
  if precision_mode.upper() not in supported_precision_modes:
    raise ValueError(("precision mode '{}' is not supported."
                      "It should be one of {}").format(
                          precision_mode, "{'FP32', 'FP16', 'INT8'}"))
  mode = supported_precision_modes[precision_mode.upper()]
  compiled_version = get_linked_tensorrt_version()
  loaded_version = get_loaded_tensorrt_version()
  version_mismatch = False
  if loaded_version[0] < compiled_version[0]:
    tf_logging.error(
        "TensorRT version mismatch. Tensorflow was compiled against " +
        "TensorRT %s but library loaded from environment is TensorRT %s" %
        (".".join([str(x) for x in compiled_version]),
         ".".join([str(x) for x in loaded_version])) +
        ". Please make sure that correct version of TensorRT " +
        "is available in the system and added to ldconfig or LD_LIBRARY_PATH"
    )
    raise RuntimeError("Incompatible TensorRT library version")
  for i in zip(loaded_version, compiled_version):
    if i[0] != i[1]:
      tf_logging.warn("TensorRT mismatch. Compiled against version " +
                      "%s, but loaded %s. Things may not work" %
                      (".".join([str(x) for x in compiled_version]),
                       ".".join([str(x) for x in loaded_version])))
      version_mismatch = True
      break
  if not version_mismatch:
    tf_logging.info("Running against TensorRT version %s" % ".".join(
        [str(x) for x in loaded_version]))

  def py2bytes(inp):
    return inp

  def py3bytes(inp):
    return inp.encode("utf-8", errors="surrogateescape")

  def py2string(inp):
    return inp

  def py3string(inp):
    return inp.decode("utf-8")

  if _six.PY2:
    to_bytes = py2bytes
    to_string = py2string
  else:
    to_bytes = py3bytes
    to_string = py3string

  out_names = []
  for i in outputs:
    if isinstance(i, ops.Tensor):
      out_names.append(to_bytes(i.name))
    else:
      out_names.append(to_bytes(i))

  input_graph_def_str = input_graph_def.SerializeToString()

  # TODO(sami): Fix this when we can return status from C++ library
  # There is a problem with the TF internal library setup that doesn't
  # allow us to return a status object from C++.  Thus we return a
  # pair or strings where first one is encoded status and the second
  # one is the transformed graphs protobuf string.
  out = trt_convert(input_graph_def_str, out_names, max_batch_size,
                    max_workspace_size_bytes, mode, minimum_segment_size,
                    is_dynamic_op, maximum_cached_engines,
                    cached_engine_batches)
  status = to_string(out[0])
  output_graph_def_string = out[1]
  del input_graph_def_str  # Save some memory
  if len(status) < 2:
    raise _impl.UnknownError(None, None, status)
  if status[:2] != "OK":
    msg = status.split(";")
    if len(msg) == 1:
      raise RuntimeError("Status message is malformed {}".format(status))
    # pylint: disable=protected-access
    raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
                                         int(msg[0]))
    # pylint: enable=protected-access
  output_graph_def = graph_pb2.GraphDef()
  output_graph_def.ParseFromString(output_graph_def_string)
  del output_graph_def_string  # Save some memory
  return output_graph_def
コード例 #19
0
def create_inference_graph(input_graph_def,
                           outputs,
                           max_batch_size=1,
                           max_workspace_size_bytes=2 << 20):
  """Python wrapper for the TRT transormation.


  Args:
    input_graph_def: GraphDef object containing a model to be transformed.
    outputs: List of tensors or node names for the model outputs.
    max_batch_size: max size for the input batch
    max_workspace_size_bytes: parameter to control memory allocation (in Bytes)

  Returns:
    New GraphDef with TRTEngineOps placed in graph replacing subgraphs.

  Raises:
    RuntimeError: if the returned status message is malformed.
  """

  def py2bytes(inp):
    return inp

  def py3bytes(inp):
    return inp.encode("utf-8", errors="surrogateescape")

  def py2string(inp):
    return inp

  def py3string(inp):
    return inp.decode("utf-8")

  if _six.PY2:
    to_bytes = py2bytes
    to_string = py2string
  else:
    to_bytes = py3bytes
    to_string = py3string

  out_names = []
  for i in outputs:
    if isinstance(i, ops.Tensor):
      out_names.append(to_bytes(i.name))
    else:
      out_names.append(to_bytes(i))

  input_graph_def_str = input_graph_def.SerializeToString()

  # TODO(sami): Fix this when we can return status from C++ library
  # There is a problem with the TF internal library setup that doesn't
  # allow us to return a status object from C++.  Thus we return a
  # pair or strings where first one is encoded status and the second
  # one is the transformed graphs protobuf string.
  out = trt_convert(input_graph_def_str, out_names, max_batch_size,
                    max_workspace_size_bytes)
  status = to_string(out[0])
  output_graph_def_string = out[1]
  del input_graph_def_str  # Save some memory
  if len(status) < 2:
    raise _impl.UnknownError(None, None, status)
  if status[:2] != "OK":
    msg = status.split(";")
    if len(msg) == 1:
      raise RuntimeError("Status message is malformed {}".format(status))
    # pylint: disable=protected-access
    raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
                                         int(msg[0]))
    # pylint: enable=protected-access
  output_graph_def = graph_pb2.GraphDef()
  output_graph_def.ParseFromString(output_graph_def_string)
  del output_graph_def_string  # Save some memory
  return output_graph_def
コード例 #20
0
ファイル: trt_convert.py プロジェクト: AndrewTwinz/tensorflow
def create_inference_graph(input_graph_def,
                           outputs,
                           max_batch_size=1,
                           max_workspace_size_bytes=2 << 20,
                           precision_mode="FP32",
                           minimum_segment_size=3):
  """Python wrapper for the TRT transformation.

  Args:
    input_graph_def: GraphDef object containing a model to be transformed.
    outputs: list of tensors or node names for the model outputs.
    max_batch_size: max size for the input batch
    max_workspace_size_bytes: parameter to control memory allocation (in Bytes)
    precision_mode: one of 'FP32', 'FP16' and 'INT8'
    minimum_segment_size: the minimum number of nodes required for a subgraph to
      be replaced by TRTEngineOp.

  Returns:
    New GraphDef with TRTEngineOps placed in graph replacing subgraphs.

  Raises:
    ValueError: if the provided precision mode is invalid.
    RuntimeError: if the returned status message is malformed.
  """
  supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2}
  if precision_mode.upper() not in supported_precision_modes:
    raise ValueError(("precision mode '{}' is not supported."
                      "It should be one of {}").format(
                          precision_mode, "{'FP32', 'FP16', 'INT8'}"))
  mode = supported_precision_modes[precision_mode.upper()]

  def py2bytes(inp):
    return inp

  def py3bytes(inp):
    return inp.encode("utf-8", errors="surrogateescape")

  def py2string(inp):
    return inp

  def py3string(inp):
    return inp.decode("utf-8")

  if _six.PY2:
    to_bytes = py2bytes
    to_string = py2string
  else:
    to_bytes = py3bytes
    to_string = py3string

  out_names = []
  for i in outputs:
    if isinstance(i, ops.Tensor):
      out_names.append(to_bytes(i.name))
    else:
      out_names.append(to_bytes(i))

  input_graph_def_str = input_graph_def.SerializeToString()

  # TODO(sami): Fix this when we can return status from C++ library
  # There is a problem with the TF internal library setup that doesn't
  # allow us to return a status object from C++.  Thus we return a
  # pair or strings where first one is encoded status and the second
  # one is the transformed graphs protobuf string.
  out = trt_convert(input_graph_def_str, out_names, max_batch_size,
                    max_workspace_size_bytes, mode, minimum_segment_size)
  status = to_string(out[0])
  output_graph_def_string = out[1]
  del input_graph_def_str  # Save some memory
  if len(status) < 2:
    raise _impl.UnknownError(None, None, status)
  if status[:2] != "OK":
    msg = status.split(";")
    if len(msg) == 1:
      raise RuntimeError("Status message is malformed {}".format(status))
    # pylint: disable=protected-access
    raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
                                         int(msg[0]))
    # pylint: enable=protected-access
  output_graph_def = graph_pb2.GraphDef()
  output_graph_def.ParseFromString(output_graph_def_string)
  del output_graph_def_string  # Save some memory
  return output_graph_def