Example #1
0
 def get_action(self, observation):
     if self.state_include_action:
         if self.prev_action is None:
             prev_action = np.zeros((self.action_space.flat_dim, ))
         else:
             prev_action = self.action_space.flatten(self.prev_action)
         all_input = np.concatenate(
             [self.observation_space.flatten(observation), prev_action])
     else:
         all_input = self.observation_space.flatten(observation)
         # should not be used
         prev_action = np.nan
     probs, hidden_vec = [
         x[0] for x in self.f_step_prob([all_input], [self.prev_hidden])
     ]
     action = special.weighted_sample(probs, range(self.action_space.n))
     self.prev_action = action
     self.prev_hidden = hidden_vec
     agent_info = dict(prob=probs)
     if self.state_include_action:
         agent_info["prev_action"] = prev_action
     return action, agent_info
Example #2
0
def weighted_sample(space, weights):
    if isinstance(space, gym.spaces.Discrete):
        return special.weighted_sample(weights, range(space.n))
    else:
        raise NotImplementedError
Example #3
0
 def weighted_sample(self, weights):
     return special.weighted_sample(weights, range(self.n))