def norm_posterior(dim, std0): """Initialise a posterior (diagonal) Normal distribution. Parameters ---------- dim : tuple or list the dimension of this distribution. std0 : float the initial (unoptimized) standard deviation of this distribution. Returns ------- Q : tf.distributions.Normal the initialised posterior Normal object. Note ---- This will make tf.Variables on the randomly initialised mean and standard deviation of the posterior. The initialisation of the mean is from a Normal with zero mean, and ``std0`` standard deviation, and the initialisation of the standard deviation is from a gamma distribution with an alpha of ``std0`` and a beta of 1. """ mu_0 = tf.random_normal(dim, stddev=std0, seed=next(seedgen)) mu = tf.Variable(mu_0, name="W_mu_q") std_0 = tf.random_gamma(alpha=std0, shape=dim, seed=next(seedgen)) std = pos(tf.Variable(std_0, name="W_std_q")) Q = tf.distributions.Normal(loc=mu, scale=std) return Q
def norm_prior(dim, std): """Initialise a prior (zero mean, isotropic) Normal distribution. Parameters ---------- dim : tuple or list the dimension of this distribution. std : float the prior standard deviation of this distribution. Returns ------- P : tf.distributions.Normal the initialised prior Normal object. Note ---- This will make a tf.Variable on the variance of the prior that is initialised with ``std``. """ mu = tf.zeros(dim) std = pos(tf.Variable(std, name="W_mu_p")) P = tf.distributions.Normal(loc=mu, scale=std) return P
def _initialise_variables(self, X): """Initialise the impute variables.""" datadim = int(X.shape[2]) impute_means = tf.Variable(tf.random_normal(shape=(1, datadim), seed=next(seedgen)), name="impute_scalars") impute_stddev = tf.Variable(tf.random_gamma(alpha=1., shape=(1, datadim), seed=next(seedgen)), name="impute_scalars") self.normal = tf.distributions.Normal(impute_means, tf.sqrt(pos(impute_stddev)))
def _chollogdet(L): """Log det of a cholesky, where L is (..., D, D).""" l = pos(tf.matrix_diag_part(L)) # keep > 0, and no vanashing gradient logdet = 2. * tf.reduce_sum(tf.log(l)) return logdet