def testLinearFunctionZeroDebias( self, effective_mean, effective_log_scale, decay): weights = jnp.array([1., 2., 3.], dtype=jnp.float32) num_samples = 10**5 data_dims = len(weights) mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = effective_log_scale * jnp.ones( shape=(data_dims), dtype=jnp.float32) params = [mean, log_scale] function = lambda x: jnp.sum(weights * x) rng = jax.random.PRNGKey(1) dist = utils.multi_normal(*params) dist_samples = dist.sample((num_samples,), rng) update_state = control_variates.moving_avg_baseline( function, decay=decay, zero_debias=False, use_decay_early_training_heuristic=False)[-1] update_state_zero_debias = control_variates.moving_avg_baseline( function, decay=decay, zero_debias=True, use_decay_early_training_heuristic=False)[-1] updated_state = update_state(params, dist_samples, (jnp.array(0.), 0))[0] _assert_equal(updated_state, (1 - decay) * function(mean)) updated_state_zero_debias = update_state_zero_debias( params, dist_samples, (jnp.array(0.), 0))[0] _assert_equal( updated_state_zero_debias, function(mean))
def testLinearFunctionWithHeuristic(self, effective_mean, effective_log_scale, decay): weights = jnp.array([1., 2., 3.], dtype=jnp.float32) num_samples = 10**5 data_dims = len(weights) mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = effective_log_scale * jnp.ones(shape=(data_dims), dtype=jnp.float32) params = [mean, log_scale] function = lambda x: jnp.sum(weights * x) rng = jax.random.PRNGKey(1) dist = utils.multi_normal(*params) dist_samples = dist.sample((num_samples, ), rng) cv, expected_cv, update_state = control_variates.moving_avg_baseline( function, decay=decay, zero_debias=False, use_decay_early_training_heuristic=True) state_1 = jnp.array(1.) avg_cv = jnp.mean( _map_variant(self.variant)(cv, params, dist_samples, (state_1, 0))) _assert_equal(avg_cv, state_1) _assert_equal(expected_cv(params, (state_1, 0)), state_1) state_2 = jnp.array(2.) avg_cv = jnp.mean( _map_variant(self.variant)(cv, params, dist_samples, (state_2, 0))) _assert_equal(avg_cv, state_2) _assert_equal(expected_cv(params, (state_2, 0)), state_2) first_step_decay = 0.1 update_state_1 = update_state(params, dist_samples, (state_1, 0))[0] _assert_equal( update_state_1, first_step_decay * state_1 + (1 - first_step_decay) * function(mean)) second_step_decay = 2. / 11 update_state_2 = update_state(params, dist_samples, (state_2, 1))[0] _assert_equal( update_state_2, second_step_decay * state_2 + (1 - second_step_decay) * function(mean))