def testNonPolynomialFunction(self): data_dims = 10 num_samples = 10**3 mean = jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = jnp.ones(shape=(data_dims), dtype=jnp.float32) params = [mean, log_scale] rng = jax.random.PRNGKey(1) dist = utils.multi_normal(*params) dist_samples = dist.sample((num_samples, ), rng) function = lambda x: jnp.sum(jnp.log(x**2)) cv, expected_cv, _ = control_variates.control_delta_method(function) avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples)) # Check that the average value of the control variate is close to the # expected value. _assert_equal(avg_cv, expected_cv(params, None), rtol=1e-1, atol=1e-3) # Second order expansion is log(\mu**2) + 1/2 * \sigma**2 (-2 / \mu**2) expected_cv_val = -np.exp(1.)**2 * data_dims _assert_equal(expected_cv(params, None), expected_cv_val, rtol=1e-1, atol=1e-3)
def testPolinomialFunction(self, effective_mean, effective_log_scale): data_dims = 10 num_samples = 10**3 mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = effective_log_scale * jnp.ones( shape=(data_dims), dtype=jnp.float32) params = [mean, log_scale] dist = utils.multi_normal(*params) rng = jax.random.PRNGKey(1) dist_samples = dist.sample((num_samples,), rng) function = lambda x: jnp.sum(x**5) cv, expected_cv, _ = control_variates.control_delta_method(function) avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples)) # Check that the average value of the control variate is close to the # expected value. _assert_equal(avg_cv, expected_cv(params, None), rtol=1e-1, atol=1e-3)
def testQuadraticFunction(self, effective_mean, effective_log_scale): data_dims = 20 num_samples = 10**6 rng = jax.random.PRNGKey(1) mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = effective_log_scale * jnp.ones( shape=(data_dims), dtype=jnp.float32) params = [mean, log_scale] dist = utils.multi_normal(*params) dist_samples = dist.sample((num_samples,), rng) function = lambda x: jnp.sum(x**2) cv, expected_cv, _ = control_variates.control_delta_method(function) avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples)) expected_cv_value = jnp.sum(dist_samples**2) / num_samples # This should be an analytical computation, the result needs to be # accurate. _assert_equal(avg_cv, expected_cv_value, rtol=1e-1, atol=1e-3) _assert_equal(expected_cv(params, None), expected_cv_value, atol=1e-1)