コード例 #1
0
ファイル: ppo_test.py プロジェクト: wangleiphy/trax
  def test_get_policy_model_files(self):
    output_dir = self.get_temp_dir()

    def write_policy_model_file(epoch):
      with gfile.GFile(
          ppo.get_policy_model_file_from_epoch(output_dir, epoch), 'w') as f:
        f.write('some data')

    epochs = [200, 100, 300]

    # 300, 200, 100
    expected_policy_model_files = [
        output_dir + '/model-000300.pkl',
        output_dir + '/model-000200.pkl',
        output_dir + '/model-000100.pkl',
    ]

    for epoch in epochs:
      write_policy_model_file(epoch)

    policy_model_files = ppo.get_policy_model_files(output_dir)

    self.assertEqual(expected_policy_model_files, policy_model_files)

    gfile.rmtree(output_dir)
コード例 #2
0
def get_newer_policy_model_file(
    output_dir,
    min_epoch=-1,
    sleep_time_secs=0.1,
    max_sleep_time_secs=1.0,
    max_tries=1,
    wait_forever=False,
):
    """Gets a policy model file subject to availability and wait time."""

    while max_tries or wait_forever:
        max_tries -= 1
        policy_files = ppo.get_policy_model_files(output_dir)

        def do_wait(t):
            time.sleep(t)
            t *= 2
            return min(t, max_sleep_time_secs)

        # No policy files at all.
        if not policy_files:
            logging.info(
                "There are no policy files in [%s], waiting for %s secs.",
                output_dir, sleep_time_secs)
            sleep_time_secs = do_wait(sleep_time_secs)
            continue

        # Check if we have a newer epoch.
        policy_file = policy_files[0]
        epoch = ppo.get_epoch_from_policy_model_file(policy_file)

        # We don't - wait.
        if epoch <= min_epoch:
            logging.info("epoch [%s] <= min_epoch [%s], waiting for %s secs.",
                         epoch, min_epoch, sleep_time_secs)
            sleep_time_secs = do_wait(sleep_time_secs)
            continue

        # We do have a new file, return it.
        policy_file = policy_files[0]
        epoch = ppo.get_epoch_from_policy_model_file(policy_file)
        logging.info("Found epoch [%s] and policy file [%s]", epoch,
                     policy_file)
        return policy_file, epoch

    # Exhausted our waiting limit.
    return None