Example #1
0
 def recurrent_policy(
     params_list: List[networks_lib.Params],
     random_key: networks_lib.PRNGKey,
     observation: networks_lib.Observation,
     previous_trajectory: Trajectory,
 ) -> Tuple[networks_lib.Action, Trajectory]:
     # Note that splitting the random key is handled by GenericActor.
     if mean_std is not None:
         observation = running_statistics.normalize(
             observation, mean_std=mean_std_observation)
     trajectory = mppi.mppi_planner(config=mppi_config,
                                    world_model=planner_world_model,
                                    policy_prior=policy_prior,
                                    n_step_return=planner_n_step_return,
                                    world_model_params=params_list[0],
                                    policy_prior_params=params_list[1],
                                    n_step_return_params=params_list[2],
                                    random_key=random_key,
                                    observation=observation,
                                    previous_trajectory=previous_trajectory)
     action = trajectory[0, ...]
     if mean_std is not None:
         action = running_statistics.denormalize(action,
                                                 mean_std=mean_std_action)
     return (action, trajectory)
Example #2
0
 def denormalized_n_step_return(
         params: networks_lib.Params,
         observation_t: networks_lib.Observation,
         action_t: networks_lib.Action) -> networks_lib.Value:
     """Denormalize the n-step return for proper weighting in the planner."""
     normalized_n_step_return_t = n_step_return(params, observation_t,
                                                action_t)
     return running_statistics.denormalize(normalized_n_step_return_t,
                                           mean_std_n_step_return)
Example #3
0
 def denormalized_world_model(
     params: networks_lib.Params,
     observation_t: networks_lib.Observation,
     action_t: networks_lib.Action
 ) -> Tuple[networks_lib.Observation, networks_lib.Value]:
     """Denormalizes the reward for proper weighting in the planner."""
     observation_tp1, normalized_reward_t = world_model(
         params, observation_t, action_t)
     reward_t = running_statistics.denormalize(normalized_reward_t,
                                               mean_std_reward)
     return observation_tp1, reward_t
  def test_denormalize(self):
    state = running_statistics.init_state(specs.Array((5,), jnp.float32))

    x = jnp.arange(100, dtype=jnp.float32).reshape(10, 2, 5)
    x1, x2 = jnp.split(x, 2, axis=0)

    state = update_and_validate(state, x1)
    state = update_and_validate(state, x2)
    normalized = running_statistics.normalize(x, state)

    mean = jnp.mean(normalized)
    std = jnp.std(normalized)
    self.assert_allclose(mean, jnp.zeros_like(mean))
    self.assert_allclose(std, jnp.ones_like(std))

    denormalized = running_statistics.denormalize(normalized, state)
    self.assert_allclose(denormalized, x)