Esempio n. 1
0
    def __init__(self, save_path: str):
        if not isinstance(save_path, str):
            raise ValueError(
                "param 'save_path' must be str, but got {}".format(save_path))

        self.version_ = None
        self.checkpoint_dir_ = self.DEFAULT_CHECKPOINT_DIR
        self.saved_model_dir_ = save_path
        self.saved_model_pb_filename_ = "{}.pb".format(
            self.DEFAULT_SAVED_MODEL_FILE_BASENAME)
        self.saved_model_pbtxt_filename_ = "{}.prototxt".format(
            self.DEFAULT_SAVED_MODEL_FILE_BASENAME)
        self.saved_model_proto_ = saved_model_pb.SavedModel()
        self.graph_builders_ = {}
Esempio n. 2
0
    def load_saved_model(
        self,
        saved_model_dir,
        model_version=ModelVersionPolicy.LATEST,
        saved_model_meta_file_basename="saved_model",
        graph_name=None,
        signature_name=None,
    ):
        if not os.path.isdir(saved_model_dir):
            raise ValueError(
                "{} is not a valid directory".format(saved_model_dir))

        if isinstance(model_version, int):
            pass
        elif model_version == ModelVersionPolicy.LATEST:
            model_version = _find_model_latest_version(saved_model_dir)
        else:
            raise NotImplementedError

        saved_model_path = os.path.join(saved_model_dir, str(model_version))
        if not os.path.isdir(saved_model_path):
            raise ValueError(
                "version {} of saved model in dir {} do not exist".format(
                    model_version, saved_model_dir))

        subfiles = list(os.listdir(saved_model_path))
        saved_model_meta_pb_filename = saved_model_meta_file_basename + ".pb"
        saved_model_meta_prototxt_filename = (saved_model_meta_file_basename +
                                              ".prototxt")
        saved_model_proto = saved_model_pb.SavedModel()
        if saved_model_meta_pb_filename in subfiles:
            saved_model_meta_file_path = os.path.join(
                saved_model_path, saved_model_meta_pb_filename)
            with open(saved_model_meta_file_path, "rb") as f:
                saved_model_proto.ParseFromString(f.read())
        elif saved_model_meta_prototxt_filename in subfiles:
            saved_model_meta_file_path = os.path.join(
                saved_model_path, saved_model_meta_prototxt_filename)
            with open(saved_model_meta_file_path, "rt") as f:
                text_format.Merge(f.read(), saved_model_proto)
        else:
            raise ValueError(
                "saved model meta file {} do not exist in {}".format(
                    saved_model_meta_file_basename, saved_model_path))
        # set checkpoint
        self.set_checkpoint_path(
            os.path.join(saved_model_path, saved_model_proto.checkpoint_dir))
        # get signature
        signature = None
        if graph_name is None:
            graph_name = saved_model_proto.default_graph_name
        else:
            if graph_name not in saved_model_proto.graphs:
                raise ValueError("graph {} do not exist".format(graph_name))
        graph_def = saved_model_proto.graphs[graph_name]
        if signature_name is None and graph_def.HasField(
                "default_signature_name"):
            signature_name = graph_def.default_signature_name
        if signature_name is not None:
            if signature_name not in graph_def.signatures:
                raise ValueError(
                    "signature {} do not exist".format(signature_name))
            else:
                signature = graph_def.signatures[signature_name]

        # compile job
        with self.open(graph_name, signature):
            self.compile(graph_def.op_list)
Esempio n. 3
0
def load_saved_model(model_meta_file_path):
    saved_model_proto = saved_model_pb.SavedModel()
    with open(model_meta_file_path, "rb") as f:
        text_format.Merge(f.read(), saved_model_proto)
    return saved_model_proto