def _train_policy_latent(self, replay_buffer, eval_fn): for it in range(self.args["actor_iterations"]): batch = replay_buffer.sample(self.args["actor_batch_size"]) batch = to_torch(batch, torch.float, device=self.args["device"]) rew = batch.rew done = batch.done obs = batch.obs act = batch.act obs_next = batch.obs_next # Critic Training with torch.no_grad(): _, _, next_action = self.actor_target(obs_next, self.vae.decode) target_q1 = self.critic1_target(obs_next, next_action) target_q2 = self.critic2_target(obs_next, next_action) target_q = self.args["lmbda"] * torch.min( target_q1, target_q2 ) + (1 - self.args["lmbda"]) * torch.max(target_q1, target_q2) target_q = rew + (1 - done) * self.args["discount"] * target_q current_q1 = self.critic1(obs, act) current_q2 = self.critic2(obs, act) critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss( current_q2, target_q) self.critic1_opt.zero_grad() self.critic2_opt.zero_grad() critic_loss.backward() self.critic1_opt.step() self.critic2_opt.step() # Actor Training latent_actions, mid_actions, actions = self.actor( obs, self.vae.decode) actor_loss = -self.critic1(obs, actions).mean() self.actor.zero_grad() actor_loss.backward() self.actor_opt.step() # update target network self._sync_weight(self.actor_target, self.actor) self._sync_weight(self.critic1_target, self.critic1) self._sync_weight(self.critic2_target, self.critic2) if (it + 1) % 1000 == 0: print("mid_actions :", torch.abs(actions - mid_actions).mean()) if eval_fn is None: self.eval_policy() else: self.vae._actor = copy.deepcopy(self.actor) res = eval_fn(self.get_policy()) self.log_res((it + 1) // 1000, res)
def _train_policy(self, train_buffer, callback_fn): for it in range(self.args["actor_iterations"]): batch = train_buffer.sample(self.args["actor_batch_size"]) batch = to_torch(batch, torch.float, device=self.args["device"]) rew = batch.rew done = batch.done obs = batch.obs act = batch.act obs_next = batch.obs_next # Critic Training with torch.no_grad(): action_next_actor,_ = self.actor_target(obs_next) action_next_vae = self.vae.decode(obs_next, z = action_next_actor) target_q1 = self.critic1_target(obs_next, action_next_vae) target_q2 = self.critic2_target(obs_next, action_next_vae) target_q = self.args["lmbda"] * torch.min(target_q1, target_q2) + (1 - self.args["lmbda"]) * torch.max(target_q1, target_q2) target_q = rew + (1 - done) * self.args["discount"] * target_q current_q1 = self.critic1(obs, act) current_q2 = self.critic2(obs, act) critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q) self.critic1_opt.zero_grad() self.critic2_opt.zero_grad() critic_loss.backward() self.critic1_opt.step() self.critic2_opt.step() # Actor Training action_actor,_ = self.actor(obs) action_vae = self.vae.decode(obs, z = action_actor) actor_loss = -self.critic1(obs, action_vae).mean() self.actor.zero_grad() actor_loss.backward() self.actor_opt.step() # update target network self._sync_weight(self.actor_target, self.actor) self._sync_weight(self.critic1_target, self.critic1) self._sync_weight(self.critic2_target, self.critic2) if (it + 1) % 1000 == 0: if callback_fn is None: self.eval_policy() else: res = callback_fn(self.get_policy()) self.log_res((it + 1) // 1000, res)
def train(self, train_buffer, val_buffer, callback_fn): training_iters = 0 while training_iters < self.args["max_timesteps"]: # Sample replay buffer batch = train_buffer.sample(self.args["batch_size"]) batch = to_torch(batch, torch.float, device=self.args["device"]) reward = batch.rew done = batch.done state = batch.obs action = batch.act.to(torch.int64) next_state = batch.obs_next # Compute the target Q value with torch.no_grad(): q, imt, i = self.Q(next_state) imt = imt.exp() imt = (imt / imt.max(1, keepdim=True)[0] > self.threshold).float() # Use large negative number to mask actions from argmax next_action = (imt * q + (1 - imt) * -1e8).argmax(1, keepdim=True) q, imt, i = self.Q_target(next_state) target_Q = reward + done * self.discount * q.gather( 1, next_action).reshape(-1, 1) # Get current Q estimate current_Q, imt, i = self.Q(state) current_Q = current_Q.gather(1, action) # Compute Q loss q_loss = F.smooth_l1_loss(current_Q, target_Q) i_loss = F.nll_loss(imt, action.reshape(-1)) Q_loss = q_loss + i_loss + 1e-2 * i.pow(2).mean() # Optimize the Q self.Q_optimizer.zero_grad() Q_loss.backward() self.Q_optimizer.step() # Update target network by polyak or full copy every X iterations. self.maybe_update_target() training_iters += 1 #print(training_iters ,self.args["eval_freq"]) if training_iters % self.args["eval_freq"] == 0: res = callback_fn(self.get_policy()) self.log_res(training_iters // self.args["eval_freq"], res)
def _train_vae_step(self, batch): batch = to_torch(batch, torch.float, device=self.args["device"]) obs = batch.obs act = batch.act recon, mean, std = self.vae(obs, act) recon_loss = F.mse_loss(recon, act) KL_loss = -self.args["vae_kl_weight"] * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean() vae_loss = recon_loss + 0.5 * KL_loss self.vae_opt.zero_grad() vae_loss.backward() self.vae_opt.step() return vae_loss.cpu().data.numpy(), recon_loss.cpu().data.numpy(), KL_loss.cpu().data.numpy()
def _train_bc(self, buffer): bc_loss_list = [] rew_loss_list = [] for step in range(100000): for i in range(len(self.bcs)): batch = buffer.sample(256) batch = to_torch(batch, torch.float, device=self.args["device"]) rew = batch.rew obs = batch.obs act = batch.act obs_next = batch.obs_next obs_act = torch.cat([obs, act], axis=1) obs_next_pre, _ = self.bcs[i](obs_act) rew_pre, _ = self.rews[i](torch.cat([obs, act, obs_next], axis=1)) bc_loss = F.mse_loss(obs_next_pre, obs_next) rew_loss = F.mse_loss(rew_pre, rew) self.bcs_opt[i].zero_grad() self.rews_opt[i].zero_grad() bc_loss.backward() rew_loss.backward() self.bcs_opt[i].step() self.rews_opt[i].step() bc_loss_list.append(bc_loss.item()) rew_loss_list.append(rew_loss.item()) if (step + 1) % 1000 == 0: logger.info('BC Epoch : {}, bc_loss : {:.4}', (step + 1) // 1000, np.mean(bc_loss_list)) logger.info('BC Epoch : {}, recon_loss : {:.4}', (step + 1) // 1000, np.mean(rew_loss_list))
def _train(self, batch): self._current_epoch += 1 batch = to_torch(batch, torch.float, device=self.args["device"]) rewards = batch.rew terminals = batch.done obs = batch.obs actions = batch.act next_obs = batch.obs_next """ Policy and Alpha Loss """ new_obs_actions, log_pi = self.forward(obs) if self.args["use_automatic_entropy_tuning"]: alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean() self.alpha_opt.zero_grad() alpha_loss.backward() self.alpha_opt.step() alpha = self.log_alpha.exp() else: alpha_loss = 0 alpha = 1 if self._current_epoch < self.args["policy_bc_steps"]: """ For the initial few epochs, try doing behaivoral cloning, if needed conventionally, there's not much difference in performance with having 20k gradient steps here, or not having it """ policy_log_prob = self.actor.log_prob(obs, actions) policy_loss = (alpha * log_pi - policy_log_prob).mean() else: q_new_actions = torch.min( self.critic1(obs, new_obs_actions), self.critic2(obs, new_obs_actions), ) policy_loss = (alpha * log_pi - q_new_actions).mean() self.actor_opt.zero_grad() policy_loss.backward() self.actor_opt.step() """ QF Loss """ q1_pred = self.critic1(obs, actions) q2_pred = self.critic2(obs, actions) new_next_actions, new_log_pi = self.forward( next_obs, reparameterize=True, return_log_prob=True, ) new_curr_actions, new_curr_log_pi = self.forward( obs, reparameterize=True, return_log_prob=True, ) if self.args["type_q_backup"] == "max": target_q_values = torch.max( self.critic1_target(next_obs, new_next_actions), self.critic2_target(next_obs, new_next_actions), ) target_q_values = target_q_values - alpha * new_log_pi elif self.args["type_q_backup"] == "min": target_q_values = torch.min( self.critic1_target(next_obs, new_next_actions), self.critic2_target(next_obs, new_next_actions), ) target_q_values = target_q_values - alpha * new_log_pi elif self.args["type_q_backup"] == "medium": target_q1_next = self.critic1_target(next_obs, new_next_actions) target_q2_next = self.critic2_target(next_obs, new_next_actions) target_q_values = self.args["q_backup_lmbda"] * torch.min(target_q1_next, target_q2_next) \ + (1 - self.args["q_backup_lmbda"]) * torch.max(target_q1_next, target_q2_next) target_q_values = target_q_values - alpha * new_log_pi else: """when using max q backup""" next_actions_temp, _ = self._get_policy_actions( next_obs, num_actions=10, network=self.forward) target_qf1_values = self._get_tensor_values( next_obs, next_actions_temp, network=self.critic1).max(1)[0].view(-1, 1) target_qf2_values = self._get_tensor_values( next_obs, next_actions_temp, network=self.critic2).max(1)[0].view(-1, 1) target_q_values = torch.min(target_qf1_values, target_qf2_values) q_target = self.args["reward_scale"] * rewards + ( 1. - terminals) * self.args["discount"] * target_q_values.detach() qf1_loss = self.critic_criterion(q1_pred, q_target) qf2_loss = self.critic_criterion(q2_pred, q_target) ## add CQL random_actions_tensor = torch.FloatTensor( q2_pred.shape[0] * self.args["num_random"], actions.shape[-1]).uniform_(-1, 1).to(self.args["device"]) curr_actions_tensor, curr_log_pis = self._get_policy_actions( obs, num_actions=self.args["num_random"], network=self.forward) new_curr_actions_tensor, new_log_pis = self._get_policy_actions( next_obs, num_actions=self.args["num_random"], network=self.forward) q1_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.critic1) q2_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.critic2) q1_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.critic1) q2_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.critic2) q1_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.critic1) q2_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.critic2) cat_q1 = torch.cat( [q1_rand, q1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions], 1) cat_q2 = torch.cat( [q2_rand, q2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions], 1) if self.args["min_q_version"] == 3: # importance sammpled version random_density = np.log(0.5**curr_actions_tensor.shape[-1]) cat_q1 = torch.cat([ q1_rand - random_density, q1_next_actions - new_log_pis.detach(), q1_curr_actions - curr_log_pis.detach() ], 1) cat_q2 = torch.cat([ q2_rand - random_density, q2_next_actions - new_log_pis.detach(), q2_curr_actions - curr_log_pis.detach() ], 1) min_qf1_loss = torch.logsumexp( cat_q1 / self.args["temp"], dim=1, ).mean() * self.args["min_q_weight"] * self.args["temp"] min_qf2_loss = torch.logsumexp( cat_q2 / self.args["temp"], dim=1, ).mean() * self.args["min_q_weight"] * self.args["temp"] """Subtract the log likelihood of data""" min_qf1_loss = min_qf1_loss - q1_pred.mean() * self.args["min_q_weight"] min_qf2_loss = min_qf2_loss - q2_pred.mean() * self.args["min_q_weight"] if self.args["lagrange_thresh"] >= 0: alpha_prime = torch.clamp(self.log_alpha_prime.exp(), min=0.0, max=1000000.0) min_qf1_loss = alpha_prime * (min_qf1_loss - self.args["lagrange_thresh"]) min_qf2_loss = alpha_prime * (min_qf2_loss - self.args["lagrange_thresh"]) self.alpha_prime_opt.zero_grad() alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5 alpha_prime_loss.backward(retain_graph=True) self.alpha_prime_opt.step() qf1_loss = self.args["explore"] * qf1_loss + ( 2 - self.args["explore"]) * min_qf1_loss qf2_loss = self.args["explore"] * qf2_loss + ( 2 - self.args["explore"]) * min_qf2_loss """ Update critic networks """ self.critic1_opt.zero_grad() qf1_loss.backward(retain_graph=True) self.critic1_opt.step() self.critic2_opt.zero_grad() qf2_loss.backward() self.critic2_opt.step() """ Soft Updates target network """ self._sync_weight(self.critic1_target, self.critic1, self.args["soft_target_tau"]) self._sync_weight(self.critic2_target, self.critic2, self.args["soft_target_tau"]) self._n_train_steps_total += 1
def _train_policy(self, replay_buffer, eval_fn): for it in range(self.args["actor_iterations"]): batch = replay_buffer.sample(self.args["actor_batch_size"]) batch = to_torch(batch, torch.float, device=self.args["device"]) rew = batch.rew done = batch.done obs = batch.obs act = batch.act obs_next = batch.obs_next rew_list = [] obs = [obs for _ in self.bcs] for i in range(5): act = [self.actor_target(o)[0] for o in obs] obs_act = [torch.cat([o, a], axis=1) for o, a in zip(obs, act)] obs_next = [net(oa)[0] for net, oa in zip(self.bcs, obs_act)] obs_act_obs = [ torch.cat([oa, on], axis=1) for oa, on in zip(obs_act, obs_next) ] rew = [net(oao)[0] for net, oao in zip(self.rews, obs_act_obs)] if i == 0: r_p = [ torch.mean(torch.abs(self.vae(o, a)[0] - oa).detach(), axis=1) for o, a, oa in zip(obs, act, obs_act) ] r_p = torch.mean(torch.cat(rew, axis=1), axis=1) rew = torch.cat(rew, axis=1) #rew_min,_ = torch.min(rew, axis=1) #rew_mean = torch.mean(rew, axis=1) #rew = (0.5 * rew_min) + (0.5 *rew_mean) - (0.5*r_p) rew_list.append(rew) obs = obs_next r_e = None for index in range(len(rew_list)): if r_e is None: r_e = rew_list[index] else: r_e += rew_list[index] * (0.99**index) rew_min, _ = torch.min(r_e, axis=1) rew_mean = torch.mean(r_e, axis=1) rew = (0.5 * rew_min) + (0.5 * rew_mean) - (0.5 * r_p) done = batch.done obs = batch.obs act = batch.act obs_next = batch.obs_next # Critic Training with torch.no_grad(): action_next, _ = self.actor_target(obs_next) target_q1 = self.critic1_target(obs_next, action_next) target_q2 = self.critic2_target(obs_next, action_next) target_q = self.args["lmbda"] * torch.min( target_q1, target_q2 ) + (1 - self.args["lmbda"]) * torch.max(target_q1, target_q2) target_q = rew + (1 - done) * self.args["discount"] * target_q current_q1 = self.critic1(obs, act) current_q2 = self.critic2(obs, act) critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss( current_q2, target_q) self.critic1_opt.zero_grad() self.critic2_opt.zero_grad() critic_loss.backward() self.critic1_opt.step() self.critic2_opt.step() # Actor Training action, _ = self.actor(obs) actor_loss = -self.critic1(obs, action).mean() self.actor.zero_grad() actor_loss.backward() self.actor_opt.step() # update target network self._sync_weight(self.actor_target, self.actor) self._sync_weight(self.critic1_target, self.critic1) self._sync_weight(self.critic2_target, self.critic2) if (it + 1) % 1000 == 0: if eval_fn is None: self.eval_policy() else: self.vae._actor = copy.deepcopy(self.actor) res = eval_fn(self.get_policy()) self.log_res((it + 1) // 1000, res)
def get_action(self, obs): obs_tensor = to_torch(obs, device=next(self.parameters()).device, dtype=torch.float32) act = to_array_as(self.policy_infer(obs_tensor), obs) return act