def normalize_sample( observation_statistics: running_statistics.RunningStatisticsState, sample: reverb.ReplaySample ) -> Tuple[running_statistics.RunningStatisticsState, reverb.ReplaySample]: observation = sample.data.observation observation_statistics = running_statistics.update( observation_statistics, observation) observation = running_statistics.normalize( observation, observation_statistics, max_abs_value=max_abs_observation) if is_sequence_based: assert not hasattr(sample.data, 'next_observation') sample = reverb.ReplaySample( sample.info, sample.data._replace(observation=observation)) else: next_observation = running_statistics.normalize( sample.data.next_observation, observation_statistics, max_abs_value=max_abs_observation) sample = reverb.ReplaySample( sample.info, sample.data._replace(observation=observation, next_observation=next_observation)) return observation_statistics, sample
def test_init_normalize(self): state = running_statistics.init_state(specs.Array((5,), jnp.float32)) x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5) normalized = running_statistics.normalize(x, state) self.assert_allclose(normalized, x)
def test_nested_normalize(self): state = running_statistics.init_state({ 'a': specs.Array((5,), jnp.float32), 'b': specs.Array((2,), jnp.float32) }) x1 = { 'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5), 'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2) } x2 = { 'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5) + 20, 'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2) + 8 } x3 = { 'a': jnp.arange(40, dtype=jnp.float32).reshape(4, 2, 5), 'b': jnp.arange(16, dtype=jnp.float32).reshape(4, 2, 2) } state = update_and_validate(state, x1) state = update_and_validate(state, x2) state = update_and_validate(state, x3) normalized = running_statistics.normalize(x3, state) mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized) std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized) tree.map_structure( lambda x: self.assert_allclose(x, jnp.zeros_like(x)), mean) tree.map_structure( lambda x: self.assert_allclose(x, jnp.ones_like(x)), std)
def recurrent_policy( params_list: List[networks_lib.Params], random_key: networks_lib.PRNGKey, observation: networks_lib.Observation, previous_trajectory: Trajectory, ) -> Tuple[networks_lib.Action, Trajectory]: # Note that splitting the random key is handled by GenericActor. if mean_std is not None: observation = running_statistics.normalize( observation, mean_std=mean_std_observation) trajectory = mppi.mppi_planner(config=mppi_config, world_model=planner_world_model, policy_prior=policy_prior, n_step_return=planner_n_step_return, world_model_params=params_list[0], policy_prior_params=params_list[1], n_step_return_params=params_list[2], random_key=random_key, observation=observation, previous_trajectory=previous_trajectory) action = trajectory[0, ...] if mean_std is not None: action = running_statistics.denormalize(action, mean_std=mean_std_action) return (action, trajectory)
def test_int_not_normalized(self): state = running_statistics.init_state(specs.Array((), jnp.int32)) x = jnp.arange(5, dtype=jnp.int32) state = update_and_validate(state, x) normalized = running_statistics.normalize(x, state) np.testing.assert_array_equal(normalized, x)
def test_clip(self): state = running_statistics.init_state(specs.Array((), jnp.float32)) x = jnp.arange(5, dtype=jnp.float32) state = update_and_validate(state, x) normalized = running_statistics.normalize(x, state, max_abs_value=1.0) mean = jnp.mean(normalized) std = jnp.std(normalized) self.assert_allclose(mean, jnp.zeros_like(mean)) self.assert_allclose(std, jnp.ones_like(std) * math.sqrt(0.6))
def test_one_batch_dim(self): state = running_statistics.init_state(specs.Array((5,), jnp.float32)) x = jnp.arange(10, dtype=jnp.float32).reshape(2, 5) state = update_and_validate(state, x) normalized = running_statistics.normalize(x, state) mean = jnp.mean(normalized, axis=0) std = jnp.std(normalized, axis=0) self.assert_allclose(mean, jnp.zeros_like(mean)) self.assert_allclose(std, jnp.ones_like(std))
def test_normalize(self): state = running_statistics.init_state(specs.Array((5,), jnp.float32)) x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5) x1, x2, x3, x4 = jnp.split(x, 4, axis=0) state = update_and_validate(state, x1) state = update_and_validate(state, x2) state = update_and_validate(state, x3) state = update_and_validate(state, x4) normalized = running_statistics.normalize(x, state) mean = jnp.mean(normalized) std = jnp.std(normalized) self.assert_allclose(mean, jnp.zeros_like(mean)) self.assert_allclose(std, jnp.ones_like(std))
def test_weights(self): state = running_statistics.init_state(specs.Array((), jnp.float32)) x = jnp.arange(5, dtype=jnp.float32) x_weights = jnp.ones_like(x) y = 2 * x + 5 y_weights = 2 * x_weights z = jnp.concatenate([x, y]) weights = jnp.concatenate([x_weights, y_weights]) state = update_and_validate(state, z, weights=weights) self.assertEqual(state.mean, (jnp.mean(x) + 2 * jnp.mean(y)) / 3) big_z = jnp.concatenate([x, y, y]) normalized = running_statistics.normalize(big_z, state) self.assertAlmostEqual(jnp.mean(normalized), 0., places=6) self.assertAlmostEqual(jnp.std(normalized), 1., places=6)