def test_e_z_update(self):
        # Test that the manual e_z update is at an optimum of the KL.
        np.random.seed(234324)
        regs = generate_test_fit.load_test_regs()
        comb_params, metadata = generate_test_fit.load_test_fit()
        gmm = generate_test_fit.load_test_gmm(regs, comb_params['reg'])

        pert = 0.01 * np.random.random(
            gmm.gmm_params_pattern.flat_length(free=True))
        gmm_params_free = gmm.gmm_params_pattern.flatten(
            comb_params['mix'], free=True)
        gmm_params = gmm.gmm_params_pattern.fold(
            gmm_params_free + pert, free=True)

        log_lik_by_nk, e_z = rm_lib.wrap_get_loglik_terms(
            gmm_params, gmm.transformed_reg_params)
        log_prior = rm_lib.get_log_prior(
            gmm_params['centroids'], gmm_params['probs'], gmm.prior_params)

        assert_array_almost_equal(e_z, rm_lib.get_e_z(log_lik_by_nk))
        assert_array_almost_equal(np.ones(e_z.shape[0]), np.sum(e_z, axis=1))

        # Neither log_lik_by_nk nor log_prior depend on e_z.
        def kl_ez(e_z):
            return rm_lib.get_kl(log_lik_by_nk, e_z, log_prior)
        e_z_pattern = paragami.SimplexArrayPattern(
            array_shape=(e_z.shape[0], ), simplex_size=e_z.shape[1])
        kl_ez_flat = paragami.FlattenFunctionInput(
            kl_ez, patterns=e_z_pattern, free=True)

        get_e_z_grad = autograd.grad(kl_ez_flat)
        e_z_grad = get_e_z_grad(e_z_pattern.flatten(e_z, free=True))
        assert_array_almost_equal(np.zeros_like(e_z_grad), e_z_grad)
Exemple #2
0
    def test_simplex_array_patterns(self):
        def test_shape_and_size(simplex_size, array_shape):
            shape = array_shape + (simplex_size, )
            valid_value = np.random.random(shape) + 0.1
            valid_value = \
                valid_value / np.sum(valid_value, axis=-1, keepdims=True)

            pattern = paragami.SimplexArrayPattern(simplex_size, array_shape)
            _test_pattern(self, pattern, valid_value)

        test_shape_and_size(4, (2, 3))
        test_shape_and_size(2, (2, 3))
        test_shape_and_size(2, (2, ))

        self.assertTrue(
            paragami.SimplexArrayPattern(3, (2, 3)) !=
            paragami.SimplexArrayPattern(3, (2, 4)))

        self.assertTrue(
            paragami.SimplexArrayPattern(4, (2, 3)) !=
            paragami.SimplexArrayPattern(3, (2, 3)))

        pattern = paragami.SimplexArrayPattern(5, (2, 3))
        self.assertEqual((2, 3), pattern.array_shape())
        self.assertEqual(5, pattern.simplex_size())
        self.assertEqual((2, 3, 5), pattern.shape())

        # Test bad values.
        with self.assertRaisesRegex(ValueError, 'simplex_size'):
            paragami.SimplexArrayPattern(1, (2, 3))

        pattern = paragami.SimplexArrayPattern(5, (2, 3))
        with self.assertRaisesRegex(ValueError, 'wrong shape'):
            pattern.flatten(np.full((2, 3, 4), 0.2), free=False)

        with self.assertRaisesRegex(ValueError, 'Some values are negative'):
            bad_folded = np.full((2, 3, 5), 0.2)
            bad_folded[0, 0, 0] = -0.1
            bad_folded[0, 0, 1] = 0.5
            pattern.flatten(bad_folded, free=False)

        with self.assertRaisesRegex(ValueError, 'sum to one'):
            pattern.flatten(np.full((2, 3, 5), 0.1), free=False)

        with self.assertRaisesRegex(ValueError, 'wrong length'):
            pattern.fold(np.full(5, 0.2), free=False)

        with self.assertRaisesRegex(ValueError, 'wrong length'):
            pattern.fold(np.full(5, 0.2), free=True)

        with self.assertRaisesRegex(ValueError, 'sum to one'):
            pattern.fold(np.full(2 * 3 * 5, 0.1), free=False)

        # Test flat indices.
        pattern = paragami.SimplexArrayPattern(5, (2, 3))
        _test_array_flat_indices(self, pattern)
Exemple #3
0
        def test_shape_and_size(simplex_size, array_shape):
            shape = array_shape + (simplex_size, )
            valid_value = np.random.random(shape) + 0.1
            valid_value = \
                valid_value / np.sum(valid_value, axis=-1, keepdims=True)

            pattern = paragami.SimplexArrayPattern(simplex_size, array_shape)
            _test_pattern(self, pattern, valid_value)
Exemple #4
0
def get_gmm_params_pattern(obs_dim, num_components):
    """A ``paragami`` pattern for a mixture model.

    ``centroids`` are the locations of the clusters.
    ``probs`` are the a priori probabilities of each cluster.
    """
    gmm_params_pattern = paragami.PatternDict()
    gmm_params_pattern['centroids'] = \
        paragami.NumericArrayPattern((num_components, obs_dim))
    gmm_params_pattern['probs'] = \
        paragami.SimplexArrayPattern(
            simplex_size=num_components, array_shape=(1,))
    return gmm_params_pattern
Exemple #5
0
def get_test_pattern():
    # autograd will pass invalid values, so turn off value checking.
    pattern = paragami.PatternDict()
    pattern['array'] = paragami.NumericArrayPattern((2, 3, 4),
                                                    lb=-1,
                                                    ub=20,
                                                    default_validate=False)
    pattern['mat'] = paragami.PSDSymmetricMatrixPattern(3,
                                                        default_validate=False)
    pattern['simplex'] = paragami.SimplexArrayPattern(2, (3, ),
                                                      default_validate=False)
    subdict = paragami.PatternDict()
    subdict['array2'] = paragami.NumericArrayPattern((2, ),
                                                     lb=-3,
                                                     ub=10,
                                                     default_validate=False)
    pattern['dict'] = subdict

    return pattern