class DeltaMethodAnalyticalExpectedGrads(chex.TestCase): @chex.all_variants @parameterized.named_parameters( chex.params_product([ ('_score_function_jacobians', 1.0, 1.0, sge.score_function_jacobians), ('_pathwise_jacobians', 1.0, 1.0, sge.pathwise_jacobians), ('_measure_valued_jacobians', 1.0, 1.0, sge.measure_valued_jacobians), ], [ ('estimate_cv_coeffs', True), ('no_estimate_cv_coeffs', False), ], named=True)) def testQuadraticFunction(self, effective_mean, effective_log_scale, grad_estimator, estimate_cv_coeffs): data_dims = 3 num_samples = 10**3 mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = effective_log_scale * jnp.ones( shape=(data_dims), dtype=jnp.float32) params = [mean, log_scale] function = lambda x: jnp.sum(x**2) rng = jax.random.PRNGKey(1) jacobians = _cv_jac_variant(self.variant)( function, control_variates.control_delta_method, grad_estimator, params, utils.multi_normal, # dist_builder rng, num_samples, None, # No cv state. estimate_cv_coeffs)[0] expected_mean_grads = 2 * effective_mean * np.ones( data_dims, dtype=np.float32) expected_log_scale_grads = 2 * np.exp(2 * effective_log_scale) * np.ones( data_dims, dtype=np.float32) mean_jacobians = jacobians[0] chex.assert_shape(mean_jacobians, (num_samples, data_dims)) mean_grads_from_jacobian = jnp.mean(mean_jacobians, axis=0) log_scale_jacobians = jacobians[1] chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) log_scale_grads_from_jacobian = jnp.mean(log_scale_jacobians, axis=0) _assert_equal(mean_grads_from_jacobian, expected_mean_grads, rtol=1e-1, atol=1e-3) _assert_equal(log_scale_grads_from_jacobian, expected_log_scale_grads, rtol=1e-1, atol=1e-3) @chex.all_variants @parameterized.named_parameters( chex.params_product([ ('_score_function_jacobians', 1.0, 1.0, sge.score_function_jacobians), ('_pathwise_jacobians', 1.0, 1.0, sge.pathwise_jacobians), ('_measure_valued_jacobians', 1.0, 1.0, sge.measure_valued_jacobians), ], [ ('estimate_cv_coeffs', True), ('no_estimate_cv_coeffs', False), ], named=True)) def testCubicFunction( self, effective_mean, effective_log_scale, grad_estimator, estimate_cv_coeffs): data_dims = 1 num_samples = 10**5 mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = effective_log_scale * jnp.ones( shape=(data_dims), dtype=jnp.float32) params = [mean, log_scale] function = lambda x: jnp.sum(x**3) rng = jax.random.PRNGKey(1) jacobians = _cv_jac_variant(self.variant)( function, control_variates.control_delta_method, grad_estimator, params, utils.multi_normal, rng, num_samples, None, # No cv state. estimate_cv_coeffs)[0] # The third order uncentered moment of the Gaussian distribution is # mu**3 + 2 mu * sigma **2. We use that to compute the expected value # of the gradients. Note: for the log scale we need use the chain rule. expected_mean_grads = ( 3 * effective_mean**2 + 3 * np.exp(effective_log_scale)**2) expected_mean_grads *= np.ones(data_dims, dtype=np.float32) expected_log_scale_grads = ( 6 * effective_mean * np.exp(effective_log_scale) ** 2) expected_log_scale_grads *= np.ones(data_dims, dtype=np.float32) mean_jacobians = jacobians[0] chex.assert_shape(mean_jacobians, (num_samples, data_dims)) mean_grads_from_jacobian = jnp.mean(mean_jacobians, axis=0) log_scale_jacobians = jacobians[1] chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) log_scale_grads_from_jacobian = jnp.mean(log_scale_jacobians, axis=0) _assert_equal(mean_grads_from_jacobian, expected_mean_grads, rtol=1e-1, atol=1e-3) _assert_equal(log_scale_grads_from_jacobian, expected_log_scale_grads, rtol=1e-1, atol=1e-3) @chex.all_variants @parameterized.named_parameters( chex.params_product([ ('_score_function_jacobians', 1.0, 1.0, sge.score_function_jacobians), ('_pathwise_jacobians', 1.0, 1.0, sge.pathwise_jacobians), ('_measure_valued_jacobians', 1.0, 1.0, sge.measure_valued_jacobians), ], [ ('estimate_cv_coeffs', True), ('no_estimate_cv_coeffs', False), ], named=True)) def testForthPowerFunction( self, effective_mean, effective_log_scale, grad_estimator, estimate_cv_coeffs): data_dims = 1 num_samples = 10**5 mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = effective_log_scale * jnp.ones( shape=(data_dims), dtype=jnp.float32) params = [mean, log_scale] function = lambda x: jnp.sum(x**4) rng = jax.random.PRNGKey(1) jacobians = _cv_jac_variant(self.variant)( function, control_variates.control_delta_method, grad_estimator, params, utils.multi_normal, rng, num_samples, None, # No cv state estimate_cv_coeffs)[0] # The third order uncentered moment of the Gaussian distribution is # mu**4 + 6 mu **2 sigma **2 + 3 sigma**4. We use that to compute the # expected value of the gradients. # Note: for the log scale we need use the chain rule. expected_mean_grads = ( 3 * effective_mean**3 + 12 * effective_mean * np.exp(effective_log_scale)**2) expected_mean_grads *= np.ones(data_dims, dtype=np.float32) expected_log_scale_grads = 12 * ( effective_mean**2 * np.exp(effective_log_scale) + np.exp(effective_log_scale) ** 3) * np.exp(effective_log_scale) expected_log_scale_grads *= np.ones(data_dims, dtype=np.float32) mean_jacobians = jacobians[0] chex.assert_shape(mean_jacobians, (num_samples, data_dims)) mean_grads_from_jacobian = jnp.mean(mean_jacobians, axis=0) log_scale_jacobians = jacobians[1] chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) log_scale_grads_from_jacobian = jnp.mean(log_scale_jacobians, axis=0) _assert_equal(mean_grads_from_jacobian, expected_mean_grads, rtol=1e-1, atol=1e-3) _assert_equal(log_scale_grads_from_jacobian, expected_log_scale_grads, rtol=1e-1, atol=1e-3)
class ConsistencyWithStandardEstimators(chex.TestCase): @chex.all_variants @parameterized.named_parameters( chex.params_product([ ('_score_function_jacobians', 1, 1, sge.score_function_jacobians, 10**6), ('_pathwise_jacobians', 1, 1, sge.pathwise_jacobians, 10**5), ('_measure_valued_jacobians', 1, 1, sge.measure_valued_jacobians, 10**5), ], [ ('control_delta_method', control_variates.control_delta_method), ('moving_avg_baseline', control_variates.moving_avg_baseline), ], named=True)) def testWeightedLinearFunction(self, effective_mean, effective_log_scale, grad_estimator, num_samples, control_variate_from_function): """Check that the gradients are consistent between estimators.""" weights = jnp.array([1., 2., 3.], dtype=jnp.float32) data_dims = len(weights) mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = effective_log_scale * jnp.ones( shape=(data_dims), dtype=jnp.float32) params = [mean, log_scale] function = lambda x: jnp.sum(weights * x) rng = jax.random.PRNGKey(1) cv_rng, ge_rng = jax.random.split(rng) jacobians = _cv_jac_variant(self.variant)( function, control_variate_from_function, grad_estimator, params, utils.multi_normal, # dist_builder cv_rng, # rng num_samples, (0., 0), # control_variate_state False)[0] mean_jacobians = jacobians[0] chex.assert_shape(mean_jacobians, (num_samples, data_dims)) mean_grads = jnp.mean(mean_jacobians, axis=0) log_scale_jacobians = jacobians[1] chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) log_scale_grads = jnp.mean(log_scale_jacobians, axis=0) # We use a different random number generator for the gradient estimator # without the control variate. no_cv_jacobians = grad_estimator( function, [mean, log_scale], utils.multi_normal, ge_rng, num_samples=num_samples) no_cv_mean_jacobians = no_cv_jacobians[0] chex.assert_shape(no_cv_mean_jacobians, (num_samples, data_dims)) no_cv_mean_grads = jnp.mean(no_cv_mean_jacobians, axis=0) no_cv_log_scale_jacobians = no_cv_jacobians[1] chex.assert_shape(no_cv_log_scale_jacobians, (num_samples, data_dims)) no_cv_log_scale_grads = jnp.mean(no_cv_log_scale_jacobians, axis=0) _assert_equal(mean_grads, no_cv_mean_grads, rtol=1e-1, atol=5e-2) _assert_equal(log_scale_grads, no_cv_log_scale_grads, rtol=1, atol=5e-2) @chex.all_variants @parameterized.named_parameters( chex.params_product([ ('_score_function_jacobians', 1, 1, sge.score_function_jacobians, 10**5), ('_pathwise_jacobians', 1, 1, sge.pathwise_jacobians, 10**5), ('_measure_valued_jacobians', 1, 1, sge.measure_valued_jacobians, 10**5), ], [ ('control_delta_method', control_variates.control_delta_method), ('moving_avg_baseline', control_variates.moving_avg_baseline), ], named=True)) def testNonPolynomialFunction( self, effective_mean, effective_log_scale, grad_estimator, num_samples, control_variate_from_function): """Check that the gradients are consistent between estimators.""" data_dims = 3 mean = effective_mean * jnp.ones(shape=(data_dims), dtype=jnp.float32) log_scale = effective_log_scale * jnp.ones( shape=(data_dims), dtype=jnp.float32) params = [mean, log_scale] function = lambda x: jnp.log(jnp.sum(x**2)) rng = jax.random.PRNGKey(1) cv_rng, ge_rng = jax.random.split(rng) jacobians = _cv_jac_variant(self.variant)( function, control_variate_from_function, grad_estimator, params, utils.multi_normal, cv_rng, num_samples, (0., 0), # control_variate_state False)[0] mean_jacobians = jacobians[0] chex.assert_shape(mean_jacobians, (num_samples, data_dims)) mean_grads = jnp.mean(mean_jacobians, axis=0) log_scale_jacobians = jacobians[1] chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) log_scale_grads = jnp.mean(log_scale_jacobians, axis=0) # We use a different random number generator for the gradient estimator # without the control variate. no_cv_jacobians = grad_estimator( function, [mean, log_scale], utils.multi_normal, ge_rng, num_samples=num_samples) no_cv_mean_jacobians = no_cv_jacobians[0] chex.assert_shape(no_cv_mean_jacobians, (num_samples, data_dims)) no_cv_mean_grads = jnp.mean(no_cv_mean_jacobians, axis=0) no_cv_log_scale_jacobians = no_cv_jacobians[1] chex.assert_shape(no_cv_log_scale_jacobians, (num_samples, data_dims)) no_cv_log_scale_grads = jnp.mean(no_cv_log_scale_jacobians, axis=0) _assert_equal(mean_grads, no_cv_mean_grads, rtol=1e-1, atol=5e-2) _assert_equal(log_scale_grads, no_cv_log_scale_grads, rtol=1e-1, atol=5e-2)
class GradientEstimatorsTest(chex.TestCase): @chex.all_variants @parameterized.named_parameters( chex.params_product([ ('_score_function_jacobians', sge.score_function_jacobians), ('_pathwise_jacobians', sge.pathwise_jacobians), ('_measure_valued_jacobians', sge.measure_valued_jacobians), ], [ ('0.1', 0.1), ('0.5', 0.5), ('0.9', 0.9), ], named=True)) def testConstantFunction(self, estimator, constant): data_dims = 3 num_samples = _estimator_to_num_samples[estimator] effective_mean = 1.5 mean = effective_mean * _ones(data_dims) effective_log_scale = 0.0 log_scale = effective_log_scale * _ones(data_dims) rng = jax.random.PRNGKey(1) jacobians = _estimator_variant(self.variant, estimator)( lambda x: jnp.array(constant), [mean, log_scale], utils.multi_normal, rng, num_samples) # Average over the number of samples. mean_jacobians = jacobians[0] chex.assert_shape(mean_jacobians, (num_samples, data_dims)) mean_grads = np.mean(mean_jacobians, axis=0) expected_mean_grads = np.zeros(data_dims, dtype=np.float32) log_scale_jacobians = jacobians[1] chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) log_scale_grads = np.mean(log_scale_jacobians, axis=0) expected_log_scale_grads = np.zeros(data_dims, dtype=np.float32) _assert_equal(mean_grads, expected_mean_grads, atol=5e-3) _assert_equal(log_scale_grads, expected_log_scale_grads, atol=5e-3) @chex.all_variants @parameterized.named_parameters( chex.params_product([ ('_score_function_jacobians', sge.score_function_jacobians), ('_pathwise_jacobians', sge.pathwise_jacobians), ('_measure_valued_jacobians', sge.measure_valued_jacobians), ], [ ('0.5_-1.', 0.5, -1.), ('0.7_0.0)', 0.7, 0.0), ('0.8_0.1', 0.8, 0.1), ], named=True)) def testLinearFunction(self, estimator, effective_mean, effective_log_scale): data_dims = 3 num_samples = _estimator_to_num_samples[estimator] rng = jax.random.PRNGKey(1) mean = effective_mean * _ones(data_dims) log_scale = effective_log_scale * _ones(data_dims) jacobians = _estimator_variant(self.variant, estimator)(np.sum, [mean, log_scale], utils.multi_normal, rng, num_samples) mean_jacobians = jacobians[0] chex.assert_shape(mean_jacobians, (num_samples, data_dims)) mean_grads = np.mean(mean_jacobians, axis=0) expected_mean_grads = np.ones(data_dims, dtype=np.float32) log_scale_jacobians = jacobians[1] chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) log_scale_grads = np.mean(log_scale_jacobians, axis=0) expected_log_scale_grads = np.zeros(data_dims, dtype=np.float32) _assert_equal(mean_grads, expected_mean_grads) _assert_equal(log_scale_grads, expected_log_scale_grads) @chex.all_variants @parameterized.named_parameters( chex.params_product([ ('_score_function_jacobians', sge.score_function_jacobians), ('_pathwise_jacobians', sge.pathwise_jacobians), ('_measure_valued_jacobians', sge.measure_valued_jacobians), ], [ ('1.0_0.3', 1.0, 0.3), ], named=True)) def testQuadraticFunction(self, estimator, effective_mean, effective_log_scale): data_dims = 3 num_samples = _estimator_to_num_samples[estimator] rng = jax.random.PRNGKey(1) mean = effective_mean * _ones(data_dims) log_scale = effective_log_scale * _ones(data_dims) jacobians = _estimator_variant(self.variant, estimator)(lambda x: np.sum(x**2) / 2, [mean, log_scale], utils.multi_normal, rng, num_samples) mean_jacobians = jacobians[0] chex.assert_shape(mean_jacobians, (num_samples, data_dims)) mean_grads = np.mean(mean_jacobians, axis=0) expected_mean_grads = effective_mean * np.ones(data_dims, dtype=np.float32) log_scale_jacobians = jacobians[1] chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) log_scale_grads = np.mean(log_scale_jacobians, axis=0) expected_log_scale_grads = np.exp(2 * effective_log_scale) * np.ones( data_dims, dtype=np.float32) _assert_equal(mean_grads, expected_mean_grads, atol=5e-2) _assert_equal(log_scale_grads, expected_log_scale_grads, atol=5e-2) @chex.all_variants @parameterized.named_parameters( chex.params_product([ ('_score_function_jacobians', sge.score_function_jacobians), ('_pathwise_jacobians', sge.pathwise_jacobians), ('_measure_valued_jacobians', sge.measure_valued_jacobians), ], [ ('case_1', [1.0, 2.0, 3.], [-1., 0.3, -2.], [1., 1., 1.]), ('case_2', [1.0, 2.0, 3.], [-1., 0.3, -2.], [4., 2., 3.]), ('case_3', [1.0, 2.0, 3.], [0.1, 0.2, 0.1], [10., 5., 1.]), ], named=True)) def testWeightedLinear(self, estimator, effective_mean, effective_log_scale, weights): num_samples = _weighted_estimator_to_num_samples[estimator] rng = jax.random.PRNGKey(1) mean = jnp.array(effective_mean) log_scale = jnp.array(effective_log_scale) weights = jnp.array(weights) data_dims = len(effective_mean) function = lambda x: jnp.sum(x * weights) jacobians = _estimator_variant(self.variant, estimator)(function, [mean, log_scale], utils.multi_normal, rng, num_samples) mean_jacobians = jacobians[0] chex.assert_shape(mean_jacobians, (num_samples, data_dims)) mean_grads = np.mean(mean_jacobians, axis=0) log_scale_jacobians = jacobians[1] chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) log_scale_grads = np.mean(log_scale_jacobians, axis=0) expected_mean_grads = weights expected_log_scale_grads = np.zeros(data_dims, dtype=np.float32) _assert_equal(mean_grads, expected_mean_grads, atol=5e-2) _assert_equal(log_scale_grads, expected_log_scale_grads, atol=5e-2) @chex.all_variants @parameterized.named_parameters( chex.params_product([ ('_score_function_jacobians', sge.score_function_jacobians), ('_pathwise_jacobians', sge.pathwise_jacobians), ('_measure_valued_jacobians', sge.measure_valued_jacobians), ], [ ('case_1', [1.0, 2.0, 3.], [-1., 0.3, -2.], [1., 1., 1.]), ('case_2', [1.0, 2.0, 3.], [-1., 0.3, -2.], [4., 2., 3.]), ('case_3', [1.0, 2.0, 3.], [0.1, 0.2, 0.1], [3., 5., 1.]), ], named=True)) def testWeightedQuadratic(self, estimator, effective_mean, effective_log_scale, weights): num_samples = _weighted_estimator_to_num_samples[estimator] rng = jax.random.PRNGKey(1) mean = jnp.array(effective_mean, dtype=jnp.float32) log_scale = jnp.array(effective_log_scale, dtype=jnp.float32) weights = jnp.array(weights, dtype=jnp.float32) data_dims = len(effective_mean) function = lambda x: jnp.sum(x * weights)**2 jacobians = _estimator_variant(self.variant, estimator)(function, [mean, log_scale], utils.multi_normal, rng, num_samples) mean_jacobians = jacobians[0] chex.assert_shape(mean_jacobians, (num_samples, data_dims)) mean_grads = np.mean(mean_jacobians, axis=0) log_scale_jacobians = jacobians[1] chex.assert_shape(log_scale_jacobians, (num_samples, data_dims)) log_scale_grads = np.mean(log_scale_jacobians, axis=0) expected_mean_grads = 2 * weights * np.sum(weights * mean) effective_scale = np.exp(log_scale) expected_scale_grads = 2 * weights**2 * effective_scale expected_log_scale_grads = expected_scale_grads * effective_scale _assert_equal(mean_grads, expected_mean_grads, atol=1e-1, rtol=1e-1) _assert_equal(log_scale_grads, expected_log_scale_grads, atol=1e-1, rtol=1e-1) @chex.all_variants @parameterized.named_parameters( chex.params_product( [ ('_sum_cos_x', [1.0], [1.0], lambda x: jnp.sum(jnp.cos(x))), # Need to ensure that the mean is not too close to 0. ('_sum_log_x', [10.0], [0.0], lambda x: jnp.sum(jnp.log(x))), ('_sum_cos_2x', [1.0, 2.0], [1.0, -2], lambda x: jnp.sum(jnp.cos(2 * x))), ('_cos_sum_2x', [1.0, 2.0], [1.0, -2], lambda x: jnp.cos(jnp.sum(2 * x))), ], [ ('coupling', True), ('nocoupling', False), ], named=True)) def testNonPolynomialFunctionConsistencyWithPathwise( self, effective_mean, effective_log_scale, function, coupling): num_samples = 10**5 rng = jax.random.PRNGKey(1) measure_rng, pathwise_rng = jax.random.split(rng) mean = jnp.array(effective_mean, dtype=jnp.float32) log_scale = jnp.array(effective_log_scale, dtype=jnp.float32) data_dims = len(effective_mean) measure_valued_jacobians = _measure_valued_variant( self.variant)(function, [mean, log_scale], utils.multi_normal, measure_rng, num_samples, coupling) measure_valued_mean_jacobians = measure_valued_jacobians[0] chex.assert_shape(measure_valued_mean_jacobians, (num_samples, data_dims)) measure_valued_mean_grads = np.mean(measure_valued_mean_jacobians, axis=0) measure_valued_log_scale_jacobians = measure_valued_jacobians[1] chex.assert_shape(measure_valued_log_scale_jacobians, (num_samples, data_dims)) measure_valued_log_scale_grads = np.mean( measure_valued_log_scale_jacobians, axis=0) pathwise_jacobians = _estimator_variant( self.variant, sge.pathwise_jacobians)(function, [mean, log_scale], utils.multi_normal, pathwise_rng, num_samples) pathwise_mean_jacobians = pathwise_jacobians[0] chex.assert_shape(pathwise_mean_jacobians, (num_samples, data_dims)) pathwise_mean_grads = np.mean(pathwise_mean_jacobians, axis=0) pathwise_log_scale_jacobians = pathwise_jacobians[1] chex.assert_shape(pathwise_log_scale_jacobians, (num_samples, data_dims)) pathwise_log_scale_grads = np.mean(pathwise_log_scale_jacobians, axis=0) _assert_equal(pathwise_mean_grads, measure_valued_mean_grads, rtol=5e-1, atol=1e-1) _assert_equal(pathwise_log_scale_grads, measure_valued_log_scale_grads, rtol=5e-1, atol=1e-1)