def test_get_policy_model_files(self): output_dir = self.get_temp_dir() def write_policy_model_file(epoch): with gfile.GFile( ppo.get_policy_model_file_from_epoch(output_dir, epoch), "w") as f: f.write("some data") epochs = [200, 100, 300] # 300, 200, 100 expected_policy_model_files = [ output_dir + "/model-000300.pkl", output_dir + "/model-000200.pkl", output_dir + "/model-000100.pkl", ] for epoch in epochs: write_policy_model_file(epoch) policy_model_files = ppo.get_policy_model_files(output_dir) self.assertEqual(expected_policy_model_files, policy_model_files) gfile.rmtree(output_dir)
def save(self): """Save the agent parameters.""" logging.vlog(1, "PPO epoch [% 6d]: saving model.", self._epoch) old_model_files = ppo.get_policy_model_files(self._output_dir) params_file = os.path.join(self._output_dir, "model-%06d.pkl" % self._epoch) with gfile.GFile(params_file, "wb") as f: pickle.dump((self._policy_and_value_opt_state, self._model_state, self._total_opt_step), f) # Remove the old model files. for path in old_model_files: if path != params_file: gfile.remove(path) # Reset this number. self._n_trajectories_done = 0 self._last_saved_at = self._epoch
def get_newer_policy_model_file( output_dir, min_epoch=-1, sleep_time_secs=0.1, max_sleep_time_secs=1.0, max_tries=1, wait_forever=False, ): """Gets a policy model file subject to availability and wait time.""" while max_tries or wait_forever: max_tries -= 1 policy_files = ppo.get_policy_model_files(output_dir) def do_wait(t): time.sleep(t) t *= 2 return min(t, max_sleep_time_secs) # No policy files at all. if not policy_files: logging.info( "There are no policy files in [%s], waiting for %s secs.", output_dir, sleep_time_secs) sleep_time_secs = do_wait(sleep_time_secs) continue # Check if we have a newer epoch. policy_file = policy_files[0] epoch = ppo.get_epoch_from_policy_model_file(policy_file) # We don't - wait. if epoch <= min_epoch: logging.info("epoch [%s] <= min_epoch [%s], waiting for %s secs.", epoch, min_epoch, sleep_time_secs) sleep_time_secs = do_wait(sleep_time_secs) continue # We do have a new file, return it. policy_file = policy_files[0] epoch = ppo.get_epoch_from_policy_model_file(policy_file) logging.info("Found epoch [%s] and policy file [%s]", epoch, policy_file) return policy_file, epoch # Exhausted our waiting limit. return None