def test_e_z_update(self): # Test that the manual e_z update is at an optimum of the KL. np.random.seed(234324) regs = generate_test_fit.load_test_regs() comb_params, metadata = generate_test_fit.load_test_fit() gmm = generate_test_fit.load_test_gmm(regs, comb_params['reg']) pert = 0.01 * np.random.random( gmm.gmm_params_pattern.flat_length(free=True)) gmm_params_free = gmm.gmm_params_pattern.flatten( comb_params['mix'], free=True) gmm_params = gmm.gmm_params_pattern.fold( gmm_params_free + pert, free=True) log_lik_by_nk, e_z = rm_lib.wrap_get_loglik_terms( gmm_params, gmm.transformed_reg_params) log_prior = rm_lib.get_log_prior( gmm_params['centroids'], gmm_params['probs'], gmm.prior_params) assert_array_almost_equal(e_z, rm_lib.get_e_z(log_lik_by_nk)) assert_array_almost_equal(np.ones(e_z.shape[0]), np.sum(e_z, axis=1)) # Neither log_lik_by_nk nor log_prior depend on e_z. def kl_ez(e_z): return rm_lib.get_kl(log_lik_by_nk, e_z, log_prior) e_z_pattern = paragami.SimplexArrayPattern( array_shape=(e_z.shape[0], ), simplex_size=e_z.shape[1]) kl_ez_flat = paragami.FlattenFunctionInput( kl_ez, patterns=e_z_pattern, free=True) get_e_z_grad = autograd.grad(kl_ez_flat) e_z_grad = get_e_z_grad(e_z_pattern.flatten(e_z, free=True)) assert_array_almost_equal(np.zeros_like(e_z_grad), e_z_grad)
def test_simplex_array_patterns(self): def test_shape_and_size(simplex_size, array_shape): shape = array_shape + (simplex_size, ) valid_value = np.random.random(shape) + 0.1 valid_value = \ valid_value / np.sum(valid_value, axis=-1, keepdims=True) pattern = paragami.SimplexArrayPattern(simplex_size, array_shape) _test_pattern(self, pattern, valid_value) test_shape_and_size(4, (2, 3)) test_shape_and_size(2, (2, 3)) test_shape_and_size(2, (2, )) self.assertTrue( paragami.SimplexArrayPattern(3, (2, 3)) != paragami.SimplexArrayPattern(3, (2, 4))) self.assertTrue( paragami.SimplexArrayPattern(4, (2, 3)) != paragami.SimplexArrayPattern(3, (2, 3))) pattern = paragami.SimplexArrayPattern(5, (2, 3)) self.assertEqual((2, 3), pattern.array_shape()) self.assertEqual(5, pattern.simplex_size()) self.assertEqual((2, 3, 5), pattern.shape()) # Test bad values. with self.assertRaisesRegex(ValueError, 'simplex_size'): paragami.SimplexArrayPattern(1, (2, 3)) pattern = paragami.SimplexArrayPattern(5, (2, 3)) with self.assertRaisesRegex(ValueError, 'wrong shape'): pattern.flatten(np.full((2, 3, 4), 0.2), free=False) with self.assertRaisesRegex(ValueError, 'Some values are negative'): bad_folded = np.full((2, 3, 5), 0.2) bad_folded[0, 0, 0] = -0.1 bad_folded[0, 0, 1] = 0.5 pattern.flatten(bad_folded, free=False) with self.assertRaisesRegex(ValueError, 'sum to one'): pattern.flatten(np.full((2, 3, 5), 0.1), free=False) with self.assertRaisesRegex(ValueError, 'wrong length'): pattern.fold(np.full(5, 0.2), free=False) with self.assertRaisesRegex(ValueError, 'wrong length'): pattern.fold(np.full(5, 0.2), free=True) with self.assertRaisesRegex(ValueError, 'sum to one'): pattern.fold(np.full(2 * 3 * 5, 0.1), free=False) # Test flat indices. pattern = paragami.SimplexArrayPattern(5, (2, 3)) _test_array_flat_indices(self, pattern)
def test_shape_and_size(simplex_size, array_shape): shape = array_shape + (simplex_size, ) valid_value = np.random.random(shape) + 0.1 valid_value = \ valid_value / np.sum(valid_value, axis=-1, keepdims=True) pattern = paragami.SimplexArrayPattern(simplex_size, array_shape) _test_pattern(self, pattern, valid_value)
def get_gmm_params_pattern(obs_dim, num_components): """A ``paragami`` pattern for a mixture model. ``centroids`` are the locations of the clusters. ``probs`` are the a priori probabilities of each cluster. """ gmm_params_pattern = paragami.PatternDict() gmm_params_pattern['centroids'] = \ paragami.NumericArrayPattern((num_components, obs_dim)) gmm_params_pattern['probs'] = \ paragami.SimplexArrayPattern( simplex_size=num_components, array_shape=(1,)) return gmm_params_pattern
def get_test_pattern(): # autograd will pass invalid values, so turn off value checking. pattern = paragami.PatternDict() pattern['array'] = paragami.NumericArrayPattern((2, 3, 4), lb=-1, ub=20, default_validate=False) pattern['mat'] = paragami.PSDSymmetricMatrixPattern(3, default_validate=False) pattern['simplex'] = paragami.SimplexArrayPattern(2, (3, ), default_validate=False) subdict = paragami.PatternDict() subdict['array2'] = paragami.NumericArrayPattern((2, ), lb=-3, ub=10, default_validate=False) pattern['dict'] = subdict return pattern