예제 #1
0
 def update_targets(self):
     """soft update targets"""
     for ddpg_agent in self.maddpg_agent:
         soft_update(ddpg_agent.target_actor, ddpg_agent.actor,
                     self.config.tau)
         soft_update(ddpg_agent.target_critic, ddpg_agent.critic,
                     self.config.tau)
예제 #2
0
    def update(self, pol_opt, q_opt, v_opt, batch_size):
        for _ in range(iters):
            states, batch_actions, rewards, next_states, masks = self.replay_memory.sample_and_split(batch_size)
            actions, log_probs, _ = self.select_action(states)

            state_actions = torch.cat([states, actions], dim=-1)
            v_targ = self.q_fn.get_min_value(state_actions) - log_probs
            v_values = self.v_fn(states)
            v_fn_loss = torch.mean(torch.cat([(v - v_targ.detach()) ** 2 for v in v_values], dim=-1))
            v_opt.zero_grad()
            v_fn_loss.backward()
            v_opt.step()
            
            q_targ = rewards + masks * gamma * masks * self.v_fn_targ.get_min_value(next_states)
            batch_state_actions = torch.cat([states, batch_actions], dim=-1)
            q_values = self.q_fn(batch_state_actions)
            q_fn_loss = torch.mean(torch.cat([(q - q_targ.detach()) ** 2 for q in q_values], dim=-1))
            q_opt.zero_grad()
            q_fn_loss.backward()
            q_opt.step()

        states, _, _, _, _ = self.replay_memory.sample_and_split(batch_size)
        actions, log_probs, _ = self.select_action(states)
        state_actions = torch.cat([states, actions], dim=-1)
        policy_loss = torch.mean(alpha * log_probs - self.q_fn(state_actions)[0])
        pol_opt.zero_grad()
        policy_loss.backward()
        pol_opt.step()
        utils.soft_update(self.v_fn_targ, self.v_fn)
        return policy_loss.item(), q_fn_loss.item()
예제 #3
0
    def learn(self, samples):
        states, actions, rewards, next_states, dones = samples

        #Train the critic
        # Get the actions corresponding to next states from actor and then their Q-values
        # from target critic
        actions_next = self.target_actor(next_states)
        Q_targets_next = self.target_critic(next_states, actions_next)

        # Compute Q targets using TD-difference
        Q_targets = rewards + (self.gamma * Q_targets_next * (1 - dones))

        # Compute critic loss, perform backward pass and training step
        Q_expected = self.critic(states, actions)
        critic_loss = F.mse_loss(Q_expected, Q_targets)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(), 1)
        self.critic_optimizer.step()

        # Update Actor
        # Compute Actor loss
        actions_pred = self.actor(states)
        # -ve sign because we want to maximise this value
        actor_loss = -self.critic(states, actions_pred).mean()

        # minimizing the loss
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # update target networks
        soft_update(self.critic, self.target_critic, self.tau)
        soft_update(self.actor, self.target_actor, self.tau)
예제 #4
0
 def update_targets(self):
     """soft update targets"""
     self.iter += 1
     for ddpg_agent in self.maddpg_agent:
         soft_update(ddpg_agent.actor_target,
                     ddpg_agent.actor_local, self.params["tau"])
         soft_update(ddpg_agent.critic_target,
                     ddpg_agent.critic_local, self.params["tau"])
예제 #5
0
 def update_targets(self):
     """soft update of critic and actor target networks for all agents"""
     self.iter += 1
     for ddpg_agent in self.maddpg_agent:
         soft_update(ddpg_agent.target_actor, ddpg_agent.local_actor,
                     self.tau)
         soft_update(ddpg_agent.target_critic, ddpg_agent.local_critic,
                     self.tau)
예제 #6
0
 def update_targets(self, update_critic):
     """soft update targets"""
     self.iter += 1
     for ddpg_agent in self.maddpg_agent:
         utilities.soft_update(ddpg_agent.actor_target,
                               ddpg_agent.actor_local, self.tau)
         if update_critic:
             utilities.soft_update(ddpg_agent.critic_target,
                                   ddpg_agent.critic_local, self.tau)
예제 #7
0
    def update(self,
               buffer: ReplayBuffer,
               batchsize: int = 1000,
               tau: float = 0.005,
               discount: float = 0.98):

        states, actions, rewards, states_next, dones = buffer.sample(
            batchsize=batchsize)

        actions_next = self.target_actor(torch.stack(states_next).float())
        input_target_critic = torch.cat(
            [torch.stack(states_next).float(),
             actions_next.float()], axis=1)
        state_value = self.target_critic(input_target_critic)
        state_value.add_(torch.tensor(rewards).unsqueeze(1))
        state_value = state_value * discount * (1 -
                                                torch.tensor(dones).float())
        state_value.detach()

        input_critic = torch.cat(
            [torch.stack(states).float(),
             torch.stack(actions).float()],
            axis=1)
        state_value_local = self.critic(input_critic)

        critic_loss = (state_value -
                       state_value_local).pow(2).mul(0.5).sum(-1).mean()
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # update actor
        actions_new = self.actor(torch.stack(states).float())
        value_critic = self.critic(
            torch.cat([torch.stack(states).float(), actions_new], axis=1))
        loss_actor = -value_critic.mean()

        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        self.actor_optimizer.step()
        soft_update(self.target_actor, self.actor, tau)
        soft_update(self.target_critic, self.critic, tau)
예제 #8
0
 def update_targets(self, agent, agent_critic):
     """soft update targets"""
     self.iter += 1
     soft_update(agent.target_actor, agent.actor, self.tau)
     soft_update(agent_critic.target_critic, agent_critic.critic, self.tau)
     agent.noise.reset()
 def update_targets(self,agent_num):
     """soft update targets"""
     self.iter += 1
     ddpg_agent = self.maddpg_agent[agent_num]
     soft_update(ddpg_agent.target_actor, ddpg_agent.actor, self.tau)
     soft_update(ddpg_agent.target_critic, ddpg_agent.critic, self.tau)
예제 #10
0
 def soft_update(self):
     """soft update targets"""
     for agent in self.agents_list:
         soft_update(agent.actor_target, agent.actor_local, TAU)
         soft_update(agent.critic_target, agent.critic_local, TAU)
예제 #11
0
 def update_targets(self, i):
     """soft update targets"""
     soft_update(self.agents[i].target_actor, self.agents[i].actor, TAU)
     soft_update(self.agents[i].target_critic, self.agents[i].critic, TAU)
예제 #12
0
 def update_targets(self, tau=0.005):
     soft_update(self.target_actor, self.actor, tau)
     soft_update(self.target_critic, self.critic, tau)
 def update_targets(self):
     """soft update targets"""
     self.update_count += 1
     for ddpg_agent in self.maddpg_agent:
         soft_update(ddpg_agent.target_actor, ddpg_agent.actor, self.tau)
         soft_update(ddpg_agent.target_critic, ddpg_agent.critic, self.tau)