def compute_value(self, vf_in, h_critic=None, target=False, truncate_steps=-1): """ training critic forward with specified policy Arguments: vf_in: (B,T,K) target: if use target network truncate_steps: number of BPTT steps to truncate if used Returns: q: (B*T,1) """ bs, ts, _ = vf_in.shape critic = self.target_critic if target else self.critic if self.rnn_critic: if h_critic is None: h_t = self.critic_hidden_states.clone() # (B,H) else: h_t = h_critic #.clone() # rollout q = rnn_forward_sequence( critic, vf_in, h_t, truncate_steps=truncate_steps) # q = [] # (B,1)*T # for t in range(ts): # q_t, h_t = critic(vf_in[:,t], h_t) # q.append(q_t) q = torch.stack(q, 0).permute(1,0,2) # (T,B,1) -> (B,T,1) q = q.reshape(bs*ts, -1) # (B*T,1) else: # (B,T,D) -> (B*T,1) q, _ = critic(vf_in.reshape(bs*ts, -1)) return q
def compute_action(self, obs, h_actor=None, target=False, requires_grad=True, truncate_steps=-1): """ traininsg actor forward with specified policy concat all actions to be fed in critics Arguments: obs: (B,T,O) target: if use target network requires_grad: if use _soft_act to differentiate discrete action Returns: act: dict of (B,T,A) """ bs, ts, _ = obs.shape pi = self.target_policy if target else self.policy if self.rnn_policy: if h_actor is None: h_t = self.policy_hidden_states.clone() # (B,H) else: h_t = h_actor #.clone() # rollout seq_logits = rnn_forward_sequence(pi, obs, h_t, truncate_steps=truncate_steps) # seq_logits = [] # for t in range(ts): # act_t, h_t = pi(obs[:,t], h_t) # act_t is dict (B,A) # seq_logits.append(act_t) # soften deterministic output for backprop act = defaultdict(list) for act_t in seq_logits: for k, a in act_t.items(): act[k].append(self._soft_act(a, requires_grad)) act = { k: torch.stack(ac, 0).permute(1, 0, 2) for k, ac in act.items() } # dict [(B,A)]*T -> dict (B,T,A) else: stacked_obs = obs.reshape(bs * ts, -1) # (B*T,O) act, _ = pi(stacked_obs) # act is dict of (B*T,A) act = { k: self._soft_act(ac, requires_grad).reshape(bs, ts, -1) for k, ac in act.items() } # dict of (B,T,A) return act
def compute_moa_action(self, agent_j, obs, h_actor=None, target=False, requires_grad=True, truncate_steps=-1, return_logits=True): """ traininsg actor forward with specified policy concat all actions to be fed in critics Arguments: obs: (B,T,O) target: if use target network requires_grad: if use _soft_act to differentiate discrete action Returns: act: dict of (B,T,A) """ bs, ts, _ = obs.shape if target: pi = self.moa_target_policies[agent_j] else: pi = self.moa_policies[agent_j] if self.rnn_policy: if h_actor is None: h_t = self.policy_hidden_states.clone() # (B,H) else: h_t = h_actor #.clone() # rollout seq_logits = rnn_forward_sequence(pi, obs, h_t, truncate_steps=truncate_steps) # soften deterministic output for backprop act = defaultdict(list) for act_t in seq_logits: for k, logits in act_t.items(): # if requires_grad, need gumbel-softmax, same as in explore with low temperature if return_logits: action = logits else: action, _ = self.selector.select_action( logits, explore=False, hard=True, reparameterize=requires_grad, temperature=0.5) act[k].append(action) # act[k].append(self._soft_act(logits, requires_grad)) act = { k: torch.stack(ac, 0).permute(1, 0, 2) for k, ac in act.items() } # dict [(B,A)]*T -> dict (B,T,A) else: stacked_obs = obs.reshape(bs * ts, -1) # (B*T,O) act, _ = pi(stacked_obs) # act is dict of (B*T,A) # act is dict of (B,T,A) for k, logits in act.items(): if return_logits: action = logits else: action, _ = self.selector.select_action( logits, explore=False, hard=True, reparameterize=requires_grad, temperature=0.5) action = action.reshape(bs, ts, -1) act[k] = action # act[k] = self._soft_act(ac, requires_grad).reshape(bs, ts, -1) return act
def evaluate_moa_action(self, agent_j, act_samples, obs, h_actor=None, requires_grad=True, contract_keys=None, truncate_steps=-1): """ traininsg actor forward with specified policy concat all actions to be fed in critics Arguments: agent_j: use j-th moa agent act_samples: dict of (B,T,A), actions in sample obs: (B,T,O) requires_grad: if use _soft_act to differentiate discrete action contract_keys: list of keys to contract dict on i.e. sum up all fields in log_prob, entropy, kl on given keys Returns: log_prob: action log probs (B,T,1) entropy: action entropy (B,T,1) """ bs, ts, _ = obs.shape pi = self.moa_policies[agent_j] log_prob_d, entropy_d = {}, {} # get logits for current policy if self.rnn_policy: if h_actor is None: h_t = self.moa_hidden_states[agent_j].clone() # (B,H) else: h_t = h_actor #.clone() # rollout seq_logits = rnn_forward_sequence(pi, obs, h_t, truncate_steps=truncate_steps) # soften deterministic output for backprop act = defaultdict(list) for act_t in seq_logits: for k, logits in act_t.items(): act[k].append(logits) # act[k].append(self._soft_act(logits, requires_grad)) act = { k: torch.stack(ac, 0).permute(1, 0, 2) for k, ac in act.items() } # dict [(B,A)]*T -> dict (B,T,A) else: stacked_obs = obs.reshape(bs * ts, -1) # (B*T,O) act, _ = pi(stacked_obs) # act is dict of (B*T,A) act = {k: ac.reshape(bs, ts, -1) for k, ac in act.items()} # dict of (B,T,A) if contract_keys is None: contract_keys = sorted(list(act.keys())) log_prob, entropy = 0.0, 0.0 for k, seq_logits in act.items(): if k not in contract_keys: continue action = act_samples[k] _, dist = self.selector.select_action(seq_logits, explore=False, hard=False, reparameterize=False) # evaluate log prob (B,T) -> (B,T,1) # NOTE: attention!!! if log_prob on rsample action, backprop is done twice and wrong log_prob += dist.log_prob(action.clone().detach()).unsqueeze(-1) # get current action distrib entropy entropy += dist.entropy().unsqueeze(-1) return log_prob, entropy