def load_tf_checkpoint(sess, checkpoint_path, saver=None): """ Load tensorflow checkpoint from checkpoint path without using native tensorflow accessing remote method. :param sess: tensorflow session to be loaded to. :param checkpoint_path: tensorflow checkpoint path. Could be local, hdfs, s3 filesystems. :param saver: tensorflow saver to load checkpoint """ if is_local_path(checkpoint_path): if saver is None: saver = tf.train.Saver() saver.restore(sess, checkpoint_path) else: ckpt_name = basename(checkpoint_path) checkpoint_dir = dirname(checkpoint_path) # get remote file lists file_list = get_file_list(checkpoint_dir) # filter checkpoint files checkpoint_files = [file for file in file_list if basename(file).startswith(ckpt_name)] # get checkpoint files to local temp = tempfile.mkdtemp() [get_remote_file_to_local(file, join(temp, basename(file))) for file in checkpoint_files] if saver is None: saver = tf.train.Saver() try: saver.restore(sess, join(temp, ckpt_name)) except Exception as e: raise e finally: shutil.rmtree(temp)
def load_tf_checkpoint_from_remote(sess, checkpoint_path, saver=None): """ Load tf checkpoint from remote checkpoint path :param sess: tf session to be loaded to. :param checkpoint_path: remote tf checkpoint path :param saver: tf saver to load checkpoint """ ckpt_name = basename(checkpoint_path) checkpoint_dir = dirname(checkpoint_path) # get remote file lists file_list = get_file_list(checkpoint_dir) # filter checkpoint files checkpoint_files = [ file for file in file_list if basename(file).startswith(ckpt_name) ] # get checkpoint files to local temp = tempfile.mkdtemp() [ get_remote_file_to_local(file, join(temp, basename(file))) for file in checkpoint_files ] if saver is None: saver = tf.train.Saver() try: saver.restore(sess, join(temp, ckpt_name)) except Exception as e: raise e finally: shutil.rmtree(temp)
def get_checkpoint_state(checkpoint_dir): """ Get tf checkpoint state from checkpoint directory without using native tensorflow accessing remote method. :param checkpoint_dir: tensorflow checkpoint directory. Could be local, hdfs, s3 filesystems. :return: tf checkpoint protobuf """ if is_local_path(checkpoint_dir): return tf.train.get_checkpoint_state(checkpoint_dir) else: # check if checkpoint file exists file_list = get_file_list(checkpoint_dir) has_checkpoint = False for file in file_list: if basename(file) == 'checkpoint': has_checkpoint = True break if not has_checkpoint: return None # get checkpoint file temp = tempfile.mkdtemp() get_remote_file_to_local(join(checkpoint_dir, "checkpoint"), join(temp, "checkpoint")) ckpt_name = None with open(join(temp, "checkpoint")) as f: lines = f.readlines() # get checkpoint name from 'checkpoint' file for line in lines: m = re.compile("^model_checkpoint_path: \"(.*)\"$").match(line) if m: ckpt_name = m.group(1) break if ckpt_name is None: shutil.rmtree(temp) return None # filter checkpoint files checkpoint_files = [ file for file in file_list if basename(file).startswith(ckpt_name) ] if not checkpoint_files: shutil.rmtree(temp) return None # get checkpoint files to local [ get_remote_file_to_local(file, join(temp, basename(file))) for file in checkpoint_files ] # get checkpoint state ckpt = tf.train.get_checkpoint_state(temp) if not ckpt: shutil.rmtree(temp) return None ckpt.model_checkpoint_path = join(checkpoint_dir, ckpt_name) ckpt.all_model_checkpoint_paths[:] = [join(checkpoint_dir, ckpt_name)] shutil.rmtree(temp) return ckpt
def find_latest_checkpoint(model_dir, model_type="bigdl"): import os import re import datetime ckpt_path = None latest_version = None optim_prefix = None optim_regex = None if model_type == "bigdl": optim_regex = ".*\.([0-9]+)$" elif model_type == "pytorch": optim_regex = "TorchModel[0-9a-z]*\.([0-9]+)$" elif model_type == "tf": optim_regex = "TFParkTraining\.([0-9]+)$" else: ValueError("Only bigdl, pytorch and tf are supported for now.") file_list = get_file_list(model_dir, recursive=True) optim_dict = {} pattern_re = re.compile( '(.*)(\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2})(.*)optimMethod-' + optim_regex) for file_path in file_list: matched = pattern_re.match(file_path) if matched is not None: try: # check if dir name is date time timestamp = matched.group(2) datetime.datetime.strptime(timestamp, '%Y-%m-%d_%H-%M-%S') if timestamp in optim_dict: optim_dict[timestamp].append( (int(matched.group(4)), os.path.dirname(file_path), os.path.basename(file_path).split('.')[0])) else: optim_dict[timestamp] = [ (int(matched.group(4)), os.path.dirname(file_path), os.path.basename(file_path).split('.')[0]) ] except: continue if optim_dict: latest_timestamp = max(optim_dict) latest_version, ckpt_path, optim_prefix = max( optim_dict[latest_timestamp], key=lambda version_path: version_path[0]) return ckpt_path, optim_prefix, latest_version