Пример #1
0
    def compute_value(self, vf_in, h_critic=None, target=False, truncate_steps=-1):
        """ training critic forward with specified policy 
        Arguments:
            vf_in: (B,T,K)
            target: if use target network
            truncate_steps: number of BPTT steps to truncate if used
        Returns:
            q: (B*T,1)
        """
        bs, ts, _ = vf_in.shape
        critic = self.target_critic if target else self.critic

        if self.rnn_critic:
            if h_critic is None:
                h_t = self.critic_hidden_states.clone() # (B,H)
            else:
                h_t = h_critic  #.clone()

            # rollout 
            q = rnn_forward_sequence(
                critic, vf_in, h_t, truncate_steps=truncate_steps)
            # q = []   # (B,1)*T
            # for t in range(ts):
            #     q_t, h_t = critic(vf_in[:,t], h_t)
            #     q.append(q_t)
            q = torch.stack(q, 0).permute(1,0,2)   # (T,B,1) -> (B,T,1)
            q = q.reshape(bs*ts, -1)  # (B*T,1)
        else:
            # (B,T,D) -> (B*T,1)
            q, _ = critic(vf_in.reshape(bs*ts, -1))
        return q 
Пример #2
0
    def compute_action(self,
                       obs,
                       h_actor=None,
                       target=False,
                       requires_grad=True,
                       truncate_steps=-1):
        """ traininsg actor forward with specified policy 
        concat all actions to be fed in critics
        Arguments:
            obs: (B,T,O)
            target: if use target network
            requires_grad: if use _soft_act to differentiate discrete action
        Returns:
            act: dict of (B,T,A) 
        """
        bs, ts, _ = obs.shape
        pi = self.target_policy if target else self.policy

        if self.rnn_policy:
            if h_actor is None:
                h_t = self.policy_hidden_states.clone()  # (B,H)
            else:
                h_t = h_actor  #.clone()

            # rollout
            seq_logits = rnn_forward_sequence(pi,
                                              obs,
                                              h_t,
                                              truncate_steps=truncate_steps)
            # seq_logits = []
            # for t in range(ts):
            #     act_t, h_t = pi(obs[:,t], h_t)  # act_t is dict (B,A)
            #     seq_logits.append(act_t)

            # soften deterministic output for backprop
            act = defaultdict(list)
            for act_t in seq_logits:
                for k, a in act_t.items():
                    act[k].append(self._soft_act(a, requires_grad))
            act = {
                k: torch.stack(ac, 0).permute(1, 0, 2)
                for k, ac in act.items()
            }  # dict [(B,A)]*T -> dict (B,T,A)
        else:
            stacked_obs = obs.reshape(bs * ts, -1)  # (B*T,O)
            act, _ = pi(stacked_obs)  # act is dict of (B*T,A)
            act = {
                k: self._soft_act(ac, requires_grad).reshape(bs, ts, -1)
                for k, ac in act.items()
            }  # dict of (B,T,A)
        return act
Пример #3
0
    def compute_moa_action(self,
                           agent_j,
                           obs,
                           h_actor=None,
                           target=False,
                           requires_grad=True,
                           truncate_steps=-1,
                           return_logits=True):
        """ traininsg actor forward with specified policy 
        concat all actions to be fed in critics
        Arguments:
            obs: (B,T,O)
            target: if use target network
            requires_grad: if use _soft_act to differentiate discrete action
        Returns:
            act: dict of (B,T,A) 
        """
        bs, ts, _ = obs.shape
        if target:
            pi = self.moa_target_policies[agent_j]
        else:
            pi = self.moa_policies[agent_j]

        if self.rnn_policy:
            if h_actor is None:
                h_t = self.policy_hidden_states.clone()  # (B,H)
            else:
                h_t = h_actor  #.clone()

            # rollout
            seq_logits = rnn_forward_sequence(pi,
                                              obs,
                                              h_t,
                                              truncate_steps=truncate_steps)

            # soften deterministic output for backprop
            act = defaultdict(list)
            for act_t in seq_logits:
                for k, logits in act_t.items():
                    # if requires_grad, need gumbel-softmax, same as in explore with low temperature
                    if return_logits:
                        action = logits
                    else:
                        action, _ = self.selector.select_action(
                            logits,
                            explore=False,
                            hard=True,
                            reparameterize=requires_grad,
                            temperature=0.5)
                    act[k].append(action)
                    # act[k].append(self._soft_act(logits, requires_grad))
            act = {
                k: torch.stack(ac, 0).permute(1, 0, 2)
                for k, ac in act.items()
            }  # dict [(B,A)]*T -> dict (B,T,A)
        else:
            stacked_obs = obs.reshape(bs * ts, -1)  # (B*T,O)
            act, _ = pi(stacked_obs)  # act is dict of (B*T,A)
            # act is dict of (B,T,A)
            for k, logits in act.items():
                if return_logits:
                    action = logits
                else:
                    action, _ = self.selector.select_action(
                        logits,
                        explore=False,
                        hard=True,
                        reparameterize=requires_grad,
                        temperature=0.5)
                action = action.reshape(bs, ts, -1)
                act[k] = action
                # act[k] = self._soft_act(ac, requires_grad).reshape(bs, ts, -1)
        return act
Пример #4
0
    def evaluate_moa_action(self,
                            agent_j,
                            act_samples,
                            obs,
                            h_actor=None,
                            requires_grad=True,
                            contract_keys=None,
                            truncate_steps=-1):
        """ traininsg actor forward with specified policy 
        concat all actions to be fed in critics
        Arguments:
            agent_j: use j-th moa agent 
            act_samples: dict of (B,T,A), actions in sample
            obs: (B,T,O)
            requires_grad: if use _soft_act to differentiate discrete action
            contract_keys: 
                list of keys to contract dict on
                i.e. sum up all fields in log_prob, entropy, kl on given keys
        Returns:
            log_prob: action log probs (B,T,1)
            entropy: action entropy (B,T,1)
        """
        bs, ts, _ = obs.shape
        pi = self.moa_policies[agent_j]
        log_prob_d, entropy_d = {}, {}

        # get logits for current policy
        if self.rnn_policy:
            if h_actor is None:
                h_t = self.moa_hidden_states[agent_j].clone()  # (B,H)
            else:
                h_t = h_actor  #.clone()

            # rollout
            seq_logits = rnn_forward_sequence(pi,
                                              obs,
                                              h_t,
                                              truncate_steps=truncate_steps)

            # soften deterministic output for backprop
            act = defaultdict(list)
            for act_t in seq_logits:
                for k, logits in act_t.items():
                    act[k].append(logits)
                    # act[k].append(self._soft_act(logits, requires_grad))
            act = {
                k: torch.stack(ac, 0).permute(1, 0, 2)
                for k, ac in act.items()
            }  # dict [(B,A)]*T -> dict (B,T,A)
        else:
            stacked_obs = obs.reshape(bs * ts, -1)  # (B*T,O)
            act, _ = pi(stacked_obs)  # act is dict of (B*T,A)
            act = {k: ac.reshape(bs, ts, -1)
                   for k, ac in act.items()}  # dict of (B,T,A)

        if contract_keys is None:
            contract_keys = sorted(list(act.keys()))
        log_prob, entropy = 0.0, 0.0

        for k, seq_logits in act.items():
            if k not in contract_keys:
                continue
            action = act_samples[k]
            _, dist = self.selector.select_action(seq_logits,
                                                  explore=False,
                                                  hard=False,
                                                  reparameterize=False)
            # evaluate log prob (B,T) -> (B,T,1)
            # NOTE: attention!!! if log_prob on rsample action, backprop is done twice and wrong
            log_prob += dist.log_prob(action.clone().detach()).unsqueeze(-1)
            # get current action distrib entropy
            entropy += dist.entropy().unsqueeze(-1)

        return log_prob, entropy