Esempio n. 1
0
    def test_get_policy_model_files(self):
        output_dir = self.get_temp_dir()

        def write_policy_model_file(epoch):
            with gfile.GFile(
                    ppo.get_policy_model_file_from_epoch(output_dir, epoch),
                    "w") as f:
                f.write("some data")

        epochs = [200, 100, 300]

        # 300, 200, 100
        expected_policy_model_files = [
            output_dir + "/model-000300.pkl",
            output_dir + "/model-000200.pkl",
            output_dir + "/model-000100.pkl",
        ]

        for epoch in epochs:
            write_policy_model_file(epoch)

        policy_model_files = ppo.get_policy_model_files(output_dir)

        self.assertEqual(expected_policy_model_files, policy_model_files)

        gfile.rmtree(output_dir)
Esempio n. 2
0
 def save(self):
   """Save the agent parameters."""
   logging.vlog(1, "PPO epoch [% 6d]: saving model.", self._epoch)
   old_model_files = ppo.get_policy_model_files(self._output_dir)
   params_file = os.path.join(self._output_dir, "model-%06d.pkl" % self._epoch)
   with gfile.GFile(params_file, "wb") as f:
     pickle.dump((self._policy_and_value_opt_state, self._model_state,
                  self._total_opt_step), f)
   # Remove the old model files.
   for path in old_model_files:
     if path != params_file:
       gfile.remove(path)
   # Reset this number.
   self._n_trajectories_done = 0
   self._last_saved_at = self._epoch
Esempio n. 3
0
def get_newer_policy_model_file(
    output_dir,
    min_epoch=-1,
    sleep_time_secs=0.1,
    max_sleep_time_secs=1.0,
    max_tries=1,
    wait_forever=False,
):
    """Gets a policy model file subject to availability and wait time."""

    while max_tries or wait_forever:
        max_tries -= 1
        policy_files = ppo.get_policy_model_files(output_dir)

        def do_wait(t):
            time.sleep(t)
            t *= 2
            return min(t, max_sleep_time_secs)

        # No policy files at all.
        if not policy_files:
            logging.info(
                "There are no policy files in [%s], waiting for %s secs.",
                output_dir, sleep_time_secs)
            sleep_time_secs = do_wait(sleep_time_secs)
            continue

        # Check if we have a newer epoch.
        policy_file = policy_files[0]
        epoch = ppo.get_epoch_from_policy_model_file(policy_file)

        # We don't - wait.
        if epoch <= min_epoch:
            logging.info("epoch [%s] <= min_epoch [%s], waiting for %s secs.",
                         epoch, min_epoch, sleep_time_secs)
            sleep_time_secs = do_wait(sleep_time_secs)
            continue

        # We do have a new file, return it.
        policy_file = policy_files[0]
        epoch = ppo.get_epoch_from_policy_model_file(policy_file)
        logging.info("Found epoch [%s] and policy file [%s]", epoch,
                     policy_file)
        return policy_file, epoch

    # Exhausted our waiting limit.
    return None