Esempio n. 1
0
    def testNonPolynomialFunction(self):
        data_dims = 10
        num_samples = 10**3

        mean = jnp.ones(shape=(data_dims), dtype=jnp.float32)
        log_scale = jnp.ones(shape=(data_dims), dtype=jnp.float32)
        params = [mean, log_scale]

        rng = jax.random.PRNGKey(1)
        dist = utils.multi_normal(*params)
        dist_samples = dist.sample((num_samples, ), rng)
        function = lambda x: jnp.sum(jnp.log(x**2))

        cv, expected_cv, _ = control_variates.control_delta_method(function)
        avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples))

        # Check that the average value of the control variate is close to the
        # expected value.
        _assert_equal(avg_cv, expected_cv(params, None), rtol=1e-1, atol=1e-3)

        # Second order expansion is log(\mu**2) + 1/2 * \sigma**2 (-2 / \mu**2)
        expected_cv_val = -np.exp(1.)**2 * data_dims
        _assert_equal(expected_cv(params, None),
                      expected_cv_val,
                      rtol=1e-1,
                      atol=1e-3)
Esempio n. 2
0
  def testPolinomialFunction(self, effective_mean, effective_log_scale):
    data_dims = 10
    num_samples = 10**3

    mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32)
    log_scale = effective_log_scale * jnp.ones(
        shape=(data_dims), dtype=jnp.float32)
    params = [mean, log_scale]

    dist = utils.multi_normal(*params)
    rng = jax.random.PRNGKey(1)
    dist_samples = dist.sample((num_samples,), rng)
    function = lambda x: jnp.sum(x**5)

    cv, expected_cv, _ = control_variates.control_delta_method(function)
    avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples))

    # Check that the average value of the control variate is close to the
    # expected value.
    _assert_equal(avg_cv, expected_cv(params, None), rtol=1e-1, atol=1e-3)
Esempio n. 3
0
  def testQuadraticFunction(self, effective_mean, effective_log_scale):
    data_dims = 20
    num_samples = 10**6
    rng = jax.random.PRNGKey(1)

    mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32)
    log_scale = effective_log_scale * jnp.ones(
        shape=(data_dims), dtype=jnp.float32)
    params = [mean, log_scale]

    dist = utils.multi_normal(*params)
    dist_samples = dist.sample((num_samples,), rng)
    function = lambda x: jnp.sum(x**2)

    cv, expected_cv, _ = control_variates.control_delta_method(function)
    avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples))
    expected_cv_value = jnp.sum(dist_samples**2) / num_samples

    # This should be an analytical computation, the result needs to be
    # accurate.
    _assert_equal(avg_cv, expected_cv_value, rtol=1e-1, atol=1e-3)
    _assert_equal(expected_cv(params, None), expected_cv_value, atol=1e-1)