예제 #1
0
def restore_ckpt(sess: tf.Session, depth_multiplier: float, var_list: list, ckptdir: str):
    if ckptdir == '' or ckptdir == None:
        pass
    elif 'pkl' in ckptdir:
        restore_from_pkl(sess, tf.global_variables(), ckptdir)
    else:
        ckpt = tf.train.get_checkpoint_state(ckptdir)
        loader = tf.train.Saver(var_list=var_list)
        loader.restore(sess, ckpt.model_checkpoint_path)
예제 #2
0
 def _load_parameters(self, directory, unknown_environments):
     filename = tf.train.latest_checkpoint(os.path.dirname(f"{directory}/checkpoint"))
     params = tf.global_variables()
     saved_params = tf.train.NewCheckpointReader(filename).get_variable_to_shape_map()
     params_to_load = {var.name: var for var in params}
     for parameter in params:
         if parameter.name not in saved_params.keys():
             new_env_name = [env_name for env_name in self._actor_critics if env_name in parameter.name and env_name not in unknown_environments]
             if len(new_env_name) > 0:
                 unknown_environments.append(new_env_name[0])
             del params_to_load[parameter.name]
     if len(unknown_environments) > 0:
         warning(f"No Actor-Critics for {unknown_environments} found, spawned new policy-value layers")
     saver = tf.train.Saver(params_to_load, max_to_keep=0)
     saver.restore(self._session, filename)
예제 #3
0
 def _save_parameters(self, directory):
     variables = tf.global_variables()
     saver = tf.train.Saver({var.name: var for var in variables}, max_to_keep=0)
     saver.save(self._session, f"{directory}/params")