def testPickleable(self): for error_code in [ errors.CANCELLED, errors.UNKNOWN, errors.INVALID_ARGUMENT, errors.DEADLINE_EXCEEDED, errors.NOT_FOUND, errors.ALREADY_EXISTS, errors.PERMISSION_DENIED, errors.UNAUTHENTICATED, errors.RESOURCE_EXHAUSTED, errors.FAILED_PRECONDITION, errors.ABORTED, errors.OUT_OF_RANGE, errors.UNIMPLEMENTED, errors.INTERNAL, errors.UNAVAILABLE, errors.DATA_LOSS, ]: # pylint: disable=protected-access exc = errors_impl._make_specific_exception(None, None, None, error_code) # pylint: enable=protected-access unpickled = pickle.loads(pickle.dumps(exc)) self.assertEqual(exc.node_def, unpickled.node_def) self.assertEqual(exc.op, unpickled.op) self.assertEqual(exc.message, unpickled.message) self.assertEqual(exc.error_code, unpickled.error_code)
def load_file_system_library(library_filename): """Loads a TensorFlow plugin, containing file system implementation. Pass `library_filename` to a platform-specific mechanism for dynamically loading a library. The rules for determining the exact location of the library are platform-specific and are not documented here. Args: library_filename: Path to the plugin. Relative or absolute filesystem path to a dynamic library file. Returns: None. Raises: RuntimeError: when unable to load the library. """ status = py_tf.TF_NewStatus() lib_handle = py_tf.TF_LoadLibrary(library_filename, status) try: error_code = py_tf.TF_GetCode(status) if error_code != 0: error_msg = compat.as_text(py_tf.TF_Message(status)) # pylint: disable=protected-access raise errors_impl._make_specific_exception( None, None, error_msg, error_code) # pylint: enable=protected-access finally: py_tf.TF_DeleteStatus(status)
def testUniqueClassForEachErrorCode(self): for error_code, exc_type in [ (errors.CANCELLED, errors_impl.CancelledError), (errors.UNKNOWN, errors_impl.UnknownError), (errors.INVALID_ARGUMENT, errors_impl.InvalidArgumentError), (errors.DEADLINE_EXCEEDED, errors_impl.DeadlineExceededError), (errors.NOT_FOUND, errors_impl.NotFoundError), (errors.ALREADY_EXISTS, errors_impl.AlreadyExistsError), (errors.PERMISSION_DENIED, errors_impl.PermissionDeniedError), (errors.UNAUTHENTICATED, errors_impl.UnauthenticatedError), (errors.RESOURCE_EXHAUSTED, errors_impl.ResourceExhaustedError), (errors.FAILED_PRECONDITION, errors_impl.FailedPreconditionError), (errors.ABORTED, errors_impl.AbortedError), (errors.OUT_OF_RANGE, errors_impl.OutOfRangeError), (errors.UNIMPLEMENTED, errors_impl.UnimplementedError), (errors.INTERNAL, errors_impl.InternalError), (errors.UNAVAILABLE, errors_impl.UnavailableError), (errors.DATA_LOSS, errors_impl.DataLossError), ]: # pylint: disable=protected-access self.assertTrue( isinstance( errors_impl._make_specific_exception(None, None, None, error_code), exc_type)) # error_code_from_exception_type and exception_type_from_error_code should # be consistent with operation result. self.assertEqual(error_code, errors_impl.error_code_from_exception_type(exc_type))
def testUnknownErrorCodeCausesWarning(self): with warnings.catch_warnings(record=True) as w: # pylint: disable=protected-access exc = errors_impl._make_specific_exception(None, None, None, 37) # pylint: enable=protected-access self.assertEqual(1, len(w)) self.assertTrue("Unknown error code: 37" in str(w[0].message)) self.assertTrue(isinstance(exc, errors_impl.OpError))
def load_op_library(library_filename): """Loads a TensorFlow plugin, containing custom ops and kernels. Pass "library_filename" to a platform-specific mechanism for dynamically loading a library. The rules for determining the exact location of the library are platform-specific and are not documented here. When the library is loaded, ops and kernels registered in the library via the `REGISTER_*` macros are made available in the TensorFlow process. Note that ops with the same name as an existing op are rejected and not registered with the process. Args: library_filename: Path to the plugin. Relative or absolute filesystem path to a dynamic library file. Returns: A python module containing the Python wrappers for Ops defined in the plugin. Raises: RuntimeError: when unable to load the library or get the python wrappers. """ status = py_tf.TF_NewStatus() lib_handle = py_tf.TF_LoadLibrary(library_filename, status) try: error_code = py_tf.TF_GetCode(status) if error_code != 0: error_msg = compat.as_text(py_tf.TF_Message(status)) # pylint: disable=protected-access raise errors_impl._make_specific_exception( None, None, error_msg, error_code) # pylint: enable=protected-access finally: py_tf.TF_DeleteStatus(status) op_list_str = py_tf.TF_GetOpList(lib_handle) op_list = op_def_pb2.OpList() op_list.ParseFromString(compat.as_bytes(op_list_str)) wrappers = py_tf.GetPythonWrappers(op_list_str) # Delete the library handle to release any memory held in C # that are no longer needed. py_tf.TF_DeleteLibraryHandle(lib_handle) # Get a unique name for the module. module_name = hashlib.md5(wrappers).hexdigest() if module_name in sys.modules: return sys.modules[module_name] module = imp.new_module(module_name) # pylint: disable=exec-used exec(wrappers, module.__dict__) # Stash away the library handle for making calls into the dynamic library. module.LIB_HANDLE = lib_handle # OpDefs of the list of ops defined in the library. module.OP_LIST = op_list sys.modules[module_name] = module return module
def load_op_library(library_filename): """Loads a TensorFlow plugin, containing custom ops and kernels. Pass "library_filename" to a platform-specific mechanism for dynamically loading a library. The rules for determining the exact location of the library are platform-specific and are not documented here. When the library is loaded, ops and kernels registered in the library via the `REGISTER_*` macros are made available in the TensorFlow process. Note that ops with the same name as an existing op are rejected and not registered with the process. Args: library_filename: Path to the plugin. Relative or absolute filesystem path to a dynamic library file. Returns: A python module containing the Python wrappers for Ops defined in the plugin. Raises: RuntimeError: when unable to load the library or get the python wrappers. """ status = py_tf.TF_NewStatus() lib_handle = py_tf.TF_LoadLibrary(library_filename, status) try: error_code = py_tf.TF_GetCode(status) if error_code != 0: error_msg = compat.as_text(py_tf.TF_Message(status)) # pylint: disable=protected-access raise errors_impl._make_specific_exception(None, None, error_msg, error_code) # pylint: enable=protected-access finally: py_tf.TF_DeleteStatus(status) op_list_str = py_tf.TF_GetOpList(lib_handle) op_list = op_def_pb2.OpList() op_list.ParseFromString(compat.as_bytes(op_list_str)) wrappers = py_tf.GetPythonWrappers(op_list_str) # Get a unique name for the module. module_name = hashlib.md5(wrappers).hexdigest() if module_name in sys.modules: return sys.modules[module_name] module = imp.new_module(module_name) # pylint: disable=exec-used exec(wrappers, module.__dict__) # Stash away the library handle for making calls into the dynamic library. module.LIB_HANDLE = lib_handle # OpDefs of the list of ops defined in the library. module.OP_LIST = op_list sys.modules[module_name] = module return module
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): """Convert an existing calibration graph to inference graph. Args: calibration_graph_def: the calibration GraphDef object with calibration data is_dynamic_op: whether to create dynamic static engines from calibration Returns: New GraphDef with TRTEngineOps placed in graph replacing calibration nodes. Raises: RuntimeError: if the returned status message is malformed. """ def py2string(inp): return inp def py3string(inp): return inp.decode("utf-8") if _six.PY2: to_string = py2string else: to_string = py3string is_calib_graph = False for n in calibration_graph_def.node: if n.op == "TRTEngineOp": is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s if not is_calib_graph: tf_logging.error( "Not a calib graph. Doesn't seem to contain any calibration nodes.") return None graph_str = calibration_graph_def.SerializeToString() out = calib_convert(graph_str, is_dynamic_op) status = to_string(out[0]) output_graph_def_string = out[1] del graph_str # Save some memory if len(status) < 2: raise _impl.UnknownError(None, None, status) if status[:2] != "OK": msg = status.split(";") if len(msg) == 1: raise RuntimeError("Status message is malformed {}".format(status)) # pylint: disable=protected-access raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), int(msg[0])) # pylint: enable=protected-access output_graph_def = graph_pb2.GraphDef() output_graph_def.ParseFromString(output_graph_def_string) del output_graph_def_string # Save some memory return output_graph_def
def testUnknownErrorCodeCausesWarning(self): with warnings.catch_warnings(record=True) as w: # pylint: disable=protected-access exc = errors_impl._make_specific_exception(None, None, None, 37) # pylint: enable=protected-access self.assertEqual(1, len(w)) self.assertTrue("Unknown error code: 37" in str(w[0].message)) self.assertTrue(isinstance(exc, errors_impl.OpError)) with warnings.catch_warnings(record=True) as w: # pylint: disable=protected-access exc = errors_impl.error_code_from_exception_type("Unknown") # pylint: enable=protected-access self.assertEqual(1, len(w)) self.assertTrue("Unknown class exception" in str(w[0].message)) self.assertTrue(isinstance(exc, errors_impl.OpError))
def testKnownErrorClassForEachErrorCodeInProto(self): for error_code in error_codes_pb2.Code.values(): # pylint: disable=line-too-long if error_code in ( error_codes_pb2.OK, error_codes_pb2. DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_ ): continue # pylint: enable=line-too-long with warnings.catch_warnings(record=True) as w: # pylint: disable=protected-access exc = errors_impl._make_specific_exception(None, None, None, error_code) # pylint: enable=protected-access self.assertEqual(0, len(w)) # No warning is raised. self.assertTrue(isinstance(exc, errors_impl.OpError)) self.assertTrue(errors_impl.OpError in exc.__class__.__bases__)
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): """Convert an existing calibration graph to inference graph. Args: calibration_graph_def: the calibration GraphDef object with calibration data is_dynamic_op: whether to create dynamic static engines from calibration Returns: New GraphDef with TRTEngineOps placed in graph replacing calibration nodes. Raises: RuntimeError: if the returned status message is malformed. """ # Lazily load the TF-TRT C bindings, so `import tensorflow` doesn't complain # even if it cannot find TensorRT library. trt_ops.load_trt_ops() # pylint: disable=g-import-not-at-top,line-too-long from tensorflow.python.compiler.tensorrt.wrap_conversion import calib_convert # pylint: enable=g-import-not-at-top,line-too-long is_calib_graph = False for n in calibration_graph_def.node: if n.op == "TRTEngineOp": is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s if not is_calib_graph: tf_logging.error( "Not a calib graph. Doesn't seem to contain any calibration nodes." ) return None graph_str = calibration_graph_def.SerializeToString() out = calib_convert(graph_str, is_dynamic_op) status = _to_string(out[0]) output_graph_def_string = out[1] del graph_str # Save some memory if len(status) < 2: raise _impl.UnknownError(None, None, status) if status[:2] != "OK": msg = status.split(";") if len(msg) == 1: raise RuntimeError("Status message is malformed {}".format(status)) # pylint: disable=protected-access raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), int(msg[0])) # pylint: enable=protected-access output_graph_def = graph_pb2.GraphDef() output_graph_def.ParseFromString(output_graph_def_string) del output_graph_def_string # Save some memory return output_graph_def
def calib_graph_to_infer_graph(calibration_graph_def): """Convert an existing calibration graph to inference graph. Args: calibration_graph_def: the calibration GraphDef object with calibration data Returns: New GraphDef with TRTEngineOps placed in graph replacing calibration nodes. Raises: RuntimeError: if the returned status message is malformed. """ def py2string(inp): return inp def py3string(inp): return inp.decode("utf-8") if _six.PY2: to_string = py2string else: to_string = py3string graph_str = calibration_graph_def.SerializeToString() out = calib_convert(graph_str) status = to_string(out[0]) output_graph_def_string = out[1] del graph_str # Save some memory if len(status) < 2: raise _impl.UnknownError(None, None, status) if status[:2] != "OK": msg = status.split(";") if len(msg) == 1: raise RuntimeError("Status message is malformed {}".format(status)) # pylint: disable=protected-access raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), int(msg[0])) # pylint: enable=protected-access output_graph_def = graph_pb2.GraphDef() output_graph_def.ParseFromString(output_graph_def_string) del output_graph_def_string # Save some memory return output_graph_def
def testUniqueClassForEachErrorCode(self): for error_code, exc_type in [ (errors.CANCELLED, errors_impl.CancelledError), (errors.UNKNOWN, errors_impl.UnknownError), (errors.INVALID_ARGUMENT, errors_impl.InvalidArgumentError), (errors.DEADLINE_EXCEEDED, errors_impl.DeadlineExceededError), (errors.NOT_FOUND, errors_impl.NotFoundError), (errors.ALREADY_EXISTS, errors_impl.AlreadyExistsError), (errors.PERMISSION_DENIED, errors_impl.PermissionDeniedError), (errors.UNAUTHENTICATED, errors_impl.UnauthenticatedError), (errors.RESOURCE_EXHAUSTED, errors_impl.ResourceExhaustedError), (errors.FAILED_PRECONDITION, errors_impl.FailedPreconditionError), (errors.ABORTED, errors_impl.AbortedError), (errors.OUT_OF_RANGE, errors_impl.OutOfRangeError), (errors.UNIMPLEMENTED, errors_impl.UnimplementedError), (errors.INTERNAL, errors_impl.InternalError), (errors.UNAVAILABLE, errors_impl.UnavailableError), (errors.DATA_LOSS, errors_impl.DataLossError), ]: # pylint: disable=protected-access self.assertTrue( isinstance( errors_impl._make_specific_exception( None, None, None, error_code), exc_type))
def create_inference_graph(input_graph_def, outputs, max_batch_size=1, max_workspace_size_bytes=2 << 20): """Python wrapper for the TRT transormation. Args: input_graph_def: GraphDef object containing a model to be transformed. outputs: List of tensors or node names for the model outputs. max_batch_size: max size for the input batch max_workspace_size_bytes: parameter to control memory allocation (in Bytes) Returns: New GraphDef with TRTEngineOps placed in graph replacing subgraphs. Raises: RuntimeError: if the returned status message is malformed. """ def py2bytes(inp): return inp def py3bytes(inp): return inp.encode("utf-8", errors="surrogateescape") def py2string(inp): return inp def py3string(inp): return inp.decode("utf-8") if _six.PY2: to_bytes = py2bytes to_string = py2string else: to_bytes = py3bytes to_string = py3string out_names = [] for i in outputs: if isinstance(i, ops.Tensor): out_names.append(to_bytes(i.name)) else: out_names.append(to_bytes(i)) input_graph_def_str = input_graph_def.SerializeToString() # TODO(sami): Fix this when we can return status from C++ library # There is a problem with the TF internal library setup that doesn't # allow us to return a status object from C++. Thus we return a # pair or strings where first one is encoded status and the second # one is the transformed graphs protobuf string. out = trt_convert(input_graph_def_str, out_names, max_batch_size, max_workspace_size_bytes) status = to_string(out[0]) output_graph_def_string = out[1] del input_graph_def_str # Save some memory if len(status) < 2: raise _impl.UnknownError(None, None, status) if status[:2] != "OK": msg = status.split(";") if len(msg) == 1: raise RuntimeError("Status message is malformed {}".format(status)) # pylint: disable=protected-access raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), int(msg[0])) # pylint: enable=protected-access output_graph_def = graph_pb2.GraphDef() output_graph_def.ParseFromString(output_graph_def_string) del output_graph_def_string # Save some memory return output_graph_def
def create_inference_graph(input_graph_def, outputs, max_batch_size=1, max_workspace_size_bytes=2 << 20, precision_mode="FP32", minimum_segment_size=3): """Python wrapper for the TRT transormation. Args: input_graph_def: GraphDef object containing a model to be transformed. outputs: list of tensors or node names for the model outputs. max_batch_size: max size for the input batch max_workspace_size_bytes: parameter to control memory allocation (in Bytes) precision_mode: one of 'FP32', 'FP16' and 'INT8' minimum_segment_size: the minimum number of nodes required for a subgraph to be replaced by TRTEngineOp. Returns: New GraphDef with TRTEngineOps placed in graph replacing subgraphs. Raises: ValueError: if the provided precision mode is invalid. RuntimeError: if the returned status message is malformed. """ supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2} if precision_mode.upper() not in supported_precision_modes: raise ValueError( ("precision mode '{}' is not supported." "It should be one of {}").format(precision_mode, "{'FP32', 'FP16', 'INT8'}")) mode = supported_precision_modes[precision_mode.upper()] def py2bytes(inp): return inp def py3bytes(inp): return inp.encode("utf-8", errors="surrogateescape") def py2string(inp): return inp def py3string(inp): return inp.decode("utf-8") if _six.PY2: to_bytes = py2bytes to_string = py2string else: to_bytes = py3bytes to_string = py3string out_names = [] for i in outputs: if isinstance(i, ops.Tensor): out_names.append(to_bytes(i.name)) else: out_names.append(to_bytes(i)) input_graph_def_str = input_graph_def.SerializeToString() # TODO(sami): Fix this when we can return status from C++ library # There is a problem with the TF internal library setup that doesn't # allow us to return a status object from C++. Thus we return a # pair or strings where first one is encoded status and the second # one is the transformed graphs protobuf string. out = trt_convert(input_graph_def_str, out_names, max_batch_size, max_workspace_size_bytes, mode, minimum_segment_size) status = to_string(out[0]) output_graph_def_string = out[1] del input_graph_def_str # Save some memory if len(status) < 2: raise _impl.UnknownError(None, None, status) if status[:2] != "OK": msg = status.split(";") if len(msg) == 1: raise RuntimeError("Status message is malformed {}".format(status)) # pylint: disable=protected-access raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), int(msg[0])) # pylint: enable=protected-access output_graph_def = graph_pb2.GraphDef() output_graph_def.ParseFromString(output_graph_def_string) del output_graph_def_string # Save some memory return output_graph_def
def create_inference_graph(input_graph_def, outputs, max_batch_size=1, max_workspace_size_bytes=2 << 20, precision_mode="FP32", minimum_segment_size=3, is_dynamic_op=False, maximum_cached_engines=1, cached_engine_batches=[]): """Python wrapper for the TRT transformation. Args: input_graph_def: GraphDef object containing a model to be transformed. outputs: list of tensors or node names for the model outputs. max_batch_size: max size for the input batch max_workspace_size_bytes: parameter to control memory allocation (in Bytes) precision_mode: one of 'FP32', 'FP16' and 'INT8' minimum_segment_size: the minimum number of nodes required for a subgraph to be replaced by TRTEngineOp. is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT network and engine at run time. maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops. cached_engine_batches: batch sizes used to pre-create cached engines. Returns: New GraphDef with TRTEngineOps placed in graph replacing subgraphs. Raises: ValueError: if the provided precision mode is invalid. RuntimeError: if the returned status message is malformed. """ supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2} if precision_mode.upper() not in supported_precision_modes: raise ValueError( ("precision mode '{}' is not supported." "It should be one of {}").format(precision_mode, "{'FP32', 'FP16', 'INT8'}")) mode = supported_precision_modes[precision_mode.upper()] compiled_version = get_linked_tensorrt_version() loaded_version = get_loaded_tensorrt_version() version_mismatch = False if loaded_version[0] < compiled_version[0]: tf_logging.error( "TensorRT version mismatch. Tensorflow was compiled against " + "TensorRT %s but library loaded from environment is TensorRT %s" % (".".join([str(x) for x in compiled_version]), ".".join([str(x) for x in loaded_version])) + ". Please make sure that correct version of TensorRT " + "is available in the system and added to ldconfig or LD_LIBRARY_PATH" ) raise RuntimeError("Incompatible TensorRT library version") for i in zip(loaded_version, compiled_version): if i[0] != i[1]: tf_logging.warn("TensorRT mismatch. Compiled against version " + "%s, but loaded %s. Things may not work" % (".".join([str(x) for x in compiled_version]), ".".join([str(x) for x in loaded_version]))) version_mismatch = True break if not version_mismatch: tf_logging.info("Running against TensorRT version %s" % ".".join([str(x) for x in loaded_version])) def py2bytes(inp): return inp def py3bytes(inp): return inp.encode("utf-8", errors="surrogateescape") def py2string(inp): return inp def py3string(inp): return inp.decode("utf-8") if _six.PY2: to_bytes = py2bytes to_string = py2string else: to_bytes = py3bytes to_string = py3string out_names = [] for i in outputs: if isinstance(i, ops.Tensor): out_names.append(to_bytes(i.name)) else: out_names.append(to_bytes(i)) input_graph_def_str = input_graph_def.SerializeToString() # TODO(sami): Fix this when we can return status from C++ library # There is a problem with the TF internal library setup that doesn't # allow us to return a status object from C++. Thus we return a # pair or strings where first one is encoded status and the second # one is the transformed graphs protobuf string. out = trt_convert(input_graph_def_str, out_names, max_batch_size, max_workspace_size_bytes, mode, minimum_segment_size, is_dynamic_op, maximum_cached_engines, cached_engine_batches) status = to_string(out[0]) output_graph_def_string = out[1] del input_graph_def_str # Save some memory if len(status) < 2: raise _impl.UnknownError(None, None, status) if status[:2] != "OK": msg = status.split(";") if len(msg) == 1: raise RuntimeError("Status message is malformed {}".format(status)) # pylint: disable=protected-access raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), int(msg[0])) # pylint: enable=protected-access output_graph_def = graph_pb2.GraphDef() output_graph_def.ParseFromString(output_graph_def_string) del output_graph_def_string # Save some memory return output_graph_def
def create_inference_graph(input_graph_def, outputs, max_batch_size=1, max_workspace_size_bytes=2 << 20, precision_mode="FP32", minimum_segment_size=3, is_dynamic_op=False, maximum_cached_engines=1, cached_engine_batches=[]): """Python wrapper for the TRT transformation. Args: input_graph_def: GraphDef object containing a model to be transformed. outputs: list of tensors or node names for the model outputs. max_batch_size: max size for the input batch max_workspace_size_bytes: parameter to control memory allocation (in Bytes) precision_mode: one of 'FP32', 'FP16' and 'INT8' minimum_segment_size: the minimum number of nodes required for a subgraph to be replaced by TRTEngineOp. is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT network and engine at run time. maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops. cached_engine_batches: batch sizes used to pre-create cached engines. Returns: New GraphDef with TRTEngineOps placed in graph replacing subgraphs. Raises: ValueError: if the provided precision mode is invalid. RuntimeError: if the returned status message is malformed. """ supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2} if precision_mode.upper() not in supported_precision_modes: raise ValueError(("precision mode '{}' is not supported." "It should be one of {}").format( precision_mode, "{'FP32', 'FP16', 'INT8'}")) mode = supported_precision_modes[precision_mode.upper()] compiled_version = get_linked_tensorrt_version() loaded_version = get_loaded_tensorrt_version() version_mismatch = False if loaded_version[0] < compiled_version[0]: tf_logging.error( "TensorRT version mismatch. Tensorflow was compiled against " + "TensorRT %s but library loaded from environment is TensorRT %s" % (".".join([str(x) for x in compiled_version]), ".".join([str(x) for x in loaded_version])) + ". Please make sure that correct version of TensorRT " + "is available in the system and added to ldconfig or LD_LIBRARY_PATH" ) raise RuntimeError("Incompatible TensorRT library version") for i in zip(loaded_version, compiled_version): if i[0] != i[1]: tf_logging.warn("TensorRT mismatch. Compiled against version " + "%s, but loaded %s. Things may not work" % (".".join([str(x) for x in compiled_version]), ".".join([str(x) for x in loaded_version]))) version_mismatch = True break if not version_mismatch: tf_logging.info("Running against TensorRT version %s" % ".".join( [str(x) for x in loaded_version])) def py2bytes(inp): return inp def py3bytes(inp): return inp.encode("utf-8", errors="surrogateescape") def py2string(inp): return inp def py3string(inp): return inp.decode("utf-8") if _six.PY2: to_bytes = py2bytes to_string = py2string else: to_bytes = py3bytes to_string = py3string out_names = [] for i in outputs: if isinstance(i, ops.Tensor): out_names.append(to_bytes(i.name)) else: out_names.append(to_bytes(i)) input_graph_def_str = input_graph_def.SerializeToString() # TODO(sami): Fix this when we can return status from C++ library # There is a problem with the TF internal library setup that doesn't # allow us to return a status object from C++. Thus we return a # pair or strings where first one is encoded status and the second # one is the transformed graphs protobuf string. out = trt_convert(input_graph_def_str, out_names, max_batch_size, max_workspace_size_bytes, mode, minimum_segment_size, is_dynamic_op, maximum_cached_engines, cached_engine_batches) status = to_string(out[0]) output_graph_def_string = out[1] del input_graph_def_str # Save some memory if len(status) < 2: raise _impl.UnknownError(None, None, status) if status[:2] != "OK": msg = status.split(";") if len(msg) == 1: raise RuntimeError("Status message is malformed {}".format(status)) # pylint: disable=protected-access raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), int(msg[0])) # pylint: enable=protected-access output_graph_def = graph_pb2.GraphDef() output_graph_def.ParseFromString(output_graph_def_string) del output_graph_def_string # Save some memory return output_graph_def
def create_inference_graph(input_graph_def, outputs, max_batch_size=1, max_workspace_size_bytes=2 << 20, precision_mode="FP32", minimum_segment_size=3): """Python wrapper for the TRT transformation. Args: input_graph_def: GraphDef object containing a model to be transformed. outputs: list of tensors or node names for the model outputs. max_batch_size: max size for the input batch max_workspace_size_bytes: parameter to control memory allocation (in Bytes) precision_mode: one of 'FP32', 'FP16' and 'INT8' minimum_segment_size: the minimum number of nodes required for a subgraph to be replaced by TRTEngineOp. Returns: New GraphDef with TRTEngineOps placed in graph replacing subgraphs. Raises: ValueError: if the provided precision mode is invalid. RuntimeError: if the returned status message is malformed. """ supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2} if precision_mode.upper() not in supported_precision_modes: raise ValueError(("precision mode '{}' is not supported." "It should be one of {}").format( precision_mode, "{'FP32', 'FP16', 'INT8'}")) mode = supported_precision_modes[precision_mode.upper()] def py2bytes(inp): return inp def py3bytes(inp): return inp.encode("utf-8", errors="surrogateescape") def py2string(inp): return inp def py3string(inp): return inp.decode("utf-8") if _six.PY2: to_bytes = py2bytes to_string = py2string else: to_bytes = py3bytes to_string = py3string out_names = [] for i in outputs: if isinstance(i, ops.Tensor): out_names.append(to_bytes(i.name)) else: out_names.append(to_bytes(i)) input_graph_def_str = input_graph_def.SerializeToString() # TODO(sami): Fix this when we can return status from C++ library # There is a problem with the TF internal library setup that doesn't # allow us to return a status object from C++. Thus we return a # pair or strings where first one is encoded status and the second # one is the transformed graphs protobuf string. out = trt_convert(input_graph_def_str, out_names, max_batch_size, max_workspace_size_bytes, mode, minimum_segment_size) status = to_string(out[0]) output_graph_def_string = out[1] del input_graph_def_str # Save some memory if len(status) < 2: raise _impl.UnknownError(None, None, status) if status[:2] != "OK": msg = status.split(";") if len(msg) == 1: raise RuntimeError("Status message is malformed {}".format(status)) # pylint: disable=protected-access raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), int(msg[0])) # pylint: enable=protected-access output_graph_def = graph_pb2.GraphDef() output_graph_def.ParseFromString(output_graph_def_string) del output_graph_def_string # Save some memory return output_graph_def