def _test(distribution, bijector, n): x = TransformedDistribution(distribution=distribution, bijector=bijector, validate_args=True) val_est = get_dims(x.sample(n)) val_true = n + get_dims(distribution.mean()) assert val_est == val_true
def test_auto_transform_true(self): with self.test_session() as sess: # Match normal || softplus-inverse-normal distribution with # automated transformation on latter (assuming it is softplus). x = TransformedDistribution( distribution=Normal(0.0, 0.5), bijector=tf.contrib.distributions.bijectors.Softplus()) x.support = 'nonnegative' qx = Normal(loc=tf.Variable(tf.random_normal([])), scale=tf.nn.softplus(tf.Variable(tf.random_normal( [])))) inference = ed.KLqp({x: qx}) inference.initialize(auto_transform=True, n_samples=5, n_iter=1000) tf.global_variables_initializer().run() for _ in range(inference.n_iter): info_dict = inference.update() # Check approximation on constrained space has same moments as # target distribution. n_samples = 10000 x_mean, x_var = tf.nn.moments(x.sample(n_samples), 0) x_unconstrained = inference.transformations[x] qx_constrained = transform( qx, bijectors.Invert(x_unconstrained.bijector)) qx_mean, qx_var = tf.nn.moments(qx_constrained.sample(n_samples), 0) stats = sess.run([x_mean, qx_mean, x_var, qx_var]) self.assertAllClose(info_dict['loss'], 0.0, rtol=0.2, atol=0.2) self.assertAllClose(stats[0], stats[1], rtol=1e-1, atol=1e-1) self.assertAllClose(stats[2], stats[3], rtol=1e-1, atol=1e-1)
def __init__(self, M, C, theta_prior, delta_prior, a_prior): self.M = M self.C = C self.theta_prior = theta_prior # prior of ability self.delta_prior = delta_prior # prior of difficulty self.a_prior = a_prior # prior of discrimination if isinstance(a_prior, ed.RandomVariable): # variational posterior of discrimination self.qa = Normal(loc=tf.Variable(tf.ones([M])), scale=tf.nn.softplus( tf.Variable(tf.ones([M]) * .5)), name='qa') else: self.qa = a_prior with tf.variable_scope('local'): # variational posterior of ability if isinstance(self.theta_prior, RandomVariable): self.qtheta = TransformedDistribution(distribution=Normal(loc=tf.Variable(tf.random_normal([C])), scale=tf.nn.softplus(tf.Variable(tf.random_normal([C])))),\ bijector=ds.bijectors.Sigmoid(), sample_shape=[M],name='qtheta') else: self.qtheta = self.theta_prior # variational posterior of difficulty self.qdelta = TransformedDistribution(distribution=Normal(loc=tf.Variable(tf.random_normal([M])), scale=tf.nn.softplus(tf.Variable(tf.random_normal([M])))), \ bijector=ds.bijectors.Sigmoid(), sample_shape=[C],name='qdelta') alpha = (tf.transpose(self.qtheta) / self.qdelta)**self.qa beta = ((1. - tf.transpose(self.qtheta)) / (1. - self.qdelta))**self.qa # observed variable self.x = Beta(tf.transpose(alpha), tf.transpose(beta))
def test_hmc_default(self): with self.test_session() as sess: x = TransformedDistribution( distribution=Normal(1.0, 1.0), bijector=tf.contrib.distributions.bijectors.Softplus()) x.support = 'nonnegative' inference = ed.HMC([x]) inference.initialize(auto_transform=True, step_size=0.8) tf.global_variables_initializer().run() for _ in range(inference.n_iter): info_dict = inference.update() inference.print_progress(info_dict) # Check approximation on constrained space has same moments as # target distribution. n_samples = 10000 x_unconstrained = inference.transformations[x] qx = inference.latent_vars[x_unconstrained] qx_constrained = Empirical( x_unconstrained.bijector.inverse(qx.params)) x_mean, x_var = tf.nn.moments(x.sample(n_samples), 0) qx_mean, qx_var = tf.nn.moments(qx_constrained.params[500:], 0) stats = sess.run([x_mean, qx_mean, x_var, qx_var]) self.assertAllClose(stats[0], stats[1], rtol=1e-1, atol=1e-1) self.assertAllClose(stats[2], stats[3], rtol=1e-1, atol=1e-1)
def test_auto_transform_true(self): with self.test_session() as sess: # Match normal || softplus-inverse-normal distribution with # automated transformation on latter (assuming it is softplus). x = TransformedDistribution( distribution=Normal(0.0, 0.5), bijector=tf.contrib.distributions.bijectors.Softplus()) x.support = 'nonnegative' qx = Normal(loc=tf.Variable(tf.random_normal([])), scale=tf.nn.softplus(tf.Variable(tf.random_normal([])))) inference = ed.KLqp({x: qx}) inference.initialize(auto_transform=True, n_samples=5, n_iter=1000) tf.global_variables_initializer().run() for _ in range(inference.n_iter): info_dict = inference.update() # Check approximation on constrained space has same moments as # target distribution. n_samples = 10000 x_mean, x_var = tf.nn.moments(x.sample(n_samples), 0) x_unconstrained = inference.transformations[x] qx_constrained = transform(qx, bijectors.Invert(x_unconstrained.bijector)) qx_mean, qx_var = tf.nn.moments(qx_constrained.sample(n_samples), 0) stats = sess.run([x_mean, qx_mean, x_var, qx_var]) self.assertAllClose(info_dict['loss'], 0.0, rtol=0.2, atol=0.2) self.assertAllClose(stats[0], stats[1], rtol=1e-1, atol=1e-1) self.assertAllClose(stats[2], stats[3], rtol=1e-1, atol=1e-1)
def fit(self, X, batch_idx=None, max_iter=100, max_time=60): tf.reset_default_graph() # Data size N = X.shape[0] P = X.shape[1] if not self.batch_correction: batch_idx = None # Number of experimental batches if batch_idx is not None: self.n_batches = np.unique(batch_idx[:, 0]).size else: self.n_batches = 0 # Prior for cell scalings log_library_size = np.log(np.sum(X, axis=1)) self.mean_llib, self.std_llib = np.mean(log_library_size), np.std( log_library_size) if self.minibatch_size is not None: # Create ZINBayes computation graph self.define_stochastic_model(P, self.n_components) inference_local, inference_global = self.define_stochastic_inference( N, P, self.n_components) self.run_stochastic_inference(X, inference_local, inference_global, n_iterations=max_iter) else: # Create ZINBayes computation graph self.define_model(N, P, self.n_components, batch_idx=batch_idx) self.inference = self.define_inference(X) # If we want to assess convergence during inference on held-out data inference_val = None if self.validation and self.X_test is not None: self.define_val_model(self.X_test.shape[0], P, self.n_components) inference_val = self.define_val_inference(self.X_test) # Run inference self.loss = self.run_inference(self.inference, inference_val=inference_val, n_iterations=max_iter) # Get estimated variational distributions of global latent variables self.est_qW0 = TransformedDistribution( distribution=Normal(self.qW0.distribution.loc.eval(), self.qW0.distribution.scale.eval()), bijector=tf.contrib.distributions.bijectors.Exp()) self.est_qr = TransformedDistribution( distribution=Normal(self.qr.distribution.loc.eval(), self.qr.distribution.scale.eval()), bijector=tf.contrib.distributions.bijectors.Exp()) if self.zero_inflation: self.est_qW1 = Normal(self.qW1.loc.eval(), self.qW1.scale.eval())
def define_inference(self, X): # Local latent variables # self.qz = lognormal_q(self.z.shape) self.qz = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(self.z.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.z.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) # self.qlam = lognormal_q(self.lam.shape) self.qlam = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(self.lam.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.lam.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) # Global latent variables # self.qr = lognormal_q(self.r.shape) self.qr = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(self.r.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.r.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) # self.qW0 = lognormal_q(self.W0.shape) self.qW0 = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(self.W0.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.W0.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) latent_vars_dict = { self.z: self.qz, self.lam: self.qlam, self.r: self.qr, self.W0: self.qW0 } if self.zero_inflation: self.qW1 = Normal( tf.Variable(tf.zeros(self.W1.shape)), tf.nn.softplus(tf.Variable(0.1 * tf.ones(self.W1.shape)))) latent_vars_dict[self.W1] = self.qW1 if self.scalings: # self.ql = lognormal_q(self.l.shape) self.ql = TransformedDistribution( distribution=Normal( tf.Variable(self.mean_llib * tf.ones(self.l.shape)), tf.nn.softplus( tf.Variable(self.std_llib * tf.ones(self.l.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) latent_vars_dict[self.l] = self.ql inference = ed.ReparameterizationKLqp( latent_vars_dict, data={self.likelihood: tf.cast(X, tf.float32)}) return inference
def _test(base_dist_cls, transform, inverse, log_det_jacobian, n, **base_dist_args): x = TransformedDistribution(base_dist_cls=base_dist_cls, transform=transform, inverse=inverse, log_det_jacobian=log_det_jacobian, **base_dist_args) val_est = get_dims(x.sample(n)) val_true = n + get_dims(base_dist_args['mu']) assert val_est == val_true
def _test(base_dist_cls, transform, inverse, log_det_jacobian, n, **base_dist_args): x = TransformedDistribution( base_dist_cls=base_dist_cls, transform=transform, inverse=inverse, log_det_jacobian=log_det_jacobian, **base_dist_args ) val_est = get_dims(x.sample(n)) val_true = n + get_dims(base_dist_args["mu"]) assert val_est == val_true
def define_val_model(self, N, P, K): # Define new graph self.z_test = Gamma(2. * tf.ones([N, K]), 1. * tf.ones([N, K])) self.l_test = TransformedDistribution( distribution=Normal(self.mean_llib * tf.ones([N, 1]), np.sqrt(self.std_llib) * tf.ones([N, 1])), bijector=tf.contrib.distributions.bijectors.Exp()) rho_test = tf.matmul(self.z_test, self.W0) rho_test = rho_test / tf.reshape(tf.reduce_sum(rho_test, axis=1), (-1, 1)) # NxP self.lam_test = Gamma(self.r, self.r / (rho_test * self.l_test)) if self.zero_inflation: logit_pi_test = tf.matmul(self.z_test, self.W1) pi_test = tf.minimum( tf.maximum(tf.nn.sigmoid(logit_pi_test), 1e-7), 1. - 1e-7) cat_test = Categorical( probs=tf.stack([pi_test, 1. - pi_test], axis=2)) components_test = [ Poisson(rate=1e-30 * tf.ones([N, P])), Poisson(rate=self.lam_test) ] self.likelihood_test = Mixture(cat=cat_test, components=components_test) else: self.likelihood_test = Poisson(rate=self.lam_test)
def test_auto_transform_false(self): with self.test_session(): # Match normal || softplus-inverse-normal distribution without # automated transformation; it should fail. x = TransformedDistribution( distribution=Normal(0.0, 0.5), bijector=tf.contrib.distributions.bijectors.Softplus()) x.support = 'nonnegative' qx = Normal(loc=tf.Variable(tf.random_normal([])), scale=tf.nn.softplus(tf.Variable(tf.random_normal([])))) inference = ed.KLqp({x: qx}) inference.initialize(auto_transform=False, n_samples=5, n_iter=150) tf.global_variables_initializer().run() for _ in range(inference.n_iter): info_dict = inference.update() self.assertAllEqual(info_dict['loss'], np.nan)
def lognormal_q(shape): min_scale = 1e-5 loc_init = tf.random_normal(shape) scale_init = 0.1 * tf.random_normal(shape) rv = TransformedDistribution( distribution=Normal( tf.Variable(loc_init), tf.maximum(tf.nn.softplus(tf.Variable(scale_init)), min_scale)), bijector=tf.contrib.distributions.bijectors.Exp()) return rv
def test_auto_transform_false(self): with self.test_session(): # Match normal || softplus-inverse-normal distribution without # automated transformation; it should fail. x = TransformedDistribution( distribution=Normal(0.0, 0.5), bijector=tf.contrib.distributions.bijectors.Softplus()) x.support = 'nonnegative' qx = Normal(loc=tf.Variable(tf.random_normal([])), scale=tf.nn.softplus(tf.Variable(tf.random_normal( [])))) inference = ed.KLqp({x: qx}) inference.initialize(auto_transform=False, n_samples=5, n_iter=150) tf.global_variables_initializer().run() for _ in range(inference.n_iter): info_dict = inference.update() self.assertAllEqual(info_dict['loss'], np.nan)
def lognormal_q(shape, name=None): with tf.variable_scope(name, default_name="lognormal_q"): min_scale = 1e-5 loc = tf.get_variable("loc", shape) scale = tf.get_variable( "scale", shape, initializer=tf.random_normal_initializer(stddev=0.1)) rv = TransformedDistribution( distribution=Normal(loc, tf.maximum(tf.nn.softplus(scale), min_scale)), bijector=tf.contrib.distributions.bijectors.Exp()) return rv
def test_hmc_default(self): with self.test_session() as sess: x = TransformedDistribution( distribution=Normal(1.0, 1.0), bijector=tf.contrib.distributions.bijectors.Softplus()) x.support = 'nonnegative' inference = ed.HMC([x]) inference.initialize(auto_transform=True, step_size=0.8) tf.global_variables_initializer().run() for _ in range(inference.n_iter): info_dict = inference.update() inference.print_progress(info_dict) # Check approximation on constrained space has same moments as # target distribution. n_samples = 1000 qx_constrained = inference.latent_vars[x] x_mean, x_var = tf.nn.moments(x.sample(n_samples), 0) qx_mean, qx_var = tf.nn.moments(qx_constrained.params[500:], 0) stats = sess.run([x_mean, qx_mean, x_var, qx_var]) self.assertAllClose(stats[0], stats[1], rtol=1e-1, atol=1e-1) self.assertAllClose(stats[2], stats[3], rtol=1e-1, atol=1e-1)
def _initialize_output_model(self): with self.sess.as_default(): with self.sess.graph.as_default(): if self.output_scale is None: output_scale = self.decoder_scale else: output_scale = self.output_scale if self.mv: self.out = MultivariateNormalTriL( loc=tf.layers.Flatten()(self.decoder), scale_tril= output_scale ) else: self.out = Normal( loc=self.decoder, scale = output_scale ) if self.normalize_data and self.constrain_output: self.out = TransformedDistribution( self.out, bijector=tf.contrib.distributions.bijectors.Sigmoid() )
def test_hmc_custom(self): with self.test_session() as sess: x = TransformedDistribution( distribution=Normal(1.0, 1.0), bijector=tf.contrib.distributions.bijectors.Softplus()) x.support = 'nonnegative' qx = Empirical(tf.Variable(tf.random_normal([1000]))) inference = ed.HMC({x: qx}) inference.initialize(auto_transform=True, step_size=0.8) tf.global_variables_initializer().run() for _ in range(inference.n_iter): info_dict = inference.update() # Check approximation on constrained space has same moments as # target distribution. n_samples = 10000 x_unconstrained = inference.transformations[x] qx_constrained_params = x_unconstrained.bijector.inverse(qx.params) x_mean, x_var = tf.nn.moments(x.sample(n_samples), 0) qx_mean, qx_var = tf.nn.moments(qx_constrained_params[500:], 0) stats = sess.run([x_mean, qx_mean, x_var, qx_var]) self.assertAllClose(stats[0], stats[1], rtol=1e-1, atol=1e-1) self.assertAllClose(stats[2], stats[3], rtol=1e-1, atol=1e-1)
def construct_model(): nku = len(Ku) nkv = len(Kv) obs = tf.placeholder(tf.float32, R_.shape) Ug = TransformedDistribution(distribution=Normal(tf.zeros([nku]), tf.ones([nku])), bijector=tf.contrib.distributions.bijectors.Exp()) Vg = TransformedDistribution(distribution=Normal(tf.zeros([nkv]), tf.ones([nkv])), bijector=tf.contrib.distributions.bijectors.Exp()) Ua = TransformedDistribution(distribution=Normal(tf.zeros([1]), tf.ones([1])), bijector=tf.contrib.distributions.bijectors.Exp()) Va = TransformedDistribution(distribution=Normal(tf.zeros([1]), tf.ones([1])), bijector=tf.contrib.distributions.bijectors.Exp()) cKu = tf.cholesky(Ku + tf.eye(I) / Ua) # TODO: rank 1 chol update cKv = tf.cholesky(Kv + tf.eye(J) / Va) Uw1 = MultivariateNormalTriL(tf.zeros([L, I]), tf.reduce_sum(cKu / tf.reshape(tf.sqrt(Ug), [nku, 1, 1]), axis=0)) Vw1 = MultivariateNormalTriL(tf.zeros([L, J]), tf.reduce_sum(cKv / tf.reshape(tf.sqrt(Vg), [nkv, 1, 1]), axis=0)) logits = nn(Uw1, Vw1) R = AugmentedBernoulli(logits=logits, c=c, obs=obs, value=tf.cast(logits > 0, tf.int32)) qUg = TransformedDistribution(distribution=NormalWithSoftplusScale(tf.Variable(tf.zeros([nku])), tf.Variable(tf.ones([nku]))), bijector=tf.contrib.distributions.bijectors.Exp()) qVg = TransformedDistribution(distribution=NormalWithSoftplusScale(tf.Variable(tf.zeros([nkv])), tf.Variable(tf.ones([nkv]))), bijector=tf.contrib.distributions.bijectors.Exp()) qUa = TransformedDistribution(distribution=NormalWithSoftplusScale(tf.Variable(tf.zeros([1])), tf.Variable(tf.ones([1]))), bijector=tf.contrib.distributions.bijectors.Exp()) qVa = TransformedDistribution(distribution=NormalWithSoftplusScale(tf.Variable(tf.zeros([1])), tf.Variable(tf.ones([1]))), bijector=tf.contrib.distributions.bijectors.Exp()) qUw1 = MultivariateNormalTriL(tf.Variable(tf.zeros([L, I])), tf.Variable(tf.eye(I))) qVw1 = MultivariateNormalTriL(tf.Variable(tf.zeros([L, J])), tf.Variable(tf.eye(J))) return obs, Ug, Vg, Ua, Va, cKu, cKv, Uw1, Vw1, R, qUg, qVg, qUa, qVa, qUw1, qVw1
def define_val_inference(self, X): self.qz_test = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(self.z_test.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.z_test.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) # self.qlam = lognormal_q(self.lam.shape) self.qlam_test = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(self.lam_test.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.lam_test.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) # self.ql = lognormal_q(self.l.shape) self.ql_test = TransformedDistribution( distribution=Normal( tf.Variable(self.mean_llib * tf.ones(self.l_test.shape)), tf.nn.softplus( tf.Variable( np.sqrt(self.std_llib) * tf.ones(self.l_test.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) if self.zero_inflation: inference = ed.ReparameterizationKLqp( { self.z_test: self.qz_test, self.lam_test: self.qlam_test, self.l_test: self.ql_test }, data={ self.likelihood_test: tf.cast(X, tf.float32), self.W0: self.qW0, self.W1: self.qW1, self.r: self.qr }) else: inference = ed.ReparameterizationKLqp( { self.z_test: self.qz_test, self.lam_test: self.qlam_test, self.l_test: self.ql_test }, data={ self.likelihood_test: tf.cast(X, tf.float32), self.W0: self.qW0, self.r: self.qr }) return inference
def define_stochastic_model(self, P, K): M = self.minibatch_size self.W0 = Gamma(0.1 * tf.ones([K, P]), 0.3 * tf.ones([K, P])) if self.zero_inflation: self.W1 = Normal(tf.zeros([K, P]), tf.ones([K, P])) self.z = Gamma(2. * tf.ones([M, K]), 1. * tf.ones([M, K])) self.r = Gamma(2. * tf.ones([ P, ]), 1. * tf.ones([ P, ])) self.l = TransformedDistribution( distribution=Normal(self.mean_llib * tf.ones([M, 1]), self.std_llib * tf.ones([M, 1])), bijector=tf.contrib.distributions.bijectors.Exp()) self.rho = tf.matmul(self.z, self.W0) self.rho = self.rho / tf.reshape(tf.reduce_sum(self.rho, axis=1), (-1, 1)) # NxP self.lam = Gamma(self.r, self.r / (self.rho * self.l)) if self.zero_inflation: self.logit_pi = tf.matmul(self.z, self.W1) self.pi = tf.minimum( tf.maximum(tf.nn.sigmoid(self.logit_pi), 1e-7), 1. - 1e-7) self.cat = Categorical( probs=tf.stack([self.pi, 1. - self.pi], axis=2)) self.components = [ Poisson(rate=1e-30 * tf.ones([M, P])), Poisson(rate=self.lam) ] self.likelihood = Mixture(cat=self.cat, components=self.components) else: self.likelihood = Poisson(rate=self.lam)
def __init__(self, hdims, zdim, xdim, gen_scale=1.): x_ph = tf.placeholder(tf.float32, [None, xdim]) batch_size = tf.shape(x_ph)[0] sample_size = tf.placeholder(tf.int32, []) # Define the generative network (p(x | z)) with tf.variable_scope('generative', reuse=tf.AUTO_REUSE): z = Normal(loc=tf.zeros([batch_size, zdim]), scale=tf.ones([batch_size, zdim])) hidden = tf.layers.dense(z, hdims[0], activation=tf.nn.relu, name="dense1") loc = tf.layers.dense(hidden, xdim, name="dense2") x_gen = TransformedDistribution( distribution=tfd.Normal(loc=loc, scale=gen_scale), bijector=tfd.bijectors.Exp(), name="LogNormalTransformedDistribution" ) #x_gen = Bernoulli(logits=loc) # Define the inference network (q(z | x)) with tf.variable_scope('inference', reuse=tf.AUTO_REUSE): hidden = tf.layers.dense(x_ph, hdims[0], activation=tf.nn.relu) qloc = tf.layers.dense(hidden, zdim) qscale = tf.layers.dense(hidden, zdim, activation=tf.nn.softplus) qz = Normal(loc=qloc, scale=qscale) qz_sample = qz.sample(sample_size) # Define the generative network using posterior samples from q(z | x) with tf.variable_scope('generative'): qz_sample = tf.reshape(qz_sample, [-1, zdim]) hidden = tf.layers.dense(qz_sample, hdims[0], activation=tf.nn.relu, reuse=True, name="dense1") loc = tf.layers.dense(hidden, xdim, reuse=True, name="dense2") x_gen_post = tf.exp(loc) self.x_ph = x_ph self.x_data = self.x_ph self.batch_size = batch_size self.sample_size = sample_size self.ops = { 'generative': x_gen, 'inference': qz_sample, 'generative_post': x_gen_post } self.kl_coef = tf.placeholder(tf.float32, ()) with tf.variable_scope('inference', reuse=tf.AUTO_REUSE): self.inference = ed.KLqp({z: qz}, data={x_gen: self.x_data}) self.lr = tf.placeholder(tf.float32, shape=()) optimizer = tf.train.RMSPropOptimizer(self.lr, epsilon=0.9) self.inference.initialize( optimizer=optimizer, n_samples=10, kl_scaling={z: self.kl_coef} ) # Build elbo loss to evaluate on validation data self.eval_loss, _ = self.inference.build_loss_and_gradients([])
class DNNSegBayes(DNNSeg): _INITIALIZATION_KWARGS = UNSUPERVISED_WORD_CLASSIFIER_BAYES_INITIALIZATION_KWARGS _doc_header = """ Bayesian implementation of unsupervised word classifier. """ _doc_args = DNNSeg._doc_args _doc_kwargs = DNNSeg._doc_kwargs _doc_kwargs += '\n' + '\n'.join([' ' * 8 + ':param %s' % x.key + ': ' + '; '.join( [x.dtypes_str(), x.descr]) + ' **Default**: ``%s``.' % (x.default_value if not isinstance(x.default_value, str) else "'%s'" % x.default_value) for x in _INITIALIZATION_KWARGS]) __doc__ = _doc_header + _doc_args + _doc_kwargs def __init__(self, k, train_data, **kwargs): super(DNNSegBayes, self).__init__( k, train_data, **kwargs ) for kwarg in DNNSegBayes._INITIALIZATION_KWARGS: setattr(self, kwarg.key, kwargs.pop(kwarg.key, kwarg.default_value)) kwarg_keys = [x.key for x in DNNSeg._INITIALIZATION_KWARGS] for kwarg_key in kwargs: if kwarg_key not in kwarg_keys: raise TypeError('__init__() got an unexpected keyword argument %s' % kwarg_key) assert self.declare_priors or self.relaxed, 'Priors must be explicitly declared unless relaxed==True' self._initialize_metadata() def _initialize_metadata(self): super(DNNSegBayes, self)._initialize_metadata() self.inference_map = {} def _pack_metadata(self): md = super(DNNSegBayes, self)._pack_metadata() for kwarg in DNNSegBayes._INITIALIZATION_KWARGS: md[kwarg.key] = getattr(self, kwarg.key) return md def _unpack_metadata(self, md): super(DNNSegBayes, self)._unpack_metadata(md) for kwarg in DNNSegBayes._INITIALIZATION_KWARGS: setattr(self, kwarg.key, md.pop(kwarg.key, kwarg.default_value)) if len(md) > 0: sys.stderr.write( 'Saved model contained unrecognized attributes %s which are being ignored\n' % sorted(list(md.keys()))) def _initialize_classifier(self): with self.sess.as_default(): with self.sess.graph.as_default(): if self.trainable_temp: temp = tf.Variable(self.temp, name='temp') else: temp = self.temp if self.binary_classifier: if self.relaxed: self.encoding_q = RelaxedBernoulli(temp, logits = self.encoder[:,:self.k]) else: self.encoding_q = Bernoulli(logits = self.encoder[:,:self.k], dtype=self.FLOAT_TF) if self.declare_priors: if self.relaxed: self.encoding = RelaxedBernoulli(temp, probs =tf.ones([tf.shape(self.y_bwd)[0], self.k]) * 0.5) else: self.encoding = Bernoulli(probs =tf.ones([tf.shape(self.y_bwd)[0], self.k]) * 0.5, dtype=self.FLOAT_TF) if self.k: self.inference_map[self.encoding] = self.encoding_q else: self.encoding = self.encoding_q else: if self.relaxed: self.encoding_q = RelaxedOneHotCategorical(temp, logits = self.encoder[:,:self.k]) else: self.encoding_q = OneHotCategorical(logits = self.encoder[:,:self.k], dtype=self.FLOAT_TF) if self.declare_priors: if self.relaxed: self.encoding = RelaxedOneHotCategorical(temp, probs =tf.ones([tf.shape(self.y_bwd)[0], self.k]) / self.k) else: self.encoding = OneHotCategorical(probs =tf.ones([tf.shape(self.y_bwd)[0], self.k]) / self.k, dtype=self.FLOAT_TF) if self.k: self.inference_map[self.encoding] = self.encoding_q else: self.encoding = self.encoding_q def _initialize_decoder_scale(self): with self.sess.as_default(): with self.sess.graph.as_default(): dim = self.n_timesteps_output_bwd * self.frame_dim if self.decoder_type == 'rnn': decoder_scale = tf.nn.softplus( tf.keras.layers.LSTM( self.frame_dim, recurrent_activation='sigmoid', return_sequences=True, unroll=self.unroll )(self.decoder_in) ) elif self.decoder_type == 'cnn': assert self.n_timesteps_output_bwd is not None, 'n_timesteps_output must be defined when decoder_type == "cnn"' decoder_scale = tf.layers.dense(self.decoder_in, self.n_timesteps * self.frame_dim)[..., None] decoder_scale = tf.reshape(decoder_scale, (self.batch_len, self.n_timesteps_output_bwd, self.frame_dim, 1)) decoder_scale = tf.keras.layers.Conv2D(self.conv_n_filters, self.conv_kernel_size, padding='same', activation='elu')(decoder_scale) decoder_scale = tf.keras.layers.Conv2D(1, self.conv_kernel_size, padding='same', activation='linear')(decoder_scale) decoder_scale = tf.squeeze(decoder_scale, axis=-1) elif self.decoder_type in ['dense', 'dense_resnet']: n_classes = int(2 ** self.k) if self.binary_classifier else int(self.k) # First layer decoder_scale = DenseLayer( self.decoder_in, self.n_timesteps_output_bwd * self.frame_dim, self.training, activation=tf.nn.elu, batch_normalize=self.batch_normalize, session=self.sess ) # Intermediate layers if self.decoder_type == 'dense': for i in range(1, self.n_layers_decoder - 1): decoder_scale = DenseLayer( decoder_scale, self.n_timesteps_output_bwd * self.frame_dim, self.training, activation=tf.nn.elu, batch_normalize=self.batch_normalize, session=self.sess ) else: # self.decoder_type = 'dense_resnet' for i in range(1, self.n_layers_decoder - 1): decoder_scale = DenseResidualLayer( decoder_scale, self.training, units=self.n_timesteps_output_bwd * self.frame_dim, layers_inner=self.resnet_n_layers_inner, activation_inner=tf.nn.elu, activation=None, batch_normalize=self.batch_normalize, session=self.sess ) # Last layer if self.n_layers_decoder > 1: decoder_scale = DenseLayer( decoder_scale, self.n_timesteps_output_bwd * self.frame_dim, self.training, activation=None, batch_normalize=False, session=self.sess ) # Reshape decoder_scale = tf.reshape(decoder_scale, (self.batch_len, self.n_timesteps_output_bwd, self.frame_dim)) else: raise ValueError('Decoder type "%s" not supported at this time' %self.decoder_type) self.decoder_scale = tf.nn.softplus(decoder_scale) + self.epsilon # Override this method to include scale params for output distribution def _initialize_decoder(self): super(DNNSegBayes, self)._initialize_decoder() self._initialize_decoder_scale() def _initialize_output_model(self): with self.sess.as_default(): with self.sess.graph.as_default(): if self.output_scale is None: output_scale = self.decoder_scale else: output_scale = self.output_scale if self.mv: self.out = MultivariateNormalTriL( loc=tf.layers.Flatten()(self.decoder), scale_tril= output_scale ) else: self.out = Normal( loc=self.decoder, scale = output_scale ) if self.normalize_data and self.constrain_output: self.out = TransformedDistribution( self.out, bijector=tf.contrib.distributions.bijectors.Sigmoid() ) def _initialize_objective(self, n_train): n_train_minibatch = self.n_minibatch(n_train) minibatch_scale = self.minibatch_scale(n_train) with self.sess.as_default(): with self.sess.graph.as_default(): if self.mv: y = tf.layers.Flatten()(self.y_bwd) y_mask = tf.layers.Flatten()(self.y_bwd_mask[..., None] * tf.ones_like(self.y_bwd)) else: y = self.y_bwd y_mask = self.y_bwd_mask # Define access points to important layers if len(self.inference_map) > 0: self.out_post = ed.copy(self.out, self.inference_map) else: self.out_post = self.out if self.mv: self.reconst = tf.reshape(self.out_post * y_mask, [-1, self.n_timesteps_output_bwd, self.frame_dim]) # self.reconst_mean = tf.reshape(self.out_post.mean() * y_mask, [-1, self.n_timesteps_output, self.frame_dim]) else: self.reconst = self.out_post * y_mask[..., None] # self.reconst_mean = self.out_post.mean() * y_mask[..., None] if len(self.inference_map) > 0: self.encoding_post = ed.copy(self.encoding, self.inference_map) self.labels_post = ed.copy(self.labels, self.inference_map) self.label_probs_post = ed.copy(self.label_probs, self.inference_map) else: self.encoding_post = self.encoding self.labels_post = self.labels self.label_probs_post = self.label_probs self.llprior = self.out.log_prob(y) self.ll_post = self.out_post.log_prob(y) self.optim = self._initialize_optimizer(self.optim_name) if self.variational(): self.inference = getattr(ed, self.inference_name)(self.inference_map, data={self.out: y}) if self.mask_padding: self.inference.initialize( n_samples=self.n_samples, n_iter=n_train_minibatch * self.n_iter, # n_print=n_train_minibatch * self.log_freq, n_print=0, logdir=self.outdir + '/tensorboard/edward', log_timestamp=False, scale={self.out: y_mask[...,None] * minibatch_scale}, optimizer=self.optim ) else: self.inference.initialize( n_samples=self.n_samples, n_iter=n_train_minibatch * self.n_iter, # n_print=n_train_minibatch * self.log_freq, n_print=0, logdir=self.outdir + '/tensorboard/edward', log_timestamp=False, scale={self.out: minibatch_scale}, optimizer=self.optim ) else: raise ValueError('Only variational inferences are supported at this time') def set_output_scale(self, output_scale, trainable=False): with self.sess.as_default(): with self.sess.graph.as_default(): if trainable: self.output_scale = tf.Variable(output_scale, dtype=self.FLOAT_TF) self.sess.run(tf.variables_initializer(self.output_scale)) else: self.output_scale = tf.constant(output_scale, dtype=self.FLOAT_TF) def run_train_step( self, feed_dict, return_loss=True, return_reconstructions=False, return_labels=False, return_label_probs=False, return_encoding_entropy=False, return_segmentation_probs=False ): info_dict = self.inference.update(feed_dict) out_dict = {} if return_loss: out_dict['loss'] = info_dict['loss'] if return_reconstructions or return_labels or return_label_probs: to_run = [] to_run_names = [] if return_reconstructions: to_run.append(self.out_post) to_run_names.append('reconst') if return_labels: to_run.append(self.labels_post) to_run_names.append('labels') if return_label_probs: to_run.append(self.label_probs_post) to_run_names.append('label_probs') if return_encoding_entropy: to_run.append(self.encoding_entropy_mean) to_run_names.append('encoding_entropy') if self.encoder_type.lower() == 'softhmlstm' and return_segmentation_probs: to_run.append(self.segmentation_probs) to_run_names.append('segmentation_probs') output = self.sess.run(to_run, feed_dict=feed_dict) for i, x in enumerate(output): out_dict[to_run_names[i]] = x return out_dict def variational(self): """ Check whether the DTSR model uses variational Bayes. :return: ``True`` if the model is variational, ``False`` otherwise. """ return self.inference_name in [ 'KLpq', 'KLqp', 'ImplicitKLqp', 'ReparameterizationEntropyKLqp', 'ReparameterizationKLKLqp', 'ReparameterizationKLqp', 'ScoreEntropyKLqp', 'ScoreKLKLqp', 'ScoreKLqp', 'ScoreRBKLqp', 'WakeSleep' ] def report_settings(self, indent=0): out = super(DNNSegBayes, self).report_settings(indent=indent) for kwarg in UNSUPERVISED_WORD_CLASSIFIER_BAYES_INITIALIZATION_KWARGS: val = getattr(self, kwarg.key) out += ' ' * indent + ' %s: %s\n' %(kwarg.key, "\"%s\"" %val if isinstance(val, str) else val) out += '\n' return out
def evaluate_loglikelihood(self, X, batch_idx=None): """ This is the ELBO, which is a lower bound on the marginal log-likelihood. We perform some local optimization on the new data points to obtain the ELBO of the new data. """ N = X.shape[0] P = X.shape[1] K = self.n_components # Define new graph conditioned on the posterior global factors z_test = Gamma(2. * tf.ones([N, K]), 1. * tf.ones([N, K])) l_test = TransformedDistribution( distribution=Normal(self.mean_llib * tf.ones([N, 1]), np.sqrt(self.std_llib) * tf.ones([N, 1])), bijector=tf.contrib.distributions.bijectors.Exp()) if batch_idx is not None and self.n_batches > 0: rho_test = tf.matmul( tf.concat([ z_test, tf.cast(tf.one_hot(batch_idx[:, 0], self.n_batches), tf.float32) ], axis=1), self.W0) else: rho_test = tf.matmul(z_test, self.W0) rho_test = rho_test / tf.reshape(tf.reduce_sum(rho_test, axis=1), (-1, 1)) # NxP lam_test = Gamma(self.r, self.r / (rho_test * l_test)) if self.zero_inflation: if batch_idx is not None and self.n_batches > 0: logit_pi_test = tf.matmul( tf.concat([ z_test, tf.cast(tf.one_hot(batch_idx[:, 0], self.n_batches), tf.float32) ], axis=1), self.W1) else: logit_pi_test = tf.matmul(z_test, self.W1) pi_test = tf.minimum( tf.maximum(tf.nn.sigmoid(logit_pi_test), 1e-7), 1. - 1e-7) cat_test = Categorical( probs=tf.stack([pi_test, 1. - pi_test], axis=2)) components_test = [ Poisson(rate=1e-30 * tf.ones([N, P])), Poisson(rate=lam_test) ] likelihood_test = Mixture(cat=cat_test, components=components_test) else: likelihood_test = Poisson(rate=lam_test) qz_test = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(z_test.shape)), tf.nn.softplus(tf.Variable(1. * tf.ones(z_test.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) qlam_test = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(lam_test.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(lam_test.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) ql_test = TransformedDistribution( distribution=Normal( tf.Variable(self.mean_llib * tf.ones(l_test.shape)), tf.nn.softplus( tf.Variable( np.sqrt(self.std_llib) * tf.ones(l_test.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) if self.zero_inflation: inference_local = ed.ReparameterizationKLqp( { z_test: qz_test, lam_test: qlam_test, l_test: ql_test }, data={ likelihood_test: tf.cast(X, tf.float32), self.W0: self.est_qW0, self.W1: self.est_qW1, self.r: self.est_qr }) else: inference_local = ed.ReparameterizationKLqp( { z_test: qz_test, lam_test: qlam_test, l_test: ql_test }, data={ likelihood_test: tf.cast(X, tf.float32), self.W0: self.est_qW0, self.r: self.est_qr }) inference_local.run(n_iter=self.test_iterations, n_samples=self.n_mc_samples) return -self.sess.run(inference_local.loss, feed_dict={likelihood_test: X.astype('float32') }) / N
# TODO: beta is actually the expm of C: correct this! import edward as ed import tensorflow as tf from edward.models import Normal, InverseGamma, PointMass, Uniform, TransformedDistribution # simulate data d = 50 T = 300 X, C, S = MOU_sim(N=d, Sigma=None, mu=0, T=T, connectivity_strength=8.) # the model mu = tf.constant(0.) # Normal(loc=tf.zeros([d]), scale=1.*tf.ones([d])) beta = Normal(loc=tf.ones([d, d]), scale=2. * tf.ones([d, d])) ds = tf.contrib.distributions C = TransformedDistribution(distribution=beta, bijector=ds.bijectors.Exp(), name="LogNormalTransformedDistribution") noise_proc = InverseGamma(concentration=tf.ones([d]), rate=tf.ones([d])) # tf.constant(0.1) noise_obs = tf.constant(0.1) # InverseGamma(alpha=1.0, beta=1.0) x = [0] * T x[0] = Normal(loc=mu, scale=10. * tf.ones([d])) for n in range(1, T): x[n] = Normal(loc=mu + tf.tensordot(C, x[n - 1], axes=[[1], [0]]), scale=noise_proc * tf.ones([d])) ## map inference #print("setting up distributions") #qmu = PointMass(params=tf.Variable(tf.zeros([d]))) #qbeta = PointMass(params=tf.Variable(tf.zeros([d,d])))
def define_stochastic_inference(self, N, P, K): M = self.minibatch_size qz_vars = [ tf.Variable(tf.ones([N, K]), name='qz_loc'), tf.Variable(0.01 * tf.ones([N, K]), name='qz_scale') ] qlam_vars = [ tf.Variable(tf.ones([N, P]), name='qlam_loc'), tf.Variable(0.01 * tf.ones([N, P]), name='qlam_scale') ] ql_vars = [ tf.Variable(self.mean_llib * tf.ones([N, 1]), name='ql_loc'), tf.Variable(self.std_llib * tf.ones([N, 1]), name='ql_scale') ] qlocal_vars = [qz_vars, qlam_vars, ql_vars] qlocal_vars = [item for sublist in qlocal_vars for item in sublist] self.idx_ph = tf.placeholder(tf.int32, M) self.qz = TransformedDistribution( distribution=Normal( tf.gather(qlocal_vars[0], self.idx_ph), tf.nn.softplus(tf.gather(qlocal_vars[1], self.idx_ph))), bijector=tf.contrib.distributions.bijectors.Exp()) self.qlam = TransformedDistribution( distribution=Normal( tf.gather(qlocal_vars[2], self.idx_ph), tf.nn.softplus(tf.gather(qlocal_vars[3], self.idx_ph))), bijector=tf.contrib.distributions.bijectors.Exp()) self.ql = TransformedDistribution( distribution=Normal( tf.gather(qlocal_vars[4], self.idx_ph), tf.nn.softplus(tf.gather(qlocal_vars[5], self.idx_ph))), bijector=tf.contrib.distributions.bijectors.Exp()) self.qW0 = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(self.W0.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.W0.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) self.qW1 = Normal(tf.Variable(tf.zeros(self.W1.shape)), tf.nn.softplus(tf.Variable(tf.ones(self.W1.shape)))) self.qr = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(self.r.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.r.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) self.x_ph = tf.placeholder(tf.float32, [M, P]) inference_global = ed.ReparameterizationKLqp( { self.r: self.qr, self.W0: self.qW0, self.W1: self.qW1 }, data={ self.likelihood: self.x_ph, self.z: self.qz, self.lam: self.qlam, self.l: self.ql }) inference_local = ed.ReparameterizationKLqp( { self.z: self.qz, self.lam: self.qlam, self.l: self.ql }, data={ self.likelihood: self.x_ph, self.r: self.qr, self.W0: self.qW0, self.W1: self.qW1 }) return inference_local, inference_global
class ZINBayes(BaseEstimator, TransformerMixin): def __init__(self, n_components=10, n_mc_samples=1, gene_dispersion=True, zero_inflation=True, scalings=True, batch_correction=False, test_iterations=100, optimizer=None, minibatch_size=None, validation=False, X_test=None): self.n_components = n_components self.est_X = None self.est_L = None self.est_Z = None self.zero_inflation = zero_inflation if zero_inflation: print('Considering zero-inflation.') self.batch_correction = batch_correction if batch_correction: print('Performing batch correction.') self.scalings = scalings if scalings: print('Considering cell-specific scalings.') self.gene_dispersion = gene_dispersion if scalings: print('Considering gene-specific dispersion.') self.n_mc_samples = n_mc_samples self.test_iterations = test_iterations self.optimizer = optimizer self.minibatch_size = minibatch_size # if validation, use X_test to assess convergence self.validation = validation and X_test is not None self.X_test = X_test self.loss_dict = {'t_loss': [], 'v_loss': []} sess = ed.get_session() sess.close() tf.reset_default_graph() def close_session(self): return self.sess.close() def fit(self, X, batch_idx=None, max_iter=100, max_time=60): tf.reset_default_graph() # Data size N = X.shape[0] P = X.shape[1] if not self.batch_correction: batch_idx = None # Number of experimental batches if batch_idx is not None: self.n_batches = np.unique(batch_idx[:, 0]).size else: self.n_batches = 0 # Prior for cell scalings log_library_size = np.log(np.sum(X, axis=1)) self.mean_llib, self.std_llib = np.mean(log_library_size), np.std( log_library_size) if self.minibatch_size is not None: # Create ZINBayes computation graph self.define_stochastic_model(P, self.n_components) inference_local, inference_global = self.define_stochastic_inference( N, P, self.n_components) self.run_stochastic_inference(X, inference_local, inference_global, n_iterations=max_iter) else: # Create ZINBayes computation graph self.define_model(N, P, self.n_components, batch_idx=batch_idx) self.inference = self.define_inference(X) # If we want to assess convergence during inference on held-out data inference_val = None if self.validation and self.X_test is not None: self.define_val_model(self.X_test.shape[0], P, self.n_components) inference_val = self.define_val_inference(self.X_test) # Run inference self.loss = self.run_inference(self.inference, inference_val=inference_val, n_iterations=max_iter) # Get estimated variational distributions of global latent variables self.est_qW0 = TransformedDistribution( distribution=Normal(self.qW0.distribution.loc.eval(), self.qW0.distribution.scale.eval()), bijector=tf.contrib.distributions.bijectors.Exp()) self.est_qr = TransformedDistribution( distribution=Normal(self.qr.distribution.loc.eval(), self.qr.distribution.scale.eval()), bijector=tf.contrib.distributions.bijectors.Exp()) if self.zero_inflation: self.est_qW1 = Normal(self.qW1.loc.eval(), self.qW1.scale.eval()) def transform(self): if self.minibatch_size is None: self.est_Z = self.sess.run(tf.exp(self.qz.distribution.loc)) if self.scalings: self.est_L = self.sess.run(tf.exp(self.ql.distribution.loc)) self.est_X = self.posterior_nb_mean() return self.est_Z def fit_transform(self, X, batch_idx=None, max_iter=100, max_time=60): self.fit(X, batch_idx=batch_idx, max_iter=max_iter, max_time=60) return self.transform() def get_est_X(self): return self.est_X def get_est_l(self): return self.est_L def score(self, X, batch_idx=None): return self.evaluate_loglikelihood(X, batch_idx=batch_idx) # def define_model(self, N, P, K, batch_idx=None): # self.W0 = Gamma(.1 * tf.ones([K + self.n_batches, P]), .3 * tf.ones([K + self.n_batches, P])) # self.z = Gamma(16. * tf.ones([N, K]), 4. * tf.ones([N, K])) # self.a = Gamma(2. * tf.ones([1,P]), 1. * tf.ones([1,P])) # self.r = Gamma(self.a, self.a) # self.l = Gamma(self.mean_llib**2 / self.std_llib**2 * tf.ones([N, 1]), self.mean_llib / self.std_llib**2 * tf.ones([N, 1])) # rho = tf.matmul(self.z, self.W0) # self.likelihood = Poisson(rate=self.r*rho) def define_model(self, N, P, K, batch_idx=None): self.W0 = Gamma(.1 * tf.ones([K + self.n_batches, P]), .3 * tf.ones([K + self.n_batches, P])) if self.zero_inflation: self.W1 = Normal(tf.zeros([K + self.n_batches, P]), tf.ones([K + self.n_batches, P])) self.z = Gamma(2. * tf.ones([N, K]), 1. * tf.ones([N, K])) disp_size = 1 if self.gene_dispersion: disp_size = P self.r = Gamma(2. * tf.ones([ disp_size, ]), 1. * tf.ones([ disp_size, ])) self.l = TransformedDistribution( distribution=Normal(self.mean_llib * tf.ones([N, 1]), np.sqrt(self.std_llib) * tf.ones([N, 1])), bijector=tf.contrib.distributions.bijectors.Exp()) if batch_idx is not None and self.n_batches > 0: self.rho = tf.matmul( tf.concat([ self.z, tf.cast(tf.one_hot(batch_idx[:, 0], self.n_batches), tf.float32) ], axis=1), self.W0) else: self.rho = tf.matmul(self.z, self.W0) if self.scalings: self.rho = self.rho / tf.reshape(tf.reduce_sum(self.rho, axis=1), (-1, 1)) # NxP self.lam = Gamma(self.r, self.r / (self.rho * self.l)) else: self.lam = Gamma(self.r, self.r / self.rho) if self.zero_inflation: if batch_idx is not None and self.n_batches > 0: self.logit_pi = tf.matmul( tf.concat([ self.z, tf.cast(tf.one_hot(batch_idx[:, 0], self.n_batches), tf.float32) ], axis=1), self.W1) else: self.logit_pi = tf.matmul(self.z, self.W1) self.pi = tf.minimum( tf.maximum(tf.nn.sigmoid(self.logit_pi), 1e-7), 1. - 1e-7) self.cat = Categorical( probs=tf.stack([self.pi, 1. - self.pi], axis=2)) self.components = [ Poisson(rate=1e-30 * tf.ones([N, P])), Poisson(rate=self.lam) ] self.likelihood = Mixture(cat=self.cat, components=self.components) else: self.likelihood = Poisson(rate=self.lam) # def define_inference(self, X): # # Local latent variables # # self.qz = lognormal_q(self.z.shape) # self.qz = TransformedDistribution( # distribution=Normal(tf.Variable(tf.ones(self.z.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.z.shape)))), # bijector=tf.contrib.distributions.bijectors.Exp()) # self.ql = TransformedDistribution( # distribution=Normal(tf.Variable(self.mean_llib * tf.ones(self.l.shape)), tf.nn.softplus(tf.Variable(self.std_llib * tf.ones(self.l.shape)))), # bijector=tf.contrib.distributions.bijectors.Exp()) # self.qr = TransformedDistribution( # distribution=Normal(tf.Variable(tf.ones(self.r.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.r.shape)))), # bijector=tf.contrib.distributions.bijectors.Exp()) # self.qa = TransformedDistribution( # distribution=Normal(tf.Variable(tf.ones(self.a.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.a.shape)))), # bijector=tf.contrib.distributions.bijectors.Exp()) # self.qW0 = TransformedDistribution( # distribution=Normal(tf.Variable(tf.ones(self.W0.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.W0.shape)))), # bijector=tf.contrib.distributions.bijectors.Exp()) # latent_vars_dict = {self.z: self.qz, self.a: self.qa, self.r: self.qr, self.W0: self.qW0} # inference = ed.ReparameterizationKLqp(latent_vars_dict, data={self.likelihood: tf.cast(X, tf.float32)}) # return inference def define_inference(self, X): # Local latent variables # self.qz = lognormal_q(self.z.shape) self.qz = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(self.z.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.z.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) # self.qlam = lognormal_q(self.lam.shape) self.qlam = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(self.lam.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.lam.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) # Global latent variables # self.qr = lognormal_q(self.r.shape) self.qr = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(self.r.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.r.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) # self.qW0 = lognormal_q(self.W0.shape) self.qW0 = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(self.W0.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.W0.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) latent_vars_dict = { self.z: self.qz, self.lam: self.qlam, self.r: self.qr, self.W0: self.qW0 } if self.zero_inflation: self.qW1 = Normal( tf.Variable(tf.zeros(self.W1.shape)), tf.nn.softplus(tf.Variable(0.1 * tf.ones(self.W1.shape)))) latent_vars_dict[self.W1] = self.qW1 if self.scalings: # self.ql = lognormal_q(self.l.shape) self.ql = TransformedDistribution( distribution=Normal( tf.Variable(self.mean_llib * tf.ones(self.l.shape)), tf.nn.softplus( tf.Variable(self.std_llib * tf.ones(self.l.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) latent_vars_dict[self.l] = self.ql inference = ed.ReparameterizationKLqp( latent_vars_dict, data={self.likelihood: tf.cast(X, tf.float32)}) return inference def define_val_model(self, N, P, K): # Define new graph self.z_test = Gamma(2. * tf.ones([N, K]), 1. * tf.ones([N, K])) self.l_test = TransformedDistribution( distribution=Normal(self.mean_llib * tf.ones([N, 1]), np.sqrt(self.std_llib) * tf.ones([N, 1])), bijector=tf.contrib.distributions.bijectors.Exp()) rho_test = tf.matmul(self.z_test, self.W0) rho_test = rho_test / tf.reshape(tf.reduce_sum(rho_test, axis=1), (-1, 1)) # NxP self.lam_test = Gamma(self.r, self.r / (rho_test * self.l_test)) if self.zero_inflation: logit_pi_test = tf.matmul(self.z_test, self.W1) pi_test = tf.minimum( tf.maximum(tf.nn.sigmoid(logit_pi_test), 1e-7), 1. - 1e-7) cat_test = Categorical( probs=tf.stack([pi_test, 1. - pi_test], axis=2)) components_test = [ Poisson(rate=1e-30 * tf.ones([N, P])), Poisson(rate=self.lam_test) ] self.likelihood_test = Mixture(cat=cat_test, components=components_test) else: self.likelihood_test = Poisson(rate=self.lam_test) def define_val_inference(self, X): self.qz_test = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(self.z_test.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.z_test.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) # self.qlam = lognormal_q(self.lam.shape) self.qlam_test = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(self.lam_test.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.lam_test.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) # self.ql = lognormal_q(self.l.shape) self.ql_test = TransformedDistribution( distribution=Normal( tf.Variable(self.mean_llib * tf.ones(self.l_test.shape)), tf.nn.softplus( tf.Variable( np.sqrt(self.std_llib) * tf.ones(self.l_test.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) if self.zero_inflation: inference = ed.ReparameterizationKLqp( { self.z_test: self.qz_test, self.lam_test: self.qlam_test, self.l_test: self.ql_test }, data={ self.likelihood_test: tf.cast(X, tf.float32), self.W0: self.qW0, self.W1: self.qW1, self.r: self.qr }) else: inference = ed.ReparameterizationKLqp( { self.z_test: self.qz_test, self.lam_test: self.qlam_test, self.l_test: self.ql_test }, data={ self.likelihood_test: tf.cast(X, tf.float32), self.W0: self.qW0, self.r: self.qr }) return inference # def run_inference(self, inference, inference_val=None, n_iterations=1000): # N = self.l.shape.as_list()[0] # self.inference_e.initialize(n_iter=n_iterations, n_samples=self.n_mc_samples, optimizer=self.optimizer) # self.inference_m.initialize(optimizer=self.optimizer) # self.sess = ed.get_session() # tf.global_variables_initializer().run() # for i in range(self.inference_e.n_iter): # info_dict = self.inference_e.update() # self.inference_m.update() # info_dict['loss'] = info_dict['loss'] / N # self.inference_e.print_progress(info_dict) # self.loss_dict['t_loss'].append(info_dict["loss"]) # self.inference_e.finalize() # self.inference_m.finalize() def run_inference(self, inference, inference_val=None, n_iterations=1000): N = self.l.shape.as_list()[0] inference.initialize(n_iter=n_iterations, n_samples=self.n_mc_samples, optimizer=self.optimizer) if inference_val is not None: N_val = self.l_test.shape.as_list()[0] inference_val.initialize(n_samples=self.n_mc_samples, optimizer=self.optimizer) self.sess = ed.get_session() tf.global_variables_initializer().run() for i in range(inference.n_iter): info_dict = inference.update() info_dict['loss'] = info_dict['loss'] / N inference.print_progress(info_dict) self.loss_dict['t_loss'].append(info_dict["loss"]) if inference_val is not None: self.sess.run(inference_val.reset) self.sess.run( tf.variables_initializer(self.ql_test.get_variables())) self.sess.run( tf.variables_initializer(self.qz_test.get_variables())) self.sess.run( tf.variables_initializer(self.qlam_test.get_variables())) for _ in range(5): val_info_dict = inference_val.update() self.loss_dict['v_loss'].append(val_info_dict['loss'] / N_val) inference.finalize() if inference_val is not None: inference_val.finalize() def evaluate_loglikelihood(self, X, batch_idx=None): """ This is the ELBO, which is a lower bound on the marginal log-likelihood. We perform some local optimization on the new data points to obtain the ELBO of the new data. """ N = X.shape[0] P = X.shape[1] K = self.n_components # Define new graph conditioned on the posterior global factors z_test = Gamma(2. * tf.ones([N, K]), 1. * tf.ones([N, K])) l_test = TransformedDistribution( distribution=Normal(self.mean_llib * tf.ones([N, 1]), np.sqrt(self.std_llib) * tf.ones([N, 1])), bijector=tf.contrib.distributions.bijectors.Exp()) if batch_idx is not None and self.n_batches > 0: rho_test = tf.matmul( tf.concat([ z_test, tf.cast(tf.one_hot(batch_idx[:, 0], self.n_batches), tf.float32) ], axis=1), self.W0) else: rho_test = tf.matmul(z_test, self.W0) rho_test = rho_test / tf.reshape(tf.reduce_sum(rho_test, axis=1), (-1, 1)) # NxP lam_test = Gamma(self.r, self.r / (rho_test * l_test)) if self.zero_inflation: if batch_idx is not None and self.n_batches > 0: logit_pi_test = tf.matmul( tf.concat([ z_test, tf.cast(tf.one_hot(batch_idx[:, 0], self.n_batches), tf.float32) ], axis=1), self.W1) else: logit_pi_test = tf.matmul(z_test, self.W1) pi_test = tf.minimum( tf.maximum(tf.nn.sigmoid(logit_pi_test), 1e-7), 1. - 1e-7) cat_test = Categorical( probs=tf.stack([pi_test, 1. - pi_test], axis=2)) components_test = [ Poisson(rate=1e-30 * tf.ones([N, P])), Poisson(rate=lam_test) ] likelihood_test = Mixture(cat=cat_test, components=components_test) else: likelihood_test = Poisson(rate=lam_test) qz_test = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(z_test.shape)), tf.nn.softplus(tf.Variable(1. * tf.ones(z_test.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) qlam_test = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(lam_test.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(lam_test.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) ql_test = TransformedDistribution( distribution=Normal( tf.Variable(self.mean_llib * tf.ones(l_test.shape)), tf.nn.softplus( tf.Variable( np.sqrt(self.std_llib) * tf.ones(l_test.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) if self.zero_inflation: inference_local = ed.ReparameterizationKLqp( { z_test: qz_test, lam_test: qlam_test, l_test: ql_test }, data={ likelihood_test: tf.cast(X, tf.float32), self.W0: self.est_qW0, self.W1: self.est_qW1, self.r: self.est_qr }) else: inference_local = ed.ReparameterizationKLqp( { z_test: qz_test, lam_test: qlam_test, l_test: ql_test }, data={ likelihood_test: tf.cast(X, tf.float32), self.W0: self.est_qW0, self.r: self.est_qr }) inference_local.run(n_iter=self.test_iterations, n_samples=self.n_mc_samples) return -self.sess.run(inference_local.loss, feed_dict={likelihood_test: X.astype('float32') }) / N def posterior_nb_mean(self): est_rho = self.sess.run(self.rho, feed_dict={ self.z: np.exp(self.qz.distribution.loc.eval()), self.W0: np.exp(self.qW0.distribution.loc.eval()) }) est_l = 1 if self.scalings: est_l = np.exp(self.ql.distribution.loc.eval()) est_mean = est_rho * est_l return est_mean def define_stochastic_model(self, P, K): M = self.minibatch_size self.W0 = Gamma(0.1 * tf.ones([K, P]), 0.3 * tf.ones([K, P])) if self.zero_inflation: self.W1 = Normal(tf.zeros([K, P]), tf.ones([K, P])) self.z = Gamma(2. * tf.ones([M, K]), 1. * tf.ones([M, K])) self.r = Gamma(2. * tf.ones([ P, ]), 1. * tf.ones([ P, ])) self.l = TransformedDistribution( distribution=Normal(self.mean_llib * tf.ones([M, 1]), self.std_llib * tf.ones([M, 1])), bijector=tf.contrib.distributions.bijectors.Exp()) self.rho = tf.matmul(self.z, self.W0) self.rho = self.rho / tf.reshape(tf.reduce_sum(self.rho, axis=1), (-1, 1)) # NxP self.lam = Gamma(self.r, self.r / (self.rho * self.l)) if self.zero_inflation: self.logit_pi = tf.matmul(self.z, self.W1) self.pi = tf.minimum( tf.maximum(tf.nn.sigmoid(self.logit_pi), 1e-7), 1. - 1e-7) self.cat = Categorical( probs=tf.stack([self.pi, 1. - self.pi], axis=2)) self.components = [ Poisson(rate=1e-30 * tf.ones([M, P])), Poisson(rate=self.lam) ] self.likelihood = Mixture(cat=self.cat, components=self.components) else: self.likelihood = Poisson(rate=self.lam) def define_stochastic_inference(self, N, P, K): M = self.minibatch_size qz_vars = [ tf.Variable(tf.ones([N, K]), name='qz_loc'), tf.Variable(0.01 * tf.ones([N, K]), name='qz_scale') ] qlam_vars = [ tf.Variable(tf.ones([N, P]), name='qlam_loc'), tf.Variable(0.01 * tf.ones([N, P]), name='qlam_scale') ] ql_vars = [ tf.Variable(self.mean_llib * tf.ones([N, 1]), name='ql_loc'), tf.Variable(self.std_llib * tf.ones([N, 1]), name='ql_scale') ] qlocal_vars = [qz_vars, qlam_vars, ql_vars] qlocal_vars = [item for sublist in qlocal_vars for item in sublist] self.idx_ph = tf.placeholder(tf.int32, M) self.qz = TransformedDistribution( distribution=Normal( tf.gather(qlocal_vars[0], self.idx_ph), tf.nn.softplus(tf.gather(qlocal_vars[1], self.idx_ph))), bijector=tf.contrib.distributions.bijectors.Exp()) self.qlam = TransformedDistribution( distribution=Normal( tf.gather(qlocal_vars[2], self.idx_ph), tf.nn.softplus(tf.gather(qlocal_vars[3], self.idx_ph))), bijector=tf.contrib.distributions.bijectors.Exp()) self.ql = TransformedDistribution( distribution=Normal( tf.gather(qlocal_vars[4], self.idx_ph), tf.nn.softplus(tf.gather(qlocal_vars[5], self.idx_ph))), bijector=tf.contrib.distributions.bijectors.Exp()) self.qW0 = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(self.W0.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.W0.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) self.qW1 = Normal(tf.Variable(tf.zeros(self.W1.shape)), tf.nn.softplus(tf.Variable(tf.ones(self.W1.shape)))) self.qr = TransformedDistribution( distribution=Normal( tf.Variable(tf.ones(self.r.shape)), tf.nn.softplus(tf.Variable(0.01 * tf.ones(self.r.shape)))), bijector=tf.contrib.distributions.bijectors.Exp()) self.x_ph = tf.placeholder(tf.float32, [M, P]) inference_global = ed.ReparameterizationKLqp( { self.r: self.qr, self.W0: self.qW0, self.W1: self.qW1 }, data={ self.likelihood: self.x_ph, self.z: self.qz, self.lam: self.qlam, self.l: self.ql }) inference_local = ed.ReparameterizationKLqp( { self.z: self.qz, self.lam: self.qlam, self.l: self.ql }, data={ self.likelihood: self.x_ph, self.r: self.qr, self.W0: self.qW0, self.W1: self.qW1 }) return inference_local, inference_global def run_stochastic_inference(self, X, inference_local, inference_global, n_iterations=100): M = self.minibatch_size N = X.shape[0] # Run inference inference_global.initialize( scale={ self.likelihood: float(N) / M, self.z: float(N) / M, self.lam: float(N) / M, self.l: float(N) / M }) inference_local.initialize( scale={ self.likelihood: float(N) / M, self.z: float(N) / M, self.lam: float(N) / M, self.l: float(N) / M }) self.sess = ed.get_session() tf.global_variables_initializer().run() self.loss = [] n_iter_per_epoch = N // M pbar = ed.Progbar(n_iterations) for epoch in range(n_iterations): # print("Epoch: {0}".format(epoch)) avg_loss = 0.0 for t in range(1, n_iter_per_epoch + 1): x_batch, idx_batch = next_batch(X, M) # inference_local.update(feed_dict={x_ph: x_batch}) for _ in range(5): # make local inferences info_dict = inference_local.update(feed_dict={ self.x_ph: x_batch, self.idx_ph: idx_batch }) info_dict = inference_global.update(feed_dict={ self.x_ph: x_batch, self.idx_ph: idx_batch }) avg_loss += info_dict['loss'] # Print a lower bound to the average marginal likelihood for a cell avg_loss /= n_iter_per_epoch avg_loss /= M # print("-log p(x) <= {:0.3f}\n".format(avg_loss), end='\r') self.loss.append(avg_loss) pbar.update(epoch, values={'Loss': avg_loss}) inference_global.finalize() inference_local.finalize()
x_ph = tf.placeholder(tf.float32, [M, D]) U = Gamma(0.1, 0.5, sample_shape=[M, K]) V = Gamma(0.1, 0.3, sample_shape=[D, K]) x = Poisson(tf.matmul(U, V, transpose_b=True)) min_scale = 1e-5 qV_variables = [ tf.Variable(tf.random_uniform([D, K])), tf.Variable(tf.random_uniform([D, K])) ] qV = TransformedDistribution( distribution=Normal(qV_variables[0],\ tf.maximum(tf.nn.softplus(qV_variables[1]), \ min_scale)), bijector=tf.contrib.distributions.bijectors.Exp()) qU_variables = [ tf.Variable(tf.random_uniform([N, K])), tf.Variable(tf.random_uniform([N, K])) ] qU = TransformedDistribution( distribution=Normal(tf.gather(qU_variables[0], idx_ph),\ tf.maximum(tf.nn.softplus(tf.gather(qU_variables[1], idx_ph)), \ min_scale)), bijector=tf.contrib.distributions.bijectors.Exp())
def define_model(self, N, P, K, batch_idx=None): self.W0 = Gamma(.1 * tf.ones([K + self.n_batches, P]), .3 * tf.ones([K + self.n_batches, P])) if self.zero_inflation: self.W1 = Normal(tf.zeros([K + self.n_batches, P]), tf.ones([K + self.n_batches, P])) self.z = Gamma(2. * tf.ones([N, K]), 1. * tf.ones([N, K])) disp_size = 1 if self.gene_dispersion: disp_size = P self.r = Gamma(2. * tf.ones([ disp_size, ]), 1. * tf.ones([ disp_size, ])) self.l = TransformedDistribution( distribution=Normal(self.mean_llib * tf.ones([N, 1]), np.sqrt(self.std_llib) * tf.ones([N, 1])), bijector=tf.contrib.distributions.bijectors.Exp()) if batch_idx is not None and self.n_batches > 0: self.rho = tf.matmul( tf.concat([ self.z, tf.cast(tf.one_hot(batch_idx[:, 0], self.n_batches), tf.float32) ], axis=1), self.W0) else: self.rho = tf.matmul(self.z, self.W0) if self.scalings: self.rho = self.rho / tf.reshape(tf.reduce_sum(self.rho, axis=1), (-1, 1)) # NxP self.lam = Gamma(self.r, self.r / (self.rho * self.l)) else: self.lam = Gamma(self.r, self.r / self.rho) if self.zero_inflation: if batch_idx is not None and self.n_batches > 0: self.logit_pi = tf.matmul( tf.concat([ self.z, tf.cast(tf.one_hot(batch_idx[:, 0], self.n_batches), tf.float32) ], axis=1), self.W1) else: self.logit_pi = tf.matmul(self.z, self.W1) self.pi = tf.minimum( tf.maximum(tf.nn.sigmoid(self.logit_pi), 1e-7), 1. - 1e-7) self.cat = Categorical( probs=tf.stack([self.pi, 1. - self.pi], axis=2)) self.components = [ Poisson(rate=1e-30 * tf.ones([N, P])), Poisson(rate=self.lam) ] self.likelihood = Mixture(cat=self.cat, components=self.components) else: self.likelihood = Poisson(rate=self.lam)