def get_checkpoint_state(checkpoint_dir): """ Get tf checkpoint state from checkpoint directory without using native tensorflow accessing remote method. :param checkpoint_dir: tensorflow checkpoint directory. Could be local, hdfs, s3 filesystems. :return: tf checkpoint protobuf """ if is_local_path(checkpoint_dir): return tf.train.get_checkpoint_state(checkpoint_dir) else: # check if checkpoint file exists file_list = get_file_list(checkpoint_dir) has_checkpoint = False for file in file_list: if basename(file) == 'checkpoint': has_checkpoint = True break if not has_checkpoint: return None # get checkpoint file temp = tempfile.mkdtemp() get_remote_file_to_local(join(checkpoint_dir, "checkpoint"), join(temp, "checkpoint")) ckpt_name = None with open(join(temp, "checkpoint")) as f: lines = f.readlines() # get checkpoint name from 'checkpoint' file for line in lines: m = re.compile("^model_checkpoint_path: \"(.*)\"$").match(line) if m: ckpt_name = m.group(1) break if ckpt_name is None: shutil.rmtree(temp) return None # filter checkpoint files checkpoint_files = [ file for file in file_list if basename(file).startswith(ckpt_name) ] if not checkpoint_files: shutil.rmtree(temp) return None # get checkpoint files to local [ get_remote_file_to_local(file, join(temp, basename(file))) for file in checkpoint_files ] # get checkpoint state ckpt = tf.train.get_checkpoint_state(temp) if not ckpt: shutil.rmtree(temp) return None ckpt.model_checkpoint_path = join(checkpoint_dir, ckpt_name) ckpt.all_model_checkpoint_paths[:] = [join(checkpoint_dir, ckpt_name)] shutil.rmtree(temp) return ckpt
def load_tf_checkpoint(sess, checkpoint_path, saver=None): """ Load tensorflow checkpoint from checkpoint path without using native tensorflow accessing remote method. :param sess: tensorflow session to be loaded to. :param checkpoint_path: tensorflow checkpoint path. Could be local, hdfs, s3 filesystems. :param saver: tensorflow saver to load checkpoint """ if is_local_path(checkpoint_path): if saver is None: saver = tf.train.Saver() saver.restore(sess, checkpoint_path) else: ckpt_name = basename(checkpoint_path) checkpoint_dir = dirname(checkpoint_path) # get remote file lists file_list = get_file_list(checkpoint_dir) # filter checkpoint files checkpoint_files = [file for file in file_list if basename(file).startswith(ckpt_name)] # get checkpoint files to local temp = tempfile.mkdtemp() [get_remote_file_to_local(file, join(temp, basename(file))) for file in checkpoint_files] if saver is None: saver = tf.train.Saver() try: saver.restore(sess, join(temp, ckpt_name)) except Exception as e: raise e finally: shutil.rmtree(temp)
def load_tf_checkpoint_from_remote(sess, checkpoint_path, saver=None): """ Load tf checkpoint from remote checkpoint path :param sess: tf session to be loaded to. :param checkpoint_path: remote tf checkpoint path :param saver: tf saver to load checkpoint """ ckpt_name = basename(checkpoint_path) checkpoint_dir = dirname(checkpoint_path) # get remote file lists file_list = get_file_list(checkpoint_dir) # filter checkpoint files checkpoint_files = [ file for file in file_list if basename(file).startswith(ckpt_name) ] # get checkpoint files to local temp = tempfile.mkdtemp() [ get_remote_file_to_local(file, join(temp, basename(file))) for file in checkpoint_files ] if saver is None: saver = tf.train.Saver() try: saver.restore(sess, join(temp, ckpt_name)) except Exception as e: raise e finally: shutil.rmtree(temp)