def testLinearFunctionZeroDebias(
      self, effective_mean, effective_log_scale, decay):
    weights = jnp.array([1., 2., 3.], dtype=jnp.float32)
    num_samples = 10**5
    data_dims = len(weights)

    mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32)
    log_scale = effective_log_scale * jnp.ones(
        shape=(data_dims), dtype=jnp.float32)

    params = [mean, log_scale]
    function = lambda x: jnp.sum(weights * x)

    rng = jax.random.PRNGKey(1)
    dist = utils.multi_normal(*params)
    dist_samples = dist.sample((num_samples,), rng)

    update_state = control_variates.moving_avg_baseline(
        function, decay=decay, zero_debias=False,
        use_decay_early_training_heuristic=False)[-1]

    update_state_zero_debias = control_variates.moving_avg_baseline(
        function, decay=decay, zero_debias=True,
        use_decay_early_training_heuristic=False)[-1]

    updated_state = update_state(params, dist_samples, (jnp.array(0.), 0))[0]
    _assert_equal(updated_state, (1 - decay) * function(mean))

    updated_state_zero_debias = update_state_zero_debias(
        params, dist_samples, (jnp.array(0.), 0))[0]
    _assert_equal(
        updated_state_zero_debias, function(mean))
Exemple #2
0
    def testNonPolynomialFunction(self):
        data_dims = 10
        num_samples = 10**3

        mean = jnp.ones(shape=(data_dims), dtype=jnp.float32)
        log_scale = jnp.ones(shape=(data_dims), dtype=jnp.float32)
        params = [mean, log_scale]

        rng = jax.random.PRNGKey(1)
        dist = utils.multi_normal(*params)
        dist_samples = dist.sample((num_samples, ), rng)
        function = lambda x: jnp.sum(jnp.log(x**2))

        cv, expected_cv, _ = control_variates.control_delta_method(function)
        avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples))

        # Check that the average value of the control variate is close to the
        # expected value.
        _assert_equal(avg_cv, expected_cv(params, None), rtol=1e-1, atol=1e-3)

        # Second order expansion is log(\mu**2) + 1/2 * \sigma**2 (-2 / \mu**2)
        expected_cv_val = -np.exp(1.)**2 * data_dims
        _assert_equal(expected_cv(params, None),
                      expected_cv_val,
                      rtol=1e-1,
                      atol=1e-3)
Exemple #3
0
    def testLinearFunctionWithHeuristic(self, effective_mean,
                                        effective_log_scale, decay):
        weights = jnp.array([1., 2., 3.], dtype=jnp.float32)
        num_samples = 10**5
        data_dims = len(weights)

        mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32)
        log_scale = effective_log_scale * jnp.ones(shape=(data_dims),
                                                   dtype=jnp.float32)

        params = [mean, log_scale]
        function = lambda x: jnp.sum(weights * x)

        rng = jax.random.PRNGKey(1)
        dist = utils.multi_normal(*params)
        dist_samples = dist.sample((num_samples, ), rng)

        cv, expected_cv, update_state = control_variates.moving_avg_baseline(
            function,
            decay=decay,
            zero_debias=False,
            use_decay_early_training_heuristic=True)

        state_1 = jnp.array(1.)
        avg_cv = jnp.mean(
            _map_variant(self.variant)(cv, params, dist_samples, (state_1, 0)))
        _assert_equal(avg_cv, state_1)
        _assert_equal(expected_cv(params, (state_1, 0)), state_1)

        state_2 = jnp.array(2.)
        avg_cv = jnp.mean(
            _map_variant(self.variant)(cv, params, dist_samples, (state_2, 0)))
        _assert_equal(avg_cv, state_2)
        _assert_equal(expected_cv(params, (state_2, 0)), state_2)

        first_step_decay = 0.1
        update_state_1 = update_state(params, dist_samples, (state_1, 0))[0]
        _assert_equal(
            update_state_1, first_step_decay * state_1 +
            (1 - first_step_decay) * function(mean))

        second_step_decay = 2. / 11
        update_state_2 = update_state(params, dist_samples, (state_2, 1))[0]
        _assert_equal(
            update_state_2, second_step_decay * state_2 +
            (1 - second_step_decay) * function(mean))
  def testPolinomialFunction(self, effective_mean, effective_log_scale):
    data_dims = 10
    num_samples = 10**3

    mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32)
    log_scale = effective_log_scale * jnp.ones(
        shape=(data_dims), dtype=jnp.float32)
    params = [mean, log_scale]

    dist = utils.multi_normal(*params)
    rng = jax.random.PRNGKey(1)
    dist_samples = dist.sample((num_samples,), rng)
    function = lambda x: jnp.sum(x**5)

    cv, expected_cv, _ = control_variates.control_delta_method(function)
    avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples))

    # Check that the average value of the control variate is close to the
    # expected value.
    _assert_equal(avg_cv, expected_cv(params, None), rtol=1e-1, atol=1e-3)
  def testQuadraticFunction(self, effective_mean, effective_log_scale):
    data_dims = 20
    num_samples = 10**6
    rng = jax.random.PRNGKey(1)

    mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32)
    log_scale = effective_log_scale * jnp.ones(
        shape=(data_dims), dtype=jnp.float32)
    params = [mean, log_scale]

    dist = utils.multi_normal(*params)
    dist_samples = dist.sample((num_samples,), rng)
    function = lambda x: jnp.sum(x**2)

    cv, expected_cv, _ = control_variates.control_delta_method(function)
    avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples))
    expected_cv_value = jnp.sum(dist_samples**2) / num_samples

    # This should be an analytical computation, the result needs to be
    # accurate.
    _assert_equal(avg_cv, expected_cv_value, rtol=1e-1, atol=1e-3)
    _assert_equal(expected_cv(params, None), expected_cv_value, atol=1e-1)