Example #1
0
    def select_action(self, obs):
        if self.is_continuous:
            if self._share_net:
                mu, log_std, value = self.net(obs, rnncs=self.rnncs)  # [B, A]
                self.rnncs_ = self.net.get_rnncs()
            else:
                mu, log_std = self.actor(obs, rnncs=self.rnncs)  # [B, A]
                self.rnncs_ = self.actor.get_rnncs()
                value = self.critic(obs, rnncs=self.rnncs)  # [B, 1]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            action = dist.sample().clamp(-1, 1)  # [B, A]
            log_prob = dist.log_prob(action).unsqueeze(-1)  # [B, 1]
        else:
            if self._share_net:
                logits, value = self.net(obs, rnncs=self.rnncs)  # [B, A], [B, 1]
                self.rnncs_ = self.net.get_rnncs()
            else:
                logits = self.actor(obs, rnncs=self.rnncs)  # [B, A]
                self.rnncs_ = self.actor.get_rnncs()
                value = self.critic(obs, rnncs=self.rnncs)  # [B, 1]
            norm_dist = td.Categorical(logits=logits)
            action = norm_dist.sample()  # [B,]
            log_prob = norm_dist.log_prob(action).unsqueeze(-1)  # [B, 1]

        acts_info = Data(action=action,
                         value=value,
                         log_prob=log_prob + th.finfo().eps)
        if self.use_rnn:
            acts_info.update(rnncs=self.rnncs)
        return action, acts_info
Example #2
0
    def reset(self, **kwargs) -> Dict[str, Data]:
        obss = dict()
        obs = self._envs.run('reset')  # list(dict)

        for k in self._agents:
            obss[k] = np.stack([obs[i][k] for i in range(self._n_copies)], 0)

        rets = {}
        for k in self._agents:
            if self._is_obs_visual[k]:
                rets[k] = Data(visual={'visual_0': obss[k]})
            else:
                rets[k] = Data(vector={'vector_0': obss[k]})
        rets['global'] = Data(begin_mask=np.full((self._n_copies, 1), True))

        if self._has_global_state:
            state = self._envs.run('state')
            state = np.stack(state, 0)  # [B, *]
            if self._is_state_visual:
                _state = Data(visual={'visual_0': state})
            else:
                _state = Data(vector={'vector_0': state})
            rets['global'].update(obs=_state)

        return rets
Example #3
0
    def episode_step(self,
                     obs,
                     env_rets: Dict[str, Data]):
        super().episode_step()
        if self._store:
            expss = {}
            for id in self.agent_ids:
                expss[id] = Data(obs=obs[id],
                                 # [B, ] => [B, 1]
                                 reward=env_rets[id].reward[:, np.newaxis],
                                 obs_=env_rets[id].obs_fs,
                                 done=env_rets[id].done[:, np.newaxis])
                expss[id].update(self._acts_info[id])
            expss['global'] = Data(begin_mask=obs['global'].begin_mask)
            if self._has_global_state:
                expss['global'].update(obs=obs['global'].obs,
                                       obs_=env_rets['global'].obs)
            self._buffer.add(expss)

        for id in self.agent_ids:
            idxs = np.where(env_rets[id].done)[0]
            self._pre_acts[id][idxs] = 0.
            self.rnncs[id] = self.rnncs_[id]
            if self.rnncs[id] is not None:
                for k in self.rnncs[id].keys():
                    self.rnncs[id][k][idxs] = 0.
Example #4
0
 def reset(self, **kwargs) -> Dict[str, Data]:
     obs = self._envs.run('reset')
     obs = np.stack(obs, 0)
     if self._use_visual:
         ret = Data(visual={'visual_0': obs})
     else:
         ret = Data(vector={'vector_0': obs})
     return {
         'single': ret,
         'global': Data(begin_mask=np.full((self._n_copies, 1), True))
     }
Example #5
0
 def select_action(self, obs):
     output = self.actor(obs, rnncs=self.rnncs)  # [B, A]
     self.rnncs_ = self.actor.get_rnncs()
     value = self.critic(obs, rnncs=self.rnncs)  # [B, 1]
     if self.is_continuous:
         mu, log_std = output  # [B, A]
         dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
         action = dist.sample().clamp(-1, 1)  # [B, A]
         log_prob = dist.log_prob(action).unsqueeze(-1)  # [B, 1]
     else:
         logits = output  # [B, A]
         logp_all = logits.log_softmax(-1)  # [B, A]
         norm_dist = td.Categorical(logits=logp_all)
         action = norm_dist.sample()  # [B,]
         log_prob = norm_dist.log_prob(action).unsqueeze(-1)  # [B, 1]
     acts_info = Data(action=action,
                      value=value,
                      log_prob=log_prob + th.finfo().eps)
     if self.use_rnn:
         acts_info.update(rnncs=self.rnncs)
     if self.is_continuous:
         acts_info.update(mu=mu, log_std=log_std)
     else:
         acts_info.update(logp_all=logp_all)
     return action, acts_info
Example #6
0
 def select_action(self, obs):
     q = self.q_net(obs, rnncs=self.rnncs)  # [B, P]
     self.rnncs_ = self.q_net.get_rnncs()
     pi = self.intra_option_net(obs, rnncs=self.rnncs)  # [B, P, A]
     beta = self.termination_net(obs, rnncs=self.rnncs)  # [B, P]
     options_onehot = F.one_hot(self.options,
                                self.options_num).float()  # [B, P]
     options_onehot_expanded = options_onehot.unsqueeze(-1)  # [B, P, 1]
     pi = (pi * options_onehot_expanded).sum(-2)  # [B, A]
     if self.is_continuous:
         mu = pi.tanh()  # [B, A]
         log_std = self.log_std[self.options]  # [B, A]
         dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
         actions = dist.sample().clamp(-1, 1)  # [B, A]
     else:
         pi = pi / self.boltzmann_temperature  # [B, A]
         dist = td.Categorical(logits=pi)
         actions = dist.sample()  # [B, ]
     max_options = q.argmax(-1).long()  # [B, P] => [B, ]
     if self.use_eps_greedy:
         # epsilon greedy
         if self._is_train_mode and self.expl_expt_mng.is_random(
                 self._cur_train_step):
             self.new_options = self._generate_random_options()
         else:
             self.new_options = max_options
     else:
         beta_probs = (beta * options_onehot).sum(-1)  # [B, P] => [B,]
         beta_dist = td.Bernoulli(probs=beta_probs)
         self.new_options = th.where(beta_dist.sample() < 1, self.options,
                                     max_options)
     return actions, Data(action=actions,
                          last_options=self.options,
                          options=self.new_options)
Example #7
0
    def sample(self, batchsize=None, chunk_length=None):
        if batchsize == 0:
            return self.all_data()

        B = batchsize or self.batch_size
        T = chunk_length or self._chunk_length
        assert T <= self._horizon_length

        if self._horizon_length == self.max_horizon:
            start = self._pointer - self.max_horizon
        else:
            start = 0
        end = self._pointer - T + 1

        x = np.random.randint(start, end, B)  # [B, ]
        y = np.random.randint(0, self.n_copies, B)  # (B, )
        # (T, B) + (B, ) = (T, B)
        xs = (np.tile(np.arange(T)[:, np.newaxis],
                      B) + x) % self._horizon_length
        sample_idxs = (xs, y)
        samples = {}
        for k, v in self._buffer.items():
            samples[k] = Data.from_nested_dict(
                {_k: _v[sample_idxs] for _k, _v in v.items()}
            )
        return samples  # [T, B, *]
Example #8
0
    def select_action(self, obs):
        output = self.actor(obs, rnncs=self.rnncs)  # [B, A]
        self.rnncs_ = self.actor.get_rnncs()
        if self.is_continuous:
            mu, log_std = output  # [B, A]
            dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
            action = dist.sample().clamp(-1, 1)  # [B, A]
        else:
            logits = output  # [B, A]
            norm_dist = td.Categorical(logits=logits)
            action = norm_dist.sample()  # [B,]

        acts_info = Data(action=action)
        if self.use_rnn:
            acts_info.update(rnncs=self.rnncs)
        return action, acts_info
Example #9
0
 def select_action(self, obs):
     q_values = self.q_net(obs, rnncs=self.rnncs)  # [B, A]
     self.rnncs_ = self.q_net.get_rnncs()
     logits = ((q_values - self._get_v(q_values)) / self.alpha).exp()  # > 0   # [B, A]
     logits /= logits.sum(-1, keepdim=True)  # [B, A]
     cate_dist = td.Categorical(logits=logits)
     actions = cate_dist.sample()  # [B,]
     return actions, Data(action=actions)
Example #10
0
 def random_action(self):
     if self.is_continuous:
         actions = np.random.uniform(-1.0, 1.0, (self.n_copies, self.a_dim))
     else:
         actions = np.random.randint(0, self.a_dim, self.n_copies)
     self._pre_act = actions
     self._acts_info = Data(action=actions)
     return actions
Example #11
0
    def select_action(self, obs):
        q_values = self.q_net(obs, rnncs=self.rnncs)  # [B, A]
        self.rnncs_ = self.q_net.get_rnncs()

        if self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            actions = q_values.argmax(-1)  # [B,]
        return actions, Data(action=actions)
Example #12
0
    def episode_step(self, obs: Data, env_rets: Data, begin_mask: np.ndarray):
        super().episode_step()
        if self._store:
            exps = Data(
                obs=obs,
                # [B, ] => [B, 1]
                reward=env_rets.reward[:, np.newaxis],
                obs_=env_rets.obs_fs,
                done=env_rets.done[:, np.newaxis],
                begin_mask=begin_mask)
            exps.update(self._acts_info)
            self._buffer.add({self._agent_id: exps})

        idxs = np.where(env_rets.done)[0]
        self._pre_act[idxs] = 0.
        self.rnncs = self.rnncs_
        if self.rnncs is not None:
            for k in self.rnncs.keys():
                self.rnncs[k][idxs] = 0.
Example #13
0
 def learn(self, BATCH: Data):
     BATCH = self._preprocess_BATCH(BATCH)  # [T, B, *]
     for _ in range(self._epochs):
         for _BATCH in BATCH.sample(self._chunk_length,
                                    self.batch_size,
                                    repeat=self._sample_allow_repeat):
             _BATCH = self._before_train(_BATCH)
             summaries = self._train(_BATCH)
             self.summaries.update(summaries)
             self._after_train()
Example #14
0
    def select_action(self, obs):
        q = self.critic(obs, rnncs=self.rnncs)  # [B, A]
        self.rnncs_ = self.critic.get_rnncs()

        if self.use_epsilon and self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            cate_dist = td.Categorical(logits=(q / self.alpha))
            mu = q.argmax(-1)  # [B,]
            actions = pi = cate_dist.sample()  # [B,]
        return actions, Data(action=actions)
Example #15
0
    def select_action(self, obs):
        feat = self.q_net(obs, rnncs=self.rnncs)  # [B, A, N]
        self.rnncs_ = self.q_net.get_rnncs()

        if self._is_train_mode and self.expl_expt_mng.is_random(
                self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            q = (self._z * feat).sum(-1)  # [B, A, N] * [N,] => [B, A]
            actions = q.argmax(-1)  # [B,]
        return actions, Data(action=actions)
Example #16
0
    def select_action(self, obs):
        if self._is_visual:
            obs = get_first_visual(obs)
        else:
            obs = get_first_vector(obs)
        # Compute starting state for planning
        # while taking information from current observation (posterior)
        embedded_obs = self.obs_encoder(obs)  # [B, *]
        state_posterior = self.rssm.posterior(self.rnncs['hx'], embedded_obs)  # dist # [B, *]

        # Initialize action distribution
        mean = th.zeros((self.cem_horizon, 1, self.n_copies, self.a_dim))  # [H, 1, B, A]
        stddev = th.ones((self.cem_horizon, 1, self.n_copies, self.a_dim))  # [H, 1, B, A]

        # Iteratively improve action distribution with CEM
        for itr in range(self.cem_iter_nums):
            action_candidates = mean + stddev * \
                                th.randn(self.cem_horizon, self.cem_candidates, self.n_copies,
                                         self.a_dim)  # [H, N, B, A]
            action_candidates = action_candidates.reshape(self.cem_horizon, -1, self.a_dim)  # [H, N*B, A]

            # Initialize reward, state, and rnn hidden state
            # These are for parallel exploration
            total_predicted_reward = th.zeros((self.cem_candidates * self.n_copies, 1))  # [N*B, 1]

            state = state_posterior.sample((self.cem_candidates,))  # [N, B, *]
            state = state.view(-1, state.shape[-1])  # [N*B, *]
            rnn_hidden = self.rnncs['hx'].repeat((self.cem_candidates, 1))  # [B, *] => [N*B, *]

            # Compute total predicted reward by open-loop prediction using pri
            for t in range(self.cem_horizon):
                next_state_prior, rnn_hidden = self.rssm.prior(state, th.tanh(action_candidates[t]), rnn_hidden)
                state = next_state_prior.sample()  # [N*B, *]
                post_feat = th.cat([state, rnn_hidden], -1)  # [N*B, *]
                total_predicted_reward += self.reward_predictor(post_feat).mean  # [N*B, 1]

            # update action distribution using top-k samples
            total_predicted_reward = total_predicted_reward.view(self.cem_candidates, self.n_copies, 1)  # [N, B, 1]
            _, top_indexes = total_predicted_reward.topk(self.cem_tops, dim=0, largest=True, sorted=False)  # [N', B, 1]
            action_candidates = action_candidates.view(self.cem_horizon, self.cem_candidates, self.n_copies,
                                                       -1)  # [H, N, B, A]
            top_action_candidates = action_candidates[:, top_indexes,
                                    th.arange(self.n_copies).reshape(self.n_copies, 1),
                                    th.arange(self.a_dim)]  # [H, N', B, A]
            mean = top_action_candidates.mean(dim=1, keepdim=True)  # [H, 1, B, A]
            stddev = top_action_candidates.std(dim=1, unbiased=False, keepdim=True)  # [H, 1, B, A]

        # Return only first action (replan each state based on new observation)
        actions = th.tanh(mean[0].squeeze(0))  # [B, A]
        actions = self._exploration(actions)
        _, self.rnncs_['hx'] = self.rssm.prior(state_posterior.sample(),
                                               actions,
                                               self.rnncs['hx'])
        return actions, Data(action=actions)
Example #17
0
 def random_action(self):
     actions = {}
     self._acts_info = {}
     for id in self.agent_ids:
         if self.is_continuouss[id]:
             actions[id] = np.random.uniform(-1.0, 1.0, (self.n_copies, self.a_dims[id]))
         else:
             actions[id] = np.random.randint(0, self.a_dims[id], self.n_copies)
         self._acts_info[id] = Data(action=actions[id])
     self._pre_acts = actions
     return actions
Example #18
0
 def select_action(self, obs):
     # [B, P], [B, P, A], [B, P]
     (q, pi, beta) = self.net(obs, rnncs=self.rnncs)
     self.rnncs_ = self.net.get_rnncs()
     options_onehot = F.one_hot(self.options,
                                self.options_num).float()  # [B, P]
     options_onehot_expanded = options_onehot.unsqueeze(-1)  # [B, P, 1]
     pi = (pi * options_onehot_expanded).sum(-2)  # [B, A]
     if self.is_continuous:
         mu = pi  # [B, A]
         log_std = self.log_std[self.options]  # [B, A]
         dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
         action = dist.sample().clamp(-1, 1)  # [B, A]
         log_prob = dist.log_prob(action).unsqueeze(-1)  # [B, 1]
     else:
         logits = pi  # [B, A]
         norm_dist = td.Categorical(logits=logits)
         action = norm_dist.sample()  # [B,]
         log_prob = norm_dist.log_prob(action).unsqueeze(-1)  # [B, 1]
     value = q_o = (q * options_onehot).sum(-1, keepdim=True)  # [B, 1]
     beta_adv = q_o - ((1 - self.eps) * q.max(-1, keepdim=True)[0] +
                       self.eps * q.mean(-1, keepdim=True))  # [B, 1]
     max_options = q.argmax(-1)  # [B, P] => [B, ]
     beta_probs = (beta * options_onehot).sum(-1)  # [B, P] => [B,]
     beta_dist = td.Bernoulli(probs=beta_probs)
     # <1 则不改变op, =1 则改变op
     new_options = th.where(beta_dist.sample() < 1, self.options,
                            max_options)
     self.new_options = th.where(self._done_mask, max_options, new_options)
     self.oc_mask = (self.new_options == self.options).float()
     acts_info = Data(
         action=action,
         value=value,
         log_prob=log_prob + th.finfo().eps,
         beta_advantage=beta_adv + self.dc,
         last_options=self.options,
         options=self.new_options,
         reward_offset=-((1 - self.oc_mask) * self.dc).unsqueeze(-1))
     if self.use_rnn:
         acts_info.update(rnncs=self.rnncs)
     return action, acts_info
Example #19
0
 def learn(self, BATCH: Data):
     BATCH = self._preprocess_BATCH(BATCH)  # [T, B, *]
     for _ in range(self._epochs):
         kls = []
         for _BATCH in BATCH.sample(self._chunk_length, self.batch_size, repeat=self._sample_allow_repeat):
             _BATCH = self._before_train(_BATCH)
             summaries, kl = self._train(_BATCH)
             kls.append(kl)
             self.summaries.update(summaries)
             self._after_train()
         if self._use_early_stop and sum(kls) / len(kls) > self._kl_stop:
             break
Example #20
0
    def select_action(self, obs):
        q_values = self.q_net(obs, rnncs=self.rnncs)  # [B, *]
        self.rnncs_ = self.q_net.get_rnncs()

        if self._is_train_mode and self.expl_expt_mng.is_random(
                self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            for i in range(self.target_k):
                target_q_values = self.target_nets[i](obs, rnncs=self.rnncs)
                q_values += target_q_values
            actions = q_values.argmax(-1)  # 不取平均也可以 [B, ]
        return actions, Data(action=actions)
Example #21
0
 def select_action(self, obs):
     if self.is_continuous:
         mu, log_std = self.actor(obs, rnncs=self.rnncs)  # [B, A]
         pi = td.Normal(mu, log_std.exp()).sample().tanh()  # [B, A]
         mu.tanh_()  # squash mu     # [B, A]
     else:
         logits = self.actor(obs, rnncs=self.rnncs)  # [B, A]
         mu = logits.argmax(-1)  # [B,]
         cate_dist = td.Categorical(logits=logits)
         pi = cate_dist.sample()  # [B,]
     self.rnncs_ = self.actor.get_rnncs()
     actions = pi if self._is_train_mode else mu
     return actions, Data(action=actions)
Example #22
0
 def select_action(self, obs):
     output = self.actor(obs, rnncs=self.rnncs)  # [B, A]
     self.rnncs_ = self.actor.get_rnncs()
     if self.is_continuous:
         mu = output  # [B, A]
         pi = self.noised_action(mu)  # [B, A]
     else:
         logits = output  # [B, A]
         mu = logits.argmax(-1)  # [B,]
         cate_dist = td.Categorical(logits=logits)
         pi = cate_dist.sample()  # [B,]
     actions = pi if self._is_train_mode else mu
     return actions, Data(action=actions)
Example #23
0
 def select_action(self, obs):
     output = self.actor(obs, rnncs=self.rnncs)  # [B, *]
     self.rnncs_ = self.actor.get_rnncs()
     if self.is_continuous:
         mu, log_std = output  # [B, *]
         dist = td.Independent(td.Normal(mu, log_std.exp()), -1)
         action = dist.sample().clamp(-1, 1)  # [B, *]
         log_prob = dist.log_prob(action)  # [B,]
     else:
         logits = output  # [B, *]
         norm_dist = td.Categorical(logits=logits)
         action = norm_dist.sample()  # [B,]
         log_prob = norm_dist.log_prob(action)  # [B,]
     return action, Data(action=action, log_prob=log_prob)
Example #24
0
    def select_action(self, obs):
        _, select_quantiles_tiled = self._generate_quantiles(  # [N*B, X]
            batch_size=self.n_copies,
            quantiles_num=self.select_quantiles
        )
        q_values = self.q_net(obs, select_quantiles_tiled, rnncs=self.rnncs)  # [N, B, A]
        self.rnncs_ = self.q_net.get_rnncs()

        if self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step):
            actions = np.random.randint(0, self.a_dim, self.n_copies)
        else:
            # [N, B, A] => [B, A] => [B,]
            actions = q_values.mean(0).argmax(-1)
        return actions, Data(action=actions)
Example #25
0
    def select_action(self, obs):
        acts_info = {}
        actions = {}
        for aid, mid in zip(self.agent_ids, self.model_ids):
            q_values = self.q_nets[mid](obs[aid],
                                        rnncs=self.rnncs[aid])  # [B, A]
            self.rnncs_[aid] = self.q_nets[mid].get_rnncs()

            if self._is_train_mode and self.expl_expt_mng.is_random(
                    self._cur_train_step):
                action = np.random.randint(0, self.a_dims[aid], self.n_copies)
            else:
                action = q_values.argmax(-1)  # [B,]

            actions[aid] = action
            acts_info[aid] = Data(action=action)
        return actions, acts_info
Example #26
0
 def select_action(self, obs):
     if self._is_visual:
         obs = get_first_visual(obs)
     else:
         obs = get_first_vector(obs)
     embedded_obs = self.obs_encoder(obs)  # [B, *]
     state_posterior = self.rssm.posterior(self.rnncs['hx'], embedded_obs)
     state = state_posterior.sample()  # [B, *]
     actions = self.actor.sample_actions(th.cat((state, self.rnncs['hx']),
                                                -1),
                                         is_train=self._is_train_mode)
     actions = self._exploration(actions)
     _, self.rnncs_['hx'] = self.rssm.prior(state, actions,
                                            self.rnncs['hx'])
     if not self.is_continuous:
         actions = actions.argmax(-1)  # [B,]
     return actions, Data(action=actions)
Example #27
0
 def select_action(self, obs: Dict):
     acts_info = {}
     actions = {}
     for aid, mid in zip(self.agent_ids, self.model_ids):
         output = self.actors[mid](obs[aid],
                                   rnncs=self.rnncs[aid])  # [B, A]
         self.rnncs_[aid] = self.actors[mid].get_rnncs()
         if self.is_continuouss[aid]:
             mu = output  # [B, A]
             pi = self.noised_actions[mid](mu)  # [B, A]
         else:
             logits = output  # [B, A]
             mu = logits.argmax(-1)  # [B,]
             cate_dist = td.Categorical(logits=logits)
             pi = cate_dist.sample()  # [B,]
         action = pi if self._is_train_mode else mu
         acts_info[aid] = Data(action=action)
         actions[aid] = action
     return actions, acts_info
Example #28
0
 def select_action(self, obs):
     q = self.q_net(obs, rnncs=self.rnncs)  # [B, P]
     self.rnncs_ = self.q_net.get_rnncs()
     pi = self.intra_option_net(obs, rnncs=self.rnncs)  # [B, P, A]
     options_onehot = F.one_hot(self.options,
                                self.options_num).float()  # [B, P]
     options_onehot_expanded = options_onehot.unsqueeze(-1)  # [B, P, 1]
     pi = (pi * options_onehot_expanded).sum(-2)  # [B, A]
     if self.is_continuous:
         mu = pi.tanh()  # [B, A]
         log_std = self.log_std[self.options]  # [B, A]
         dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
         actions = dist.sample().clamp(-1, 1)  # [B, A]
     else:
         pi = pi / self.boltzmann_temperature  # [B, A]
         dist = td.Categorical(logits=pi)
         actions = dist.sample()  # [B, ]
     interests = self.interest_net(obs, rnncs=self.rnncs)  # [B, P]
     op_logits = interests * q  # [B, P] or q.softmax(-1)
     self.new_options = td.Categorical(logits=op_logits).sample()  # [B, ]
     return actions, Data(action=actions,
                          last_options=self.options,
                          options=self.new_options)
Example #29
0
    def select_action(self, obs):
        if self.is_continuous:
            _actions = []
            for _ in range(self._select_samples):
                _actions.append(
                    self.actor(obs, self.vae.decode(obs),
                               rnncs=self.rnncs))  # [B, A]
            self.rnncs_ = self.actor.get_rnncs(
            )  # TODO: calculate corrected hidden state
            _actions = th.stack(_actions, dim=0)  # [N, B, A]
            q1s = []
            for i in range(self._select_samples):
                q1s.append(self.critic(obs, _actions[i])[0])
            q1s = th.stack(q1s, dim=0)  # [N, B, 1]
            max_idxs = q1s.argmax(dim=0, keepdim=True)[-1]  # [1, B, 1]
            actions = _actions[
                max_idxs,
                th.arange(self.n_copies).reshape(self.n_copies, 1),
                th.arange(self.a_dim)]
        else:
            q_values, i_values = self.q_net(obs, rnncs=self.rnncs)  # [B, *]
            q_values = q_values - q_values.min(dim=-1,
                                               keepdim=True)[0]  # [B, *]
            i_values = F.log_softmax(i_values, dim=-1)  # [B, *]
            i_values = i_values.exp()  # [B, *]
            i_values = (i_values / i_values.max(-1, keepdim=True)[0] >
                        self._threshold).float()  # [B, *]

            self.rnncs_ = self.q_net.get_rnncs()

            if self._is_train_mode and self.expl_expt_mng.is_random(
                    self._cur_train_step):
                actions = np.random.randint(0, self.a_dim, self.n_copies)
            else:
                actions = (i_values * q_values).argmax(-1)  # [B,]
        return actions, Data(action=actions)
Example #30
0
    def step(self, actions: Dict[str, np.ndarray],
             **kwargs) -> Dict[str, Data]:
        # choose the first agents' actions
        actions = deepcopy(actions['single'])
        params = []
        for i in range(self._n_copies):
            params.append(dict(args=(actions[i], )))
        rets = self._envs.run('step', params)
        obs_fs, reward, done, info = zip(*rets)
        obs_fs = np.stack(obs_fs, 0)
        reward = np.stack(reward, 0)
        done = np.stack(done, 0)
        # TODO: info

        obs_fa = deepcopy(obs_fs)  # obs for next action choosing.

        idxs = np.where(done)[0]
        if len(idxs) > 0:
            reset_obs = self._envs.run('reset', idxs=idxs)
            obs_fa[idxs] = np.stack(reset_obs, 0)

        if self._use_visual:
            obs_fs = Data(visual={'visual_0': obs_fs})
            obs_fa = Data(visual={'visual_0': obs_fa})
        else:
            obs_fs = Data(vector={'vector_0': obs_fs})
            obs_fa = Data(vector={'vector_0': obs_fa})
        return {
            'single':
            Data(obs_fs=obs_fs,
                 obs_fa=obs_fa,
                 reward=reward,
                 done=done,
                 info=info),
            'global':
            Data(begin_mask=done[:, np.newaxis])
        }