def testLinearFunctionZeroDebias( self, effective_mean, effective_log_scale, decay): weights = jnp.array([1., 2., 3.], dtype=jnp.float32) num_samples = 10**5 data_dims = len(weights) mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = effective_log_scale * jnp.ones( shape=(data_dims), dtype=jnp.float32) params = [mean, log_scale] function = lambda x: jnp.sum(weights * x) rng = jax.random.PRNGKey(1) dist = utils.multi_normal(*params) dist_samples = dist.sample((num_samples,), rng) update_state = control_variates.moving_avg_baseline( function, decay=decay, zero_debias=False, use_decay_early_training_heuristic=False)[-1] update_state_zero_debias = control_variates.moving_avg_baseline( function, decay=decay, zero_debias=True, use_decay_early_training_heuristic=False)[-1] updated_state = update_state(params, dist_samples, (jnp.array(0.), 0))[0] _assert_equal(updated_state, (1 - decay) * function(mean)) updated_state_zero_debias = update_state_zero_debias( params, dist_samples, (jnp.array(0.), 0))[0] _assert_equal( updated_state_zero_debias, function(mean))
def testNonPolynomialFunction(self): data_dims = 10 num_samples = 10**3 mean = jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = jnp.ones(shape=(data_dims), dtype=jnp.float32) params = [mean, log_scale] rng = jax.random.PRNGKey(1) dist = utils.multi_normal(*params) dist_samples = dist.sample((num_samples, ), rng) function = lambda x: jnp.sum(jnp.log(x**2)) cv, expected_cv, _ = control_variates.control_delta_method(function) avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples)) # Check that the average value of the control variate is close to the # expected value. _assert_equal(avg_cv, expected_cv(params, None), rtol=1e-1, atol=1e-3) # Second order expansion is log(\mu**2) + 1/2 * \sigma**2 (-2 / \mu**2) expected_cv_val = -np.exp(1.)**2 * data_dims _assert_equal(expected_cv(params, None), expected_cv_val, rtol=1e-1, atol=1e-3)
def testLinearFunctionWithHeuristic(self, effective_mean, effective_log_scale, decay): weights = jnp.array([1., 2., 3.], dtype=jnp.float32) num_samples = 10**5 data_dims = len(weights) mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = effective_log_scale * jnp.ones(shape=(data_dims), dtype=jnp.float32) params = [mean, log_scale] function = lambda x: jnp.sum(weights * x) rng = jax.random.PRNGKey(1) dist = utils.multi_normal(*params) dist_samples = dist.sample((num_samples, ), rng) cv, expected_cv, update_state = control_variates.moving_avg_baseline( function, decay=decay, zero_debias=False, use_decay_early_training_heuristic=True) state_1 = jnp.array(1.) avg_cv = jnp.mean( _map_variant(self.variant)(cv, params, dist_samples, (state_1, 0))) _assert_equal(avg_cv, state_1) _assert_equal(expected_cv(params, (state_1, 0)), state_1) state_2 = jnp.array(2.) avg_cv = jnp.mean( _map_variant(self.variant)(cv, params, dist_samples, (state_2, 0))) _assert_equal(avg_cv, state_2) _assert_equal(expected_cv(params, (state_2, 0)), state_2) first_step_decay = 0.1 update_state_1 = update_state(params, dist_samples, (state_1, 0))[0] _assert_equal( update_state_1, first_step_decay * state_1 + (1 - first_step_decay) * function(mean)) second_step_decay = 2. / 11 update_state_2 = update_state(params, dist_samples, (state_2, 1))[0] _assert_equal( update_state_2, second_step_decay * state_2 + (1 - second_step_decay) * function(mean))
def testPolinomialFunction(self, effective_mean, effective_log_scale): data_dims = 10 num_samples = 10**3 mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = effective_log_scale * jnp.ones( shape=(data_dims), dtype=jnp.float32) params = [mean, log_scale] dist = utils.multi_normal(*params) rng = jax.random.PRNGKey(1) dist_samples = dist.sample((num_samples,), rng) function = lambda x: jnp.sum(x**5) cv, expected_cv, _ = control_variates.control_delta_method(function) avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples)) # Check that the average value of the control variate is close to the # expected value. _assert_equal(avg_cv, expected_cv(params, None), rtol=1e-1, atol=1e-3)
def testQuadraticFunction(self, effective_mean, effective_log_scale): data_dims = 20 num_samples = 10**6 rng = jax.random.PRNGKey(1) mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = effective_log_scale * jnp.ones( shape=(data_dims), dtype=jnp.float32) params = [mean, log_scale] dist = utils.multi_normal(*params) dist_samples = dist.sample((num_samples,), rng) function = lambda x: jnp.sum(x**2) cv, expected_cv, _ = control_variates.control_delta_method(function) avg_cv = jnp.mean(_map_variant(self.variant)(cv, params, dist_samples)) expected_cv_value = jnp.sum(dist_samples**2) / num_samples # This should be an analytical computation, the result needs to be # accurate. _assert_equal(avg_cv, expected_cv_value, rtol=1e-1, atol=1e-3) _assert_equal(expected_cv(params, None), expected_cv_value, atol=1e-1)