def _test(shape, n): rv = InvGamma(shape, alpha=tf.zeros(shape)+0.5, beta=tf.zeros(shape)+0.5) rv_sample = rv.sample(n) x = rv_sample.eval() x_tf = tf.constant(x, dtype=tf.float32) alpha = rv.alpha.eval() beta = rv.beta.eval() for idx in range(shape[0]): assert np.allclose( rv.log_prob_idx((idx, ), x_tf).eval(), stats.invgamma.logpdf(x[:, idx], alpha[idx], scale=beta[idx]))
def _test(shape, n): rv = InvGamma(shape, alpha=tf.zeros(shape)+0.5, beta=tf.zeros(shape)+0.5) rv_sample = rv.sample(n) with sess.as_default(): x = rv_sample.eval() x_tf = tf.constant(x, dtype=tf.float32) alpha = rv.alpha.eval() beta = rv.beta.eval() for idx in range(shape[0]): assert np.allclose( rv.log_prob_idx((idx, ), x_tf).eval(), stats.invgamma.logpdf(x[:, idx], alpha[idx], scale=beta[idx]))
log_prior = dirichlet.logpdf(pi, self.alpha) log_prior += tf.reduce_sum(norm.logpdf(mus, 0, np.sqrt(self.c)), 1) log_prior += tf.reduce_sum(invgamma.logpdf(sigmas, self.a, self.b), 1) # Loop over each mini-batch zs[b,:] log_lik = [] n_minibatch = get_dims(zs[0])[0] for s in range(n_minibatch): log_lik_z = N*tf.reduce_sum(tf.log(pi), 1) for k in range(self.K): log_lik_z += tf.reduce_sum(multivariate_normal.logpdf(xs, mus[s, (k*self.D):((k+1)*self.D)], sigmas[s, (k*self.D):((k+1)*self.D)])) log_lik += [log_lik_z] return log_prior + tf.pack(log_lik) ed.set_seed(42) x = np.loadtxt('data/mixture_data.txt', dtype='float32', delimiter=',') data = ed.Data(tf.constant(x, dtype=tf.float32)) model = MixtureGaussian(K=2, D=2) variational = Variational() variational.add(Dirichlet(model.K)) variational.add(Normal(model.K*model.D)) variational.add(InvGamma(model.K*model.D)) inference = ed.MFVI(model, variational, data) inference.run(n_iter=500, n_minibatch=5, n_data=5)
def _test(shape, a, scale, n): x = InvGamma(shape, a, scale) val_est = tuple(get_dims(x.sample(n))) val_true = (n, ) + shape assert val_est == val_true
def _test(shape, a, scale, size): x = InvGamma(shape, a, scale) val_est = tuple(get_dims(x.sample(size=size))) val_true = (size, ) + shape assert val_est == val_true