Пример #1
0
        def normalize_sample(
            observation_statistics: running_statistics.RunningStatisticsState,
            sample: reverb.ReplaySample
        ) -> Tuple[running_statistics.RunningStatisticsState,
                   reverb.ReplaySample]:
            observation = sample.data.observation
            observation_statistics = running_statistics.update(
                observation_statistics, observation)
            observation = running_statistics.normalize(
                observation,
                observation_statistics,
                max_abs_value=max_abs_observation)
            if is_sequence_based:
                assert not hasattr(sample.data, 'next_observation')
                sample = reverb.ReplaySample(
                    sample.info, sample.data._replace(observation=observation))
            else:
                next_observation = running_statistics.normalize(
                    sample.data.next_observation,
                    observation_statistics,
                    max_abs_value=max_abs_observation)
                sample = reverb.ReplaySample(
                    sample.info,
                    sample.data._replace(observation=observation,
                                         next_observation=next_observation))

            return observation_statistics, sample
Пример #2
0
  def test_init_normalize(self):
    state = running_statistics.init_state(specs.Array((5,), jnp.float32))

    x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5)
    normalized = running_statistics.normalize(x, state)

    self.assert_allclose(normalized, x)
Пример #3
0
  def test_nested_normalize(self):
    state = running_statistics.init_state({
        'a': specs.Array((5,), jnp.float32),
        'b': specs.Array((2,), jnp.float32)
    })

    x1 = {
        'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5),
        'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2)
    }
    x2 = {
        'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5) + 20,
        'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2) + 8
    }
    x3 = {
        'a': jnp.arange(40, dtype=jnp.float32).reshape(4, 2, 5),
        'b': jnp.arange(16, dtype=jnp.float32).reshape(4, 2, 2)
    }

    state = update_and_validate(state, x1)
    state = update_and_validate(state, x2)
    state = update_and_validate(state, x3)
    normalized = running_statistics.normalize(x3, state)

    mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized)
    std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized)
    tree.map_structure(
        lambda x: self.assert_allclose(x, jnp.zeros_like(x)),
        mean)
    tree.map_structure(
        lambda x: self.assert_allclose(x, jnp.ones_like(x)),
        std)
Пример #4
0
 def recurrent_policy(
     params_list: List[networks_lib.Params],
     random_key: networks_lib.PRNGKey,
     observation: networks_lib.Observation,
     previous_trajectory: Trajectory,
 ) -> Tuple[networks_lib.Action, Trajectory]:
     # Note that splitting the random key is handled by GenericActor.
     if mean_std is not None:
         observation = running_statistics.normalize(
             observation, mean_std=mean_std_observation)
     trajectory = mppi.mppi_planner(config=mppi_config,
                                    world_model=planner_world_model,
                                    policy_prior=policy_prior,
                                    n_step_return=planner_n_step_return,
                                    world_model_params=params_list[0],
                                    policy_prior_params=params_list[1],
                                    n_step_return_params=params_list[2],
                                    random_key=random_key,
                                    observation=observation,
                                    previous_trajectory=previous_trajectory)
     action = trajectory[0, ...]
     if mean_std is not None:
         action = running_statistics.denormalize(action,
                                                 mean_std=mean_std_action)
     return (action, trajectory)
Пример #5
0
  def test_int_not_normalized(self):
    state = running_statistics.init_state(specs.Array((), jnp.int32))

    x = jnp.arange(5, dtype=jnp.int32)

    state = update_and_validate(state, x)
    normalized = running_statistics.normalize(x, state)

    np.testing.assert_array_equal(normalized, x)
Пример #6
0
  def test_clip(self):
    state = running_statistics.init_state(specs.Array((), jnp.float32))

    x = jnp.arange(5, dtype=jnp.float32)

    state = update_and_validate(state, x)
    normalized = running_statistics.normalize(x, state, max_abs_value=1.0)

    mean = jnp.mean(normalized)
    std = jnp.std(normalized)
    self.assert_allclose(mean, jnp.zeros_like(mean))
    self.assert_allclose(std, jnp.ones_like(std) * math.sqrt(0.6))
Пример #7
0
  def test_one_batch_dim(self):
    state = running_statistics.init_state(specs.Array((5,), jnp.float32))

    x = jnp.arange(10, dtype=jnp.float32).reshape(2, 5)

    state = update_and_validate(state, x)
    normalized = running_statistics.normalize(x, state)

    mean = jnp.mean(normalized, axis=0)
    std = jnp.std(normalized, axis=0)
    self.assert_allclose(mean, jnp.zeros_like(mean))
    self.assert_allclose(std, jnp.ones_like(std))
Пример #8
0
  def test_normalize(self):
    state = running_statistics.init_state(specs.Array((5,), jnp.float32))

    x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5)
    x1, x2, x3, x4 = jnp.split(x, 4, axis=0)

    state = update_and_validate(state, x1)
    state = update_and_validate(state, x2)
    state = update_and_validate(state, x3)
    state = update_and_validate(state, x4)
    normalized = running_statistics.normalize(x, state)

    mean = jnp.mean(normalized)
    std = jnp.std(normalized)
    self.assert_allclose(mean, jnp.zeros_like(mean))
    self.assert_allclose(std, jnp.ones_like(std))
Пример #9
0
  def test_weights(self):
    state = running_statistics.init_state(specs.Array((), jnp.float32))

    x = jnp.arange(5, dtype=jnp.float32)
    x_weights = jnp.ones_like(x)
    y = 2 * x + 5
    y_weights = 2 * x_weights
    z = jnp.concatenate([x, y])
    weights = jnp.concatenate([x_weights, y_weights])

    state = update_and_validate(state, z, weights=weights)

    self.assertEqual(state.mean, (jnp.mean(x) + 2 * jnp.mean(y)) / 3)
    big_z = jnp.concatenate([x, y, y])
    normalized = running_statistics.normalize(big_z, state)
    self.assertAlmostEqual(jnp.mean(normalized), 0., places=6)
    self.assertAlmostEqual(jnp.std(normalized), 1., places=6)