def test(flags, num_episodes: int = 10): if flags.xpid is None: checkpointpath = "./latest/model.tar" else: checkpointpath = os.path.expandvars( os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))) gym_env = create_env(flags) env = environment.Environment(gym_env) model = Net(gym_env.observation_space.shape, gym_env.action_space.n, flags.use_lstm) model.eval() checkpoint = torch.load(checkpointpath, map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) observation = env.initial() returns = [] while len(returns) < num_episodes: if flags.mode == "test_render": env.gym_env.render() agent_outputs = model(observation) policy_outputs, _ = agent_outputs observation = env.step(policy_outputs["action"]) if observation["done"].item(): returns.append(observation["episode_return"].item()) logging.info( "Episode ended after %d steps. Return: %.1f", observation["episode_step"].item(), observation["episode_return"].item(), ) env.close() logging.info("Average returns over %i steps: %.1f", num_episodes, sum(returns) / len(returns))
def act( flags, actor_index: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue, model: torch.nn.Module, buffers: Buffers, initial_agent_state_buffers, ): try: logging.info("Actor %i started.", actor_index) timings = prof.Timings() # Keep track of how fast things are. gym_env = create_env(flags) seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little") gym_env.seed(seed) env = environment.Environment(gym_env) env_output = env.initial() agent_state = model.initial_state(batch_size=1) agent_output, unused_state = model(env_output, agent_state) while True: index = free_queue.get() if index is None: break # Write old rollout end. for key in env_output: buffers[key][index][0, ...] = env_output[key] for key in agent_output: buffers[key][index][0, ...] = agent_output[key] for i, tensor in enumerate(agent_state): initial_agent_state_buffers[index][i][...] = tensor # Do new rollout. for t in range(flags.unroll_length): timings.reset() with torch.no_grad(): agent_output, agent_state = model(env_output, agent_state) timings.time("model") env_output = env.step(agent_output["action"]) timings.time("step") for key in env_output: buffers[key][index][t + 1, ...] = env_output[key] for key in agent_output: buffers[key][index][t + 1, ...] = agent_output[key] timings.time("write") full_queue.put(index) if actor_index == 0: logging.info("Actor %i: %s", actor_index, timings.summary()) except KeyboardInterrupt: pass # Return silently. except Exception as e: logging.error("Exception in worker process %i", actor_index) traceback.print_exc() print() raise e
def act( flags, actor_index: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue, model: torch.nn.Module, buffers: Buffers, initial_agent_state_buffers, ): try: logging.info("Actor %i started.", actor_index) timings = prof.Timings() # Keep track of how fast things are. gym_env = create_env(flags) seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little") gym_env.seed(seed) env = environment.Environment(gym_env) env_output = env.initial() agent_state = model.initial_state(batch_size=1) mems, mem_padding = None, None agent_output, unused_state, mems, mem_padding, _ = model( env_output, agent_state, mems, mem_padding) while True: index = free_queue.get() if index is None: break # explicitly make done False to allow the loop to run # Don't need to set 'done' to true since now take step out of done state # when do arrive at 'done' # env_output['done'] = torch.tensor([0], dtype=torch.uint8) # Write old rollout end. for key in env_output: buffers[key][index][0, ...] = env_output[key] for key in agent_output: buffers[key][index][0, ...] = agent_output[key] for i, tensor in enumerate(agent_state): initial_agent_state_buffers[index][i][...] = tensor # Do one new rollout, untill flags.unroll_length t = 0 while t < flags.unroll_length and not env_output['done'].item(): # for t in range(flags.unroll_length): timings.reset() # REmoved since never this will never be true (MOVED TO AFTER FOR LOOP) # if env_output['done'].item(): # mems = None with torch.no_grad(): agent_output, agent_state, mems, mem_padding, _ = model( env_output, agent_state, mems, mem_padding) timings.time("model") # TODO: Shakti add action repeat? env_output = env.step(agent_output["action"]) timings.time("step") for key in env_output: buffers[key][index][t + 1, ...] = env_output[key] for key in agent_output: buffers[key][index][t + 1, ...] = agent_output[key] timings.time("write") t += 1 if env_output['done'].item(): mems = None # Take arbitrary step to reset environment env_output = env.step(torch.tensor([2])) if t != flags.unroll_length: # TODO I checked and seems good but Shakti can you check as well? buffers['done'][index][t + 1:] = torch.tensor( [True]).repeat(flags.unroll_length - t) full_queue.put(index) if actor_index == 0: logging.info("Actor %i: %s", actor_index, timings.summary()) except KeyboardInterrupt: pass # Return silently. except Exception as e: logging.error("Exception in worker process %i", actor_index) traceback.print_exc() # print() raise e