def test_get_policy_model_files(self): output_dir = self.get_temp_dir() def write_policy_model_file(epoch): with gfile.GFile( ppo.get_policy_model_file_from_epoch(output_dir, epoch), 'w') as f: f.write('some data') epochs = [200, 100, 300] # 300, 200, 100 expected_policy_model_files = [ output_dir + '/model-000300.pkl', output_dir + '/model-000200.pkl', output_dir + '/model-000100.pkl', ] for epoch in epochs: write_policy_model_file(epoch) policy_model_files = ppo.get_policy_model_files(output_dir) self.assertEqual(expected_policy_model_files, policy_model_files) gfile.rmtree(output_dir)
def get_newer_policy_model_file( output_dir, min_epoch=-1, sleep_time_secs=0.1, max_sleep_time_secs=1.0, max_tries=1, wait_forever=False, ): """Gets a policy model file subject to availability and wait time.""" while max_tries or wait_forever: max_tries -= 1 policy_files = ppo.get_policy_model_files(output_dir) def do_wait(t): time.sleep(t) t *= 2 return min(t, max_sleep_time_secs) # No policy files at all. if not policy_files: logging.info( "There are no policy files in [%s], waiting for %s secs.", output_dir, sleep_time_secs) sleep_time_secs = do_wait(sleep_time_secs) continue # Check if we have a newer epoch. policy_file = policy_files[0] epoch = ppo.get_epoch_from_policy_model_file(policy_file) # We don't - wait. if epoch <= min_epoch: logging.info("epoch [%s] <= min_epoch [%s], waiting for %s secs.", epoch, min_epoch, sleep_time_secs) sleep_time_secs = do_wait(sleep_time_secs) continue # We do have a new file, return it. policy_file = policy_files[0] epoch = ppo.get_epoch_from_policy_model_file(policy_file) logging.info("Found epoch [%s] and policy file [%s]", epoch, policy_file) return policy_file, epoch # Exhausted our waiting limit. return None