コード例 #1
0
 def __call__(self, batch):
     not_terminal = 1.0 - batch.terminal.float()
     # normalize actions
     action = rescale_actions(
         batch.action,
         new_min=self.train_low,
         new_max=self.train_high,
         prev_min=self.action_low,
         prev_max=self.action_high,
     )
     # only normalize non-terminal
     non_terminal_indices = (batch.terminal == 0).squeeze(1)
     next_action = torch.zeros_like(action)
     next_action[non_terminal_indices] = rescale_actions(
         batch.next_action[non_terminal_indices],
         new_min=self.train_low,
         new_max=self.train_high,
         prev_min=self.action_low,
         prev_max=self.action_high,
     )
     dict_batch = {
         InputColumn.STATE_FEATURES: batch.state,
         InputColumn.NEXT_STATE_FEATURES: batch.next_state,
         InputColumn.ACTION: action,
         InputColumn.NEXT_ACTION: next_action,
         InputColumn.REWARD: batch.reward,
         InputColumn.NOT_TERMINAL: not_terminal,
         InputColumn.STEP: None,
         InputColumn.TIME_DIFF: None,
         InputColumn.EXTRAS: rlt.ExtraData(
             mdp_id=None,
             sequence_number=None,
             action_probability=batch.log_prob.exp(),
             max_num_actions=None,
             metrics=None,
         ),
     }
     has_candidate_features = False
     try:
         dict_batch.update(
             {
                 InputColumn.CANDIDATE_FEATURES: batch.doc,
                 InputColumn.NEXT_CANDIDATE_FEATURES: batch.next_doc,
             }
         )
         has_candidate_features = True
     except AttributeError:
         pass
     output = rlt.PolicyNetworkInput.from_dict(dict_batch)
     if has_candidate_features:
         output.state = rlt._embed_states(output.state)
         output.next_state = rlt._embed_states(output.next_state)
     return output
コード例 #2
0
 def action_extractor(self, actor_output: rlt.ActorOutput) -> torch.Tensor:
     action = actor_output.action
     action_space = self.action_space
     # Canonical rule to return one-hot encoded actions for discrete
     assert (len(action.shape) == 2 and action.shape[0] == 1
             ), f"{action} (shape: {action.shape}) is not a single action!"
     if isinstance(action_space, spaces.Discrete):
         # pyre-fixme[16]: `Tensor` has no attribute `argmax`.
         return action.squeeze(0).argmax()
     elif isinstance(action_space, spaces.MultiDiscrete):
         return action.squeeze(0)
     # Canonical rule to scale actions to CONTINUOUS_TRAINING_ACTION_RANGE
     elif isinstance(action_space, spaces.Box):
         assert len(
             action_space.shape) == 1, f"{action_space} not supported."
         return rescale_actions(
             action.squeeze(0),
             new_min=torch.tensor(action_space.low),
             new_max=torch.tensor(action_space.high),
             prev_min=CONTINUOUS_MODEL_LOW,
             prev_max=CONTINUOUS_MODEL_HIGH,
         )
     else:
         raise NotImplementedError(
             f"Unsupported action space: {action_space}")
コード例 #3
0
 def box_action_extractor(actor_output: rlt.ActorOutput) -> np.ndarray:
     action = actor_output.action
     assert (len(action.shape) == 2 and action.shape[0] == 1
             ), f"{action} (shape: {action.shape}) is not a single action!"
     return rescale_actions(
         action.squeeze(0).cpu().numpy(),
         new_min=action_space.low,
         new_max=action_space.high,
         prev_min=model_low,
         prev_max=model_high,
     )
コード例 #4
0
 def __call__(self, batch):
     not_terminal = 1.0 - batch.terminal.float()
     # normalize actions
     (train_low, train_high) = CONTINUOUS_TRAINING_ACTION_RANGE
     action = torch.tensor(
         rescale_actions(
             batch.action.numpy(),
             new_min=train_low,
             new_max=train_high,
             prev_min=self.action_low,
             prev_max=self.action_high,
         ))
     # only normalize non-terminal
     non_terminal_indices = (batch.terminal == 0).squeeze(1)
     next_action = torch.zeros_like(action)
     next_action[non_terminal_indices] = torch.tensor(
         rescale_actions(
             batch.next_action[non_terminal_indices].numpy(),
             new_min=train_low,
             new_max=train_high,
             prev_min=self.action_low,
             prev_max=self.action_high,
         ))
     return rlt.PolicyNetworkInput(
         state=rlt.FeatureData(float_features=batch.state),
         action=rlt.FeatureData(float_features=action),
         next_state=rlt.FeatureData(float_features=batch.next_state),
         next_action=rlt.FeatureData(float_features=next_action),
         reward=batch.reward,
         not_terminal=not_terminal,
         step=None,
         time_diff=None,
         extras=rlt.ExtraData(
             mdp_id=None,
             sequence_number=None,
             action_probability=batch.log_prob.exp(),
             max_num_actions=None,
             metrics=None,
         ),
     )
コード例 #5
0
ファイル: cem_planner.py プロジェクト: phillip1029/ReAgent
    def continuous_planning(self, state: rlt.FeatureData) -> torch.Tensor:
        # TODO: Warmstarts means and vars using previous solutions (T48841404)
        mean = (self.action_upper_bounds + self.action_lower_bounds) / 2
        var = (self.action_upper_bounds - self.action_lower_bounds) ** 2 / 16
        # pyre-fixme[29]: `truncnorm_gen` is not a function.
        normal_sampler = stats.truncnorm(
            -2, 2, loc=np.zeros_like(mean), scale=np.ones_like(mean)
        )

        for i in range(self.cem_num_iterations):
            logger.debug(f"{i}-th cem iteration.")
            const_var = self.constrained_variance(mean, var)
            solutions = (
                normal_sampler.rvs(
                    size=[self.cem_pop_size, self.action_dim * self.plan_horizon_length]
                )
                * np.sqrt(const_var)
                + mean
            )
            action_solutions = torch.from_numpy(
                solutions.reshape(
                    (self.cem_pop_size, self.plan_horizon_length, self.action_dim)
                )
            ).float()
            acc_rewards = self.acc_rewards_of_all_solutions(state, action_solutions)
            elites = solutions[np.argsort(acc_rewards)][-self.num_elites :]
            new_mean = np.mean(elites, axis=0)
            new_var = np.var(elites, axis=0)
            mean = self.alpha * mean + (1 - self.alpha) * new_mean
            var = self.alpha * var + (1 - self.alpha) * new_var

            if np.max(var) <= self.epsilon:
                break

        # Pick the first action of the optimal solution
        solution = mean[: self.action_dim]
        raw_action = solution.reshape(-1)
        low = torch.tensor(CONTINUOUS_TRAINING_ACTION_RANGE[0])
        high = torch.tensor(CONTINUOUS_TRAINING_ACTION_RANGE[1])
        # rescale to range (-1, 1) as per canonical output range of continuous agents
        return rescale_actions(
            torch.tensor(raw_action),
            new_min=low,
            new_max=high,
            prev_min=self.orig_action_lower,
            prev_max=self.orig_action_upper,
        )