def recurrent_policy( params_list: List[networks_lib.Params], random_key: networks_lib.PRNGKey, observation: networks_lib.Observation, previous_trajectory: Trajectory, ) -> Tuple[networks_lib.Action, Trajectory]: # Note that splitting the random key is handled by GenericActor. if mean_std is not None: observation = running_statistics.normalize( observation, mean_std=mean_std_observation) trajectory = mppi.mppi_planner(config=mppi_config, world_model=planner_world_model, policy_prior=policy_prior, n_step_return=planner_n_step_return, world_model_params=params_list[0], policy_prior_params=params_list[1], n_step_return_params=params_list[2], random_key=random_key, observation=observation, previous_trajectory=previous_trajectory) action = trajectory[0, ...] if mean_std is not None: action = running_statistics.denormalize(action, mean_std=mean_std_action) return (action, trajectory)
def denormalized_n_step_return( params: networks_lib.Params, observation_t: networks_lib.Observation, action_t: networks_lib.Action) -> networks_lib.Value: """Denormalize the n-step return for proper weighting in the planner.""" normalized_n_step_return_t = n_step_return(params, observation_t, action_t) return running_statistics.denormalize(normalized_n_step_return_t, mean_std_n_step_return)
def denormalized_world_model( params: networks_lib.Params, observation_t: networks_lib.Observation, action_t: networks_lib.Action ) -> Tuple[networks_lib.Observation, networks_lib.Value]: """Denormalizes the reward for proper weighting in the planner.""" observation_tp1, normalized_reward_t = world_model( params, observation_t, action_t) reward_t = running_statistics.denormalize(normalized_reward_t, mean_std_reward) return observation_tp1, reward_t
def test_denormalize(self): state = running_statistics.init_state(specs.Array((5,), jnp.float32)) x = jnp.arange(100, dtype=jnp.float32).reshape(10, 2, 5) x1, x2 = jnp.split(x, 2, axis=0) state = update_and_validate(state, x1) state = update_and_validate(state, x2) normalized = running_statistics.normalize(x, state) mean = jnp.mean(normalized) std = jnp.std(normalized) self.assert_allclose(mean, jnp.zeros_like(mean)) self.assert_allclose(std, jnp.ones_like(std)) denormalized = running_statistics.denormalize(normalized, state) self.assert_allclose(denormalized, x)