def expand(self, batch_shape): return MultivariateNormal.expand(self, batch_shape, _instance=self)
def train(self, iteration_num=100, log_dir='log', model_save_period=10, render=False, debug=False): """ :param iteration_num: :param log_dir: :param model_save_period: :param render: :param debug: """ self.render = render model_save_dir = os.path.join(log_dir, 'model') if not os.path.exists(model_save_dir): os.makedirs(model_save_dir) writer = SummaryWriter(os.path.join(log_dir, 'tb')) for it in range(self.iteration, iteration_num): self.__sample_trajectory(self.sample_episode_num) buff_sz = len(self.replaybuffer) mean_reward = self.replaybuffer.mean_reward() mean_return = self.replaybuffer.mean_return() mean_loss_q = [] mean_loss_p = [] mean_loss_l = [] mean_est_q = [] max_kl_μ = [] max_kl_Σ = [] max_kl = [] # Find better policy by gradient descent for r in range(self.episode_rerun_num): for indices in tqdm( BatchSampler( SubsetRandomSampler(range(buff_sz)), self.batch_size, False), desc='training {}/{}'.format(r+1, self.episode_rerun_num)): B = len(indices) M = self.sample_action_num ds = self.ds da = self.da state_batch, action_batch, next_state_batch, reward_batch = zip( *[self.replaybuffer[index] for index in indices]) state_batch = torch.from_numpy(np.stack(state_batch)).type(torch.float32).to(self.device) # (B, ds) action_batch = torch.from_numpy(np.stack(action_batch)).type(torch.float32).to(self.device) # (B, da) or (B,) next_state_batch = torch.from_numpy(np.stack(next_state_batch)).type(torch.float32).to(self.device) # (B, ds) reward_batch = torch.from_numpy(np.stack(reward_batch)).type(torch.float32).to(self.device) # (B,) # Policy Evaluation loss_q, q = self.__update_critic_td( state_batch=state_batch, action_batch=action_batch, next_state_batch=next_state_batch, reward_batch=reward_batch, sample_num=self.sample_action_num ) if loss_q is None: raise RuntimeError('invalid policy evaluation') mean_loss_q.append(loss_q.item()) mean_est_q.append(q.abs().mean().item()) # sample M additional action for each state with torch.no_grad(): if self.continuous_action_space: b_μ, b_A = self.target_actor.forward(state_batch) # (B,) b = MultivariateNormal(b_μ, scale_tril=b_A) # (B,) sampled_actions = b.sample((M,)) # (M, B, da) expanded_states = state_batch[None, ...].expand(M, -1, -1) # (M, B, ds) target_q = self.target_critic.forward( expanded_states.reshape(-1, ds), # (M * B, ds) sampled_actions.reshape(-1, da) # (M * B, da) ).reshape(M, B) # (M, B) target_q_np = target_q.cpu().numpy() # (M, B) else: actions = torch.arange(da)[..., None].expand(da, B).to(self.device) # (da, B) b_p = self.target_actor.forward(state_batch) # (B, da) b = Categorical(probs=b_p) # (B,) b_prob = b.expand((da, B)).log_prob(actions).exp() # (da, B) expanded_actions = self.A_eye[None, ...].expand(B, -1, -1) # (B, da, da) expanded_states = state_batch.reshape(B, 1, ds).expand((B, da, ds)) # (B, da, ds) target_q = ( self.target_critic.forward( expanded_states.reshape(-1, ds), # (B * da, ds) expanded_actions.reshape(-1, da) # (B * da, da) ).reshape(B, da) # (B, da) ).transpose(0, 1) # (da, B) b_prob_np = b_prob.cpu().numpy() # (da, B) target_q_np = target_q.cpu().numpy() # (da, B) # E-step if self.continuous_action_space: def dual(η): """ dual function of the non-parametric variational g(η) = η*ε + η \sum \log (\sum \exp(Q(a, s)/η)) """ max_q = np.max(target_q_np, 0) return η * self.ε_dual + np.mean(max_q) \ + η * np.mean(np.log(np.mean(np.exp((target_q_np - max_q) / η), axis=0))) else: def dual(η): """ dual function of the non-parametric variational g(η) = η*ε + η \sum \log (\sum \exp(Q(a, s)/η)) """ max_q = np.max(target_q_np, 0) return η * self.ε_dual + np.mean(max_q) \ + η * np.mean(np.log(np.sum(b_prob_np * np.exp((target_q_np - max_q) / η), axis=0))) bounds = [(1e-6, None)] res = minimize(dual, np.array([self.η]), method='SLSQP', bounds=bounds) self.η = res.x[0] qij = torch.softmax(target_q / self.η, dim=0) # (M, B) or (da, B) # M-step # update policy based on lagrangian for _ in range(self.lagrange_iteration_num): if self.continuous_action_space: μ, A = self.actor.forward(state_batch) π1 = MultivariateNormal(loc=μ, scale_tril=b_A) # (B,) π2 = MultivariateNormal(loc=b_μ, scale_tril=A) # (B,) loss_p = torch.mean( qij * ( π1.expand((M, B)).log_prob(sampled_actions) # (M, B) + π2.expand((M, B)).log_prob(sampled_actions) # (M, B) ) ) mean_loss_p.append((-loss_p).item()) kl_μ, kl_Σ = gaussian_kl( μi=b_μ, μ=μ, Ai=b_A, A=A) max_kl_μ.append(kl_μ.item()) max_kl_Σ.append(kl_Σ.item()) if debug and np.isnan(kl_μ.item()): print('kl_μ is nan') embed() if debug and np.isnan(kl_Σ.item()): print('kl_Σ is nan') embed() # Update lagrange multipliers by gradient descent self.η_kl_μ -= self.α * (self.ε_kl_μ - kl_μ).detach().item() self.η_kl_Σ -= self.α * (self.ε_kl_Σ - kl_Σ).detach().item() if self.η_kl_μ < 0.0: self.η_kl_μ = 0.0 if self.η_kl_Σ < 0.0: self.η_kl_Σ = 0.0 self.actor_optimizer.zero_grad() loss_l = -( loss_p + self.η_kl_μ * (self.ε_kl_μ - kl_μ) + self.η_kl_Σ * (self.ε_kl_Σ - kl_Σ) ) mean_loss_l.append(loss_l.item()) loss_l.backward() clip_grad_norm_(self.actor.parameters(), 0.1) self.actor_optimizer.step() else: π_p = self.actor.forward(state_batch) # (B, da) π = Categorical(probs=π_p) # (B,) loss_p = torch.mean( qij * π.expand((da, B)).log_prob(actions) ) mean_loss_p.append((-loss_p).item()) kl = categorical_kl(p1=π_p, p2=b_p) max_kl.append(kl.item()) if debug and np.isnan(kl.item()): print('kl is nan') embed() # Update lagrange multipliers by gradient descent self.η_kl -= self.α * (self.ε_kl - kl).detach().item() if self.η_kl < 0.0: self.η_kl = 0.0 self.actor_optimizer.zero_grad() loss_l = -(loss_p + self.η_kl * (self.ε_kl - kl)) mean_loss_l.append(loss_l.item()) loss_l.backward() clip_grad_norm_(self.actor.parameters(), 0.1) self.actor_optimizer.step() self.__update_param() self.η_kl_μ = 0.0 self.η_kl_Σ = 0.0 self.η_kl = 0.0 it = it + 1 mean_loss_q = np.mean(mean_loss_q) mean_loss_p = np.mean(mean_loss_p) mean_loss_l = np.mean(mean_loss_l) mean_est_q = np.mean(mean_est_q) if self.continuous_action_space: max_kl_μ = np.max(max_kl_μ) max_kl_Σ = np.max(max_kl_Σ) else: max_kl = np.max(max_kl) print('iteration :', it) print(' mean return :', mean_return) print(' mean reward :', mean_reward) print(' mean loss_q :', mean_loss_q) print(' mean loss_p :', mean_loss_p) print(' mean loss_l :', mean_loss_l) print(' mean est_q :', mean_est_q) print(' η :', self.η) if self.continuous_action_space: print(' max_kl_μ :', max_kl_μ) print(' max_kl_Σ :', max_kl_Σ) else: print(' max_kl :', max_kl) # saving and logging self.save_model(os.path.join(model_save_dir, 'model_latest.pt')) if it % model_save_period == 0: self.save_model(os.path.join(model_save_dir, 'model_{}.pt'.format(it))) writer.add_scalar('return', mean_return, it) writer.add_scalar('reward', mean_reward, it) writer.add_scalar('loss_q', mean_loss_q, it) writer.add_scalar('loss_p', mean_loss_p, it) writer.add_scalar('loss_l', mean_loss_l, it) writer.add_scalar('mean_q', mean_est_q, it) writer.add_scalar('η', self.η, it) if self.continuous_action_space: writer.add_scalar('max_kl_μ', max_kl_μ, it) writer.add_scalar('max_kl_Σ', max_kl_Σ, it) else: writer.add_scalar('η_kl', max_kl, it) writer.flush() # end training if writer is not None: writer.close()