def load_tf_checkpoint(sess, checkpoint_path, saver=None): """ Load tensorflow checkpoint from checkpoint path without using native tensorflow accessing remote method. :param sess: tensorflow session to be loaded to. :param checkpoint_path: tensorflow checkpoint path. Could be local, hdfs, s3 filesystems. :param saver: tensorflow saver to load checkpoint """ if is_local_path(checkpoint_path): if saver is None: saver = tf.train.Saver() saver.restore(sess, checkpoint_path) else: ckpt_name = basename(checkpoint_path) checkpoint_dir = dirname(checkpoint_path) # get remote file lists file_list = get_file_list(checkpoint_dir) # filter checkpoint files checkpoint_files = [file for file in file_list if basename(file).startswith(ckpt_name)] # get checkpoint files to local temp = tempfile.mkdtemp() [get_remote_file_to_local(file, join(temp, basename(file))) for file in checkpoint_files] if saver is None: saver = tf.train.Saver() try: saver.restore(sess, join(temp, ckpt_name)) except Exception as e: raise e finally: shutil.rmtree(temp)
def save_tf_checkpoint(sess, checkpoint_path, saver=None): """ Save tf checkpoint without using native tensorflow remote access method. :param sess: tf session to be saved. :param checkpoint_path: checkpoint path. Could be local, hdfs, s3 filesystems. :param saver: tf saver to save checkpoint """ import tensorflow as tf if is_local_path(checkpoint_path): if saver is None: saver = tf.train.Saver() saver.save(sess, checkpoint_path) else: ckpt_name = os.path.basename(checkpoint_path) remote_dir = os.path.dirname(checkpoint_path) # save to local checkpoint temp = tempfile.mkdtemp() if saver is None: saver = tf.train.Saver() saver.save(sess, os.path.join(temp, ckpt_name)) change_path_in_tf_checkpoint(os.path.join(temp, "checkpoint"), ckpt_name) # move to remote [ put_local_file_to_remote(os.path.join(temp, file), os.path.join(remote_dir, file), over_write=True) for file in os.listdir(temp) ] shutil.rmtree(temp)
def get_checkpoint_state(checkpoint_dir): """ Get tf checkpoint state from checkpoint directory without using native tensorflow accessing remote method. :param checkpoint_dir: tensorflow checkpoint directory. Could be local, hdfs, s3 filesystems. :return: tf checkpoint protobuf """ if is_local_path(checkpoint_dir): return tf.train.get_checkpoint_state(checkpoint_dir) else: # check if checkpoint file exists file_list = get_file_list(checkpoint_dir) has_checkpoint = False for file in file_list: if basename(file) == 'checkpoint': has_checkpoint = True break if not has_checkpoint: return None # get checkpoint file temp = tempfile.mkdtemp() get_remote_file_to_local(join(checkpoint_dir, "checkpoint"), join(temp, "checkpoint")) ckpt_name = None with open(join(temp, "checkpoint")) as f: lines = f.readlines() # get checkpoint name from 'checkpoint' file for line in lines: m = re.compile("^model_checkpoint_path: \"(.*)\"$").match(line) if m: ckpt_name = m.group(1) break if ckpt_name is None: shutil.rmtree(temp) return None # filter checkpoint files checkpoint_files = [ file for file in file_list if basename(file).startswith(ckpt_name) ] if not checkpoint_files: shutil.rmtree(temp) return None # get checkpoint files to local [ get_remote_file_to_local(file, join(temp, basename(file))) for file in checkpoint_files ] # get checkpoint state ckpt = tf.train.get_checkpoint_state(temp) if not ckpt: shutil.rmtree(temp) return None ckpt.model_checkpoint_path = join(checkpoint_dir, ckpt_name) ckpt.all_model_checkpoint_paths[:] = [join(checkpoint_dir, ckpt_name)] shutil.rmtree(temp) return ckpt
def save_tf_checkpoint(sess, checkpoint_path, saver=None): """ Save tf checkpoint without using native tensorflow remote access method. :param sess: tf session to be saved. :param checkpoint_path: checkpoint path. Could be local, hdfs, s3 filesystems. :param saver: tf saver to save checkpoint """ if is_local_path(checkpoint_path): if saver is None: saver = tf.train.Saver() saver.save(sess, checkpoint_path) else: ckpt_name = basename(checkpoint_path) remote_dir = dirname(checkpoint_path) # save to local checkpoint temp = tempfile.mkdtemp() if saver is None: saver = tf.train.Saver() saver.save(sess, join(temp, ckpt_name)) # change checkpoint file with open(join(temp, "checkpoint")) as f: new_lines = [] lines = f.readlines() # replace model_checkpoint_path and all_model_checkpoint_paths to checkpoint name # instead of the absolute checkpoint path for line in lines: if re.compile("^model_checkpoint_path: \"(.*)\"$").match(line): new_lines.append( "model_checkpoint_path: \"{}\"\n".format(ckpt_name)) elif re.compile( "^all_model_checkpoint_paths: \"(.*)\"$").match(line): new_lines.append( "all_model_checkpoint_paths: \"{}\"\n".format( ckpt_name)) else: new_lines.append(line) with open(join(temp, "checkpoint"), 'w') as f: f.writelines(new_lines) # move to remote [ put_local_file_to_remote(join(temp, file), join(remote_dir, file), over_write=True) for file in os.listdir(temp) ] shutil.rmtree(temp)