Esempio n. 1
0
class MAXSQN(SarlOffPolicy):
    """
    https://github.com/createamind/DRL/blob/master/spinup/algos/maxsqn/maxsqn.py
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 alpha=0.2,
                 beta=0.1,
                 polyak=0.995,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 use_epsilon=False,
                 q_lr=5.0e-4,
                 alpha_lr=5.0e-4,
                 auto_adaption=True,
                 network_settings=[32, 32],
                 **kwargs):
        super().__init__(**kwargs)
        assert not self.is_continuous, 'maxsqn only support discrete action space'
        self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init,
                                                          eps_mid=eps_mid,
                                                          eps_final=eps_final,
                                                          init2mid_annealing_step=init2mid_annealing_step,
                                                          max_step=self._max_train_step)
        self.use_epsilon = use_epsilon
        self.polyak = polyak
        self.auto_adaption = auto_adaption
        self.target_entropy = beta * np.log(self.a_dim)

        self.critic = TargetTwin(CriticQvalueAll(self.obs_spec,
                                                 rep_net_params=self._rep_net_params,
                                                 output_shape=self.a_dim,
                                                 network_settings=network_settings),
                                 self.polyak).to(self.device)
        self.critic2 = deepcopy(self.critic)

        self.critic_oplr = OPLR([self.critic, self.critic2], q_lr, **self._oplr_params)

        if self.auto_adaption:
            self.log_alpha = th.tensor(0., requires_grad=True).to(self.device)
            self.alpha_oplr = OPLR(self.log_alpha, alpha_lr, **self._oplr_params)
            self._trainer_modules.update(alpha_oplr=self.alpha_oplr)
        else:
            self.log_alpha = th.tensor(alpha).log().to(self.device)

        self._trainer_modules.update(critic=self.critic,
                                     critic2=self.critic2,
                                     log_alpha=self.log_alpha,
                                     critic_oplr=self.critic_oplr)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    @iton
    def select_action(self, obs):
        q = self.critic(obs, rnncs=self.rnncs)  # [B, A]
        self.rnncs_ = self.critic.get_rnncs()

        if self.use_epsilon and self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            cate_dist = td.Categorical(logits=(q / self.alpha))
            mu = q.argmax(-1)  # [B,]
            actions = pi = cate_dist.sample()  # [B,]
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        q1 = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q2 = self.critic2(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q1_eval = (q1 * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        q2_eval = (q2 * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]

        q1_log_probs = (q1 / (self.alpha + th.finfo().eps)).log_softmax(-1)  # [T, B, A]
        q1_entropy = -(q1_log_probs.exp() * q1_log_probs).sum(-1, keepdim=True).mean()  # 1

        q1_target = self.critic.t(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q2_target = self.critic2.t(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q1_target_max = q1_target.max(-1, keepdim=True)[0]  # [T, B, 1]
        q1_target_log_probs = (q1_target / (self.alpha + th.finfo().eps)).log_softmax(-1)  # [T, B, A]
        q1_target_entropy = -(q1_target_log_probs.exp() * q1_target_log_probs).sum(-1, keepdim=True)  # [T, B, 1]

        q2_target_max = q2_target.max(-1, keepdim=True)[0]  # [T, B, 1]
        # q2_target_log_probs = q2_target.log_softmax(-1)
        # q2_target_log_max = q2_target_log_probs.max(1, keepdim=True)[0]

        q_target = th.minimum(q1_target_max, q2_target_max) + self.alpha * q1_target_entropy  # [T, B, 1]
        dc_r = n_step_return(BATCH.reward,
                             self.gamma,
                             BATCH.done,
                             q_target,
                             BATCH.begin_mask).detach()  # [T, B, 1]
        td_error1 = q1_eval - dc_r  # [T, B, 1]
        td_error2 = q2_eval - dc_r  # [T, B, 1]
        q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean()  # 1
        q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean()  # 1
        loss = 0.5 * (q1_loss + q2_loss)
        self.critic_oplr.optimize(loss)
        summaries = {
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/loss': loss,
            'Statistics/log_alpha': self.log_alpha,
            'Statistics/alpha': self.alpha,
            'Statistics/q1_entropy': q1_entropy,
            'Statistics/q_min': th.minimum(q1, q2).mean(),
            'Statistics/q_mean': q1.mean(),
            'Statistics/q_max': th.maximum(q1, q2).mean()
        }
        if self.auto_adaption:
            alpha_loss = -(self.alpha * (self.target_entropy - q1_entropy).detach()).mean()
            self.alpha_oplr.optimize(alpha_loss)
            summaries.update({
                'LOSS/alpha_loss': alpha_loss,
                'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr
            })
        return (td_error1 + td_error2) / 2, summaries

    def _after_train(self):
        super()._after_train()
        self.critic.sync()
        self.critic2.sync()
Esempio n. 2
0
class OC(SarlOffPolicy):
    """
    The Option-Critic Architecture. http://arxiv.org/abs/1609.05140
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 q_lr=5.0e-3,
                 intra_option_lr=5.0e-4,
                 termination_lr=5.0e-4,
                 use_eps_greedy=False,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 boltzmann_temperature=1.0,
                 options_num=4,
                 ent_coff=0.01,
                 double_q=False,
                 use_baseline=True,
                 terminal_mask=True,
                 termination_regularizer=0.01,
                 assign_interval=1000,
                 network_settings={
                     'q': [32, 32],
                     'intra_option': [32, 32],
                     'termination': [32, 32]
                 },
                 **kwargs):
        super().__init__(**kwargs)
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self._max_train_step)
        self.assign_interval = assign_interval
        self.options_num = options_num
        self.termination_regularizer = termination_regularizer
        self.ent_coff = ent_coff
        self.use_baseline = use_baseline
        self.terminal_mask = terminal_mask
        self.double_q = double_q
        self.boltzmann_temperature = boltzmann_temperature
        self.use_eps_greedy = use_eps_greedy

        self.q_net = TargetTwin(
            CriticQvalueAll(self.obs_spec,
                            rep_net_params=self._rep_net_params,
                            output_shape=self.options_num,
                            network_settings=network_settings['q'])).to(
                                self.device)

        self.intra_option_net = OcIntraOption(
            self.obs_spec,
            rep_net_params=self._rep_net_params,
            output_shape=self.a_dim,
            options_num=self.options_num,
            network_settings=network_settings['intra_option']).to(self.device)
        self.termination_net = CriticQvalueAll(
            self.obs_spec,
            rep_net_params=self._rep_net_params,
            output_shape=self.options_num,
            network_settings=network_settings['termination'],
            out_act='sigmoid').to(self.device)

        if self.is_continuous:
            # https://discuss.pytorch.org/t/valueerror-cant-optimize-a-non-leaf-tensor/21751
            # https://blog.csdn.net/nkhgl/article/details/100047276
            self.log_std = th.as_tensor(
                np.full((self.options_num, self.a_dim),
                        -0.5)).requires_grad_().to(self.device)  # [P, A]
            self.intra_option_oplr = OPLR(
                [self.intra_option_net, self.log_std], intra_option_lr,
                **self._oplr_params)
        else:
            self.intra_option_oplr = OPLR(self.intra_option_net,
                                          intra_option_lr, **self._oplr_params)
        self.q_oplr = OPLR(self.q_net, q_lr, **self._oplr_params)
        self.termination_oplr = OPLR(self.termination_net, termination_lr,
                                     **self._oplr_params)

        self._trainer_modules.update(q_net=self.q_net,
                                     intra_option_net=self.intra_option_net,
                                     termination_net=self.termination_net,
                                     q_oplr=self.q_oplr,
                                     intra_option_oplr=self.intra_option_oplr,
                                     termination_oplr=self.termination_oplr)
        self.options = self.new_options = self._generate_random_options()

    def _generate_random_options(self):
        # [B,]
        return th.tensor(np.random.randint(0, self.options_num,
                                           self.n_copies)).to(self.device)

    def episode_step(self, obs: Data, env_rets: Data, begin_mask: np.ndarray):
        super().episode_step(obs, env_rets, begin_mask)
        self.options = self.new_options

    @iton
    def select_action(self, obs):
        q = self.q_net(obs, rnncs=self.rnncs)  # [B, P]
        self.rnncs_ = self.q_net.get_rnncs()
        pi = self.intra_option_net(obs, rnncs=self.rnncs)  # [B, P, A]
        beta = self.termination_net(obs, rnncs=self.rnncs)  # [B, P]
        options_onehot = F.one_hot(self.options,
                                   self.options_num).float()  # [B, P]
        options_onehot_expanded = options_onehot.unsqueeze(-1)  # [B, P, 1]
        pi = (pi * options_onehot_expanded).sum(-2)  # [B, A]
        if self.is_continuous:
            mu = pi.tanh()  # [B, A]
            log_std = self.log_std[self.options]  # [B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            actions = dist.sample().clamp(-1, 1)  # [B, A]
        else:
            pi = pi / self.boltzmann_temperature  # [B, A]
            dist = td.Categorical(logits=pi)
            actions = dist.sample()  # [B, ]
        max_options = q.argmax(-1).long()  # [B, P] => [B, ]
        if self.use_eps_greedy:
            # epsilon greedy
            if self._is_train_mode and self.expl_expt_mng.is_random(
                    self._cur_train_step):
                self.new_options = self._generate_random_options()
            else:
                self.new_options = max_options
        else:
            beta_probs = (beta * options_onehot).sum(-1)  # [B, P] => [B,]
            beta_dist = td.Bernoulli(probs=beta_probs)
            self.new_options = th.where(beta_dist.sample() < 1, self.options,
                                        max_options)
        return actions, Data(action=actions,
                             last_options=self.options,
                             options=self.new_options)

    def random_action(self):
        actions = super().random_action()
        self._acts_info.update(
            last_options=np.random.randint(0, self.options_num, self.n_copies),
            options=np.random.randint(0, self.options_num, self.n_copies))
        return actions

    def _preprocess_BATCH(self, BATCH):  # [T, B, *]
        BATCH = super()._preprocess_BATCH(BATCH)
        BATCH.last_options = int2one_hot(BATCH.last_options, self.options_num)
        BATCH.options = int2one_hot(BATCH.options, self.options_num)
        return BATCH

    @iton
    def _train(self, BATCH):
        q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, P]
        q_next = self.q_net.t(BATCH.obs_,
                              begin_mask=BATCH.begin_mask)  # [T, B, P]
        beta_next = self.termination_net(
            BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, P]

        qu_eval = (q * BATCH.options).sum(-1, keepdim=True)  # [T, B, 1]
        beta_s_ = (beta_next * BATCH.options).sum(-1,
                                                  keepdim=True)  # [T, B, 1]
        q_s_ = (q_next * BATCH.options).sum(-1, keepdim=True)  # [T, B, 1]
        # https://github.com/jeanharb/option_critic/blob/5d6c81a650a8f452bc8ad3250f1f211d317fde8c/neural_net.py#L94
        if self.double_q:
            q_ = self.q_net(BATCH.obs_,
                            begin_mask=BATCH.begin_mask)  # [T, B, P]
            # [T, B, P] => [T, B] => [T, B, P]
            max_a_idx = F.one_hot(q_.argmax(-1), self.options_num).float()
            q_s_max = (q_next * max_a_idx).sum(-1, keepdim=True)  # [T, B, 1]
        else:
            q_s_max = q_next.max(-1, keepdim=True)[0]  # [T, B, 1]
        u_target = (1 - beta_s_) * q_s_ + beta_s_ * q_s_max  # [T, B, 1]
        qu_target = n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                  u_target,
                                  BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = qu_target - qu_eval  # gradient : q   [T, B, 1]
        q_loss = (td_error.square() *
                  BATCH.get('isw', 1.0)).mean()  # [T, B, 1] => 1
        self.q_oplr.optimize(q_loss)

        q_s = qu_eval.detach()  # [T, B, 1]
        # https://github.com/jeanharb/option_critic/blob/5d6c81a650a8f452bc8ad3250f1f211d317fde8c/neural_net.py#L130
        if self.use_baseline:
            adv = (qu_target - q_s).detach()  # [T, B, 1]
        else:
            adv = qu_target.detach()  # [T, B, 1]
        # [T, B, P] => [T, B, P, 1]
        options_onehot_expanded = BATCH.options.unsqueeze(-1)
        pi = self.intra_option_net(BATCH.obs,
                                   begin_mask=BATCH.begin_mask)  # [T, B, P, A]
        # [T, B, P, A] => [T, B, A]
        pi = (pi * options_onehot_expanded).sum(-2)
        if self.is_continuous:
            mu = pi.tanh()  # [T, B, A]
            log_std = self.log_std[BATCH.options.argmax(-1)]  # [T, B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            log_p = dist.log_prob(BATCH.action).unsqueeze(-1)  # [T, B, 1]
            entropy = dist.entropy().unsqueeze(-1)  # [T, B, 1]
        else:
            pi = pi / self.boltzmann_temperature  # [T, B, A]
            log_pi = pi.log_softmax(-1)  # [T, B, A]
            entropy = -(log_pi.exp() * log_pi).sum(-1,
                                                   keepdim=True)  # [T, B, 1]
            log_p = (BATCH.action * log_pi).sum(-1, keepdim=True)  # [T, B, 1]
        pi_loss = -(log_p * adv + self.ent_coff * entropy).mean()  # 1

        beta = self.termination_net(BATCH.obs,
                                    begin_mask=BATCH.begin_mask)  # [T, B, P]
        beta_s = (beta * BATCH.last_options).sum(-1, keepdim=True)  # [T, B, 1]
        if self.use_eps_greedy:
            v_s = q.max(
                -1,
                keepdim=True)[0] - self.termination_regularizer  # [T, B, 1]
        else:
            v_s = (1 - beta_s) * q_s + beta_s * q.max(
                -1, keepdim=True)[0]  # [T, B, 1]
            # v_s = q.mean(-1, keepdim=True)  # [T, B, 1]
        beta_loss = beta_s * (q_s - v_s).detach()  # [T, B, 1]
        # https://github.com/lweitkamp/option-critic-pytorch/blob/0c57da7686f8903ed2d8dded3fae832ee9defd1a/option_critic.py#L238
        if self.terminal_mask:
            beta_loss *= (1 - BATCH.done)  # [T, B, 1]
        beta_loss = beta_loss.mean()  # 1

        self.intra_option_oplr.optimize(pi_loss)
        self.termination_oplr.optimize(beta_loss)

        return td_error, {
            'LEARNING_RATE/q_lr': self.q_oplr.lr,
            'LEARNING_RATE/intra_option_lr': self.intra_option_oplr.lr,
            'LEARNING_RATE/termination_lr': self.termination_oplr.lr,
            # 'Statistics/option': self.options[0],
            'LOSS/q_loss': q_loss,
            'LOSS/pi_loss': pi_loss,
            'LOSS/beta_loss': beta_loss,
            'Statistics/q_option_max': q_s.max(),
            'Statistics/q_option_min': q_s.min(),
            'Statistics/q_option_mean': q_s.mean()
        }

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            self.q_net.sync()
Esempio n. 3
0
class C51(SarlOffPolicy):
    """
    Category 51, https://arxiv.org/abs/1707.06887
    No double, no dueling, no noisy net.
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 v_min=-10,
                 v_max=10,
                 atoms=51,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=1000,
                 network_settings=[128, 128],
                 **kwargs):
        super().__init__(**kwargs)
        assert not self.is_continuous, 'c51 only support discrete action space'
        self._v_min = v_min
        self._v_max = v_max
        self._atoms = atoms
        self._delta_z = (self._v_max - self._v_min) / (self._atoms - 1)
        self._z = th.linspace(self._v_min, self._v_max,
                              self._atoms).float().to(self.device)  # [N,]
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self._max_train_step)
        self.assign_interval = assign_interval
        self.q_net = TargetTwin(
            C51Distributional(self.obs_spec,
                              rep_net_params=self._rep_net_params,
                              action_dim=self.a_dim,
                              atoms=self._atoms,
                              network_settings=network_settings)).to(
                                  self.device)
        self.oplr = OPLR(self.q_net, lr, **self._oplr_params)
        self._trainer_modules.update(model=self.q_net, oplr=self.oplr)

    @iton
    def select_action(self, obs):
        feat = self.q_net(obs, rnncs=self.rnncs)  # [B, A, N]
        self.rnncs_ = self.q_net.get_rnncs()

        if self._is_train_mode and self.expl_expt_mng.is_random(
                self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            q = (self._z * feat).sum(-1)  # [B, A, N] * [N,] => [B, A]
            actions = q.argmax(-1)  # [B,]
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        q_dist = self.q_net(BATCH.obs,
                            begin_mask=BATCH.begin_mask)  # [T, B, A, N]
        # [T, B, A, N] * [T, B, A, 1] => [T, B, A, N] => [T, B, N]
        q_dist = (q_dist * BATCH.action.unsqueeze(-1)).sum(-2)

        q_eval = (q_dist * self._z).sum(-1)  # [T, B, N] * [N,] => [T, B]

        target_q_dist = self.q_net.t(
            BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A, N]
        # [T, B, A, N] * [1, N] => [T, B, A]
        target_q = (target_q_dist * self._z).sum(-1)
        a_ = target_q.argmax(-1)  # [T, B]
        a_onehot = F.one_hot(a_, self.a_dim).float()  # [T, B, A]
        # [T, B, A, N] * [T, B, A, 1] => [T, B, A, N] => [T, B, N]
        target_q_dist = (target_q_dist * a_onehot.unsqueeze(-1)).sum(-2)

        target = n_step_return(
            BATCH.reward.repeat(1, 1, self._atoms), self.gamma,
            BATCH.done.repeat(1, 1, self._atoms), target_q_dist,
            BATCH.begin_mask.repeat(1, 1, self._atoms)).detach()  # [T, B, N]
        target = target.clamp(self._v_min, self._v_max)  # [T, B, N]
        # An amazing trick for calculating the projection gracefully.
        # ref: https://github.com/ShangtongZhang/DeepRL
        target_dist = (
            1 - (target.unsqueeze(-1) - self._z.view(1, 1, -1, 1)).abs() /
            self._delta_z).clamp(0, 1) * target_q_dist.unsqueeze(
                -1)  # [T, B, N, 1]
        target_dist = target_dist.sum(-1)  # [T, B, N]

        _cross_entropy = -(target_dist * th.log(q_dist + th.finfo().eps)).sum(
            -1, keepdim=True)  # [T, B, 1]
        loss = (_cross_entropy * BATCH.get('isw', 1.0)).mean()  # 1

        self.oplr.optimize(loss)
        return _cross_entropy, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            self.q_net.sync()
Esempio n. 4
0
class DQN(SarlOffPolicy):
    """
    Deep Q-learning Network, DQN, [2013](https://arxiv.org/pdf/1312.5602.pdf), [2015](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
    DQN + LSTM, https://arxiv.org/abs/1507.06527
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 lr: float = 5.0e-4,
                 eps_init: float = 1,
                 eps_mid: float = 0.2,
                 eps_final: float = 0.01,
                 init2mid_annealing_step: int = 1000,
                 assign_interval: int = 1000,
                 network_settings: List[int] = [32, 32],
                 **kwargs):
        super().__init__(**kwargs)
        assert not self.is_continuous, 'dqn only support discrete action space'
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self._max_train_step)
        self.assign_interval = assign_interval
        self.q_net = TargetTwin(
            CriticQvalueAll(self.obs_spec,
                            rep_net_params=self._rep_net_params,
                            output_shape=self.a_dim,
                            network_settings=network_settings)).to(self.device)
        self.oplr = OPLR(self.q_net, lr, **self._oplr_params)
        self._trainer_modules.update(model=self.q_net)
        self._trainer_modules.update(oplr=self.oplr)

    @iton
    def select_action(self, obs):
        q_values = self.q_net(obs, rnncs=self.rnncs)  # [B, *]
        self.rnncs_ = self.q_net.get_rnncs()

        if self._is_train_mode and self.expl_expt_mng.is_random(
                self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            actions = q_values.argmax(-1)  # [B,]
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q_next = self.q_net.t(BATCH.obs_,
                              begin_mask=BATCH.begin_mask)  # [T, B, A]
        q_eval = (q * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        q_target = n_step_return(
            BATCH.reward,
            self.gamma,
            BATCH.done,
            q_next.max(-1, keepdim=True)[0],
            BATCH.begin_mask,
            nstep=self._n_step_value).detach()  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1]
        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1
        self.oplr.optimize(q_loss)
        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': q_loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            self.q_net.sync()
Esempio n. 5
0
class IQN(SarlOffPolicy):
    """
    Implicit Quantile Networks, https://arxiv.org/abs/1806.06923
    Double DQN
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 online_quantiles=8,
                 target_quantiles=8,
                 select_quantiles=32,
                 quantiles_idx=64,
                 huber_delta=1.,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=2,
                 network_settings={
                     'q_net': [128, 64],
                     'quantile': [128, 64],
                     'tile': [64]
                 },
                 **kwargs):
        super().__init__(**kwargs)
        assert not self.is_continuous, 'iqn only support discrete action space'
        self.online_quantiles = online_quantiles
        self.target_quantiles = target_quantiles
        self.select_quantiles = select_quantiles
        self.quantiles_idx = quantiles_idx
        self.huber_delta = huber_delta
        self.assign_interval = assign_interval
        self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init,
                                                          eps_mid=eps_mid,
                                                          eps_final=eps_final,
                                                          init2mid_annealing_step=init2mid_annealing_step,
                                                          max_step=self._max_train_step)
        self.q_net = TargetTwin(IqnNet(self.obs_spec,
                                       rep_net_params=self._rep_net_params,
                                       action_dim=self.a_dim,
                                       quantiles_idx=self.quantiles_idx,
                                       network_settings=network_settings)).to(self.device)
        self.oplr = OPLR(self.q_net, lr, **self._oplr_params)
        self._trainer_modules.update(model=self.q_net,
                                     oplr=self.oplr)

    @iton
    def select_action(self, obs):
        _, select_quantiles_tiled = self._generate_quantiles(  # [N*B, X]
            batch_size=self.n_copies,
            quantiles_num=self.select_quantiles
        )
        q_values = self.q_net(obs, select_quantiles_tiled, rnncs=self.rnncs)  # [N, B, A]
        self.rnncs_ = self.q_net.get_rnncs()

        if self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            # [N, B, A] => [B, A] => [B,]
            actions = q_values.mean(0).argmax(-1)
        return actions, Data(action=actions)

    def _generate_quantiles(self, batch_size, quantiles_num):
        _quantiles = th.rand([quantiles_num * batch_size, 1])  # [N*B, 1]
        _quantiles_tiled = _quantiles.repeat(1, self.quantiles_idx)  # [N*B, 1] => [N*B, X]

        # pi * i * tau [N*B, X] * [X, ] => [N*B, X]
        _quantiles_tiled = th.arange(self.quantiles_idx) * np.pi * _quantiles_tiled
        _quantiles_tiled.cos_()  # [N*B, X]

        _quantiles = _quantiles.view(batch_size, quantiles_num, 1)  # [N*B, 1] => [B, N, 1]
        return _quantiles, _quantiles_tiled  # [B, N, 1], [N*B, X]

    @iton
    def _train(self, BATCH):
        time_step = BATCH.reward.shape[0]
        batch_size = BATCH.reward.shape[1]

        quantiles, quantiles_tiled = self._generate_quantiles(  # [T*B, N, 1], [N*T*B, X]
            batch_size=time_step * batch_size,
            quantiles_num=self.online_quantiles)
        # [T*B, N, 1] => [T, B, N, 1]
        quantiles = quantiles.view(time_step, batch_size, -1, 1)
        quantiles_tiled = quantiles_tiled.view(time_step, -1, self.quantiles_idx)  # [N*T*B, X] => [T, N*B, X]

        quantiles_value = self.q_net(BATCH.obs, quantiles_tiled, begin_mask=BATCH.begin_mask)  # [T, N, B, A]
        # [T, N, B, A] => [N, T, B, A] * [T, B, A] => [N, T, B, 1]
        quantiles_value = (quantiles_value.swapaxes(0, 1) * BATCH.action).sum(-1, keepdim=True)
        q_eval = quantiles_value.mean(0)  # [N, T, B, 1] => [T, B, 1]

        _, select_quantiles_tiled = self._generate_quantiles(  # [N*T*B, X]
            batch_size=time_step * batch_size,
            quantiles_num=self.select_quantiles)
        select_quantiles_tiled = select_quantiles_tiled.view(
            time_step, -1, self.quantiles_idx)  # [N*T*B, X] => [T, N*B, X]

        q_values = self.q_net(
            BATCH.obs_, select_quantiles_tiled, begin_mask=BATCH.begin_mask)  # [T, N, B, A]
        q_values = q_values.mean(1)  # [T, N, B, A] => [T, B, A]
        next_max_action = q_values.argmax(-1)  # [T, B]
        next_max_action = F.one_hot(
            next_max_action, self.a_dim).float()  # [T, B, A]

        _, target_quantiles_tiled = self._generate_quantiles(  # [N'*T*B, X]
            batch_size=time_step * batch_size,
            quantiles_num=self.target_quantiles)
        target_quantiles_tiled = target_quantiles_tiled.view(
            time_step, -1, self.quantiles_idx)  # [N'*T*B, X] => [T, N'*B, X]
        target_quantiles_value = self.q_net.t(BATCH.obs_, target_quantiles_tiled,
                                              begin_mask=BATCH.begin_mask)  # [T, N', B, A]
        target_quantiles_value = target_quantiles_value.swapaxes(0, 1)  # [T, N', B, A] => [N', T, B, A]
        target_quantiles_value = (target_quantiles_value * next_max_action).sum(-1, keepdim=True)  # [N', T, B, 1]

        target_q = target_quantiles_value.mean(0)  # [T, B, 1]
        q_target = n_step_return(BATCH.reward,  # [T, B, 1]
                                 self.gamma,
                                 BATCH.done,  # [T, B, 1]
                                 target_q,  # [T, B, 1]
                                 BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1]

        # [N', T, B, 1] => [N', T, B]
        target_quantiles_value = target_quantiles_value.squeeze(-1)
        target_quantiles_value = target_quantiles_value.permute(
            1, 2, 0)  # [N', T, B] => [T, B, N']
        quantiles_value_target = n_step_return(BATCH.reward.repeat(1, 1, self.target_quantiles),
                                               self.gamma,
                                               BATCH.done.repeat(1, 1, self.target_quantiles),
                                               target_quantiles_value,
                                               BATCH.begin_mask.repeat(1, 1,
                                                                       self.target_quantiles)).detach()  # [T, B, N']
        # [T, B, N'] => [T, B, 1, N']
        quantiles_value_target = quantiles_value_target.unsqueeze(-2)
        quantiles_value_online = quantiles_value.permute(1, 2, 0, 3)  # [N, T, B, 1] => [T, B, N, 1]
        # [T, B, N, 1] - [T, B, 1, N'] => [T, B, N, N']
        quantile_error = quantiles_value_online - quantiles_value_target
        huber = F.huber_loss(quantiles_value_online, quantiles_value_target,
                             reduction="none", delta=self.huber_delta)  # [T, B, N, N]
        # [T, B, N, 1] - [T, B, N, N'] => [T, B, N, N']
        huber_abs = (quantiles - quantile_error.detach().le(0.).float()).abs()
        loss = (huber_abs * huber).mean(-1)  # [T, B, N, N'] => [T, B, N]
        loss = loss.sum(-1, keepdim=True)  # [T, B, N] => [T, B, 1]

        loss = (loss * BATCH.get('isw', 1.0)).mean()  # 1
        self.oplr.optimize(loss)
        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            self.q_net.sync()
Esempio n. 6
0
class QRDQN(SarlOffPolicy):
    """
    Quantile Regression DQN
    Distributional Reinforcement Learning with Quantile Regression, https://arxiv.org/abs/1710.10044
    No double, no dueling, no noisy net.
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 nums=20,
                 huber_delta=1.,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=1000,
                 network_settings=[128, 128],
                 **kwargs):
        assert nums > 0, 'assert nums > 0'
        super().__init__(**kwargs)
        assert not self.is_continuous, 'qrdqn only support discrete action space'
        self.nums = nums
        self.huber_delta = huber_delta
        self.quantiles = th.tensor((2 * np.arange(self.nums) + 1) / (2.0 * self.nums)).float().to(self.device)  # [N,]
        self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init,
                                                          eps_mid=eps_mid,
                                                          eps_final=eps_final,
                                                          init2mid_annealing_step=init2mid_annealing_step,
                                                          max_step=self._max_train_step)
        self.assign_interval = assign_interval
        self.q_net = TargetTwin(QrdqnDistributional(self.obs_spec,
                                                    rep_net_params=self._rep_net_params,
                                                    action_dim=self.a_dim,
                                                    nums=self.nums,
                                                    network_settings=network_settings)).to(self.device)
        self.oplr = OPLR(self.q_net, lr, **self._oplr_params)
        self._trainer_modules.update(model=self.q_net,
                                     oplr=self.oplr)

    @iton
    def select_action(self, obs):
        q_values = self.q_net(obs, rnncs=self.rnncs)  # [B, A, N]
        self.rnncs_ = self.q_net.get_rnncs()

        if self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            q = q_values.mean(-1)  # [B, A, N] => [B, A]
            actions = q.argmax(-1)  # [B,]
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        q_dist = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A, N]
        q_dist = (q_dist * BATCH.action.unsqueeze(-1)).sum(-2)  # [T, B, A, N] => [T, B, N]

        target_q_dist = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A, N]
        target_q = target_q_dist.mean(-1)  # [T, B, A, N] => [T, B, A]
        _a = target_q.argmax(-1)  # [T, B]
        next_max_action = F.one_hot(_a, self.a_dim).float().unsqueeze(-1)  # [T, B, A, 1]
        # [T, B, A, N] => [T, B, N]
        target_q_dist = (target_q_dist * next_max_action).sum(-2)

        target = n_step_return(BATCH.reward.repeat(1, 1, self.nums),
                               self.gamma,
                               BATCH.done.repeat(1, 1, self.nums),
                               target_q_dist,
                               BATCH.begin_mask.repeat(1, 1, self.nums)).detach()  # [T, B, N]

        q_eval = q_dist.mean(-1, keepdim=True)  # [T, B, 1]
        q_target = target.mean(-1, keepdim=True)  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1], used for PER

        target = target.unsqueeze(-2)  # [T, B, 1, N]
        q_dist = q_dist.unsqueeze(-1)  # [T, B, N, 1]

        # [T, B, 1, N] - [T, B, N, 1] => [T, B, N, N]
        quantile_error = target - q_dist
        huber = F.huber_loss(target, q_dist, reduction="none", delta=self.huber_delta)  # [T, B, N, N]
        # [N,] - [T, B, N, N] => [T, B, N, N]
        huber_abs = (self.quantiles - quantile_error.detach().le(0.).float()).abs()
        loss = (huber_abs * huber).mean(-1)  # [T, B, N, N] => [T, B, N]
        loss = loss.sum(-1, keepdim=True)  # [T, B, N] => [T, B, 1]
        loss = (loss * BATCH.get('isw', 1.0)).mean()  # 1

        self.oplr.optimize(loss)
        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            self.q_net.sync()
Esempio n. 7
0
class DreamerV1(SarlOffPolicy):
    """
    Dream to Control: Learning Behaviors by Latent Imagination, http://arxiv.org/abs/1912.01603
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 eps_init: float = 1,
                 eps_mid: float = 0.2,
                 eps_final: float = 0.01,
                 init2mid_annealing_step: int = 1000,
                 stoch_dim=30,
                 deter_dim=200,
                 model_lr=6e-4,
                 actor_lr=8e-5,
                 critic_lr=8e-5,
                 kl_free_nats=3,
                 action_sigma=0.3,
                 imagination_horizon=15,
                 lambda_=0.95,
                 kl_scale=1.0,
                 reward_scale=1.0,
                 use_pcont=False,
                 pcont_scale=10.0,
                 network_settings=dict(),
                 **kwargs):
        super().__init__(**kwargs)

        assert self.use_rnn == False, 'assert self.use_rnn == False'

        if self.obs_spec.has_visual_observation \
                and len(self.obs_spec.visual_dims) == 1 \
                and not self.obs_spec.has_vector_observation:
            visual_dim = self.obs_spec.visual_dims[0]
            # TODO: optimize this
            assert visual_dim[0] == visual_dim[
                1] == 64, 'visual dimension must be [64, 64, *]'
            self._is_visual = True
        elif self.obs_spec.has_vector_observation \
                and len(self.obs_spec.vector_dims) == 1 \
                and not self.obs_spec.has_visual_observation:
            self._is_visual = False
        else:
            raise ValueError("please check the observation type")

        self.stoch_dim = stoch_dim
        self.deter_dim = deter_dim
        self.kl_free_nats = kl_free_nats
        self.imagination_horizon = imagination_horizon
        self.lambda_ = lambda_
        self.kl_scale = kl_scale
        self.reward_scale = reward_scale
        # https://github.com/danijar/dreamer/issues/2
        self.use_pcont = use_pcont  # probability of continuing
        self.pcont_scale = pcont_scale
        self._action_sigma = action_sigma
        self._network_settings = network_settings

        if not self.is_continuous:
            self.expl_expt_mng = ExplorationExploitationClass(
                eps_init=eps_init,
                eps_mid=eps_mid,
                eps_final=eps_final,
                init2mid_annealing_step=init2mid_annealing_step,
                max_step=self._max_train_step)

        if self.obs_spec.has_visual_observation:
            from rls.nn.dreamer import VisualDecoder, VisualEncoder
            self.obs_encoder = VisualEncoder(
                self.obs_spec.visual_dims[0],
                **network_settings['obs_encoder']['visual']).to(self.device)
            self.obs_decoder = VisualDecoder(
                self.decoder_input_dim, self.obs_spec.visual_dims[0],
                **network_settings['obs_decoder']['visual']).to(self.device)
        else:
            from rls.nn.dreamer import VectorEncoder
            self.obs_encoder = VectorEncoder(
                self.obs_spec.vector_dims[0],
                **network_settings['obs_encoder']['vector']).to(self.device)
            self.obs_decoder = DenseModel(
                self.decoder_input_dim, self.obs_spec.vector_dims[0],
                **network_settings['obs_decoder']['vector']).to(self.device)

        self.rssm = self._dreamer_build_rssm()
        """
        p(r_t | s_t, h_t)
        Reward model to predict reward from state and rnn hidden state
        """
        self.reward_predictor = DenseModel(self.decoder_input_dim, 1,
                                           **network_settings['reward']).to(
                                               self.device)

        self.actor = ActionDecoder(self.a_dim,
                                   self.decoder_input_dim,
                                   dist=self._action_dist,
                                   **network_settings['actor']).to(self.device)
        self.critic = self._dreamer_build_critic()

        _modules = [
            self.obs_encoder, self.rssm, self.obs_decoder,
            self.reward_predictor
        ]
        if self.use_pcont:
            self.pcont_decoder = DenseModel(self.decoder_input_dim, 1,
                                            **network_settings['pcont']).to(
                                                self.device)
            _modules.append(self.pcont_decoder)

        self.model_oplr = OPLR(_modules, model_lr, **self._oplr_params)
        self.actor_oplr = OPLR(self.actor, actor_lr, **self._oplr_params)
        self.critic_oplr = OPLR(self.critic, critic_lr, **self._oplr_params)
        self._trainer_modules.update(obs_encoder=self.obs_encoder,
                                     obs_decoder=self.obs_decoder,
                                     reward_predictor=self.reward_predictor,
                                     rssm=self.rssm,
                                     actor=self.actor,
                                     critic=self.critic,
                                     model_oplr=self.model_oplr,
                                     actor_oplr=self.actor_oplr,
                                     critic_oplr=self.critic_oplr)
        if self.use_pcont:
            self._trainer_modules.update(pcont_decoder=self.pcont_decoder)

    @property
    def _action_dist(self):
        return 'tanh_normal' if self.is_continuous else 'one_hot'  # 'relaxed_one_hot'

    @property
    def decoder_input_dim(self):
        return self.stoch_dim + self.deter_dim

    def _dreamer_build_rssm(self):
        return RecurrentStateSpaceModel(
            self.stoch_dim, self.deter_dim, self.a_dim, self.obs_encoder.h_dim,
            **self._network_settings['rssm']).to(self.device)

    def _dreamer_build_critic(self):
        return DenseModel(self.decoder_input_dim, 1,
                          **self._network_settings['critic']).to(self.device)

    @iton
    def select_action(self, obs):
        if self._is_visual:
            obs = get_first_visual(obs)
        else:
            obs = get_first_vector(obs)
        embedded_obs = self.obs_encoder(obs)  # [B, *]
        state_posterior = self.rssm.posterior(self.rnncs['hx'], embedded_obs)
        state = state_posterior.sample()  # [B, *]
        actions = self.actor.sample_actions(th.cat((state, self.rnncs['hx']),
                                                   -1),
                                            is_train=self._is_train_mode)
        actions = self._exploration(actions)
        _, self.rnncs_['hx'] = self.rssm.prior(state, actions,
                                               self.rnncs['hx'])
        if not self.is_continuous:
            actions = actions.argmax(-1)  # [B,]
        return actions, Data(action=actions)

    def _exploration(self, action: th.Tensor) -> th.Tensor:
        """
        :param action: action to take, shape (1,) (if categorical), or (action dim,) (if continuous)
        :return: action of the same shape passed in, augmented with some noise
        """
        if self.is_continuous:
            sigma = self._action_sigma if self._is_train_mode else 0.
            noise = th.randn(*action.shape) * sigma
            return th.clamp(action + noise, -1, 1)
        else:
            if self._is_train_mode and self.expl_expt_mng.is_random(
                    self._cur_train_step):
                index = th.randint(0, self.a_dim, (self.n_copies, ))
                action = th.zeros_like(action)
                action[..., index] = 1
            return action

    @iton
    def _train(self, BATCH):
        T, B = BATCH.action.shape[:2]
        if self._is_visual:
            obs_ = get_first_visual(BATCH.obs_)
        else:
            obs_ = get_first_vector(BATCH.obs_)

        # embed observations with CNN
        embedded_observations = self.obs_encoder(obs_)  # [T, B, *]

        # initialize state and rnn hidden state with 0 vector
        state, rnn_hidden = self.rssm.init_state(shape=B)  # [B, S], [B, D]

        # compute state and rnn hidden sequences and kl loss
        kl_loss = 0
        states, rnn_hiddens = [], []
        for l in range(T):
            # if the begin of this episode, then reset to 0.
            # No matther whether last episode is beened truncated of not.
            state = state * (1. - BATCH.begin_mask[l])  # [B, S]
            rnn_hidden = rnn_hidden * (1. - BATCH.begin_mask[l])  # [B, D]

            next_state_prior, next_state_posterior, rnn_hidden = self.rssm(
                state, BATCH.action[l], rnn_hidden,
                embedded_observations[l])  # a, s_
            state = next_state_posterior.rsample()  # [B, S] posterior of s_
            states.append(state)  # [B, S]
            rnn_hiddens.append(rnn_hidden)  # [B, D]
            kl_loss += self._kl_loss(next_state_prior, next_state_posterior)
        kl_loss /= T  # 1

        # compute reconstructed observations and predicted rewards
        post_feat = th.cat([th.stack(states, 0),
                            th.stack(rnn_hiddens, 0)], -1)  # [T, B, *]

        obs_pred = self.obs_decoder(post_feat)  # [T, B, C, H, W] or [T, B, *]
        reward_pred = self.reward_predictor(post_feat)  # [T, B, 1], s_ => r

        # compute loss for observation and reward
        obs_loss = -th.mean(obs_pred.log_prob(obs_))  # [T, B] => 1
        # [T, B, 1]=>1
        reward_loss = -th.mean(
            reward_pred.log_prob(BATCH.reward).unsqueeze(-1))

        # add all losses and update model parameters with gradient descent
        model_loss = self.kl_scale * kl_loss + obs_loss + self.reward_scale * reward_loss  # 1

        if self.use_pcont:
            pcont_pred = self.pcont_decoder(post_feat)  # [T, B, 1], s_ => done
            # https://github.com/danijar/dreamer/issues/2#issuecomment-605392659
            pcont_target = self.gamma * (1. - BATCH.done)
            # [T, B, 1]=>1
            pcont_loss = -th.mean(
                pcont_pred.log_prob(pcont_target).unsqueeze(-1))
            model_loss += self.pcont_scale * pcont_loss

        self.model_oplr.optimize(model_loss)

        # remove gradients from previously calculated tensors
        with th.no_grad():
            # [T, B, S] => [T*B, S]
            flatten_states = th.cat(states, 0).detach()
            # [T, B, D] => [T*B, D]
            flatten_rnn_hiddens = th.cat(rnn_hiddens, 0).detach()

        with FreezeParameters(self.model_oplr.parameters):
            # compute target values
            imaginated_states = []
            imaginated_rnn_hiddens = []
            log_probs = []
            entropies = []

            for h in range(self.imagination_horizon):
                imaginated_states.append(flatten_states)  # [T*B, S]
                imaginated_rnn_hiddens.append(flatten_rnn_hiddens)  # [T*B, D]

                flatten_feat = th.cat([flatten_states, flatten_rnn_hiddens],
                                      -1).detach()
                action_dist = self.actor(flatten_feat)
                actions = action_dist.rsample()  # [T*B, A]
                log_probs.append(
                    action_dist.log_prob(
                        actions.detach()).unsqueeze(-1))  # [T*B, 1]
                entropies.append(
                    action_dist.entropy().unsqueeze(-1))  # [T*B, 1]
                flatten_states_prior, flatten_rnn_hiddens = self.rssm.prior(
                    flatten_states, actions, flatten_rnn_hiddens)
                flatten_states = flatten_states_prior.rsample()  # [T*B, S]

            imaginated_states = th.stack(imaginated_states, 0)  # [H, T*B, S]
            imaginated_rnn_hiddens = th.stack(imaginated_rnn_hiddens,
                                              0)  # [H, T*B, D]
            log_probs = th.stack(log_probs, 0)  # [H, T*B, 1]
            entropies = th.stack(entropies, 0)  # [H, T*B, 1]

        imaginated_feats = th.cat([imaginated_states, imaginated_rnn_hiddens],
                                  -1)  # [H, T*B, *]

        with FreezeParameters(self.model_oplr.parameters +
                              self.critic_oplr.parameters):
            imaginated_rewards = self.reward_predictor(
                imaginated_feats).mean  # [H, T*B, 1]
            imaginated_values = self._dreamer_target_img_value(
                imaginated_feats)  # [H, T*B, 1]]

        # Compute the exponential discounted sum of rewards
        if self.use_pcont:
            with FreezeParameters(self.pcont_decoder.parameters()):
                discount_arr = self.pcont_decoder(
                    imaginated_feats).mean  # [H, T*B, 1]
        else:
            discount_arr = self.gamma * th.ones_like(
                imaginated_rewards)  # [H, T*B, 1]

        returns = compute_return(imaginated_rewards[:-1],
                                 imaginated_values[:-1],
                                 discount_arr[:-1],
                                 bootstrap=imaginated_values[-1],
                                 lambda_=self.lambda_)  # [H-1, T*B, 1]
        # Make the top row 1 so the cumulative product starts with discount^0
        discount_arr = th.cat(
            [th.ones_like(discount_arr[:1]), discount_arr[:-1]],
            0)  # [H, T*B, 1]
        discount = th.cumprod(discount_arr, 0).detach()[:-1]  # [H-1, T*B, 1]

        # discount_arr = th.cat([th.ones_like(discount_arr[:1]), discount_arr[1:]])
        # discount = th.cumprod(discount_arr[:-1], 0)

        actor_loss = self._dreamer_build_actor_loss(imaginated_feats,
                                                    log_probs, entropies,
                                                    discount, returns)  # 1

        # Don't let gradients pass through to prevent overwriting gradients.
        # Value Loss
        with th.no_grad():
            value_feat = imaginated_feats[:-1].detach()  # [H-1, T*B, 1]
            value_target = returns.detach()  # [H-1, T*B, 1]

        value_pred = self.critic(value_feat)  # [H-1, T*B, 1]
        log_prob = value_pred.log_prob(value_target).unsqueeze(
            -1)  # [H-1, T*B, 1]
        critic_loss = -th.mean(discount * log_prob)  # 1

        self.actor_oplr.zero_grad()
        self.critic_oplr.zero_grad()

        self.actor_oplr.backward(actor_loss)
        self.critic_oplr.backward(critic_loss)

        self.actor_oplr.step()
        self.critic_oplr.step()

        td_error = (value_pred.mean - value_target).mean(0).detach()  # [T*B,]
        td_error = td_error.view(T, B, 1)

        summaries = {
            'LEARNING_RATE/model_lr': self.model_oplr.lr,
            'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
            'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
            'LOSS/model_loss': model_loss,
            'LOSS/kl_loss': kl_loss,
            'LOSS/obs_loss': obs_loss,
            'LOSS/reward_loss': reward_loss,
            'LOSS/actor_loss': actor_loss,
            'LOSS/critic_loss': critic_loss
        }
        if self.use_pcont:
            summaries.update({'LOSS/pcont_loss', pcont_loss})

        return td_error, summaries

    def _initial_rnncs(self, batch: int) -> Dict[str, np.ndarray]:
        return {'hx': np.zeros((batch, self.deter_dim))}

    def _kl_loss(self, prior_dist, post_dist):
        # 1
        return td.kl_divergence(prior_dist,
                                post_dist).clamp(min=self.kl_free_nats).mean()

    def _dreamer_target_img_value(self, imaginated_feats):
        imaginated_values = self.critic(imaginated_feats).mean  # [H, T*B, 1]
        return imaginated_values

    def _dreamer_build_actor_loss(self, imaginated_feats, log_probs, entropies,
                                  discount, returns):
        actor_loss = -th.mean(discount * returns)  # [H-1, T*B, 1] => 1
        return actor_loss
Esempio n. 8
0
class BCQ(SarlOffPolicy):
    """
    Benchmarking Batch Deep Reinforcement Learning Algorithms, http://arxiv.org/abs/1910.01708
    Off-Policy Deep Reinforcement Learning without Exploration, http://arxiv.org/abs/1812.02900
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 polyak=0.995,
                 discrete=dict(threshold=0.3,
                               lr=5.0e-4,
                               eps_init=1,
                               eps_mid=0.2,
                               eps_final=0.01,
                               init2mid_annealing_step=1000,
                               assign_interval=1000,
                               network_settings=[32, 32]),
                 continuous=dict(phi=0.05,
                                 lmbda=0.75,
                                 select_samples=100,
                                 train_samples=10,
                                 actor_lr=1e-3,
                                 critic_lr=1e-3,
                                 vae_lr=1e-3,
                                 network_settings=dict(
                                     actor=[32, 32],
                                     critic=[32, 32],
                                     vae=dict(encoder=[750, 750],
                                              decoder=[750, 750]))),
                 **kwargs):
        super().__init__(**kwargs)
        self._polyak = polyak

        if self.is_continuous:
            self._lmbda = continuous['lmbda']
            self._select_samples = continuous['select_samples']
            self._train_samples = continuous['train_samples']
            self.actor = TargetTwin(BCQ_Act_Cts(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                action_dim=self.a_dim,
                phi=continuous['phi'],
                network_settings=continuous['network_settings']['actor']),
                                    polyak=self._polyak).to(self.device)
            self.critic = TargetTwin(BCQ_CriticQvalueOne(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                action_dim=self.a_dim,
                network_settings=continuous['network_settings']['critic']),
                                     polyak=self._polyak).to(self.device)
            self.vae = VAE(self.obs_spec,
                           rep_net_params=self._rep_net_params,
                           a_dim=self.a_dim,
                           z_dim=self.a_dim * 2,
                           hiddens=continuous['network_settings']['vae']).to(
                               self.device)

            self.actor_oplr = OPLR(self.actor, continuous['actor_lr'],
                                   **self._oplr_params)
            self.critic_oplr = OPLR(self.critic, continuous['critic_lr'],
                                    **self._oplr_params)
            self.vae_oplr = OPLR(self.vae, continuous['vae_lr'],
                                 **self._oplr_params)
            self._trainer_modules.update(actor=self.actor,
                                         critic=self.critic,
                                         vae=self.vae,
                                         actor_oplr=self.actor_oplr,
                                         critic_oplr=self.critic_oplr,
                                         vae_oplr=self.vae_oplr)
        else:
            self.expl_expt_mng = ExplorationExploitationClass(
                eps_init=discrete['eps_init'],
                eps_mid=discrete['eps_mid'],
                eps_final=discrete['eps_final'],
                init2mid_annealing_step=discrete['init2mid_annealing_step'],
                max_step=self._max_train_step)
            self.assign_interval = discrete['assign_interval']
            self._threshold = discrete['threshold']
            self.q_net = TargetTwin(BCQ_DCT(
                self.obs_spec,
                rep_net_params=self._rep_net_params,
                output_shape=self.a_dim,
                network_settings=discrete['network_settings']),
                                    polyak=self._polyak).to(self.device)
            self.oplr = OPLR(self.q_net, discrete['lr'], **self._oplr_params)
            self._trainer_modules.update(model=self.q_net, oplr=self.oplr)

    @iton
    def select_action(self, obs):
        if self.is_continuous:
            _actions = []
            for _ in range(self._select_samples):
                _actions.append(
                    self.actor(obs, self.vae.decode(obs),
                               rnncs=self.rnncs))  # [B, A]
            self.rnncs_ = self.actor.get_rnncs(
            )  # TODO: calculate corrected hidden state
            _actions = th.stack(_actions, dim=0)  # [N, B, A]
            q1s = []
            for i in range(self._select_samples):
                q1s.append(self.critic(obs, _actions[i])[0])
            q1s = th.stack(q1s, dim=0)  # [N, B, 1]
            max_idxs = q1s.argmax(dim=0, keepdim=True)[-1]  # [1, B, 1]
            actions = _actions[
                max_idxs,
                th.arange(self.n_copies).reshape(self.n_copies, 1),
                th.arange(self.a_dim)]
        else:
            q_values, i_values = self.q_net(obs, rnncs=self.rnncs)  # [B, *]
            q_values = q_values - q_values.min(dim=-1,
                                               keepdim=True)[0]  # [B, *]
            i_values = F.log_softmax(i_values, dim=-1)  # [B, *]
            i_values = i_values.exp()  # [B, *]
            i_values = (i_values / i_values.max(-1, keepdim=True)[0] >
                        self._threshold).float()  # [B, *]

            self.rnncs_ = self.q_net.get_rnncs()

            if self._is_train_mode and self.expl_expt_mng.is_random(
                    self._cur_train_step):
                actions = np.random.randint(0, self.a_dim, self.n_copies)
            else:
                actions = (i_values * q_values).argmax(-1)  # [B,]
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        if self.is_continuous:
            # Variational Auto-Encoder Training
            recon, mean, std = self.vae(BATCH.obs,
                                        BATCH.action,
                                        begin_mask=BATCH.begin_mask)
            recon_loss = F.mse_loss(recon, BATCH.action)

            KL_loss = -0.5 * (1 + th.log(std.pow(2)) - mean.pow(2) -
                              std.pow(2)).mean()
            vae_loss = recon_loss + 0.5 * KL_loss

            self.vae_oplr.optimize(vae_loss)

            target_Qs = []
            for _ in range(self._train_samples):
                # Compute value of perturbed actions sampled from the VAE
                _vae_actions = self.vae.decode(BATCH.obs_,
                                               begin_mask=BATCH.begin_mask)
                _actor_actions = self.actor.t(BATCH.obs_,
                                              _vae_actions,
                                              begin_mask=BATCH.begin_mask)
                target_Q1, target_Q2 = self.critic.t(
                    BATCH.obs_, _actor_actions, begin_mask=BATCH.begin_mask)

                # Soft Clipped Double Q-learning
                target_Q = self._lmbda * th.min(target_Q1, target_Q2) + \
                           (1. - self._lmbda) * th.max(target_Q1, target_Q2)
                target_Qs.append(target_Q)
            target_Qs = th.stack(target_Qs, dim=0)  # [N, T, B, 1]
            # Take max over each BATCH.action sampled from the VAE
            target_Q = target_Qs.max(dim=0)[0]  # [T, B, 1]

            target_Q = n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                     target_Q,
                                     BATCH.begin_mask).detach()  # [T, B, 1]

            current_Q1, current_Q2 = self.critic(BATCH.obs,
                                                 BATCH.action,
                                                 begin_mask=BATCH.begin_mask)
            td_error = ((current_Q1 - target_Q) + (current_Q2 - target_Q)) / 2
            critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
                current_Q2, target_Q)

            self.critic_oplr.optimize(critic_loss)

            # Pertubation Model / Action Training
            sampled_actions = self.vae.decode(BATCH.obs,
                                              begin_mask=BATCH.begin_mask)
            perturbed_actions = self.actor(BATCH.obs,
                                           sampled_actions,
                                           begin_mask=BATCH.begin_mask)

            # Update through DPG
            q1, _ = self.critic(BATCH.obs,
                                perturbed_actions,
                                begin_mask=BATCH.begin_mask)
            actor_loss = -q1.mean()

            self.actor_oplr.optimize(actor_loss)

            return td_error, {
                'LEARNING_RATE/actor_lr': self.actor_oplr.lr,
                'LEARNING_RATE/critic_lr': self.critic_oplr.lr,
                'LEARNING_RATE/vae_lr': self.vae_oplr.lr,
                'LOSS/actor_loss': actor_loss,
                'LOSS/critic_loss': critic_loss,
                'LOSS/vae_loss': vae_loss,
                'Statistics/q_min': q1.min(),
                'Statistics/q_mean': q1.mean(),
                'Statistics/q_max': q1.max()
            }

        else:
            q_next, i_next = self.q_net(
                BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            q_next = q_next - q_next.min(dim=-1, keepdim=True)[0]  # [B, *]
            i_next = F.log_softmax(i_next, dim=-1)  # [T, B, A]
            i_next = i_next.exp()  # [T, B, A]
            i_next = (i_next / i_next.max(-1, keepdim=True)[0] >
                      self._threshold).float()  # [T, B, A]
            q_next = i_next * q_next  # [T, B, A]
            next_max_action = q_next.argmax(-1)  # [T, B]
            next_max_action_one_hot = F.one_hot(
                next_max_action.squeeze(), self.a_dim).float()  # [T, B, A]

            q_target_next, _ = self.q_net.t(
                BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
            q_target_next_max = (q_target_next * next_max_action_one_hot).sum(
                -1, keepdim=True)  # [T, B, 1]
            q_target = n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                     q_target_next_max,
                                     BATCH.begin_mask).detach()  # [T, B, 1]

            q, i = self.q_net(BATCH.obs,
                              begin_mask=BATCH.begin_mask)  # [T, B, A]
            q_eval = (q * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]

            td_error = q_target - q_eval  # [T, B, 1]
            q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1

            imt = F.log_softmax(i, dim=-1)  # [T, B, A]
            imt = imt.reshape(-1, self.a_dim)  # [T*B, A]
            action = BATCH.action.reshape(-1, self.a_dim)  # [T*B, A]
            i_loss = F.nll_loss(imt, action.argmax(-1))  # 1

            loss = q_loss + i_loss + 1e-2 * i.pow(2).mean()

            self.oplr.optimize(loss)
            return td_error, {
                'LEARNING_RATE/lr': self.oplr.lr,
                'LOSS/q_loss': q_loss,
                'LOSS/i_loss': i_loss,
                'LOSS/loss': loss,
                'Statistics/q_max': q_eval.max(),
                'Statistics/q_min': q_eval.min(),
                'Statistics/q_mean': q_eval.mean()
            }

    def _after_train(self):
        super()._after_train()
        if self.is_continuous:
            self.actor.sync()
            self.critic.sync()
        else:
            if self._polyak != 0:
                self.q_net.sync()
            else:
                if self._cur_train_step % self.assign_interval == 0:
                    self.q_net.sync()
Esempio n. 9
0
class AveragedDQN(SarlOffPolicy):
    """
    Averaged-DQN, http://arxiv.org/abs/1611.01929
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 target_k: int = 4,
                 lr: float = 5.0e-4,
                 eps_init: float = 1,
                 eps_mid: float = 0.2,
                 eps_final: float = 0.01,
                 init2mid_annealing_step: int = 1000,
                 assign_interval: int = 1000,
                 network_settings: List[int] = [32, 32],
                 **kwargs):
        super().__init__(**kwargs)
        assert not self.is_continuous, 'dqn only support discrete action space'
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self._max_train_step)
        self.assign_interval = assign_interval
        self.target_k = target_k
        assert self.target_k > 0, "assert self.target_k > 0"
        self.current_target_idx = 0

        self.q_net = CriticQvalueAll(self.obs_spec,
                                     rep_net_params=self._rep_net_params,
                                     output_shape=self.a_dim,
                                     network_settings=network_settings).to(
                                         self.device)
        self.target_nets = []
        for i in range(self.target_k):
            target_q_net = deepcopy(self.q_net)
            target_q_net.eval()
            sync_params(target_q_net, self.q_net)
            self.target_nets.append(target_q_net)

        self.oplr = OPLR(self.q_net, lr, **self._oplr_params)
        self._trainer_modules.update(model=self.q_net, oplr=self.oplr)

    @iton
    def select_action(self, obs):
        q_values = self.q_net(obs, rnncs=self.rnncs)  # [B, *]
        self.rnncs_ = self.q_net.get_rnncs()

        if self._is_train_mode and self.expl_expt_mng.is_random(
                self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            for i in range(self.target_k):
                target_q_values = self.target_nets[i](obs, rnncs=self.rnncs)
                q_values += target_q_values
            actions = q_values.argmax(-1)  # 不取平均也可以 [B, ]
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, *]
        q_next = 0
        for i in range(self.target_k):
            q_next += self.target_nets[i](BATCH.obs_,
                                          begin_mask=BATCH.begin_mask)
        q_next /= self.target_k  # [T, B, *]
        q_eval = (q * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        q_target = n_step_return(BATCH.reward, self.gamma, BATCH.done,
                                 q_next.max(-1, keepdim=True)[0],
                                 BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1]
        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1

        self.oplr.optimize(q_loss)
        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': q_loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            sync_params(self.target_nets[self.current_target_idx], self.q_net)
            self.current_target_idx = (self.current_target_idx +
                                       1) % self.target_k
Esempio n. 10
0
class DDDQN(SarlOffPolicy):
    """
    Dueling Double DQN, https://arxiv.org/abs/1511.06581
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=2,
                 network_settings={
                     'share': [128],
                     'v': [128],
                     'adv': [128]
                 },
                 **kwargs):
        super().__init__(**kwargs)
        assert not self.is_continuous, 'dueling double dqn only support discrete action space'
        self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init,
                                                          eps_mid=eps_mid,
                                                          eps_final=eps_final,
                                                          init2mid_annealing_step=init2mid_annealing_step,
                                                          max_step=self._max_train_step)
        self.assign_interval = assign_interval

        self.q_net = TargetTwin(CriticDueling(self.obs_spec,
                                              rep_net_params=self._rep_net_params,
                                              output_shape=self.a_dim,
                                              network_settings=network_settings)).to(self.device)

        self.oplr = OPLR(self.q_net, lr, **self._oplr_params)
        self._trainer_modules.update(model=self.q_net,
                                     oplr=self.oplr)

    @iton
    def select_action(self, obs):
        q_values = self.q_net(obs, rnncs=self.rnncs)  # [B, A]
        self.rnncs_ = self.q_net.get_rnncs()

        if self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            actions = q_values.argmax(-1)  # [B,]
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask)  # [T, B, A]
        next_q = self.q_net(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]
        q_target = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask)  # [T, B, A]

        q_eval = (q * BATCH.action).sum(-1, keepdim=True)  # [T, B, 1]
        next_max_action = next_q.argmax(-1)  # [T, B]
        next_max_action_one_hot = F.one_hot(next_max_action.squeeze(), self.a_dim).float()  # [T, B, A]

        q_target_next_max = (q_target * next_max_action_one_hot).sum(-1, keepdim=True)  # [T, B, 1]
        q_target = n_step_return(BATCH.reward,
                                 self.gamma,
                                 BATCH.done,
                                 q_target_next_max,
                                 BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1]
        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1
        self.oplr.optimize(q_loss)

        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': q_loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            self.q_net.sync()
Esempio n. 11
0
class VDN(MultiAgentOffPolicy):
    """
    Value-Decomposition Networks For Cooperative Multi-Agent Learning, http://arxiv.org/abs/1706.05296
    QMIX: Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning, http://arxiv.org/abs/1803.11485
    Qatten: A General Framework for Cooperative Multiagent Reinforcement Learning, http://arxiv.org/abs/2002.03939
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 mixer='vdn',
                 mixer_settings={},
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 use_double=True,
                 init2mid_annealing_step=1000,
                 assign_interval=1000,
                 network_settings={
                     'share': [128],
                     'v': [128],
                     'adv': [128]
                 },
                 **kwargs):
        super().__init__(**kwargs)
        assert not any(list(self.is_continuouss.values())
                       ), 'VDN only support discrete action space'
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self._max_train_step)
        self.assign_interval = assign_interval
        self._use_double = use_double
        self._mixer_type = mixer
        self._mixer_settings = mixer_settings

        self.q_nets = {}
        for id in set(self.model_ids):
            self.q_nets[id] = TargetTwin(
                CriticDueling(self.obs_specs[id],
                              rep_net_params=self._rep_net_params,
                              output_shape=self.a_dims[id],
                              network_settings=network_settings)).to(
                                  self.device)

        self.mixer = self._build_mixer()

        self.oplr = OPLR(
            tuple(self.q_nets.values()) + (self.mixer, ), lr,
            **self._oplr_params)
        self._trainer_modules.update(
            {f"model_{id}": self.q_nets[id]
             for id in set(self.model_ids)})
        self._trainer_modules.update(mixer=self.mixer, oplr=self.oplr)

    def _build_mixer(self):
        assert self._mixer_type in [
            'vdn', 'qmix', 'qatten'
        ], "assert self._mixer_type in ['vdn', 'qmix', 'qatten']"
        if self._mixer_type in ['qmix', 'qatten']:
            assert self._has_global_state, 'assert self._has_global_state'
        return TargetTwin(Mixer_REGISTER[self._mixer_type](
            n_agents=self.n_agents_percopy,
            state_spec=self.state_spec,
            rep_net_params=self._rep_net_params,
            **self._mixer_settings)).to(self.device)

    @iton  # TODO: optimization
    def select_action(self, obs):
        acts_info = {}
        actions = {}
        for aid, mid in zip(self.agent_ids, self.model_ids):
            q_values = self.q_nets[mid](obs[aid],
                                        rnncs=self.rnncs[aid])  # [B, A]
            self.rnncs_[aid] = self.q_nets[mid].get_rnncs()

            if self._is_train_mode and self.expl_expt_mng.is_random(
                    self._cur_train_step):
                action = np.random.randint(0, self.a_dims[aid], self.n_copies)
            else:
                action = q_values.argmax(-1)  # [B,]

            actions[aid] = action
            acts_info[aid] = Data(action=action)
        return actions, acts_info

    @iton
    def _train(self, BATCH_DICT):
        summaries = {}
        reward = BATCH_DICT[self.agent_ids[0]].reward  # [T, B, 1]
        done = 0.
        q_evals = []
        q_target_next_choose_maxs = []
        for aid, mid in zip(self.agent_ids, self.model_ids):
            done += BATCH_DICT[aid].done  # [T, B, 1]

            q = self.q_nets[mid](
                BATCH_DICT[aid].obs,
                begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
            q_eval = (q * BATCH_DICT[aid].action).sum(
                -1, keepdim=True)  # [T, B, 1]
            q_evals.append(q_eval)  # N * [T, B, 1]

            q_target = self.q_nets[mid].t(
                BATCH_DICT[aid].obs_,
                begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]
            if self._use_double:
                next_q = self.q_nets[mid](
                    BATCH_DICT[aid].obs_,
                    begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, A]

                next_max_action = next_q.argmax(-1)  # [T, B]
                next_max_action_one_hot = F.one_hot(
                    next_max_action, self.a_dims[aid]).float()  # [T, B, A]

                q_target_next_max = (q_target * next_max_action_one_hot).sum(
                    -1, keepdim=True)  # [T, B, 1]
            else:
                # [T, B, 1]
                q_target_next_max = q_target.max(-1, keepdim=True)[0]

            q_target_next_choose_maxs.append(
                q_target_next_max)  # N * [T, B, 1]

        q_evals = th.stack(q_evals, -1)  # [T, B, 1, N]
        q_target_next_choose_maxs = th.stack(q_target_next_choose_maxs,
                                             -1)  # [T, B, 1, N]
        q_eval_tot = self.mixer(
            q_evals,
            BATCH_DICT['global'].obs,
            begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, 1]
        q_target_next_max_tot = self.mixer.t(
            q_target_next_choose_maxs,
            BATCH_DICT['global'].obs_,
            begin_mask=BATCH_DICT['global'].begin_mask)  # [T, B, 1]

        q_target_tot = n_step_return(
            reward, self.gamma, (done > 0.).float(), q_target_next_max_tot,
            BATCH_DICT['global'].begin_mask).detach()  # [T, B, 1]
        td_error = q_target_tot - q_eval_tot  # [T, B, 1]
        q_loss = td_error.square().mean()  # 1
        self.oplr.optimize(q_loss)

        summaries['model'] = {
            'LOSS/q_loss': q_loss,
            'Statistics/q_max': q_eval_tot.max(),
            'Statistics/q_min': q_eval_tot.min(),
            'Statistics/q_mean': q_eval_tot.mean()
        }
        return td_error, summaries

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            for q_net in self.q_nets.values():
                q_net.sync()
            self.mixer.sync()
Esempio n. 12
0
class BootstrappedDQN(SarlOffPolicy):
    """
    Deep Exploration via Bootstrapped DQN, http://arxiv.org/abs/1602.04621
    """
    policy_mode = 'off-policy'

    def __init__(self,
                 lr=5.0e-4,
                 eps_init=1,
                 eps_mid=0.2,
                 eps_final=0.01,
                 init2mid_annealing_step=1000,
                 assign_interval=1000,
                 head_num=4,
                 network_settings=[32, 32],
                 **kwargs):
        super().__init__(**kwargs)
        assert not self.is_continuous, 'Bootstrapped DQN only support discrete action space'
        self.expl_expt_mng = ExplorationExploitationClass(
            eps_init=eps_init,
            eps_mid=eps_mid,
            eps_final=eps_final,
            init2mid_annealing_step=init2mid_annealing_step,
            max_step=self._max_train_step)
        self.assign_interval = assign_interval
        self.head_num = head_num
        self._probs = th.FloatTensor([1. / head_num for _ in range(head_num)])
        self.now_head = 0

        self.q_net = TargetTwin(
            CriticQvalueBootstrap(self.obs_spec,
                                  rep_net_params=self._rep_net_params,
                                  output_shape=self.a_dim,
                                  head_num=self.head_num,
                                  network_settings=network_settings)).to(
                                      self.device)

        self.oplr = OPLR(self.q_net, lr, **self._oplr_params)
        self._trainer_modules.update(model=self.q_net, oplr=self.oplr)

    def episode_reset(self):
        super().episode_reset()
        self.now_head = np.random.randint(self.head_num)

    @iton
    def select_action(self, obs):
        q_values = self.q_net(obs, rnncs=self.rnncs)  # [H, B, A]
        self.rnncs_ = self.q_net.get_rnncs()

        if self._is_train_mode and self.expl_expt_mng.is_random(
                self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            # [H, B, A] => [B, A] => [B, ]
            actions = q_values[self.now_head].argmax(-1)
        return actions, Data(action=actions)

    @iton
    def _train(self, BATCH):
        q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask).mean(
            0)  # [H, T, B, A] => [T, B, A]
        q_next = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask).mean(
            0)  # [H, T, B, A] => [T, B, A]
        # [T, B, A] * [T, B, A] => [T, B, 1]
        q_eval = (q * BATCH.action).sum(-1, keepdim=True)
        q_target = n_step_return(
            BATCH.reward,
            self.gamma,
            BATCH.done,
            # [T, B, A] => [T, B, 1]
            q_next.max(-1, keepdim=True)[0],
            BATCH.begin_mask).detach()  # [T, B, 1]
        td_error = q_target - q_eval  # [T, B, 1]
        q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean()  # 1

        # mask_dist = td.Bernoulli(probs=self._probs)  # TODO:
        # mask = mask_dist.sample([batch_size]).T   # [H, B]
        self.oplr.optimize(q_loss)
        return td_error, {
            'LEARNING_RATE/lr': self.oplr.lr,
            'LOSS/loss': q_loss,
            'Statistics/q_max': q_eval.max(),
            'Statistics/q_min': q_eval.min(),
            'Statistics/q_mean': q_eval.mean()
        }

    def _after_train(self):
        super()._after_train()
        if self._cur_train_step % self.assign_interval == 0:
            self.q_net.sync()