Beispiel #1
0
def load_tf_checkpoint(sess, checkpoint_path, saver=None):
    """
    Load tensorflow checkpoint from checkpoint path without using native tensorflow accessing
    remote method.
    :param sess: tensorflow session to be loaded to.
    :param checkpoint_path: tensorflow checkpoint path. Could be local, hdfs, s3 filesystems.
    :param saver: tensorflow saver to load checkpoint
    """
    if is_local_path(checkpoint_path):
        if saver is None:
            saver = tf.train.Saver()
        saver.restore(sess, checkpoint_path)
    else:
        ckpt_name = basename(checkpoint_path)
        checkpoint_dir = dirname(checkpoint_path)
        # get remote file lists
        file_list = get_file_list(checkpoint_dir)
        # filter checkpoint files
        checkpoint_files = [file for file in file_list if basename(file).startswith(ckpt_name)]
        # get checkpoint files to local
        temp = tempfile.mkdtemp()
        [get_remote_file_to_local(file, join(temp, basename(file))) for file in checkpoint_files]
        if saver is None:
            saver = tf.train.Saver()
        try:
            saver.restore(sess, join(temp, ckpt_name))
        except Exception as e:
            raise e
        finally:
            shutil.rmtree(temp)
Beispiel #2
0
def load_tf_checkpoint_from_remote(sess, checkpoint_path, saver=None):
    """
    Load tf checkpoint from remote checkpoint path
    :param sess: tf session to be loaded to.
    :param checkpoint_path: remote tf checkpoint path
    :param saver: tf saver to load checkpoint
    """
    ckpt_name = basename(checkpoint_path)
    checkpoint_dir = dirname(checkpoint_path)
    # get remote file lists
    file_list = get_file_list(checkpoint_dir)
    # filter checkpoint files
    checkpoint_files = [
        file for file in file_list if basename(file).startswith(ckpt_name)
    ]
    # get checkpoint files to local
    temp = tempfile.mkdtemp()
    [
        get_remote_file_to_local(file, join(temp, basename(file)))
        for file in checkpoint_files
    ]
    if saver is None:
        saver = tf.train.Saver()
    try:
        saver.restore(sess, join(temp, ckpt_name))
    except Exception as e:
        raise e
    finally:
        shutil.rmtree(temp)
Beispiel #3
0
def get_checkpoint_state(checkpoint_dir):
    """
    Get tf checkpoint state from checkpoint directory without using native tensorflow accessing
    remote method.
    :param checkpoint_dir: tensorflow checkpoint directory. Could be local, hdfs, s3 filesystems.
    :return: tf checkpoint protobuf
    """
    if is_local_path(checkpoint_dir):
        return tf.train.get_checkpoint_state(checkpoint_dir)
    else:
        # check if checkpoint file exists
        file_list = get_file_list(checkpoint_dir)
        has_checkpoint = False
        for file in file_list:
            if basename(file) == 'checkpoint':
                has_checkpoint = True
                break
        if not has_checkpoint:
            return None
        # get checkpoint file
        temp = tempfile.mkdtemp()
        get_remote_file_to_local(join(checkpoint_dir, "checkpoint"),
                                 join(temp, "checkpoint"))
        ckpt_name = None
        with open(join(temp, "checkpoint")) as f:
            lines = f.readlines()
            # get checkpoint name from 'checkpoint' file
            for line in lines:
                m = re.compile("^model_checkpoint_path: \"(.*)\"$").match(line)
                if m:
                    ckpt_name = m.group(1)
                    break
        if ckpt_name is None:
            shutil.rmtree(temp)
            return None
        # filter checkpoint files
        checkpoint_files = [
            file for file in file_list if basename(file).startswith(ckpt_name)
        ]
        if not checkpoint_files:
            shutil.rmtree(temp)
            return None
        # get checkpoint files to local
        [
            get_remote_file_to_local(file, join(temp, basename(file)))
            for file in checkpoint_files
        ]
        # get checkpoint state
        ckpt = tf.train.get_checkpoint_state(temp)
        if not ckpt:
            shutil.rmtree(temp)
            return None
        ckpt.model_checkpoint_path = join(checkpoint_dir, ckpt_name)
        ckpt.all_model_checkpoint_paths[:] = [join(checkpoint_dir, ckpt_name)]
        shutil.rmtree(temp)
        return ckpt
Beispiel #4
0
def find_latest_checkpoint(model_dir, model_type="bigdl"):
    import os
    import re
    import datetime
    ckpt_path = None
    latest_version = None
    optim_prefix = None
    optim_regex = None
    if model_type == "bigdl":
        optim_regex = ".*\.([0-9]+)$"
    elif model_type == "pytorch":
        optim_regex = "TorchModel[0-9a-z]*\.([0-9]+)$"
    elif model_type == "tf":
        optim_regex = "TFParkTraining\.([0-9]+)$"
    else:
        ValueError("Only bigdl, pytorch and tf are supported for now.")

    file_list = get_file_list(model_dir, recursive=True)
    optim_dict = {}
    pattern_re = re.compile(
        '(.*)(\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2})(.*)optimMethod-' +
        optim_regex)
    for file_path in file_list:
        matched = pattern_re.match(file_path)
        if matched is not None:
            try:
                # check if dir name is date time
                timestamp = matched.group(2)
                datetime.datetime.strptime(timestamp, '%Y-%m-%d_%H-%M-%S')
                if timestamp in optim_dict:
                    optim_dict[timestamp].append(
                        (int(matched.group(4)), os.path.dirname(file_path),
                         os.path.basename(file_path).split('.')[0]))
                else:
                    optim_dict[timestamp] = [
                        (int(matched.group(4)), os.path.dirname(file_path),
                         os.path.basename(file_path).split('.')[0])
                    ]
            except:
                continue
    if optim_dict:
        latest_timestamp = max(optim_dict)
        latest_version, ckpt_path, optim_prefix = max(
            optim_dict[latest_timestamp],
            key=lambda version_path: version_path[0])

    return ckpt_path, optim_prefix, latest_version