Exemple #1
0
def perform_rollout(env: gym.Env, model: torch.nn.Module,
                    gamma: float) -> Memory:
    memory = Memory(gamma, BATCH_SIZE)
    obs = env.reset()
    done = False

    while not done:
        obs = torch.unsqueeze(torch.FloatTensor(obs), dim=0)
        action, action_logprobs, state_value = model.act(obs)

        obs, rew, done, _ = env.step(int(action))

        memory.update_actions(action)
        memory.update_action_logprobs(action_logprobs)
        memory.update_state_values(state_value)
        memory.update_rewards(torch.tensor(rew))
        memory.update_is_terminals(torch.tensor(done, dtype=torch.uint8))

    return memory
Exemple #2
0
def collect_batch(env: gym.Env, actor: torch.nn.Module, buffer: Memory,
                  batch_size: int, device: torch.device):
    while len(buffer) < batch_size:
        obs = env.reset()
        done = False
        obs = torch.tensor(obs, dtype=torch.float32, device=device)
        prev_idx = buffer.add_obs(obs)

        while not done:
            obs = torch.unsqueeze(obs, dim=0)
            action, action_logprobs = actor.act(obs)
            action = action.cpu().numpy()[0]
            obs, rew, done, _ = env.step(action)
            obs = torch.tensor(obs, dtype=torch.float32, device=device)
            next_idx = buffer.add_obs(obs)
            buffer.add_timestep(prev_idx, next_idx, action, action_logprobs,
                                rew, done)
            prev_idx = next_idx
        buffer.end_rollout()
Exemple #3
0
def act(
    flags,
    actor_index: int,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    model: torch.nn.Module,
    buffers: Buffers,
    initial_agent_state_buffers,
):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.

        gym_env = create_env(flags)
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        gym_env.seed(seed)
        env = environment.Environment(gym_env)
        env_output = env.initial()

        #agent_ model.initialize(env_output)
        agent_state = model.initialize(env_output, batch_size=1)
        agent_output, unused_state = model.act(env_output, agent_state)
        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]
            for i, tensor in enumerate(agent_state["core_state"]):
                initial_agent_state_buffers[index][i][...] = tensor

            # Do new rollout.
            for t in range(flags.unroll_length):
                timings.reset()

                with torch.no_grad():
                    agent_output, agent_state = model.act(
                        env_output, agent_state)

                timings.time("model")

                env_output = env.step(agent_output["action"])

                timings.time("step")

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]

                timings.time("write")
            full_queue.put(index)

        if actor_index == 0:
            logging.info("Actor %i: %s", actor_index, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e