コード例 #1
0
def get_newer_policy_model_file(
    output_dir,
    min_epoch=-1,
    sleep_time_secs=0.1,
    max_sleep_time_secs=1.0,
    max_tries=1,
    wait_forever=False,
):
    """Gets a policy model file subject to availability and wait time."""

    while max_tries or wait_forever:
        max_tries -= 1
        policy_files = ppo.get_policy_model_files(output_dir)

        def do_wait(t):
            time.sleep(t)
            t *= 2
            return min(t, max_sleep_time_secs)

        # No policy files at all.
        if not policy_files:
            logging.info(
                "There are no policy files in [%s], waiting for %s secs.",
                output_dir, sleep_time_secs)
            sleep_time_secs = do_wait(sleep_time_secs)
            continue

        # Check if we have a newer epoch.
        policy_file = policy_files[0]
        epoch = ppo.get_epoch_from_policy_model_file(policy_file)

        # We don't - wait.
        if epoch <= min_epoch:
            logging.info("epoch [%s] <= min_epoch [%s], waiting for %s secs.",
                         epoch, min_epoch, sleep_time_secs)
            sleep_time_secs = do_wait(sleep_time_secs)
            continue

        # We do have a new file, return it.
        policy_file = policy_files[0]
        epoch = ppo.get_epoch_from_policy_model_file(policy_file)
        logging.info("Found epoch [%s] and policy file [%s]", epoch,
                     policy_file)
        return policy_file, epoch

    # Exhausted our waiting limit.
    return None
コード例 #2
0
ファイル: ppo_test.py プロジェクト: levskaya/tensor2tensor
 def test_get_epoch_from_policy_model_file(self):
     self.assertEqual(
         0, ppo.get_epoch_from_policy_model_file("model-000000.pkl"))
     self.assertEqual(
         123456, ppo.get_epoch_from_policy_model_file("model-123456.pkl"))