Ejemplo n.º 1
0
    def test_learn_scale_shift(self):
        num_outputs = 2
        initial_state, update = pop_art.popart(num_outputs,
                                               step_size=1e-1,
                                               scale_lb=1e-6,
                                               scale_ub=1e6)
        state = initial_state()
        params = get_constant_linear_params(num_outputs)
        targets = np.arange(6) - 3
        indices = np.asarray([0, 0, 0, 1, 1, 1])
        # Learn the parameters.
        for _ in range(10):
            _, state = update(params, state, targets, indices)

        expected_scale = np.std(targets[:3])
        expected_scale = np.asarray([expected_scale, expected_scale])
        expected_shift = np.asarray([-2., 1.])
        # Loose tolerances; just get close.
        np.testing.assert_allclose(state.scale,
                                   expected_scale,
                                   atol=1e-1,
                                   rtol=1e-1)
        np.testing.assert_allclose(state.shift,
                                   expected_shift,
                                   atol=1e-1,
                                   rtol=1e-1)
Ejemplo n.º 2
0
    def test_slow_update(self):
        num_outputs = 2
        # Two step sizes: 0.1, and 0.8
        kwargs = dict(
            num_outputs=num_outputs,
            scale_lb=1e-6,
            scale_ub=1e6,
        )
        initial_state, slow_update = pop_art.popart(step_size=1e-2, **kwargs)
        _, fast_update = pop_art.popart(step_size=1e-1, **kwargs)
        state = initial_state()
        params = get_constant_linear_params(num_outputs)
        targets = np.arange(6) * 3  # standard deviation > 1 and mean > 0
        indices = np.asarray([0, 0, 0, 1, 1, 1])
        _, slow_state = slow_update(params, state, targets, indices)
        _, fast_state = fast_update(params, state, targets, indices)

        # Faster step size means faster adjustment.
        np.testing.assert_array_less(slow_state.shift, fast_state.shift)
        np.testing.assert_array_less(slow_state.scale, fast_state.scale)
Ejemplo n.º 3
0
 def test_scale_bounded(self):
     num_outputs = 1
     # Set scale_lb and scale_ub to 1 and verify this is obeyed.
     initial_state, update = pop_art.popart(num_outputs,
                                            step_size=1e-1,
                                            scale_lb=1.,
                                            scale_ub=1.)
     state = initial_state()
     params = get_constant_linear_params(num_outputs)
     targets = np.ones((4, 2))
     indices = np.zeros((4, 2), dtype=np.int32)
     for _ in range(4):
         _, state = update(params, state, targets, indices)
         self.assertAlmostEqual(float(state.scale[0]), 1.)
Ejemplo n.º 4
0
    def test_outputs_preserved(self):
        num_outputs = 2
        initial_state, update = pop_art.popart(num_outputs,
                                               step_size=1e-3,
                                               scale_lb=1e-6,
                                               scale_ub=1e6)
        state = initial_state()
        key = jax.random.PRNGKey(428)

        def net(x):
            linear = hk.Linear(num_outputs,
                               b_init=initializers.RandomUniform(),
                               name='head')
            return linear(x)

        init_fn, apply_fn = hk.without_apply_rng(hk.transform(net))
        key, subkey1, subkey2 = jax.random.split(key, 3)
        fixed_data = jax.random.uniform(subkey1, (4, 3))
        params = init_fn(subkey2, fixed_data)
        initial_result = apply_fn(params, fixed_data)
        indices = np.asarray([0, 1, 0, 1, 0, 1, 0, 1])
        # Repeatedly update state and verify that params still preserve outputs.
        for _ in range(30):
            key, subkey1, subkey2 = jax.random.split(key, 3)
            targets = jax.random.uniform(subkey1, (8, ))
            linear_params, state = update(params['head'], state, targets,
                                          indices)
            params = data_structures.to_mutable_dict(params)
            params['head'] = linear_params

            # Apply updated linear transformation and unnormalize outputs.
            transform = apply_fn(params, fixed_data)
            out = jnp.broadcast_to(
                state.scale, transform.shape) * transform + jnp.broadcast_to(
                    state.shift, transform.shape)
            np.testing.assert_allclose(initial_result, out, atol=1e-2)