Exemple #1
0
class DDPG(SarlOffPolicy):
    """
    Deep Deterministic Policy Gradient, https://arxiv.org/abs/1509.02971
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 polyak=0.995,
                 noise_action='ou',
                 noise_params={'sigma': 0.2},
                 use_target_action_noise=False,
                 actor_lr=5.0e-4,
                 critic_lr=1.0e-3,
                 discrete_tau=1.0,
                 network_settings={
                     'actor_continuous': [32, 32],
                     'actor_discrete': [32, 32],
                     'q': [32, 32]
                 },
                 **kwargs):
        super().__init__(**kwargs)
        self.polyak = polyak
        self.discrete_tau = discrete_tau
        self.use_target_action_noise = use_target_action_noise

        if self.is_continuous:
            actor = ActorDPG(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_continuous'])
            self.target_noised_action = ClippedNormalNoisedAction(
                sigma=0.2, noise_bound=0.2)
            if noise_action in ['ou', 'clip_normal']:
                self.noised_action = Noise_action_REGISTER[noise_action](
                    **noise_params)
            elif noise_action == 'normal':
                self.noised_action = self.target_noised_action
            else:
                raise Exception(
                    f'cannot use noised action type of {noise_action}')
        else:
            actor = ActorDct(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_discrete'])
        self.actor = TargetTwin(actor, self.polyak).to(self.device)
        self.critic = TargetTwin(
            CriticQvalueOne(self.obs_spec,
                            rep_net_params=self._rep_net_params,
                            action_dim=self.a_dim,
                            network_settings=network_settings['q']),
            self.polyak).to(self.device)

        self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params)
        self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params)
        self._trainer_modules.update(actor=self.actor,
                                     critic=self.critic,
                                     actor_oplr=self.actor_oplr,
                                     critic_oplr=self.critic_oplr)

    def episode_reset(self):
        super().episode_reset()
        if self.is_continuous:
            self.noised_action.reset()

    @iton
    def select_action(self, obs):
        output = self.actor(obs, rnncs=self.rnncs)  # [B, A]
        self.rnncs_ = self.actor.get_rnncs()
        if self.is_continuous:
            mu = output  # [B, A]
            pi = self.noised_action(mu)  # [B, A]
        else:
            logits = output  # [B, A]
            mu = logits.argmax(-1)  # [B, ]
            cate_dist = td.Categorical(logits=logits)
            pi = cate_dist.sample()  # [B,]
        actions = pi if self._is_train_mode else mu
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        if self.is_continuous:
            action_target = self.actor.t(
                BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            if self.use_target_action_noise:
                action_target = self.target_noised_action(
                    action_target)  # [T, B, A]
        else:
            target_logits = self.actor.t(
                BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            target_cate_dist = td.Categorical(logits=target_logits)
            target_pi = target_cate_dist.sample()  # [T, B]
            action_target = F.one_hot(target_pi,
                                      self.a_dim).float()  # [T, B, A]
        q = self.critic(BATCH.obs, BATCH.action,
                        begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q_target = self.critic.t(BATCH.obs_,
                                 action_target,
                                 begin_mask=BATCH.begin_mask)  # [T, B, 1]
        dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target,
                             BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = dc_r - q  # [T, B, 1]
        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1
        self.critic_oplr.optimize(q_loss)

        if self.is_continuous:
            mu = self.actor(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                -1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1),
                                         self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            mu = _pi_diff + _pi  # [T, B, A]
        q_actor = self.critic(BATCH.obs, mu,
                              begin_mask=BATCH.begin_mask)  # [T, B, 1]
        actor_loss = -q_actor.mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        return td_error, {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': q_loss,
            'Statistics/q_min': q.min(),
            'Statistics/q_mean': q.mean(),
            'Statistics/q_max': q.max()
        }

    def _after_train(self):
        super()._after_train()
        self.actor.sync()
        self.critic.sync()
Exemple #2
0
class OC(SarlOffPolicy):
    """
    The Option-Critic Architecture. http://arxiv.org/abs/1609.05140
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 q_lr=5.0e-3,
                 intra_option_lr=5.0e-4,
                 termination_lr=5.0e-4,
                 use_eps_greedy=False,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 boltzmann_temperature=1.0,
                 options_num=4,
                 ent_coff=0.01,
                 double_q=False,
                 use_baseline=True,
                 terminal_mask=True,
                 termination_regularizer=0.01,
                 assign_interval=1000,
                 network_settings={
                     'q': [32, 32],
                     'intra_option': [32, 32],
                     'termination': [32, 32]
                 },
                 **kwargs):
        super().__init__(**kwargs)
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self._max_train_step)
        self.assign_interval = assign_interval
        self.options_num = options_num
        self.termination_regularizer = termination_regularizer
        self.ent_coff = ent_coff
        self.use_baseline = use_baseline
        self.terminal_mask = terminal_mask
        self.double_q = double_q
        self.boltzmann_temperature = boltzmann_temperature
        self.use_eps_greedy = use_eps_greedy

        self.q_net = TargetTwin(
            CriticQvalueAll(self.obs_spec,
                            rep_net_params=self._rep_net_params,
                            output_shape=self.options_num,
                            network_settings=network_settings['q'])).to(
                                self.device)

        self.intra_option_net = OcIntraOption(
            self.obs_spec,
            rep_net_params=self._rep_net_params,
            output_shape=self.a_dim,
            options_num=self.options_num,
            network_settings=network_settings['intra_option']).to(self.device)
        self.termination_net = CriticQvalueAll(
            self.obs_spec,
            rep_net_params=self._rep_net_params,
            output_shape=self.options_num,
            network_settings=network_settings['termination'],
            out_act='sigmoid').to(self.device)

        if self.is_continuous:
            # https://discuss.pytorch.org/t/valueerror-cant-optimize-a-non-leaf-tensor/21751
            # https://blog.csdn.net/nkhgl/article/details/100047276
            self.log_std = th.as_tensor(
                np.full((self.options_num, self.a_dim),
                        -0.5)).requires_grad_().to(self.device)  # [P, A]
            self.intra_option_oplr = OPLR(
                [self.intra_option_net, self.log_std], intra_option_lr,
                **self._oplr_params)
        else:
            self.intra_option_oplr = OPLR(self.intra_option_net,
                                          intra_option_lr, **self._oplr_params)
        self.q_oplr = OPLR(self.q_net, q_lr, **self._oplr_params)
        self.termination_oplr = OPLR(self.termination_net, termination_lr,
                                     **self._oplr_params)

        self._trainer_modules.update(q_net=self.q_net,
                                     intra_option_net=self.intra_option_net,
                                     termination_net=self.termination_net,
                                     q_oplr=self.q_oplr,
                                     intra_option_oplr=self.intra_option_oplr,
                                     termination_oplr=self.termination_oplr)
        self.options = self.new_options = self._generate_random_options()

    def _generate_random_options(self):
        # [B,]
        return th.tensor(np.random.randint(0, self.options_num,
                                           self.n_copies)).to(self.device)

    def episode_step(self, obs: Data, env_rets: Data, begin_mask: np.ndarray):
        super().episode_step(obs, env_rets, begin_mask)
        self.options = self.new_options

    @iton
    def select_action(self, obs):
        q = self.q_net(obs, rnncs=self.rnncs)  # [B, P]
        self.rnncs_ = self.q_net.get_rnncs()
        pi = self.intra_option_net(obs, rnncs=self.rnncs)  # [B, P, A]
        beta = self.termination_net(obs, rnncs=self.rnncs)  # [B, P]
        options_onehot = F.one_hot(self.options,
                                   self.options_num).float()  # [B, P]
        options_onehot_expanded = options_onehot.unsqueeze(-1)  # [B, P, 1]
        pi = (pi * options_onehot_expanded).sum(-2)  # [B, A]
        if self.is_continuous:
            mu = pi.tanh()  # [B, A]
            log_std = self.log_std[self.options]  # [B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            actions = dist.sample().clamp(-1, 1)  # [B, A]
        else:
            pi = pi / self.boltzmann_temperature  # [B, A]
            dist = td.Categorical(logits=pi)
            actions = dist.sample()  # [B, ]
        max_options = q.argmax(-1).long()  # [B, P] => [B, ]
        if self.use_eps_greedy:
            # epsilon greedy
            if self._is_train_mode and self.expl_expt_mng.is_random(
                    self._cur_train_step):
                self.new_options = self._generate_random_options()
            else:
                self.new_options = max_options
        else:
            beta_probs = (beta * options_onehot).sum(-1)  # [B, P] => [B,]
            beta_dist = td.Bernoulli(probs=beta_probs)
            self.new_options = th.where(beta_dist.sample() < 1, self.options,
                                        max_options)
        return actions, Data(action=actions,
                             last_options=self.options,
                             options=self.new_options)

    def random_action(self):
        actions = super().random_action()
        self._acts_info.update(
            last_options=np.random.randint(0, self.options_num, self.n_copies),
            options=np.random.randint(0, self.options_num, self.n_copies))
        return actions

    def _preprocess_BATCH(self, BATCH):  # [T, B, *]
        BATCH = super()._preprocess_BATCH(BATCH)
        BATCH.last_options = int2one_hot(BATCH.last_options, self.options_num)
        BATCH.options = int2one_hot(BATCH.options, self.options_num)
        return BATCH

    @iton
    def _train(self, BATCH):
        q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, P]
        q_next = self.q_net.t(BATCH.obs_,
                              begin_mask=BATCH.begin_mask)  # [T, B, P]
        beta_next = self.termination_net(
            BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, P]

        qu_eval = (q * BATCH.options).sum(-1, keepdim=True)  # [T, B, 1]
        beta_s_ = (beta_next * BATCH.options).sum(-1,
                                                  keepdim=True)  # [T, B, 1]
        q_s_ = (q_next * BATCH.options).sum(-1, keepdim=True)  # [T, B, 1]
        # https://github.com/jeanharb/option_critic/blob/5d6c81a650a8f452bc8ad3250f1f211d317fde8c/neural_net.py#L94
        if self.double_q:
            q_ = self.q_net(BATCH.obs_,
                            begin_mask=BATCH.begin_mask)  # [T, B, P]
            # [T, B, P] => [T, B] => [T, B, P]
            max_a_idx = F.one_hot(q_.argmax(-1), self.options_num).float()
            q_s_max = (q_next * max_a_idx).sum(-1, keepdim=True)  # [T, B, 1]
        else:
            q_s_max = q_next.max(-1, keepdim=True)[0]  # [T, B, 1]
        u_target = (1 - beta_s_) * q_s_ + beta_s_ * q_s_max  # [T, B, 1]
        qu_target = n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                  u_target,
                                  BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = qu_target - qu_eval  # gradient : q   [T, B, 1]
        q_loss = (td_error.square() *
                  BATCH.get('isw', 1.0)).mean()  # [T, B, 1] => 1
        self.q_oplr.optimize(q_loss)

        q_s = qu_eval.detach()  # [T, B, 1]
        # https://github.com/jeanharb/option_critic/blob/5d6c81a650a8f452bc8ad3250f1f211d317fde8c/neural_net.py#L130
        if self.use_baseline:
            adv = (qu_target - q_s).detach()  # [T, B, 1]
        else:
            adv = qu_target.detach()  # [T, B, 1]
        # [T, B, P] => [T, B, P, 1]
        options_onehot_expanded = BATCH.options.unsqueeze(-1)
        pi = self.intra_option_net(BATCH.obs,
                                   begin_mask=BATCH.begin_mask)  # [T, B, P, A]
        # [T, B, P, A] => [T, B, A]
        pi = (pi * options_onehot_expanded).sum(-2)
        if self.is_continuous:
            mu = pi.tanh()  # [T, B, A]
            log_std = self.log_std[BATCH.options.argmax(-1)]  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            log_p = dist.log_prob(BATCH.action).unsqueeze(-1)  # [T, B, 1]
            entropy = dist.entropy().unsqueeze(-1)  # [T, B, 1]
        else:
            pi = pi / self.boltzmann_temperature  # [T, B, A]
            log_pi = pi.log_softmax(-1)  # [T, B, A]
            entropy = -(log_pi.exp() * log_pi).sum(-1,
                                                   keepdim=True)  # [T, B, 1]
            log_p = (BATCH.action * log_pi).sum(-1, keepdim=True)  # [T, B, 1]
        pi_loss = -(log_p * adv + self.ent_coff * entropy).mean()  # 1

        beta = self.termination_net(BATCH.obs,
                                    begin_mask=BATCH.begin_mask)  # [T, B, P]
        beta_s = (beta * BATCH.last_options).sum(-1, keepdim=True)  # [T, B, 1]
        if self.use_eps_greedy:
            v_s = q.max(
                -1,
                keepdim=True)[0] - self.termination_regularizer  # [T, B, 1]
        else:
            v_s = (1 - beta_s) * q_s + beta_s * q.max(
                -1, keepdim=True)[0]  # [T, B, 1]
            # v_s = q.mean(-1, keepdim=True)  # [T, B, 1]
        beta_loss = beta_s * (q_s - v_s).detach()  # [T, B, 1]
        # https://github.com/lweitkamp/option-critic-pytorch/blob/0c57da7686f8903ed2d8dded3fae832ee9defd1a/option_critic.py#L238
        if self.terminal_mask:
            beta_loss *= (1 - BATCH.done)  # [T, B, 1]
        beta_loss = beta_loss.mean()  # 1

        self.intra_option_oplr.optimize(pi_loss)
        self.termination_oplr.optimize(beta_loss)

        return td_error, {
            'LEARNING_RATE/q_lr': self.q_oplr.lr,
            'LEARNING_RATE/intra_option_lr': self.intra_option_oplr.lr,
            'LEARNING_RATE/termination_lr': self.termination_oplr.lr,
            # 'Statistics/option': self.options[0],
            'LOSS/q_loss': q_loss,
            'LOSS/pi_loss': pi_loss,
            'LOSS/beta_loss': beta_loss,
            'Statistics/q_option_max': q_s.max(),
            'Statistics/q_option_min': q_s.min(),
            'Statistics/q_option_mean': q_s.mean()
        }

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            self.q_net.sync()
Exemple #3
0
class C51(SarlOffPolicy):
    """
    Category 51, https://arxiv.org/abs/1707.06887
    No double, no dueling, no noisy net.
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 v_min=-10,
                 v_max=10,
                 atoms=51,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=1000,
                 network_settings=[128, 128],
                 **kwargs):
        super().__init__(**kwargs)
        assert not self.is_continuous, 'c51 only support discrete action space'
        self._v_min = v_min
        self._v_max = v_max
        self._atoms = atoms
        self._delta_z = (self._v_max - self._v_min) / (self._atoms - 1)
        self._z = th.linspace(self._v_min, self._v_max,
                              self._atoms).float().to(self.device)  # [N,]
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self._max_train_step)
        self.assign_interval = assign_interval
        self.q_net = TargetTwin(
            C51Distributional(self.obs_spec,
                              rep_net_params=self._rep_net_params,
                              action_dim=self.a_dim,
                              atoms=self._atoms,
                              network_settings=network_settings)).to(
                                  self.device)
        self.oplr = OPLR(self.q_net, lr, **self._oplr_params)
        self._trainer_modules.update(model=self.q_net, oplr=self.oplr)

    @iton
    def select_action(self, obs):
        feat = self.q_net(obs, rnncs=self.rnncs)  # [B, A, N]
        self.rnncs_ = self.q_net.get_rnncs()

        if self._is_train_mode and self.expl_expt_mng.is_random(
                self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            q = (self._z * feat).sum(-1)  # [B, A, N] * [N,] => [B, A]
            actions = q.argmax(-1)  # [B,]
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        q_dist = self.q_net(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A, N]
        # [T, B, A, N] * [T, B, A, 1] => [T, B, A, N] => [T, B, N]
        q_dist = (q_dist * BATCH.action.unsqueeze(-1)).sum(-2)

        q_eval = (q_dist * self._z).sum(-1)  # [T, B, N] * [N,] => [T, B]

        target_q_dist = self.q_net.t(
            BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A, N]
        # [T, B, A, N] * [1, N] => [T, B, A]
        target_q = (target_q_dist * self._z).sum(-1)
        a_ = target_q.argmax(-1)  # [T, B]
        a_onehot = F.one_hot(a_, self.a_dim).float()  # [T, B, A]
        # [T, B, A, N] * [T, B, A, 1] => [T, B, A, N] => [T, B, N]
        target_q_dist = (target_q_dist * a_onehot.unsqueeze(-1)).sum(-2)

        target = n_step_return(
            BATCH.reward.repeat(1, 1, self._atoms), self.gamma,
            BATCH.done.repeat(1, 1, self._atoms), target_q_dist,
            BATCH.begin_mask.repeat(1, 1, self._atoms)).detach()  # [T, B, N]
        target = target.clamp(self._v_min, self._v_max)  # [T, B, N]
        # An amazing trick for calculating the projection gracefully.
        # ref: https://github.com/ShangtongZhang/DeepRL
        target_dist = (
            1 - (target.unsqueeze(-1) - self._z.view(1, 1, -1, 1)).abs() /
            self._delta_z).clamp(0, 1) * target_q_dist.unsqueeze(
                -1)  # [T, B, N, 1]
        target_dist = target_dist.sum(-1)  # [T, B, N]

        _cross_entropy = -(target_dist * th.log(q_dist + th.finfo().eps)).sum(
            -1, keepdim=True)  # [T, B, 1]
        loss = (_cross_entropy * BATCH.get('isw', 1.0)).mean()  # 1

        self.oplr.optimize(loss)
        return _cross_entropy, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            self.q_net.sync()
Exemple #4
0
class CuriosityModel(nn.Module):
    """
    Model of Intrinsic Curiosity Module (ICM).
    Curiosity-driven Exploration by Self-supervised Prediction, https://arxiv.org/abs/1705.05363
    """
    def __init__(self,
                 obs_spec,
                 rep_net_params,
                 is_continuous,
                 action_dim,
                 *,
                 eta=0.2,
                 lr=1.0e-3,
                 beta=0.2):
        """
        params:
            is_continuous: sepecify whether action space is continuous(True) or discrete(False)
            action_dim: dimension of action

            eta: weight of intrinsic reward
            lr: the learning rate of curiosity model
            beta: weight factor of loss between inverse_dynamic_net and forward_net
        """
        super().__init__()
        self.eta = eta
        self.beta = beta
        self.is_continuous = is_continuous
        self.action_dim = action_dim

        self.rep_net = RepresentationNetwork(obs_spec=obs_spec,
                                             rep_net_params=rep_net_params)

        self.feat_dim = self.rep_net.h_dim

        # S, S' => A
        self.inverse_dynamic_net = nn.Sequential(
            nn.Linear(self.feat_dim * 2, self.feat_dim * 2),
            Act_REGISTER[default_act](),
            nn.Linear(self.feat_dim * 2, action_dim))
        if self.is_continuous:
            self.inverse_dynamic_net.add_module('tanh', nn.Tanh())

        # S, A => S'
        self.forward_net = nn.Sequential(
            nn.Linear(self.feat_dim + action_dim,
                      self.feat_dim), Act_REGISTER[default_act](),
            nn.Linear(self.feat_dim, self.feat_dim))

        self.oplr = OPLR(
            models=[self.rep_net, self.inverse_dynamic_net, self.forward_net],
            lr=lr)

    def forward(self, BATCH):
        fs, _ = self.rep_net(BATCH.obs,
                             begin_mask=BATCH.begin_mask)  # [T, B, *]
        fs_, _ = self.rep_net(BATCH.obs_,
                              begin_mask=BATCH.begin_mask)  # [T, B, *]

        # [T, B, *] <S, A> => S'
        s_eval = self.forward_net(th.cat((fs, BATCH.action), -1))
        LF = 0.5 * (fs_ - s_eval).square().sum(-1, keepdim=True)  # [T, B, 1]
        intrinsic_reward = self.eta * LF
        loss_forward = LF.mean()  # 1

        a_eval = self.inverse_dynamic_net(th.cat((fs, fs_), -1))  # [T, B, *]
        if self.is_continuous:
            loss_inverse = 0.5 * \
                           (a_eval - BATCH.action).square().sum(-1).mean()
        else:
            idx = BATCH.action.argmax(-1)  # [T, B]
            loss_inverse = F.cross_entropy(a_eval.view(-1, self.action_dim),
                                           idx.view(-1))  # 1

        loss = (1 - self.beta) * loss_inverse + self.beta * loss_forward
        self.oplr.optimize(loss)
        summaries = {
            'LOSS/curiosity_loss': loss,
            'LOSS/forward_loss': loss_forward,
            'LOSS/inverse_loss': loss_inverse
        }
        return intrinsic_reward, summaries
Exemple #5
0
class IQN(SarlOffPolicy):
    """
    Implicit Quantile Networks, https://arxiv.org/abs/1806.06923
    Double DQN
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 online_quantiles=8,
                 target_quantiles=8,
                 select_quantiles=32,
                 quantiles_idx=64,
                 huber_delta=1.,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=2,
                 network_settings={
                     'q_net': [128, 64],
                     'quantile': [128, 64],
                     'tile': [64]
                 },
                 **kwargs):
        super().__init__(**kwargs)
        assert not self.is_continuous, 'iqn only support discrete action space'
        self.online_quantiles = online_quantiles
        self.target_quantiles = target_quantiles
        self.select_quantiles = select_quantiles
        self.quantiles_idx = quantiles_idx
        self.huber_delta = huber_delta
        self.assign_interval = assign_interval
        self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init,
                                                          eps_mid=eps_mid,
                                                          eps_final=eps_final,
                                                          init2mid_annealing_step=init2mid_annealing_step,
                                                          max_step=self._max_train_step)
        self.q_net = TargetTwin(IqnNet(self.obs_spec,
                                       rep_net_params=self._rep_net_params,
                                       action_dim=self.a_dim,
                                       quantiles_idx=self.quantiles_idx,
                                       network_settings=network_settings)).to(self.device)
        self.oplr = OPLR(self.q_net, lr, **self._oplr_params)
        self._trainer_modules.update(model=self.q_net,
                                     oplr=self.oplr)

    @iton
    def select_action(self, obs):
        _, select_quantiles_tiled = self._generate_quantiles(  # [N*B, X]
            batch_size=self.n_copies,
            quantiles_num=self.select_quantiles
        )
        q_values = self.q_net(obs, select_quantiles_tiled, rnncs=self.rnncs)  # [N, B, A]
        self.rnncs_ = self.q_net.get_rnncs()

        if self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            # [N, B, A] => [B, A] => [B,]
            actions = q_values.mean(0).argmax(-1)
        return actions, Data(action=actions)

    def _generate_quantiles(self, batch_size, quantiles_num):
        _quantiles = th.rand([quantiles_num * batch_size, 1])  # [N*B, 1]
        _quantiles_tiled = _quantiles.repeat(1, self.quantiles_idx)  # [N*B, 1] => [N*B, X]

        # pi * i * tau [N*B, X] * [X, ] => [N*B, X]
        _quantiles_tiled = th.arange(self.quantiles_idx) * np.pi * _quantiles_tiled
        _quantiles_tiled.cos_()  # [N*B, X]

        _quantiles = _quantiles.view(batch_size, quantiles_num, 1)  # [N*B, 1] => [B, N, 1]
        return _quantiles, _quantiles_tiled  # [B, N, 1], [N*B, X]

    @iton
    def _train(self, BATCH):
        time_step = BATCH.reward.shape[0]
        batch_size = BATCH.reward.shape[1]

        quantiles, quantiles_tiled = self._generate_quantiles(  # [T*B, N, 1], [N*T*B, X]
            batch_size=time_step * batch_size,
            quantiles_num=self.online_quantiles)
        # [T*B, N, 1] => [T, B, N, 1]
        quantiles = quantiles.view(time_step, batch_size, -1, 1)
        quantiles_tiled = quantiles_tiled.view(time_step, -1, self.quantiles_idx)  # [N*T*B, X] => [T, N*B, X]

        quantiles_value = self.q_net(BATCH.obs, quantiles_tiled, begin_mask=BATCH.begin_mask)  # [T, N, B, A]
        # [T, N, B, A] => [N, T, B, A] * [T, B, A] => [N, T, B, 1]
        quantiles_value = (quantiles_value.swapaxes(0, 1) * BATCH.action).sum(-1, keepdim=True)
        q_eval = quantiles_value.mean(0)  # [N, T, B, 1] => [T, B, 1]

        _, select_quantiles_tiled = self._generate_quantiles(  # [N*T*B, X]
            batch_size=time_step * batch_size,
            quantiles_num=self.select_quantiles)
        select_quantiles_tiled = select_quantiles_tiled.view(
            time_step, -1, self.quantiles_idx)  # [N*T*B, X] => [T, N*B, X]

        q_values = self.q_net(
            BATCH.obs_, select_quantiles_tiled, begin_mask=BATCH.begin_mask)  # [T, N, B, A]
        q_values = q_values.mean(1)  # [T, N, B, A] => [T, B, A]
        next_max_action = q_values.argmax(-1)  # [T, B]
        next_max_action = F.one_hot(
            next_max_action, self.a_dim).float()  # [T, B, A]

        _, target_quantiles_tiled = self._generate_quantiles(  # [N'*T*B, X]
            batch_size=time_step * batch_size,
            quantiles_num=self.target_quantiles)
        target_quantiles_tiled = target_quantiles_tiled.view(
            time_step, -1, self.quantiles_idx)  # [N'*T*B, X] => [T, N'*B, X]
        target_quantiles_value = self.q_net.t(BATCH.obs_, target_quantiles_tiled,
                                              begin_mask=BATCH.begin_mask)  # [T, N', B, A]
        target_quantiles_value = target_quantiles_value.swapaxes(0, 1)  # [T, N', B, A] => [N', T, B, A]
        target_quantiles_value = (target_quantiles_value * next_max_action).sum(-1, keepdim=True)  # [N', T, B, 1]

        target_q = target_quantiles_value.mean(0)  # [T, B, 1]
        q_target = n_step_return(BATCH.reward,  # [T, B, 1]
                                 self.gamma,
                                 BATCH.done,  # [T, B, 1]
                                 target_q,  # [T, B, 1]
                                 BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1]

        # [N', T, B, 1] => [N', T, B]
        target_quantiles_value = target_quantiles_value.squeeze(-1)
        target_quantiles_value = target_quantiles_value.permute(
            1, 2, 0)  # [N', T, B] => [T, B, N']
        quantiles_value_target = n_step_return(BATCH.reward.repeat(1, 1, self.target_quantiles),
                                               self.gamma,
                                               BATCH.done.repeat(1, 1, self.target_quantiles),
                                               target_quantiles_value,
                                               BATCH.begin_mask.repeat(1, 1,
                                                                       self.target_quantiles)).detach()  # [T, B, N']
        # [T, B, N'] => [T, B, 1, N']
        quantiles_value_target = quantiles_value_target.unsqueeze(-2)
        quantiles_value_online = quantiles_value.permute(1, 2, 0, 3)  # [N, T, B, 1] => [T, B, N, 1]
        # [T, B, N, 1] - [T, B, 1, N'] => [T, B, N, N']
        quantile_error = quantiles_value_online - quantiles_value_target
        huber = F.huber_loss(quantiles_value_online, quantiles_value_target,
                             reduction="none", delta=self.huber_delta)  # [T, B, N, N]
        # [T, B, N, 1] - [T, B, N, N'] => [T, B, N, N']
        huber_abs = (quantiles - quantile_error.detach().le(0.).float()).abs()
        loss = (huber_abs * huber).mean(-1)  # [T, B, N, N'] => [T, B, N]
        loss = loss.sum(-1, keepdim=True)  # [T, B, N] => [T, B, 1]

        loss = (loss * BATCH.get('isw', 1.0)).mean()  # 1
        self.oplr.optimize(loss)
        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            self.q_net.sync()
Exemple #6
0
class IOC(SarlOffPolicy):
    """
    Learning Options with Interest Functions, https://www.aaai.org/ojs/index.php/AAAI/article/view/5114/4987 
    Options of Interest: Temporal Abstraction with Interest Functions, http://arxiv.org/abs/2001.00271
    """
    policy_mode = 'off-policy'

    def __init__(
            self,
            q_lr=5.0e-3,
            intra_option_lr=5.0e-4,
            termination_lr=5.0e-4,
            interest_lr=5.0e-4,
            boltzmann_temperature=1.0,
            options_num=4,
            ent_coff=0.01,
            double_q=False,
            use_baseline=True,
            terminal_mask=True,
            termination_regularizer=0.01,
            assign_interval=1000,
            network_settings={
                'q': [32, 32],
                'intra_option': [32, 32],
                'termination': [32, 32],
                'interest': [32, 32]
            },
            **kwargs):
        super().__init__(**kwargs)
        self.assign_interval = assign_interval
        self.options_num = options_num
        self.termination_regularizer = termination_regularizer
        self.ent_coff = ent_coff
        self.use_baseline = use_baseline
        self.terminal_mask = terminal_mask
        self.double_q = double_q
        self.boltzmann_temperature = boltzmann_temperature

        self.q_net = TargetTwin(
            CriticQvalueAll(self.obs_spec,
                            rep_net_params=self._rep_net_params,
                            output_shape=self.options_num,
                            network_settings=network_settings['q'])).to(
                                self.device)

        self.intra_option_net = OcIntraOption(
            self.obs_spec,
            rep_net_params=self._rep_net_params,
            output_shape=self.a_dim,
            options_num=self.options_num,
            network_settings=network_settings['intra_option']).to(self.device)
        self.termination_net = CriticQvalueAll(
            self.obs_spec,
            rep_net_params=self._rep_net_params,
            output_shape=self.options_num,
            network_settings=network_settings['termination'],
            out_act='sigmoid').to(self.device)
        self.interest_net = CriticQvalueAll(
            self.obs_spec,
            rep_net_params=self._rep_net_params,
            output_shape=self.options_num,
            network_settings=network_settings['interest'],
            out_act='sigmoid').to(self.device)

        if self.is_continuous:
            self.log_std = th.as_tensor(
                np.full((self.options_num, self.a_dim),
                        -0.5)).requires_grad_().to(self.device)  # [P, A]
            self.intra_option_oplr = OPLR(
                [self.intra_option_net, self.log_std], intra_option_lr,
                **self._oplr_params)
        else:
            self.intra_option_oplr = OPLR(self.intra_option_net,
                                          intra_option_lr, **self._oplr_params)

        self.q_oplr = OPLR(self.q_net, q_lr, **self._oplr_params)
        self.termination_oplr = OPLR(self.termination_net, termination_lr,
                                     **self._oplr_params)
        self.interest_oplr = OPLR(self.interest_net, interest_lr,
                                  **self._oplr_params)

        self._trainer_modules.update(q_net=self.q_net,
                                     intra_option_net=self.intra_option_net,
                                     termination_net=self.termination_net,
                                     interest_net=self.interest_net,
                                     q_oplr=self.q_oplr,
                                     intra_option_oplr=self.intra_option_oplr,
                                     termination_oplr=self.termination_oplr,
                                     interest_oplr=self.interest_oplr)

        self.options = self.new_options = th.tensor(
            np.random.randint(0, self.options_num,
                              self.n_copies)).to(self.device)

    def episode_step(self, obs: Data, env_rets: Data, begin_mask: np.ndarray):
        super().episode_step(obs, env_rets, begin_mask)
        self.options = self.new_options

    @iton
    def select_action(self, obs):
        q = self.q_net(obs, rnncs=self.rnncs)  # [B, P]
        self.rnncs_ = self.q_net.get_rnncs()
        pi = self.intra_option_net(obs, rnncs=self.rnncs)  # [B, P, A]
        options_onehot = F.one_hot(self.options,
                                   self.options_num).float()  # [B, P]
        options_onehot_expanded = options_onehot.unsqueeze(-1)  # [B, P, 1]
        pi = (pi * options_onehot_expanded).sum(-2)  # [B, A]
        if self.is_continuous:
            mu = pi.tanh()  # [B, A]
            log_std = self.log_std[self.options]  # [B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            actions = dist.sample().clamp(-1, 1)  # [B, A]
        else:
            pi = pi / self.boltzmann_temperature  # [B, A]
            dist = td.Categorical(logits=pi)
            actions = dist.sample()  # [B, ]
        interests = self.interest_net(obs, rnncs=self.rnncs)  # [B, P]
        op_logits = interests * q  # [B, P] or q.softmax(-1)
        self.new_options = td.Categorical(logits=op_logits).sample()  # [B, ]
        return actions, Data(action=actions,
                             last_options=self.options,
                             options=self.new_options)

    def random_action(self):
        actions = super().random_action()
        self._acts_info.update(
            last_options=np.random.randint(0, self.options_num, self.n_copies),
            options=np.random.randint(0, self.options_num, self.n_copies))
        return actions

    def _preprocess_BATCH(self, BATCH):  # [T, B, *]
        BATCH = super()._preprocess_BATCH(BATCH)
        BATCH.last_options = int2one_hot(BATCH.last_options, self.options_num)
        BATCH.options = int2one_hot(BATCH.options, self.options_num)
        return BATCH

    @iton
    def _train(self, BATCH):
        q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, P]
        q_next = self.q_net.t(BATCH.obs_,
                              begin_mask=BATCH.begin_mask)  # [T, B, P]
        beta_next = self.termination_net(
            BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, P]

        qu_eval = (q * BATCH.options).sum(-1, keepdim=True)  # [T, B, 1]
        beta_s_ = (beta_next * BATCH.options).sum(-1,
                                                  keepdim=True)  # [T, B, 1]
        q_s_ = (q_next * BATCH.options).sum(-1, keepdim=True)  # [T, B, 1]
        if self.double_q:
            q_ = self.q_net(BATCH.obs_,
                            begin_mask=BATCH.begin_mask)  # [T, B, P]
            max_a_idx = F.one_hot(q_.argmax(-1),
                                  self.options_num).float()  # [T, B, P]
            q_s_max = (q_next * max_a_idx).sum(-1, keepdim=True)  # [T, B, 1]
        else:
            q_s_max = q_next.max(-1, keepdim=True)[0]  # [T, B, 1]
        u_target = (1 - beta_s_) * q_s_ + beta_s_ * q_s_max  # [T, B, 1]
        qu_target = n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                  u_target,
                                  BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = qu_target - qu_eval  # [T, B, 1] gradient : q
        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1
        self.q_oplr.optimize(q_loss)

        q_s = qu_eval.detach()  # [T, B, 1]
        pi = self.intra_option_net(BATCH.obs,
                                   begin_mask=BATCH.begin_mask)  # [T, B, P, A]

        if self.use_baseline:
            adv = (qu_target - q_s).detach()  # [T, B, 1]
        else:
            adv = qu_target.detach()  # [T, B, 1]
        # [T, B, P] => [T, B, P, 1]
        options_onehot_expanded = BATCH.options.unsqueeze(-1)
        # [T, B, P, A] => [T, B, A]
        pi = (pi * options_onehot_expanded).sum(-2)
        if self.is_continuous:
            mu = pi.tanh()  # [T, B, A]
            log_std = self.log_std[BATCH.options.argmax(-1)]  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            log_p = dist.log_prob(BATCH.action).unsqueeze(-1)  # [T, B, 1]
            entropy = dist.entropy().unsqueeze(-1)  # [T, B, 1]
        else:
            pi = pi / self.boltzmann_temperature  # [T, B, A]
            log_pi = pi.log_softmax(-1)  # [T, B, A]
            entropy = -(log_pi.exp() * log_pi).sum(-1,
                                                   keepdim=True)  # [T, B, 1]
            log_p = (log_pi * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        pi_loss = -(log_p * adv + self.ent_coff * entropy).mean()  # 1
        self.intra_option_oplr.optimize(pi_loss)

        beta = self.termination_net(BATCH.obs,
                                    begin_mask=BATCH.begin_mask)  # [T, B, P]
        beta_s = (beta * BATCH.last_options).sum(-1, keepdim=True)  # [T, B, 1]

        interests = self.interest_net(BATCH.obs,
                                      begin_mask=BATCH.begin_mask)  # [T, B, P]
        # [T, B, P] or q.softmax(-1)
        pi_op = (interests * q.detach()).softmax(-1)
        interest_loss = -(beta_s.detach() *
                          (pi_op * BATCH.options).sum(-1, keepdim=True) *
                          q_s).mean()  # 1
        self.interest_oplr.optimize(interest_loss)

        v_s = (q * pi_op).sum(-1, keepdim=True)  # [T, B, 1]
        beta_loss = beta_s * (q_s - v_s).detach()  # [T, B, 1]
        if self.terminal_mask:
            beta_loss *= (1 - BATCH.done)  # [T, B, 1]
        beta_loss = beta_loss.mean()  # 1
        self.termination_oplr.optimize(beta_loss)

        return td_error, {
            'LEARNING_RATE/q_lr': self.q_oplr.lr,
            'LEARNING_RATE/intra_option_lr': self.intra_option_oplr.lr,
            'LEARNING_RATE/termination_lr': self.termination_oplr.lr,
            # 'Statistics/option': self.options[0],
            'LOSS/q_loss': q_loss,
            'LOSS/pi_loss': pi_loss,
            'LOSS/beta_loss': beta_loss,
            'LOSS/interest_loss': interest_loss,
            'Statistics/q_option_max': q_s.max(),
            'Statistics/q_option_min': q_s.min(),
            'Statistics/q_option_mean': q_s.mean()
        }

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            self.q_net.sync()
Exemple #7
0
class DreamerV1(SarlOffPolicy):
    """
    Dream to Control: Learning Behaviors by Latent Imagination, http://arxiv.org/abs/1912.01603
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 eps_init: float = 1,
                 eps_mid: float = 0.2,
                 eps_final: float = 0.01,
                 init2mid_annealing_step: int = 1000,
                 stoch_dim=30,
                 deter_dim=200,
                 model_lr=6e-4,
                 actor_lr=8e-5,
                 critic_lr=8e-5,
                 kl_free_nats=3,
                 action_sigma=0.3,
                 imagination_horizon=15,
                 lambda_=0.95,
                 kl_scale=1.0,
                 reward_scale=1.0,
                 use_pcont=False,
                 pcont_scale=10.0,
                 network_settings=dict(),
                 **kwargs):
        super().__init__(**kwargs)

        assert self.use_rnn == False, 'assert self.use_rnn == False'

        if self.obs_spec.has_visual_observation \
                and len(self.obs_spec.visual_dims) == 1 \
                and not self.obs_spec.has_vector_observation:
            visual_dim = self.obs_spec.visual_dims[0]
            # TODO: optimize this
            assert visual_dim[0] == visual_dim[
                1] == 64, 'visual dimension must be [64, 64, *]'
            self._is_visual = True
        elif self.obs_spec.has_vector_observation \
                and len(self.obs_spec.vector_dims) == 1 \
                and not self.obs_spec.has_visual_observation:
            self._is_visual = False
        else:
            raise ValueError("please check the observation type")

        self.stoch_dim = stoch_dim
        self.deter_dim = deter_dim
        self.kl_free_nats = kl_free_nats
        self.imagination_horizon = imagination_horizon
        self.lambda_ = lambda_
        self.kl_scale = kl_scale
        self.reward_scale = reward_scale
        # https://github.com/danijar/dreamer/issues/2
        self.use_pcont = use_pcont  # probability of continuing
        self.pcont_scale = pcont_scale
        self._action_sigma = action_sigma
        self._network_settings = network_settings

        if not self.is_continuous:
            self.expl_expt_mng = ExplorationExploitationClass(
                eps_init=eps_init,
                eps_mid=eps_mid,
                eps_final=eps_final,
                init2mid_annealing_step=init2mid_annealing_step,
                max_step=self._max_train_step)

        if self.obs_spec.has_visual_observation:
            from rls.nn.dreamer import VisualDecoder, VisualEncoder
            self.obs_encoder = VisualEncoder(
                self.obs_spec.visual_dims[0],
                **network_settings['obs_encoder']['visual']).to(self.device)
            self.obs_decoder = VisualDecoder(
                self.decoder_input_dim, self.obs_spec.visual_dims[0],
                **network_settings['obs_decoder']['visual']).to(self.device)
        else:
            from rls.nn.dreamer import VectorEncoder
            self.obs_encoder = VectorEncoder(
                self.obs_spec.vector_dims[0],
                **network_settings['obs_encoder']['vector']).to(self.device)
            self.obs_decoder = DenseModel(
                self.decoder_input_dim, self.obs_spec.vector_dims[0],
                **network_settings['obs_decoder']['vector']).to(self.device)

        self.rssm = self._dreamer_build_rssm()
        """
        p(r_t | s_t, h_t)
        Reward model to predict reward from state and rnn hidden state
        """
        self.reward_predictor = DenseModel(self.decoder_input_dim, 1,
                                           **network_settings['reward']).to(
                                               self.device)

        self.actor = ActionDecoder(self.a_dim,
                                   self.decoder_input_dim,
                                   dist=self._action_dist,
                                   **network_settings['actor']).to(self.device)
        self.critic = self._dreamer_build_critic()

        _modules = [
            self.obs_encoder, self.rssm, self.obs_decoder,
            self.reward_predictor
        ]
        if self.use_pcont:
            self.pcont_decoder = DenseModel(self.decoder_input_dim, 1,
                                            **network_settings['pcont']).to(
                                                self.device)
            _modules.append(self.pcont_decoder)

        self.model_oplr = OPLR(_modules, model_lr, **self._oplr_params)
        self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params)
        self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params)
        self._trainer_modules.update(obs_encoder=self.obs_encoder,
                                     obs_decoder=self.obs_decoder,
                                     reward_predictor=self.reward_predictor,
                                     rssm=self.rssm,
                                     actor=self.actor,
                                     critic=self.critic,
                                     model_oplr=self.model_oplr,
                                     actor_oplr=self.actor_oplr,
                                     critic_oplr=self.critic_oplr)
        if self.use_pcont:
            self._trainer_modules.update(pcont_decoder=self.pcont_decoder)

    @property
    def _action_dist(self):
        return 'tanh_normal' if self.is_continuous else 'one_hot'  # 'relaxed_one_hot'

    @property
    def decoder_input_dim(self):
        return self.stoch_dim + self.deter_dim

    def _dreamer_build_rssm(self):
        return RecurrentStateSpaceModel(
            self.stoch_dim, self.deter_dim, self.a_dim, self.obs_encoder.h_dim,
            **self._network_settings['rssm']).to(self.device)

    def _dreamer_build_critic(self):
        return DenseModel(self.decoder_input_dim, 1,
                          **self._network_settings['critic']).to(self.device)

    @iton
    def select_action(self, obs):
        if self._is_visual:
            obs = get_first_visual(obs)
        else:
            obs = get_first_vector(obs)
        embedded_obs = self.obs_encoder(obs)  # [B, *]
        state_posterior = self.rssm.posterior(self.rnncs['hx'], embedded_obs)
        state = state_posterior.sample()  # [B, *]
        actions = self.actor.sample_actions(th.cat((state, self.rnncs['hx']),
                                                   -1),
                                            is_train=self._is_train_mode)
        actions = self._exploration(actions)
        _, self.rnncs_['hx'] = self.rssm.prior(state, actions,
                                               self.rnncs['hx'])
        if not self.is_continuous:
            actions = actions.argmax(-1)  # [B,]
        return actions, Data(action=actions)

    def _exploration(self, action: th.Tensor) -> th.Tensor:
        """
        :param action: action to take, shape (1,) (if categorical), or (action dim,) (if continuous)
        :return: action of the same shape passed in, augmented with some noise
        """
        if self.is_continuous:
            sigma = self._action_sigma if self._is_train_mode else 0.
            noise = th.randn(*action.shape) * sigma
            return th.clamp(action + noise, -1, 1)
        else:
            if self._is_train_mode and self.expl_expt_mng.is_random(
                    self._cur_train_step):
                index = th.randint(0, self.a_dim, (self.n_copies, ))
                action = th.zeros_like(action)
                action[..., index] = 1
            return action

    @iton
    def _train(self, BATCH):
        T, B = BATCH.action.shape[:2]
        if self._is_visual:
            obs_ = get_first_visual(BATCH.obs_)
        else:
            obs_ = get_first_vector(BATCH.obs_)

        # embed observations with CNN
        embedded_observations = self.obs_encoder(obs_)  # [T, B, *]

        # initialize state and rnn hidden state with 0 vector
        state, rnn_hidden = self.rssm.init_state(shape=B)  # [B, S], [B, D]

        # compute state and rnn hidden sequences and kl loss
        kl_loss = 0
        states, rnn_hiddens = [], []
        for l in range(T):
            # if the begin of this episode, then reset to 0.
            # No matther whether last episode is beened truncated of not.
            state = state * (1. - BATCH.begin_mask[l])  # [B, S]
            rnn_hidden = rnn_hidden * (1. - BATCH.begin_mask[l])  # [B, D]

            next_state_prior, next_state_posterior, rnn_hidden = self.rssm(
                state, BATCH.action[l], rnn_hidden,
                embedded_observations[l])  # a, s_
            state = next_state_posterior.rsample()  # [B, S] posterior of s_
            states.append(state)  # [B, S]
            rnn_hiddens.append(rnn_hidden)  # [B, D]
            kl_loss += self._kl_loss(next_state_prior, next_state_posterior)
        kl_loss /= T  # 1

        # compute reconstructed observations and predicted rewards
        post_feat = th.cat([th.stack(states, 0),
                            th.stack(rnn_hiddens, 0)], -1)  # [T, B, *]

        obs_pred = self.obs_decoder(post_feat)  # [T, B, C, H, W] or [T, B, *]
        reward_pred = self.reward_predictor(post_feat)  # [T, B, 1], s_ => r

        # compute loss for observation and reward
        obs_loss = -th.mean(obs_pred.log_prob(obs_))  # [T, B] => 1
        # [T, B, 1]=>1
        reward_loss = -th.mean(
            reward_pred.log_prob(BATCH.reward).unsqueeze(-1))

        # add all losses and update model parameters with gradient descent
        model_loss = self.kl_scale * kl_loss + obs_loss + self.reward_scale * reward_loss  # 1

        if self.use_pcont:
            pcont_pred = self.pcont_decoder(post_feat)  # [T, B, 1], s_ => done
            # https://github.com/danijar/dreamer/issues/2#issuecomment-605392659
            pcont_target = self.gamma * (1. - BATCH.done)
            # [T, B, 1]=>1
            pcont_loss = -th.mean(
                pcont_pred.log_prob(pcont_target).unsqueeze(-1))
            model_loss += self.pcont_scale * pcont_loss

        self.model_oplr.optimize(model_loss)

        # remove gradients from previously calculated tensors
        with th.no_grad():
            # [T, B, S] => [T*B, S]
            flatten_states = th.cat(states, 0).detach()
            # [T, B, D] => [T*B, D]
            flatten_rnn_hiddens = th.cat(rnn_hiddens, 0).detach()

        with FreezeParameters(self.model_oplr.parameters):
            # compute target values
            imaginated_states = []
            imaginated_rnn_hiddens = []
            log_probs = []
            entropies = []

            for h in range(self.imagination_horizon):
                imaginated_states.append(flatten_states)  # [T*B, S]
                imaginated_rnn_hiddens.append(flatten_rnn_hiddens)  # [T*B, D]

                flatten_feat = th.cat([flatten_states, flatten_rnn_hiddens],
                                      -1).detach()
                action_dist = self.actor(flatten_feat)
                actions = action_dist.rsample()  # [T*B, A]
                log_probs.append(
                    action_dist.log_prob(
                        actions.detach()).unsqueeze(-1))  # [T*B, 1]
                entropies.append(
                    action_dist.entropy().unsqueeze(-1))  # [T*B, 1]
                flatten_states_prior, flatten_rnn_hiddens = self.rssm.prior(
                    flatten_states, actions, flatten_rnn_hiddens)
                flatten_states = flatten_states_prior.rsample()  # [T*B, S]

            imaginated_states = th.stack(imaginated_states, 0)  # [H, T*B, S]
            imaginated_rnn_hiddens = th.stack(imaginated_rnn_hiddens,
                                              0)  # [H, T*B, D]
            log_probs = th.stack(log_probs, 0)  # [H, T*B, 1]
            entropies = th.stack(entropies, 0)  # [H, T*B, 1]

        imaginated_feats = th.cat([imaginated_states, imaginated_rnn_hiddens],
                                  -1)  # [H, T*B, *]

        with FreezeParameters(self.model_oplr.parameters +
                              self.critic_oplr.parameters):
            imaginated_rewards = self.reward_predictor(
                imaginated_feats).mean  # [H, T*B, 1]
            imaginated_values = self._dreamer_target_img_value(
                imaginated_feats)  # [H, T*B, 1]]

        # Compute the exponential discounted sum of rewards
        if self.use_pcont:
            with FreezeParameters(self.pcont_decoder.parameters()):
                discount_arr = self.pcont_decoder(
                    imaginated_feats).mean  # [H, T*B, 1]
        else:
            discount_arr = self.gamma * th.ones_like(
                imaginated_rewards)  # [H, T*B, 1]

        returns = compute_return(imaginated_rewards[:-1],
                                 imaginated_values[:-1],
                                 discount_arr[:-1],
                                 bootstrap=imaginated_values[-1],
                                 lambda_=self.lambda_)  # [H-1, T*B, 1]
        # Make the top row 1 so the cumulative product starts with discount^0
        discount_arr = th.cat(
            [th.ones_like(discount_arr[:1]), discount_arr[:-1]],
            0)  # [H, T*B, 1]
        discount = th.cumprod(discount_arr, 0).detach()[:-1]  # [H-1, T*B, 1]

        # discount_arr = th.cat([th.ones_like(discount_arr[:1]), discount_arr[1:]])
        # discount = th.cumprod(discount_arr[:-1], 0)

        actor_loss = self._dreamer_build_actor_loss(imaginated_feats,
                                                    log_probs, entropies,
                                                    discount, returns)  # 1

        # Don't let gradients pass through to prevent overwriting gradients.
        # Value Loss
        with th.no_grad():
            value_feat = imaginated_feats[:-1].detach()  # [H-1, T*B, 1]
            value_target = returns.detach()  # [H-1, T*B, 1]

        value_pred = self.critic(value_feat)  # [H-1, T*B, 1]
        log_prob = value_pred.log_prob(value_target).unsqueeze(
            -1)  # [H-1, T*B, 1]
        critic_loss = -th.mean(discount * log_prob)  # 1

        self.actor_oplr.zero_grad()
        self.critic_oplr.zero_grad()

        self.actor_oplr.backward(actor_loss)
        self.critic_oplr.backward(critic_loss)

        self.actor_oplr.step()
        self.critic_oplr.step()

        td_error = (value_pred.mean - value_target).mean(0).detach()  # [T*B,]
        td_error = td_error.view(T, B, 1)

        summaries = {
            'LEARNING_RATE/model_lr': self.model_oplr.lr,
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/model_loss': model_loss,
            'LOSS/kl_loss': kl_loss,
            'LOSS/obs_loss': obs_loss,
            'LOSS/reward_loss': reward_loss,
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': critic_loss
        }
        if self.use_pcont:
            summaries.update({'LOSS/pcont_loss', pcont_loss})

        return td_error, summaries

    def _initial_rnncs(self, batch: int) -> Dict[str, np.ndarray]:
        return {'hx': np.zeros((batch, self.deter_dim))}

    def _kl_loss(self, prior_dist, post_dist):
        # 1
        return td.kl_divergence(prior_dist,
                                post_dist).clamp(min=self.kl_free_nats).mean()

    def _dreamer_target_img_value(self, imaginated_feats):
        imaginated_values = self.critic(imaginated_feats).mean  # [H, T*B, 1]
        return imaginated_values

    def _dreamer_build_actor_loss(self, imaginated_feats, log_probs, entropies,
                                  discount, returns):
        actor_loss = -th.mean(discount * returns)  # [H-1, T*B, 1] => 1
        return actor_loss
Exemple #8
0
class VDN(MultiAgentOffPolicy):
    """
    Value-Decomposition Networks For Cooperative Multi-Agent Learning, http://arxiv.org/abs/1706.05296
    QMIX: Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning, http://arxiv.org/abs/1803.11485
    Qatten: A General Framework for Cooperative Multiagent Reinforcement Learning, http://arxiv.org/abs/2002.03939
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 mixer='vdn',
                 mixer_settings={},
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 use_double=True,
                 init2mid_annealing_step=1000,
                 assign_interval=1000,
                 network_settings={
                     'share': [128],
                     'v': [128],
                     'adv': [128]
                 },
                 **kwargs):
        super().__init__(**kwargs)
        assert not any(list(self.is_continuouss.values())
                       ), 'VDN only support discrete action space'
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self._max_train_step)
        self.assign_interval = assign_interval
        self._use_double = use_double
        self._mixer_type = mixer
        self._mixer_settings = mixer_settings

        self.q_nets = {}
        for id in set(self.model_ids):
            self.q_nets[id] = TargetTwin(
                CriticDueling(self.obs_specs[id],
                              rep_net_params=self._rep_net_params,
                              output_shape=self.a_dims[id],
                              network_settings=network_settings)).to(
                                  self.device)

        self.mixer = self._build_mixer()

        self.oplr = OPLR(
            tuple(self.q_nets.values()) + (self.mixer, ), lr,
            **self._oplr_params)
        self._trainer_modules.update(
            {f"model_{id}": self.q_nets[id]
             for id in set(self.model_ids)})
        self._trainer_modules.update(mixer=self.mixer, oplr=self.oplr)

    def _build_mixer(self):
        assert self._mixer_type in [
            'vdn', 'qmix', 'qatten'
        ], "assert self._mixer_type in ['vdn', 'qmix', 'qatten']"
        if self._mixer_type in ['qmix', 'qatten']:
            assert self._has_global_state, 'assert self._has_global_state'
        return TargetTwin(Mixer_REGISTER[self._mixer_type](
            n_agents=self.n_agents_percopy,
            state_spec=self.state_spec,
            rep_net_params=self._rep_net_params,
            **self._mixer_settings)).to(self.device)

    @iton  # TODO: optimization
    def select_action(self, obs):
        acts_info = {}
        actions = {}
        for aid, mid in zip(self.agent_ids, self.model_ids):
            q_values = self.q_nets[mid](obs[aid],
                                        rnncs=self.rnncs[aid])  # [B, A]
            self.rnncs_[aid] = self.q_nets[mid].get_rnncs()

            if self._is_train_mode and self.expl_expt_mng.is_random(
                    self._cur_train_step):
                action = np.random.randint(0, self.a_dims[aid], self.n_copies)
            else:
                action = q_values.argmax(-1)  # [B,]

            actions[aid] = action
            acts_info[aid] = Data(action=action)
        return actions, acts_info

    @iton
    def _train(self, BATCH_DICT):
        summaries = {}
        reward = BATCH_DICT[self.agent_ids[0]].reward  # [T, B, 1]
        done = 0.
        q_evals = []
        q_target_next_choose_maxs = []
        for aid, mid in zip(self.agent_ids, self.model_ids):
            done += BATCH_DICT[aid].done  # [T, B, 1]

            q = self.q_nets[mid](
                BATCH_DICT[aid].obs,
                begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
            q_eval = (q * BATCH_DICT[aid].action).sum(
                -1, keepdim=True)  # [T, B, 1]
            q_evals.append(q_eval)  # N * [T, B, 1]

            q_target = self.q_nets[mid].t(
                BATCH_DICT[aid].obs_,
                begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
            if self._use_double:
                next_q = self.q_nets[mid](
                    BATCH_DICT[aid].obs_,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]

                next_max_action = next_q.argmax(-1)  # [T, B]
                next_max_action_one_hot = F.one_hot(
                    next_max_action, self.a_dims[aid]).float()  # [T, B, A]

                q_target_next_max = (q_target * next_max_action_one_hot).sum(
                    -1, keepdim=True)  # [T, B, 1]
            else:
                # [T, B, 1]
                q_target_next_max = q_target.max(-1, keepdim=True)[0]

            q_target_next_choose_maxs.append(
                q_target_next_max)  # N * [T, B, 1]

        q_evals = th.stack(q_evals, -1)  # [T, B, 1, N]
        q_target_next_choose_maxs = th.stack(q_target_next_choose_maxs,
                                             -1)  # [T, B, 1, N]
        q_eval_tot = self.mixer(
            q_evals,
            BATCH_DICT['global'].obs,
            begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, 1]
        q_target_next_max_tot = self.mixer.t(
            q_target_next_choose_maxs,
            BATCH_DICT['global'].obs_,
            begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, 1]

        q_target_tot = n_step_return(
            reward, self.gamma, (done > 0.).float(), q_target_next_max_tot,
            BATCH_DICT['global'].begin_mask).detach()  # [T, B, 1]
        td_error = q_target_tot - q_eval_tot  # [T, B, 1]
        q_loss = td_error.square().mean()  # 1
        self.oplr.optimize(q_loss)

        summaries['model'] = {
            'LOSS/q_loss': q_loss,
            'Statistics/q_max': q_eval_tot.max(),
            'Statistics/q_min': q_eval_tot.min(),
            'Statistics/q_mean': q_eval_tot.mean()
        }
        return td_error, summaries

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            for q_net in self.q_nets.values():
                q_net.sync()
            self.mixer.sync()
Exemple #9
0
class DDDQN(SarlOffPolicy):
    """
    Dueling Double DQN, https://arxiv.org/abs/1511.06581
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=2,
                 network_settings={
                     'share': [128],
                     'v': [128],
                     'adv': [128]
                 },
                 **kwargs):
        super().__init__(**kwargs)
        assert not self.is_continuous, 'dueling double dqn only support discrete action space'
        self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init,
                                                          eps_mid=eps_mid,
                                                          eps_final=eps_final,
                                                          init2mid_annealing_step=init2mid_annealing_step,
                                                          max_step=self._max_train_step)
        self.assign_interval = assign_interval

        self.q_net = TargetTwin(CriticDueling(self.obs_spec,
                                              rep_net_params=self._rep_net_params,
                                              output_shape=self.a_dim,
                                              network_settings=network_settings)).to(self.device)

        self.oplr = OPLR(self.q_net, lr, **self._oplr_params)
        self._trainer_modules.update(model=self.q_net,
                                     oplr=self.oplr)

    @iton
    def select_action(self, obs):
        q_values = self.q_net(obs, rnncs=self.rnncs)  # [B, A]
        self.rnncs_ = self.q_net.get_rnncs()

        if self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            actions = q_values.argmax(-1)  # [B,]
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
        next_q = self.q_net(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q_target = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]

        q_eval = (q * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        next_max_action = next_q.argmax(-1)  # [T, B]
        next_max_action_one_hot = F.one_hot(next_max_action.squeeze(), self.a_dim).float()  # [T, B, A]

        q_target_next_max = (q_target * next_max_action_one_hot).sum(-1, keepdim=True)  # [T, B, 1]
        q_target = n_step_return(BATCH.reward,
                                 self.gamma,
                                 BATCH.done,
                                 q_target_next_max,
                                 BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1]
        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1
        self.oplr.optimize(q_loss)

        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': q_loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            self.q_net.sync()
Exemple #10
0
class NPG(SarlOnPolicy):
    """
    Natural Policy Gradient, NPG
    https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf
    """
    policy_mode = 'on-policy'

    def __init__(
            self,
            agent_spec,
            actor_step_size=0.5,
            beta=1.0e-3,
            lambda_=0.95,
            cg_iters=10,
            damping_coeff=0.1,
            epsilon=0.2,
            critic_lr=1e-3,
            train_critic_iters=10,
            network_settings={
                'actor_continuous': {
                    'hidden_units': [64, 64],
                    'condition_sigma': False,
                    'log_std_bound': [-20, 2]
                },
                'actor_discrete': [32, 32],
                'critic': [32, 32]
            },
            **kwargs):
        super().__init__(agent_spec=agent_spec, **kwargs)
        self.actor_step_size = actor_step_size
        self.beta = beta
        self.lambda_ = lambda_
        self._epsilon = epsilon
        self._cg_iters = cg_iters
        self._damping_coeff = damping_coeff
        self._train_critic_iters = train_critic_iters

        if self.is_continuous:
            self.actor = ActorMuLogstd(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_continuous']).to(
                    self.device)
        else:
            self.actor = ActorDct(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_discrete']).to(
                    self.device)
        self.critic = CriticValue(
            self.obs_spec,
            rep_net_params=self._rep_net_params,
            network_settings=network_settings['critic']).to(self.device)

        self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params)
        self._trainer_modules.update(actor=self.actor,
                                     critic=self.critic,
                                     critic_oplr=self.critic_oplr)

    @iton
    def select_action(self, obs):
        output = self.actor(obs, rnncs=self.rnncs)  # [B, A]
        self.rnncs_ = self.actor.get_rnncs()
        value = self.critic(obs, rnncs=self.rnncs)  # [B, 1]
        if self.is_continuous:
            mu, log_std = output  # [B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            action = dist.sample().clamp(-1, 1)  # [B, A]
            log_prob = dist.log_prob(action).unsqueeze(-1)  # [B, 1]
        else:
            logits = output  # [B, A]
            logp_all = logits.log_softmax(-1)  # [B, A]
            norm_dist = td.Categorical(logits=logp_all)
            action = norm_dist.sample()  # [B,]
            log_prob = norm_dist.log_prob(action).unsqueeze(-1)  # [B, 1]
        acts_info = Data(action=action,
                         value=value,
                         log_prob=log_prob + th.finfo().eps)
        if self.use_rnn:
            acts_info.update(rnncs=self.rnncs)
        if self.is_continuous:
            acts_info.update(mu=mu, log_std=log_std)
        else:
            acts_info.update(logp_all=logp_all)
        return action, acts_info

    @iton
    def _get_value(self, obs, rnncs=None):
        value = self.critic(obs, rnncs=rnncs)  # [B, 1]
        return value

    def _preprocess_BATCH(self, BATCH):  # [T, B, *]
        BATCH = super()._preprocess_BATCH(BATCH)
        value = self._get_value(BATCH.obs_[-1], rnncs=self.rnncs)
        BATCH.discounted_reward = discounted_sum(BATCH.reward,
                                                 self.gamma,
                                                 BATCH.done,
                                                 BATCH.begin_mask,
                                                 init_value=value)
        td_error = calculate_td_error(
            BATCH.reward,
            self.gamma,
            BATCH.done,
            value=BATCH.value,
            next_value=np.concatenate((BATCH.value[1:], value[np.newaxis, :]),
                                      0))
        BATCH.gae_adv = discounted_sum(td_error,
                                       self.lambda_ * self.gamma,
                                       BATCH.done,
                                       BATCH.begin_mask,
                                       init_value=0.,
                                       normalize=True)
        return BATCH

    @iton
    def _train(self, BATCH):
        output = self.actor(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        if self.is_continuous:
            mu, log_std = output  # [T, B, A], [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            new_log_prob = dist.log_prob(BATCH.action).unsqueeze(
                -1)  # [T, B, 1]
            entropy = dist.entropy().mean()  # 1
        else:
            logits = output  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            new_log_prob = (BATCH.action * logp_all).sum(
                -1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(-1).mean()  # 1
        ratio = (new_log_prob - BATCH.log_prob).exp()  # [T, B, 1]
        actor_loss = -(ratio * BATCH.gae_adv).mean()  # 1

        flat_grads = grads_flatten(actor_loss, self.actor,
                                   retain_graph=True).detach()  # [1,]

        if self.is_continuous:
            kl = td.kl_divergence(
                td.Independent(td.Normal(BATCH.mu, BATCH.log_std.exp()), 1),
                td.Independent(td.Normal(mu, log_std.exp()), 1)).mean()
        else:
            kl = (BATCH.logp_all.exp() *
                  (BATCH.logp_all - logp_all)).sum(-1).mean()  # 1

        flat_kl_grad = grads_flatten(kl, self.actor, create_graph=True)
        search_direction = -self._conjugate_gradients(
            flat_grads, flat_kl_grad, cg_iters=self._cg_iters)  # [1,]

        with th.no_grad():
            flat_params = th.cat(
                [param.data.view(-1) for param in self.actor.parameters()])
            new_flat_params = flat_params + self.actor_step_size * search_direction
            set_from_flat_params(self.actor, new_flat_params)

        for _ in range(self._train_critic_iters):
            value = self.critic(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, 1]
            td_error = BATCH.discounted_reward - value  # [T, B, 1]
            critic_loss = td_error.square().mean()  # 1
            self.critic_oplr.optimize(critic_loss)

        return {
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': critic_loss,
            'Statistics/entropy': entropy.mean(),
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr
        }

    def _conjugate_gradients(self,
                             flat_grads,
                             flat_kl_grad,
                             cg_iters: int = 10,
                             residual_tol: float = 1e-10):
        """
        Conjugate gradient algorithm
        (see https://en.wikipedia.org/wiki/Conjugate_gradient_method)
        """
        x = th.zeros_like(flat_grads)
        r, p = flat_grads.clone(), flat_grads.clone()
        # Note: should be 'r, p = b - MVP(x)', but for x=0, MVP(x)=0.
        # Change if doing warm start.
        rdotr = r.dot(r)
        for i in range(cg_iters):
            z = self._MVP(p, flat_kl_grad)
            alpha = rdotr / (p.dot(z) + th.finfo().eps)
            x += alpha * p
            r -= alpha * z
            new_rdotr = r.dot(r)
            if new_rdotr < residual_tol:
                break
            p = r + new_rdotr / rdotr * p
            rdotr = new_rdotr
        return x

    def _MVP(self, v, flat_kl_grad):
        """Matrix vector product."""
        # caculate second order gradient of kl with respect to theta
        kl_v = (flat_kl_grad * v).sum()
        mvp = grads_flatten(kl_v, self.actor, retain_graph=True).detach()
        mvp += max(0, self._damping_coeff) * v
        return mvp
Exemple #11
0
class BootstrappedDQN(SarlOffPolicy):
    """
    Deep Exploration via Bootstrapped DQN, http://arxiv.org/abs/1602.04621
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=1000,
                 head_num=4,
                 network_settings=[32, 32],
                 **kwargs):
        super().__init__(**kwargs)
        assert not self.is_continuous, 'Bootstrapped DQN only support discrete action space'
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self._max_train_step)
        self.assign_interval = assign_interval
        self.head_num = head_num
        self._probs = th.FloatTensor([1. / head_num for _ in range(head_num)])
        self.now_head = 0

        self.q_net = TargetTwin(
            CriticQvalueBootstrap(self.obs_spec,
                                  rep_net_params=self._rep_net_params,
                                  output_shape=self.a_dim,
                                  head_num=self.head_num,
                                  network_settings=network_settings)).to(
                                      self.device)

        self.oplr = OPLR(self.q_net, lr, **self._oplr_params)
        self._trainer_modules.update(model=self.q_net, oplr=self.oplr)

    def episode_reset(self):
        super().episode_reset()
        self.now_head = np.random.randint(self.head_num)

    @iton
    def select_action(self, obs):
        q_values = self.q_net(obs, rnncs=self.rnncs)  # [H, B, A]
        self.rnncs_ = self.q_net.get_rnncs()

        if self._is_train_mode and self.expl_expt_mng.is_random(
                self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            # [H, B, A] => [B, A] => [B, ]
            actions = q_values[self.now_head].argmax(-1)
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask).mean(
            0)  # [H, T, B, A] => [T, B, A]
        q_next = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask).mean(
            0)  # [H, T, B, A] => [T, B, A]
        # [T, B, A] * [T, B, A] => [T, B, 1]
        q_eval = (q * BATCH.action).sum(-1, keepdim=True)
        q_target = n_step_return(
            BATCH.reward,
            self.gamma,
            BATCH.done,
            # [T, B, A] => [T, B, 1]
            q_next.max(-1, keepdim=True)[0],
            BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1]
        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1

        # mask_dist = td.Bernoulli(probs=self._probs)  # TODO:
        # mask = mask_dist.sample([batch_size]).T   # [H, B]
        self.oplr.optimize(q_loss)
        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': q_loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            self.q_net.sync()
Exemple #12
0
class PG(SarlOnPolicy):
    policy_mode = 'on-policy'

    def __init__(
            self,
            agent_spec,
            lr=5.0e-4,
            network_settings={
                'actor_continuous': {
                    'hidden_units': [32, 32],
                    'condition_sigma': False,
                    'log_std_bound': [-20, 2]
                },
                'actor_discrete': [32, 32]
            },
            **kwargs):
        super().__init__(agent_spec=agent_spec, **kwargs)
        if self.is_continuous:
            self.net = ActorMuLogstd(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_continuous']).to(
                    self.device)
        else:
            self.net = ActorDct(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_discrete']).to(
                    self.device)
        self.oplr = OPLR(self.net, lr, **self._oplr_params)

        self._trainer_modules.update(model=self.net, oplr=self.oplr)

    @iton
    def select_action(self, obs):
        output = self.net(obs, rnncs=self.rnncs)  # [B, A]
        self.rnncs_ = self.net.get_rnncs()
        if self.is_continuous:
            mu, log_std = output  # [B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            action = dist.sample().clamp(-1, 1)  # [B, A]
        else:
            logits = output  # [B, A]
            norm_dist = td.Categorical(logits=logits)
            action = norm_dist.sample()  # [B,]

        acts_info = Data(action=action)
        if self.use_rnn:
            acts_info.update(rnncs=self.rnncs)
        return action, acts_info

    def _preprocess_BATCH(self, BATCH):  # [T, B, *]
        BATCH = super()._preprocess_BATCH(BATCH)
        BATCH.discounted_reward = discounted_sum(BATCH.reward,
                                                 self.gamma,
                                                 BATCH.done,
                                                 BATCH.begin_mask,
                                                 init_value=0.,
                                                 normalize=True)
        return BATCH

    @iton
    def _train(self, BATCH):  # [B, T, *]
        output = self.net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [B, T, A]
        if self.is_continuous:
            mu, log_std = output  # [B, T, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            log_act_prob = dist.log_prob(BATCH.action).unsqueeze(
                -1)  # [B, T, 1]
            entropy = dist.entropy().unsqueeze(-1)  # [B, T, 1]
        else:
            logits = output  # [B, T, A]
            logp_all = logits.log_softmax(-1)  # [B, T, A]
            log_act_prob = (logp_all * BATCH.action).sum(
                -1, keepdim=True)  # [B, T, 1]
            entropy = -(logp_all.exp() * logp_all).sum(
                1, keepdim=True)  # [B, T, 1]
        loss = -(log_act_prob * BATCH.discounted_reward).mean()
        self.oplr.optimize(loss)
        return {
            'LOSS/loss': loss,
            'Statistics/entropy': entropy.mean(),
            'LEARNING_RATE/lr': self.oplr.lr
        }
Exemple #13
0
class MASAC(MultiAgentOffPolicy):
    policy_mode = 'off-policy'

    def __init__(
            self,
            alpha=0.2,
            annealing=True,
            last_alpha=0.01,
            polyak=0.995,
            discrete_tau=1.0,
            network_settings={
                'actor_continuous': {
                    'share': [128, 128],
                    'mu': [64],
                    'log_std': [64],
                    'soft_clip': False,
                    'log_std_bound': [-20, 2]
                },
                'actor_discrete': [64, 32],
                'q': [128, 128]
            },
            auto_adaption=True,
            actor_lr=5.0e-4,
            critic_lr=1.0e-3,
            alpha_lr=5.0e-4,
            **kwargs):
        """
        TODO: Annotation
        """
        super().__init__(**kwargs)
        self.polyak = polyak
        self.discrete_tau = discrete_tau
        self.auto_adaption = auto_adaption
        self.annealing = annealing

        self.target_entropy = 0.98
        for id in self.agent_ids:
            if self.is_continuouss[id]:
                self.target_entropy *= (-self.a_dims[id])
            else:
                self.target_entropy *= np.log(self.a_dims[id])

        self.actors, self.critics, self.critics2 = {}, {}, {}
        for id in set(self.model_ids):
            if self.is_continuouss[id]:
                self.actors[id] = ActorCts(
                    self.obs_specs[id],
                    rep_net_params=self._rep_net_params,
                    output_shape=self.a_dims[id],
                    network_settings=network_settings['actor_continuous']).to(
                        self.device)
            else:
                self.actors[id] = ActorDct(
                    self.obs_specs[id],
                    rep_net_params=self._rep_net_params,
                    output_shape=self.a_dims[id],
                    network_settings=network_settings['actor_discrete']).to(
                        self.device)
            self.critics[id] = TargetTwin(
                MACriticQvalueOne(list(self.obs_specs.values()),
                                  rep_net_params=self._rep_net_params,
                                  action_dim=sum(self.a_dims.values()),
                                  network_settings=network_settings['q']),
                self.polyak).to(self.device)
            self.critics2[id] = deepcopy(self.critics[id])
        self.actor_oplr = OPLR(list(self.actors.values()), actor_lr,
                               **self._oplr_params)
        self.critic_oplr = OPLR(
            list(self.critics.values()) + list(self.critics2.values()),
            critic_lr, **self._oplr_params)

        if self.auto_adaption:
            self.log_alpha = th.tensor(0., requires_grad=True).to(self.device)
            self.alpha_oplr = OPLR(self.log_alpha, alpha_lr,
                                   **self._oplr_params)
            self._trainer_modules.update(alpha_oplr=self.alpha_oplr)
        else:
            self.log_alpha = th.tensor(alpha).log().to(self.device)
            if self.annealing:
                self.alpha_annealing = LinearAnnealing(alpha, last_alpha,
                                                       int(1e6))

        self._trainer_modules.update(
            {f"actor_{id}": self.actors[id]
             for id in set(self.model_ids)})
        self._trainer_modules.update(
            {f"critic_{id}": self.critics[id]
             for id in set(self.model_ids)})
        self._trainer_modules.update(
            {f"critic2_{id}": self.critics2[id]
             for id in set(self.model_ids)})
        self._trainer_modules.update(log_alpha=self.log_alpha,
                                     actor_oplr=self.actor_oplr,
                                     critic_oplr=self.critic_oplr)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    @iton
    def select_action(self, obs: Dict):
        acts_info = {}
        actions = {}
        for aid, mid in zip(self.agent_ids, self.model_ids):
            output = self.actors[mid](obs[aid],
                                      rnncs=self.rnncs[aid])  # [B, A]
            self.rnncs_[aid] = self.actors[mid].get_rnncs()
            if self.is_continuouss[aid]:
                mu, log_std = output  # [B, A]
                pi = td.Normal(mu, log_std.exp()).sample().tanh()
                mu.tanh_()  # squash mu  # [B, A]
            else:
                logits = output  # [B, A]
                mu = logits.argmax(-1)  # [B,]
                cate_dist = td.Categorical(logits=logits)
                pi = cate_dist.sample()  # [B,]
            action = pi if self._is_train_mode else mu
            acts_info[aid] = Data(action=action)
            actions[aid] = action
        return actions, acts_info

    @iton
    def _train(self, BATCH_DICT):
        """
        TODO: Annotation
        """
        summaries = defaultdict(dict)
        target_actions = {}
        target_log_pis = 1.
        for aid, mid in zip(self.agent_ids, self.model_ids):
            if self.is_continuouss[aid]:
                target_mu, target_log_std = self.actors[mid](
                    BATCH_DICT[aid].obs_,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
                dist = td.Independent(
                    td.Normal(target_mu, target_log_std.exp()), 1)
                target_pi = dist.sample()  # [T, B, A]
                target_pi, target_log_pi = squash_action(
                    target_pi,
                    dist.log_prob(target_pi).unsqueeze(
                        -1))  # [T, B, A], [T, B, 1]
            else:
                target_logits = self.actors[mid](
                    BATCH_DICT[aid].obs_,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
                target_cate_dist = td.Categorical(logits=target_logits)
                target_pi = target_cate_dist.sample()  # [T, B]
                target_log_pi = target_cate_dist.log_prob(target_pi).unsqueeze(
                    -1)  # [T, B, 1]
                target_pi = F.one_hot(target_pi,
                                      self.a_dims[aid]).float()  # [T, B, A]
            target_actions[aid] = target_pi
            target_log_pis *= target_log_pi

        target_log_pis += th.finfo().eps
        target_actions = th.cat(list(target_actions.values()),
                                -1)  # [T, B, N*A]

        qs1, qs2, q_targets1, q_targets2 = {}, {}, {}, {}
        for mid in self.model_ids:
            qs1[mid] = self.critics[mid](
                [BATCH_DICT[id].obs for id in self.agent_ids],
                th.cat([BATCH_DICT[id].action for id in self.agent_ids],
                       -1))  # [T, B, 1]
            qs2[mid] = self.critics2[mid](
                [BATCH_DICT[id].obs for id in self.agent_ids],
                th.cat([BATCH_DICT[id].action for id in self.agent_ids],
                       -1))  # [T, B, 1]
            q_targets1[mid] = self.critics[mid].t(
                [BATCH_DICT[id].obs_ for id in self.agent_ids],
                target_actions)  # [T, B, 1]
            q_targets2[mid] = self.critics2[mid].t(
                [BATCH_DICT[id].obs_ for id in self.agent_ids],
                target_actions)  # [T, B, 1]

        q_loss = {}
        td_errors = 0.
        for aid, mid in zip(self.agent_ids, self.model_ids):
            q_target = th.minimum(q_targets1[mid],
                                  q_targets2[mid])  # [T, B, 1]
            dc_r = n_step_return(
                BATCH_DICT[aid].reward, self.gamma, BATCH_DICT[aid].done,
                q_target - self.alpha * target_log_pis,
                BATCH_DICT['global'].begin_mask).detach()  # [T, B, 1]
            td_error1 = qs1[mid] - dc_r  # [T, B, 1]
            td_error2 = qs2[mid] - dc_r  # [T, B, 1]
            td_errors += (td_error1 + td_error2) / 2
            q1_loss = td_error1.square().mean()  # 1
            q2_loss = td_error2.square().mean()  # 1
            q_loss[aid] = 0.5 * q1_loss + 0.5 * q2_loss
            summaries[aid].update({
                'Statistics/q_min': qs1[mid].min(),
                'Statistics/q_mean': qs1[mid].mean(),
                'Statistics/q_max': qs1[mid].max()
            })
        self.critic_oplr.optimize(sum(q_loss.values()))

        log_pi_actions = {}
        log_pis = {}
        sample_pis = {}
        for aid, mid in zip(self.agent_ids, self.model_ids):
            if self.is_continuouss[aid]:
                mu, log_std = self.actors[mid](
                    BATCH_DICT[aid].obs,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
                dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
                pi = dist.rsample()  # [T, B, A]
                pi, log_pi = squash_action(
                    pi,
                    dist.log_prob(pi).unsqueeze(-1))  # [T, B, A], [T, B, 1]
                pi_action = BATCH_DICT[aid].action.arctanh()
                _, log_pi_action = squash_action(
                    pi_action,
                    dist.log_prob(pi_action).unsqueeze(
                        -1))  # [T, B, A], [T, B, 1]
            else:
                logits = self.actors[mid](
                    BATCH_DICT[aid].obs,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
                logp_all = logits.log_softmax(-1)  # [T, B, A]
                gumbel_noise = td.Gumbel(0,
                                         1).sample(logp_all.shape)  # [T, B, A]
                _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                    -1)  # [T, B, A]
                _pi_true_one_hot = F.one_hot(
                    _pi.argmax(-1), self.a_dims[aid]).float()  # [T, B, A]
                _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
                pi = _pi_diff + _pi  # [T, B, A]
                log_pi = (logp_all * pi).sum(-1, keepdim=True)  # [T, B, 1]
                log_pi_action = (logp_all * BATCH_DICT[aid].action).sum(
                    -1, keepdim=True)  # [T, B, 1]
            log_pi_actions[aid] = log_pi_action
            log_pis[aid] = log_pi
            sample_pis[aid] = pi

        actor_loss = {}
        for aid, mid in zip(self.agent_ids, self.model_ids):
            all_actions = {id: BATCH_DICT[id].action for id in self.agent_ids}
            all_actions[aid] = sample_pis[aid]
            all_log_pis = {id: log_pi_actions[id] for id in self.agent_ids}
            all_log_pis[aid] = log_pis[aid]

            q_s_pi = th.minimum(
                self.critics[mid](
                    [BATCH_DICT[id].obs for id in self.agent_ids],
                    th.cat(list(all_actions.values()), -1),
                    begin_mask=BATCH_DICT['global'].begin_mask),
                self.critics2[mid](
                    [BATCH_DICT[id].obs for id in self.agent_ids],
                    th.cat(list(all_actions.values()), -1),
                    begin_mask=BATCH_DICT['global'].begin_mask))  # [T, B, 1]

            _log_pis = 1.
            for _log_pi in all_log_pis.values():
                _log_pis *= _log_pi
            _log_pis += th.finfo().eps
            actor_loss[aid] = -(q_s_pi - self.alpha * _log_pis).mean()  # 1

        self.actor_oplr.optimize(sum(actor_loss.values()))

        for aid in self.agent_ids:
            summaries[aid].update({
                'LOSS/actor_loss': actor_loss[aid],
                'LOSS/critic_loss': q_loss[aid]
            })
        summaries['model'].update({
            'LOSS/actor_loss': sum(actor_loss.values()),
            'LOSS/critic_loss': sum(q_loss.values())
        })

        if self.auto_adaption:
            _log_pis = 1.
            _log_pis = 1.
            for _log_pi in log_pis.values():
                _log_pis *= _log_pi
            _log_pis += th.finfo().eps

            alpha_loss = -(
                self.alpha *
                (_log_pis + self.target_entropy).detach()).mean()  # 1

            self.alpha_oplr.optimize(alpha_loss)
            summaries['model'].update({
                'LOSS/alpha_loss':
                alpha_loss,
                'LEARNING_RATE/alpha_lr':
                self.alpha_oplr.lr
            })
        return td_errors / self.n_agents_percopy, summaries

    def _after_train(self):
        super()._after_train()
        if self.annealing and not self.auto_adaption:
            self.log_alpha.copy_(
                self.alpha_annealing(self._cur_train_step).log())
        for critic in self.critics.values():
            critic.sync()
        for critic2 in self.critics2.values():
            critic2.sync()
Exemple #14
0
class MVE(DDPG):
    """
    Model-Based Value Estimation for Efficient Model-Free Reinforcement Learning, http://arxiv.org/abs/1803.00101
    """
    policy_mode = 'off-policy'

    def __init__(self, wm_lr=1e-3, roll_out_horizon=15, **kwargs):
        super().__init__(**kwargs)
        network_settings = kwargs.get('network_settings', {})
        assert not self.obs_spec.has_visual_observation, "assert not self.obs_spec.has_visual_observation"
        assert self.obs_spec.has_vector_observation, "assert self.obs_spec.has_vector_observation"

        self._wm_lr = wm_lr
        self._roll_out_horizon = roll_out_horizon
        self._forward_dynamic_model = VectorSA2S(
            self.obs_spec.vector_dims[0],
            self.a_dim,
            hidden_units=network_settings['forward_model'])
        self._reward_model = VectorSA2R(
            self.obs_spec.vector_dims[0],
            self.a_dim,
            hidden_units=network_settings['reward_model'])
        self._done_model = VectorSA2D(
            self.obs_spec.vector_dims[0],
            self.a_dim,
            hidden_units=network_settings['done_model'])
        self._wm_oplr = OPLR([
            self._forward_dynamic_model, self._reward_model, self._done_model
        ], self._wm_lr, **self._oplr_params)
        self._trainer_modules.update(
            _forward_dynamic_model=self._forward_dynamic_model,
            _reward_model=self._reward_model,
            _done_model=self._done_model,
            _wm_oplr=self._wm_oplr)

    @iton
    def _train(self, BATCH):

        obs = get_first_vector(BATCH.obs)  # [T, B, S]
        obs_ = get_first_vector(BATCH.obs_)  # [T, B, S]
        _timestep = obs.shape[0]
        _batchsize = obs.shape[1]
        predicted_obs_ = self._forward_dynamic_model(obs,
                                                     BATCH.action)  # [T, B, S]
        predicted_reward = self._reward_model(obs, BATCH.action)  # [T, B, 1]
        predicted_done_dist = self._done_model(obs, BATCH.action)  # [T, B, 1]
        _obs_loss = F.mse_loss(obs_, predicted_obs_)  # todo
        _reward_loss = F.mse_loss(BATCH.reward, predicted_reward)
        _done_loss = -predicted_done_dist.log_prob(BATCH.done).mean()
        wm_loss = _obs_loss + _reward_loss + _done_loss
        self._wm_oplr.optimize(wm_loss)

        obs = th.reshape(obs, (_timestep * _batchsize, -1))  # [T*B, S]
        obs_ = th.reshape(obs_, (_timestep * _batchsize, -1))  # [T*B, S]
        actions = th.reshape(BATCH.action,
                             (_timestep * _batchsize, -1))  # [T*B, A]
        rewards = th.reshape(BATCH.reward,
                             (_timestep * _batchsize, -1))  # [T*B, 1]
        dones = th.reshape(BATCH.done,
                           (_timestep * _batchsize, -1))  # [T*B, 1]

        rollout_rewards = [rewards]
        rollout_dones = [dones]

        r_obs_ = obs_
        _r_obs = deepcopy(BATCH.obs_)
        r_done = (1. - dones)

        for _ in range(self._roll_out_horizon):
            r_obs = r_obs_
            _r_obs.vector.vector_0 = r_obs
            if self.is_continuous:
                action_target = self.actor.t(_r_obs)  # [T*B, A]
                if self.use_target_action_noise:
                    r_action = self.target_noised_action(
                        action_target)  # [T*B, A]
            else:
                target_logits = self.actor.t(_r_obs)  # [T*B, A]
                target_cate_dist = td.Categorical(logits=target_logits)
                target_pi = target_cate_dist.sample()  # [T*B,]
                r_action = F.one_hot(target_pi, self.a_dim).float()  # [T*B, A]
            r_obs_ = self._forward_dynamic_model(r_obs, r_action)  # [T*B, S]
            r_reward = self._reward_model(r_obs, r_action)  # [T*B, 1]
            r_done = r_done * (1. - self._done_model(r_obs, r_action).sample()
                               )  # [T*B, 1]

            rollout_rewards.append(r_reward)  # [H+1, T*B, 1]
            rollout_dones.append(r_done)  # [H+1, T*B, 1]

        _r_obs.vector.vector_0 = obs
        q = self.critic(_r_obs, actions)  # [T*B, 1]
        _r_obs.vector.vector_0 = r_obs_
        q_target = self.critic.t(_r_obs, r_action)  # [T*B, 1]
        dc_r = rewards
        for t in range(1, self._roll_out_horizon):
            dc_r += (self.gamma**t) * (rollout_rewards[t] * rollout_dones[t])
        dc_r += (self.gamma**self._roll_out_horizon) * rollout_dones[
            self._roll_out_horizon] * q_target  # [T*B, 1]

        td_error = dc_r - q  # [T*B, 1]
        q_loss = td_error.square().mean()  # 1
        self.critic_oplr.optimize(q_loss)

        # train actor
        if self.is_continuous:
            mu = self.actor(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                -1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1),
                                         self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            mu = _pi_diff + _pi  # [T, B, A]
        q_actor = self.critic(BATCH.obs, mu,
                              begin_mask=BATCH.begin_mask)  # [T, B, 1]
        actor_loss = -q_actor.mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        return th.ones_like(BATCH.reward), {
            'LEARNING_RATE/wm_lr': self._wm_oplr.lr,
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/wm_loss': wm_loss,
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': q_loss,
            'Statistics/q_min': q.min(),
            'Statistics/q_mean': q.mean(),
            'Statistics/q_max': q.max()
        }
Exemple #15
0
class PPO(SarlOnPolicy):
    """
    Proximal Policy Optimization, https://arxiv.org/abs/1707.06347
    Emergence of Locomotion Behaviours in Rich Environments, http://arxiv.org/abs/1707.02286, DPPO
    """
    policy_mode = 'on-policy'

    def __init__(self,
                 agent_spec,

                 ent_coef: float = 1.0e-2,
                 vf_coef: float = 0.5,
                 lr: float = 5.0e-4,
                 lambda_: float = 0.95,
                 epsilon: float = 0.2,
                 use_duel_clip: bool = False,
                 duel_epsilon: float = 0.,
                 use_vclip: bool = False,
                 value_epsilon: float = 0.2,
                 share_net: bool = True,
                 actor_lr: float = 3e-4,
                 critic_lr: float = 1e-3,
                 kl_reverse: bool = False,
                 kl_target: float = 0.02,
                 kl_target_cutoff: float = 2,
                 kl_target_earlystop: float = 4,
                 kl_beta: List[float] = [0.7, 1.3],
                 kl_alpha: float = 1.5,
                 kl_coef: float = 1.0,
                 extra_coef: float = 1000.0,
                 use_kl_loss: bool = False,
                 use_extra_loss: bool = False,
                 use_early_stop: bool = False,
                 network_settings: Dict = {
                     'share': {
                         'continuous': {
                             'condition_sigma': False,
                             'log_std_bound': [-20, 2],
                             'share': [32, 32],
                             'mu': [32, 32],
                             'v': [32, 32]
                         },
                         'discrete': {
                             'share': [32, 32],
                             'logits': [32, 32],
                             'v': [32, 32]
                         }
                     },
                     'actor_continuous': {
                         'hidden_units': [64, 64],
                         'condition_sigma': False,
                         'log_std_bound': [-20, 2]
                     },
                     'actor_discrete': [32, 32],
                     'critic': [32, 32]
                 },
                 **kwargs):
        super().__init__(agent_spec=agent_spec, **kwargs)
        self._ent_coef = ent_coef
        self.lambda_ = lambda_
        assert 0.0 <= lambda_ <= 1.0, "GAE lambda should be in [0, 1]."
        self._epsilon = epsilon
        self._use_vclip = use_vclip
        self._value_epsilon = value_epsilon
        self._share_net = share_net
        self._kl_reverse = kl_reverse
        self._kl_target = kl_target
        self._kl_alpha = kl_alpha
        self._kl_coef = kl_coef
        self._extra_coef = extra_coef
        self._vf_coef = vf_coef

        self._use_duel_clip = use_duel_clip
        self._duel_epsilon = duel_epsilon
        if self._use_duel_clip:
            assert - \
                       self._epsilon < self._duel_epsilon < self._epsilon, "duel_epsilon should be set in the range of (-epsilon, epsilon)."

        self._kl_cutoff = kl_target * kl_target_cutoff
        self._kl_stop = kl_target * kl_target_earlystop
        self._kl_low = kl_target * kl_beta[0]
        self._kl_high = kl_target * kl_beta[-1]

        self._use_kl_loss = use_kl_loss
        self._use_extra_loss = use_extra_loss
        self._use_early_stop = use_early_stop

        if self._share_net:
            if self.is_continuous:
                self.net = ActorCriticValueCts(self.obs_spec,
                                               rep_net_params=self._rep_net_params,
                                               output_shape=self.a_dim,
                                               network_settings=network_settings['share']['continuous']).to(self.device)
            else:
                self.net = ActorCriticValueDct(self.obs_spec,
                                               rep_net_params=self._rep_net_params,
                                               output_shape=self.a_dim,
                                               network_settings=network_settings['share']['discrete']).to(self.device)
            self.oplr = OPLR(self.net, lr, **self._oplr_params)
            self._trainer_modules.update(model=self.net,
                                         oplr=self.oplr)
        else:
            if self.is_continuous:
                self.actor = ActorMuLogstd(self.obs_spec,
                                           rep_net_params=self._rep_net_params,
                                           output_shape=self.a_dim,
                                           network_settings=network_settings['actor_continuous']).to(self.device)
            else:
                self.actor = ActorDct(self.obs_spec,
                                      rep_net_params=self._rep_net_params,
                                      output_shape=self.a_dim,
                                      network_settings=network_settings['actor_discrete']).to(self.device)
            self.critic = CriticValue(self.obs_spec,
                                      rep_net_params=self._rep_net_params,
                                      network_settings=network_settings['critic']).to(self.device)
            self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params)
            self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params)
            self._trainer_modules.update(actor=self.actor,
                                         critic=self.critic,
                                         actor_oplr=self.actor_oplr,
                                         critic_oplr=self.critic_oplr)

    @iton
    def select_action(self, obs):
        if self.is_continuous:
            if self._share_net:
                mu, log_std, value = self.net(obs, rnncs=self.rnncs)  # [B, A]
                self.rnncs_ = self.net.get_rnncs()
            else:
                mu, log_std = self.actor(obs, rnncs=self.rnncs)  # [B, A]
                self.rnncs_ = self.actor.get_rnncs()
                value = self.critic(obs, rnncs=self.rnncs)  # [B, 1]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            action = dist.sample().clamp(-1, 1)  # [B, A]
            log_prob = dist.log_prob(action).unsqueeze(-1)  # [B, 1]
        else:
            if self._share_net:
                logits, value = self.net(obs, rnncs=self.rnncs)  # [B, A], [B, 1]
                self.rnncs_ = self.net.get_rnncs()
            else:
                logits = self.actor(obs, rnncs=self.rnncs)  # [B, A]
                self.rnncs_ = self.actor.get_rnncs()
                value = self.critic(obs, rnncs=self.rnncs)  # [B, 1]
            norm_dist = td.Categorical(logits=logits)
            action = norm_dist.sample()  # [B,]
            log_prob = norm_dist.log_prob(action).unsqueeze(-1)  # [B, 1]

        acts_info = Data(action=action,
                         value=value,
                         log_prob=log_prob + th.finfo().eps)
        if self.use_rnn:
            acts_info.update(rnncs=self.rnncs)
        return action, acts_info

    @iton
    def _get_value(self, obs, rnncs=None):
        if self._share_net:
            if self.is_continuous:
                _, _, value = self.net(obs, rnncs=rnncs)  # [B, 1]
            else:
                _, value = self.net(obs, rnncs=rnncs)  # [B, 1]
        else:
            value = self.critic(obs, rnncs=rnncs)  # [B, 1]
        return value

    def _preprocess_BATCH(self, BATCH):  # [T, B, *]
        BATCH = super()._preprocess_BATCH(BATCH)
        value = self._get_value(BATCH.obs_[-1], rnncs=self.rnncs)
        BATCH.discounted_reward = discounted_sum(BATCH.reward,
                                                 self.gamma,
                                                 BATCH.done,
                                                 BATCH.begin_mask,
                                                 init_value=value)
        td_error = calculate_td_error(BATCH.reward,
                                      self.gamma,
                                      BATCH.done,
                                      value=BATCH.value,
                                      next_value=np.concatenate((BATCH.value[1:], value[np.newaxis, :]), 0))
        BATCH.gae_adv = discounted_sum(td_error,
                                       self.lambda_ * self.gamma,
                                       BATCH.done,
                                       BATCH.begin_mask,
                                       init_value=0.,
                                       normalize=True)
        return BATCH

    def learn(self, BATCH: Data):
        BATCH = self._preprocess_BATCH(BATCH)  # [T, B, *]
        for _ in range(self._epochs):
            kls = []
            for _BATCH in BATCH.sample(self._chunk_length, self.batch_size, repeat=self._sample_allow_repeat):
                _BATCH = self._before_train(_BATCH)
                summaries, kl = self._train(_BATCH)
                kls.append(kl)
                self.summaries.update(summaries)
                self._after_train()
            if self._use_early_stop and sum(kls) / len(kls) > self._kl_stop:
                break

    def _train(self, BATCH):
        if self._share_net:
            summaries, kl = self.train_share(BATCH)
        else:
            summaries = dict()
            actor_summaries, kl = self.train_actor(BATCH)
            critic_summaries = self.train_critic(BATCH)
            summaries.update(actor_summaries)
            summaries.update(critic_summaries)

        if self._use_kl_loss:
            # ref: https://github.com/joschu/modular_rl/blob/6970cde3da265cf2a98537250fea5e0c0d9a7639/modular_rl/ppo.py#L93
            if kl > self._kl_high:
                self._kl_coef *= self._kl_alpha
            elif kl < self._kl_low:
                self._kl_coef /= self._kl_alpha
            summaries.update({
                'Statistics/kl_coef': self._kl_coef
            })

        return summaries, kl

    @iton
    def train_share(self, BATCH):
        if self.is_continuous:
            # [T, B, A], [T, B, A], [T, B, 1]
            mu, log_std, value = self.net(BATCH.obs, begin_mask=BATCH.begin_mask)
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            new_log_prob = dist.log_prob(BATCH.action).unsqueeze(-1)  # [T, B, 1]
            entropy = dist.entropy().unsqueeze(-1)  # [T, B, 1]
        else:
            # [T, B, A], [T, B, 1]
            logits, value = self.net(BATCH.obs, begin_mask=BATCH.begin_mask)
            logp_all = logits.log_softmax(-1)  # [T, B, 1]
            new_log_prob = (BATCH.action * logp_all).sum(-1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(-1, keepdim=True)  # [T, B, 1]
        ratio = (new_log_prob - BATCH.log_prob).exp()  # [T, B, 1]
        surrogate = ratio * BATCH.gae_adv  # [T, B, 1]
        clipped_surrogate = th.minimum(
            surrogate,
            ratio.clamp(1.0 - self._epsilon, 1.0 + self._epsilon) * BATCH.gae_adv
        )  # [T, B, 1]
        # ref: https://github.com/thu-ml/tianshou/blob/c97aa4065ee8464bd5897bb86f1f81abd8e2cff9/tianshou/policy/modelfree/ppo.py#L159
        if self._use_duel_clip:
            clipped_surrogate2 = th.maximum(
                clipped_surrogate,
                (1.0 + self._duel_epsilon) * BATCH.gae_adv
            )  # [T, B, 1]
            clipped_surrogate = th.where(BATCH.gae_adv < 0, clipped_surrogate2, clipped_surrogate)  # [T, B, 1]
        actor_loss = -(clipped_surrogate + self._ent_coef * entropy).mean()  # 1

        # ref: https://github.com/joschu/modular_rl/blob/6970cde3da265cf2a98537250fea5e0c0d9a7639/modular_rl/ppo.py#L40
        # ref: https://github.com/hill-a/stable-baselines/blob/b3f414f4f2900403107357a2206f80868af16da3/stable_baselines/ppo2/ppo2.py#L185
        if self._kl_reverse:  # TODO:
            kl = .5 * (new_log_prob - BATCH.log_prob).square().mean()  # 1
        else:
            # a sample estimate for KL-divergence, easy to compute
            kl = .5 * (BATCH.log_prob - new_log_prob).square().mean()

        if self._use_kl_loss:
            kl_loss = self._kl_coef * kl  # 1
            actor_loss += kl_loss

        if self._use_extra_loss:
            extra_loss = self._extra_coef * th.maximum(th.zeros_like(kl), kl - self._kl_cutoff).square().mean()  # 1
            actor_loss += extra_loss

        td_error = BATCH.discounted_reward - value  # [T, B, 1]
        if self._use_vclip:
            # ref: https://github.com/llSourcell/OpenAI_Five_vs_Dota2_Explained/blob/c5def7e57aa70785c2394ea2eeb3e5f66ad59a53/train.py#L154
            # ref: https://github.com/hill-a/stable-baselines/blob/b3f414f4f2900403107357a2206f80868af16da3/stable_baselines/ppo2/ppo2.py#L172
            value_clip = BATCH.value + (value - BATCH.value).clamp(-self._value_epsilon,
                                                                   self._value_epsilon)  # [T, B, 1]
            td_error_clip = BATCH.discounted_reward - value_clip  # [T, B, 1]
            td_square = th.maximum(td_error.square(), td_error_clip.square())  # [T, B, 1]
        else:
            td_square = td_error.square()  # [T, B, 1]

        critic_loss = 0.5 * td_square.mean()  # 1
        loss = actor_loss + self._vf_coef * critic_loss  # 1
        self.oplr.optimize(loss)
        return {
                   'LOSS/actor_loss': actor_loss,
                   'LOSS/critic_loss': critic_loss,
                   'Statistics/kl': kl,
                   'Statistics/entropy': entropy.mean(),
                   'LEARNING_RATE/lr': self.oplr.lr
               }, kl

    @iton
    def train_actor(self, BATCH):
        if self.is_continuous:
            # [T, B, A], [T, B, A]
            mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask)
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            new_log_prob = dist.log_prob(BATCH.action).unsqueeze(-1)  # [T, B, 1]
            entropy = dist.entropy().unsqueeze(-1)  # [T, B, 1]
        else:
            logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            new_log_prob = (BATCH.action * logp_all).sum(-1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(-1, keepdim=True)  # [T, B, 1]
        ratio = (new_log_prob - BATCH.log_prob).exp()  # [T, B, 1]
        kl = (BATCH.log_prob - new_log_prob).square().mean()  # 1
        surrogate = ratio * BATCH.gae_adv  # [T, B, 1]
        clipped_surrogate = th.minimum(
            surrogate,
            th.where(BATCH.gae_adv > 0, (1 + self._epsilon) *
                     BATCH.gae_adv, (1 - self._epsilon) * BATCH.gae_adv)
        )  # [T, B, 1]
        if self._use_duel_clip:
            clipped_surrogate = th.maximum(
                clipped_surrogate,
                (1.0 + self._duel_epsilon) * BATCH.gae_adv
            )  # [T, B, 1]

        actor_loss = -(clipped_surrogate + self._ent_coef * entropy).mean()  # 1

        if self._use_kl_loss:
            kl_loss = self._kl_coef * kl  # 1
            actor_loss += kl_loss
        if self._use_extra_loss:
            extra_loss = self._extra_coef * th.maximum(th.zeros_like(kl), kl - self._kl_cutoff).square().mean()  # 1
            actor_loss += extra_loss

        self.actor_oplr.optimize(actor_loss)
        return {
                   'LOSS/actor_loss': actor_loss,
                   'Statistics/kl': kl,
                   'Statistics/entropy': entropy.mean(),
                   'LEARNING_RATE/actor_lr': self.actor_oplr.lr
               }, kl

    @iton
    def train_critic(self, BATCH):
        value = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, 1]

        td_error = BATCH.discounted_reward - value  # [T, B, 1]
        if self._use_vclip:
            value_clip = BATCH.value + (value - BATCH.value).clamp(-self._value_epsilon,
                                                                   self._value_epsilon)  # [T, B, 1]
            td_error_clip = BATCH.discounted_reward - value_clip  # [T, B, 1]
            td_square = th.maximum(td_error.square(), td_error_clip.square())  # [T, B, 1]
        else:
            td_square = td_error.square()  # [T, B, 1]

        critic_loss = 0.5 * td_square.mean()  # 1
        self.critic_oplr.optimize(critic_loss)
        return {
            'LOSS/critic_loss': critic_loss,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr
        }
Exemple #16
0
class AveragedDQN(SarlOffPolicy):
    """
    Averaged-DQN, http://arxiv.org/abs/1611.01929
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 target_k: int = 4,
                 lr: float = 5.0e-4,
                 eps_init: float = 1,
                 eps_mid: float = 0.2,
                 eps_final: float = 0.01,
                 init2mid_annealing_step: int = 1000,
                 assign_interval: int = 1000,
                 network_settings: List[int] = [32, 32],
                 **kwargs):
        super().__init__(**kwargs)
        assert not self.is_continuous, 'dqn only support discrete action space'
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self._max_train_step)
        self.assign_interval = assign_interval
        self.target_k = target_k
        assert self.target_k > 0, "assert self.target_k > 0"
        self.current_target_idx = 0

        self.q_net = CriticQvalueAll(self.obs_spec,
                                     rep_net_params=self._rep_net_params,
                                     output_shape=self.a_dim,
                                     network_settings=network_settings).to(
                                         self.device)
        self.target_nets = []
        for i in range(self.target_k):
            target_q_net = deepcopy(self.q_net)
            target_q_net.eval()
            sync_params(target_q_net, self.q_net)
            self.target_nets.append(target_q_net)

        self.oplr = OPLR(self.q_net, lr, **self._oplr_params)
        self._trainer_modules.update(model=self.q_net, oplr=self.oplr)

    @iton
    def select_action(self, obs):
        q_values = self.q_net(obs, rnncs=self.rnncs)  # [B, *]
        self.rnncs_ = self.q_net.get_rnncs()

        if self._is_train_mode and self.expl_expt_mng.is_random(
                self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            for i in range(self.target_k):
                target_q_values = self.target_nets[i](obs, rnncs=self.rnncs)
                q_values += target_q_values
            actions = q_values.argmax(-1)  # 不取平均也可以 [B, ]
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, *]
        q_next = 0
        for i in range(self.target_k):
            q_next += self.target_nets[i](BATCH.obs_,
                                          begin_mask=BATCH.begin_mask)
        q_next /= self.target_k  # [T, B, *]
        q_eval = (q * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        q_target = n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                 q_next.max(-1, keepdim=True)[0],
                                 BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1]
        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1

        self.oplr.optimize(q_loss)
        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': q_loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            sync_params(self.target_nets[self.current_target_idx], self.q_net)
            self.current_target_idx = (self.current_target_idx +
                                       1) % self.target_k
Exemple #17
0
class SAC_V(SarlOffPolicy):
    """
        Soft Actor Critic with Value neural network. https://arxiv.org/abs/1812.05905
        Soft Actor-Critic for Discrete Action Settings. https://arxiv.org/abs/1910.07207
    """
    policy_mode = 'off-policy'

    def __init__(
            self,
            alpha=0.2,
            annealing=True,
            last_alpha=0.01,
            polyak=0.995,
            use_gumbel=True,
            discrete_tau=1.0,
            network_settings={
                'actor_continuous': {
                    'share': [128, 128],
                    'mu': [64],
                    'log_std': [64],
                    'soft_clip': False,
                    'log_std_bound': [-20, 2]
                },
                'actor_discrete': [64, 32],
                'q': [128, 128],
                'v': [128, 128]
            },
            actor_lr=5.0e-4,
            critic_lr=1.0e-3,
            alpha_lr=5.0e-4,
            auto_adaption=True,
            **kwargs):
        super().__init__(**kwargs)
        self.polyak = polyak
        self.use_gumbel = use_gumbel
        self.discrete_tau = discrete_tau
        self.auto_adaption = auto_adaption
        self.annealing = annealing

        self.v_net = TargetTwin(
            CriticValue(self.obs_spec,
                        rep_net_params=self._rep_net_params,
                        network_settings=network_settings['v']),
            self.polyak).to(self.device)

        if self.is_continuous:
            self.actor = ActorCts(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_continuous']).to(
                    self.device)
        else:
            self.actor = ActorDct(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_discrete']).to(
                    self.device)

        # entropy = -log(1/|A|) = log |A|
        self.target_entropy = 0.98 * (-self.a_dim if self.is_continuous else
                                      np.log(self.a_dim))

        if self.is_continuous or self.use_gumbel:
            self.q_net = CriticQvalueOne(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                action_dim=self.a_dim,
                network_settings=network_settings['q']).to(self.device)
        else:
            self.q_net = CriticQvalueAll(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['q']).to(self.device)
        self.q_net2 = deepcopy(self.q_net)

        self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params)
        self.critic_oplr = OPLR([self.q_net, self.q_net2, self.v_net],
                                critic_lr, **self._oplr_params)

        if self.auto_adaption:
            self.log_alpha = th.tensor(0., requires_grad=True).to(self.device)
            self.alpha_oplr = OPLR(self.log_alpha, alpha_lr,
                                   **self._oplr_params)
            self._trainer_modules.update(alpha_oplr=self.alpha_oplr)
        else:
            self.log_alpha = th.tensor(alpha).log().to(self.device)
            if self.annealing:
                self.alpha_annealing = LinearAnnealing(alpha, last_alpha,
                                                       int(1e6))

        self._trainer_modules.update(actor=self.actor,
                                     v_net=self.v_net,
                                     q_net=self.q_net,
                                     q_net2=self.q_net2,
                                     log_alpha=self.log_alpha,
                                     actor_oplr=self.actor_oplr,
                                     critic_oplr=self.critic_oplr)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    @iton
    def select_action(self, obs):
        if self.is_continuous:
            mu, log_std = self.actor(obs, rnncs=self.rnncs)  # [B, A]
            pi = td.Normal(mu, log_std.exp()).sample().tanh()  # [B, A]
            mu.tanh_()  # squash mu   # [B, A]
        else:
            logits = self.actor(obs, rnncs=self.rnncs)  # [B, A]
            mu = logits.argmax(-1)  # [B,]
            cate_dist = td.Categorical(logits=logits)
            pi = cate_dist.sample()  # [B,]
        self.rnncs_ = self.actor.get_rnncs()
        actions = pi if self._is_train_mode else mu
        return actions, Data(action=actions)

    def _train(self, BATCH):
        if self.is_continuous or self.use_gumbel:
            td_error, summaries = self._train_continuous(BATCH)
        else:
            td_error, summaries = self._train_discrete(BATCH)
        return td_error, summaries

    @iton
    def _train_continuous(self, BATCH):
        v = self.v_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        v_target = self.v_net.t(BATCH.obs_,
                                begin_mask=BATCH.begin_mask)  # [T, B, 1]

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs,
                                     begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            pi = dist.rsample()  # [T, B, A]
            pi, log_pi = squash_action(
                pi,
                dist.log_prob(pi).unsqueeze(-1))  # [T, B, A], [T, B, 1]
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                -1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1),
                                         self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            pi = _pi_diff + _pi  # [T, B, A]
            log_pi = (logp_all * pi).sum(-1, keepdim=True)  # [T, B, 1]
        q1 = self.q_net(BATCH.obs, BATCH.action,
                        begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q2 = self.q_net2(BATCH.obs, BATCH.action,
                         begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q1_pi = self.q_net(BATCH.obs, pi,
                           begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q2_pi = self.q_net2(BATCH.obs, pi,
                            begin_mask=BATCH.begin_mask)  # [T, B, 1]
        dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, v_target,
                             BATCH.begin_mask).detach()  # [T, B, 1]
        v_from_q_stop = (th.minimum(q1_pi, q2_pi) -
                         self.alpha * log_pi).detach()  # [T, B, 1]
        td_v = v - v_from_q_stop  # [T, B, 1]
        td_error1 = q1 - dc_r  # [T, B, 1]
        td_error2 = q2 - dc_r  # [T, B, 1]
        q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
        q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
        v_loss_stop = (td_v.square() * BATCH.get('isw', 1.0)).mean()  # 1

        critic_loss = 0.5 * q1_loss + 0.5 * q2_loss + 0.5 * v_loss_stop
        self.critic_oplr.optimize(critic_loss)

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs,
                                     begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            pi = dist.rsample()  # [T, B, A]
            pi, log_pi = squash_action(
                pi,
                dist.log_prob(pi).unsqueeze(-1))  # [T, B, A], [T, B, 1]
            entropy = dist.entropy().mean()  # 1
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                -1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1),
                                         self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            pi = _pi_diff + _pi  # [T, B, A]
            log_pi = (logp_all * pi).sum(-1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(-1).mean()  # 1
        q1_pi = self.q_net(BATCH.obs, pi,
                           begin_mask=BATCH.begin_mask)  # [T, B, 1]
        actor_loss = -(q1_pi - self.alpha * log_pi).mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        summaries = {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/q1_loss': q1_loss,
            'LOSS/q2_loss': q2_loss,
            'LOSS/v_loss': v_loss_stop,
            'LOSS/critic_loss': critic_loss,
            'Statistics/log_alpha': self.log_alpha,
            'Statistics/alpha': self.alpha,
            'Statistics/entropy': entropy,
            'Statistics/q_min': th.minimum(q1, q2).min(),
            'Statistics/q_mean': th.minimum(q1, q2).mean(),
            'Statistics/q_max': th.maximum(q1, q2).max(),
            'Statistics/v_mean': v.mean()
        }
        if self.auto_adaption:
            alpha_loss = -(self.alpha *
                           (log_pi.detach() + self.target_entropy)).mean()
            self.alpha_oplr.optimize(alpha_loss)
            summaries.update({
                'LOSS/alpha_loss': alpha_loss,
                'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr
            })
        return (td_error1 + td_error2) / 2, summaries

    @iton
    def _train_discrete(self, BATCH):
        v = self.v_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        v_target = self.v_net.t(BATCH.obs_,
                                begin_mask=BATCH.begin_mask)  # [T, B, 1]

        q1_all = self.q_net(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        q2_all = self.q_net2(BATCH.obs,
                             begin_mask=BATCH.begin_mask)  # [T, B, A]
        q1 = (q1_all * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        q2 = (q2_all * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        logits = self.actor(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        logp_all = logits.log_softmax(-1)  # [T, B, A]

        dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, v_target,
                             BATCH.begin_mask).detach()  # [T, B, 1]
        td_v = v - (th.minimum((logp_all.exp() * q1_all).sum(-1, keepdim=True),
                               (logp_all.exp() * q2_all).sum(
                                   -1, keepdim=True))).detach()  # [T, B, 1]
        td_error1 = q1 - dc_r  # [T, B, 1]
        td_error2 = q2 - dc_r  # [T, B, 1]

        q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
        q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
        v_loss_stop = (td_v.square() * BATCH.get('isw', 1.0)).mean()  # 1
        critic_loss = 0.5 * q1_loss + 0.5 * q2_loss + 0.5 * v_loss_stop
        self.critic_oplr.optimize(critic_loss)

        q1_all = self.q_net(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        q2_all = self.q_net2(BATCH.obs,
                             begin_mask=BATCH.begin_mask)  # [T, B, A]
        logits = self.actor(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        logp_all = logits.log_softmax(-1)  # [T, B, A]

        entropy = -(logp_all.exp() * logp_all).sum(-1,
                                                   keepdim=True)  # [T, B, 1]
        q_all = th.minimum(q1_all, q2_all)  # [T, B, A]
        actor_loss = -((q_all - self.alpha * logp_all) * logp_all.exp()).sum(
            -1)  # [T, B, A] => [T, B]
        actor_loss = actor_loss.mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        summaries = {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/q1_loss': q1_loss,
            'LOSS/q2_loss': q2_loss,
            'LOSS/v_loss': v_loss_stop,
            'LOSS/critic_loss': critic_loss,
            'Statistics/log_alpha': self.log_alpha,
            'Statistics/alpha': self.alpha,
            'Statistics/entropy': entropy.mean(),
            'Statistics/v_mean': v.mean()
        }
        if self.auto_adaption:
            corr = (self.target_entropy - entropy).detach()  # [T, B, 1]
            # corr = ((logp_all - self.a_dim) * logp_all.exp()).sum(-1).detach()
            alpha_loss = -(self.alpha * corr)  # [T, B, 1]
            alpha_loss = alpha_loss.mean()  # 1
            self.alpha_oplr.optimize(alpha_loss)
            summaries.update({
                'LOSS/alpha_loss': alpha_loss,
                'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr
            })
        return (td_error1 + td_error2) / 2, summaries

    def _after_train(self):
        super()._after_train()
        if self.annealing and not self.auto_adaption:
            self.log_alpha.copy_(
                self.alpha_annealing(self._cur_train_step).log())
        self.v_net.sync()
Exemple #18
0
class BCQ(SarlOffPolicy):
    """
    Benchmarking Batch Deep Reinforcement Learning Algorithms, http://arxiv.org/abs/1910.01708
    Off-Policy Deep Reinforcement Learning without Exploration, http://arxiv.org/abs/1812.02900
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 polyak=0.995,
                 discrete=dict(threshold=0.3,
                               lr=5.0e-4,
                               eps_init=1,
                               eps_mid=0.2,
                               eps_final=0.01,
                               init2mid_annealing_step=1000,
                               assign_interval=1000,
                               network_settings=[32, 32]),
                 continuous=dict(phi=0.05,
                                 lmbda=0.75,
                                 select_samples=100,
                                 train_samples=10,
                                 actor_lr=1e-3,
                                 critic_lr=1e-3,
                                 vae_lr=1e-3,
                                 network_settings=dict(
                                     actor=[32, 32],
                                     critic=[32, 32],
                                     vae=dict(encoder=[750, 750],
                                              decoder=[750, 750]))),
                 **kwargs):
        super().__init__(**kwargs)
        self._polyak = polyak

        if self.is_continuous:
            self._lmbda = continuous['lmbda']
            self._select_samples = continuous['select_samples']
            self._train_samples = continuous['train_samples']
            self.actor = TargetTwin(BCQ_Act_Cts(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                action_dim=self.a_dim,
                phi=continuous['phi'],
                network_settings=continuous['network_settings']['actor']),
                                    polyak=self._polyak).to(self.device)
            self.critic = TargetTwin(BCQ_CriticQvalueOne(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                action_dim=self.a_dim,
                network_settings=continuous['network_settings']['critic']),
                                     polyak=self._polyak).to(self.device)
            self.vae = VAE(self.obs_spec,
                           rep_net_params=self._rep_net_params,
                           a_dim=self.a_dim,
                           z_dim=self.a_dim * 2,
                           hiddens=continuous['network_settings']['vae']).to(
                               self.device)

            self.actor_oplr = OPLR(self.actor, continuous['actor_lr'],
                                   **self._oplr_params)
            self.critic_oplr = OPLR(self.critic, continuous['critic_lr'],
                                    **self._oplr_params)
            self.vae_oplr = OPLR(self.vae, continuous['vae_lr'],
                                 **self._oplr_params)
            self._trainer_modules.update(actor=self.actor,
                                         critic=self.critic,
                                         vae=self.vae,
                                         actor_oplr=self.actor_oplr,
                                         critic_oplr=self.critic_oplr,
                                         vae_oplr=self.vae_oplr)
        else:
            self.expl_expt_mng = ExplorationExploitationClass(
                eps_init=discrete['eps_init'],
                eps_mid=discrete['eps_mid'],
                eps_final=discrete['eps_final'],
                init2mid_annealing_step=discrete['init2mid_annealing_step'],
                max_step=self._max_train_step)
            self.assign_interval = discrete['assign_interval']
            self._threshold = discrete['threshold']
            self.q_net = TargetTwin(BCQ_DCT(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=discrete['network_settings']),
                                    polyak=self._polyak).to(self.device)
            self.oplr = OPLR(self.q_net, discrete['lr'], **self._oplr_params)
            self._trainer_modules.update(model=self.q_net, oplr=self.oplr)

    @iton
    def select_action(self, obs):
        if self.is_continuous:
            _actions = []
            for _ in range(self._select_samples):
                _actions.append(
                    self.actor(obs, self.vae.decode(obs),
                               rnncs=self.rnncs))  # [B, A]
            self.rnncs_ = self.actor.get_rnncs(
            )  # TODO: calculate corrected hidden state
            _actions = th.stack(_actions, dim=0)  # [N, B, A]
            q1s = []
            for i in range(self._select_samples):
                q1s.append(self.critic(obs, _actions[i])[0])
            q1s = th.stack(q1s, dim=0)  # [N, B, 1]
            max_idxs = q1s.argmax(dim=0, keepdim=True)[-1]  # [1, B, 1]
            actions = _actions[
                max_idxs,
                th.arange(self.n_copies).reshape(self.n_copies, 1),
                th.arange(self.a_dim)]
        else:
            q_values, i_values = self.q_net(obs, rnncs=self.rnncs)  # [B, *]
            q_values = q_values - q_values.min(dim=-1,
                                               keepdim=True)[0]  # [B, *]
            i_values = F.log_softmax(i_values, dim=-1)  # [B, *]
            i_values = i_values.exp()  # [B, *]
            i_values = (i_values / i_values.max(-1, keepdim=True)[0] >
                        self._threshold).float()  # [B, *]

            self.rnncs_ = self.q_net.get_rnncs()

            if self._is_train_mode and self.expl_expt_mng.is_random(
                    self._cur_train_step):
                actions = np.random.randint(0, self.a_dim, self.n_copies)
            else:
                actions = (i_values * q_values).argmax(-1)  # [B,]
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        if self.is_continuous:
            # Variational Auto-Encoder Training
            recon, mean, std = self.vae(BATCH.obs,
                                        BATCH.action,
                                        begin_mask=BATCH.begin_mask)
            recon_loss = F.mse_loss(recon, BATCH.action)

            KL_loss = -0.5 * (1 + th.log(std.pow(2)) - mean.pow(2) -
                              std.pow(2)).mean()
            vae_loss = recon_loss + 0.5 * KL_loss

            self.vae_oplr.optimize(vae_loss)

            target_Qs = []
            for _ in range(self._train_samples):
                # Compute value of perturbed actions sampled from the VAE
                _vae_actions = self.vae.decode(BATCH.obs_,
                                               begin_mask=BATCH.begin_mask)
                _actor_actions = self.actor.t(BATCH.obs_,
                                              _vae_actions,
                                              begin_mask=BATCH.begin_mask)
                target_Q1, target_Q2 = self.critic.t(
                    BATCH.obs_, _actor_actions, begin_mask=BATCH.begin_mask)

                # Soft Clipped Double Q-learning
                target_Q = self._lmbda * th.min(target_Q1, target_Q2) + \
                           (1. - self._lmbda) * th.max(target_Q1, target_Q2)
                target_Qs.append(target_Q)
            target_Qs = th.stack(target_Qs, dim=0)  # [N, T, B, 1]
            # Take max over each BATCH.action sampled from the VAE
            target_Q = target_Qs.max(dim=0)[0]  # [T, B, 1]

            target_Q = n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                     target_Q,
                                     BATCH.begin_mask).detach()  # [T, B, 1]

            current_Q1, current_Q2 = self.critic(BATCH.obs,
                                                 BATCH.action,
                                                 begin_mask=BATCH.begin_mask)
            td_error = ((current_Q1 - target_Q) + (current_Q2 - target_Q)) / 2
            critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
                current_Q2, target_Q)

            self.critic_oplr.optimize(critic_loss)

            # Pertubation Model / Action Training
            sampled_actions = self.vae.decode(BATCH.obs,
                                              begin_mask=BATCH.begin_mask)
            perturbed_actions = self.actor(BATCH.obs,
                                           sampled_actions,
                                           begin_mask=BATCH.begin_mask)

            # Update through DPG
            q1, _ = self.critic(BATCH.obs,
                                perturbed_actions,
                                begin_mask=BATCH.begin_mask)
            actor_loss = -q1.mean()

            self.actor_oplr.optimize(actor_loss)

            return td_error, {
                'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
                'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
                'LEARNING_RATE/vae_lr': self.vae_oplr.lr,
                'LOSS/actor_loss': actor_loss,
                'LOSS/critic_loss': critic_loss,
                'LOSS/vae_loss': vae_loss,
                'Statistics/q_min': q1.min(),
                'Statistics/q_mean': q1.mean(),
                'Statistics/q_max': q1.max()
            }

        else:
            q_next, i_next = self.q_net(
                BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            q_next = q_next - q_next.min(dim=-1, keepdim=True)[0]  # [B, *]
            i_next = F.log_softmax(i_next, dim=-1)  # [T, B, A]
            i_next = i_next.exp()  # [T, B, A]
            i_next = (i_next / i_next.max(-1, keepdim=True)[0] >
                      self._threshold).float()  # [T, B, A]
            q_next = i_next * q_next  # [T, B, A]
            next_max_action = q_next.argmax(-1)  # [T, B]
            next_max_action_one_hot = F.one_hot(
                next_max_action.squeeze(), self.a_dim).float()  # [T, B, A]

            q_target_next, _ = self.q_net.t(
                BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            q_target_next_max = (q_target_next * next_max_action_one_hot).sum(
                -1, keepdim=True)  # [T, B, 1]
            q_target = n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                     q_target_next_max,
                                     BATCH.begin_mask).detach()  # [T, B, 1]

            q, i = self.q_net(BATCH.obs,
                              begin_mask=BATCH.begin_mask)  # [T, B, A]
            q_eval = (q * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]

            td_error = q_target - q_eval  # [T, B, 1]
            q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1

            imt = F.log_softmax(i, dim=-1)  # [T, B, A]
            imt = imt.reshape(-1, self.a_dim)  # [T*B, A]
            action = BATCH.action.reshape(-1, self.a_dim)  # [T*B, A]
            i_loss = F.nll_loss(imt, action.argmax(-1))  # 1

            loss = q_loss + i_loss + 1e-2 * i.pow(2).mean()

            self.oplr.optimize(loss)
            return td_error, {
                'LEARNING_RATE/lr': self.oplr.lr,
                'LOSS/q_loss': q_loss,
                'LOSS/i_loss': i_loss,
                'LOSS/loss': loss,
                'Statistics/q_max': q_eval.max(),
                'Statistics/q_min': q_eval.min(),
                'Statistics/q_mean': q_eval.mean()
            }

    def _after_train(self):
        super()._after_train()
        if self.is_continuous:
            self.actor.sync()
            self.critic.sync()
        else:
            if self._polyak != 0:
                self.q_net.sync()
            else:
                if self._cur_train_step % self.assign_interval == 0:
                    self.q_net.sync()
Exemple #19
0
class SQL(SarlOffPolicy):
    """
        Soft Q-Learning. ref: https://github.com/Bigpig4396/PyTorch-Soft-Q-Learning/blob/master/SoftQ.py
        NOTE: not the original of the paper, NO SVGD.
        Reinforcement Learning with Deep Energy-Based Policies: https://arxiv.org/abs/1702.08165
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 lr=5.0e-4,
                 alpha=2,
                 polyak=0.995,
                 network_settings=[32, 32],
                 **kwargs):
        super().__init__(**kwargs)
        assert not self.is_continuous, 'sql only support discrete action space'
        self.alpha = alpha
        self.polyak = polyak

        self.q_net = TargetTwin(CriticQvalueAll(self.obs_spec,
                                                rep_net_params=self._rep_net_params,
                                                output_shape=self.a_dim,
                                                network_settings=network_settings),
                                self.polyak).to(self.device)

        self.oplr = OPLR(self.q_net, lr, **self._oplr_params)
        self._trainer_modules.update(model=self.q_net,
                                     oplr=self.oplr)

    @iton
    def select_action(self, obs):
        q_values = self.q_net(obs, rnncs=self.rnncs)  # [B, A]
        self.rnncs_ = self.q_net.get_rnncs()
        logits = ((q_values - self._get_v(q_values)) / self.alpha).exp()  # > 0   # [B, A]
        logits /= logits.sum(-1, keepdim=True)  # [B, A]
        cate_dist = td.Categorical(logits=logits)
        actions = cate_dist.sample()  # [B,]
        return actions, Data(action=actions)

    def _get_v(self, q):
        v = self.alpha * (q / self.alpha).exp().mean(-1, keepdim=True).log()  # [B, 1] or [T, B, 1]
        return v

    @iton
    def _train(self, BATCH):
        q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q_next = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
        v_next = self._get_v(q_next)  # [T, B, 1]
        q_eval = (q * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        q_target = n_step_return(BATCH.reward,
                                 self.gamma,
                                 BATCH.done,
                                 v_next,
                                 BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1]

        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1
        self.oplr.optimize(q_loss)
        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': q_loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }

    def _after_train(self):
        super()._after_train()
        self.q_net.sync()
Exemple #20
0
class QRDQN(SarlOffPolicy):
    """
    Quantile Regression DQN
    Distributional Reinforcement Learning with Quantile Regression, https://arxiv.org/abs/1710.10044
    No double, no dueling, no noisy net.
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 nums=20,
                 huber_delta=1.,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=1000,
                 network_settings=[128, 128],
                 **kwargs):
        assert nums > 0, 'assert nums > 0'
        super().__init__(**kwargs)
        assert not self.is_continuous, 'qrdqn only support discrete action space'
        self.nums = nums
        self.huber_delta = huber_delta
        self.quantiles = th.tensor((2 * np.arange(self.nums) + 1) / (2.0 * self.nums)).float().to(self.device)  # [N,]
        self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init,
                                                          eps_mid=eps_mid,
                                                          eps_final=eps_final,
                                                          init2mid_annealing_step=init2mid_annealing_step,
                                                          max_step=self._max_train_step)
        self.assign_interval = assign_interval
        self.q_net = TargetTwin(QrdqnDistributional(self.obs_spec,
                                                    rep_net_params=self._rep_net_params,
                                                    action_dim=self.a_dim,
                                                    nums=self.nums,
                                                    network_settings=network_settings)).to(self.device)
        self.oplr = OPLR(self.q_net, lr, **self._oplr_params)
        self._trainer_modules.update(model=self.q_net,
                                     oplr=self.oplr)

    @iton
    def select_action(self, obs):
        q_values = self.q_net(obs, rnncs=self.rnncs)  # [B, A, N]
        self.rnncs_ = self.q_net.get_rnncs()

        if self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            q = q_values.mean(-1)  # [B, A, N] => [B, A]
            actions = q.argmax(-1)  # [B,]
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        q_dist = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A, N]
        q_dist = (q_dist * BATCH.action.unsqueeze(-1)).sum(-2)  # [T, B, A, N] => [T, B, N]

        target_q_dist = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A, N]
        target_q = target_q_dist.mean(-1)  # [T, B, A, N] => [T, B, A]
        _a = target_q.argmax(-1)  # [T, B]
        next_max_action = F.one_hot(_a, self.a_dim).float().unsqueeze(-1)  # [T, B, A, 1]
        # [T, B, A, N] => [T, B, N]
        target_q_dist = (target_q_dist * next_max_action).sum(-2)

        target = n_step_return(BATCH.reward.repeat(1, 1, self.nums),
                               self.gamma,
                               BATCH.done.repeat(1, 1, self.nums),
                               target_q_dist,
                               BATCH.begin_mask.repeat(1, 1, self.nums)).detach()  # [T, B, N]

        q_eval = q_dist.mean(-1, keepdim=True)  # [T, B, 1]
        q_target = target.mean(-1, keepdim=True)  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1], used for PER

        target = target.unsqueeze(-2)  # [T, B, 1, N]
        q_dist = q_dist.unsqueeze(-1)  # [T, B, N, 1]

        # [T, B, 1, N] - [T, B, N, 1] => [T, B, N, N]
        quantile_error = target - q_dist
        huber = F.huber_loss(target, q_dist, reduction="none", delta=self.huber_delta)  # [T, B, N, N]
        # [N,] - [T, B, N, N] => [T, B, N, N]
        huber_abs = (self.quantiles - quantile_error.detach().le(0.).float()).abs()
        loss = (huber_abs * huber).mean(-1)  # [T, B, N, N] => [T, B, N]
        loss = loss.sum(-1, keepdim=True)  # [T, B, N] => [T, B, 1]
        loss = (loss * BATCH.get('isw', 1.0)).mean()  # 1

        self.oplr.optimize(loss)
        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            self.q_net.sync()
Exemple #21
0
class TD3(SarlOffPolicy):
    """
    Twin Delayed Deep Deterministic Policy Gradient, https://arxiv.org/abs/1802.09477
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 polyak=0.995,
                 delay_num=2,
                 noise_action='clip_normal',
                 noise_params={
                     'sigma': 0.2,
                     'noise_bound': 0.2
                 },
                 actor_lr=5.0e-4,
                 critic_lr=1.0e-3,
                 discrete_tau=1.0,
                 network_settings={
                     'actor_continuous': [32, 32],
                     'actor_discrete': [32, 32],
                     'q': [32, 32]
                 },
                 **kwargs):
        super().__init__(**kwargs)
        self.polyak = polyak
        self.delay_num = delay_num
        self.discrete_tau = discrete_tau

        if self.is_continuous:
            actor = ActorDPG(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_continuous'])
            self.noised_action = self.target_noised_action = Noise_action_REGISTER[
                noise_action](**noise_params)
        else:
            actor = ActorDct(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_continuous'])
        self.actor = TargetTwin(actor, self.polyak).to(self.device)

        self.critic = TargetTwin(
            CriticQvalueOne(self.obs_spec,
                            rep_net_params=self._rep_net_params,
                            action_dim=self.a_dim,
                            network_settings=network_settings['q']),
            self.polyak).to(self.device)
        self.critic2 = deepcopy(self.critic)

        self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params)
        self.critic_oplr = OPLR([self.critic, self.critic2], critic_lr,
                                **self._oplr_params)
        self._trainer_modules.update(actor=self.actor,
                                     critic=self.critic,
                                     critic2=self.critic2,
                                     actor_oplr=self.actor_oplr,
                                     critic_oplr=self.critic_oplr)

    def episode_reset(self):
        super().episode_reset()
        if self.is_continuous:
            self.noised_action.reset()

    @iton
    def select_action(self, obs):
        output = self.actor(obs, rnncs=self.rnncs)  # [B, A]
        self.rnncs_ = self.actor.get_rnncs()
        if self.is_continuous:
            mu = output  # [B, A]
            pi = self.noised_action(mu)  # [B, A]
        else:
            logits = output  # [B, A]
            mu = logits.argmax(-1)  # [B,]
            cate_dist = td.Categorical(logits=logits)
            pi = cate_dist.sample()  # [B,]
        actions = pi if self._is_train_mode else mu
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        for _ in range(self.delay_num):
            if self.is_continuous:
                action_target = self.target_noised_action(
                    self.actor.t(BATCH.obs_,
                                 begin_mask=BATCH.begin_mask))  # [T, B, A]
            else:
                target_logits = self.actor.t(
                    BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
                target_cate_dist = td.Categorical(logits=target_logits)
                target_pi = target_cate_dist.sample()  # [T, B]
                action_target = F.one_hot(target_pi,
                                          self.a_dim).float()  # [T, B, A]
            q1 = self.critic(BATCH.obs,
                             BATCH.action,
                             begin_mask=BATCH.begin_mask)  # [T, B, 1]
            q2 = self.critic2(BATCH.obs,
                              BATCH.action,
                              begin_mask=BATCH.begin_mask)  # [T, B, 1]
            q_target = th.minimum(
                self.critic.t(BATCH.obs_,
                              action_target,
                              begin_mask=BATCH.begin_mask),
                self.critic2.t(BATCH.obs_,
                               action_target,
                               begin_mask=BATCH.begin_mask))  # [T, B, 1]
            dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                 q_target,
                                 BATCH.begin_mask).detach()  # [T, B, 1]
            td_error1 = q1 - dc_r  # [T, B, 1]
            td_error2 = q2 - dc_r  # [T, B, 1]

            q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
            q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
            critic_loss = 0.5 * (q1_loss + q2_loss)
            self.critic_oplr.optimize(critic_loss)

        if self.is_continuous:
            mu = self.actor(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A]
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                -1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1),
                                         self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            mu = _pi_diff + _pi  # [T, B, A]
        q1_actor = self.critic(BATCH.obs, mu,
                               begin_mask=BATCH.begin_mask)  # [T, B, 1]

        actor_loss = -q1_actor.mean()  # 1
        self.actor_oplr.optimize(actor_loss)
        return (td_error1 + td_error2) / 2, {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': critic_loss,
            'Statistics/q_min': th.minimum(q1, q2).min(),
            'Statistics/q_mean': th.minimum(q1, q2).mean(),
            'Statistics/q_max': th.maximum(q1, q2).max()
        }

    def _after_train(self):
        super()._after_train()
        self.actor.sync()
        self.critic.sync()
        self.critic2.sync()
Exemple #22
0
class PlaNet(SarlOffPolicy):
    """
    Learning Latent Dynamics for Planning from Pixels, http://arxiv.org/abs/1811.04551
    """
    policy_mode = 'off-policy'

    def __init__(self,

                 stoch_dim=30,
                 deter_dim=200,
                 model_lr=6e-4,
                 kl_free_nats=3,
                 kl_scale=1.0,
                 reward_scale=1.0,
                 cem_horizon=12,
                 cem_iter_nums=10,
                 cem_candidates=1000,
                 cem_tops=100,
                 action_sigma=0.3,
                 network_settings=dict(),
                 **kwargs):
        super().__init__(**kwargs)

        assert self.is_continuous == True, 'assert self.is_continuous == True'

        self.cem_horizon = cem_horizon
        self.cem_iter_nums = cem_iter_nums
        self.cem_candidates = cem_candidates
        self.cem_tops = cem_tops

        assert self.use_rnn == False, 'assert self.use_rnn == False'

        if self.obs_spec.has_visual_observation \
                and len(self.obs_spec.visual_dims) == 1 \
                and not self.obs_spec.has_vector_observation:
            visual_dim = self.obs_spec.visual_dims[0]
            # TODO: optimize this
            assert visual_dim[0] == visual_dim[1] == 64, 'visual dimension must be [64, 64, *]'
            self._is_visual = True
        elif self.obs_spec.has_vector_observation \
                and len(self.obs_spec.vector_dims) == 1 \
                and not self.obs_spec.has_visual_observation:
            self._is_visual = False
        else:
            raise ValueError("please check the observation type")

        self.stoch_dim = stoch_dim
        self.deter_dim = deter_dim
        self.kl_free_nats = kl_free_nats
        self.kl_scale = kl_scale
        self.reward_scale = reward_scale
        self._action_sigma = action_sigma
        self._network_settings = network_settings

        if self.obs_spec.has_visual_observation:
            from rls.nn.dreamer import VisualDecoder, VisualEncoder
            self.obs_encoder = VisualEncoder(self.obs_spec.visual_dims[0],
                                             **network_settings['obs_encoder']['visual']).to(self.device)
            self.obs_decoder = VisualDecoder(self.decoder_input_dim,
                                             self.obs_spec.visual_dims[0],
                                             **network_settings['obs_decoder']['visual']).to(self.device)
        else:
            from rls.nn.dreamer import VectorEncoder
            self.obs_encoder = VectorEncoder(self.obs_spec.vector_dims[0],
                                             **network_settings['obs_encoder']['vector']).to(self.device)
            self.obs_decoder = DenseModel(self.decoder_input_dim,
                                          self.obs_spec.vector_dims[0],
                                          **network_settings['obs_decoder']['vector']).to(self.device)

        self.rssm = self._dreamer_build_rssm()

        """
        p(r_t | s_t, h_t)
        Reward model to predict reward from state and rnn hidden state
        """
        self.reward_predictor = DenseModel(self.decoder_input_dim,
                                           1,
                                           **network_settings['reward']).to(self.device)

        self.model_oplr = OPLR([self.obs_encoder, self.rssm, self.obs_decoder, self.reward_predictor],
                               model_lr, **self._oplr_params)
        self._trainer_modules.update(obs_encoder=self.obs_encoder,
                                     obs_decoder=self.obs_decoder,
                                     reward_predictor=self.reward_predictor,
                                     rssm=self.rssm,
                                     model_oplr=self.model_oplr)

    @property
    def decoder_input_dim(self):
        return self.stoch_dim + self.deter_dim

    def _dreamer_build_rssm(self):
        return RecurrentStateSpaceModel(self.stoch_dim,
                                        self.deter_dim,
                                        self.a_dim,
                                        self.obs_encoder.h_dim,
                                        **self._network_settings['rssm']).to(self.device)

    @iton
    def select_action(self, obs):
        if self._is_visual:
            obs = get_first_visual(obs)
        else:
            obs = get_first_vector(obs)
        # Compute starting state for planning
        # while taking information from current observation (posterior)
        embedded_obs = self.obs_encoder(obs)  # [B, *]
        state_posterior = self.rssm.posterior(self.rnncs['hx'], embedded_obs)  # dist # [B, *]

        # Initialize action distribution
        mean = th.zeros((self.cem_horizon, 1, self.n_copies, self.a_dim))  # [H, 1, B, A]
        stddev = th.ones((self.cem_horizon, 1, self.n_copies, self.a_dim))  # [H, 1, B, A]

        # Iteratively improve action distribution with CEM
        for itr in range(self.cem_iter_nums):
            action_candidates = mean + stddev * \
                                th.randn(self.cem_horizon, self.cem_candidates, self.n_copies,
                                         self.a_dim)  # [H, N, B, A]
            action_candidates = action_candidates.reshape(self.cem_horizon, -1, self.a_dim)  # [H, N*B, A]

            # Initialize reward, state, and rnn hidden state
            # These are for parallel exploration
            total_predicted_reward = th.zeros((self.cem_candidates * self.n_copies, 1))  # [N*B, 1]

            state = state_posterior.sample((self.cem_candidates,))  # [N, B, *]
            state = state.view(-1, state.shape[-1])  # [N*B, *]
            rnn_hidden = self.rnncs['hx'].repeat((self.cem_candidates, 1))  # [B, *] => [N*B, *]

            # Compute total predicted reward by open-loop prediction using pri
            for t in range(self.cem_horizon):
                next_state_prior, rnn_hidden = self.rssm.prior(state, th.tanh(action_candidates[t]), rnn_hidden)
                state = next_state_prior.sample()  # [N*B, *]
                post_feat = th.cat([state, rnn_hidden], -1)  # [N*B, *]
                total_predicted_reward += self.reward_predictor(post_feat).mean  # [N*B, 1]

            # update action distribution using top-k samples
            total_predicted_reward = total_predicted_reward.view(self.cem_candidates, self.n_copies, 1)  # [N, B, 1]
            _, top_indexes = total_predicted_reward.topk(self.cem_tops, dim=0, largest=True, sorted=False)  # [N', B, 1]
            action_candidates = action_candidates.view(self.cem_horizon, self.cem_candidates, self.n_copies,
                                                       -1)  # [H, N, B, A]
            top_action_candidates = action_candidates[:, top_indexes,
                                    th.arange(self.n_copies).reshape(self.n_copies, 1),
                                    th.arange(self.a_dim)]  # [H, N', B, A]
            mean = top_action_candidates.mean(dim=1, keepdim=True)  # [H, 1, B, A]
            stddev = top_action_candidates.std(dim=1, unbiased=False, keepdim=True)  # [H, 1, B, A]

        # Return only first action (replan each state based on new observation)
        actions = th.tanh(mean[0].squeeze(0))  # [B, A]
        actions = self._exploration(actions)
        _, self.rnncs_['hx'] = self.rssm.prior(state_posterior.sample(),
                                               actions,
                                               self.rnncs['hx'])
        return actions, Data(action=actions)

    def _exploration(self, action: th.Tensor) -> th.Tensor:
        """
        :param action: action to take, shape (1,) (if categorical), or (action dim,) (if continuous)
        :return: action of the same shape passed in, augmented with some noise
        """
        sigma = self._action_sigma if self._is_train_mode else 0.
        noise = th.randn(*action.shape) * sigma
        return th.clamp(action + noise, -1, 1)

    @iton
    def _train(self, BATCH):
        T, B = BATCH.action.shape[:2]
        if self._is_visual:
            obs_ = get_first_visual(BATCH.obs_)
        else:
            obs_ = get_first_vector(BATCH.obs_)

        # embed observations with CNN
        embedded_observations = self.obs_encoder(obs_)  # [T, B, *]

        # initialize state and rnn hidden state with 0 vector
        state, rnn_hidden = self.rssm.init_state(shape=B)  # [B, S], [B, D]

        # compute state and rnn hidden sequences and kl loss
        kl_loss = 0
        states, rnn_hiddens = [], []
        for l in range(T):
            # if the begin of this episode, then reset to 0.
            # No matther whether last episode is beened truncated of not.
            state = state * (1. - BATCH.begin_mask[l])  # [B, S]
            rnn_hidden = rnn_hidden * (1. - BATCH.begin_mask[l])  # [B, D]

            next_state_prior, next_state_posterior, rnn_hidden = self.rssm(state,
                                                                           BATCH.action[l],
                                                                           rnn_hidden,
                                                                           embedded_observations[l])  # a, s_
            state = next_state_posterior.rsample()  # [B, S] posterior of s_
            states.append(state)  # [B, S]
            rnn_hiddens.append(rnn_hidden)  # [B, D]
            kl_loss += self._kl_loss(next_state_prior, next_state_posterior)
        kl_loss /= T  # 1

        # compute reconstructed observations and predicted rewards
        post_feat = th.cat([th.stack(states, 0), th.stack(rnn_hiddens, 0)], -1)  # [T, B, *]

        obs_pred = self.obs_decoder(post_feat)  # [T, B, C, H, W] or [T, B, *]
        reward_pred = self.reward_predictor(post_feat)  # [T, B, 1], s_ => r

        # compute loss for observation and reward
        obs_loss = -th.mean(obs_pred.log_prob(obs_))  # [T, B] => 1
        # [T, B, 1]=>1
        reward_loss = -th.mean(reward_pred.log_prob(BATCH.reward).unsqueeze(-1))

        # add all losses and update model parameters with gradient descent
        model_loss = self.kl_scale * kl_loss + obs_loss + self.reward_scale * reward_loss  # 1

        self.model_oplr.optimize(model_loss)

        summaries = {
            'LEARNING_RATE/model_lr': self.model_oplr.lr,
            'LOSS/model_loss': model_loss,
            'LOSS/kl_loss': kl_loss,
            'LOSS/obs_loss': obs_loss,
            'LOSS/reward_loss': reward_loss
        }

        return th.ones_like(BATCH.reward), summaries

    def _initial_rnncs(self, batch: int) -> Dict[str, np.ndarray]:
        return {'hx': np.zeros((batch, self.deter_dim))}

    def _kl_loss(self, prior_dist, post_dist):
        # 1
        return td.kl_divergence(prior_dist, post_dist).clamp(min=self.kl_free_nats).mean()
Exemple #23
0
class DQN(SarlOffPolicy):
    """
    Deep Q-learning Network, DQN, [2013](https://arxiv.org/pdf/1312.5602.pdf), [2015](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
    DQN + LSTM, https://arxiv.org/abs/1507.06527
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 lr: float = 5.0e-4,
                 eps_init: float = 1,
                 eps_mid: float = 0.2,
                 eps_final: float = 0.01,
                 init2mid_annealing_step: int = 1000,
                 assign_interval: int = 1000,
                 network_settings: List[int] = [32, 32],
                 **kwargs):
        super().__init__(**kwargs)
        assert not self.is_continuous, 'dqn only support discrete action space'
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self._max_train_step)
        self.assign_interval = assign_interval
        self.q_net = TargetTwin(
            CriticQvalueAll(self.obs_spec,
                            rep_net_params=self._rep_net_params,
                            output_shape=self.a_dim,
                            network_settings=network_settings)).to(self.device)
        self.oplr = OPLR(self.q_net, lr, **self._oplr_params)
        self._trainer_modules.update(model=self.q_net)
        self._trainer_modules.update(oplr=self.oplr)

    @iton
    def select_action(self, obs):
        q_values = self.q_net(obs, rnncs=self.rnncs)  # [B, *]
        self.rnncs_ = self.q_net.get_rnncs()

        if self._is_train_mode and self.expl_expt_mng.is_random(
                self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            actions = q_values.argmax(-1)  # [B,]
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q_next = self.q_net.t(BATCH.obs_,
                              begin_mask=BATCH.begin_mask)  # [T, B, A]
        q_eval = (q * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        q_target = n_step_return(
            BATCH.reward,
            self.gamma,
            BATCH.done,
            q_next.max(-1, keepdim=True)[0],
            BATCH.begin_mask,
            nstep=self._n_step_value).detach()  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1]
        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1
        self.oplr.optimize(q_loss)
        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': q_loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            self.q_net.sync()
Exemple #24
0
class A2C(SarlOnPolicy):
    """
    Synchronous Advantage Actor-Critic, A2C, http://arxiv.org/abs/1602.01783
    """
    policy_mode = 'on-policy'

    def __init__(
            self,
            agent_spec,
            beta=1.0e-3,
            actor_lr=5.0e-4,
            critic_lr=1.0e-3,
            network_settings={
                'actor_continuous': {
                    'hidden_units': [64, 64],
                    'condition_sigma': False,
                    'log_std_bound': [-20, 2]
                },
                'actor_discrete': [32, 32],
                'critic': [32, 32]
            },
            **kwargs):
        super().__init__(agent_spec=agent_spec, **kwargs)
        self.beta = beta

        if self.is_continuous:
            self.actor = ActorMuLogstd(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_continuous']).to(
                    self.device)
        else:
            self.actor = ActorDct(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_discrete']).to(
                    self.device)
        self.critic = CriticValue(
            self.obs_spec,
            rep_net_params=self._rep_net_params,
            network_settings=network_settings['critic']).to(self.device)

        self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params)
        self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params)

        self._trainer_modules.update(actor=self.actor,
                                     critic=self.critic,
                                     actor_oplr=self.actor_oplr,
                                     critic_oplr=self.critic_oplr)

    @iton
    def select_action(self, obs):
        output = self.actor(obs, rnncs=self.rnncs)  # [B, A]
        self.rnncs_ = self.actor.get_rnncs()
        if self.is_continuous:
            mu, log_std = output  # [B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            action = dist.sample().clamp(-1, 1)  # [B, A]
        else:
            logits = output  # [B, A]
            norm_dist = td.Categorical(logits=logits)
            action = norm_dist.sample()  # [B,]

        acts_info = Data(action=action)
        if self.use_rnn:
            acts_info.update(rnncs=self.rnncs)
        return action, acts_info

    @iton
    def _get_value(self, obs, rnncs=None):
        value = self.critic(obs, rnncs=self.rnncs)
        return value

    def _preprocess_BATCH(self, BATCH):  # [T, B, *]
        BATCH = super()._preprocess_BATCH(BATCH)
        value = self._get_value(BATCH.obs_[-1], rnncs=self.rnncs)
        BATCH.discounted_reward = discounted_sum(BATCH.reward,
                                                 self.gamma,
                                                 BATCH.done,
                                                 BATCH.begin_mask,
                                                 init_value=value)
        td_error = calculate_td_error(
            BATCH.reward,
            self.gamma,
            BATCH.done,
            value=BATCH.value,
            next_value=np.concatenate((BATCH.value[1:], value[np.newaxis, :]),
                                      0))
        BATCH.gae_adv = discounted_sum(td_error,
                                       self.lambda_ * self.gamma,
                                       BATCH.done,
                                       BATCH.begin_mask,
                                       init_value=0.,
                                       normalize=True)
        return BATCH

    @iton
    def _train(self, BATCH):
        v = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        td_error = BATCH.discounted_reward - v  # [T, B, 1]
        critic_loss = td_error.square().mean()  # 1
        self.critic_oplr.optimize(critic_loss)

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs,
                                     begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            log_act_prob = dist.log_prob(BATCH.action).unsqueeze(
                -1)  # [T, B, 1]
            entropy = dist.entropy().unsqueeze(-1)  # [T, B, 1]
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            log_act_prob = (BATCH.action * logp_all).sum(
                -1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(
                -1, keepdim=True)  # [T, B, 1]
        # advantage = BATCH.discounted_reward - v.detach()  # [T, B, 1]
        actor_loss = -(log_act_prob * BATCH.gae_adv +
                       self.beta * entropy).mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        return {
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': critic_loss,
            'Statistics/entropy': entropy.mean(),
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr
        }
Exemple #25
0
class TAC(SarlOffPolicy):
    """Tsallis Actor Critic, TAC with V neural Network. https://arxiv.org/abs/1902.00137
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 alpha=0.2,
                 annealing=True,
                 last_alpha=0.01,
                 polyak=0.995,
                 entropic_index=1.5,
                 discrete_tau=1.0,
                 network_settings={
                     'actor_continuous': {
                         'share': [128, 128],
                         'mu': [64],
                         'log_std': [64],
                         'soft_clip': False,
                         'log_std_bound': [-20, 2]
                     },
                     'actor_discrete': [64, 32],
                     'q': [128, 128]
                 },
                 auto_adaption=True,
                 actor_lr=5.0e-4,
                 critic_lr=1.0e-3,
                 alpha_lr=5.0e-4,
                 **kwargs):
        super().__init__(**kwargs)
        self.polyak = polyak
        self.discrete_tau = discrete_tau
        self.entropic_index = 2 - entropic_index
        self.auto_adaption = auto_adaption
        self.annealing = annealing

        self.critic = TargetTwin(CriticQvalueOne(self.obs_spec,
                                                 rep_net_params=self._rep_net_params,
                                                 action_dim=self.a_dim,
                                                 network_settings=network_settings['q']),
                                 self.polyak).to(self.device)
        self.critic2 = deepcopy(self.critic)

        if self.is_continuous:
            self.actor = ActorCts(self.obs_spec,
                                  rep_net_params=self._rep_net_params,
                                  output_shape=self.a_dim,
                                  network_settings=network_settings['actor_continuous']).to(self.device)
        else:
            self.actor = ActorDct(self.obs_spec,
                                  rep_net_params=self._rep_net_params,
                                  output_shape=self.a_dim,
                                  network_settings=network_settings['actor_discrete']).to(self.device)

        # entropy = -log(1/|A|) = log |A|
        self.target_entropy = 0.98 * (-self.a_dim if self.is_continuous else np.log(self.a_dim))

        self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params)
        self.critic_oplr = OPLR([self.critic, self.critic2], critic_lr, **self._oplr_params)

        if self.auto_adaption:
            self.log_alpha = th.tensor(0., requires_grad=True).to(self.device)
            self.alpha_oplr = OPLR(self.log_alpha, alpha_lr, **self._oplr_params)
            self._trainer_modules.update(alpha_oplr=self.alpha_oplr)
        else:
            self.log_alpha = th.tensor(alpha).log().to(self.device)
            if self.annealing:
                self.alpha_annealing = LinearAnnealing(alpha, last_alpha, int(1e6))

        self._trainer_modules.update(actor=self.actor,
                                     critic=self.critic,
                                     critic2=self.critic2,
                                     log_alpha=self.log_alpha,
                                     actor_oplr=self.actor_oplr,
                                     critic_oplr=self.critic_oplr)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    @iton
    def select_action(self, obs):
        if self.is_continuous:
            mu, log_std = self.actor(obs, rnncs=self.rnncs)  # [B, A]
            pi = td.Normal(mu, log_std.exp()).sample().tanh()  # [B, A]
            mu.tanh_()  # squash mu     # [B, A]
        else:
            logits = self.actor(obs, rnncs=self.rnncs)  # [B, A]
            mu = logits.argmax(-1)  # [B,]
            cate_dist = td.Categorical(logits=logits)
            pi = cate_dist.sample()  # [B,]
        self.rnncs_ = self.actor.get_rnncs()
        actions = pi if self._is_train_mode else mu
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        if self.is_continuous:
            target_mu, target_log_std = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(target_mu, target_log_std.exp()), 1)
            target_pi = dist.sample()  # [T, B, A]
            target_pi, target_log_pi = squash_action(target_pi, dist.log_prob(
                target_pi).unsqueeze(-1), is_independent=False)  # [T, B, A]
            target_log_pi = tsallis_entropy_log_q(target_log_pi, self.entropic_index)  # [T, B, 1]
        else:
            target_logits = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            target_cate_dist = td.Categorical(logits=target_logits)
            target_pi = target_cate_dist.sample()  # [T, B]
            target_log_pi = target_cate_dist.log_prob(target_pi).unsqueeze(-1)  # [T, B, 1]
            target_pi = F.one_hot(target_pi, self.a_dim).float()  # [T, B, A]
        q1 = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q2 = self.critic2(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask)  # [T, B, 1]

        q1_target = self.critic.t(BATCH.obs_, target_pi, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q2_target = self.critic2.t(BATCH.obs_, target_pi, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        q_target = th.minimum(q1_target, q2_target)  # [T, B, 1]
        dc_r = n_step_return(BATCH.reward,
                             self.gamma,
                             BATCH.done,
                             (q_target - self.alpha * target_log_pi),
                             BATCH.begin_mask).detach()  # [T, B, 1]
        td_error1 = q1 - dc_r  # [T, B, 1]
        td_error2 = q2 - dc_r  # [T, B, 1]

        q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
        q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
        critic_loss = 0.5 * q1_loss + 0.5 * q2_loss
        self.critic_oplr.optimize(critic_loss)

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            pi = dist.rsample()  # [T, B, A]
            pi, log_pi = squash_action(pi, dist.log_prob(pi).unsqueeze(-1), is_independent=False)  # [T, B, A]
            log_pi = tsallis_entropy_log_q(log_pi, self.entropic_index)  # [T, B, 1]
            entropy = dist.entropy().mean()  # 1
        else:
            logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape)  # [T, B, A]
            _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(-1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            pi = _pi_diff + _pi  # [T, B, A]
            log_pi = (logp_all * pi).sum(-1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(-1).mean()  # 1
        q_s_pi = th.minimum(self.critic(BATCH.obs, pi, begin_mask=BATCH.begin_mask),
                            self.critic2(BATCH.obs, pi, begin_mask=BATCH.begin_mask))  # [T, B, 1]
        actor_loss = -(q_s_pi - self.alpha * log_pi).mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        summaries = {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/q1_loss': q1_loss,
            'LOSS/q2_loss': q2_loss,
            'LOSS/critic_loss': critic_loss,
            'Statistics/log_alpha': self.log_alpha,
            'Statistics/alpha': self.alpha,
            'Statistics/entropy': entropy,
            'Statistics/q_min': th.minimum(q1, q2).min(),
            'Statistics/q_mean': th.minimum(q1, q2).mean(),
            'Statistics/q_max': th.maximum(q1, q2).max()
        }
        if self.auto_adaption:
            alpha_loss = -(self.alpha * (log_pi + self.target_entropy).detach()).mean()  # 1
            self.alpha_oplr.optimize(alpha_loss)
            summaries.update({
                'LOSS/alpha_loss': alpha_loss,
                'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr
            })
        return (td_error1 + td_error2) / 2, summaries

    def _after_train(self):
        super()._after_train()
        self.critic.sync()
        self.critic2.sync()

        if self.annealing and not self.auto_adaption:
            self.log_alpha.copy_(self.alpha_annealing(self._cur_train_step).log())
Exemple #26
0
class DPG(SarlOffPolicy):
    """
    Deterministic Policy Gradient, https://hal.inria.fr/file/index/docid/938992/filename/dpg-icml2014.pdf
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 actor_lr=5.0e-4,
                 critic_lr=1.0e-3,
                 use_target_action_noise=False,
                 noise_action='ou',
                 noise_params={
                     'sigma': 0.2
                 },
                 discrete_tau=1.0,
                 network_settings={
                     'actor_continuous': [32, 32],
                     'actor_discrete': [32, 32],
                     'q': [32, 32]
                 },
                 **kwargs):
        super().__init__(**kwargs)
        self.discrete_tau = discrete_tau
        self.use_target_action_noise = use_target_action_noise

        if self.is_continuous:
            self.target_noised_action = ClippedNormalNoisedAction(sigma=0.2, noise_bound=0.2)
            self.noised_action = Noise_action_REGISTER[noise_action](**noise_params)
            self.actor = ActorDPG(self.obs_spec,
                                  rep_net_params=self._rep_net_params,
                                  output_shape=self.a_dim,
                                  network_settings=network_settings['actor_continuous']).to(self.device)
        else:
            self.actor = ActorDct(self.obs_spec,
                                  rep_net_params=self._rep_net_params,
                                  output_shape=self.a_dim,
                                  network_settings=network_settings['actor_discrete']).to(self.device)

        self.critic = CriticQvalueOne(self.obs_spec,
                                      rep_net_params=self._rep_net_params,
                                      action_dim=self.a_dim,
                                      network_settings=network_settings['q']).to(self.device)

        self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params)
        self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params)
        self._trainer_modules.update(actor=self.actor,
                                     critic=self.critic,
                                     actor_oplr=self.actor_oplr,
                                     critic_oplr=self.critic_oplr)

    def episode_reset(self):
        super().episode_reset()
        if self.is_continuous:
            self.noised_action.reset()

    @iton
    def select_action(self, obs):
        output = self.actor(obs, rnncs=self.rnncs)  # [B, A]
        self.rnncs_ = self.actor.get_rnncs()
        if self.is_continuous:
            mu = output  # [B, A]
            pi = self.noised_action(mu)  # [B, A]
        else:
            logits = output  # [B, A]
            mu = logits.argmax(-1)  # [B,]
            cate_dist = td.Categorical(logits=logits)
            pi = cate_dist.sample()  # [B,]
        actions = pi if self._is_train_mode else mu
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        if self.is_continuous:
            action_target = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            if self.use_target_action_noise:
                action_target = self.target_noised_action(action_target)  # [T, B, A]
        else:
            target_logits = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            target_cate_dist = td.Categorical(logits=target_logits)
            target_pi = target_cate_dist.sample()  # [T, B]
            action_target = F.one_hot(target_pi, self.a_dim).float()  # [T, B, A]
        q_target = self.critic(BATCH.obs_, action_target, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        dc_r = n_step_return(BATCH.reward,
                             self.gamma,
                             BATCH.done,
                             q_target,
                             BATCH.begin_mask).detach()  # [T, B, 1]
        q = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask)  # [T, B, A]
        td_error = dc_r - q  # [T, B, A]
        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1
        self.critic_oplr.optimize(q_loss)

        if self.is_continuous:
            mu = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
        else:
            logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
            _pi = logits.softmax(-1)  # [T, B, A]
            _pi_true_one_hot = F.one_hot(
                logits.argmax(-1), self.a_dim).float()  # [T, B, A]
            _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
            mu = _pi_diff + _pi  # [T, B, A]
        q_actor = self.critic(BATCH.obs, mu, begin_mask=BATCH.begin_mask)  # [T, B, 1]
        actor_loss = -q_actor.mean()  # 1
        self.actor_oplr.optimize(actor_loss)

        return td_error, {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': q_loss,
            'Statistics/q_min': q.min(),
            'Statistics/q_mean': q.mean(),
            'Statistics/q_max': q.max()
        }
Exemple #27
0
class MADDPG(MultiAgentOffPolicy):
    """
    Multi-Agent Deep Deterministic Policy Gradient, https://arxiv.org/abs/1706.02275
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 polyak=0.995,
                 noise_action='ou',
                 noise_params={'sigma': 0.2},
                 actor_lr=5.0e-4,
                 critic_lr=1.0e-3,
                 discrete_tau=1.0,
                 network_settings={
                     'actor_continuous': [32, 32],
                     'actor_discrete': [32, 32],
                     'q': [32, 32]
                 },
                 **kwargs):
        """
        TODO: Annotation
        """
        super().__init__(**kwargs)
        self.polyak = polyak
        self.discrete_tau = discrete_tau

        self.actors, self.critics = {}, {}
        for id in set(self.model_ids):
            if self.is_continuouss[id]:
                self.actors[id] = TargetTwin(
                    ActorDPG(
                        self.obs_specs[id],
                        rep_net_params=self._rep_net_params,
                        output_shape=self.a_dims[id],
                        network_settings=network_settings['actor_continuous']),
                    self.polyak).to(self.device)
            else:
                self.actors[id] = TargetTwin(
                    ActorDct(
                        self.obs_specs[id],
                        rep_net_params=self._rep_net_params,
                        output_shape=self.a_dims[id],
                        network_settings=network_settings['actor_discrete']),
                    self.polyak).to(self.device)
            self.critics[id] = TargetTwin(
                MACriticQvalueOne(list(self.obs_specs.values()),
                                  rep_net_params=self._rep_net_params,
                                  action_dim=sum(self.a_dims.values()),
                                  network_settings=network_settings['q']),
                self.polyak).to(self.device)
        self.actor_oplr = OPLR(list(self.actors.values()), actor_lr,
                               **self._oplr_params)
        self.critic_oplr = OPLR(list(self.critics.values()), critic_lr,
                                **self._oplr_params)

        # TODO: 添加动作类型判断
        self.noised_actions = {
            id: Noise_action_REGISTER[noise_action](**noise_params)
            for id in set(self.model_ids) if self.is_continuouss[id]
        }

        self._trainer_modules.update(
            {f"actor_{id}": self.actors[id]
             for id in set(self.model_ids)})
        self._trainer_modules.update(
            {f"critic_{id}": self.critics[id]
             for id in set(self.model_ids)})
        self._trainer_modules.update(actor_oplr=self.actor_oplr,
                                     critic_oplr=self.critic_oplr)

    def episode_reset(self):
        super().episode_reset()
        for noised_action in self.noised_actions.values():
            noised_action.reset()

    @iton
    def select_action(self, obs: Dict):
        acts_info = {}
        actions = {}
        for aid, mid in zip(self.agent_ids, self.model_ids):
            output = self.actors[mid](obs[aid],
                                      rnncs=self.rnncs[aid])  # [B, A]
            self.rnncs_[aid] = self.actors[mid].get_rnncs()
            if self.is_continuouss[aid]:
                mu = output  # [B, A]
                pi = self.noised_actions[mid](mu)  # [B, A]
            else:
                logits = output  # [B, A]
                mu = logits.argmax(-1)  # [B,]
                cate_dist = td.Categorical(logits=logits)
                pi = cate_dist.sample()  # [B,]
            action = pi if self._is_train_mode else mu
            acts_info[aid] = Data(action=action)
            actions[aid] = action
        return actions, acts_info

    @iton
    def _train(self, BATCH_DICT):
        """
        TODO: Annotation
        """
        summaries = defaultdict(dict)
        target_actions = {}
        for aid, mid in zip(self.agent_ids, self.model_ids):
            if self.is_continuouss[aid]:
                target_actions[aid] = self.actors[mid].t(
                    BATCH_DICT[aid].obs_,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
            else:
                target_logits = self.actors[mid].t(
                    BATCH_DICT[aid].obs_,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
                target_cate_dist = td.Categorical(logits=target_logits)
                target_pi = target_cate_dist.sample()  # [T, B]
                action_target = F.one_hot(
                    target_pi, self.a_dims[aid]).float()  # [T, B, A]
                target_actions[aid] = action_target  # [T, B, A]
        target_actions = th.cat(list(target_actions.values()),
                                -1)  # [T, B, N*A]

        qs, q_targets = {}, {}
        for mid in self.model_ids:
            qs[mid] = self.critics[mid](
                [BATCH_DICT[id].obs for id in self.agent_ids],
                th.cat([BATCH_DICT[id].action for id in self.agent_ids],
                       -1))  # [T, B, 1]
            q_targets[mid] = self.critics[mid].t(
                [BATCH_DICT[id].obs_ for id in self.agent_ids],
                target_actions)  # [T, B, 1]

        q_loss = {}
        td_errors = 0.
        for aid, mid in zip(self.agent_ids, self.model_ids):
            dc_r = n_step_return(
                BATCH_DICT[aid].reward, self.gamma, BATCH_DICT[aid].done,
                q_targets[mid],
                BATCH_DICT['global'].begin_mask).detach()  # [T, B, 1]
            td_error = dc_r - qs[mid]  # [T, B, 1]
            td_errors += td_error
            q_loss[aid] = 0.5 * td_error.square().mean()  # 1
            summaries[aid].update({
                'Statistics/q_min': qs[mid].min(),
                'Statistics/q_mean': qs[mid].mean(),
                'Statistics/q_max': qs[mid].max()
            })
        self.critic_oplr.optimize(sum(q_loss.values()))

        actor_loss = {}
        for aid, mid in zip(self.agent_ids, self.model_ids):
            if self.is_continuouss[aid]:
                mu = self.actors[mid](
                    BATCH_DICT[aid].obs,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
            else:
                logits = self.actors[mid](
                    BATCH_DICT[aid].obs,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
                logp_all = logits.log_softmax(-1)  # [T, B, A]
                gumbel_noise = td.Gumbel(0,
                                         1).sample(logp_all.shape)  # [T, B, A]
                _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(
                    -1)  # [T, B, A]
                _pi_true_one_hot = F.one_hot(
                    _pi.argmax(-1), self.a_dims[aid]).float()  # [T, B, A]
                _pi_diff = (_pi_true_one_hot - _pi).detach()  # [T, B, A]
                mu = _pi_diff + _pi  # [T, B, A]

            all_actions = {id: BATCH_DICT[id].action for id in self.agent_ids}
            all_actions[aid] = mu
            q_actor = self.critics[mid](
                [BATCH_DICT[id].obs for id in self.agent_ids],
                th.cat(list(all_actions.values()), -1),
                begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, 1]
            actor_loss[aid] = -q_actor.mean()  # 1

        self.actor_oplr.optimize(sum(actor_loss.values()))

        for aid in self.agent_ids:
            summaries[aid].update({
                'LOSS/actor_loss': actor_loss[aid],
                'LOSS/critic_loss': q_loss[aid]
            })
        summaries['model'].update({
            'LOSS/actor_loss',
            sum(actor_loss.values()), 'LOSS/critic_loss',
            sum(q_loss.values())
        })
        return td_errors / self.n_agents_percopy, summaries

    def _after_train(self):
        super()._after_train()
        for actor in self.actors.values():
            actor.sync()
        for critic in self.critics.values():
            critic.sync()
Exemple #28
0
class AOC(SarlOnPolicy):
    """
    Asynchronous Advantage Option-Critic with Deliberation Cost, A2OC
    When Waiting is not an Option : Learning Options with a Deliberation Cost, A2OC, http://arxiv.org/abs/1709.04571
    """
    policy_mode = 'on-policy'

    def __init__(
            self,
            agent_spec,
            options_num=4,
            dc=0.01,
            terminal_mask=False,
            eps=0.1,
            pi_beta=1.0e-3,
            lr=5.0e-4,
            lambda_=0.95,
            epsilon=0.2,
            value_epsilon=0.2,
            kl_reverse=False,
            kl_target=0.02,
            kl_target_cutoff=2,
            kl_target_earlystop=4,
            kl_beta=[0.7, 1.3],
            kl_alpha=1.5,
            kl_coef=1.0,
            network_settings={
                'share': [32, 32],
                'q': [32, 32],
                'intra_option': [32, 32],
                'termination': [32, 32]
            },
            **kwargs):
        super().__init__(agent_spec=agent_spec, **kwargs)
        self.pi_beta = pi_beta
        self.lambda_ = lambda_
        self._epsilon = epsilon
        self._value_epsilon = value_epsilon
        self._kl_reverse = kl_reverse
        self._kl_target = kl_target
        self._kl_alpha = kl_alpha
        self._kl_coef = kl_coef

        self._kl_cutoff = kl_target * kl_target_cutoff
        self._kl_stop = kl_target * kl_target_earlystop
        self._kl_low = kl_target * kl_beta[0]
        self._kl_high = kl_target * kl_beta[-1]

        self.options_num = options_num
        self.dc = dc
        self.terminal_mask = terminal_mask
        self.eps = eps

        self.net = AocShare(self.obs_spec,
                            rep_net_params=self._rep_net_params,
                            action_dim=self.a_dim,
                            options_num=self.options_num,
                            network_settings=network_settings,
                            is_continuous=self.is_continuous).to(self.device)
        if self.is_continuous:
            self.log_std = th.as_tensor(
                np.full((self.options_num, self.a_dim),
                        -0.5)).requires_grad_().to(self.device)  # [P, A]
            self.oplr = OPLR([self.net, self.log_std], lr, **self._oplr_params)
        else:
            self.oplr = OPLR(self.net, lr, **self._oplr_params)

        self._trainer_modules.update(model=self.net, oplr=self.oplr)
        self.oc_mask = th.tensor(np.zeros(self.n_copies)).to(self.device)
        self.options = th.tensor(
            np.random.randint(0, self.options_num,
                              self.n_copies)).to(self.device)

    def episode_reset(self):
        super().episode_reset()
        self._done_mask = th.tensor(np.full(self.n_copies,
                                            True)).to(self.device)

    def episode_step(self, obs: Data, env_rets: Data, begin_mask: np.ndarray):
        super().episode_step(obs, env_rets, begin_mask)
        self._done_mask = th.tensor(env_rets.done).to(self.device)
        self.options = self.new_options
        self.oc_mask = th.zeros_like(self.oc_mask)

    @iton
    def select_action(self, obs):
        # [B, P], [B, P, A], [B, P]
        (q, pi, beta) = self.net(obs, rnncs=self.rnncs)
        self.rnncs_ = self.net.get_rnncs()
        options_onehot = F.one_hot(self.options,
                                   self.options_num).float()  # [B, P]
        options_onehot_expanded = options_onehot.unsqueeze(-1)  # [B, P, 1]
        pi = (pi * options_onehot_expanded).sum(-2)  # [B, A]
        if self.is_continuous:
            mu = pi  # [B, A]
            log_std = self.log_std[self.options]  # [B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            action = dist.sample().clamp(-1, 1)  # [B, A]
            log_prob = dist.log_prob(action).unsqueeze(-1)  # [B, 1]
        else:
            logits = pi  # [B, A]
            norm_dist = td.Categorical(logits=logits)
            action = norm_dist.sample()  # [B,]
            log_prob = norm_dist.log_prob(action).unsqueeze(-1)  # [B, 1]
        value = q_o = (q * options_onehot).sum(-1, keepdim=True)  # [B, 1]
        beta_adv = q_o - ((1 - self.eps) * q.max(-1, keepdim=True)[0] +
                          self.eps * q.mean(-1, keepdim=True))  # [B, 1]
        max_options = q.argmax(-1)  # [B, P] => [B, ]
        beta_probs = (beta * options_onehot).sum(-1)  # [B, P] => [B,]
        beta_dist = td.Bernoulli(probs=beta_probs)
        # <1 则不改变op, =1 则改变op
        new_options = th.where(beta_dist.sample() < 1, self.options,
                               max_options)
        self.new_options = th.where(self._done_mask, max_options, new_options)
        self.oc_mask = (self.new_options == self.options).float()
        acts_info = Data(
            action=action,
            value=value,
            log_prob=log_prob + th.finfo().eps,
            beta_advantage=beta_adv + self.dc,
            last_options=self.options,
            options=self.new_options,
            reward_offset=-((1 - self.oc_mask) * self.dc).unsqueeze(-1))
        if self.use_rnn:
            acts_info.update(rnncs=self.rnncs)
        return action, acts_info

    @iton
    def _get_value(self, obs, options, rnncs=None):
        (q, _, _) = self.net(obs, rnncs=rnncs)  # [B, P]
        value = (q * options).sum(-1, keepdim=True)  # [B, 1]
        return value

    def _preprocess_BATCH(self, BATCH):  # [T, B, *]
        BATCH = super()._preprocess_BATCH(BATCH)
        BATCH.reward += BATCH.reward_offset

        BATCH.last_options = int2one_hot(BATCH.last_options, self.options_num)
        BATCH.options = int2one_hot(BATCH.options, self.options_num)
        value = self._get_value(BATCH.obs_[-1],
                                BATCH.options[-1],
                                rnncs=self.rnncs)
        BATCH.discounted_reward = discounted_sum(BATCH.reward,
                                                 self.gamma,
                                                 BATCH.done,
                                                 BATCH.begin_mask,
                                                 init_value=value)
        td_error = calculate_td_error(
            BATCH.reward,
            self.gamma,
            BATCH.done,
            value=BATCH.value,
            next_value=np.concatenate((BATCH.value[1:], value[np.newaxis, :]),
                                      0))
        BATCH.gae_adv = discounted_sum(td_error,
                                       self.lambda_ * self.gamma,
                                       BATCH.done,
                                       BATCH.begin_mask,
                                       init_value=0.,
                                       normalize=True)
        return BATCH

    def learn(self, BATCH: Data):
        BATCH = self._preprocess_BATCH(BATCH)  # [T, B, *]
        for _ in range(self._epochs):
            kls = []
            for _BATCH in BATCH.sample(self._chunk_length,
                                       self.batch_size,
                                       repeat=self._sample_allow_repeat):
                _BATCH = self._before_train(_BATCH)
                summaries, kl = self._train(_BATCH)
                kls.append(kl)
                self.summaries.update(summaries)
                self._after_train()
            if sum(kls) / len(kls) > self._kl_stop:
                break

    @iton
    def _train(self, BATCH):
        # [T, B, P], [T, B, P, A], [T, B, P]
        (q, pi, beta) = self.net(BATCH.obs, begin_mask=BATCH.begin_mask)
        options_onehot_expanded = BATCH.options.unsqueeze(-1)  # [T, B, P, 1]
        # [T, B, P, A] => [T, B, A]
        pi = (pi * options_onehot_expanded).sum(-2)
        value = (q * BATCH.options).sum(-1, keepdim=True)  # [T, B, 1]

        if self.is_continuous:
            mu = pi  # [T, B, A]
            log_std = self.log_std[BATCH.options.argmax(-1)]  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            new_log_prob = dist.log_prob(BATCH.action).unsqueeze(
                -1)  # [T, B, 1]
            entropy = dist.entropy().mean()  # 1
        else:
            logits = pi  # [T, B, A]
            logp_all = logits.log_softmax(-1)  # [T, B, A]
            new_log_prob = (BATCH.action * logp_all).sum(
                -1, keepdim=True)  # [T, B, 1]
            entropy = -(logp_all.exp() * logp_all).sum(
                -1, keepdim=True).mean()  # 1
        ratio = (new_log_prob - BATCH.log_prob).exp()  # [T, B, 1]

        if self._kl_reverse:
            kl = (new_log_prob - BATCH.log_prob).mean()  # 1
        else:
            # a sample estimate for KL-divergence, easy to compute
            kl = (BATCH.log_prob - new_log_prob).mean()
        surrogate = ratio * BATCH.gae_adv  # [T, B, 1]

        value_clip = BATCH.value + (value - BATCH.value).clamp(
            -self._value_epsilon, self._value_epsilon)  # [T, B, 1]
        td_error = BATCH.discounted_reward - value  # [T, B, 1]
        td_error_clip = BATCH.discounted_reward - value_clip  # [T, B, 1]
        td_square = th.maximum(td_error.square(),
                               td_error_clip.square())  # [T, B, 1]

        pi_loss = -th.minimum(
            surrogate,
            ratio.clamp(1.0 - self._epsilon, 1.0 + self._epsilon) *
            BATCH.gae_adv).mean()  # [T, B, 1]
        kl_loss = self._kl_coef * kl
        extra_loss = 1000.0 * th.maximum(th.zeros_like(kl),
                                         kl - self._kl_cutoff).square().mean()
        pi_loss = pi_loss + kl_loss + extra_loss  # 1
        q_loss = 0.5 * td_square.mean()  # 1

        beta_s = (beta * BATCH.last_options).sum(-1, keepdim=True)  # [T, B, 1]
        beta_loss = (beta_s * BATCH.beta_advantage)  # [T, B, 1]
        if self.terminal_mask:
            beta_loss *= (1 - BATCH.done)  # [T, B, 1]
        beta_loss = beta_loss.mean()  # 1

        loss = pi_loss + 1.0 * q_loss + beta_loss - self.pi_beta * entropy
        self.oplr.optimize(loss)

        if kl > self._kl_high:
            self._kl_coef *= self._kl_alpha
        elif kl < self._kl_low:
            self._kl_coef /= self._kl_alpha

        return {
            'LOSS/loss': loss,
            'LOSS/pi_loss': pi_loss,
            'LOSS/q_loss': q_loss,
            'LOSS/beta_loss': beta_loss,
            'Statistics/kl': kl,
            'Statistics/kl_coef': self._kl_coef,
            'Statistics/entropy': entropy,
            'LEARNING_RATE/lr': self.oplr.lr
        }, kl
Exemple #29
0
class MAXSQN(SarlOffPolicy):
    """
    https://github.com/createamind/DRL/blob/master/spinup/algos/maxsqn/maxsqn.py
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 alpha=0.2,
                 beta=0.1,
                 polyak=0.995,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 use_epsilon=False,
                 q_lr=5.0e-4,
                 alpha_lr=5.0e-4,
                 auto_adaption=True,
                 network_settings=[32, 32],
                 **kwargs):
        super().__init__(**kwargs)
        assert not self.is_continuous, 'maxsqn only support discrete action space'
        self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init,
                                                          eps_mid=eps_mid,
                                                          eps_final=eps_final,
                                                          init2mid_annealing_step=init2mid_annealing_step,
                                                          max_step=self._max_train_step)
        self.use_epsilon = use_epsilon
        self.polyak = polyak
        self.auto_adaption = auto_adaption
        self.target_entropy = beta * np.log(self.a_dim)

        self.critic = TargetTwin(CriticQvalueAll(self.obs_spec,
                                                 rep_net_params=self._rep_net_params,
                                                 output_shape=self.a_dim,
                                                 network_settings=network_settings),
                                 self.polyak).to(self.device)
        self.critic2 = deepcopy(self.critic)

        self.critic_oplr = OPLR([self.critic, self.critic2], q_lr, **self._oplr_params)

        if self.auto_adaption:
            self.log_alpha = th.tensor(0., requires_grad=True).to(self.device)
            self.alpha_oplr = OPLR(self.log_alpha, alpha_lr, **self._oplr_params)
            self._trainer_modules.update(alpha_oplr=self.alpha_oplr)
        else:
            self.log_alpha = th.tensor(alpha).log().to(self.device)

        self._trainer_modules.update(critic=self.critic,
                                     critic2=self.critic2,
                                     log_alpha=self.log_alpha,
                                     critic_oplr=self.critic_oplr)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    @iton
    def select_action(self, obs):
        q = self.critic(obs, rnncs=self.rnncs)  # [B, A]
        self.rnncs_ = self.critic.get_rnncs()

        if self.use_epsilon and self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            cate_dist = td.Categorical(logits=(q / self.alpha))
            mu = q.argmax(-1)  # [B,]
            actions = pi = cate_dist.sample()  # [B,]
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        q1 = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q2 = self.critic2(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q1_eval = (q1 * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        q2_eval = (q2 * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]

        q1_log_probs = (q1 / (self.alpha + th.finfo().eps)).log_softmax(-1)  # [T, B, A]
        q1_entropy = -(q1_log_probs.exp() * q1_log_probs).sum(-1, keepdim=True).mean()  # 1

        q1_target = self.critic.t(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q2_target = self.critic2.t(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q1_target_max = q1_target.max(-1, keepdim=True)[0]  # [T, B, 1]
        q1_target_log_probs = (q1_target / (self.alpha + th.finfo().eps)).log_softmax(-1)  # [T, B, A]
        q1_target_entropy = -(q1_target_log_probs.exp() * q1_target_log_probs).sum(-1, keepdim=True)  # [T, B, 1]

        q2_target_max = q2_target.max(-1, keepdim=True)[0]  # [T, B, 1]
        # q2_target_log_probs = q2_target.log_softmax(-1)
        # q2_target_log_max = q2_target_log_probs.max(1, keepdim=True)[0]

        q_target = th.minimum(q1_target_max, q2_target_max) + self.alpha * q1_target_entropy  # [T, B, 1]
        dc_r = n_step_return(BATCH.reward,
                             self.gamma,
                             BATCH.done,
                             q_target,
                             BATCH.begin_mask).detach()  # [T, B, 1]
        td_error1 = q1_eval - dc_r  # [T, B, 1]
        td_error2 = q2_eval - dc_r  # [T, B, 1]
        q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
        q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
        loss = 0.5 * (q1_loss + q2_loss)
        self.critic_oplr.optimize(loss)
        summaries = {
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/loss': loss,
            'Statistics/log_alpha': self.log_alpha,
            'Statistics/alpha': self.alpha,
            'Statistics/q1_entropy': q1_entropy,
            'Statistics/q_min': th.minimum(q1, q2).mean(),
            'Statistics/q_mean': q1.mean(),
            'Statistics/q_max': th.maximum(q1, q2).mean()
        }
        if self.auto_adaption:
            alpha_loss = -(self.alpha * (self.target_entropy - q1_entropy).detach()).mean()
            self.alpha_oplr.optimize(alpha_loss)
            summaries.update({
                'LOSS/alpha_loss': alpha_loss,
                'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr
            })
        return (td_error1 + td_error2) / 2, summaries

    def _after_train(self):
        super()._after_train()
        self.critic.sync()
        self.critic2.sync()
Exemple #30
0
class AC(SarlOffPolicy):
    policy_mode = 'off-policy'

    # off-policy actor-critic
    def __init__(
            self,
            actor_lr=5.0e-4,
            critic_lr=1.0e-3,
            network_settings={
                'actor_continuous': {
                    'hidden_units': [64, 64],
                    'condition_sigma': False,
                    'log_std_bound': [-20, 2]
                },
                'actor_discrete': [32, 32],
                'critic': [32, 32]
            },
            **kwargs):
        super().__init__(**kwargs)

        if self.is_continuous:
            self.actor = ActorMuLogstd(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_continuous']).to(
                    self.device)
        else:
            self.actor = ActorDct(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=network_settings['actor_discrete']).to(
                    self.device)
        self.critic = CriticQvalueOne(
            self.obs_spec,
            rep_net_params=self._rep_net_params,
            action_dim=self.a_dim,
            network_settings=network_settings['critic']).to(self.device)

        self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params)
        self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params)

        self._trainer_modules.update(actor=self.actor,
                                     critic=self.critic,
                                     actor_oplr=self.actor_oplr,
                                     critic_oplr=self.critic_oplr)

    @iton
    def select_action(self, obs):
        output = self.actor(obs, rnncs=self.rnncs)  # [B, *]
        self.rnncs_ = self.actor.get_rnncs()
        if self.is_continuous:
            mu, log_std = output  # [B, *]
            dist = td.Independent(td.Normal(mu, log_std.exp()), -1)
            action = dist.sample().clamp(-1, 1)  # [B, *]
            log_prob = dist.log_prob(action)  # [B,]
        else:
            logits = output  # [B, *]
            norm_dist = td.Categorical(logits=logits)
            action = norm_dist.sample()  # [B,]
            log_prob = norm_dist.log_prob(action)  # [B,]
        return action, Data(action=action, log_prob=log_prob)

    def random_action(self):
        actions = super().random_action()
        if self.is_continuous:
            self._acts_info.update(log_prob=np.full(self.n_copies,
                                                    np.log(0.5)))  # [B,]
        else:
            self._acts_info.update(log_prob=np.full(self.n_copies,
                                                    1. / self.a_dim))  # [B,]
        return actions

    @iton
    def _train(self, BATCH):
        q = self.critic(BATCH.obs, BATCH.action,
                        begin_mask=BATCH.begin_mask)  # [T, B, 1]
        if self.is_continuous:
            next_mu, _ = self.actor(BATCH.obs_,
                                    begin_mask=BATCH.begin_mask)  # [T, B, *]
            max_q_next = self.critic(
                BATCH.obs_, next_mu,
                begin_mask=BATCH.begin_mask).detach()  # [T, B, 1]
        else:
            logits = self.actor(BATCH.obs_,
                                begin_mask=BATCH.begin_mask)  # [T, B, *]
            max_a = logits.argmax(-1)  # [T, B]
            max_a_one_hot = F.one_hot(max_a, self.a_dim).float()  # [T, B, N]
            max_q_next = self.critic(BATCH.obs_,
                                     max_a_one_hot).detach()  # [T, B, 1]
        td_error = q - n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                     max_q_next,
                                     BATCH.begin_mask).detach()  # [T, B, 1]
        critic_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1
        self.critic_oplr.optimize(critic_loss)

        if self.is_continuous:
            mu, log_std = self.actor(BATCH.obs,
                                     begin_mask=BATCH.begin_mask)  # [T, B, *]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            log_prob = dist.log_prob(BATCH.action)  # [T, B]
            entropy = dist.entropy().mean()  # 1
        else:
            logits = self.actor(BATCH.obs,
                                begin_mask=BATCH.begin_mask)  # [T, B, *]
            logp_all = logits.log_softmax(-1)  # [T, B, *]
            log_prob = (logp_all * BATCH.action).sum(-1)  # [T, B]
            entropy = -(logp_all.exp() * logp_all).sum(-1).mean()  # 1
        ratio = (log_prob - BATCH.log_prob).exp().detach()  # [T, B]
        actor_loss = -(ratio * log_prob *
                       q.squeeze(-1).detach()).mean()  # [T, B] => 1
        self.actor_oplr.optimize(actor_loss)

        return td_error, {
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': critic_loss,
            'Statistics/q_max': q.max(),
            'Statistics/q_min': q.min(),
            'Statistics/q_mean': q.mean(),
            'Statistics/ratio': ratio.mean(),
            'Statistics/entropy': entropy
        }