def __init__(self, save_path: str): if not isinstance(save_path, str): raise ValueError( "param 'save_path' must be str, but got {}".format(save_path)) self.version_ = None self.checkpoint_dir_ = self.DEFAULT_CHECKPOINT_DIR self.saved_model_dir_ = save_path self.saved_model_pb_filename_ = "{}.pb".format( self.DEFAULT_SAVED_MODEL_FILE_BASENAME) self.saved_model_pbtxt_filename_ = "{}.prototxt".format( self.DEFAULT_SAVED_MODEL_FILE_BASENAME) self.saved_model_proto_ = saved_model_pb.SavedModel() self.graph_builders_ = {}
def load_saved_model( self, saved_model_dir, model_version=ModelVersionPolicy.LATEST, saved_model_meta_file_basename="saved_model", graph_name=None, signature_name=None, ): if not os.path.isdir(saved_model_dir): raise ValueError( "{} is not a valid directory".format(saved_model_dir)) if isinstance(model_version, int): pass elif model_version == ModelVersionPolicy.LATEST: model_version = _find_model_latest_version(saved_model_dir) else: raise NotImplementedError saved_model_path = os.path.join(saved_model_dir, str(model_version)) if not os.path.isdir(saved_model_path): raise ValueError( "version {} of saved model in dir {} do not exist".format( model_version, saved_model_dir)) subfiles = list(os.listdir(saved_model_path)) saved_model_meta_pb_filename = saved_model_meta_file_basename + ".pb" saved_model_meta_prototxt_filename = (saved_model_meta_file_basename + ".prototxt") saved_model_proto = saved_model_pb.SavedModel() if saved_model_meta_pb_filename in subfiles: saved_model_meta_file_path = os.path.join( saved_model_path, saved_model_meta_pb_filename) with open(saved_model_meta_file_path, "rb") as f: saved_model_proto.ParseFromString(f.read()) elif saved_model_meta_prototxt_filename in subfiles: saved_model_meta_file_path = os.path.join( saved_model_path, saved_model_meta_prototxt_filename) with open(saved_model_meta_file_path, "rt") as f: text_format.Merge(f.read(), saved_model_proto) else: raise ValueError( "saved model meta file {} do not exist in {}".format( saved_model_meta_file_basename, saved_model_path)) # set checkpoint self.set_checkpoint_path( os.path.join(saved_model_path, saved_model_proto.checkpoint_dir)) # get signature signature = None if graph_name is None: graph_name = saved_model_proto.default_graph_name else: if graph_name not in saved_model_proto.graphs: raise ValueError("graph {} do not exist".format(graph_name)) graph_def = saved_model_proto.graphs[graph_name] if signature_name is None and graph_def.HasField( "default_signature_name"): signature_name = graph_def.default_signature_name if signature_name is not None: if signature_name not in graph_def.signatures: raise ValueError( "signature {} do not exist".format(signature_name)) else: signature = graph_def.signatures[signature_name] # compile job with self.open(graph_name, signature): self.compile(graph_def.op_list)
def load_saved_model(model_meta_file_path): saved_model_proto = saved_model_pb.SavedModel() with open(model_meta_file_path, "rb") as f: text_format.Merge(f.read(), saved_model_proto) return saved_model_proto