def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): """Convert an existing calibration graph to inference graph. Args: calibration_graph_def: the calibration GraphDef object with calibration data is_dynamic_op: whether to create dynamic static engines from calibration Returns: New GraphDef with TRTEngineOps placed in graph replacing calibration nodes. Raises: RuntimeError: if the returned status message is malformed. """ def py2string(inp): return inp def py3string(inp): return inp.decode("utf-8") if _six.PY2: to_string = py2string else: to_string = py3string is_calib_graph = False for n in calibration_graph_def.node: if n.op == "TRTEngineOp": is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s if not is_calib_graph: tf_logging.error( "Not a calib graph. Doesn't seem to contain any calibration nodes.") return None graph_str = calibration_graph_def.SerializeToString() out = calib_convert(graph_str, is_dynamic_op) status = to_string(out[0]) output_graph_def_string = out[1] del graph_str # Save some memory if len(status) < 2: raise _impl.UnknownError(None, None, status) if status[:2] != "OK": msg = status.split(";") if len(msg) == 1: raise RuntimeError("Status message is malformed {}".format(status)) # pylint: disable=protected-access raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), int(msg[0])) # pylint: enable=protected-access output_graph_def = graph_pb2.GraphDef() output_graph_def.ParseFromString(output_graph_def_string) del output_graph_def_string # Save some memory return output_graph_def
def calib_graph_to_infer_graph(calibration_graph_def): """Convert an existing calibration graph to inference graph. Args: calibration_graph_def: the calibration GraphDef object with calibration data Returns: New GraphDef with TRTEngineOps placed in graph replacing calibration nodes. Raises: RuntimeError: if the returned status message is malformed. """ def py2string(inp): return inp def py3string(inp): return inp.decode("utf-8") if _six.PY2: to_string = py2string else: to_string = py3string graph_str = calibration_graph_def.SerializeToString() out = calib_convert(graph_str) status = to_string(out[0]) output_graph_def_string = out[1] del graph_str # Save some memory if len(status) < 2: raise _impl.UnknownError(None, None, status) if status[:2] != "OK": msg = status.split(";") if len(msg) == 1: raise RuntimeError("Status message is malformed {}".format(status)) # pylint: disable=protected-access raise _impl._make_specific_exception(None, None, ";".join(msg[1:]), int(msg[0])) # pylint: enable=protected-access output_graph_def = graph_pb2.GraphDef() output_graph_def.ParseFromString(output_graph_def_string) del output_graph_def_string # Save some memory return output_graph_def