def train_model(self, observations_tensor, ext_returns_tensor, int_returns_tensor, actions_tensor, advantages_tensor, one_channel_observations_tensor, old_log_prob): if flag.DEBUG: print("input observations shape", observations_tensor.shape) print("ext returns shape", ext_returns_tensor.shape) print("int returns shape", int_returns_tensor.shape) print("input actions shape", actions_tensor.shape) print("input advantages shape", advantages_tensor.shape) print("one channel observations", one_channel_observations_tensor.shape) self.new_model.train() self.predictor_model.train() target_value = self.target_model(one_channel_observations_tensor) predictor_value = self.predictor_model(one_channel_observations_tensor) predictor_loss = self.predictor_mse_loss(predictor_value, target_value).mean(-1) mask = torch.rand(len(predictor_loss)).to(self.device) mask = (mask < self.predictor_update_proportion).type( torch.FloatTensor).to(self.device) predictor_loss = (predictor_loss * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) new_policy, ext_new_values, int_new_values = self.new_model( observations_tensor) ext_value_loss = self.mse_loss(ext_new_values, ext_returns_tensor) int_value_loss = self.mse_loss(int_new_values, int_returns_tensor) value_loss = ext_value_loss + int_value_loss softmax_policy = F.softmax(new_policy, dim=1) new_dist = Categorical(softmax_policy) new_log_prob = new_dist.log_prob(actions_tensor) ratio = torch.exp(new_log_prob - old_log_prob) clipped_policy_loss = torch.clamp(ratio, 1.0 - self.clip_range, 1 + self.clip_range) \ * advantages_tensor policy_loss = ratio * advantages_tensor selected_policy_loss = -torch.min(clipped_policy_loss, policy_loss).mean() entropy = new_dist.entropy().mean() self.optimizer.zero_grad() loss = selected_policy_loss + (self.value_coef * value_loss) \ - (self.entropy_coef * entropy) + predictor_loss loss.backward() global_grad_norm_( list(self.new_model.parameters()) + list(self.predictor_model.parameters())) self.optimizer.step() return loss, selected_policy_loss, value_loss, predictor_loss, entropy
def train_just_vae(self, s_batch, next_obs_batch): s_batch = torch.FloatTensor(s_batch).to(self.device) next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device) sample_range = np.arange(len(s_batch)) reconstruction_loss = nn.MSELoss(reduction='none') recon_losses = np.array([]) kld_losses = np.array([]) for i in range(self.epoch): np.random.shuffle(sample_range) for j in range(int(len(s_batch) / self.batch_size)): sample_idx = sample_range[self.batch_size * j:self.batch_size * (j + 1)] # -------------------------------------------------------------------------------- # for generative curiosity (VAE loss) gen_next_state, mu, logvar = self.vae( next_obs_batch[sample_idx]) d = len(gen_next_state.shape) recon_loss = -1 * pytorch_ssim.ssim(gen_next_state, next_obs_batch[sample_idx], size_average=False) # recon_loss = reconstruction_loss(gen_next_state, next_obs_batch[sample_idx]).mean(axis=list(range(1, d))) kld_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(axis=1) # TODO: keep this proportion of experience used for VAE update? # Proportion of experience used for VAE update mask = torch.rand(len(recon_loss)).to(self.device) mask = (mask < self.update_proportion).type( torch.FloatTensor).to(self.device) recon_loss = (recon_loss * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) kld_loss = (kld_loss * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) recon_losses = np.append(recon_losses, recon_loss.detach().cpu().numpy()) kld_losses = np.append(kld_losses, kld_loss.detach().cpu().numpy()) # --------------------------------------------------------------------------------- self.optimizer.zero_grad() loss = recon_loss + kld_loss loss.backward() global_grad_norm_(list(self.vae.parameters())) self.optimizer.step() return recon_losses, kld_losses
def train_model(self, s_batch, target_ext_batch, target_int_batch, y_batch, adv_batch, next_obs_batch, old_policy): s_batch = torch.FloatTensor(s_batch).to(self.device) target_ext_batch = torch.FloatTensor(target_ext_batch).to(self.device) target_int_batch = torch.FloatTensor(target_int_batch).to(self.device) y_batch = torch.LongTensor(y_batch).to(self.device) adv_batch = torch.FloatTensor(adv_batch).to(self.device) next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device) sample_range = np.arange(len(s_batch)) forward_mse = nn.MSELoss(reduction='none') # Get old policy with torch.no_grad(): policy_old_list = torch.stack(old_policy).permute( 1, 0, 2).contiguous().view(-1, self.output_size).to(self.device) m_old = Categorical(F.softmax(policy_old_list, dim=-1)) log_prob_old = m_old.log_prob(y_batch) # ------------------------------------------------------------ for i in range(self.epoch): # Here we'll do minibatches of training np.random.shuffle(sample_range) for j in range(int(len(s_batch) / self.batch_size)): sample_idx = sample_range[self.batch_size * j:self.batch_size * (j + 1)] # -------------------------------------------------------------------------------- # for Curiosity-driven(Random Network Distillation) predict_next_state_feature, target_next_state_feature = self.rnd( next_obs_batch[sample_idx]) forward_loss = forward_mse( predict_next_state_feature, target_next_state_feature.detach()).mean(-1) # Proportion of exp used for predictor update mask = torch.rand(len(forward_loss)).to(self.device) mask = (mask < self.update_proportion).type( torch.FloatTensor).to(self.device) forward_loss = (forward_loss * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) # --------------------------------------------------------------------------------- policy, value_ext, value_int = self.model(s_batch[sample_idx]) m = Categorical(F.softmax(policy, dim=-1)) log_prob = m.log_prob(y_batch[sample_idx]) ratio = torch.exp(log_prob - log_prob_old[sample_idx]) surr1 = ratio * adv_batch[sample_idx] surr2 = torch.clamp(ratio, 1.0 - self.ppo_eps, 1.0 + self.ppo_eps) * adv_batch[sample_idx] # Calculate actor loss # - J is equivalent to max J hence -torch actor_loss = -torch.min(surr1, surr2).mean() # Calculate critic loss critic_ext_loss = F.mse_loss(value_ext.sum(1), target_ext_batch[sample_idx]) critic_int_loss = F.mse_loss(value_int.sum(1), target_int_batch[sample_idx]) # Critic loss = critic E loss + critic I loss critic_loss = critic_ext_loss + critic_int_loss # Calculate the entropy # Entropy is used to improve exploration by limiting the premature convergence to suboptimal policy. entropy = m.entropy().mean() # Reset the gradients self.optimizer.zero_grad() # CALCULATE THE LOSS # Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss + forward_loss loss = actor_loss + 0.5 * critic_loss - self.ent_coef * entropy + forward_loss # Backpropagation loss.backward() global_grad_norm_( list(self.model.parameters()) + list(self.rnd.predictor.parameters())) self.optimizer.step()
def train_model(self, s_batch, target_ext_batch, target_int_batch, y_batch, adv_batch, next_obs_batch, old_policy): s_batch = torch.FloatTensor(s_batch).to(self.device) target_ext_batch = torch.FloatTensor(target_ext_batch).to(self.device) target_int_batch = torch.FloatTensor(target_int_batch).to(self.device) y_batch = torch.LongTensor(y_batch).to(self.device) adv_batch = torch.FloatTensor(adv_batch).to(self.device) next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device) sample_range = np.arange(len(s_batch)) forward_mse = nn.MSELoss(reduction='none') with torch.no_grad(): policy_old_list = torch.stack(old_policy).permute( 1, 0, 2).contiguous().view(-1, self.output_size).to(self.device) m_old = Categorical(F.softmax(policy_old_list, dim=-1)) log_prob_old = m_old.log_prob(y_batch) # ------------------------------------------------------------ for i in range(self.epoch): np.random.shuffle(sample_range) for j in range(int(len(s_batch) / self.batch_size)): sample_idx = sample_range[self.batch_size * j:self.batch_size * (j + 1)] # -------------------------------------------------------------------------------- # for Curiosity-driven(Random Network Distillation) predict_next_state_feature, target_next_state_feature = self.rnd( next_obs_batch[sample_idx]) forward_loss = forward_mse( predict_next_state_feature, target_next_state_feature.detach()).mean(-1) # Proportion of exp used for predictor update mask = torch.rand(len(forward_loss)).to(self.device) mask = (mask < self.update_proportion).type( torch.FloatTensor).to(self.device) forward_loss = (forward_loss * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) # --------------------------------------------------------------------------------- policy, value_ext, value_int = self.model(s_batch[sample_idx]) m = Categorical(F.softmax(policy, dim=-1)) log_prob = m.log_prob(y_batch[sample_idx]) ratio = torch.exp(log_prob - log_prob_old[sample_idx]) surr1 = ratio * adv_batch[sample_idx] surr2 = torch.clamp(ratio, 1.0 - self.ppo_eps, 1.0 + self.ppo_eps) * adv_batch[sample_idx] actor_loss = -torch.min(surr1, surr2).mean() critic_ext_loss = F.mse_loss(value_ext.sum(1), target_ext_batch[sample_idx]) critic_int_loss = F.mse_loss(value_int.sum(1), target_int_batch[sample_idx]) critic_loss = critic_ext_loss + critic_int_loss entropy = m.entropy().mean() self.optimizer.zero_grad() loss = actor_loss + 0.5 * critic_loss - self.ent_coef * entropy + forward_loss loss.backward() global_grad_norm_( list(self.model.parameters()) + list(self.rnd.predictor.parameters())) self.optimizer.step()
def train_model(self, s_batch, target_ext_batch, target_int_batch, y_batch, adv_batch, next_obs_batch, old_policy): s_batch = torch.FloatTensor(s_batch).to(self.device) target_ext_batch = torch.FloatTensor(target_ext_batch).to(self.device) target_int_batch = torch.FloatTensor(target_int_batch).to(self.device) y_batch = torch.LongTensor(y_batch).to(self.device) adv_batch = torch.FloatTensor(adv_batch).to(self.device) next_obs_batch = torch.FloatTensor(next_obs_batch).to(self.device) sample_range = np.arange(len(s_batch)) reconstruction_loss = nn.MSELoss(reduction='none') with torch.no_grad(): policy_old_list = torch.stack(old_policy).permute( 1, 0, 2).contiguous().view(-1, self.output_size).to(self.device) m_old = Categorical(F.softmax(policy_old_list, dim=-1)) log_prob_old = m_old.log_prob(y_batch) # ------------------------------------------------------------ recon_losses = np.array([]) kld_losses = np.array([]) for i in range(self.epoch): np.random.shuffle(sample_range) for j in range(int(len(s_batch) / self.batch_size)): sample_idx = sample_range[self.batch_size * j:self.batch_size * (j + 1)] # -------------------------------------------------------------------------------- # for generative curiosity (VAE loss) gen_next_state, mu, logvar = self.vae( next_obs_batch[sample_idx]) d = len(gen_next_state.shape) recon_loss = reconstruction_loss( gen_next_state, next_obs_batch[sample_idx]).mean(axis=list(range(1, d))) kld_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(axis=1) # TODO: keep this proportion of experience used for VAE update? # Proportion of experience used for VAE update mask = torch.rand(len(recon_loss)).to(self.device) mask = (mask < self.update_proportion).type( torch.FloatTensor).to(self.device) recon_loss = (recon_loss * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) kld_loss = (kld_loss * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) recon_losses = np.append(recon_losses, recon_loss.detach().cpu().numpy()) kld_losses = np.append(kld_losses, kld_loss.detach().cpu().numpy()) # --------------------------------------------------------------------------------- policy, value_ext, value_int = self.model(s_batch[sample_idx]) m = Categorical(F.softmax(policy, dim=-1)) log_prob = m.log_prob(y_batch[sample_idx]) ratio = torch.exp(log_prob - log_prob_old[sample_idx]) surr1 = ratio * adv_batch[sample_idx] surr2 = torch.clamp(ratio, 1.0 - self.ppo_eps, 1.0 + self.ppo_eps) * adv_batch[sample_idx] actor_loss = -torch.min(surr1, surr2).mean() critic_ext_loss = F.mse_loss(value_ext.sum(1), target_ext_batch[sample_idx]) critic_int_loss = F.mse_loss(value_int.sum(1), target_int_batch[sample_idx]) critic_loss = critic_ext_loss + critic_int_loss entropy = m.entropy().mean() self.optimizer.zero_grad() loss = actor_loss + 0.5 * critic_loss - self.ent_coef * entropy + recon_loss + kld_loss loss.backward() global_grad_norm_( list(self.model.parameters()) + list(self.vae.parameters())) self.optimizer.step() return recon_losses, kld_losses
def update(self, o, a, r_i, r_e, mask, o_, log_prob): self.normalizer_obs.update(o_.reshape(-1, 4, 84, 84).copy()) self.normalizer_ri.update(r_i.reshape(-1).copy()) r_i = self.normalizer_ri.normalize(r_i) o_ = self.normalizer_obs.normalize(o_) o = torch.from_numpy(o).to(self.device).float() / 255. returns_ex = np.zeros_like(r_e) returns_in = np.zeros_like(r_e) advantage_ex = np.zeros_like(r_e) advantage_in = np.zeros_like(r_e) for i in range(r_e.shape[0]): action_logits, value_ex, value_in = self.actor_critic(o[i]) value_ex, value_in = value_ex.cpu().detach().numpy(), value_in.cpu( ).detach().numpy() returns_ex[i], _, advantage_ex[i] = self.GAE_caculate( r_e[i], mask[i], value_ex, self.gamma_e, self.lamda) #episodic returns_in[i], _, advantage_in[i] = self.GAE_caculate( r_i[i], np.ones_like(mask[i]), value_in, self.gamma_i, self.lamda) #non_episodic o = o.reshape((-1, 4, 84, 84)) a = np.reshape(a, -1) o_ = np.reshape(o_[:, :, 3, :, :], (-1, 1, 84, 84)) log_prob = np.reshape(log_prob, -1) returns_ex = np.reshape(returns_ex, -1) returns_in = np.reshape(returns_in, -1) advantage_ex = np.reshape(advantage_ex, -1) advantage_in = np.reshape(advantage_in, -1) a = torch.from_numpy(a).float().to(self.device) o_ = torch.from_numpy(o_).float().to(self.device).float() log_prob = torch.from_numpy(log_prob).float().to(self.device) returns_ex = torch.from_numpy(returns_ex).float().to( self.device).unsqueeze(dim=1) returns_in = torch.from_numpy(returns_in).float().to( self.device).unsqueeze(dim=1) advantage_ex = torch.from_numpy(advantage_ex).float().to(self.device) advantage_in = torch.from_numpy(advantage_in).float().to(self.device) sample_range = list(range(len(o))) for i_update in range(self.update_epoch): np.random.shuffle(sample_range) for j in range(int(len(o) / self.batch_size)): idx = sample_range[self.batch_size * j:self.batch_size * (j + 1)] #update RND pred_RND, tar_RND = self.RND(o_[idx]) loss_RND = F.mse_loss(pred_RND, tar_RND.detach(), reduction='none').mean(-1) mask = torch.randn(len(loss_RND)).to(self.device) mask = (mask < self.update_proportion).type( torch.FloatTensor).to(self.device) loss_RND = (loss_RND * mask).sum() / torch.max( mask.sum(), torch.Tensor([1]).to(self.device)) #update actor-critic action_logits, value_ex, value_in = self.actor_critic(o[idx]) advantage = self.ex_coef * advantage_ex[ idx] + self.in_coef * advantage_in[idx] dist = Categorical(action_logits) new_log_prob = dist.log_prob(a[idx]) ratio = torch.exp(new_log_prob - log_prob[idx]) surr1 = ratio * advantage surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantage loss_actor = torch.min( surr1, surr2).mean() - self.entropy_coef * dist.entropy().mean() loss_critic = F.mse_loss(value_ex, returns_ex[idx]) + F.mse_loss( value_in, returns_in[idx]) loss_ac = loss_actor + 0.5 * loss_critic loss = loss_RND + loss_ac self.optimizer.zero_grad() loss.backward() global_grad_norm_( list(self.actor_critic.parameters()) + list(self.RND.predictor.parameters())) self.optimizer.step() return loss_RND.cpu().detach().numpy(), loss_actor.cpu().detach( ).numpy(), loss_critic.cpu().detach().numpy()