def perform_rollout(env: gym.Env, model: torch.nn.Module, gamma: float) -> Memory: memory = Memory(gamma, BATCH_SIZE) obs = env.reset() done = False while not done: obs = torch.unsqueeze(torch.FloatTensor(obs), dim=0) action, action_logprobs, state_value = model.act(obs) obs, rew, done, _ = env.step(int(action)) memory.update_actions(action) memory.update_action_logprobs(action_logprobs) memory.update_state_values(state_value) memory.update_rewards(torch.tensor(rew)) memory.update_is_terminals(torch.tensor(done, dtype=torch.uint8)) return memory
def collect_batch(env: gym.Env, actor: torch.nn.Module, buffer: Memory, batch_size: int, device: torch.device): while len(buffer) < batch_size: obs = env.reset() done = False obs = torch.tensor(obs, dtype=torch.float32, device=device) prev_idx = buffer.add_obs(obs) while not done: obs = torch.unsqueeze(obs, dim=0) action, action_logprobs = actor.act(obs) action = action.cpu().numpy()[0] obs, rew, done, _ = env.step(action) obs = torch.tensor(obs, dtype=torch.float32, device=device) next_idx = buffer.add_obs(obs) buffer.add_timestep(prev_idx, next_idx, action, action_logprobs, rew, done) prev_idx = next_idx buffer.end_rollout()
def act( flags, actor_index: int, free_queue: mp.SimpleQueue, full_queue: mp.SimpleQueue, model: torch.nn.Module, buffers: Buffers, initial_agent_state_buffers, ): try: logging.info("Actor %i started.", actor_index) timings = prof.Timings() # Keep track of how fast things are. gym_env = create_env(flags) seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little") gym_env.seed(seed) env = environment.Environment(gym_env) env_output = env.initial() #agent_ model.initialize(env_output) agent_state = model.initialize(env_output, batch_size=1) agent_output, unused_state = model.act(env_output, agent_state) while True: index = free_queue.get() if index is None: break # Write old rollout end. for key in env_output: buffers[key][index][0, ...] = env_output[key] for key in agent_output: buffers[key][index][0, ...] = agent_output[key] for i, tensor in enumerate(agent_state["core_state"]): initial_agent_state_buffers[index][i][...] = tensor # Do new rollout. for t in range(flags.unroll_length): timings.reset() with torch.no_grad(): agent_output, agent_state = model.act( env_output, agent_state) timings.time("model") env_output = env.step(agent_output["action"]) timings.time("step") for key in env_output: buffers[key][index][t + 1, ...] = env_output[key] for key in agent_output: buffers[key][index][t + 1, ...] = agent_output[key] timings.time("write") full_queue.put(index) if actor_index == 0: logging.info("Actor %i: %s", actor_index, timings.summary()) except KeyboardInterrupt: pass # Return silently. except Exception as e: logging.error("Exception in worker process %i", actor_index) traceback.print_exc() print() raise e