Exemplo n.º 1
0
def get_checkpoint_state(checkpoint_dir):
    """
    Get tf checkpoint state from checkpoint directory without using native tensorflow accessing
    remote method.
    :param checkpoint_dir: tensorflow checkpoint directory. Could be local, hdfs, s3 filesystems.
    :return: tf checkpoint protobuf
    """
    if is_local_path(checkpoint_dir):
        return tf.train.get_checkpoint_state(checkpoint_dir)
    else:
        # check if checkpoint file exists
        file_list = get_file_list(checkpoint_dir)
        has_checkpoint = False
        for file in file_list:
            if basename(file) == 'checkpoint':
                has_checkpoint = True
                break
        if not has_checkpoint:
            return None
        # get checkpoint file
        temp = tempfile.mkdtemp()
        get_remote_file_to_local(join(checkpoint_dir, "checkpoint"),
                                 join(temp, "checkpoint"))
        ckpt_name = None
        with open(join(temp, "checkpoint")) as f:
            lines = f.readlines()
            # get checkpoint name from 'checkpoint' file
            for line in lines:
                m = re.compile("^model_checkpoint_path: \"(.*)\"$").match(line)
                if m:
                    ckpt_name = m.group(1)
                    break
        if ckpt_name is None:
            shutil.rmtree(temp)
            return None
        # filter checkpoint files
        checkpoint_files = [
            file for file in file_list if basename(file).startswith(ckpt_name)
        ]
        if not checkpoint_files:
            shutil.rmtree(temp)
            return None
        # get checkpoint files to local
        [
            get_remote_file_to_local(file, join(temp, basename(file)))
            for file in checkpoint_files
        ]
        # get checkpoint state
        ckpt = tf.train.get_checkpoint_state(temp)
        if not ckpt:
            shutil.rmtree(temp)
            return None
        ckpt.model_checkpoint_path = join(checkpoint_dir, ckpt_name)
        ckpt.all_model_checkpoint_paths[:] = [join(checkpoint_dir, ckpt_name)]
        shutil.rmtree(temp)
        return ckpt
Exemplo n.º 2
0
def load_tf_checkpoint(sess, checkpoint_path, saver=None):
    """
    Load tensorflow checkpoint from checkpoint path without using native tensorflow accessing
    remote method.
    :param sess: tensorflow session to be loaded to.
    :param checkpoint_path: tensorflow checkpoint path. Could be local, hdfs, s3 filesystems.
    :param saver: tensorflow saver to load checkpoint
    """
    if is_local_path(checkpoint_path):
        if saver is None:
            saver = tf.train.Saver()
        saver.restore(sess, checkpoint_path)
    else:
        ckpt_name = basename(checkpoint_path)
        checkpoint_dir = dirname(checkpoint_path)
        # get remote file lists
        file_list = get_file_list(checkpoint_dir)
        # filter checkpoint files
        checkpoint_files = [file for file in file_list if basename(file).startswith(ckpt_name)]
        # get checkpoint files to local
        temp = tempfile.mkdtemp()
        [get_remote_file_to_local(file, join(temp, basename(file))) for file in checkpoint_files]
        if saver is None:
            saver = tf.train.Saver()
        try:
            saver.restore(sess, join(temp, ckpt_name))
        except Exception as e:
            raise e
        finally:
            shutil.rmtree(temp)
Exemplo n.º 3
0
def load_tf_checkpoint_from_remote(sess, checkpoint_path, saver=None):
    """
    Load tf checkpoint from remote checkpoint path
    :param sess: tf session to be loaded to.
    :param checkpoint_path: remote tf checkpoint path
    :param saver: tf saver to load checkpoint
    """
    ckpt_name = basename(checkpoint_path)
    checkpoint_dir = dirname(checkpoint_path)
    # get remote file lists
    file_list = get_file_list(checkpoint_dir)
    # filter checkpoint files
    checkpoint_files = [
        file for file in file_list if basename(file).startswith(ckpt_name)
    ]
    # get checkpoint files to local
    temp = tempfile.mkdtemp()
    [
        get_remote_file_to_local(file, join(temp, basename(file)))
        for file in checkpoint_files
    ]
    if saver is None:
        saver = tf.train.Saver()
    try:
        saver.restore(sess, join(temp, ckpt_name))
    except Exception as e:
        raise e
    finally:
        shutil.rmtree(temp)