def get_newer_policy_model_file( output_dir, min_epoch=-1, sleep_time_secs=0.1, max_sleep_time_secs=1.0, max_tries=1, wait_forever=False, ): """Gets a policy model file subject to availability and wait time.""" while max_tries or wait_forever: max_tries -= 1 policy_files = ppo.get_policy_model_files(output_dir) def do_wait(t): time.sleep(t) t *= 2 return min(t, max_sleep_time_secs) # No policy files at all. if not policy_files: logging.info( "There are no policy files in [%s], waiting for %s secs.", output_dir, sleep_time_secs) sleep_time_secs = do_wait(sleep_time_secs) continue # Check if we have a newer epoch. policy_file = policy_files[0] epoch = ppo.get_epoch_from_policy_model_file(policy_file) # We don't - wait. if epoch <= min_epoch: logging.info("epoch [%s] <= min_epoch [%s], waiting for %s secs.", epoch, min_epoch, sleep_time_secs) sleep_time_secs = do_wait(sleep_time_secs) continue # We do have a new file, return it. policy_file = policy_files[0] epoch = ppo.get_epoch_from_policy_model_file(policy_file) logging.info("Found epoch [%s] and policy file [%s]", epoch, policy_file) return policy_file, epoch # Exhausted our waiting limit. return None
def test_get_epoch_from_policy_model_file(self): self.assertEqual( 0, ppo.get_epoch_from_policy_model_file("model-000000.pkl")) self.assertEqual( 123456, ppo.get_epoch_from_policy_model_file("model-123456.pkl"))