def train(self): # Sample batch = self.replay_buffer.sample() obs = batch["obs"].to(self.device) acts = batch["acts"].to(self.device) rews = batch["rews"].to(self.device) next_obs = batch["next_obs"].to(self.device) done = batch["done"].to(self.device) # Compute target Q value with torch.no_grad(): next_act = self.target_actor_net(next_obs) next_Q = self.target_critic_net(next_obs, next_act).squeeze(1) target_Q = rews + (1. - done) * self.gamma * next_Q # Compute current Q current_Q = self.critic_net(obs, acts).squeeze(1) # Compute critic loss critic_loss = F.mse_loss(current_Q, target_Q) # Compute actor loss actor_loss = -self.critic_net(obs, self.actor_net(obs)).mean() # Optimize actor net self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # Optimize critic net self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() soft_target_update(self.actor_net, self.target_actor_net, tau=self.tau) soft_target_update(self.critic_net, self.target_critic_net, tau=self.tau) self.train_step += 1 return actor_loss.cpu().item(), critic_loss.cpu().item()
def train(self): # Sample batch = self.data_buffer.sample() obs = batch["obs"].to(self.device) acts = batch["acts"].to(self.device) rews = batch["rews"].to(self.device) next_obs = batch["next_obs"].to(self.device) done = batch["done"].to(self.device) """ Train Critic """ with torch.no_grad(): decode_action_next = self.target_actor_net(next_obs, self.cvae_net.decode) target_q1 = self.target_critic_net1(next_obs, decode_action_next) target_q2 = self.target_critic_net2(next_obs, decode_action_next) target_q = ( self.lmbda * torch.min(target_q1, target_q2) + (1. - self.lmbda) * torch.max(target_q1, target_q2)).squeeze(1) target_q = rews + self.gamma * (1. - done) * target_q current_q1 = self.critic_net1(obs, acts).squeeze(1) current_q2 = self.critic_net2(obs, acts).squeeze(1) critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss( current_q2, target_q) self.critic_optimizer1.zero_grad() self.critic_optimizer2.zero_grad() critic_loss.backward() self.critic_optimizer1.step() self.critic_optimizer2.step() """ Train Actor """ decode_action = self.actor_net(obs, self.cvae_net.decode) actor_loss = -self.critic_net1(obs, decode_action).mean() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() """ Update target networks """ soft_target_update(self.critic_net1, self.target_critic_net1, tau=self.tau) soft_target_update(self.critic_net2, self.target_critic_net2, tau=self.tau) soft_target_update(self.actor_net, self.target_actor_net, tau=self.tau) self.train_step += 1 return critic_loss.cpu().item(), actor_loss.cpu().item()
def train(self): # Sample batch = self.replay_buffer.sample() obs = batch["obs"].to(self.device) acts = batch["acts"].to(self.device) rews = batch["rews"].to(self.device) next_obs = batch["next_obs"].to(self.device) done = batch["done"].to(self.device) # Target Policy Smoothing. Add clipped noise to next actions when computing target Q. with torch.no_grad(): noise = torch.normal(mean=0, std=self.policy_noise, size=acts.size()).to(self.device) noise = noise.clamp(-self.noise_clip, self.noise_clip) next_act = self.target_actor_net(next_obs) + noise next_act = next_act.clamp(-self.action_bound, self.action_bound) # Clipped Double Q-Learning. Compute the min of target Q1 and target Q2 min_target_q = torch.min(self.target_critic_net1(next_obs, next_act), self.target_critic_net2(next_obs, next_act)).squeeze(1) y = rews + self.gamma * (1. - done) * min_target_q current_q1 = self.critic_net1(obs, acts).squeeze(1) current_q2 = self.critic_net2(obs, acts).squeeze(1) # TD3 Loss critic_loss1 = F.mse_loss(current_q1, y) critic_loss2 = F.mse_loss(current_q2, y) # Optimize critic net self.critic_optimizer1.zero_grad() critic_loss1.backward() self.critic_optimizer1.step() self.critic_optimizer2.zero_grad() critic_loss2.backward() self.critic_optimizer2.step() if (self.train_step+1) % self.policy_delay == 0: # Compute actor loss actor_loss = -self.critic_net1(obs, self.actor_net(obs)).mean() # Optimize actor net self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() soft_target_update(self.actor_net, self.target_actor_net, tau=self.tau) soft_target_update(self.critic_net1, self.target_critic_net1, tau=self.tau) soft_target_update(self.critic_net2, self.target_critic_net2, tau=self.tau) else: actor_loss = torch.tensor(0) self.train_step += 1 return actor_loss.cpu().item(), critic_loss1.cpu().item(), critic_loss2.cpu().item()
def train(self): # Sample batch = self.data_buffer.sample() obs = batch["obs"].to(self.device) acts = batch["acts"].to(self.device) rews = batch["rews"].to(self.device) next_obs = batch["next_obs"].to(self.device) done = batch["done"].to(self.device) """ Train the Behaviour cloning policy to be able to take more than 1 sample for MMD. Conditional VAE is used as Behaviour cloning policy in BEAR. """ recon_action, mu, log_std = self.cvae_net(obs, acts) cvae_loss = self.cvae_net.loss_function(recon_action, acts, mu, log_std) self.cvae_optimizer.zero_grad() cvae_loss.backward() self.cvae_optimizer.step() """ Critic Training """ with torch.no_grad(): # generate 10 actions for every next_obs(Same as BCQ) next_obs = torch.repeat_interleave(next_obs, repeats=self.n_target_samples, dim=0).to(self.device) # compute target Q value of generated action target_q1 = self.target_q_net1(next_obs, self.policy_net(next_obs)[0]) target_q2 = self.target_q_net2(next_obs, self.policy_net(next_obs)[0]) # soft clipped double q-learning target_q = self.lmbda * torch.min(target_q1, target_q2) + ( 1. - self.lmbda) * torch.max(target_q1, target_q2) # take max over each action sampled from the generation and perturbation model target_q = target_q.reshape(obs.shape[0], self.n_target_samples, 1).max(1)[0].squeeze(1) target_q = rews + self.gamma * (1. - done) * target_q # compute current Q current_q1 = self.q_net1(obs, acts).squeeze(1) current_q2 = self.q_net2(obs, acts).squeeze(1) # compute critic loss critic_loss1 = F.mse_loss(current_q1, target_q) critic_loss2 = F.mse_loss(current_q2, target_q) self.q_optimizer1.zero_grad() critic_loss1.backward() self.q_optimizer1.step() self.q_optimizer2.zero_grad() critic_loss2.backward() self.q_optimizer2.step() # MMD Loss # sample actions from dataset and current policy(B x N x D) raw_sampled_actions = self.cvae_net.decode_multiple_without_squash( obs, decode_num=self.n_mmd_action_samples, z_device=self.device) raw_actor_actions = self.policy_net.sample_multiple_without_squash( obs, sample_num=self.n_mmd_action_samples) if self.kernel_type == 'gaussian': mmd_loss = self.mmd_loss_gaussian(raw_sampled_actions, raw_actor_actions, sigma=self.mmd_sigma) else: mmd_loss = self.mmd_loss_laplacian(raw_sampled_actions, raw_actor_actions, sigma=self.mmd_sigma) """ Alpha prime training(lagrangian parameter update for MMD loss weight) """ alpha_prime_loss = -(self.log_alpha_prime.exp() * (mmd_loss - self.lagrange_thresh)).mean() self.alpha_prime_optimizer.zero_grad() alpha_prime_loss.backward(retain_graph=True) self.alpha_prime_optimizer.step() self.log_alpha_prime.data.clamp_(min=-5.0, max=10.0) # clip for stability """ Actor Training Actor Loss = alpha_prime * MMD Loss + -minQ(s,a) """ a, log_prob, _ = self.policy_net(obs) min_q = torch.min(self.q_net1(obs, a), self.q_net2(obs, a)).squeeze(1) # policy_loss = (self.alpha * log_prob - min_q).mean() # SAC Type policy_loss = -(min_q.mean()) # BEAR Actor Loss actor_loss = (self.log_alpha_prime.exp() * mmd_loss).mean() if self.train_step > self.warmup_step: actor_loss = policy_loss + actor_loss self.policy_optimizer.zero_grad() actor_loss.backward( ) # the mmd_loss will backward again in alpha_prime_loss. self.policy_optimizer.step() soft_target_update(self.q_net1, self.target_q_net1, tau=self.tau) soft_target_update(self.q_net2, self.target_q_net2, tau=self.tau) self.train_step += 1 return critic_loss1.cpu().item(), critic_loss2.cpu().item( ), policy_loss.cpu().item(), alpha_prime_loss.cpu().item()
def train(self): # Sample batch = self.replay_buffer.sample() obs = batch["obs"] acts = batch["acts"] rews = batch["rews"] next_obs = batch["next_obs"] done = batch["done"] # compute policy Loss a, log_prob = self.policy_net(obs) min_q = torch.min(self.q_net1(obs, a), self.q_net2(obs, a)).squeeze(1) policy_loss = (self.alpha * log_prob - min_q).mean() # compute Q Loss q1 = self.q_net1(obs, acts).squeeze(1) q2 = self.q_net2(obs, acts).squeeze(1) with torch.no_grad(): next_a, next_log_prob = self.policy_net(next_obs) min_target_next_q = torch.min(self.target_q_net1(next_obs, next_a), self.target_q_net2( next_obs, next_a)).squeeze(1) y = rews + self.gamma * (1. - done) * (min_target_next_q - self.alpha * next_log_prob) q_loss1 = F.mse_loss(q1, y) q_loss2 = F.mse_loss(q2, y) # Update policy network parameter(应该先更新策略网络,否则梯度不对) self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() # Update q network1 parameter self.q_optimizer1.zero_grad() q_loss1.backward() self.q_optimizer1.step() # Update q network2 parameter self.q_optimizer2.zero_grad() q_loss2.backward() self.q_optimizer2.step() if self.automatic_entropy_tuning: alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self.alpha = self.log_alpha.exp() else: alpha_loss = torch.tensor(0) self.train_step += 1 soft_target_update(self.q_net1, self.target_q_net1, tau=self.tau) soft_target_update(self.q_net2, self.target_q_net2, tau=self.tau) return q_loss1.item(), q_loss2.item(), policy_loss.item( ), alpha_loss.item()
def train(self): # Sample batch = self.data_buffer.sample() obs = batch["obs"].to(self.device) acts = batch["acts"].to(self.device) rews = batch["rews"].to(self.device) next_obs = batch["next_obs"].to(self.device) done = batch["done"].to(self.device) """ SAC Loss """ # compute policy Loss a, log_prob, _ = self.policy_net(obs) min_q = torch.min(self.q_net1(obs, a), self.q_net2(obs, a)).squeeze(1) policy_loss = (self.alpha * log_prob - min_q).mean() # compute Q Loss q1 = self.q_net1(obs, acts).squeeze(1) q2 = self.q_net2(obs, acts).squeeze(1) with torch.no_grad(): if not self.max_q_backup: next_a, next_log_prob, _ = self.policy_net(next_obs) min_target_next_q = torch.min(self.target_q_net1(next_obs, next_a), self.target_q_net2(next_obs, next_a)).squeeze(1) if self.entropy_backup: # y = rews + self.gamma * (1. - done) * (min_target_next_q - self.alpha * next_log_prob) min_target_next_q = min_target_next_q - self.alpha * next_log_prob else: """when using max q backup""" next_a_temp, _ = self.get_policy_actions(next_obs, n_action_samples=10) target_qf1_values = self.get_actions_values(next_obs, next_a_temp, self.n_action_samples, self.q_net1).max(1)[0] target_qf2_values = self.get_actions_values(next_obs, next_a_temp, self.n_action_samples, self.q_net2).max(1)[0] min_target_next_q = torch.min(target_qf1_values, target_qf2_values).squeeze(1) y = rews + self.gamma * (1. - done) * min_target_next_q q_loss1 = F.mse_loss(q1, y) q_loss2 = F.mse_loss(q2, y) """ CQL Loss Total Loss = SAC loss + min_q_weight * CQL loss """ # Use importance sampling to compute log sum exp of Q(s, a), which is shown in paper's Appendix F. random_sampled_actions = torch.FloatTensor(obs.shape[0] * self.n_action_samples, acts.shape[-1]).uniform_(-1, 1).to(self.device) curr_sampled_actions, curr_log_probs = self.get_policy_actions(obs, self.n_action_samples) # This is different from the paper because it samples not only from the current state, but also from the next state next_sampled_actions, next_log_probs = self.get_policy_actions(next_obs, self.n_action_samples) q1_rand = self.get_actions_values(obs, random_sampled_actions, self.n_action_samples, self.q_net1) q2_rand = self.get_actions_values(obs, random_sampled_actions, self.n_action_samples, self.q_net2) q1_curr = self.get_actions_values(obs, curr_sampled_actions, self.n_action_samples, self.q_net1) q2_curr = self.get_actions_values(obs, curr_sampled_actions, self.n_action_samples, self.q_net2) q1_next = self.get_actions_values(obs, next_sampled_actions, self.n_action_samples, self.q_net1) q2_next = self.get_actions_values(obs, next_sampled_actions, self.n_action_samples, self.q_net2) random_density = np.log(0.5 ** acts.shape[-1]) cat_q1 = torch.cat([q1_rand - random_density, q1_next - next_log_probs, q1_curr - curr_log_probs], dim=1) cat_q2 = torch.cat([q2_rand - random_density, q2_next - next_log_probs, q2_curr - curr_log_probs], dim=1) min_qf1_loss = torch.logsumexp(cat_q1, dim=1).mean() min_qf2_loss = torch.logsumexp(cat_q2, dim=1).mean() min_qf1_loss = self.min_q_weight * (min_qf1_loss - q1.mean()) min_qf2_loss = self.min_q_weight * (min_qf2_loss - q2.mean()) if self.with_lagrange: alpha_prime = torch.clamp(self.log_alpha_prime.exp(), min=0.0, max=1e6) # the lagrange_thresh has no effect on the gradient of policy, # but it has an effect on the gradient of alpha_prime min_qf1_loss = alpha_prime * (min_qf1_loss - self.lagrange_thresh) min_qf2_loss = alpha_prime * (min_qf2_loss - self.lagrange_thresh) alpha_prime_loss = -(min_qf1_loss + min_qf2_loss) * 0.5 self.alpha_prime_optimizer.zero_grad() alpha_prime_loss.backward(retain_graph=True) # the min_qf_loss will backward again latter, so retain graph. self.alpha_prime_optimizer.step() else: alpha_prime_loss = torch.tensor(0) q_loss1 = q_loss1 + min_qf1_loss q_loss2 = q_loss2 + min_qf2_loss """ Update networks """ # Update policy network parameter # policy network's update should be done before updating q network, or there will make some errors self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() # Update q network1 parameter self.q_optimizer1.zero_grad() q_loss1.backward(retain_graph=True) self.q_optimizer1.step() # Update q network2 parameter self.q_optimizer2.zero_grad() q_loss2.backward(retain_graph=True) self.q_optimizer2.step() if self.auto_alpha_tuning: alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self.alpha = self.log_alpha.exp() else: alpha_loss = torch.tensor(0) soft_target_update(self.q_net1, self.target_q_net1, tau=self.tau) soft_target_update(self.q_net2, self.target_q_net2, tau=self.tau) self.train_step += 1 return q_loss1.cpu().item(), q_loss2.cpu().item(), policy_loss.cpu().item(), alpha_loss.cpu().item(), alpha_prime_loss.cpu().item()
def train(self): # Sample batch = self.data_buffer.sample() obs = batch["obs"].to(self.device) acts = batch["acts"].to(self.device) rews = batch["rews"].to(self.device) next_obs = batch["next_obs"].to(self.device) done = batch["done"].to(self.device) """ CVAE Loss (the generation model) """ recon_action, mu, log_std = self.cvae_net(obs, acts) cvae_loss = self.cvae_net.loss_function(recon_action, acts, mu, log_std) self.cvae_optimizer.zero_grad() cvae_loss.backward() self.cvae_optimizer.step() """ Critic Loss """ with torch.no_grad(): # generate 10 actions for every next_obs next_obs = torch.repeat_interleave(next_obs, repeats=10, dim=0).to(self.device) generated_action = self.cvae_net.decode(next_obs, z_device=self.device) # perturb the generated action perturbed_action = self.target_perturbation_net( next_obs, generated_action) # compute target Q value of perturbed action target_q1 = self.target_critic_net1(next_obs, perturbed_action) target_q2 = self.target_critic_net2(next_obs, perturbed_action) # soft clipped double q-learning target_q = self.lmbda * torch.min(target_q1, target_q2) + ( 1. - self.lmbda) * torch.max(target_q1, target_q2) # take max over each action sampled from the generation and perturbation model target_q = target_q.reshape(obs.shape[0], 10, 1).max(1)[0].squeeze(1) target_q = rews + self.gamma * (1. - done) * target_q # compute current Q current_q1 = self.critic_net1(obs, acts).squeeze(1) current_q2 = self.critic_net2(obs, acts).squeeze(1) # compute critic loss critic_loss1 = F.mse_loss(current_q1, target_q) critic_loss2 = F.mse_loss(current_q2, target_q) self.critic_optimizer1.zero_grad() critic_loss1.backward() self.critic_optimizer1.step() self.critic_optimizer2.zero_grad() critic_loss2.backward() self.critic_optimizer2.step() """ Perturbation Loss """ generated_action_ = self.cvae_net.decode(obs, z_device=self.device) perturbed_action_ = self.perturbation_net(obs, generated_action_) perturbation_loss = -self.critic_net1(obs, perturbed_action_).mean() self.perturbation_optimizer.zero_grad() perturbation_loss.backward() self.perturbation_optimizer.step() """ Update target networks """ soft_target_update(self.critic_net1, self.target_critic_net1, tau=self.tau) soft_target_update(self.critic_net2, self.target_critic_net2, tau=self.tau) soft_target_update(self.perturbation_net, self.target_perturbation_net, tau=self.tau) self.train_step += 1 return cvae_loss.cpu().item(), ( critic_loss1 + critic_loss2).cpu().item(), perturbation_loss.cpu().item()
def train(self): # Sample batch = self.data_buffer.sample() obs = batch["obs"].to(self.device) acts = batch["acts"].to(self.device) rews = batch["rews"].to(self.device) next_obs = batch["next_obs"].to(self.device) done = batch["done"].to(self.device) # compute policy Loss a, log_prob, _ = self.policy_net(obs) min_q = torch.min(self.q_net1(obs, a), self.q_net2(obs, a)).squeeze(1) policy_loss = (self.alpha * log_prob - min_q).mean() # compute Q Loss q1 = self.q_net1(obs, acts).squeeze(1) q2 = self.q_net2(obs, acts).squeeze(1) with torch.no_grad(): next_a, next_log_prob, _ = self.policy_net(next_obs) min_target_next_q = torch.min(self.target_q_net1(next_obs, next_a), self.target_q_net2( next_obs, next_a)).squeeze(1) y = rews + self.gamma * (1. - done) * (min_target_next_q - self.alpha * next_log_prob) q_loss1 = F.mse_loss(q1, y) q_loss2 = F.mse_loss(q2, y) # Update policy network parameter # policy network's update should be done before updating q network, or there will make some errors self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() # Update q network1 parameter self.q_optimizer1.zero_grad() q_loss1.backward() self.q_optimizer1.step() # Update q network2 parameter self.q_optimizer2.zero_grad() q_loss2.backward() self.q_optimizer2.step() if self.auto_alpha_tuning: alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() self.alpha = self.log_alpha.exp() else: alpha_loss = torch.tensor(0) self.train_step += 1 soft_target_update(self.q_net1, self.target_q_net1, tau=self.tau) soft_target_update(self.q_net2, self.target_q_net2, tau=self.tau) return q_loss1.cpu().item(), q_loss2.cpu().item(), policy_loss.cpu( ).item(), alpha_loss.cpu().item()