Exemplo n.º 1
0
    data[i, 1] = 3. * (np.sin(x) + npr.normal(0, .1))

plt.figure()
plt.plot(data[:, 0], data[:, 1], 'kx')
plt.title('data')

nb_models = 25

gating_hypparams = dict(K=nb_models, alphas=np.ones((nb_models, )))
gating_prior = Dirichlet(**gating_hypparams)

components_hypparams = dict(mu=np.zeros((2, )), kappa=0.01,
                            psi=np.eye(2), nu=3)
components_prior = NormalWishart(**components_hypparams)

gmm = BayesianMixtureOfGaussians(gating=CategoricalWithDirichlet(gating_prior),
                                 components=[GaussianWithNormalWishart(components_prior)
                                             for _ in range(nb_models)])

gmm.add_data(data, labels_from_prior=True)

allscores = []
allmodels = []
for superitr in range(5):
    # Gibbs sampling to wander around the posterior
    gmm.resample(maxiter=25)
    # mean field to lock onto a mode
    scores = gmm.meanfield_coordinate_descent(maxiter=100)

    allscores.append(scores)
    allmodels.append(copy.deepcopy(gmm))
Exemplo n.º 2
0
def _job(kwargs):
    args = kwargs.pop('arguments')
    seed = kwargs.pop('seed')

    input = kwargs.pop('train_input')
    target = kwargs.pop('train_target')

    input_dim = input.shape[-1]
    target_dim = target.shape[-1]

    # set random seed
    np.random.seed(seed)

    nb_params = input_dim
    if args.affine:
        nb_params += 1

    basis_prior = []
    models_prior = []

    # initialize Normal
    psi_nw = 1e0
    kappa = 1e-2

    # initialize Matrix-Normal
    psi_mnw = 1e0
    K = 1e-3

    for n in range(args.nb_models):
        basis_hypparams = dict(mu=np.zeros((input_dim, )),
                               psi=np.eye(input_dim) * psi_nw,
                               kappa=kappa,
                               nu=input_dim + 1)

        aux = NormalWishart(**basis_hypparams)
        basis_prior.append(aux)

        models_hypparams = dict(M=np.zeros((target_dim, nb_params)),
                                K=K * np.eye(nb_params),
                                nu=target_dim + 1,
                                psi=np.eye(target_dim) * psi_mnw)

        aux = MatrixNormalWishart(**models_hypparams)
        models_prior.append(aux)

    # define gating
    if args.prior == 'stick-breaking':
        gating_hypparams = dict(K=args.nb_models,
                                gammas=np.ones((args.nb_models, )),
                                deltas=np.ones(
                                    (args.nb_models, )) * args.alpha)
        gating_prior = TruncatedStickBreaking(**gating_hypparams)

        ilr = BayesianMixtureOfLinearGaussians(
            gating=CategoricalWithStickBreaking(gating_prior),
            basis=[
                GaussianWithNormalWishart(basis_prior[i])
                for i in range(args.nb_models)
            ],
            models=[
                LinearGaussianWithMatrixNormalWishart(models_prior[i],
                                                      affine=args.affine)
                for i in range(args.nb_models)
            ])

    else:
        gating_hypparams = dict(K=args.nb_models,
                                alphas=np.ones(
                                    (args.nb_models, )) * args.alpha)
        gating_prior = Dirichlet(**gating_hypparams)

        ilr = BayesianMixtureOfLinearGaussians(
            gating=CategoricalWithDirichlet(gating_prior),
            basis=[
                GaussianWithNormalWishart(basis_prior[i])
                for i in range(args.nb_models)
            ],
            models=[
                LinearGaussianWithMatrixNormalWishart(models_prior[i],
                                                      affine=args.affine)
                for i in range(args.nb_models)
            ])
    ilr.add_data(target, input, whiten=True)

    # Gibbs sampling
    ilr.resample(maxiter=args.gibbs_iters, progprint=args.verbose)

    for _ in range(args.super_iters):
        if args.stochastic:
            # Stochastic meanfield VI
            ilr.meanfield_stochastic_descent(maxiter=args.svi_iters,
                                             stepsize=args.svi_stepsize,
                                             batchsize=args.svi_batchsize)
        if args.deterministic:
            # Meanfield VI
            ilr.meanfield_coordinate_descent(tol=args.earlystop,
                                             maxiter=args.meanfield_iters,
                                             progprint=args.verbose)

        ilr.gating.prior = ilr.gating.posterior
        for i in range(ilr.likelihood.size):
            ilr.basis[i].prior = ilr.basis[i].posterior
            ilr.models[i].prior = ilr.models[i].posterior

    return ilr
Exemplo n.º 3
0
lmbda = stats.wishart(3, np.eye(2)).rvs()
ensemble = TiedGaussiansWithPrecision(
    mus=[np.array([1., 1.]), np.array([-1., -1.])], lmbda=lmbda)

gmm = MixtureOfTiedGaussians(gating=gating, ensemble=ensemble)

obs = [gmm.rvs(100)[0] for _ in range(5)]
gmm.plot(obs)

gating_hypparams = dict(K=2, alphas=np.ones((2, )))
gating_prior = Dirichlet(**gating_hypparams)

ensemble_hypparams = dict(mus=[np.zeros((2, )) for _ in range(2)],
                          kappas=[1. for _ in range(2)],
                          psi=np.eye(2),
                          nu=3)
ensemble_prior = TiedNormalWisharts(**ensemble_hypparams)

model = BayesianMixtureOfTiedGaussians(
    gating=CategoricalWithDirichlet(gating_prior),
    ensemble=TiedGaussiansWithNormalWishart(ensemble_prior))

model.add_data(obs)

model.resample()
model.max_aposteriori(maxiter=1000)

plt.figure()
model.plot(obs)
Exemplo n.º 4
0
gating = Categorical(K=2)

lmbda = stats.wishart(3, np.eye(2)).rvs()
ensemble = TiedGaussiansWithPrecision(mus=[np.array([1., 1.]),
                                           np.array([-1., -1.])],
                                      lmbda=lmbda)

gmm = MixtureOfTiedGaussians(gating=gating, ensemble=ensemble)

obs = [gmm.rvs(100)[0] for _ in range(5)]
gmm.plot(obs)

gating_hypparams = dict(K=2, alphas=np.ones((2, )))
gating_prior = Dirichlet(**gating_hypparams)

ensemble_hypparams = dict(mus=[np.zeros((2, )) for _ in range(2)],
                          kappas=[1. for _ in range(2)],
                          psi=np.eye(2), nu=3)
ensemble_prior = TiedNormalWisharts(**ensemble_hypparams)

model = BayesianMixtureOfTiedGaussians(gating=CategoricalWithDirichlet(gating_prior),
                                       ensemble=TiedGaussiansWithNormalWishart(ensemble_prior))

model.add_data(obs)

model.resample(maxiter=1000)

plt.figure()
model.plot(obs)