def get_action(self, observation): if self.state_include_action: if self.prev_action is None: prev_action = np.zeros((self.action_space.flat_dim, )) else: prev_action = self.action_space.flatten(self.prev_action) all_input = np.concatenate( [self.observation_space.flatten(observation), prev_action]) else: all_input = self.observation_space.flatten(observation) # should not be used prev_action = np.nan probs, hidden_vec = [ x[0] for x in self.f_step_prob([all_input], [self.prev_hidden]) ] action = special.weighted_sample(probs, range(self.action_space.n)) self.prev_action = action self.prev_hidden = hidden_vec agent_info = dict(prob=probs) if self.state_include_action: agent_info["prev_action"] = prev_action return action, agent_info
def weighted_sample(space, weights): if isinstance(space, gym.spaces.Discrete): return special.weighted_sample(weights, range(space.n)) else: raise NotImplementedError
def weighted_sample(self, weights): return special.weighted_sample(weights, range(self.n))