Example #1
0
    def learn(self, experiences):
        """Update critics and actors"""
        rewards = to_tensor(experiences['reward']).float().to(self.device)
        dones = to_tensor(experiences['done']).type(torch.int).to(self.device)
        states = to_tensor(experiences['state']).float().to(self.device)
        actions = to_tensor(experiences['action']).to(self.device)
        next_states = to_tensor(experiences['next_state']).float().to(self.device)
        assert rewards.shape == dones.shape == (self.batch_size, 1)
        assert states.shape == next_states.shape == (self.batch_size, self.state_size)
        assert actions.shape == (self.batch_size, self.action_size)

        indices = None
        if hasattr(self.buffer, 'priority_update'):  # When using PER buffer
            indices = experiences['index']
        loss_critic = self.compute_value_loss(states, actions, next_states, rewards, dones, indices)

        # Value (critic) optimization
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        nn.utils.clip_grad_norm_(self.actor_params, self.max_grad_norm_critic)
        self.critic_optimizer.step()
        self._loss_critic = float(loss_critic.item())

        # Policy (actor) optimization
        loss_actor = self.compute_policy_loss(states)
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm_actor)
        self.actor_optimizer.step()
        self._loss_actor = float(loss_actor.item())

        # Networks gradual sync
        soft_update(self.target_actor, self.actor, self.tau)
        soft_update(self.target_critic, self.critic, self.tau)
Example #2
0
    def step(self, state, action, reward, next_state, done) -> None:
        """Letting the agent to take a step.

        On some steps the agent will initiate learning step. This is dependent on
        the `update_freq` value.

        Parameters:
            state: S(t)
            action: A(t)
            reward: R(t)
            nexxt_state: S(t+1)
            done: (bool) Whether the state is terminal.

        """
        self.iteration += 1
        state = to_tensor(self.state_transform(state)).float().to("cpu")
        next_state = to_tensor(self.state_transform(next_state)).float().to("cpu")
        reward = self.reward_transform(reward)

        # Delay adding to buffer to account for n_steps (particularly the reward)
        self.n_buffer.add(state=state.numpy(), action=[int(action)], reward=[reward], done=[done], next_state=next_state.numpy())
        if not self.n_buffer.available:
            return

        self.buffer.add(**self.n_buffer.get().get_dict())

        if self.iteration < self.warm_up:
            return

        if len(self.buffer) >= self.batch_size and (self.iteration % self.update_freq) == 0:
            for _ in range(self.number_updates):
                self.learn(self.buffer.sample())

            # Update networks only once - sync local & target
            soft_update(self.target_net, self.net, self.tau)
Example #3
0
    def learn(self, experiences: Dict[str, List]) -> None:
        """
        Parameters:
            experiences: Contains all experiences for the agent. Typically sampled from the memory buffer.
                Five keys are expected, i.e. `state`, `action`, `reward`, `next_state`, `done`.
                Each key contains a array and all arrays have to have the same length.

        """
        rewards = to_tensor(experiences['reward']).float().to(self.device)
        dones = to_tensor(experiences['done']).type(torch.int).to(self.device)
        states = to_tensor(experiences['state']).float().to(self.device)
        next_states = to_tensor(experiences['next_state']).float().to(self.device)
        actions = to_tensor(experiences['action']).type(torch.long).to(self.device)
        assert rewards.shape == dones.shape == (self.batch_size, 1)
        assert states.shape == next_states.shape == (self.batch_size, self.state_size)
        assert actions.shape == (self.batch_size, 1)  # Discrete domain

        with torch.no_grad():
            prob_next = self.target_net.act(next_states)
            q_next = (prob_next * self.z_atoms).sum(-1) * self.z_delta
            if self.using_double_q:
                duel_prob_next = self.net.act(next_states)
                a_next = torch.argmax((duel_prob_next * self.z_atoms).sum(-1), dim=-1)
            else:
                a_next = torch.argmax(q_next, dim=-1)

            prob_next = prob_next[self.__batch_indices, a_next, :]

        m = self.net.dist_projection(rewards, 1 - dones, self.gamma ** self.n_steps, prob_next)
        assert m.shape == (self.batch_size, self.num_atoms)

        log_prob = self.net(states, log_prob=True)
        assert log_prob.shape == (self.batch_size, self.action_size, self.num_atoms)
        log_prob = log_prob[self.__batch_indices, actions.squeeze(), :]
        assert log_prob.shape == m.shape == (self.batch_size, self.num_atoms)

        # Cross-entropy loss error and the loss is batch mean
        error = -torch.sum(m * log_prob, 1)
        assert error.shape == (self.batch_size,)
        loss = error.mean()
        assert loss >= 0

        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm)
        self.optimizer.step()
        self._loss = float(loss.item())

        if hasattr(self.buffer, 'priority_update'):
            assert (~torch.isnan(error)).any()
            self.buffer.priority_update(experiences['index'], error.detach().cpu().numpy())

        # Update networks - sync local & target
        soft_update(self.target_net, self.net, self.tau)
Example #4
0
    def learn(self, experiences):
        """Update critics and actors"""
        rewards = to_tensor(experiences['reward']).float().to(
            self.device).unsqueeze(1)
        dones = to_tensor(experiences['done']).type(torch.int).to(
            self.device).unsqueeze(1)
        states = to_tensor(experiences['state']).float().to(self.device)
        actions = to_tensor(experiences['action']).to(self.device)
        next_states = to_tensor(experiences['next_state']).float().to(
            self.device)

        if (self.iteration % self.update_freq) == 0:
            self._update_value_function(states, actions, rewards, next_states,
                                        dones)

        if (self.iteration % self.update_policy_freq) == 0:
            self._update_policy(states)

            soft_update(self.target_actor, self.actor, self.tau)
            soft_update(self.target_critic, self.critic, self.tau)
Example #5
0
    def learn(self, samples):
        """update the critics and actors of all the agents """

        rewards = to_tensor(samples['reward']).float().to(self.device).view(
            self.batch_size, 1)
        dones = to_tensor(samples['done']).int().to(self.device).view(
            self.batch_size, 1)
        states = to_tensor(samples['state']).float().to(self.device).view(
            self.batch_size, self.state_size)
        next_states = to_tensor(samples['next_state']).float().to(
            self.device).view(self.batch_size, self.state_size)
        actions = to_tensor(samples['action']).to(self.device).view(
            self.batch_size, self.action_size)

        # Critic (value) update
        for _ in range(self.critic_number_updates):
            value_loss, error = self.compute_value_loss(
                states, actions, rewards, next_states, dones)
            self.critic_optimizer.zero_grad()
            value_loss.backward()
            nn.utils.clip_grad_norm_(self.critic_params,
                                     self.max_grad_norm_critic)
            self.critic_optimizer.step()
            self._loss_critic = value_loss.item()

        # Actor (policy) update
        for _ in range(self.actor_number_updates):
            policy_loss = self.compute_policy_loss(states)
            self.actor_optimizer.zero_grad()
            policy_loss.backward()
            nn.utils.clip_grad_norm_(self.actor_params,
                                     self.max_grad_norm_actor)
            self.actor_optimizer.step()
            self._loss_actor = policy_loss.item()

        if hasattr(self.memory, 'priority_update'):
            assert any(~torch.isnan(error))
            self.memory.priority_update(samples['index'], error.abs())

        soft_update(self.target_double_critic, self.double_critic, self.tau)
Example #6
0
    def learn(self, experiences) -> None:
        """Update critics and actors"""
        rewards = to_tensor(experiences['reward']).float().to(
            self.device).unsqueeze(1)
        dones = to_tensor(experiences['done']).type(torch.int).to(
            self.device).unsqueeze(1)
        states = to_tensor(experiences['state']).float().to(self.device)
        actions = to_tensor(experiences['action']).to(self.device)
        next_states = to_tensor(experiences['next_state']).float().to(
            self.device)
        assert rewards.shape == dones.shape == (self.batch_size, 1)
        assert states.shape == next_states.shape == (self.batch_size,
                                                     self.state_size)
        assert actions.shape == (self.batch_size, self.action_size)

        # Value (critic) optimization
        loss_critic = self.compute_value_loss(states, actions, next_states,
                                              rewards, dones)
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(),
                                 self.max_grad_norm_critic)
        self.critic_optimizer.step()
        self._loss_critic = float(loss_critic.item())

        # Policy (actor) optimization
        loss_actor = self.compute_policy_loss(states)
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(),
                                 self.max_grad_norm_actor)
        self.actor_optimizer.step()
        self._loss_actor = loss_actor.item()

        # Soft update target weights
        soft_update(self.target_actor, self.actor, self.tau)
        soft_update(self.target_critic, self.critic, self.tau)
Example #7
0
 def update_targets(self):
     """soft update targets"""
     for agent in self.agents.values():
         soft_update(agent.target_actor, agent.actor, self.tau)
     soft_update(self.target_critic, self.critic, self.tau)