Ejemplo n.º 1
0
 def act(self, state):
     body = self.body
     if self.normalize_state:
         state = policy_util.update_online_stats_and_normalize_state(
             body, state)
     action, action_pd = self.action_policy(state, self, body)
     body.action_tensor, body.action_pd = action, action_pd  # used for body.action_pd_update later
     if len(action.shape) == 0:  # scalar
         return action.cpu().numpy().astype(body.action_space.dtype).item()
     else:
         return action.cpu().numpy()
Ejemplo n.º 2
0
 def act(self, state):
     body = self.body
     if self.normalize_state:
         state = policy_util.update_online_stats_and_normalize_state(
             body, state)
     action, action_pd = self.action_policy(state, self, body)
     # sum for single and multi-action
     body.entropies.append(action_pd.entropy().sum(dim=0))
     body.log_probs.append(action_pd.log_prob(action.float()).sum(dim=0))
     assert not torch.isnan(body.log_probs[-1])
     if len(action.shape) == 0:  # scalar
         return action.cpu().numpy().astype(body.action_space.dtype).item()
     else:
         return action.cpu().numpy()
Ejemplo n.º 3
0
 def space_act(self, state_a):
     '''Non-atomizable act to override agent.act(), do a single pass on the entire state_a instead of composing act() via iteration'''
     # gather and flatten
     states = []
     for eb, body in util.ndenumerate_nonan(self.agent.body_a):
         state = state_a[eb]
         if self.normalize_state:
             state = policy_util.update_online_stats_and_normalize_state(body, state)
         states.append(state)
     xs = [torch.from_numpy(state).float() for state in states]
     pdparam = self.calc_pdparam(xs, evaluate=False)
     # use multi-policy. note arg change
     action_a, action_pd_a = self.action_policy(states, self, self.agent.nanflat_body_a, pdparam)
     for idx, body in enumerate(self.agent.nanflat_body_a):
         body.action_tensor, body.action_pd = action_a[idx], action_pd_a[idx]  # used for body.action_pd_update later
     return action_a.cpu().numpy()
Ejemplo n.º 4
0
 def space_act(self, state_a):
     '''Non-atomizable act to override agent.act(), do a single pass on the entire state_a instead of composing act() via iteration'''
     # gather and flatten
     states = []
     for eb, body in util.ndenumerate_nonan(self.agent.body_a):
         state = state_a[eb]
         if self.normalize_state:
             state = policy_util.update_online_stats_and_normalize_state(
                 body, state)
         states.append(state)
     state = torch.tensor(
         states, device=self.net.device).view(-1).unsqueeze_(0).float()
     pdparam = self.calc_pdparam(state, evaluate=False)
     # use multi-policy. note arg change
     action_a, action_pd_a = self.action_policy(states, self,
                                                self.agent.nanflat_body_a,
                                                pdparam)
     for idx, body in enumerate(self.agent.nanflat_body_a):
         action_pd = action_pd_a[idx]
         body.entropies.append(action_pd.entropy())
         body.log_probs.append(action_pd.log_prob(action_a[idx].float()))
         assert not torch.isnan(body.log_probs[-1])
     return action_a.cpu().numpy()