Exemple #1
0
def get_normalization_stats(
    iterator: Iterator[types.Transition],
    num_normalization_batches: int = 50
) -> running_statistics.RunningStatisticsState:
    """Precomputes normalization statistics over a fixed number of batches.

  The iterator should contain batches of 3-transitions, i.e. with two leading
  dimensions, the first one denoting the batch dimension and the second one the
  previous, current and next timesteps. The statistics are calculated using the
  data of the previous timestep.

  Args:
    iterator: Iterator of batchs of 3-transitions.
    num_normalization_batches: Number of batches to calculate the statistics.

  Returns:
    RunningStatisticsState containing the normalization statistics.
  """
    # Set up normalization:
    example = next(iterator)
    unbatched_single_example = jax.tree_map(lambda x: x[0, PREVIOUS, :],
                                            example)
    mean_std = running_statistics.init_state(unbatched_single_example)

    for batch in itertools.islice(iterator, num_normalization_batches - 1):
        example = jax.tree_map(lambda x: x[:, PREVIOUS, :], batch)
        mean_std = running_statistics.update(mean_std, example)

    return mean_std
  def test_nested_normalize(self):
    state = running_statistics.init_state({
        'a': specs.Array((5,), jnp.float32),
        'b': specs.Array((2,), jnp.float32)
    })

    x1 = {
        'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5),
        'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2)
    }
    x2 = {
        'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5) + 20,
        'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2) + 8
    }
    x3 = {
        'a': jnp.arange(40, dtype=jnp.float32).reshape(4, 2, 5),
        'b': jnp.arange(16, dtype=jnp.float32).reshape(4, 2, 2)
    }

    state = update_and_validate(state, x1)
    state = update_and_validate(state, x2)
    state = update_and_validate(state, x3)
    normalized = running_statistics.normalize(x3, state)

    mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized)
    std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized)
    tree.map_structure(
        lambda x: self.assert_allclose(x, jnp.zeros_like(x)),
        mean)
    tree.map_structure(
        lambda x: self.assert_allclose(x, jnp.ones_like(x)),
        std)
  def test_init_normalize(self):
    state = running_statistics.init_state(specs.Array((5,), jnp.float32))

    x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5)
    normalized = running_statistics.normalize(x, state)

    self.assert_allclose(normalized, x)
  def test_pmap_update_nested(self):
    local_device_count = jax.local_device_count()
    state = running_statistics.init_state({
        'a': specs.Array((5,), jnp.float32),
        'b': specs.Array((2,), jnp.float32)
    })

    x = {
        'a': (jnp.arange(15 * local_device_count,
                         dtype=jnp.float32)).reshape(local_device_count, 3, 5),
        'b': (jnp.arange(6 * local_device_count,
                         dtype=jnp.float32)).reshape(local_device_count, 3, 2),
    }

    devices = jax.local_devices()
    state = jax.device_put_replicated(state, devices)
    pmap_axis_name = 'i'
    state = jax.pmap(
        functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name),
        pmap_axis_name)(state, x)
    state = jax.pmap(
        functools.partial(update_and_validate, pmap_axis_name=pmap_axis_name),
        pmap_axis_name)(state, x)
    normalized = jax.pmap(running_statistics.normalize)(x, state)

    mean = tree.map_structure(lambda x: jnp.mean(x, axis=(0, 1)), normalized)
    std = tree.map_structure(lambda x: jnp.std(x, axis=(0, 1)), normalized)
    tree.map_structure(
        lambda x: self.assert_allclose(x, jnp.zeros_like(x)), mean)
    tree.map_structure(
        lambda x: self.assert_allclose(x, jnp.ones_like(x)), std)
  def test_int_not_normalized(self):
    state = running_statistics.init_state(specs.Array((), jnp.int32))

    x = jnp.arange(5, dtype=jnp.int32)

    state = update_and_validate(state, x)
    normalized = running_statistics.normalize(x, state)

    np.testing.assert_array_equal(normalized, x)
  def test_validation(self):
    state = running_statistics.init_state(specs.Array((1, 2, 3), jnp.float32))

    x = jnp.arange(12, dtype=jnp.float32).reshape(2, 2, 3)
    with self.assertRaises(AssertionError):
      update_and_validate(state, x)

    x = jnp.arange(3, dtype=jnp.float32).reshape(1, 1, 3)
    with self.assertRaises(AssertionError):
      update_and_validate(state, x)
  def test_clip(self):
    state = running_statistics.init_state(specs.Array((), jnp.float32))

    x = jnp.arange(5, dtype=jnp.float32)

    state = update_and_validate(state, x)
    normalized = running_statistics.normalize(x, state, max_abs_value=1.0)

    mean = jnp.mean(normalized)
    std = jnp.std(normalized)
    self.assert_allclose(mean, jnp.zeros_like(mean))
    self.assert_allclose(std, jnp.ones_like(std) * math.sqrt(0.6))
  def test_one_batch_dim(self):
    state = running_statistics.init_state(specs.Array((5,), jnp.float32))

    x = jnp.arange(10, dtype=jnp.float32).reshape(2, 5)

    state = update_and_validate(state, x)
    normalized = running_statistics.normalize(x, state)

    mean = jnp.mean(normalized, axis=0)
    std = jnp.std(normalized, axis=0)
    self.assert_allclose(mean, jnp.zeros_like(mean))
    self.assert_allclose(std, jnp.ones_like(std))
  def test_different_structure_normalize(self):
    spec = TestNestedSpec(
        a=specs.Array((5,), jnp.float32), b=specs.Array((2,), jnp.float32))
    state = running_statistics.init_state(spec)

    x = {
        'a': jnp.arange(20, dtype=jnp.float32).reshape(2, 2, 5),
        'b': jnp.arange(8, dtype=jnp.float32).reshape(2, 2, 2)
    }

    with self.assertRaises(TypeError):
      state = update_and_validate(state, x)
Exemple #10
0
    def test_normalize_config(self):
        x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5)
        x_split = jnp.split(x, 5, axis=0)

        y = jnp.arange(160, dtype=jnp.float32).reshape(20, 2, 4)
        y_split = jnp.split(y, 5, axis=0)

        z = {'a': x, 'b': y}

        z_split = [{'a': xx, 'b': yy} for xx, yy in zip(x_split, y_split)]

        update = jax.jit(running_statistics.update,
                         static_argnames=('config', ))

        config = running_statistics.NestStatisticsConfig((('a', ), ))
        state = running_statistics.init_state({
            'a':
            specs.Array((5, ), jnp.float32),
            'b':
            specs.Array((4, ), jnp.float32)
        })
        # Test initialization from the first element.
        state = update(state, z_split[0], config=config)
        state = update(state, z_split[1], config=config)
        state = update(state, z_split[2], config=config)
        state = update(state, z_split[3], config=config)
        state = update(state, z_split[4], config=config)

        normalize = jax.jit(running_statistics.normalize)
        normalized = normalize(z, state)

        for key in normalized:
            mean = jnp.mean(normalized[key], axis=(0, 1))
            std = jnp.std(normalized[key], axis=(0, 1))
            if key == 'a':
                self.assert_allclose(
                    mean,
                    jnp.zeros_like(mean),
                    err_msg=
                    f'key:{key} mean:{mean} normalized:{normalized[key]}')
                self.assert_allclose(
                    std,
                    jnp.ones_like(std),
                    err_msg=f'key:{key} std:{std} normalized:{normalized[key]}'
                )
            else:
                assert key == 'b'
                np.testing.assert_array_equal(
                    normalized[key],
                    z[key],
                    err_msg=f'z:{z[key]} normalized:{normalized[key]}')
Exemple #11
0
    def __init__(self,
                 spec: specs.EnvironmentSpec,
                 networks: networks_lib.FeedForwardNetwork,
                 rng: networks_lib.PRNGKey,
                 config: ars_config.ARSConfig,
                 iterator: Iterator[reverb.ReplaySample],
                 counter: Optional[counting.Counter] = None,
                 logger: Optional[loggers.Logger] = None):

        self._config = config
        self._lock = threading.Lock()

        # General learner book-keeping and loggers.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.make_default_logger(
            'learner',
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            steps_key=self._counter.get_steps_key())

        # Iterator on demonstration transitions.
        self._iterator = iterator

        if self._config.normalize_observations:
            normalizer_params = running_statistics.init_state(
                spec.observations)
            self._normalizer_update_fn = running_statistics.update
        else:
            normalizer_params = ()
            self._normalizer_update_fn = lambda a, b: a

        rng1, rng2, tmp = jax.random.split(rng, 3)
        # Create initial state.
        self._training_state = TrainingState(
            key=rng1,
            policy_params=networks.init(tmp),
            normalizer_params=normalizer_params,
            training_iteration=0)
        self._evaluation_state = EvaluationState(
            key=rng2,
            evaluation_queue=collections.deque(),
            received_results={},
            noises=[])

        # Do not record timestamps until after the first learning step is done.
        # This is to avoid including the time it takes for actors to come online and
        # fill the replay buffer.
        self._timestamp = None
  def test_normalize(self):
    state = running_statistics.init_state(specs.Array((5,), jnp.float32))

    x = jnp.arange(200, dtype=jnp.float32).reshape(20, 2, 5)
    x1, x2, x3, x4 = jnp.split(x, 4, axis=0)

    state = update_and_validate(state, x1)
    state = update_and_validate(state, x2)
    state = update_and_validate(state, x3)
    state = update_and_validate(state, x4)
    normalized = running_statistics.normalize(x, state)

    mean = jnp.mean(normalized)
    std = jnp.std(normalized)
    self.assert_allclose(mean, jnp.zeros_like(mean))
    self.assert_allclose(std, jnp.ones_like(std))
  def test_weights(self):
    state = running_statistics.init_state(specs.Array((), jnp.float32))

    x = jnp.arange(5, dtype=jnp.float32)
    x_weights = jnp.ones_like(x)
    y = 2 * x + 5
    y_weights = 2 * x_weights
    z = jnp.concatenate([x, y])
    weights = jnp.concatenate([x_weights, y_weights])

    state = update_and_validate(state, z, weights=weights)

    self.assertEqual(state.mean, (jnp.mean(x) + 2 * jnp.mean(y)) / 3)
    big_z = jnp.concatenate([x, y, y])
    normalized = running_statistics.normalize(big_z, state)
    self.assertAlmostEqual(jnp.mean(normalized), 0., places=6)
    self.assertAlmostEqual(jnp.std(normalized), 1., places=6)
Exemple #14
0
    def __init__(self,
                 learner_factory: Callable[[Iterator[reverb.ReplaySample]],
                                           acme.Learner],
                 iterator: Iterator[reverb.ReplaySample],
                 environment_spec: specs.EnvironmentSpec,
                 is_sequence_based: bool, batch_dims: Optional[Tuple[int,
                                                                     ...]],
                 max_abs_observation: Optional[float]):
        def normalize_sample(
            observation_statistics: running_statistics.RunningStatisticsState,
            sample: reverb.ReplaySample
        ) -> Tuple[running_statistics.RunningStatisticsState,
                   reverb.ReplaySample]:
            observation = sample.data.observation
            observation_statistics = running_statistics.update(
                observation_statistics, observation)
            observation = running_statistics.normalize(
                observation,
                observation_statistics,
                max_abs_value=max_abs_observation)
            if is_sequence_based:
                assert not hasattr(sample.data, 'next_observation')
                sample = reverb.ReplaySample(
                    sample.info, sample.data._replace(observation=observation))
            else:
                next_observation = running_statistics.normalize(
                    sample.data.next_observation,
                    observation_statistics,
                    max_abs_value=max_abs_observation)
                sample = reverb.ReplaySample(
                    sample.info,
                    sample.data._replace(observation=observation,
                                         next_observation=next_observation))

            return observation_statistics, sample

        self._observation_running_statistics = running_statistics.init_state(
            environment_spec.observations)
        self._normalize_sample = jax.jit(normalize_sample)

        normalizing_iterator = (self._normalize_sample_and_update(sample)
                                for sample in iterator)
        self._wrapped_learner = learner_factory(normalizing_iterator)