Example #1
0
class TAC(SarlOffPolicy):
    """Tsallis Actor Critic, TAC with V neural Network. https://arxiv.org/abs/1902.00137
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 alpha=0.2,
                 annealing=True,
                 last_alpha=0.01,
                 polyak=0.995,
                 entropic_index=1.5,
                 discrete_tau=1.0,
                 network_settings={
                     'actor_continuous': {
                         'share': [128, 128],
                         'mu': [64],
                         'log_std': [64],
                         'soft_clip': False,
                         'log_std_bound': [-20, 2]
                     },
                     'actor_discrete': [64, 32],
                     'q': [128, 128]
                 },
                 auto_adaption=True,
                 actor_lr=5.0e-4,
                 critic_lr=1.0e-3,
                 alpha_lr=5.0e-4,
                 **kwargs):
        super().__init__(**kwargs)
        self.polyak = polyak
        self.discrete_tau = discrete_tau
        self.entropic_index = 2 - entropic_index
        self.auto_adaption = auto_adaption
        self.annealing = annealing

        self.critic = TargetTwin(CriticQvalueOne(self.obs_spec,
                                                 rep_net_params=self._rep_net_params,
                                                 action_dim=self.a_dim,
                                                 network_settings=network_settings['q']),
                                 self.polyak).to(self.device)
        self.critic2 = deepcopy(self.critic)

        if self.is_continuous:
            self.actor = ActorCts(self.obs_spec,
                                  rep_net_params=self._rep_net_params,
                                  output_shape=self.a_dim,
                                  network_settings=network_settings['actor_continuous']).to(self.device)
        else:
            self.actor = ActorDct(self.obs_spec,
                                  rep_net_params=self._rep_net_params,
                                  output_shape=self.a_dim,
                                  network_settings=network_settings['actor_discrete']).to(self.device)

        # entropy = -log(1/|A|) = log |A|
        self.target_entropy = 0.98 * (-self.a_dim if self.is_continuous else np.log(self.a_dim))

        self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params)
        self.critic_oplr = OPLR([self.critic, self.critic2], critic_lr, **self._oplr_params)

        if self.auto_adaption:
            self.log_alpha = th.tensor(0., requires_grad=True).to(self.device)
            self.alpha_oplr = OPLR(self.log_alpha, alpha_lr, **self._oplr_params)
            self._trainer_modules.update(alpha_oplr=self.alpha_oplr)
        else:
            self.log_alpha = th.tensor(alpha).log().to(self.device)
            if self.annealing:
                self.alpha_annealing = LinearAnnealing(alpha, last_alpha, int(1e6))

        self._trainer_modules.update(actor=self.actor,
                                     critic=self.critic,
                                     critic2=self.critic2,
                                     log_alpha=self.log_alpha,
                                     actor_oplr=self.actor_oplr,
                                     critic_oplr=self.critic_oplr)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    @iton
    def select_action(self, obs):
        if self.is_continuous:
            mu, log_std = self.actor(obs, rnncs=self.rnncs)  # [B, A]
            pi = td.Normal(mu, log_std.exp()).sample().tanh()  # [B, A]
            mu.tanh_()  # squash mu     # [B, A]
        else:
            logits = self.actor(obs, rnncs=self.rnncs)  # [B, A]
            mu = logits.argmax(-1)  # [B,]
            cate_dist = td.Categorical(logits=logits)
            pi = cate_dist.sample()  # [B,]
        self.rnncs_ = self.actor.get_rnncs()
        actions = pi if self._is_train_mode else mu
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        if self.is_continuous:
            target_mu, target_log_std = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(target_mu, target_log_std.exp()), 1)
            target_pi = dist.sample()  # [T, B, A]
            target_pi, target_log_pi = squash_action(target_pi, dist.log_prob(
                target_pi).unsqueeze(-1), is_independent=False)  # [T, B, A]
            target_log_pi = tsallis_entropy_log_q(target_log_pi, self.entropic_index)  # [T, B, 1]
        else:
            target_logits = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            target_cate_dist = td.Categorical(logits=target_logits)
            target_pi = target_cate_dist.sample()  # [T, B]
            target_log_pi = target_cate_dist.log_prob(target_pi).unsqueeze(-1)  # [T, B, 1]
            target_pi = F.one_hot(target_pi, self.a_dim).float()  # [T, B, A]
        q1 = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q2 = self.critic2(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask)  # [T, B, 1]

        q1_target = self.critic.t(BATCH.obs_, target_pi, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q2_target = self.critic2.t(BATCH.obs_, target_pi, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q_target = th.minimum(q1_target, q2_target)  # [T, B, 1]
        dc_r = n_step_return(BATCH.reward,
                             self.gamma,
                             BATCH.done,
                             (q_target - self.alpha * target_log_pi),
                             BATCH.begin_mask).detach()  # [T, B, 1]
        td_error1 = q1 - dc_r  # [T, B, 1]
        td_error2 = q2 - dc_r  # [T, B, 1]

        q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
        q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
        critic_loss = 0.5 * q1_loss + 0.5 * q2_loss
        self.critic_oplr.optimize(critic_loss)

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            pi = dist.rsample()  # [T, B, A]
            pi, log_pi = squash_action(pi, dist.log_prob(pi).unsqueeze(-1), is_independent=False)  # [T, B, A]
            log_pi = tsallis_entropy_log_q(log_pi, self.entropic_index)  # [T, B, 1]
            entropy = dist.entropy().mean()  # 1
        else:
            logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(-1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            pi = _pi_diff + _pi  # [T, B, A]
            log_pi = (logp_all * pi).sum(-1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(-1).mean()  # 1
        q_s_pi = th.minimum(self.critic(BATCH.obs, pi, begin_mask=BATCH.begin_mask),
                            self.critic2(BATCH.obs, pi, begin_mask=BATCH.begin_mask))  # [T, B, 1]
        actor_loss = -(q_s_pi - self.alpha * log_pi).mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        summaries = {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/q1_loss': q1_loss,
            'LOSS/q2_loss': q2_loss,
            'LOSS/critic_loss': critic_loss,
            'Statistics/log_alpha': self.log_alpha,
            'Statistics/alpha': self.alpha,
            'Statistics/entropy': entropy,
            'Statistics/q_min': th.minimum(q1, q2).min(),
            'Statistics/q_mean': th.minimum(q1, q2).mean(),
            'Statistics/q_max': th.maximum(q1, q2).max()
        }
        if self.auto_adaption:
            alpha_loss = -(self.alpha * (log_pi + self.target_entropy).detach()).mean()  # 1
            self.alpha_oplr.optimize(alpha_loss)
            summaries.update({
                'LOSS/alpha_loss': alpha_loss,
                'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr
            })
        return (td_error1 + td_error2) / 2, summaries

    def _after_train(self):
        super()._after_train()
        self.critic.sync()
        self.critic2.sync()

        if self.annealing and not self.auto_adaption:
            self.log_alpha.copy_(self.alpha_annealing(self._cur_train_step).log())
Example #2
0
class PPO(SarlOnPolicy):
    """
    Proximal Policy Optimization, https://arxiv.org/abs/1707.06347
    Emergence of Locomotion Behaviours in Rich Environments, http://arxiv.org/abs/1707.02286, DPPO
    """
    policy_mode = 'on-policy'

    def __init__(self,
                 agent_spec,

                 ent_coef: float = 1.0e-2,
                 vf_coef: float = 0.5,
                 lr: float = 5.0e-4,
                 lambda_: float = 0.95,
                 epsilon: float = 0.2,
                 use_duel_clip: bool = False,
                 duel_epsilon: float = 0.,
                 use_vclip: bool = False,
                 value_epsilon: float = 0.2,
                 share_net: bool = True,
                 actor_lr: float = 3e-4,
                 critic_lr: float = 1e-3,
                 kl_reverse: bool = False,
                 kl_target: float = 0.02,
                 kl_target_cutoff: float = 2,
                 kl_target_earlystop: float = 4,
                 kl_beta: List[float] = [0.7, 1.3],
                 kl_alpha: float = 1.5,
                 kl_coef: float = 1.0,
                 extra_coef: float = 1000.0,
                 use_kl_loss: bool = False,
                 use_extra_loss: bool = False,
                 use_early_stop: bool = False,
                 network_settings: Dict = {
                     'share': {
                         'continuous': {
                             'condition_sigma': False,
                             'log_std_bound': [-20, 2],
                             'share': [32, 32],
                             'mu': [32, 32],
                             'v': [32, 32]
                         },
                         'discrete': {
                             'share': [32, 32],
                             'logits': [32, 32],
                             'v': [32, 32]
                         }
                     },
                     'actor_continuous': {
                         'hidden_units': [64, 64],
                         'condition_sigma': False,
                         'log_std_bound': [-20, 2]
                     },
                     'actor_discrete': [32, 32],
                     'critic': [32, 32]
                 },
                 **kwargs):
        super().__init__(agent_spec=agent_spec, **kwargs)
        self._ent_coef = ent_coef
        self.lambda_ = lambda_
        assert 0.0 <= lambda_ <= 1.0, "GAE lambda should be in [0, 1]."
        self._epsilon = epsilon
        self._use_vclip = use_vclip
        self._value_epsilon = value_epsilon
        self._share_net = share_net
        self._kl_reverse = kl_reverse
        self._kl_target = kl_target
        self._kl_alpha = kl_alpha
        self._kl_coef = kl_coef
        self._extra_coef = extra_coef
        self._vf_coef = vf_coef

        self._use_duel_clip = use_duel_clip
        self._duel_epsilon = duel_epsilon
        if self._use_duel_clip:
            assert - \
                       self._epsilon < self._duel_epsilon < self._epsilon, "duel_epsilon should be set in the range of (-epsilon, epsilon)."

        self._kl_cutoff = kl_target * kl_target_cutoff
        self._kl_stop = kl_target * kl_target_earlystop
        self._kl_low = kl_target * kl_beta[0]
        self._kl_high = kl_target * kl_beta[-1]

        self._use_kl_loss = use_kl_loss
        self._use_extra_loss = use_extra_loss
        self._use_early_stop = use_early_stop

        if self._share_net:
            if self.is_continuous:
                self.net = ActorCriticValueCts(self.obs_spec,
                                               rep_net_params=self._rep_net_params,
                                               output_shape=self.a_dim,
                                               network_settings=network_settings['share']['continuous']).to(self.device)
            else:
                self.net = ActorCriticValueDct(self.obs_spec,
                                               rep_net_params=self._rep_net_params,
                                               output_shape=self.a_dim,
                                               network_settings=network_settings['share']['discrete']).to(self.device)
            self.oplr = OPLR(self.net, lr, **self._oplr_params)
            self._trainer_modules.update(model=self.net,
                                         oplr=self.oplr)
        else:
            if self.is_continuous:
                self.actor = ActorMuLogstd(self.obs_spec,
                                           rep_net_params=self._rep_net_params,
                                           output_shape=self.a_dim,
                                           network_settings=network_settings['actor_continuous']).to(self.device)
            else:
                self.actor = ActorDct(self.obs_spec,
                                      rep_net_params=self._rep_net_params,
                                      output_shape=self.a_dim,
                                      network_settings=network_settings['actor_discrete']).to(self.device)
            self.critic = CriticValue(self.obs_spec,
                                      rep_net_params=self._rep_net_params,
                                      network_settings=network_settings['critic']).to(self.device)
            self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params)
            self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params)
            self._trainer_modules.update(actor=self.actor,
                                         critic=self.critic,
                                         actor_oplr=self.actor_oplr,
                                         critic_oplr=self.critic_oplr)

    @iton
    def select_action(self, obs):
        if self.is_continuous:
            if self._share_net:
                mu, log_std, value = self.net(obs, rnncs=self.rnncs)  # [B, A]
                self.rnncs_ = self.net.get_rnncs()
            else:
                mu, log_std = self.actor(obs, rnncs=self.rnncs)  # [B, A]
                self.rnncs_ = self.actor.get_rnncs()
                value = self.critic(obs, rnncs=self.rnncs)  # [B, 1]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            action = dist.sample().clamp(-1, 1)  # [B, A]
            log_prob = dist.log_prob(action).unsqueeze(-1)  # [B, 1]
        else:
            if self._share_net:
                logits, value = self.net(obs, rnncs=self.rnncs)  # [B, A], [B, 1]
                self.rnncs_ = self.net.get_rnncs()
            else:
                logits = self.actor(obs, rnncs=self.rnncs)  # [B, A]
                self.rnncs_ = self.actor.get_rnncs()
                value = self.critic(obs, rnncs=self.rnncs)  # [B, 1]
            norm_dist = td.Categorical(logits=logits)
            action = norm_dist.sample()  # [B,]
            log_prob = norm_dist.log_prob(action).unsqueeze(-1)  # [B, 1]

        acts_info = Data(action=action,
                         value=value,
                         log_prob=log_prob + th.finfo().eps)
        if self.use_rnn:
            acts_info.update(rnncs=self.rnncs)
        return action, acts_info

    @iton
    def _get_value(self, obs, rnncs=None):
        if self._share_net:
            if self.is_continuous:
                _, _, value = self.net(obs, rnncs=rnncs)  # [B, 1]
            else:
                _, value = self.net(obs, rnncs=rnncs)  # [B, 1]
        else:
            value = self.critic(obs, rnncs=rnncs)  # [B, 1]
        return value

    def _preprocess_BATCH(self, BATCH):  # [T, B, *]
        BATCH = super()._preprocess_BATCH(BATCH)
        value = self._get_value(BATCH.obs_[-1], rnncs=self.rnncs)
        BATCH.discounted_reward = discounted_sum(BATCH.reward,
                                                 self.gamma,
                                                 BATCH.done,
                                                 BATCH.begin_mask,
                                                 init_value=value)
        td_error = calculate_td_error(BATCH.reward,
                                      self.gamma,
                                      BATCH.done,
                                      value=BATCH.value,
                                      next_value=np.concatenate((BATCH.value[1:], value[np.newaxis, :]), 0))
        BATCH.gae_adv = discounted_sum(td_error,
                                       self.lambda_ * self.gamma,
                                       BATCH.done,
                                       BATCH.begin_mask,
                                       init_value=0.,
                                       normalize=True)
        return BATCH

    def learn(self, BATCH: Data):
        BATCH = self._preprocess_BATCH(BATCH)  # [T, B, *]
        for _ in range(self._epochs):
            kls = []
            for _BATCH in BATCH.sample(self._chunk_length, self.batch_size, repeat=self._sample_allow_repeat):
                _BATCH = self._before_train(_BATCH)
                summaries, kl = self._train(_BATCH)
                kls.append(kl)
                self.summaries.update(summaries)
                self._after_train()
            if self._use_early_stop and sum(kls) / len(kls) > self._kl_stop:
                break

    def _train(self, BATCH):
        if self._share_net:
            summaries, kl = self.train_share(BATCH)
        else:
            summaries = dict()
            actor_summaries, kl = self.train_actor(BATCH)
            critic_summaries = self.train_critic(BATCH)
            summaries.update(actor_summaries)
            summaries.update(critic_summaries)

        if self._use_kl_loss:
            # ref: https://github.com/joschu/modular_rl/blob/6970cde3da265cf2a98537250fea5e0c0d9a7639/modular_rl/ppo.py#L93
            if kl > self._kl_high:
                self._kl_coef *= self._kl_alpha
            elif kl < self._kl_low:
                self._kl_coef /= self._kl_alpha
            summaries.update({
                'Statistics/kl_coef': self._kl_coef
            })

        return summaries, kl

    @iton
    def train_share(self, BATCH):
        if self.is_continuous:
            # [T, B, A], [T, B, A], [T, B, 1]
            mu, log_std, value = self.net(BATCH.obs, begin_mask=BATCH.begin_mask)
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            new_log_prob = dist.log_prob(BATCH.action).unsqueeze(-1)  # [T, B, 1]
            entropy = dist.entropy().unsqueeze(-1)  # [T, B, 1]
        else:
            # [T, B, A], [T, B, 1]
            logits, value = self.net(BATCH.obs, begin_mask=BATCH.begin_mask)
            logp_all = logits.log_softmax(-1)  # [T, B, 1]
            new_log_prob = (BATCH.action * logp_all).sum(-1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(-1, keepdim=True)  # [T, B, 1]
        ratio = (new_log_prob - BATCH.log_prob).exp()  # [T, B, 1]
        surrogate = ratio * BATCH.gae_adv  # [T, B, 1]
        clipped_surrogate = th.minimum(
            surrogate,
            ratio.clamp(1.0 - self._epsilon, 1.0 + self._epsilon) * BATCH.gae_adv
        )  # [T, B, 1]
        # ref: https://github.com/thu-ml/tianshou/blob/c97aa4065ee8464bd5897bb86f1f81abd8e2cff9/tianshou/policy/modelfree/ppo.py#L159
        if self._use_duel_clip:
            clipped_surrogate2 = th.maximum(
                clipped_surrogate,
                (1.0 + self._duel_epsilon) * BATCH.gae_adv
            )  # [T, B, 1]
            clipped_surrogate = th.where(BATCH.gae_adv < 0, clipped_surrogate2, clipped_surrogate)  # [T, B, 1]
        actor_loss = -(clipped_surrogate + self._ent_coef * entropy).mean()  # 1

        # ref: https://github.com/joschu/modular_rl/blob/6970cde3da265cf2a98537250fea5e0c0d9a7639/modular_rl/ppo.py#L40
        # ref: https://github.com/hill-a/stable-baselines/blob/b3f414f4f2900403107357a2206f80868af16da3/stable_baselines/ppo2/ppo2.py#L185
        if self._kl_reverse:  # TODO:
            kl = .5 * (new_log_prob - BATCH.log_prob).square().mean()  # 1
        else:
            # a sample estimate for KL-divergence, easy to compute
            kl = .5 * (BATCH.log_prob - new_log_prob).square().mean()

        if self._use_kl_loss:
            kl_loss = self._kl_coef * kl  # 1
            actor_loss += kl_loss

        if self._use_extra_loss:
            extra_loss = self._extra_coef * th.maximum(th.zeros_like(kl), kl - self._kl_cutoff).square().mean()  # 1
            actor_loss += extra_loss

        td_error = BATCH.discounted_reward - value  # [T, B, 1]
        if self._use_vclip:
            # ref: https://github.com/llSourcell/OpenAI_Five_vs_Dota2_Explained/blob/c5def7e57aa70785c2394ea2eeb3e5f66ad59a53/train.py#L154
            # ref: https://github.com/hill-a/stable-baselines/blob/b3f414f4f2900403107357a2206f80868af16da3/stable_baselines/ppo2/ppo2.py#L172
            value_clip = BATCH.value + (value - BATCH.value).clamp(-self._value_epsilon,
                                                                   self._value_epsilon)  # [T, B, 1]
            td_error_clip = BATCH.discounted_reward - value_clip  # [T, B, 1]
            td_square = th.maximum(td_error.square(), td_error_clip.square())  # [T, B, 1]
        else:
            td_square = td_error.square()  # [T, B, 1]

        critic_loss = 0.5 * td_square.mean()  # 1
        loss = actor_loss + self._vf_coef * critic_loss  # 1
        self.oplr.optimize(loss)
        return {
                   'LOSS/actor_loss': actor_loss,
                   'LOSS/critic_loss': critic_loss,
                   'Statistics/kl': kl,
                   'Statistics/entropy': entropy.mean(),
                   'LEARNING_RATE/lr': self.oplr.lr
               }, kl

    @iton
    def train_actor(self, BATCH):
        if self.is_continuous:
            # [T, B, A], [T, B, A]
            mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask)
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            new_log_prob = dist.log_prob(BATCH.action).unsqueeze(-1)  # [T, B, 1]
            entropy = dist.entropy().unsqueeze(-1)  # [T, B, 1]
        else:
            logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            new_log_prob = (BATCH.action * logp_all).sum(-1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(-1, keepdim=True)  # [T, B, 1]
        ratio = (new_log_prob - BATCH.log_prob).exp()  # [T, B, 1]
        kl = (BATCH.log_prob - new_log_prob).square().mean()  # 1
        surrogate = ratio * BATCH.gae_adv  # [T, B, 1]
        clipped_surrogate = th.minimum(
            surrogate,
            th.where(BATCH.gae_adv > 0, (1 + self._epsilon) *
                     BATCH.gae_adv, (1 - self._epsilon) * BATCH.gae_adv)
        )  # [T, B, 1]
        if self._use_duel_clip:
            clipped_surrogate = th.maximum(
                clipped_surrogate,
                (1.0 + self._duel_epsilon) * BATCH.gae_adv
            )  # [T, B, 1]

        actor_loss = -(clipped_surrogate + self._ent_coef * entropy).mean()  # 1

        if self._use_kl_loss:
            kl_loss = self._kl_coef * kl  # 1
            actor_loss += kl_loss
        if self._use_extra_loss:
            extra_loss = self._extra_coef * th.maximum(th.zeros_like(kl), kl - self._kl_cutoff).square().mean()  # 1
            actor_loss += extra_loss

        self.actor_oplr.optimize(actor_loss)
        return {
                   'LOSS/actor_loss': actor_loss,
                   'Statistics/kl': kl,
                   'Statistics/entropy': entropy.mean(),
                   'LEARNING_RATE/actor_lr': self.actor_oplr.lr
               }, kl

    @iton
    def train_critic(self, BATCH):
        value = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, 1]

        td_error = BATCH.discounted_reward - value  # [T, B, 1]
        if self._use_vclip:
            value_clip = BATCH.value + (value - BATCH.value).clamp(-self._value_epsilon,
                                                                   self._value_epsilon)  # [T, B, 1]
            td_error_clip = BATCH.discounted_reward - value_clip  # [T, B, 1]
            td_square = th.maximum(td_error.square(), td_error_clip.square())  # [T, B, 1]
        else:
            td_square = td_error.square()  # [T, B, 1]

        critic_loss = 0.5 * td_square.mean()  # 1
        self.critic_oplr.optimize(critic_loss)
        return {
            'LOSS/critic_loss': critic_loss,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr
        }
Example #3
0
class SAC_V(SarlOffPolicy):
    """
        Soft Actor Critic with Value neural network. https://arxiv.org/abs/1812.05905
        Soft Actor-Critic for Discrete Action Settings. https://arxiv.org/abs/1910.07207
    """
    policy_mode = 'off-policy'

    def __init__(
            self,
            alpha=0.2,
            annealing=True,
            last_alpha=0.01,
            polyak=0.995,
            use_gumbel=True,
            discrete_tau=1.0,
            network_settings={
                'actor_continuous': {
                    'share': [128, 128],
                    'mu': [64],
                    'log_std': [64],
                    'soft_clip': False,
                    'log_std_bound': [-20, 2]
                },
                'actor_discrete': [64, 32],
                'q': [128, 128],
                'v': [128, 128]
            },
            actor_lr=5.0e-4,
            critic_lr=1.0e-3,
            alpha_lr=5.0e-4,
            auto_adaption=True,
            **kwargs):
        super().__init__(**kwargs)
        self.polyak = polyak
        self.use_gumbel = use_gumbel
        self.discrete_tau = discrete_tau
        self.auto_adaption = auto_adaption
        self.annealing = annealing

        self.v_net = TargetTwin(
            CriticValue(self.obs_spec,
                        rep_net_params=self._rep_net_params,
                        network_settings=network_settings['v']),
            self.polyak).to(self.device)

        if self.is_continuous:
            self.actor = ActorCts(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_continuous']).to(
                    self.device)
        else:
            self.actor = ActorDct(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_discrete']).to(
                    self.device)

        # entropy = -log(1/|A|) = log |A|
        self.target_entropy = 0.98 * (-self.a_dim if self.is_continuous else
                                      np.log(self.a_dim))

        if self.is_continuous or self.use_gumbel:
            self.q_net = CriticQvalueOne(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                action_dim=self.a_dim,
                network_settings=network_settings['q']).to(self.device)
        else:
            self.q_net = CriticQvalueAll(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['q']).to(self.device)
        self.q_net2 = deepcopy(self.q_net)

        self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params)
        self.critic_oplr = OPLR([self.q_net, self.q_net2, self.v_net],
                                critic_lr, **self._oplr_params)

        if self.auto_adaption:
            self.log_alpha = th.tensor(0., requires_grad=True).to(self.device)
            self.alpha_oplr = OPLR(self.log_alpha, alpha_lr,
                                   **self._oplr_params)
            self._trainer_modules.update(alpha_oplr=self.alpha_oplr)
        else:
            self.log_alpha = th.tensor(alpha).log().to(self.device)
            if self.annealing:
                self.alpha_annealing = LinearAnnealing(alpha, last_alpha,
                                                       int(1e6))

        self._trainer_modules.update(actor=self.actor,
                                     v_net=self.v_net,
                                     q_net=self.q_net,
                                     q_net2=self.q_net2,
                                     log_alpha=self.log_alpha,
                                     actor_oplr=self.actor_oplr,
                                     critic_oplr=self.critic_oplr)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    @iton
    def select_action(self, obs):
        if self.is_continuous:
            mu, log_std = self.actor(obs, rnncs=self.rnncs)  # [B, A]
            pi = td.Normal(mu, log_std.exp()).sample().tanh()  # [B, A]
            mu.tanh_()  # squash mu   # [B, A]
        else:
            logits = self.actor(obs, rnncs=self.rnncs)  # [B, A]
            mu = logits.argmax(-1)  # [B,]
            cate_dist = td.Categorical(logits=logits)
            pi = cate_dist.sample()  # [B,]
        self.rnncs_ = self.actor.get_rnncs()
        actions = pi if self._is_train_mode else mu
        return actions, Data(action=actions)

    def _train(self, BATCH):
        if self.is_continuous or self.use_gumbel:
            td_error, summaries = self._train_continuous(BATCH)
        else:
            td_error, summaries = self._train_discrete(BATCH)
        return td_error, summaries

    @iton
    def _train_continuous(self, BATCH):
        v = self.v_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        v_target = self.v_net.t(BATCH.obs_,
                                begin_mask=BATCH.begin_mask)  # [T, B, 1]

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs,
                                     begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            pi = dist.rsample()  # [T, B, A]
            pi, log_pi = squash_action(
                pi,
                dist.log_prob(pi).unsqueeze(-1))  # [T, B, A], [T, B, 1]
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                -1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1),
                                         self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            pi = _pi_diff + _pi  # [T, B, A]
            log_pi = (logp_all * pi).sum(-1, keepdim=True)  # [T, B, 1]
        q1 = self.q_net(BATCH.obs, BATCH.action,
                        begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q2 = self.q_net2(BATCH.obs, BATCH.action,
                         begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q1_pi = self.q_net(BATCH.obs, pi,
                           begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q2_pi = self.q_net2(BATCH.obs, pi,
                            begin_mask=BATCH.begin_mask)  # [T, B, 1]
        dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, v_target,
                             BATCH.begin_mask).detach()  # [T, B, 1]
        v_from_q_stop = (th.minimum(q1_pi, q2_pi) -
                         self.alpha * log_pi).detach()  # [T, B, 1]
        td_v = v - v_from_q_stop  # [T, B, 1]
        td_error1 = q1 - dc_r  # [T, B, 1]
        td_error2 = q2 - dc_r  # [T, B, 1]
        q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
        q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
        v_loss_stop = (td_v.square() * BATCH.get('isw', 1.0)).mean()  # 1

        critic_loss = 0.5 * q1_loss + 0.5 * q2_loss + 0.5 * v_loss_stop
        self.critic_oplr.optimize(critic_loss)

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs,
                                     begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            pi = dist.rsample()  # [T, B, A]
            pi, log_pi = squash_action(
                pi,
                dist.log_prob(pi).unsqueeze(-1))  # [T, B, A], [T, B, 1]
            entropy = dist.entropy().mean()  # 1
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                -1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1),
                                         self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            pi = _pi_diff + _pi  # [T, B, A]
            log_pi = (logp_all * pi).sum(-1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(-1).mean()  # 1
        q1_pi = self.q_net(BATCH.obs, pi,
                           begin_mask=BATCH.begin_mask)  # [T, B, 1]
        actor_loss = -(q1_pi - self.alpha * log_pi).mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        summaries = {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/q1_loss': q1_loss,
            'LOSS/q2_loss': q2_loss,
            'LOSS/v_loss': v_loss_stop,
            'LOSS/critic_loss': critic_loss,
            'Statistics/log_alpha': self.log_alpha,
            'Statistics/alpha': self.alpha,
            'Statistics/entropy': entropy,
            'Statistics/q_min': th.minimum(q1, q2).min(),
            'Statistics/q_mean': th.minimum(q1, q2).mean(),
            'Statistics/q_max': th.maximum(q1, q2).max(),
            'Statistics/v_mean': v.mean()
        }
        if self.auto_adaption:
            alpha_loss = -(self.alpha *
                           (log_pi.detach() + self.target_entropy)).mean()
            self.alpha_oplr.optimize(alpha_loss)
            summaries.update({
                'LOSS/alpha_loss': alpha_loss,
                'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr
            })
        return (td_error1 + td_error2) / 2, summaries

    @iton
    def _train_discrete(self, BATCH):
        v = self.v_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        v_target = self.v_net.t(BATCH.obs_,
                                begin_mask=BATCH.begin_mask)  # [T, B, 1]

        q1_all = self.q_net(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        q2_all = self.q_net2(BATCH.obs,
                             begin_mask=BATCH.begin_mask)  # [T, B, A]
        q1 = (q1_all * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        q2 = (q2_all * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        logits = self.actor(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        logp_all = logits.log_softmax(-1)  # [T, B, A]

        dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, v_target,
                             BATCH.begin_mask).detach()  # [T, B, 1]
        td_v = v - (th.minimum((logp_all.exp() * q1_all).sum(-1, keepdim=True),
                               (logp_all.exp() * q2_all).sum(
                                   -1, keepdim=True))).detach()  # [T, B, 1]
        td_error1 = q1 - dc_r  # [T, B, 1]
        td_error2 = q2 - dc_r  # [T, B, 1]

        q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
        q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
        v_loss_stop = (td_v.square() * BATCH.get('isw', 1.0)).mean()  # 1
        critic_loss = 0.5 * q1_loss + 0.5 * q2_loss + 0.5 * v_loss_stop
        self.critic_oplr.optimize(critic_loss)

        q1_all = self.q_net(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        q2_all = self.q_net2(BATCH.obs,
                             begin_mask=BATCH.begin_mask)  # [T, B, A]
        logits = self.actor(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        logp_all = logits.log_softmax(-1)  # [T, B, A]

        entropy = -(logp_all.exp() * logp_all).sum(-1,
                                                   keepdim=True)  # [T, B, 1]
        q_all = th.minimum(q1_all, q2_all)  # [T, B, A]
        actor_loss = -((q_all - self.alpha * logp_all) * logp_all.exp()).sum(
            -1)  # [T, B, A] => [T, B]
        actor_loss = actor_loss.mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        summaries = {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/q1_loss': q1_loss,
            'LOSS/q2_loss': q2_loss,
            'LOSS/v_loss': v_loss_stop,
            'LOSS/critic_loss': critic_loss,
            'Statistics/log_alpha': self.log_alpha,
            'Statistics/alpha': self.alpha,
            'Statistics/entropy': entropy.mean(),
            'Statistics/v_mean': v.mean()
        }
        if self.auto_adaption:
            corr = (self.target_entropy - entropy).detach()  # [T, B, 1]
            # corr = ((logp_all - self.a_dim) * logp_all.exp()).sum(-1).detach()
            alpha_loss = -(self.alpha * corr)  # [T, B, 1]
            alpha_loss = alpha_loss.mean()  # 1
            self.alpha_oplr.optimize(alpha_loss)
            summaries.update({
                'LOSS/alpha_loss': alpha_loss,
                'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr
            })
        return (td_error1 + td_error2) / 2, summaries

    def _after_train(self):
        super()._after_train()
        if self.annealing and not self.auto_adaption:
            self.log_alpha.copy_(
                self.alpha_annealing(self._cur_train_step).log())
        self.v_net.sync()
Example #4
0
class DPG(SarlOffPolicy):
    """
    Deterministic Policy Gradient, https://hal.inria.fr/file/index/docid/938992/filename/dpg-icml2014.pdf
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 actor_lr=5.0e-4,
                 critic_lr=1.0e-3,
                 use_target_action_noise=False,
                 noise_action='ou',
                 noise_params={
                     'sigma': 0.2
                 },
                 discrete_tau=1.0,
                 network_settings={
                     'actor_continuous': [32, 32],
                     'actor_discrete': [32, 32],
                     'q': [32, 32]
                 },
                 **kwargs):
        super().__init__(**kwargs)
        self.discrete_tau = discrete_tau
        self.use_target_action_noise = use_target_action_noise

        if self.is_continuous:
            self.target_noised_action = ClippedNormalNoisedAction(sigma=0.2, noise_bound=0.2)
            self.noised_action = Noise_action_REGISTER[noise_action](**noise_params)
            self.actor = ActorDPG(self.obs_spec,
                                  rep_net_params=self._rep_net_params,
                                  output_shape=self.a_dim,
                                  network_settings=network_settings['actor_continuous']).to(self.device)
        else:
            self.actor = ActorDct(self.obs_spec,
                                  rep_net_params=self._rep_net_params,
                                  output_shape=self.a_dim,
                                  network_settings=network_settings['actor_discrete']).to(self.device)

        self.critic = CriticQvalueOne(self.obs_spec,
                                      rep_net_params=self._rep_net_params,
                                      action_dim=self.a_dim,
                                      network_settings=network_settings['q']).to(self.device)

        self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params)
        self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params)
        self._trainer_modules.update(actor=self.actor,
                                     critic=self.critic,
                                     actor_oplr=self.actor_oplr,
                                     critic_oplr=self.critic_oplr)

    def episode_reset(self):
        super().episode_reset()
        if self.is_continuous:
            self.noised_action.reset()

    @iton
    def select_action(self, obs):
        output = self.actor(obs, rnncs=self.rnncs)  # [B, A]
        self.rnncs_ = self.actor.get_rnncs()
        if self.is_continuous:
            mu = output  # [B, A]
            pi = self.noised_action(mu)  # [B, A]
        else:
            logits = output  # [B, A]
            mu = logits.argmax(-1)  # [B,]
            cate_dist = td.Categorical(logits=logits)
            pi = cate_dist.sample()  # [B,]
        actions = pi if self._is_train_mode else mu
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        if self.is_continuous:
            action_target = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            if self.use_target_action_noise:
                action_target = self.target_noised_action(action_target)  # [T, B, A]
        else:
            target_logits = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            target_cate_dist = td.Categorical(logits=target_logits)
            target_pi = target_cate_dist.sample()  # [T, B]
            action_target = F.one_hot(target_pi, self.a_dim).float()  # [T, B, A]
        q_target = self.critic(BATCH.obs_, action_target, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        dc_r = n_step_return(BATCH.reward,
                             self.gamma,
                             BATCH.done,
                             q_target,
                             BATCH.begin_mask).detach()  # [T, B, 1]
        q = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask)  # [T, B, A]
        td_error = dc_r - q  # [T, B, A]
        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1
        self.critic_oplr.optimize(q_loss)

        if self.is_continuous:
            mu = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
        else:
            logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
            _pi = logits.softmax(-1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(
                logits.argmax(-1), self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            mu = _pi_diff + _pi  # [T, B, A]
        q_actor = self.critic(BATCH.obs, mu, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        actor_loss = -q_actor.mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        return td_error, {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': q_loss,
            'Statistics/q_min': q.min(),
            'Statistics/q_mean': q.mean(),
            'Statistics/q_max': q.max()
        }
Example #5
0
class A2C(SarlOnPolicy):
    """
    Synchronous Advantage Actor-Critic, A2C, http://arxiv.org/abs/1602.01783
    """
    policy_mode = 'on-policy'

    def __init__(
            self,
            agent_spec,
            beta=1.0e-3,
            actor_lr=5.0e-4,
            critic_lr=1.0e-3,
            network_settings={
                'actor_continuous': {
                    'hidden_units': [64, 64],
                    'condition_sigma': False,
                    'log_std_bound': [-20, 2]
                },
                'actor_discrete': [32, 32],
                'critic': [32, 32]
            },
            **kwargs):
        super().__init__(agent_spec=agent_spec, **kwargs)
        self.beta = beta

        if self.is_continuous:
            self.actor = ActorMuLogstd(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_continuous']).to(
                    self.device)
        else:
            self.actor = ActorDct(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_discrete']).to(
                    self.device)
        self.critic = CriticValue(
            self.obs_spec,
            rep_net_params=self._rep_net_params,
            network_settings=network_settings['critic']).to(self.device)

        self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params)
        self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params)

        self._trainer_modules.update(actor=self.actor,
                                     critic=self.critic,
                                     actor_oplr=self.actor_oplr,
                                     critic_oplr=self.critic_oplr)

    @iton
    def select_action(self, obs):
        output = self.actor(obs, rnncs=self.rnncs)  # [B, A]
        self.rnncs_ = self.actor.get_rnncs()
        if self.is_continuous:
            mu, log_std = output  # [B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            action = dist.sample().clamp(-1, 1)  # [B, A]
        else:
            logits = output  # [B, A]
            norm_dist = td.Categorical(logits=logits)
            action = norm_dist.sample()  # [B,]

        acts_info = Data(action=action)
        if self.use_rnn:
            acts_info.update(rnncs=self.rnncs)
        return action, acts_info

    @iton
    def _get_value(self, obs, rnncs=None):
        value = self.critic(obs, rnncs=self.rnncs)
        return value

    def _preprocess_BATCH(self, BATCH):  # [T, B, *]
        BATCH = super()._preprocess_BATCH(BATCH)
        value = self._get_value(BATCH.obs_[-1], rnncs=self.rnncs)
        BATCH.discounted_reward = discounted_sum(BATCH.reward,
                                                 self.gamma,
                                                 BATCH.done,
                                                 BATCH.begin_mask,
                                                 init_value=value)
        td_error = calculate_td_error(
            BATCH.reward,
            self.gamma,
            BATCH.done,
            value=BATCH.value,
            next_value=np.concatenate((BATCH.value[1:], value[np.newaxis, :]),
                                      0))
        BATCH.gae_adv = discounted_sum(td_error,
                                       self.lambda_ * self.gamma,
                                       BATCH.done,
                                       BATCH.begin_mask,
                                       init_value=0.,
                                       normalize=True)
        return BATCH

    @iton
    def _train(self, BATCH):
        v = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        td_error = BATCH.discounted_reward - v  # [T, B, 1]
        critic_loss = td_error.square().mean()  # 1
        self.critic_oplr.optimize(critic_loss)

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs,
                                     begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            log_act_prob = dist.log_prob(BATCH.action).unsqueeze(
                -1)  # [T, B, 1]
            entropy = dist.entropy().unsqueeze(-1)  # [T, B, 1]
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            log_act_prob = (BATCH.action * logp_all).sum(
                -1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(
                -1, keepdim=True)  # [T, B, 1]
        # advantage = BATCH.discounted_reward - v.detach()  # [T, B, 1]
        actor_loss = -(log_act_prob * BATCH.gae_adv +
                       self.beta * entropy).mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        return {
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': critic_loss,
            'Statistics/entropy': entropy.mean(),
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr
        }
Example #6
0
class NPG(SarlOnPolicy):
    """
    Natural Policy Gradient, NPG
    https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf
    """
    policy_mode = 'on-policy'

    def __init__(
            self,
            agent_spec,
            actor_step_size=0.5,
            beta=1.0e-3,
            lambda_=0.95,
            cg_iters=10,
            damping_coeff=0.1,
            epsilon=0.2,
            critic_lr=1e-3,
            train_critic_iters=10,
            network_settings={
                'actor_continuous': {
                    'hidden_units': [64, 64],
                    'condition_sigma': False,
                    'log_std_bound': [-20, 2]
                },
                'actor_discrete': [32, 32],
                'critic': [32, 32]
            },
            **kwargs):
        super().__init__(agent_spec=agent_spec, **kwargs)
        self.actor_step_size = actor_step_size
        self.beta = beta
        self.lambda_ = lambda_
        self._epsilon = epsilon
        self._cg_iters = cg_iters
        self._damping_coeff = damping_coeff
        self._train_critic_iters = train_critic_iters

        if self.is_continuous:
            self.actor = ActorMuLogstd(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_continuous']).to(
                    self.device)
        else:
            self.actor = ActorDct(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_discrete']).to(
                    self.device)
        self.critic = CriticValue(
            self.obs_spec,
            rep_net_params=self._rep_net_params,
            network_settings=network_settings['critic']).to(self.device)

        self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params)
        self._trainer_modules.update(actor=self.actor,
                                     critic=self.critic,
                                     critic_oplr=self.critic_oplr)

    @iton
    def select_action(self, obs):
        output = self.actor(obs, rnncs=self.rnncs)  # [B, A]
        self.rnncs_ = self.actor.get_rnncs()
        value = self.critic(obs, rnncs=self.rnncs)  # [B, 1]
        if self.is_continuous:
            mu, log_std = output  # [B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            action = dist.sample().clamp(-1, 1)  # [B, A]
            log_prob = dist.log_prob(action).unsqueeze(-1)  # [B, 1]
        else:
            logits = output  # [B, A]
            logp_all = logits.log_softmax(-1)  # [B, A]
            norm_dist = td.Categorical(logits=logp_all)
            action = norm_dist.sample()  # [B,]
            log_prob = norm_dist.log_prob(action).unsqueeze(-1)  # [B, 1]
        acts_info = Data(action=action,
                         value=value,
                         log_prob=log_prob + th.finfo().eps)
        if self.use_rnn:
            acts_info.update(rnncs=self.rnncs)
        if self.is_continuous:
            acts_info.update(mu=mu, log_std=log_std)
        else:
            acts_info.update(logp_all=logp_all)
        return action, acts_info

    @iton
    def _get_value(self, obs, rnncs=None):
        value = self.critic(obs, rnncs=rnncs)  # [B, 1]
        return value

    def _preprocess_BATCH(self, BATCH):  # [T, B, *]
        BATCH = super()._preprocess_BATCH(BATCH)
        value = self._get_value(BATCH.obs_[-1], rnncs=self.rnncs)
        BATCH.discounted_reward = discounted_sum(BATCH.reward,
                                                 self.gamma,
                                                 BATCH.done,
                                                 BATCH.begin_mask,
                                                 init_value=value)
        td_error = calculate_td_error(
            BATCH.reward,
            self.gamma,
            BATCH.done,
            value=BATCH.value,
            next_value=np.concatenate((BATCH.value[1:], value[np.newaxis, :]),
                                      0))
        BATCH.gae_adv = discounted_sum(td_error,
                                       self.lambda_ * self.gamma,
                                       BATCH.done,
                                       BATCH.begin_mask,
                                       init_value=0.,
                                       normalize=True)
        return BATCH

    @iton
    def _train(self, BATCH):
        output = self.actor(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        if self.is_continuous:
            mu, log_std = output  # [T, B, A], [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            new_log_prob = dist.log_prob(BATCH.action).unsqueeze(
                -1)  # [T, B, 1]
            entropy = dist.entropy().mean()  # 1
        else:
            logits = output  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            new_log_prob = (BATCH.action * logp_all).sum(
                -1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(-1).mean()  # 1
        ratio = (new_log_prob - BATCH.log_prob).exp()  # [T, B, 1]
        actor_loss = -(ratio * BATCH.gae_adv).mean()  # 1

        flat_grads = grads_flatten(actor_loss, self.actor,
                                   retain_graph=True).detach()  # [1,]

        if self.is_continuous:
            kl = td.kl_divergence(
                td.Independent(td.Normal(BATCH.mu, BATCH.log_std.exp()), 1),
                td.Independent(td.Normal(mu, log_std.exp()), 1)).mean()
        else:
            kl = (BATCH.logp_all.exp() *
                  (BATCH.logp_all - logp_all)).sum(-1).mean()  # 1

        flat_kl_grad = grads_flatten(kl, self.actor, create_graph=True)
        search_direction = -self._conjugate_gradients(
            flat_grads, flat_kl_grad, cg_iters=self._cg_iters)  # [1,]

        with th.no_grad():
            flat_params = th.cat(
                [param.data.view(-1) for param in self.actor.parameters()])
            new_flat_params = flat_params + self.actor_step_size * search_direction
            set_from_flat_params(self.actor, new_flat_params)

        for _ in range(self._train_critic_iters):
            value = self.critic(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, 1]
            td_error = BATCH.discounted_reward - value  # [T, B, 1]
            critic_loss = td_error.square().mean()  # 1
            self.critic_oplr.optimize(critic_loss)

        return {
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': critic_loss,
            'Statistics/entropy': entropy.mean(),
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr
        }

    def _conjugate_gradients(self,
                             flat_grads,
                             flat_kl_grad,
                             cg_iters: int = 10,
                             residual_tol: float = 1e-10):
        """
        Conjugate gradient algorithm
        (see https://en.wikipedia.org/wiki/Conjugate_gradient_method)
        """
        x = th.zeros_like(flat_grads)
        r, p = flat_grads.clone(), flat_grads.clone()
        # Note: should be 'r, p = b - MVP(x)', but for x=0, MVP(x)=0.
        # Change if doing warm start.
        rdotr = r.dot(r)
        for i in range(cg_iters):
            z = self._MVP(p, flat_kl_grad)
            alpha = rdotr / (p.dot(z) + th.finfo().eps)
            x += alpha * p
            r -= alpha * z
            new_rdotr = r.dot(r)
            if new_rdotr < residual_tol:
                break
            p = r + new_rdotr / rdotr * p
            rdotr = new_rdotr
        return x

    def _MVP(self, v, flat_kl_grad):
        """Matrix vector product."""
        # caculate second order gradient of kl with respect to theta
        kl_v = (flat_kl_grad * v).sum()
        mvp = grads_flatten(kl_v, self.actor, retain_graph=True).detach()
        mvp += max(0, self._damping_coeff) * v
        return mvp
Example #7
0
class PG(SarlOnPolicy):
    policy_mode = 'on-policy'

    def __init__(
            self,
            agent_spec,
            lr=5.0e-4,
            network_settings={
                'actor_continuous': {
                    'hidden_units': [32, 32],
                    'condition_sigma': False,
                    'log_std_bound': [-20, 2]
                },
                'actor_discrete': [32, 32]
            },
            **kwargs):
        super().__init__(agent_spec=agent_spec, **kwargs)
        if self.is_continuous:
            self.net = ActorMuLogstd(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_continuous']).to(
                    self.device)
        else:
            self.net = ActorDct(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_discrete']).to(
                    self.device)
        self.oplr = OPLR(self.net, lr, **self._oplr_params)

        self._trainer_modules.update(model=self.net, oplr=self.oplr)

    @iton
    def select_action(self, obs):
        output = self.net(obs, rnncs=self.rnncs)  # [B, A]
        self.rnncs_ = self.net.get_rnncs()
        if self.is_continuous:
            mu, log_std = output  # [B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            action = dist.sample().clamp(-1, 1)  # [B, A]
        else:
            logits = output  # [B, A]
            norm_dist = td.Categorical(logits=logits)
            action = norm_dist.sample()  # [B,]

        acts_info = Data(action=action)
        if self.use_rnn:
            acts_info.update(rnncs=self.rnncs)
        return action, acts_info

    def _preprocess_BATCH(self, BATCH):  # [T, B, *]
        BATCH = super()._preprocess_BATCH(BATCH)
        BATCH.discounted_reward = discounted_sum(BATCH.reward,
                                                 self.gamma,
                                                 BATCH.done,
                                                 BATCH.begin_mask,
                                                 init_value=0.,
                                                 normalize=True)
        return BATCH

    @iton
    def _train(self, BATCH):  # [B, T, *]
        output = self.net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [B, T, A]
        if self.is_continuous:
            mu, log_std = output  # [B, T, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            log_act_prob = dist.log_prob(BATCH.action).unsqueeze(
                -1)  # [B, T, 1]
            entropy = dist.entropy().unsqueeze(-1)  # [B, T, 1]
        else:
            logits = output  # [B, T, A]
            logp_all = logits.log_softmax(-1)  # [B, T, A]
            log_act_prob = (logp_all * BATCH.action).sum(
                -1, keepdim=True)  # [B, T, 1]
            entropy = -(logp_all.exp() * logp_all).sum(
                1, keepdim=True)  # [B, T, 1]
        loss = -(log_act_prob * BATCH.discounted_reward).mean()
        self.oplr.optimize(loss)
        return {
            'LOSS/loss': loss,
            'Statistics/entropy': entropy.mean(),
            'LEARNING_RATE/lr': self.oplr.lr
        }
Example #8
0
class AC(SarlOffPolicy):
    policy_mode = 'off-policy'

    # off-policy actor-critic
    def __init__(
            self,
            actor_lr=5.0e-4,
            critic_lr=1.0e-3,
            network_settings={
                'actor_continuous': {
                    'hidden_units': [64, 64],
                    'condition_sigma': False,
                    'log_std_bound': [-20, 2]
                },
                'actor_discrete': [32, 32],
                'critic': [32, 32]
            },
            **kwargs):
        super().__init__(**kwargs)

        if self.is_continuous:
            self.actor = ActorMuLogstd(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_continuous']).to(
                    self.device)
        else:
            self.actor = ActorDct(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_discrete']).to(
                    self.device)
        self.critic = CriticQvalueOne(
            self.obs_spec,
            rep_net_params=self._rep_net_params,
            action_dim=self.a_dim,
            network_settings=network_settings['critic']).to(self.device)

        self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params)
        self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params)

        self._trainer_modules.update(actor=self.actor,
                                     critic=self.critic,
                                     actor_oplr=self.actor_oplr,
                                     critic_oplr=self.critic_oplr)

    @iton
    def select_action(self, obs):
        output = self.actor(obs, rnncs=self.rnncs)  # [B, *]
        self.rnncs_ = self.actor.get_rnncs()
        if self.is_continuous:
            mu, log_std = output  # [B, *]
            dist = td.Independent(td.Normal(mu, log_std.exp()), -1)
            action = dist.sample().clamp(-1, 1)  # [B, *]
            log_prob = dist.log_prob(action)  # [B,]
        else:
            logits = output  # [B, *]
            norm_dist = td.Categorical(logits=logits)
            action = norm_dist.sample()  # [B,]
            log_prob = norm_dist.log_prob(action)  # [B,]
        return action, Data(action=action, log_prob=log_prob)

    def random_action(self):
        actions = super().random_action()
        if self.is_continuous:
            self._acts_info.update(log_prob=np.full(self.n_copies,
                                                    np.log(0.5)))  # [B,]
        else:
            self._acts_info.update(log_prob=np.full(self.n_copies,
                                                    1. / self.a_dim))  # [B,]
        return actions

    @iton
    def _train(self, BATCH):
        q = self.critic(BATCH.obs, BATCH.action,
                        begin_mask=BATCH.begin_mask)  # [T, B, 1]
        if self.is_continuous:
            next_mu, _ = self.actor(BATCH.obs_,
                                    begin_mask=BATCH.begin_mask)  # [T, B, *]
            max_q_next = self.critic(
                BATCH.obs_, next_mu,
                begin_mask=BATCH.begin_mask).detach()  # [T, B, 1]
        else:
            logits = self.actor(BATCH.obs_,
                                begin_mask=BATCH.begin_mask)  # [T, B, *]
            max_a = logits.argmax(-1)  # [T, B]
            max_a_one_hot = F.one_hot(max_a, self.a_dim).float()  # [T, B, N]
            max_q_next = self.critic(BATCH.obs_,
                                     max_a_one_hot).detach()  # [T, B, 1]
        td_error = q - n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                     max_q_next,
                                     BATCH.begin_mask).detach()  # [T, B, 1]
        critic_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1
        self.critic_oplr.optimize(critic_loss)

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs,
                                     begin_mask=BATCH.begin_mask)  # [T, B, *]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            log_prob = dist.log_prob(BATCH.action)  # [T, B]
            entropy = dist.entropy().mean()  # 1
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, *]
            logp_all = logits.log_softmax(-1)  # [T, B, *]
            log_prob = (logp_all * BATCH.action).sum(-1)  # [T, B]
            entropy = -(logp_all.exp() * logp_all).sum(-1).mean()  # 1
        ratio = (log_prob - BATCH.log_prob).exp().detach()  # [T, B]
        actor_loss = -(ratio * log_prob *
                       q.squeeze(-1).detach()).mean()  # [T, B] => 1
        self.actor_oplr.optimize(actor_loss)

        return td_error, {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': critic_loss,
            'Statistics/q_max': q.max(),
            'Statistics/q_min': q.min(),
            'Statistics/q_mean': q.mean(),
            'Statistics/ratio': ratio.mean(),
            'Statistics/entropy': entropy
        }