コード例 #1
0
ファイル: train.py プロジェクト: junkilee/simple_baselines
def maybe_load_model(savedir):
    """Load model if present at the specified path."""
    if savedir is None:
        return

    state_path = os.path.join(os.path.join(savedir, 'training_state.pkl.zip'))
    found_model = os.path.exists(state_path)
    if found_model:
        state = pickle_load(state_path, compression=True)
        model_dir = "model-{}".format(state["num_iters"])
        U.load_state(os.path.join(savedir, model_dir, "saved"))
        logger.log("Loaded models checkpoint at {} iterations".format(state["num_iters"]))
        return state
コード例 #2
0
ファイル: workers.py プロジェクト: akashin/baselines
    def load(self, path, session):
        """Load model if present at the specified path."""
        if path is None:
            return

        state_path = os.path.join(os.path.join(path, 'training_state.pkl.zip'))
        found_model = os.path.exists(state_path)
        if found_model:
            state = pickle_load(state_path, compression=True)
            model_dir = "model-{}".format(state["num_iters"])
            U.load_state(os.path.join(path, model_dir, "saved"), session=session)
            self.logger.log("Loaded models checkpoint at {} iterations".format(state["num_iters"]))

            if state is not None:
                self.num_iters = state["num_iters"]
コード例 #3
0
ファイル: train.py プロジェクト: musiclicn/baselines
def maybe_load_model(savedir, container):
    """Load model if present at the specified path."""
    if savedir is None:
        return

    state_path = os.path.join(os.path.join(savedir, 'training_state.pkl.zip'))
    if container is not None:
        logger.log("Attempting to download model from Azure")
        found_model = container.get(savedir, 'training_state.pkl.zip')
    else:
        found_model = os.path.exists(state_path)
    if found_model:
        state = pickle_load(state_path, compression=True)
        model_dir = "model-{}".format(state["num_iters"])
        if container is not None:
            container.get(savedir, model_dir)
        U.load_state(os.path.join(savedir, model_dir, "saved"))
        logger.log("Loaded models checkpoint at {} iterations".format(state["num_iters"]))
        return state
コード例 #4
0
ファイル: train.py プロジェクト: shakenes/baselines
def maybe_load_model(savedir, container):
    """Load model if present at the specified path."""
    if savedir is None:
        return

    state_path = os.path.join(os.path.join(savedir, 'training_state.pkl.zip'))
    if container is not None:
        logger.log("Attempting to download model from Azure")
        found_model = container.get(savedir, 'training_state.pkl.zip')
    else:
        found_model = os.path.exists(state_path)
    if found_model:
        state = pickle_load(state_path, compression=True)
        model_dir = "model-{}".format(state["num_iters"])
        if container is not None:
            container.get(savedir, model_dir)
        load_state(os.path.join(savedir, model_dir, "saved"))
        logger.log("Loaded models checkpoint at {} iterations".format(
            state["num_iters"]))
        return state
コード例 #5
0
def load_model():
    load_state("saved_model/model.ckpt")
    dict_state = pickle_load("saved_model/model_state.pkl.zip", compression=True)
    return dict_state