コード例 #1
0
ファイル: torch_priors.py プロジェクト: ibrahmd/gpytorch
 def expand(self, batch_shape):
     return MultivariateNormal.expand(self, batch_shape, _instance=self)
コード例 #2
0
ファイル: mpo.py プロジェクト: yj-Tang/mpo
    def train(self,
              iteration_num=100,
              log_dir='log',
              model_save_period=10,
              render=False,
              debug=False):
        """
        :param iteration_num:
        :param log_dir:
        :param model_save_period:
        :param render:
        :param debug:
        """

        self.render = render

        model_save_dir = os.path.join(log_dir, 'model')
        if not os.path.exists(model_save_dir):
            os.makedirs(model_save_dir)
        writer = SummaryWriter(os.path.join(log_dir, 'tb'))

        for it in range(self.iteration, iteration_num):
            self.__sample_trajectory(self.sample_episode_num)
            buff_sz = len(self.replaybuffer)

            mean_reward = self.replaybuffer.mean_reward()
            mean_return = self.replaybuffer.mean_return()
            mean_loss_q = []
            mean_loss_p = []
            mean_loss_l = []
            mean_est_q = []
            max_kl_μ = []
            max_kl_Σ = []
            max_kl = []

            # Find better policy by gradient descent
            for r in range(self.episode_rerun_num):
                for indices in tqdm(
                        BatchSampler(
                            SubsetRandomSampler(range(buff_sz)), self.batch_size, False),
                        desc='training {}/{}'.format(r+1, self.episode_rerun_num)):
                    B = len(indices)
                    M = self.sample_action_num
                    ds = self.ds
                    da = self.da

                    state_batch, action_batch, next_state_batch, reward_batch = zip(
                        *[self.replaybuffer[index] for index in indices])

                    state_batch = torch.from_numpy(np.stack(state_batch)).type(torch.float32).to(self.device)  # (B, ds)
                    action_batch = torch.from_numpy(np.stack(action_batch)).type(torch.float32).to(self.device)  # (B, da) or (B,)
                    next_state_batch = torch.from_numpy(np.stack(next_state_batch)).type(torch.float32).to(self.device)  # (B, ds)
                    reward_batch = torch.from_numpy(np.stack(reward_batch)).type(torch.float32).to(self.device)  # (B,)

                    # Policy Evaluation
                    loss_q, q = self.__update_critic_td(
                        state_batch=state_batch,
                        action_batch=action_batch,
                        next_state_batch=next_state_batch,
                        reward_batch=reward_batch,
                        sample_num=self.sample_action_num
                    )
                    if loss_q is None:
                        raise RuntimeError('invalid policy evaluation')
                    mean_loss_q.append(loss_q.item())
                    mean_est_q.append(q.abs().mean().item())

                    # sample M additional action for each state
                    with torch.no_grad():
                        if self.continuous_action_space:
                            b_μ, b_A = self.target_actor.forward(state_batch)  # (B,)
                            b = MultivariateNormal(b_μ, scale_tril=b_A)  # (B,)
                            sampled_actions = b.sample((M,))  # (M, B, da)
                            expanded_states = state_batch[None, ...].expand(M, -1, -1)  # (M, B, ds)
                            target_q = self.target_critic.forward(
                                expanded_states.reshape(-1, ds),  # (M * B, ds)
                                sampled_actions.reshape(-1, da)  # (M * B, da)
                            ).reshape(M, B)  # (M, B)
                            target_q_np = target_q.cpu().numpy()  # (M, B)
                        else:
                            actions = torch.arange(da)[..., None].expand(da, B).to(self.device)  # (da, B)
                            b_p = self.target_actor.forward(state_batch)  # (B, da)
                            b = Categorical(probs=b_p)  # (B,)
                            b_prob = b.expand((da, B)).log_prob(actions).exp()  # (da, B)
                            expanded_actions = self.A_eye[None, ...].expand(B, -1, -1)  # (B, da, da)
                            expanded_states = state_batch.reshape(B, 1, ds).expand((B, da, ds))  # (B, da, ds)
                            target_q = (
                                self.target_critic.forward(
                                    expanded_states.reshape(-1, ds),  # (B * da, ds)
                                    expanded_actions.reshape(-1, da)  # (B * da, da)
                                ).reshape(B, da)  # (B, da)
                            ).transpose(0, 1)  # (da, B)
                            b_prob_np = b_prob.cpu().numpy()  # (da, B)
                            target_q_np = target_q.cpu().numpy()  # (da, B)

                    # E-step
                    if self.continuous_action_space:
                        def dual(η):
                            """
                            dual function of the non-parametric variational
                            g(η) = η*ε + η \sum \log (\sum \exp(Q(a, s)/η))
                            """
                            max_q = np.max(target_q_np, 0)
                            return η * self.ε_dual + np.mean(max_q) \
                                + η * np.mean(np.log(np.mean(np.exp((target_q_np - max_q) / η), axis=0)))
                    else:
                        def dual(η):
                            """
                            dual function of the non-parametric variational
                            g(η) = η*ε + η \sum \log (\sum \exp(Q(a, s)/η))
                            """
                            max_q = np.max(target_q_np, 0)
                            return η * self.ε_dual + np.mean(max_q) \
                                + η * np.mean(np.log(np.sum(b_prob_np * np.exp((target_q_np - max_q) / η), axis=0)))

                    bounds = [(1e-6, None)]
                    res = minimize(dual, np.array([self.η]), method='SLSQP', bounds=bounds)
                    self.η = res.x[0]

                    qij = torch.softmax(target_q / self.η, dim=0)  # (M, B) or (da, B)

                    # M-step
                    # update policy based on lagrangian
                    for _ in range(self.lagrange_iteration_num):
                        if self.continuous_action_space:
                            μ, A = self.actor.forward(state_batch)
                            π1 = MultivariateNormal(loc=μ, scale_tril=b_A)  # (B,)
                            π2 = MultivariateNormal(loc=b_μ, scale_tril=A)  # (B,)
                            loss_p = torch.mean(
                                qij * (
                                    π1.expand((M, B)).log_prob(sampled_actions)  # (M, B)
                                    + π2.expand((M, B)).log_prob(sampled_actions)  # (M, B)
                                )
                            )
                            mean_loss_p.append((-loss_p).item())

                            kl_μ, kl_Σ = gaussian_kl(
                                μi=b_μ, μ=μ,
                                Ai=b_A, A=A)
                            max_kl_μ.append(kl_μ.item())
                            max_kl_Σ.append(kl_Σ.item())

                            if debug and np.isnan(kl_μ.item()):
                                print('kl_μ is nan')
                                embed()
                            if debug and np.isnan(kl_Σ.item()):
                                print('kl_Σ is nan')
                                embed()

                            # Update lagrange multipliers by gradient descent
                            self.η_kl_μ -= self.α * (self.ε_kl_μ - kl_μ).detach().item()
                            self.η_kl_Σ -= self.α * (self.ε_kl_Σ - kl_Σ).detach().item()

                            if self.η_kl_μ < 0.0:
                                self.η_kl_μ = 0.0
                            if self.η_kl_Σ < 0.0:
                                self.η_kl_Σ = 0.0

                            self.actor_optimizer.zero_grad()
                            loss_l = -(
                                    loss_p
                                    + self.η_kl_μ * (self.ε_kl_μ - kl_μ)
                                    + self.η_kl_Σ * (self.ε_kl_Σ - kl_Σ)
                            )
                            mean_loss_l.append(loss_l.item())
                            loss_l.backward()
                            clip_grad_norm_(self.actor.parameters(), 0.1)
                            self.actor_optimizer.step()
                        else:
                            π_p = self.actor.forward(state_batch)  # (B, da)
                            π = Categorical(probs=π_p)  # (B,)
                            loss_p = torch.mean(
                                qij * π.expand((da, B)).log_prob(actions)
                            )
                            mean_loss_p.append((-loss_p).item())

                            kl = categorical_kl(p1=π_p, p2=b_p)
                            max_kl.append(kl.item())

                            if debug and np.isnan(kl.item()):
                                print('kl is nan')
                                embed()

                            # Update lagrange multipliers by gradient descent
                            self.η_kl -= self.α * (self.ε_kl - kl).detach().item()

                            if self.η_kl < 0.0:
                                self.η_kl = 0.0

                            self.actor_optimizer.zero_grad()
                            loss_l = -(loss_p + self.η_kl * (self.ε_kl - kl))
                            mean_loss_l.append(loss_l.item())
                            loss_l.backward()
                            clip_grad_norm_(self.actor.parameters(), 0.1)
                            self.actor_optimizer.step()

            self.__update_param()

            self.η_kl_μ = 0.0
            self.η_kl_Σ = 0.0
            self.η_kl = 0.0

            it = it + 1
            mean_loss_q = np.mean(mean_loss_q)
            mean_loss_p = np.mean(mean_loss_p)
            mean_loss_l = np.mean(mean_loss_l)
            mean_est_q = np.mean(mean_est_q)
            if self.continuous_action_space:
                max_kl_μ = np.max(max_kl_μ)
                max_kl_Σ = np.max(max_kl_Σ)
            else:
                max_kl = np.max(max_kl)

            print('iteration :', it)
            print('  mean return :', mean_return)
            print('  mean reward :', mean_reward)
            print('  mean loss_q :', mean_loss_q)
            print('  mean loss_p :', mean_loss_p)
            print('  mean loss_l :', mean_loss_l)
            print('  mean est_q :', mean_est_q)
            print('  η :', self.η)
            if self.continuous_action_space:
                print('  max_kl_μ :', max_kl_μ)
                print('  max_kl_Σ :', max_kl_Σ)
            else:
                print('  max_kl :', max_kl)

            # saving and logging
            self.save_model(os.path.join(model_save_dir, 'model_latest.pt'))
            if it % model_save_period == 0:
                self.save_model(os.path.join(model_save_dir, 'model_{}.pt'.format(it)))
            writer.add_scalar('return', mean_return, it)
            writer.add_scalar('reward', mean_reward, it)
            writer.add_scalar('loss_q', mean_loss_q, it)
            writer.add_scalar('loss_p', mean_loss_p, it)
            writer.add_scalar('loss_l', mean_loss_l, it)
            writer.add_scalar('mean_q', mean_est_q, it)
            writer.add_scalar('η', self.η, it)
            if self.continuous_action_space:
                writer.add_scalar('max_kl_μ', max_kl_μ, it)
                writer.add_scalar('max_kl_Σ', max_kl_Σ, it)
            else:
                writer.add_scalar('η_kl', max_kl, it)
            writer.flush()

        # end training
        if writer is not None:
            writer.close()