Exemplo n.º 1
0
def load(sess: tf.Session, saver: tf.train.Saver, checkpoint_dir: str) -> bool:
    '''Loads the most recent checkpoint from checkpoint_dir.

    Args
    - sess: tf.Session
    - saver: tf.train.Saver
    - checkpoint_dir: str, path to directory containing checkpoint(s)

    Returns: bool, True if successful at restoring checkpoint from given dir
    '''
    print(f'Reading from checkpoint dir: {checkpoint_dir}')
    if checkpoint_dir is None:
        raise ValueError('No checkpoint path, given, cannot load checkpoint')
    if not os.path.isdir(checkpoint_dir):
        raise ValueError('Given path is not a valid directory.')

    # read the CheckpointState proto from 'checkpoint' file in checkpoint_dir
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        print(f'Loading checkpoint: {ckpt_name}')
        if not checkpoint_path_exists(ckpt.model_checkpoint_path):
            raise LoadNoFileError(
                'Checkpoint could not be loaded because it does not exist,'
                ' but its information is in the checkpoint meta-data file.')
        saver.restore(sess, ckpt.model_checkpoint_path)
        return True
    return False
Exemplo n.º 2
0
def restore_model(sess: tf.Session, saver: tf.train.Saver, c: TrainConfig):
    model_name = get_model_name(c)
    save_dir = os.path.join(c.save_dir, model_name)
    model_path = os.path.join(c.save_dir, model_name, 'model')
    latest_ckpt = tf.train.latest_checkpoint(save_dir)
    if latest_ckpt:
        print(f'💽 restoring latest checkpoint from: {latest_ckpt}')
        saver.restore(sess, latest_ckpt)
    else:
        print(f'💽 training new model')
    return model_path
Exemplo n.º 3
0
def load(saver: tf.train.Saver, sess: tf.Session, checkpoint_dir: str):
    print(" [*] Reading checkpoints from %s..." % checkpoint_dir)

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
        counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
        print(" [*] Success to read {}".format(ckpt_name))
        return True, counter
    else:
        print(" [*] Failed to find checkpoints")
        return False, 0
Exemplo n.º 4
0
 def restore_model_and_get_inital_epoch(self, session: tf.Session,
                                        saver: tf.train.Saver,
                                        load_path: str):
     print("load_path", load_path, flush=True)
     checkpoint_path = self.get_checkpoint_path(load_path)
     print("checkpoint_path", checkpoint_path, flush=True)
     latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
     print("latest_checkpoint", latest_checkpoint, flush=True)
     if latest_checkpoint is not None:
         saver.restore(session, latest_checkpoint)
         return int(pathlib.Path(latest_checkpoint).name)
     else:
         return 0
Exemplo n.º 5
0
def restore_model(sess: tf.Session, path: str, saver: tf.train.Saver = None) -> tf.train.Saver:
    """
    Loads a tensorflow session from the given path.
    NOTE: This currently loads *all* variables in the saved file, unless one passes in a custom Saver object.
    :param sess: The tensorflow checkpoint to load from
    :param path: The path to the saved data
    :param saver: A custom saver object to use. This can be used to only load certain variables. If None,
    creates a saver object that loads all variables.
    :return: The saver object used.
    """
    if saver is None:
        saver = tf.train.Saver(tf.all_variables())
    saver.restore(sess, path)
    return saver
Exemplo n.º 6
0
    def _restore_checkpoint(saver: tf.train.Saver,
                            sess: tf.Session,
                            path: Optional[str] = None):
        if path and saver:
            # if a directory is given instead of a path, try to find a checkpoint file there
            checkpoint_file = tf.train.latest_checkpoint(
                path) if os.path.isdir(path) else path

            if checkpoint_file and tf.train.checkpoint_exists(checkpoint_file):
                saver.restore(sess, checkpoint_file)
                tf.logging.info("Model loaded from {}".format(checkpoint_file))
            else:
                tf.logging.info(
                    "No valid checkpoint has been found at {}. Ignoring.".
                    format(path))
Exemplo n.º 7
0
 def reconstruct_saved_model_variables(cls, sess: tf.Session, saver: tf.train.Saver, model_dir: str) -> None:
     print('Loading saved model variables from', model_dir)
     latest_checkpoint = tf.train.latest_checkpoint(model_dir)
     saver.restore(sess, latest_checkpoint)
Exemplo n.º 8
0
 def restore_from_checkpoint(self, sess: tf.Session, loader: tf.train.Saver,
                             path_to_checkpoint_dir: str):
     loader.restore(sess, path_to_checkpoint_dir)