def sample(self, _):
        if len(self) < self.n_steps * self.n_envs:
            raise Exception("Not enough states received!")

        sample_n = self.n_envs * self.n_steps
        sample_states = [None] * sample_n
        sample_actions = [None] * sample_n
        sample_returns = [0] * sample_n
        sample_next_states = [None] * sample_n
        sample_lengths = [0] * sample_n

        # compute the N-step returns the slow way
        sample_returns = torch.zeros(
            (self.n_steps, self.n_envs),
            device=self._rewards[0].device
        )
        sample_lengths = torch.zeros(
            (self.n_steps, self.n_envs),
            device=self._rewards[0].device
        )
        current_returns = self._rewards[0] * 0
        current_lengths = current_returns.clone()
        for i in range(self.n_steps):
            t = self.n_steps - 1 - i
            mask = self._states[t + 1].mask.float()
            current_returns = (
                self._rewards[t] + self.gamma * current_returns * mask
            )
            current_lengths = (
                1 + current_lengths * mask
            )
            sample_returns[t] = current_returns
            sample_lengths[t] = current_lengths

        for e in range(self.n_envs):
            next_state = self._states[self.n_steps][e]
            for i in range(self.n_steps):
                t = self.n_steps - 1 - i
                idx = t * self.n_envs + e
                state = self._states[t][e]
                action = self._actions[t][e]

                sample_states[idx] = state
                sample_actions[idx] = action
                sample_next_states[idx] = next_state

                if not state.mask:
                    next_state = state

        self._states = self._states[self.n_steps:]
        self._actions = self._actions[self.n_steps:]
        self._rewards = self._rewards[self.n_steps:]

        return (
            State.from_list(sample_states),
            sample_actions,
            sample_returns.view(-1),
            State.from_list(sample_next_states),
            sample_lengths.view(-1),
        )
 def _reshape(self, minibatch, weights):
     states = State.from_list([sample[0] for sample in minibatch])
     actions = [sample[1] for sample in minibatch]
     rewards = torch.tensor([sample[2] for sample in minibatch],
                            device=self.device).float()
     next_states = State.from_list([sample[3] for sample in minibatch])
     return (states, actions, rewards, next_states, weights)
Example #3
0
    def sample(self, batch_size):
        if batch_size > len(self):
            raise Exception("Not enough states for batch size!")

        states = self._states[0:batch_size]
        actions = self._actions[0:batch_size]
        actions = torch.tensor(actions, device=actions[0].device)
        next_states = self._next_states[0:batch_size]
        rewards = self._rewards[0:batch_size]
        rewards = torch.tensor(rewards,
                               device=rewards[0].device,
                               dtype=torch.float)
        lengths = self._lengths[0:batch_size]
        lengths = torch.tensor(lengths,
                               device=rewards[0].device,
                               dtype=torch.float)

        self._states = self._states[batch_size:]
        self._actions = self._actions[batch_size:]
        self._next_states = self._next_states[batch_size:]
        self._rewards = self._rewards[batch_size:]
        self._lengths = self._lengths[batch_size:]

        states = State.from_list(states)
        next_states = State.from_list(next_states)
        return states, actions, rewards, next_states, lengths
Example #4
0
    def sample(self, _):
        if len(self) < self.n_steps * self.n_envs:
            raise Exception("Not enough states received!")

        sample_n = self.n_envs * self.n_steps
        sample_states = [None] * sample_n
        sample_actions = [None] * sample_n
        sample_returns = [0] * sample_n
        sample_next_states = [None] * sample_n
        sample_lengths = [0] * sample_n

        # compute the N-step returns the slow way
        for e in range(self.n_envs):
            for t in range(self.n_steps):
                i = t * self.n_envs + e
                state = self._states[t][e]
                action = self._actions[t][e]
                returns = 0.0
                next_state = state
                sample_length = 0
                if state.mask:
                    for k in range(1, self.n_steps + 1):
                        sample_length += 1
                        next_state = self._states[t + k][e]
                        returns += (self.gamma
                                    **(k - 1)) * self._rewards[t + k][e]
                        if not next_state.mask or t + k == self.n_steps:
                            break
                sample_states[i] = state
                sample_actions[i] = action
                sample_returns[i] = returns
                sample_next_states[i] = next_state
                sample_lengths[i] = sample_length

        self._states = [self._states[-1]]
        self._actions = [self._actions[-1]]
        self._rewards = [self._rewards[-1]]
        sample_returns = torch.tensor(sample_returns,
                                      device=self._rewards[0].device,
                                      dtype=torch.float)
        sample_lengths = torch.tensor(sample_lengths,
                                      device=self._rewards[0].device,
                                      dtype=torch.float)

        return (
            State.from_list(sample_states),
            sample_actions,
            sample_returns,
            State.from_list(sample_next_states),
            sample_lengths,
        )
 def _run_multi(self, make_agent, n_envs):
     envs = self.env.duplicate(n_envs)
     agent = make_agent(envs, writer=self._writer)
     for env in envs:
         env.reset()
     returns = torch.zeros((n_envs)).float().to(self.env.device)
     start = timer()
     while not self._done():
         states = State.from_list([env.state for env in envs])
         rewards = torch.tensor([env.reward for env in envs
                                 ]).float().to(self.env.device)
         actions = agent.act(states, rewards)
         for i, env in enumerate(envs):
             if env.done:
                 end = timer()
                 fps = self._frames / (end - start)
                 returns[i] += rewards[i]
                 self._log(returns[i], fps)
                 env.reset()
                 returns[i] = 0
                 self._episode += 1
                 self._writer.episodes = self._episode
             else:
                 if actions[i] is not None:
                     returns[i] += rewards[i]
                     env.step(actions[i])
                     self._frames += 1
                     self._writer.frames = self._frames
Example #6
0
    def _step(self):
        states = State.from_list([env.state for env in self._env])
        rewards = torch.tensor([env.reward for env in self._env],
                               dtype=torch.float,
                               device=self._env[0].device)
        actions = self._agent.act(states, rewards)

        for i, env in enumerate(self._env):
            self._step_env(i, env, actions[i])
Example #7
0
 def _train(self):
     if len(self._buffer) >= self._batch_size:
         states = State.from_list(self._features)
         _, _, returns, next_states, rollout_lengths = self._buffer.sample(
             self._batch_size)
         td_errors = (returns + (self.discount_factor**rollout_lengths) *
                      self.v.eval(self.features.eval(next_states)) -
                      self.v(states))
         self.v.reinforce(td_errors)
         self.policy.reinforce(td_errors)
         self.features.reinforce()
         self._features = []
Example #8
0
    def _summarize_transitions(self):
        sample_n = self.n_envs * self.n_steps
        sample_states = [None] * sample_n
        sample_actions = [None] * sample_n
        sample_next_states = [None] * sample_n

        for e in range(self.n_envs):
            next_state = self._states[self.n_steps][e]
            for i in range(self.n_steps):
                t = self.n_steps - 1 - i
                idx = t * self.n_envs + e
                state = self._states[t][e]
                action = self._actions[t][e]

                sample_states[idx] = state
                sample_actions[idx] = action
                sample_next_states[idx] = next_state

                if not state.mask:
                    next_state = state

        return (State.from_list(sample_states), torch.stack(sample_actions),
                State.from_list(sample_next_states))
    def advantages(self, states):
        if len(self) < self._batch_size:
            raise Exception("Not enough states received!")

        self._states.append(states)
        states = State.from_list(self._states[0:self.n_steps + 1])
        actions = torch.cat(self._actions[:self.n_steps], dim=0)
        rewards = torch.stack(self._rewards[:self.n_steps]).view(
            self.n_steps, -1)
        _values = self.v.target(self.features.target(states)).view(
            (self.n_steps + 1, -1))
        values = _values[0:self.n_steps]
        next_values = _values[1:self.n_steps + 1]
        td_errors = rewards + self.gamma * next_values - values
        advantages = self._compute_advantages(td_errors)
        self._clear_buffers()

        return (states[0:self._batch_size], actions, advantages)
def validate_multi_env_agent(make_agent, base_env):
    make, n_env = make_agent
    envs = base_env.duplicate(n_env)
    agent = make(envs, writer=DummyWriter())

    for env in envs:
        env.reset()

    for _ in range(10):
        states = State.from_list([env.state for env in envs])
        rewards = [env.reward for env in envs]
        rewards = torch.tensor(rewards, device=base_env.device).float()
        actions = agent.act(states, rewards)
        for (action, env) in zip(actions, envs):
            if env.done:
                env.reset()
            elif action is not None:
                env.step(action)
    def test_parallel(self):
        buffer = GeneralizedAdvantageBuffer(self.v,
                                            self.features,
                                            2,
                                            2,
                                            discount_factor=0.5,
                                            lam=0.5)
        actions = torch.ones((2))
        states = [
            State(torch.tensor([[0], [3]])),
            State(torch.tensor([[1], [4]])),
            State(torch.tensor([[2], [5]])),
        ]
        rewards = torch.tensor([[1., 1], [2, 1], [4, 1]])
        buffer.store(states[0], actions, rewards[0])
        buffer.store(states[1], actions, rewards[1])

        values = self.v.eval(self.features.eval(State.from_list(states))).view(
            3, -1)
        tt.assert_almost_equal(values,
                               torch.tensor([[0.183, -1.408], [-0.348, -1.938],
                                             [-0.878, -2.468]]),
                               decimal=3)

        td_errors = torch.zeros(2, 2)
        td_errors[0] = rewards[0] + 0.5 * values[1] - values[0]
        td_errors[1] = rewards[1] + 0.5 * values[2] - values[1]
        tt.assert_almost_equal(td_errors,
                               torch.tensor([[0.6436, 1.439], [1.909, 1.704]]),
                               decimal=3)

        advantages = torch.zeros(2, 2)
        advantages[0] = td_errors[0] + 0.25 * td_errors[1]
        advantages[1] = td_errors[1]
        tt.assert_almost_equal(advantages,
                               torch.tensor([[1.121, 1.865], [1.909, 1.704]]),
                               decimal=3)

        _states, _actions, _advantages = buffer.advantages(states[2])
        tt.assert_almost_equal(_advantages, advantages.view(-1))
 def _aggregate_states(self):
     return State.from_list([env.state for env in self._envs])