def select_action(self, obs): if self.is_continuous: if self._share_net: mu, log_std, value = self.net(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.net.get_rnncs() else: mu, log_std = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() value = self.critic(obs, rnncs=self.rnncs) # [B, 1] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) action = dist.sample().clamp(-1, 1) # [B, A] log_prob = dist.log_prob(action).unsqueeze(-1) # [B, 1] else: if self._share_net: logits, value = self.net(obs, rnncs=self.rnncs) # [B, A], [B, 1] self.rnncs_ = self.net.get_rnncs() else: logits = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() value = self.critic(obs, rnncs=self.rnncs) # [B, 1] norm_dist = td.Categorical(logits=logits) action = norm_dist.sample() # [B,] log_prob = norm_dist.log_prob(action).unsqueeze(-1) # [B, 1] acts_info = Data(action=action, value=value, log_prob=log_prob + th.finfo().eps) if self.use_rnn: acts_info.update(rnncs=self.rnncs) return action, acts_info
def reset(self, **kwargs) -> Dict[str, Data]: obss = dict() obs = self._envs.run('reset') # list(dict) for k in self._agents: obss[k] = np.stack([obs[i][k] for i in range(self._n_copies)], 0) rets = {} for k in self._agents: if self._is_obs_visual[k]: rets[k] = Data(visual={'visual_0': obss[k]}) else: rets[k] = Data(vector={'vector_0': obss[k]}) rets['global'] = Data(begin_mask=np.full((self._n_copies, 1), True)) if self._has_global_state: state = self._envs.run('state') state = np.stack(state, 0) # [B, *] if self._is_state_visual: _state = Data(visual={'visual_0': state}) else: _state = Data(vector={'vector_0': state}) rets['global'].update(obs=_state) return rets
def episode_step(self, obs, env_rets: Dict[str, Data]): super().episode_step() if self._store: expss = {} for id in self.agent_ids: expss[id] = Data(obs=obs[id], # [B, ] => [B, 1] reward=env_rets[id].reward[:, np.newaxis], obs_=env_rets[id].obs_fs, done=env_rets[id].done[:, np.newaxis]) expss[id].update(self._acts_info[id]) expss['global'] = Data(begin_mask=obs['global'].begin_mask) if self._has_global_state: expss['global'].update(obs=obs['global'].obs, obs_=env_rets['global'].obs) self._buffer.add(expss) for id in self.agent_ids: idxs = np.where(env_rets[id].done)[0] self._pre_acts[id][idxs] = 0. self.rnncs[id] = self.rnncs_[id] if self.rnncs[id] is not None: for k in self.rnncs[id].keys(): self.rnncs[id][k][idxs] = 0.
def reset(self, **kwargs) -> Dict[str, Data]: obs = self._envs.run('reset') obs = np.stack(obs, 0) if self._use_visual: ret = Data(visual={'visual_0': obs}) else: ret = Data(vector={'vector_0': obs}) return { 'single': ret, 'global': Data(begin_mask=np.full((self._n_copies, 1), True)) }
def select_action(self, obs): output = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() value = self.critic(obs, rnncs=self.rnncs) # [B, 1] if self.is_continuous: mu, log_std = output # [B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) action = dist.sample().clamp(-1, 1) # [B, A] log_prob = dist.log_prob(action).unsqueeze(-1) # [B, 1] else: logits = output # [B, A] logp_all = logits.log_softmax(-1) # [B, A] norm_dist = td.Categorical(logits=logp_all) action = norm_dist.sample() # [B,] log_prob = norm_dist.log_prob(action).unsqueeze(-1) # [B, 1] acts_info = Data(action=action, value=value, log_prob=log_prob + th.finfo().eps) if self.use_rnn: acts_info.update(rnncs=self.rnncs) if self.is_continuous: acts_info.update(mu=mu, log_std=log_std) else: acts_info.update(logp_all=logp_all) return action, acts_info
def select_action(self, obs): q = self.q_net(obs, rnncs=self.rnncs) # [B, P] self.rnncs_ = self.q_net.get_rnncs() pi = self.intra_option_net(obs, rnncs=self.rnncs) # [B, P, A] beta = self.termination_net(obs, rnncs=self.rnncs) # [B, P] options_onehot = F.one_hot(self.options, self.options_num).float() # [B, P] options_onehot_expanded = options_onehot.unsqueeze(-1) # [B, P, 1] pi = (pi * options_onehot_expanded).sum(-2) # [B, A] if self.is_continuous: mu = pi.tanh() # [B, A] log_std = self.log_std[self.options] # [B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) actions = dist.sample().clamp(-1, 1) # [B, A] else: pi = pi / self.boltzmann_temperature # [B, A] dist = td.Categorical(logits=pi) actions = dist.sample() # [B, ] max_options = q.argmax(-1).long() # [B, P] => [B, ] if self.use_eps_greedy: # epsilon greedy if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): self.new_options = self._generate_random_options() else: self.new_options = max_options else: beta_probs = (beta * options_onehot).sum(-1) # [B, P] => [B,] beta_dist = td.Bernoulli(probs=beta_probs) self.new_options = th.where(beta_dist.sample() < 1, self.options, max_options) return actions, Data(action=actions, last_options=self.options, options=self.new_options)
def sample(self, batchsize=None, chunk_length=None): if batchsize == 0: return self.all_data() B = batchsize or self.batch_size T = chunk_length or self._chunk_length assert T <= self._horizon_length if self._horizon_length == self.max_horizon: start = self._pointer - self.max_horizon else: start = 0 end = self._pointer - T + 1 x = np.random.randint(start, end, B) # [B, ] y = np.random.randint(0, self.n_copies, B) # (B, ) # (T, B) + (B, ) = (T, B) xs = (np.tile(np.arange(T)[:, np.newaxis], B) + x) % self._horizon_length sample_idxs = (xs, y) samples = {} for k, v in self._buffer.items(): samples[k] = Data.from_nested_dict( {_k: _v[sample_idxs] for _k, _v in v.items()} ) return samples # [T, B, *]
def select_action(self, obs): output = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() if self.is_continuous: mu, log_std = output # [B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) action = dist.sample().clamp(-1, 1) # [B, A] else: logits = output # [B, A] norm_dist = td.Categorical(logits=logits) action = norm_dist.sample() # [B,] acts_info = Data(action=action) if self.use_rnn: acts_info.update(rnncs=self.rnncs) return action, acts_info
def select_action(self, obs): q_values = self.q_net(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.q_net.get_rnncs() logits = ((q_values - self._get_v(q_values)) / self.alpha).exp() # > 0 # [B, A] logits /= logits.sum(-1, keepdim=True) # [B, A] cate_dist = td.Categorical(logits=logits) actions = cate_dist.sample() # [B,] return actions, Data(action=actions)
def random_action(self): if self.is_continuous: actions = np.random.uniform(-1.0, 1.0, (self.n_copies, self.a_dim)) else: actions = np.random.randint(0, self.a_dim, self.n_copies) self._pre_act = actions self._acts_info = Data(action=actions) return actions
def select_action(self, obs): q_values = self.q_net(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: actions = q_values.argmax(-1) # [B,] return actions, Data(action=actions)
def episode_step(self, obs: Data, env_rets: Data, begin_mask: np.ndarray): super().episode_step() if self._store: exps = Data( obs=obs, # [B, ] => [B, 1] reward=env_rets.reward[:, np.newaxis], obs_=env_rets.obs_fs, done=env_rets.done[:, np.newaxis], begin_mask=begin_mask) exps.update(self._acts_info) self._buffer.add({self._agent_id: exps}) idxs = np.where(env_rets.done)[0] self._pre_act[idxs] = 0. self.rnncs = self.rnncs_ if self.rnncs is not None: for k in self.rnncs.keys(): self.rnncs[k][idxs] = 0.
def learn(self, BATCH: Data): BATCH = self._preprocess_BATCH(BATCH) # [T, B, *] for _ in range(self._epochs): for _BATCH in BATCH.sample(self._chunk_length, self.batch_size, repeat=self._sample_allow_repeat): _BATCH = self._before_train(_BATCH) summaries = self._train(_BATCH) self.summaries.update(summaries) self._after_train()
def select_action(self, obs): q = self.critic(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.critic.get_rnncs() if self.use_epsilon and self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: cate_dist = td.Categorical(logits=(q / self.alpha)) mu = q.argmax(-1) # [B,] actions = pi = cate_dist.sample() # [B,] return actions, Data(action=actions)
def select_action(self, obs): feat = self.q_net(obs, rnncs=self.rnncs) # [B, A, N] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: q = (self._z * feat).sum(-1) # [B, A, N] * [N,] => [B, A] actions = q.argmax(-1) # [B,] return actions, Data(action=actions)
def select_action(self, obs): if self._is_visual: obs = get_first_visual(obs) else: obs = get_first_vector(obs) # Compute starting state for planning # while taking information from current observation (posterior) embedded_obs = self.obs_encoder(obs) # [B, *] state_posterior = self.rssm.posterior(self.rnncs['hx'], embedded_obs) # dist # [B, *] # Initialize action distribution mean = th.zeros((self.cem_horizon, 1, self.n_copies, self.a_dim)) # [H, 1, B, A] stddev = th.ones((self.cem_horizon, 1, self.n_copies, self.a_dim)) # [H, 1, B, A] # Iteratively improve action distribution with CEM for itr in range(self.cem_iter_nums): action_candidates = mean + stddev * \ th.randn(self.cem_horizon, self.cem_candidates, self.n_copies, self.a_dim) # [H, N, B, A] action_candidates = action_candidates.reshape(self.cem_horizon, -1, self.a_dim) # [H, N*B, A] # Initialize reward, state, and rnn hidden state # These are for parallel exploration total_predicted_reward = th.zeros((self.cem_candidates * self.n_copies, 1)) # [N*B, 1] state = state_posterior.sample((self.cem_candidates,)) # [N, B, *] state = state.view(-1, state.shape[-1]) # [N*B, *] rnn_hidden = self.rnncs['hx'].repeat((self.cem_candidates, 1)) # [B, *] => [N*B, *] # Compute total predicted reward by open-loop prediction using pri for t in range(self.cem_horizon): next_state_prior, rnn_hidden = self.rssm.prior(state, th.tanh(action_candidates[t]), rnn_hidden) state = next_state_prior.sample() # [N*B, *] post_feat = th.cat([state, rnn_hidden], -1) # [N*B, *] total_predicted_reward += self.reward_predictor(post_feat).mean # [N*B, 1] # update action distribution using top-k samples total_predicted_reward = total_predicted_reward.view(self.cem_candidates, self.n_copies, 1) # [N, B, 1] _, top_indexes = total_predicted_reward.topk(self.cem_tops, dim=0, largest=True, sorted=False) # [N', B, 1] action_candidates = action_candidates.view(self.cem_horizon, self.cem_candidates, self.n_copies, -1) # [H, N, B, A] top_action_candidates = action_candidates[:, top_indexes, th.arange(self.n_copies).reshape(self.n_copies, 1), th.arange(self.a_dim)] # [H, N', B, A] mean = top_action_candidates.mean(dim=1, keepdim=True) # [H, 1, B, A] stddev = top_action_candidates.std(dim=1, unbiased=False, keepdim=True) # [H, 1, B, A] # Return only first action (replan each state based on new observation) actions = th.tanh(mean[0].squeeze(0)) # [B, A] actions = self._exploration(actions) _, self.rnncs_['hx'] = self.rssm.prior(state_posterior.sample(), actions, self.rnncs['hx']) return actions, Data(action=actions)
def random_action(self): actions = {} self._acts_info = {} for id in self.agent_ids: if self.is_continuouss[id]: actions[id] = np.random.uniform(-1.0, 1.0, (self.n_copies, self.a_dims[id])) else: actions[id] = np.random.randint(0, self.a_dims[id], self.n_copies) self._acts_info[id] = Data(action=actions[id]) self._pre_acts = actions return actions
def select_action(self, obs): # [B, P], [B, P, A], [B, P] (q, pi, beta) = self.net(obs, rnncs=self.rnncs) self.rnncs_ = self.net.get_rnncs() options_onehot = F.one_hot(self.options, self.options_num).float() # [B, P] options_onehot_expanded = options_onehot.unsqueeze(-1) # [B, P, 1] pi = (pi * options_onehot_expanded).sum(-2) # [B, A] if self.is_continuous: mu = pi # [B, A] log_std = self.log_std[self.options] # [B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) action = dist.sample().clamp(-1, 1) # [B, A] log_prob = dist.log_prob(action).unsqueeze(-1) # [B, 1] else: logits = pi # [B, A] norm_dist = td.Categorical(logits=logits) action = norm_dist.sample() # [B,] log_prob = norm_dist.log_prob(action).unsqueeze(-1) # [B, 1] value = q_o = (q * options_onehot).sum(-1, keepdim=True) # [B, 1] beta_adv = q_o - ((1 - self.eps) * q.max(-1, keepdim=True)[0] + self.eps * q.mean(-1, keepdim=True)) # [B, 1] max_options = q.argmax(-1) # [B, P] => [B, ] beta_probs = (beta * options_onehot).sum(-1) # [B, P] => [B,] beta_dist = td.Bernoulli(probs=beta_probs) # <1 则不改变op, =1 则改变op new_options = th.where(beta_dist.sample() < 1, self.options, max_options) self.new_options = th.where(self._done_mask, max_options, new_options) self.oc_mask = (self.new_options == self.options).float() acts_info = Data( action=action, value=value, log_prob=log_prob + th.finfo().eps, beta_advantage=beta_adv + self.dc, last_options=self.options, options=self.new_options, reward_offset=-((1 - self.oc_mask) * self.dc).unsqueeze(-1)) if self.use_rnn: acts_info.update(rnncs=self.rnncs) return action, acts_info
def learn(self, BATCH: Data): BATCH = self._preprocess_BATCH(BATCH) # [T, B, *] for _ in range(self._epochs): kls = [] for _BATCH in BATCH.sample(self._chunk_length, self.batch_size, repeat=self._sample_allow_repeat): _BATCH = self._before_train(_BATCH) summaries, kl = self._train(_BATCH) kls.append(kl) self.summaries.update(summaries) self._after_train() if self._use_early_stop and sum(kls) / len(kls) > self._kl_stop: break
def select_action(self, obs): q_values = self.q_net(obs, rnncs=self.rnncs) # [B, *] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: for i in range(self.target_k): target_q_values = self.target_nets[i](obs, rnncs=self.rnncs) q_values += target_q_values actions = q_values.argmax(-1) # 不取平均也可以 [B, ] return actions, Data(action=actions)
def select_action(self, obs): if self.is_continuous: mu, log_std = self.actor(obs, rnncs=self.rnncs) # [B, A] pi = td.Normal(mu, log_std.exp()).sample().tanh() # [B, A] mu.tanh_() # squash mu # [B, A] else: logits = self.actor(obs, rnncs=self.rnncs) # [B, A] mu = logits.argmax(-1) # [B,] cate_dist = td.Categorical(logits=logits) pi = cate_dist.sample() # [B,] self.rnncs_ = self.actor.get_rnncs() actions = pi if self._is_train_mode else mu return actions, Data(action=actions)
def select_action(self, obs): output = self.actor(obs, rnncs=self.rnncs) # [B, A] self.rnncs_ = self.actor.get_rnncs() if self.is_continuous: mu = output # [B, A] pi = self.noised_action(mu) # [B, A] else: logits = output # [B, A] mu = logits.argmax(-1) # [B,] cate_dist = td.Categorical(logits=logits) pi = cate_dist.sample() # [B,] actions = pi if self._is_train_mode else mu return actions, Data(action=actions)
def select_action(self, obs): output = self.actor(obs, rnncs=self.rnncs) # [B, *] self.rnncs_ = self.actor.get_rnncs() if self.is_continuous: mu, log_std = output # [B, *] dist = td.Independent(td.Normal(mu, log_std.exp()), -1) action = dist.sample().clamp(-1, 1) # [B, *] log_prob = dist.log_prob(action) # [B,] else: logits = output # [B, *] norm_dist = td.Categorical(logits=logits) action = norm_dist.sample() # [B,] log_prob = norm_dist.log_prob(action) # [B,] return action, Data(action=action, log_prob=log_prob)
def select_action(self, obs): _, select_quantiles_tiled = self._generate_quantiles( # [N*B, X] batch_size=self.n_copies, quantiles_num=self.select_quantiles ) q_values = self.q_net(obs, select_quantiles_tiled, rnncs=self.rnncs) # [N, B, A] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random(self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: # [N, B, A] => [B, A] => [B,] actions = q_values.mean(0).argmax(-1) return actions, Data(action=actions)
def select_action(self, obs): acts_info = {} actions = {} for aid, mid in zip(self.agent_ids, self.model_ids): q_values = self.q_nets[mid](obs[aid], rnncs=self.rnncs[aid]) # [B, A] self.rnncs_[aid] = self.q_nets[mid].get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): action = np.random.randint(0, self.a_dims[aid], self.n_copies) else: action = q_values.argmax(-1) # [B,] actions[aid] = action acts_info[aid] = Data(action=action) return actions, acts_info
def select_action(self, obs): if self._is_visual: obs = get_first_visual(obs) else: obs = get_first_vector(obs) embedded_obs = self.obs_encoder(obs) # [B, *] state_posterior = self.rssm.posterior(self.rnncs['hx'], embedded_obs) state = state_posterior.sample() # [B, *] actions = self.actor.sample_actions(th.cat((state, self.rnncs['hx']), -1), is_train=self._is_train_mode) actions = self._exploration(actions) _, self.rnncs_['hx'] = self.rssm.prior(state, actions, self.rnncs['hx']) if not self.is_continuous: actions = actions.argmax(-1) # [B,] return actions, Data(action=actions)
def select_action(self, obs: Dict): acts_info = {} actions = {} for aid, mid in zip(self.agent_ids, self.model_ids): output = self.actors[mid](obs[aid], rnncs=self.rnncs[aid]) # [B, A] self.rnncs_[aid] = self.actors[mid].get_rnncs() if self.is_continuouss[aid]: mu = output # [B, A] pi = self.noised_actions[mid](mu) # [B, A] else: logits = output # [B, A] mu = logits.argmax(-1) # [B,] cate_dist = td.Categorical(logits=logits) pi = cate_dist.sample() # [B,] action = pi if self._is_train_mode else mu acts_info[aid] = Data(action=action) actions[aid] = action return actions, acts_info
def select_action(self, obs): q = self.q_net(obs, rnncs=self.rnncs) # [B, P] self.rnncs_ = self.q_net.get_rnncs() pi = self.intra_option_net(obs, rnncs=self.rnncs) # [B, P, A] options_onehot = F.one_hot(self.options, self.options_num).float() # [B, P] options_onehot_expanded = options_onehot.unsqueeze(-1) # [B, P, 1] pi = (pi * options_onehot_expanded).sum(-2) # [B, A] if self.is_continuous: mu = pi.tanh() # [B, A] log_std = self.log_std[self.options] # [B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) actions = dist.sample().clamp(-1, 1) # [B, A] else: pi = pi / self.boltzmann_temperature # [B, A] dist = td.Categorical(logits=pi) actions = dist.sample() # [B, ] interests = self.interest_net(obs, rnncs=self.rnncs) # [B, P] op_logits = interests * q # [B, P] or q.softmax(-1) self.new_options = td.Categorical(logits=op_logits).sample() # [B, ] return actions, Data(action=actions, last_options=self.options, options=self.new_options)
def select_action(self, obs): if self.is_continuous: _actions = [] for _ in range(self._select_samples): _actions.append( self.actor(obs, self.vae.decode(obs), rnncs=self.rnncs)) # [B, A] self.rnncs_ = self.actor.get_rnncs( ) # TODO: calculate corrected hidden state _actions = th.stack(_actions, dim=0) # [N, B, A] q1s = [] for i in range(self._select_samples): q1s.append(self.critic(obs, _actions[i])[0]) q1s = th.stack(q1s, dim=0) # [N, B, 1] max_idxs = q1s.argmax(dim=0, keepdim=True)[-1] # [1, B, 1] actions = _actions[ max_idxs, th.arange(self.n_copies).reshape(self.n_copies, 1), th.arange(self.a_dim)] else: q_values, i_values = self.q_net(obs, rnncs=self.rnncs) # [B, *] q_values = q_values - q_values.min(dim=-1, keepdim=True)[0] # [B, *] i_values = F.log_softmax(i_values, dim=-1) # [B, *] i_values = i_values.exp() # [B, *] i_values = (i_values / i_values.max(-1, keepdim=True)[0] > self._threshold).float() # [B, *] self.rnncs_ = self.q_net.get_rnncs() if self._is_train_mode and self.expl_expt_mng.is_random( self._cur_train_step): actions = np.random.randint(0, self.a_dim, self.n_copies) else: actions = (i_values * q_values).argmax(-1) # [B,] return actions, Data(action=actions)
def step(self, actions: Dict[str, np.ndarray], **kwargs) -> Dict[str, Data]: # choose the first agents' actions actions = deepcopy(actions['single']) params = [] for i in range(self._n_copies): params.append(dict(args=(actions[i], ))) rets = self._envs.run('step', params) obs_fs, reward, done, info = zip(*rets) obs_fs = np.stack(obs_fs, 0) reward = np.stack(reward, 0) done = np.stack(done, 0) # TODO: info obs_fa = deepcopy(obs_fs) # obs for next action choosing. idxs = np.where(done)[0] if len(idxs) > 0: reset_obs = self._envs.run('reset', idxs=idxs) obs_fa[idxs] = np.stack(reset_obs, 0) if self._use_visual: obs_fs = Data(visual={'visual_0': obs_fs}) obs_fa = Data(visual={'visual_0': obs_fa}) else: obs_fs = Data(vector={'vector_0': obs_fs}) obs_fa = Data(vector={'vector_0': obs_fa}) return { 'single': Data(obs_fs=obs_fs, obs_fa=obs_fa, reward=reward, done=done, info=info), 'global': Data(begin_mask=done[:, np.newaxis]) }