def update_model(self) -> Tuple[torch.Tensor, torch.Tensor]: """Train the model after each episode.""" experiences = self.memory.sample() states, actions, rewards, next_states, dones = experiences # G_t = r + gamma * v(s_{t+1}) if state != Terminal # = r otherwise masks = 1 - dones next_actions = self.actor_target(next_states) next_values = self.critic_target( torch.cat((next_states, next_actions), dim=-1)) curr_returns = rewards + self.hyper_params[ "GAMMA"] * next_values * masks curr_returns = curr_returns.to(device) # train critic values = self.critic(torch.cat((states, actions), dim=-1)) critic_loss = F.mse_loss(values, curr_returns) self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() # train actor actions = self.actor(states) actor_loss = -self.critic(torch.cat((states, actions), dim=-1)).mean() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # update target networks tau = self.hyper_params["TAU"] common_utils.soft_update(self.actor, self.actor_target, tau) common_utils.soft_update(self.critic, self.critic_target, tau) return actor_loss.data, critic_loss.data
def update_model(self) -> Tuple[torch.Tensor, torch.Tensor]: """Train the model after each episode.""" experiences_1 = self.memory.sample(self.beta) states, actions = experiences_1[:2] weights, indices, eps_d = experiences_1[-3:] gamma = self.hyper_params["GAMMA"] # train critic gradient_clip_cr = self.hyper_params["GRADIENT_CLIP_CR"] critic_loss_element_wise = self._get_critic_loss(experiences_1, gamma) critic_loss = torch.mean(critic_loss_element_wise * weights) if self.use_n_step: experiences_n = self.memory_n.sample(indices) gamma = gamma**self.hyper_params["N_STEP"] critic_loss_n_element_wise = self._get_critic_loss( experiences_n, gamma) # to update loss and priorities lambda1 = self.hyper_params["LAMBDA1"] critic_loss_element_wise += critic_loss_n_element_wise * lambda1 critic_loss = torch.mean(critic_loss_element_wise * weights) self.critic_optimizer.zero_grad() critic_loss.backward() nn.utils.clip_grad_norm_(self.critic.parameters(), gradient_clip_cr) self.critic_optimizer.step() # train actor gradient_clip_ac = self.hyper_params["GRADIENT_CLIP_AC"] actions = self.actor(states) actor_loss_element_wise = -self.critic( torch.cat((states, actions), dim=-1)) actor_loss = torch.mean(actor_loss_element_wise * weights) self.actor_optimizer.zero_grad() actor_loss.backward() nn.utils.clip_grad_norm_(self.actor.parameters(), gradient_clip_ac) self.actor_optimizer.step() # update target networks tau = self.hyper_params["TAU"] common_utils.soft_update(self.actor, self.actor_target, tau) common_utils.soft_update(self.critic, self.critic_target, tau) # update priorities new_priorities = critic_loss_element_wise new_priorities += self.hyper_params[ "LAMBDA3"] * actor_loss_element_wise.pow(2) new_priorities += self.hyper_params["PER_EPS"] new_priorities = new_priorities.data.cpu().numpy().squeeze() new_priorities += eps_d self.memory.update_priorities(indices, new_priorities) # increase beta fraction = min(float(self.i_episode) / self.args.episode_num, 1.0) self.beta = self.beta + fraction * (1.0 - self.beta) return actor_loss.item(), critic_loss.item()
def update_model(self) -> Tuple[torch.Tensor, torch.Tensor]: """Train the model after each episode.""" # 1 step loss experiences_1 = self.memory.sample(self.beta) weights, indices = experiences_1[-2:] gamma = self.hyper_params["GAMMA"] dq_loss_element_wise, q_values = self._get_dqn_loss( experiences_1, gamma) dq_loss = torch.mean(dq_loss_element_wise * weights) # n step loss if self.use_n_step: experiences_n = self.memory_n.sample(indices) gamma = self.hyper_params["GAMMA"]**self.hyper_params["N_STEP"] dq_loss_n_element_wise, q_values_n = self._get_dqn_loss( experiences_n, gamma) # to update loss and priorities q_values = 0.5 * (q_values + q_values_n) dq_loss_element_wise += (dq_loss_n_element_wise * self.hyper_params["W_N_STEP"]) dq_loss = torch.mean(dq_loss_element_wise * weights) # q_value regularization q_regular = torch.norm(q_values, 2).mean() * self.hyper_params["W_Q_REG"] # total loss loss = dq_loss + q_regular self.dqn_optimizer.zero_grad() loss.backward() clip_grad_norm_(self.dqn.parameters(), self.hyper_params["GRADIENT_CLIP"]) self.dqn_optimizer.step() # update target networks tau = self.hyper_params["TAU"] common_utils.soft_update(self.dqn, self.dqn_target, tau) # update priorities in PER loss_for_prior = dq_loss_element_wise.detach().cpu().numpy() new_priorities = loss_for_prior + self.hyper_params["PER_EPS"] self.memory.update_priorities(indices, new_priorities) # increase beta fraction = min(float(self.i_episode) / self.args.episode_num, 1.0) self.beta = self.beta + fraction * (1.0 - self.beta) if self.hyper_params["USE_NOISY_NET"]: self.dqn.reset_noise() self.dqn_target.reset_noise() return loss.data, q_values.mean().data
def update_model(self) -> Tuple[torch.Tensor, torch.Tensor]: """Train the model after each episode.""" experiences = self.memory.sample(self.beta) states, actions, rewards, next_states, dones, weights, indexes, eps_d = ( experiences) # G_t = r + gamma * v(s_{t+1}) if state != Terminal # = r otherwise masks = 1 - dones next_actions = self.actor_target(next_states) next_states_actions = torch.cat((next_states, next_actions), dim=-1) next_values = self.critic_target(next_states_actions) curr_returns = rewards + self.hyper_params[ "GAMMA"] * next_values * masks curr_returns = curr_returns.to(device).detach() # train critic values = self.critic(torch.cat((states, actions), dim=-1)) critic_loss_element_wise = (values - curr_returns).pow(2) critic_loss = torch.mean(critic_loss_element_wise * weights) self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() # train actor actions = self.actor(states) actor_loss_element_wise = -self.critic( torch.cat((states, actions), dim=-1)) actor_loss = torch.mean(actor_loss_element_wise * weights) self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # update target networks tau = self.hyper_params["TAU"] common_utils.soft_update(self.actor, self.actor_target, tau) common_utils.soft_update(self.critic, self.critic_target, tau) # update priorities new_priorities = critic_loss_element_wise new_priorities += self.hyper_params[ "LAMBDA3"] * actor_loss_element_wise.pow(2) new_priorities += self.hyper_params["PER_EPS"] new_priorities = new_priorities.data.cpu().numpy().squeeze() new_priorities += eps_d self.memory.update_priorities(indexes, new_priorities) # increase beta fraction = min( float(self.i_episode) / self.args.max_episode_steps, 1.0) self.beta = self.beta + fraction * (1.0 - self.beta) return actor_loss.data, critic_loss.data
def update_model(self) -> Tuple[torch.Tensor, ...]: """Train the model after each episode.""" experiences = self.memory.sample(self.beta) states, actions, rewards, next_states, dones, weights, indexes = experiences q_values = self.dqn(states, self.epsilon) next_q_values = self.dqn(next_states, self.epsilon) next_target_q_values = self.dqn_target(next_states, self.epsilon) curr_q_value = q_values.gather(1, actions.long().unsqueeze(1)) next_q_value = next_target_q_values.gather( # Double DQN 1, next_q_values.argmax(1).unsqueeze(1)) # G_t = r + gamma * v(s_{t+1}) if state != Terminal # = r otherwise masks = 1 - dones target = rewards + self.hyper_params["GAMMA"] * next_q_value * masks target = target.to(device) # calculate dq loss dq_loss_element_wise = (target - curr_q_value).pow(2) dq_loss = torch.mean(dq_loss_element_wise * weights) # q_value regularization q_regular = torch.norm(q_values, 2).mean() * self.hyper_params["W_Q_REG"] # total loss loss = dq_loss + q_regular self.dqn_optimizer.zero_grad() loss.backward() clip_grad_norm_(self.dqn.parameters(), self.hyper_params["GRADIENT_CLIP"]) self.dqn_optimizer.step() # update target networks tau = self.hyper_params["TAU"] common_utils.soft_update(self.dqn, self.dqn_target, tau) # update priorities in PER loss_for_prior = dq_loss_element_wise.detach().cpu().numpy().squeeze() new_priorities = loss_for_prior + self.hyper_params["PER_EPS"] self.memory.update_priorities(indexes, new_priorities) # increase beta fraction = min( float(self.i_episode) / self.args.max_episode_steps, 1.0) self.beta = self.beta + fraction * (1.0 - self.beta) return loss.data
def update_model(self) -> Tuple[torch.Tensor, ...]: """Train the model after each episode.""" experiences = self.memory.sample(self.beta) states, actions, rewards, next_states, dones, weights, indexes = experiences # G_t = r + gamma * v(s_{t+1}) if state != Terminal # = r otherwise masks = 1 - dones next_actions = self.actor_target(next_states) next_values = self.critic_target( torch.cat((next_states, next_actions), dim=-1)) curr_returns = rewards + self.hyper_params[ "GAMMA"] * next_values * masks curr_returns = curr_returns.to(device).detach() # train critic gradient_clip_cr = self.hyper_params["GRADIENT_CLIP_CR"] values = self.critic(torch.cat((states, actions), dim=-1)) critic_loss_element_wise = (values - curr_returns).pow(2) critic_loss = torch.mean(critic_loss_element_wise * weights) self.critic_optimizer.zero_grad() critic_loss.backward() nn.utils.clip_grad_norm_(self.critic.parameters(), gradient_clip_cr) self.critic_optimizer.step() # train actor gradient_clip_ac = self.hyper_params["GRADIENT_CLIP_AC"] actions = self.actor(states) actor_loss_element_wise = -self.critic( torch.cat((states, actions), dim=-1)) actor_loss = torch.mean(actor_loss_element_wise * weights) self.actor_optimizer.zero_grad() actor_loss.backward() nn.utils.clip_grad_norm_(self.actor.parameters(), gradient_clip_ac) self.actor_optimizer.step() # update target networks tau = self.hyper_params["TAU"] common_utils.soft_update(self.actor, self.actor_target, tau) common_utils.soft_update(self.critic, self.critic_target, tau) # update priorities in PER new_priorities = critic_loss_element_wise new_priorities = (new_priorities.data.cpu().numpy() + self.hyper_params["PER_EPS"]) self.memory.update_priorities(indexes, new_priorities) # increase beta fraction = min(float(self.i_episode) / self.args.episode_num, 1.0) self.beta = self.beta + fraction * (1.0 - self.beta) return actor_loss.item(), critic_loss.item()
def update_model( self, experiences: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: """Train the model after each episode.""" obs, acts, rews, obs_nxt, goals, dones = experiences # G_t = r + gamma * v(s_{t+1}) if state != Terminal # = r otherwise masks = 1 - dones # next_actions = self.actor_target(next_states) # next_values = self.critic_target(torch.cat((next_states, next_actions), dim=-1)) # curr_returns = rewards + self.hyper_params["GAMMA"] * next_values * masks # curr_returns = curr_returns.to(device) next_actions = self.actor_target(obs_nxt) # a' next_values = self.critic_target( torch.cat((obs_nxt, next_actions), dim=-1)) # target Q curr_returns = rewards + self.hyper_params[ "GAMMA"] * next_values * masks curr_returns = curr_returns.to(device) # train critic values = self.critic(torch.cat((obs, acts), dim=-1)) critic_loss = F.mse_loss(values, curr_returns) self.critic_optimizer.zero_grad() # all grads of all params to zero critic_loss.backward() # compute gradients self.critic_optimizer.step() # apply gradients # train actor actions = self.actor(obs) actor_loss = -self.critic(torch.cat((obs, acts), dim=-1)).mean() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # update target networks tau = self.hyper_params["TAU"] common_utils.soft_update(self.actor, self.actor_target, tau) common_utils.soft_update(self.critic, self.critic_target, tau) return actor_loss.data, critic_loss.data
def update_model(self, experiences): """Train the model after each episode.""" states, actions, rewards, next_states, dones, weights, indices, eps_d = ( experiences) gamma = self.hyper_params["GAMMA"] critic1_loss_element_wise, critic2_loss_element_wise = self._get_critic_loss( experiences, gamma) critic_loss_element_wise = critic1_loss_element_wise + critic2_loss_element_wise critic1_loss = torch.mean(critic1_loss_element_wise * weights) critic2_loss = torch.mean(critic2_loss_element_wise * weights) critic_loss = critic1_loss + critic2_loss if self.use_n_step: experiences_n = self.memory_n.sample(indices) gamma = self.hyper_params["GAMMA"]**self.hyper_params["N_STEP"] critic1_loss_n_element_wise, critic2_loss_n_element_wise = self._get_critic_loss( experiences_n, gamma) critic_loss_n_element_wise = (critic1_loss_n_element_wise + critic2_loss_n_element_wise) critic1_loss_n = torch.mean(critic1_loss_n_element_wise * weights) critic2_loss_n = torch.mean(critic2_loss_n_element_wise * weights) critic_loss_n = critic1_loss_n + critic2_loss_n lambda1 = self.hyper_params["LAMBDA1"] critic_loss_element_wise += lambda1 * critic_loss_n_element_wise critic_loss += lambda1 * critic_loss_n self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() if self.episode_steps % self.hyper_params["POLICY_UPDATE_FREQ"] == 0: # train actor actions = self.actor(states) actor_loss_element_wise = -self.critic1( torch.cat((states, actions), dim=-1)) actor_loss = torch.mean(actor_loss_element_wise * weights) self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() # update target networks tau = self.hyper_params["TAU"] common_utils.soft_update(self.actor, self.actor_target, tau) common_utils.soft_update(self.critic1, self.critic1_target, tau) common_utils.soft_update(self.critic2, self.critic2_target, tau) # update priorities new_priorities = critic_loss_element_wise new_priorities += self.hyper_params[ "LAMBDA3"] * actor_loss_element_wise.pow(2) new_priorities += self.hyper_params["PER_EPS"] new_priorities = new_priorities.data.cpu().numpy().squeeze() new_priorities += eps_d self.memory.update_priorities(indices, new_priorities) else: actor_loss = torch.zeros(1) return actor_loss.data, critic1_loss.data, critic2_loss.data
def update_model( self, experiences: Tuple[torch.Tensor, ...] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Train the model after each episode.""" self.update_step += 1 states, actions, rewards, next_states, dones = experiences masks = 1 - dones # get actions with noise noise = torch.FloatTensor(self.target_policy_noise.sample()).to(device) clipped_noise = torch.clamp( noise, -self.hyper_params["TARGET_POLICY_NOISE_CLIP"], self.hyper_params["TARGET_POLICY_NOISE_CLIP"], ) next_actions = (self.actor_target(next_states) + clipped_noise).clamp(-1.0, 1.0) # min (Q_1', Q_2') next_values1 = self.critic_target1(next_states, next_actions) next_values2 = self.critic_target2(next_states, next_actions) next_values = torch.min(next_values1, next_values2) # G_t = r + gamma * v(s_{t+1}) if state != Terminal # = r otherwise curr_returns = rewards + self.hyper_params["GAMMA"] * next_values * masks curr_returns = curr_returns.detach() # critic loss values1 = self.critic1(states, actions) values2 = self.critic2(states, actions) critic1_loss = F.mse_loss(values1, curr_returns) critic2_loss = F.mse_loss(values2, curr_returns) # train critic critic_loss = critic1_loss + critic2_loss self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() if self.update_step % self.hyper_params["POLICY_UPDATE_FREQ"] == 0: # policy loss actions = self.actor(states) actor_loss = -self.critic1(states, actions).mean() # train actor self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() # update target networks tau = self.hyper_params["TAU"] common_utils.soft_update(self.critic1, self.critic_target1, tau) common_utils.soft_update(self.critic2, self.critic_target2, tau) common_utils.soft_update(self.actor, self.actor_target, tau) else: actor_loss = torch.zeros(1) return actor_loss.item(), critic1_loss.item(), critic2_loss.item()
def update_model(self, experiences): """Train the model after each episode.""" states, actions, rewards, next_states, dones = experiences # G_t = r + gamma * v(s_{t+1}) if state != Terminal # = r otherwise masks = 1 - dones noise = torch.FloatTensor(self.target_policy_noise.sample()).to(device) clipped_noise = torch.clamp( noise, -self.hyper_params["TARGET_POLICY_NOISE_CLIP"], self.hyper_params["TARGET_POLICY_NOISE_CLIP"], ) next_actions = (self.actor_target(next_states) + clipped_noise).clamp( -1.0, 1.0) target_values1 = self.critic1_target( torch.cat((next_states, next_actions), dim=-1)) target_values2 = self.critic2_target( torch.cat((next_states, next_actions), dim=-1)) target_values = torch.min(target_values1, target_values2) target_values = ( rewards + (self.hyper_params["GAMMA"] * target_values * masks).detach()) # train critic values1 = self.critic1(torch.cat((states, actions), dim=-1)) critic1_loss = F.mse_loss(values1, target_values) values2 = self.critic2(torch.cat((states, actions), dim=-1)) critic2_loss = F.mse_loss(values2, target_values) critic_loss = critic1_loss + critic2_loss self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() if self.episode_steps % self.hyper_params["POLICY_UPDATE_FREQ"] == 0: # train actor actions = self.actor(states) actor_loss = -self.critic1(torch.cat( (states, actions), dim=-1)).mean() self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() # update target networks tau = self.hyper_params["TAU"] common_utils.soft_update(self.actor, self.actor_target, tau) common_utils.soft_update(self.critic1, self.critic1_target, tau) common_utils.soft_update(self.critic2, self.critic2_target, tau) else: actor_loss = torch.zeros(1) return actor_loss.data, critic1_loss.data, critic2_loss.data
def update_model(self) -> Tuple[torch.Tensor, ...]: """Train the model after each episode.""" experiences_1 = self.memory.sample() weights, indices, eps_d = experiences_1[-3:] actions = experiences_1[1] # 1 step loss gamma = self.hyper_params["GAMMA"] dq_loss_element_wise, q_values = self._get_dqn_loss( experiences_1, gamma) dq_loss = torch.mean(dq_loss_element_wise * weights) # n step loss if self.use_n_step: experiences_n = self.memory_n.sample(indices) gamma = self.hyper_params["GAMMA"]**self.hyper_params["N_STEP"] dq_loss_n_element_wise, q_values_n = self._get_dqn_loss( experiences_n, gamma) # to update loss and priorities q_values = 0.5 * (q_values + q_values_n) dq_loss_element_wise += (dq_loss_n_element_wise * self.hyper_params["LAMBDA1"]) dq_loss = torch.mean(dq_loss_element_wise * weights) # supervised loss using demo for only demo transitions demo_idxs = np.where(eps_d != 0.0) n_demo = demo_idxs[0].size if n_demo != 0: # if 1 or more demos are sampled # get margin for each demo transition action_idxs = actions[demo_idxs].long() margin = torch.ones(q_values.size()) * self.hyper_params["MARGIN"] margin[demo_idxs, action_idxs] = 0.0 # demo actions have 0 margins margin = margin.to(device) # calculate supervised loss demo_q_values = q_values[demo_idxs, action_idxs].squeeze() supervised_loss = torch.max(q_values + margin, dim=-1)[0] supervised_loss = supervised_loss[demo_idxs] - demo_q_values supervised_loss = torch.mean( supervised_loss) * self.hyper_params["LAMBDA2"] else: # no demo sampled supervised_loss = torch.zeros(1, device=device) # q_value regularization q_regular = torch.norm(q_values, 2).mean() * self.hyper_params["W_Q_REG"] # total loss loss = dq_loss + supervised_loss + q_regular # train dqn self.dqn_optimizer.zero_grad() loss.backward() clip_grad_norm_(self.dqn.parameters(), self.hyper_params["GRADIENT_CLIP"]) self.dqn_optimizer.step() # update target networks tau = self.hyper_params["TAU"] common_utils.soft_update(self.dqn, self.dqn_target, tau) # update priorities in PER loss_for_prior = dq_loss_element_wise.detach().cpu().numpy().squeeze() new_priorities = loss_for_prior + self.hyper_params["PER_EPS"] new_priorities += eps_d self.memory.update_priorities(indices, new_priorities) # increase beta fraction = min(float(self.i_episode) / self.args.episode_num, 1.0) self.beta = self.beta + fraction * (1.0 - self.beta) if self.hyper_params["USE_NOISY_NET"]: self.dqn.reset_noise() self.dqn_target.reset_noise() return ( loss.item(), dq_loss.item(), supervised_loss.item(), q_values.mean().item(), n_demo, )
def update_model(self) -> Tuple[torch.Tensor, ...]: """Train the model after each episode.""" self.update_step += 1 experiences = self.memory.sample(self.beta) states, actions, rewards, next_states, dones, weights, indices, eps_d = ( experiences ) new_actions, log_prob, pre_tanh_value, mu, std = self.actor(states) # train alpha if self.hyper_params["AUTO_ENTROPY_TUNING"]: alpha_loss = torch.mean( (-self.log_alpha * (log_prob + self.target_entropy).detach()) * weights ) self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = torch.zeros(1) alpha = self.hyper_params["W_ENTROPY"] # Q function loss masks = 1 - dones gamma = self.hyper_params["GAMMA"] q_1_pred = self.qf_1(states, actions) q_2_pred = self.qf_2(states, actions) v_target = self.vf_target(next_states) q_target = rewards + self.hyper_params["GAMMA"] * v_target * masks qf_1_loss = torch.mean((q_1_pred - q_target.detach()).pow(2) * weights) qf_2_loss = torch.mean((q_2_pred - q_target.detach()).pow(2) * weights) if self.use_n_step: experiences_n = self.memory_n.sample(indices) _, _, rewards, next_states, dones = experiences_n gamma = gamma ** self.hyper_params["N_STEP"] lambda1 = self.hyper_params["LAMBDA1"] masks = 1 - dones v_target = self.vf_target(next_states) q_target = rewards + gamma * v_target * masks qf_1_loss_n = torch.mean((q_1_pred - q_target.detach()).pow(2) * weights) qf_2_loss_n = torch.mean((q_2_pred - q_target.detach()).pow(2) * weights) # to update loss and priorities qf_1_loss = qf_1_loss + qf_1_loss_n * lambda1 qf_2_loss = qf_2_loss + qf_2_loss_n * lambda1 # V function loss v_pred = self.vf(states) q_pred = torch.min( self.qf_1(states, new_actions), self.qf_2(states, new_actions) ) v_target = (q_pred - alpha * log_prob).detach() vf_loss_element_wise = (v_pred - v_target).pow(2) vf_loss = torch.mean(vf_loss_element_wise * weights) # train Q functions self.qf_1_optimizer.zero_grad() qf_1_loss.backward() self.qf_1_optimizer.step() self.qf_2_optimizer.zero_grad() qf_2_loss.backward() self.qf_2_optimizer.step() # train V function self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() if self.update_step % self.hyper_params["POLICY_UPDATE_FREQ"] == 0: # actor loss advantage = q_pred - v_pred.detach() actor_loss_element_wise = alpha * log_prob - advantage actor_loss = torch.mean(actor_loss_element_wise * weights) # regularization if not self.is_discrete: # iff the action is continuous mean_reg = self.hyper_params["W_MEAN_REG"] * mu.pow(2).mean() std_reg = self.hyper_params["W_STD_REG"] * std.pow(2).mean() pre_activation_reg = self.hyper_params["W_PRE_ACTIVATION_REG"] * ( pre_tanh_value.pow(2).sum(dim=-1).mean() ) actor_reg = mean_reg + std_reg + pre_activation_reg # actor loss + regularization actor_loss += actor_reg # train actor self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # update target networks common_utils.soft_update(self.vf, self.vf_target, self.hyper_params["TAU"]) # update priorities new_priorities = vf_loss_element_wise new_priorities += self.hyper_params[ "LAMBDA3" ] * actor_loss_element_wise.pow(2) new_priorities += self.hyper_params["PER_EPS"] new_priorities = new_priorities.data.cpu().numpy().squeeze() new_priorities += eps_d self.memory.update_priorities(indices, new_priorities) # increase beta fraction = min(float(self.i_episode) / self.args.episode_num, 1.0) self.beta = self.beta + fraction * (1.0 - self.beta) else: actor_loss = torch.zeros(1) return ( actor_loss.item(), qf_1_loss.item(), qf_2_loss.item(), vf_loss.item(), alpha_loss.item(), )
def update_model(self) -> Tuple[torch.Tensor, ...]: """Train the model after each episode.""" experiences = self.memory.sample() states, actions, rewards, next_states, dones, weights, indexes, eps_d = ( experiences ) q_values = self.dqn(states, self.epsilon) next_q_values = self.dqn(next_states, self.epsilon) next_target_q_values = self.dqn_target(next_states, self.epsilon) curr_q_values = q_values.gather(1, actions.long().unsqueeze(1)) next_q_values = next_target_q_values.gather( # Double DQN 1, next_q_values.argmax(1).unsqueeze(1) ) # G_t = r + gamma * v(s_{t+1}) if state != Terminal # = r otherwise masks = 1 - dones target = rewards + self.hyper_params["GAMMA"] * next_q_values * masks target = target.to(device) # calculate dq loss dq_loss_element_wise = (target - curr_q_values).pow(2) dq_loss = torch.mean(dq_loss_element_wise * weights) # supervised loss using demo for only demo transitions demo_idxs = np.where(eps_d != 0.0) if demo_idxs[0].size != 0: # if 1 or more demos are sampled # get margin for each demo transition action_idxs = actions[demo_idxs].long() margin = torch.ones(q_values.size()) * self.hyper_params["MARGIN"] margin[demo_idxs, action_idxs] = 0.0 # demo actions have 0 margins margin = margin.to(device) # calculate supervised loss demo_q_values = q_values[demo_idxs, action_idxs].squeeze() supervised_loss = torch.max(q_values + margin, dim=-1)[0] supervised_loss = supervised_loss[demo_idxs] - demo_q_values supervised_loss = torch.mean(supervised_loss) * self.hyper_params["LAMBDA2"] else: # no demo sampled supervised_loss = torch.zeros(1, device=device) # q_value regularization q_regular = torch.norm(q_values, 2).mean() * self.hyper_params["W_Q_REG"] # total loss loss = dq_loss + supervised_loss + q_regular # train dqn self.dqn_optimizer.zero_grad() loss.backward() clip_grad_norm_(self.dqn.parameters(), self.hyper_params["GRADIENT_CLIP"]) self.dqn_optimizer.step() # update target networks tau = self.hyper_params["TAU"] common_utils.soft_update(self.dqn, self.dqn_target, tau) # update priorities in PER loss_for_prior = dq_loss_element_wise.detach().cpu().numpy().squeeze() new_priorities = loss_for_prior + self.hyper_params["PER_EPS"] new_priorities += eps_d self.memory.update_priorities(indexes, new_priorities) # increase beta fraction = min(float(self.i_episode) / self.args.max_episode_steps, 1.0) self.beta = self.beta + fraction * (1.0 - self.beta) return loss.data, dq_loss.data, supervised_loss.data
def update_model(self) -> Tuple[torch.Tensor, ...]: """Train the model after each episode.""" experiences = self.memory.sample() demos = self.demo_memory.sample() states, actions, rewards, next_states, dones = experiences demo_states, demo_actions, _, _, _ = demos new_actions, log_prob, pre_tanh_value, mu, std = self.actor(states) pred_actions, _, _, _, _ = self.actor(demo_states) # train alpha if self.hyper_params["AUTO_ENTROPY_TUNING"]: alpha_loss = (-self.log_alpha * (log_prob + self.target_entropy).detach()).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = torch.zeros(1) alpha = self.hyper_params["W_ENTROPY"] # Q function loss masks = 1 - dones q_1_pred = self.qf_1(states, actions) q_2_pred = self.qf_2(states, actions) v_target = self.vf_target(next_states) q_target = rewards + self.hyper_params["GAMMA"] * v_target * masks qf_1_loss = F.mse_loss(q_1_pred, q_target.detach()) qf_2_loss = F.mse_loss(q_2_pred, q_target.detach()) # V function loss v_pred = self.vf(states) q_pred = torch.min(self.qf_1(states, new_actions), self.qf_2(states, new_actions)) v_target = q_pred - alpha * log_prob vf_loss = F.mse_loss(v_pred, v_target.detach()) # train Q functions self.qf_1_optimizer.zero_grad() qf_1_loss.backward() self.qf_1_optimizer.step() self.qf_2_optimizer.zero_grad() qf_2_loss.backward() self.qf_2_optimizer.step() # train V function self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() if self.total_step % self.hyper_params["DELAYED_UPDATE"] == 0: # bc loss qf_mask = torch.gt( self.qf_1(demo_states, demo_actions), self.qf_1(demo_states, pred_actions), ).to(device) qf_mask = qf_mask.float() n_qf_mask = int(qf_mask.sum().item()) if n_qf_mask == 0: bc_loss = torch.zeros(1, device=device) else: bc_loss = (torch.mul(pred_actions, qf_mask) - torch.mul( demo_actions, qf_mask)).pow(2).sum() / n_qf_mask # actor loss advantage = q_pred - v_pred.detach() actor_loss = (alpha * log_prob - advantage).mean() actor_loss = self.lambda1 * actor_loss + self.lambda2 * bc_loss # regularization if not self.is_discrete: # iff the action is continuous mean_reg = self.hyper_params["W_MEAN_REG"] * mu.pow(2).mean() std_reg = self.hyper_params["W_STD_REG"] * std.pow(2).mean() pre_activation_reg = self.hyper_params[ "W_PRE_ACTIVATION_REG"] * (pre_tanh_value.pow(2).sum( dim=-1).mean()) actor_reg = mean_reg + std_reg + pre_activation_reg # actor loss + regularization actor_loss += actor_reg # train actor self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # update target networks common_utils.soft_update(self.vf, self.vf_target, self.hyper_params["TAU"]) else: actor_loss = torch.zeros(1) n_qf_mask = 0 return ( actor_loss.data, qf_1_loss.data, qf_2_loss.data, vf_loss.data, alpha_loss.data, n_qf_mask, )
def update_model(self) -> Tuple[torch.Tensor, torch.Tensor]: """Train the model after each episode.""" experiences = self.memory.sample() demos = self.demo_memory.sample() exp_states, exp_actions, exp_rewards, exp_next_states, exp_dones = experiences demo_states, demo_actions, demo_rewards, demo_next_states, demo_dones = demos states = torch.cat((exp_states, demo_states), dim=0) actions = torch.cat((exp_actions, demo_actions), dim=0) rewards = torch.cat((exp_rewards, demo_rewards), dim=0) next_states = torch.cat((exp_next_states, demo_next_states), dim=0) dones = torch.cat((exp_dones, demo_dones), dim=0) # G_t = r + gamma * v(s_{t+1}) if state != Terminal # = r otherwise masks = 1 - dones next_actions = self.actor_target(next_states) next_values = self.critic_target( torch.cat((next_states, next_actions), dim=-1)) curr_returns = rewards + (self.hyper_params["GAMMA"] * next_values * masks) curr_returns = curr_returns.to(device) # critic loss values = self.critic(torch.cat((states, actions), dim=-1)) critic_loss = F.mse_loss(values, curr_returns) # train critic self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() # policy loss actions = self.actor(states) policy_loss = -self.critic(torch.cat((states, actions), dim=-1)).mean() # bc loss pred_actions = self.actor(demo_states) qf_mask = torch.gt( self.critic(torch.cat((demo_states, demo_actions), dim=-1)), self.critic(torch.cat((demo_states, pred_actions), dim=-1)), ).to(device) qf_mask = qf_mask.float() n_qf_mask = int(qf_mask.sum().item()) if n_qf_mask == 0: bc_loss = torch.zeros(1, device=device) else: bc_loss = (torch.mul(pred_actions, qf_mask) - torch.mul( demo_actions, qf_mask)).pow(2).sum() / n_qf_mask # train actor: pg loss + BC loss actor_loss = self.lambda1 * policy_loss + self.lambda2 * bc_loss self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # update target networks tau = self.hyper_params["TAU"] common_utils.soft_update(self.actor, self.actor_target, tau) common_utils.soft_update(self.critic, self.critic_target, tau) return actor_loss.data, critic_loss.data
def update_model(self) -> Tuple[torch.Tensor, ...]: """Train the model after each episode.""" self.update_step += 1 experiences = self.memory.sample() states, actions, rewards, next_states, dones = experiences new_actions, log_prob, pre_tanh_value, mu, std = self.actor(states) # train alpha if self.hyper_params["AUTO_ENTROPY_TUNING"]: alpha_loss = ( -self.log_alpha * (log_prob + self.target_entropy).detach() ).mean() self.alpha_optimizer.zero_grad() alpha_loss.backward() self.alpha_optimizer.step() alpha = self.log_alpha.exp() else: alpha_loss = torch.zeros(1) alpha = self.hyper_params["W_ENTROPY"] # Q function loss masks = 1 - dones q_1_pred = self.qf_1(states, actions) q_2_pred = self.qf_2(states, actions) v_target = self.vf_target(next_states) q_target = rewards + self.hyper_params["GAMMA"] * v_target * masks qf_1_loss = F.mse_loss(q_1_pred, q_target.detach()) qf_2_loss = F.mse_loss(q_2_pred, q_target.detach()) # V function loss v_pred = self.vf(states) q_pred = torch.min( self.qf_1(states, new_actions), self.qf_2(states, new_actions) ) v_target = q_pred - alpha * log_prob vf_loss = F.mse_loss(v_pred, v_target.detach()) # train Q functions self.qf_1_optimizer.zero_grad() qf_1_loss.backward() self.qf_1_optimizer.step() self.qf_2_optimizer.zero_grad() qf_2_loss.backward() self.qf_2_optimizer.step() # train V function self.vf_optimizer.zero_grad() vf_loss.backward() self.vf_optimizer.step() if self.update_step % self.hyper_params["POLICY_UPDATE_FREQ"] == 0: # actor loss advantage = q_pred - v_pred.detach() actor_loss = (alpha * log_prob - advantage).mean() # regularization if not self.is_discrete: # iff the action is continuous mean_reg = self.hyper_params["W_MEAN_REG"] * mu.pow(2).mean() std_reg = self.hyper_params["W_STD_REG"] * std.pow(2).mean() pre_activation_reg = self.hyper_params["W_PRE_ACTIVATION_REG"] * ( pre_tanh_value.pow(2).sum(dim=-1).mean() ) actor_reg = mean_reg + std_reg + pre_activation_reg # actor loss + regularization actor_loss += actor_reg # train actor self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # update target networks common_utils.soft_update(self.vf, self.vf_target, self.hyper_params["TAU"]) else: actor_loss = torch.zeros(1) return ( actor_loss.item(), qf_1_loss.item(), qf_2_loss.item(), vf_loss.item(), alpha_loss.item(), )
def update_model( self, experiences: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Train the model after each episode.""" self.update_step += 1 states, actions, rewards, next_states, dones = experiences masks = 1 - dones # get actions with noise noise_std, noise_clip = ( self.hyper_params["TARGET_SMOOTHING_NOISE_STD"], self.hyper_params["TARGET_SMOOTHING_NOISE_CLIP"], ) next_actions = self.actor_target(next_states) noise = next_actions.data.normal_(0, noise_std).to(device) noise = noise.clamp(-noise_clip, noise_clip) next_actions += noise next_actions = next_actions.clamp(-1.0, 1.0) # min (Q_1', Q_2') next_states_actions = torch.cat((next_states, next_actions), dim=-1) next_values1 = self.critic_target1(next_states_actions) next_values2 = self.critic_target2(next_states_actions) next_values = torch.min(next_values1, next_values2) # G_t = r + gamma * v(s_{t+1}) if state != Terminal # = r otherwise curr_returns = rewards + self.hyper_params[ "GAMMA"] * next_values * masks curr_returns = curr_returns.to(device).detach() # critic loss states_actions = torch.cat((states, actions), dim=-1) values1 = self.critic_1(states_actions) values2 = self.critic_2(states_actions) critic_loss1 = F.mse_loss(values1, curr_returns) critic_loss2 = F.mse_loss(values2, curr_returns) critic_loss = critic_loss1 + critic_loss2 # train critic self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step() if self.update_step % self.hyper_params["DELAYED_UPDATE"] == 0: # train actor actions = self.actor(states) states_actions = torch.cat((states, actions), dim=-1) actor_loss = -self.critic_1(states_actions).mean() self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() # update target networks tau = self.hyper_params["TAU"] common_utils.soft_update(self.critic_1, self.critic_target1, tau) common_utils.soft_update(self.critic_2, self.critic_target2, tau) common_utils.soft_update(self.actor, self.actor_target, tau) else: actor_loss = torch.zeros(1) return actor_loss.data, critic_loss1.data, critic_loss2.data