예제 #1
0
    def train(self):

        # Sample
        batch = self.replay_buffer.sample()
        obs = batch["obs"].to(self.device)
        acts = batch["acts"].to(self.device)
        rews = batch["rews"].to(self.device)
        next_obs = batch["next_obs"].to(self.device)
        done = batch["done"].to(self.device)

        # Compute target Q value
        with torch.no_grad():
            next_act = self.target_actor_net(next_obs)
            next_Q = self.target_critic_net(next_obs, next_act).squeeze(1)
            target_Q = rews + (1. - done) * self.gamma * next_Q

        # Compute current Q
        current_Q = self.critic_net(obs, acts).squeeze(1)

        # Compute critic loss
        critic_loss = F.mse_loss(current_Q, target_Q)

        # Compute actor loss
        actor_loss = -self.critic_net(obs, self.actor_net(obs)).mean()

        # Optimize actor net
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # Optimize critic net
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        soft_target_update(self.actor_net, self.target_actor_net, tau=self.tau)
        soft_target_update(self.critic_net,
                           self.target_critic_net,
                           tau=self.tau)

        self.train_step += 1
        return actor_loss.cpu().item(), critic_loss.cpu().item()
예제 #2
0
    def train(self):
        # Sample
        batch = self.data_buffer.sample()
        obs = batch["obs"].to(self.device)
        acts = batch["acts"].to(self.device)
        rews = batch["rews"].to(self.device)
        next_obs = batch["next_obs"].to(self.device)
        done = batch["done"].to(self.device)
        """
        Train Critic
        """
        with torch.no_grad():
            decode_action_next = self.target_actor_net(next_obs,
                                                       self.cvae_net.decode)

            target_q1 = self.target_critic_net1(next_obs, decode_action_next)
            target_q2 = self.target_critic_net2(next_obs, decode_action_next)

            target_q = (
                self.lmbda * torch.min(target_q1, target_q2) +
                (1. - self.lmbda) * torch.max(target_q1, target_q2)).squeeze(1)
            target_q = rews + self.gamma * (1. - done) * target_q

        current_q1 = self.critic_net1(obs, acts).squeeze(1)
        current_q2 = self.critic_net2(obs, acts).squeeze(1)

        critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(
            current_q2, target_q)

        self.critic_optimizer1.zero_grad()
        self.critic_optimizer2.zero_grad()
        critic_loss.backward()
        self.critic_optimizer1.step()
        self.critic_optimizer2.step()
        """
        Train Actor
        """
        decode_action = self.actor_net(obs, self.cvae_net.decode)
        actor_loss = -self.critic_net1(obs, decode_action).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        """
        Update target networks
        """
        soft_target_update(self.critic_net1,
                           self.target_critic_net1,
                           tau=self.tau)
        soft_target_update(self.critic_net2,
                           self.target_critic_net2,
                           tau=self.tau)
        soft_target_update(self.actor_net, self.target_actor_net, tau=self.tau)

        self.train_step += 1

        return critic_loss.cpu().item(), actor_loss.cpu().item()
예제 #3
0
    def train(self):

        # Sample
        batch = self.replay_buffer.sample()
        obs = batch["obs"].to(self.device)
        acts = batch["acts"].to(self.device)
        rews = batch["rews"].to(self.device)
        next_obs = batch["next_obs"].to(self.device)
        done = batch["done"].to(self.device)

        # Target Policy Smoothing. Add clipped noise to next actions when computing target Q.
        with torch.no_grad():
            noise = torch.normal(mean=0, std=self.policy_noise, size=acts.size()).to(self.device)
            noise = noise.clamp(-self.noise_clip, self.noise_clip)
            next_act = self.target_actor_net(next_obs) + noise
            next_act = next_act.clamp(-self.action_bound, self.action_bound)

            # Clipped Double Q-Learning. Compute the min of target Q1 and target Q2
            min_target_q = torch.min(self.target_critic_net1(next_obs, next_act),
                                     self.target_critic_net2(next_obs, next_act)).squeeze(1)
            y = rews + self.gamma * (1. - done) * min_target_q

        current_q1 = self.critic_net1(obs, acts).squeeze(1)
        current_q2 = self.critic_net2(obs, acts).squeeze(1)

        # TD3 Loss
        critic_loss1 = F.mse_loss(current_q1, y)
        critic_loss2 = F.mse_loss(current_q2, y)

        # Optimize critic net
        self.critic_optimizer1.zero_grad()
        critic_loss1.backward()
        self.critic_optimizer1.step()

        self.critic_optimizer2.zero_grad()
        critic_loss2.backward()
        self.critic_optimizer2.step()

        if (self.train_step+1) % self.policy_delay == 0:
            # Compute actor loss
            actor_loss = -self.critic_net1(obs, self.actor_net(obs)).mean()
            # Optimize actor net
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            soft_target_update(self.actor_net, self.target_actor_net, tau=self.tau)
            soft_target_update(self.critic_net1, self.target_critic_net1, tau=self.tau)
            soft_target_update(self.critic_net2, self.target_critic_net2, tau=self.tau)
        else:
            actor_loss = torch.tensor(0)

        self.train_step += 1

        return actor_loss.cpu().item(), critic_loss1.cpu().item(), critic_loss2.cpu().item()
예제 #4
0
    def train(self):
        # Sample
        batch = self.data_buffer.sample()
        obs = batch["obs"].to(self.device)
        acts = batch["acts"].to(self.device)
        rews = batch["rews"].to(self.device)
        next_obs = batch["next_obs"].to(self.device)
        done = batch["done"].to(self.device)
        """
        Train the Behaviour cloning policy to be able to take more than 1 sample for MMD.
        Conditional VAE is used as Behaviour cloning policy in BEAR.
        """
        recon_action, mu, log_std = self.cvae_net(obs, acts)
        cvae_loss = self.cvae_net.loss_function(recon_action, acts, mu,
                                                log_std)

        self.cvae_optimizer.zero_grad()
        cvae_loss.backward()
        self.cvae_optimizer.step()
        """
        Critic Training
        """
        with torch.no_grad():
            # generate 10 actions for every next_obs(Same as BCQ)
            next_obs = torch.repeat_interleave(next_obs,
                                               repeats=self.n_target_samples,
                                               dim=0).to(self.device)
            # compute target Q value of generated action
            target_q1 = self.target_q_net1(next_obs,
                                           self.policy_net(next_obs)[0])
            target_q2 = self.target_q_net2(next_obs,
                                           self.policy_net(next_obs)[0])
            # soft clipped double q-learning
            target_q = self.lmbda * torch.min(target_q1, target_q2) + (
                1. - self.lmbda) * torch.max(target_q1, target_q2)
            # take max over each action sampled from the generation and perturbation model
            target_q = target_q.reshape(obs.shape[0], self.n_target_samples,
                                        1).max(1)[0].squeeze(1)
            target_q = rews + self.gamma * (1. - done) * target_q

        # compute current Q
        current_q1 = self.q_net1(obs, acts).squeeze(1)
        current_q2 = self.q_net2(obs, acts).squeeze(1)
        # compute critic loss
        critic_loss1 = F.mse_loss(current_q1, target_q)
        critic_loss2 = F.mse_loss(current_q2, target_q)

        self.q_optimizer1.zero_grad()
        critic_loss1.backward()
        self.q_optimizer1.step()

        self.q_optimizer2.zero_grad()
        critic_loss2.backward()
        self.q_optimizer2.step()

        # MMD Loss
        # sample actions from dataset and current policy(B x N x D)
        raw_sampled_actions = self.cvae_net.decode_multiple_without_squash(
            obs, decode_num=self.n_mmd_action_samples, z_device=self.device)
        raw_actor_actions = self.policy_net.sample_multiple_without_squash(
            obs, sample_num=self.n_mmd_action_samples)
        if self.kernel_type == 'gaussian':
            mmd_loss = self.mmd_loss_gaussian(raw_sampled_actions,
                                              raw_actor_actions,
                                              sigma=self.mmd_sigma)
        else:
            mmd_loss = self.mmd_loss_laplacian(raw_sampled_actions,
                                               raw_actor_actions,
                                               sigma=self.mmd_sigma)
        """
        Alpha prime training(lagrangian parameter update for MMD loss weight)
        """
        alpha_prime_loss = -(self.log_alpha_prime.exp() *
                             (mmd_loss - self.lagrange_thresh)).mean()
        self.alpha_prime_optimizer.zero_grad()
        alpha_prime_loss.backward(retain_graph=True)
        self.alpha_prime_optimizer.step()

        self.log_alpha_prime.data.clamp_(min=-5.0,
                                         max=10.0)  # clip for stability
        """
        Actor Training
        Actor Loss = alpha_prime * MMD Loss + -minQ(s,a)
        """
        a, log_prob, _ = self.policy_net(obs)
        min_q = torch.min(self.q_net1(obs, a), self.q_net2(obs, a)).squeeze(1)
        # policy_loss = (self.alpha * log_prob - min_q).mean()  # SAC Type
        policy_loss = -(min_q.mean())

        # BEAR Actor Loss
        actor_loss = (self.log_alpha_prime.exp() * mmd_loss).mean()
        if self.train_step > self.warmup_step:
            actor_loss = policy_loss + actor_loss
        self.policy_optimizer.zero_grad()
        actor_loss.backward(
        )  # the mmd_loss will backward again in alpha_prime_loss.
        self.policy_optimizer.step()

        soft_target_update(self.q_net1, self.target_q_net1, tau=self.tau)
        soft_target_update(self.q_net2, self.target_q_net2, tau=self.tau)

        self.train_step += 1

        return critic_loss1.cpu().item(), critic_loss2.cpu().item(
        ), policy_loss.cpu().item(), alpha_prime_loss.cpu().item()
예제 #5
0
    def train(self):

        # Sample
        batch = self.replay_buffer.sample()
        obs = batch["obs"]
        acts = batch["acts"]
        rews = batch["rews"]
        next_obs = batch["next_obs"]
        done = batch["done"]

        # compute policy Loss
        a, log_prob = self.policy_net(obs)
        min_q = torch.min(self.q_net1(obs, a), self.q_net2(obs, a)).squeeze(1)
        policy_loss = (self.alpha * log_prob - min_q).mean()

        # compute Q Loss
        q1 = self.q_net1(obs, acts).squeeze(1)
        q2 = self.q_net2(obs, acts).squeeze(1)
        with torch.no_grad():
            next_a, next_log_prob = self.policy_net(next_obs)
            min_target_next_q = torch.min(self.target_q_net1(next_obs, next_a),
                                          self.target_q_net2(
                                              next_obs, next_a)).squeeze(1)
            y = rews + self.gamma * (1. - done) * (min_target_next_q -
                                                   self.alpha * next_log_prob)

        q_loss1 = F.mse_loss(q1, y)
        q_loss2 = F.mse_loss(q2, y)

        # Update policy network parameter(应该先更新策略网络,否则梯度不对)
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        # Update q network1 parameter
        self.q_optimizer1.zero_grad()
        q_loss1.backward()
        self.q_optimizer1.step()

        # Update q network2 parameter
        self.q_optimizer2.zero_grad()
        q_loss2.backward()
        self.q_optimizer2.step()

        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_prob + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()

            self.alpha = self.log_alpha.exp()
        else:
            alpha_loss = torch.tensor(0)

        self.train_step += 1

        soft_target_update(self.q_net1, self.target_q_net1, tau=self.tau)
        soft_target_update(self.q_net2, self.target_q_net2, tau=self.tau)

        return q_loss1.item(), q_loss2.item(), policy_loss.item(
        ), alpha_loss.item()
예제 #6
0
    def train(self):

        # Sample
        batch = self.data_buffer.sample()
        obs = batch["obs"].to(self.device)
        acts = batch["acts"].to(self.device)
        rews = batch["rews"].to(self.device)
        next_obs = batch["next_obs"].to(self.device)
        done = batch["done"].to(self.device)

        """
        SAC Loss
        """
        # compute policy Loss
        a, log_prob, _ = self.policy_net(obs)
        min_q = torch.min(self.q_net1(obs, a), self.q_net2(obs, a)).squeeze(1)
        policy_loss = (self.alpha * log_prob - min_q).mean()

        # compute Q Loss
        q1 = self.q_net1(obs, acts).squeeze(1)
        q2 = self.q_net2(obs, acts).squeeze(1)
        with torch.no_grad():
            if not self.max_q_backup:
                next_a, next_log_prob, _ = self.policy_net(next_obs)
                min_target_next_q = torch.min(self.target_q_net1(next_obs, next_a),
                                              self.target_q_net2(next_obs, next_a)).squeeze(1)
                if self.entropy_backup:
                    # y = rews + self.gamma * (1. - done) * (min_target_next_q - self.alpha * next_log_prob)
                    min_target_next_q = min_target_next_q - self.alpha * next_log_prob
            else:
                """when using max q backup"""
                next_a_temp, _ = self.get_policy_actions(next_obs, n_action_samples=10)
                target_qf1_values = self.get_actions_values(next_obs, next_a_temp, self.n_action_samples, self.q_net1).max(1)[0]
                target_qf2_values = self.get_actions_values(next_obs, next_a_temp, self.n_action_samples, self.q_net2).max(1)[0]
                min_target_next_q = torch.min(target_qf1_values, target_qf2_values).squeeze(1)

            y = rews + self.gamma * (1. - done) * min_target_next_q

        q_loss1 = F.mse_loss(q1, y)
        q_loss2 = F.mse_loss(q2, y)

        """
        CQL Loss
        Total Loss = SAC loss + min_q_weight * CQL loss
        """
        # Use importance sampling to compute log sum exp of Q(s, a), which is shown in paper's Appendix F.
        random_sampled_actions = torch.FloatTensor(obs.shape[0] * self.n_action_samples, acts.shape[-1]).uniform_(-1, 1).to(self.device)
        curr_sampled_actions, curr_log_probs = self.get_policy_actions(obs, self.n_action_samples)
        # This is different from the paper because it samples not only from the current state, but also from the next state
        next_sampled_actions, next_log_probs = self.get_policy_actions(next_obs, self.n_action_samples)
        q1_rand = self.get_actions_values(obs, random_sampled_actions, self.n_action_samples, self.q_net1)
        q2_rand = self.get_actions_values(obs, random_sampled_actions, self.n_action_samples, self.q_net2)
        q1_curr = self.get_actions_values(obs, curr_sampled_actions, self.n_action_samples, self.q_net1)
        q2_curr = self.get_actions_values(obs, curr_sampled_actions, self.n_action_samples, self.q_net2)
        q1_next = self.get_actions_values(obs, next_sampled_actions, self.n_action_samples, self.q_net1)
        q2_next = self.get_actions_values(obs, next_sampled_actions, self.n_action_samples, self.q_net2)

        random_density = np.log(0.5 ** acts.shape[-1])

        cat_q1 = torch.cat([q1_rand - random_density, q1_next - next_log_probs, q1_curr - curr_log_probs], dim=1)
        cat_q2 = torch.cat([q2_rand - random_density, q2_next - next_log_probs, q2_curr - curr_log_probs], dim=1)

        min_qf1_loss = torch.logsumexp(cat_q1, dim=1).mean()
        min_qf2_loss = torch.logsumexp(cat_q2, dim=1).mean()

        min_qf1_loss = self.min_q_weight * (min_qf1_loss - q1.mean())
        min_qf2_loss = self.min_q_weight * (min_qf2_loss - q2.mean())

        if self.with_lagrange:
            alpha_prime = torch.clamp(self.log_alpha_prime.exp(), min=0.0, max=1e6)
            # the lagrange_thresh has no effect on the gradient of policy,
            # but it has an effect on the gradient of alpha_prime
            min_qf1_loss = alpha_prime * (min_qf1_loss - self.lagrange_thresh)
            min_qf2_loss = alpha_prime * (min_qf2_loss - self.lagrange_thresh)

            alpha_prime_loss = -(min_qf1_loss + min_qf2_loss) * 0.5

            self.alpha_prime_optimizer.zero_grad()
            alpha_prime_loss.backward(retain_graph=True)  # the min_qf_loss will backward again latter, so retain graph.
            self.alpha_prime_optimizer.step()
        else:
            alpha_prime_loss = torch.tensor(0)

        q_loss1 = q_loss1 + min_qf1_loss
        q_loss2 = q_loss2 + min_qf2_loss

        """
        Update networks
        """
        # Update policy network parameter
        # policy network's update should be done before updating q network, or there will make some errors
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        # Update q network1 parameter
        self.q_optimizer1.zero_grad()
        q_loss1.backward(retain_graph=True)
        self.q_optimizer1.step()

        # Update q network2 parameter
        self.q_optimizer2.zero_grad()
        q_loss2.backward(retain_graph=True)
        self.q_optimizer2.step()

        if self.auto_alpha_tuning:
            alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()

            self.alpha = self.log_alpha.exp()
        else:
            alpha_loss = torch.tensor(0)

        soft_target_update(self.q_net1, self.target_q_net1, tau=self.tau)
        soft_target_update(self.q_net2, self.target_q_net2, tau=self.tau)

        self.train_step += 1

        return q_loss1.cpu().item(), q_loss2.cpu().item(), policy_loss.cpu().item(), alpha_loss.cpu().item(), alpha_prime_loss.cpu().item()
예제 #7
0
    def train(self):

        # Sample
        batch = self.data_buffer.sample()
        obs = batch["obs"].to(self.device)
        acts = batch["acts"].to(self.device)
        rews = batch["rews"].to(self.device)
        next_obs = batch["next_obs"].to(self.device)
        done = batch["done"].to(self.device)
        """
        CVAE Loss (the generation model)
        """
        recon_action, mu, log_std = self.cvae_net(obs, acts)
        cvae_loss = self.cvae_net.loss_function(recon_action, acts, mu,
                                                log_std)

        self.cvae_optimizer.zero_grad()
        cvae_loss.backward()
        self.cvae_optimizer.step()
        """
        Critic Loss
        """
        with torch.no_grad():
            # generate 10 actions for every next_obs
            next_obs = torch.repeat_interleave(next_obs, repeats=10,
                                               dim=0).to(self.device)
            generated_action = self.cvae_net.decode(next_obs,
                                                    z_device=self.device)
            # perturb the generated action
            perturbed_action = self.target_perturbation_net(
                next_obs, generated_action)
            # compute target Q value of perturbed action
            target_q1 = self.target_critic_net1(next_obs, perturbed_action)
            target_q2 = self.target_critic_net2(next_obs, perturbed_action)
            # soft clipped double q-learning
            target_q = self.lmbda * torch.min(target_q1, target_q2) + (
                1. - self.lmbda) * torch.max(target_q1, target_q2)
            # take max over each action sampled from the generation and perturbation model
            target_q = target_q.reshape(obs.shape[0], 10,
                                        1).max(1)[0].squeeze(1)
            target_q = rews + self.gamma * (1. - done) * target_q

        # compute current Q
        current_q1 = self.critic_net1(obs, acts).squeeze(1)
        current_q2 = self.critic_net2(obs, acts).squeeze(1)
        # compute critic loss
        critic_loss1 = F.mse_loss(current_q1, target_q)
        critic_loss2 = F.mse_loss(current_q2, target_q)

        self.critic_optimizer1.zero_grad()
        critic_loss1.backward()
        self.critic_optimizer1.step()

        self.critic_optimizer2.zero_grad()
        critic_loss2.backward()
        self.critic_optimizer2.step()
        """
        Perturbation Loss
        """
        generated_action_ = self.cvae_net.decode(obs, z_device=self.device)
        perturbed_action_ = self.perturbation_net(obs, generated_action_)
        perturbation_loss = -self.critic_net1(obs, perturbed_action_).mean()

        self.perturbation_optimizer.zero_grad()
        perturbation_loss.backward()
        self.perturbation_optimizer.step()
        """
        Update target networks
        """
        soft_target_update(self.critic_net1,
                           self.target_critic_net1,
                           tau=self.tau)
        soft_target_update(self.critic_net2,
                           self.target_critic_net2,
                           tau=self.tau)
        soft_target_update(self.perturbation_net,
                           self.target_perturbation_net,
                           tau=self.tau)

        self.train_step += 1

        return cvae_loss.cpu().item(), (
            critic_loss1 +
            critic_loss2).cpu().item(), perturbation_loss.cpu().item()
예제 #8
0
    def train(self):

        # Sample
        batch = self.data_buffer.sample()
        obs = batch["obs"].to(self.device)
        acts = batch["acts"].to(self.device)
        rews = batch["rews"].to(self.device)
        next_obs = batch["next_obs"].to(self.device)
        done = batch["done"].to(self.device)

        # compute policy Loss
        a, log_prob, _ = self.policy_net(obs)
        min_q = torch.min(self.q_net1(obs, a), self.q_net2(obs, a)).squeeze(1)
        policy_loss = (self.alpha * log_prob - min_q).mean()

        # compute Q Loss
        q1 = self.q_net1(obs, acts).squeeze(1)
        q2 = self.q_net2(obs, acts).squeeze(1)
        with torch.no_grad():
            next_a, next_log_prob, _ = self.policy_net(next_obs)
            min_target_next_q = torch.min(self.target_q_net1(next_obs, next_a),
                                          self.target_q_net2(
                                              next_obs, next_a)).squeeze(1)
            y = rews + self.gamma * (1. - done) * (min_target_next_q -
                                                   self.alpha * next_log_prob)

        q_loss1 = F.mse_loss(q1, y)
        q_loss2 = F.mse_loss(q2, y)

        # Update policy network parameter
        # policy network's update should be done before updating q network, or there will make some errors
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        # Update q network1 parameter
        self.q_optimizer1.zero_grad()
        q_loss1.backward()
        self.q_optimizer1.step()

        # Update q network2 parameter
        self.q_optimizer2.zero_grad()
        q_loss2.backward()
        self.q_optimizer2.step()

        if self.auto_alpha_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_prob + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()

            self.alpha = self.log_alpha.exp()
        else:
            alpha_loss = torch.tensor(0)

        self.train_step += 1

        soft_target_update(self.q_net1, self.target_q_net1, tau=self.tau)
        soft_target_update(self.q_net2, self.target_q_net2, tau=self.tau)

        return q_loss1.cpu().item(), q_loss2.cpu().item(), policy_loss.cpu(
        ).item(), alpha_loss.cpu().item()