def learn(self, experiences): """Update critics and actors""" rewards = to_tensor(experiences['reward']).float().to(self.device) dones = to_tensor(experiences['done']).type(torch.int).to(self.device) states = to_tensor(experiences['state']).float().to(self.device) actions = to_tensor(experiences['action']).to(self.device) next_states = to_tensor(experiences['next_state']).float().to(self.device) assert rewards.shape == dones.shape == (self.batch_size, 1) assert states.shape == next_states.shape == (self.batch_size, self.state_size) assert actions.shape == (self.batch_size, self.action_size) indices = None if hasattr(self.buffer, 'priority_update'): # When using PER buffer indices = experiences['index'] loss_critic = self.compute_value_loss(states, actions, next_states, rewards, dones, indices) # Value (critic) optimization self.critic_optimizer.zero_grad() loss_critic.backward() nn.utils.clip_grad_norm_(self.actor_params, self.max_grad_norm_critic) self.critic_optimizer.step() self._loss_critic = float(loss_critic.item()) # Policy (actor) optimization loss_actor = self.compute_policy_loss(states) self.actor_optimizer.zero_grad() loss_actor.backward() nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm_actor) self.actor_optimizer.step() self._loss_actor = float(loss_actor.item()) # Networks gradual sync soft_update(self.target_actor, self.actor, self.tau) soft_update(self.target_critic, self.critic, self.tau)
def step(self, state, action, reward, next_state, done) -> None: """Letting the agent to take a step. On some steps the agent will initiate learning step. This is dependent on the `update_freq` value. Parameters: state: S(t) action: A(t) reward: R(t) nexxt_state: S(t+1) done: (bool) Whether the state is terminal. """ self.iteration += 1 state = to_tensor(self.state_transform(state)).float().to("cpu") next_state = to_tensor(self.state_transform(next_state)).float().to("cpu") reward = self.reward_transform(reward) # Delay adding to buffer to account for n_steps (particularly the reward) self.n_buffer.add(state=state.numpy(), action=[int(action)], reward=[reward], done=[done], next_state=next_state.numpy()) if not self.n_buffer.available: return self.buffer.add(**self.n_buffer.get().get_dict()) if self.iteration < self.warm_up: return if len(self.buffer) >= self.batch_size and (self.iteration % self.update_freq) == 0: for _ in range(self.number_updates): self.learn(self.buffer.sample()) # Update networks only once - sync local & target soft_update(self.target_net, self.net, self.tau)
def learn(self, experiences: Dict[str, List]) -> None: """ Parameters: experiences: Contains all experiences for the agent. Typically sampled from the memory buffer. Five keys are expected, i.e. `state`, `action`, `reward`, `next_state`, `done`. Each key contains a array and all arrays have to have the same length. """ rewards = to_tensor(experiences['reward']).float().to(self.device) dones = to_tensor(experiences['done']).type(torch.int).to(self.device) states = to_tensor(experiences['state']).float().to(self.device) next_states = to_tensor(experiences['next_state']).float().to(self.device) actions = to_tensor(experiences['action']).type(torch.long).to(self.device) assert rewards.shape == dones.shape == (self.batch_size, 1) assert states.shape == next_states.shape == (self.batch_size, self.state_size) assert actions.shape == (self.batch_size, 1) # Discrete domain with torch.no_grad(): prob_next = self.target_net.act(next_states) q_next = (prob_next * self.z_atoms).sum(-1) * self.z_delta if self.using_double_q: duel_prob_next = self.net.act(next_states) a_next = torch.argmax((duel_prob_next * self.z_atoms).sum(-1), dim=-1) else: a_next = torch.argmax(q_next, dim=-1) prob_next = prob_next[self.__batch_indices, a_next, :] m = self.net.dist_projection(rewards, 1 - dones, self.gamma ** self.n_steps, prob_next) assert m.shape == (self.batch_size, self.num_atoms) log_prob = self.net(states, log_prob=True) assert log_prob.shape == (self.batch_size, self.action_size, self.num_atoms) log_prob = log_prob[self.__batch_indices, actions.squeeze(), :] assert log_prob.shape == m.shape == (self.batch_size, self.num_atoms) # Cross-entropy loss error and the loss is batch mean error = -torch.sum(m * log_prob, 1) assert error.shape == (self.batch_size,) loss = error.mean() assert loss >= 0 self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm) self.optimizer.step() self._loss = float(loss.item()) if hasattr(self.buffer, 'priority_update'): assert (~torch.isnan(error)).any() self.buffer.priority_update(experiences['index'], error.detach().cpu().numpy()) # Update networks - sync local & target soft_update(self.target_net, self.net, self.tau)
def learn(self, experiences): """Update critics and actors""" rewards = to_tensor(experiences['reward']).float().to( self.device).unsqueeze(1) dones = to_tensor(experiences['done']).type(torch.int).to( self.device).unsqueeze(1) states = to_tensor(experiences['state']).float().to(self.device) actions = to_tensor(experiences['action']).to(self.device) next_states = to_tensor(experiences['next_state']).float().to( self.device) if (self.iteration % self.update_freq) == 0: self._update_value_function(states, actions, rewards, next_states, dones) if (self.iteration % self.update_policy_freq) == 0: self._update_policy(states) soft_update(self.target_actor, self.actor, self.tau) soft_update(self.target_critic, self.critic, self.tau)
def learn(self, samples): """update the critics and actors of all the agents """ rewards = to_tensor(samples['reward']).float().to(self.device).view( self.batch_size, 1) dones = to_tensor(samples['done']).int().to(self.device).view( self.batch_size, 1) states = to_tensor(samples['state']).float().to(self.device).view( self.batch_size, self.state_size) next_states = to_tensor(samples['next_state']).float().to( self.device).view(self.batch_size, self.state_size) actions = to_tensor(samples['action']).to(self.device).view( self.batch_size, self.action_size) # Critic (value) update for _ in range(self.critic_number_updates): value_loss, error = self.compute_value_loss( states, actions, rewards, next_states, dones) self.critic_optimizer.zero_grad() value_loss.backward() nn.utils.clip_grad_norm_(self.critic_params, self.max_grad_norm_critic) self.critic_optimizer.step() self._loss_critic = value_loss.item() # Actor (policy) update for _ in range(self.actor_number_updates): policy_loss = self.compute_policy_loss(states) self.actor_optimizer.zero_grad() policy_loss.backward() nn.utils.clip_grad_norm_(self.actor_params, self.max_grad_norm_actor) self.actor_optimizer.step() self._loss_actor = policy_loss.item() if hasattr(self.memory, 'priority_update'): assert any(~torch.isnan(error)) self.memory.priority_update(samples['index'], error.abs()) soft_update(self.target_double_critic, self.double_critic, self.tau)
def learn(self, experiences) -> None: """Update critics and actors""" rewards = to_tensor(experiences['reward']).float().to( self.device).unsqueeze(1) dones = to_tensor(experiences['done']).type(torch.int).to( self.device).unsqueeze(1) states = to_tensor(experiences['state']).float().to(self.device) actions = to_tensor(experiences['action']).to(self.device) next_states = to_tensor(experiences['next_state']).float().to( self.device) assert rewards.shape == dones.shape == (self.batch_size, 1) assert states.shape == next_states.shape == (self.batch_size, self.state_size) assert actions.shape == (self.batch_size, self.action_size) # Value (critic) optimization loss_critic = self.compute_value_loss(states, actions, next_states, rewards, dones) self.critic_optimizer.zero_grad() loss_critic.backward() nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm_critic) self.critic_optimizer.step() self._loss_critic = float(loss_critic.item()) # Policy (actor) optimization loss_actor = self.compute_policy_loss(states) self.actor_optimizer.zero_grad() loss_actor.backward() nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm_actor) self.actor_optimizer.step() self._loss_actor = loss_actor.item() # Soft update target weights soft_update(self.target_actor, self.actor, self.tau) soft_update(self.target_critic, self.critic, self.tau)
def update_targets(self): """soft update targets""" for agent in self.agents.values(): soft_update(agent.target_actor, agent.actor, self.tau) soft_update(self.target_critic, self.critic, self.tau)