def get_log_likelihood(self, x, batch): import numpy as np if self.observation_dist == 'nb': log_likelihood = log_likelihood_nb( self.x, self.mu, self.sigma_square, ) else: dof = 2.0 log_likelihood = log_likelihood_student(self.x, self.mu, self.sigma_square, df=dof) num_samples = 5 feed_dict = {self.x: x, self.batch_id: batch} log_likelihood_value = 0 for i in range(num_samples): log_likelihood_value += self.session.run(log_likelihood, feed_dict=feed_dict) log_likelihood_value /= np.float32(num_samples) return log_likelihood_value
def __init__(self, n_gene, n_batch=None, z_dim=2, encoder_layer=None, decoder_layer=None, activation=tf.nn.elu, latent_dist='vmf', observation_dist='nb', seed=0): # n_batch should be a integer specifying the number of batches tf.compat.v1.set_random_seed(seed) if encoder_layer is None: encoder_layer = [128, 64, 32] if decoder_layer is None: decoder_layer = [32, 128] self.n_input_feature = n_gene # placeholder for gene expression data self.x = tf.compat.v1.placeholder(tf.float32, shape=[None, n_gene], name='x') self.z_dim, self.encoder_layer, self.decoder_layer, self.activation, \ self.latent_dist, self.observation_dist = \ z_dim, encoder_layer, decoder_layer, activation, \ latent_dist, observation_dist if self.latent_dist is 'vmf': self.z_dim += 1 if type(n_batch) is not list: n_batch = [n_batch] # placeholder for batch id of x self.n_batch = n_batch if len(self.n_batch) > 1: self.batch_id = tf.compat.v1.placeholder(tf.int32, shape=[None, None], name='batch') self.batch = self.multi_one_hot(self.batch_id, self.n_batch) else: self.batch_id = tf.compat.v1.placeholder(tf.int32, shape=[None], name='batch') self.batch = tf.one_hot(self.batch_id, self.n_batch[0]) self.library_size = tf.reduce_sum(self.x, axis=1, keepdims=True, name='library-size') self.z_mu, self.z_sigma_square = self._encoder(self.x, self.batch) with tf.name_scope('latent-variable'): if self.latent_dist == 'normal': self.q_z = tf.distributions.Normal(self.z_mu, self.z_sigma_square) elif self.latent_dist == 'vmf': self.q_z = VonMisesFisher(self.z_mu, self.z_sigma_square) elif self.latent_dist == 'wn': self.q_z = HyperbolicWrappedNorm(self.z_mu, self.z_sigma_square) else: raise NotImplemented self.z = self.q_z.sample() self.mu, self.sigma_square = self._decoder(self.z, self.batch) self.depth_loss = self._depth_regularizer(self.batch) with tf.name_scope('ELBO'): if self.observation_dist == 'student': self.log_likelihood = tf.reduce_mean(log_likelihood_student( self.x, self.mu, self.sigma_square, df=5.0), name="log_likelihood") elif self.observation_dist == 'nb': self.log_likelihood = tf.reduce_mean(log_likelihood_nb( self.x, self.mu, self.sigma_square, eps=1e-10), name="log_likelihood") if self.latent_dist == 'normal': self.p_z = tf.distributions.Normal(tf.zeros_like(self.z), tf.ones_like(self.z)) kl = self.q_z.kl_divergence(self.p_z) self.kl = tf.reduce_mean(tf.reduce_sum(kl, axis=1)) elif self.latent_dist == 'vmf': self.p_z = HypersphericalUniform(self.z_dim - 1, dtype=self.x.dtype) kl = self.q_z.kl_divergence(self.p_z) self.kl = tf.reduce_mean(kl) elif self.latent_dist == 'wn': tmp = self._polar_project(tf.zeros_like(self.z_sigma_square)) self.p_z = HyperbolicWrappedNorm( tmp, tf.ones_like(self.z_sigma_square)) kl = self.q_z.log_prob(self.z) kl = tf.clip_by_value(kl, -1000, 80) kl = tf.exp(kl) * kl - self.p_z.log_prob(self.z) self.kl = tf.reduce_mean(kl) else: raise NotImplemented self.ELBO = self.log_likelihood - self.kl self.session = tf.compat.v1.Session() self.saver = tf.compat.v1.train.Saver()
def __init__(self, n_gene, sp, sp_mask, n_batch=None, z_dim=2, encoder_layer=None, decoder_layer=None, activation=tf.nn.elu, latent_dist='vmf', observation_dist='nb', batch_invariant=False, seed=0): # n_batch should be a integer specifying the number of batches tf.compat.v1.set_random_seed(seed) if encoder_layer is None: encoder_layer = [128, 64, 32] if decoder_layer is None: decoder_layer = [32, 128] self.batch_invariant = batch_invariant self.n_input_feature = n_gene # placeholder for gene expression data self.x = tf.compat.v1.placeholder(tf.float32, shape=[None, n_gene], name='x') self.z_dim, self.encoder_layer, self.decoder_layer, self.activation, \ self.latent_dist, self.observation_dist = \ z_dim, encoder_layer, decoder_layer, activation, \ latent_dist, observation_dist if self.latent_dist is 'vmf': self.z_dim += 1 if type(n_batch) is not list: n_batch = [n_batch] # placeholder for batch id of x self.n_batch = n_batch if len(self.n_batch) > 1: self.batch_id = tf.compat.v1.placeholder(tf.int32, shape=[None, None], name='batch') self.batch = self.multi_one_hot(self.batch_id, self.n_batch) else: self.batch_id = tf.compat.v1.placeholder(tf.int32, shape=[None], name='batch') self.batch = tf.one_hot(self.batch_id, self.n_batch[0]) self.library_size = tf.reduce_sum(self.x, axis=1, keepdims=True, name='library-size') self.z_mu, self.z_sigma_square = self._encoder(self.x, self.batch) self.make_encoder = tf.compat.v1.make_template('encoder', self._encoder) with tf.name_scope('latent-variable'): if self.latent_dist == 'normal': self.q_z = tf.distributions.Normal(self.z_mu, self.z_sigma_square) elif self.latent_dist == 'vmf': self.q_z = VonMisesFisher(self.z_mu, self.z_sigma_square) elif self.latent_dist == 'wn': self.q_z = HyperbolicWrappedNorm(self.z_mu, self.z_sigma_square) else: raise NotImplemented self.z = self.q_z.sample() self.mu, self.sigma_square = self._decoder(self.z, self.batch) self.depth_loss = self._depth_regularizer(self.batch) with tf.name_scope('ELBO'): if self.observation_dist == 'student': self.log_likelihood = tf.reduce_mean( log_likelihood_student(self.x, self.mu, self.sigma_square, df=5.0), name="log_likelihood") elif self.observation_dist == 'nb': self.log_likelihood = tf.reduce_mean( log_likelihood_nb(self.x, self.mu, self.sigma_square, eps=1e-10), name="log_likelihood") if self.latent_dist == 'normal': self.p_z = tf.distributions.Normal(tf.zeros_like(self.z), tf.ones_like(self.z)) kl = self.q_z.kl_divergence(self.p_z) self.kl = tf.reduce_mean(tf.reduce_sum(kl, axis=1)) elif self.latent_dist == 'vmf': self.p_z = HypersphericalUniform(self.z_dim - 1, dtype=self.x.dtype) kl = self.q_z.kl_divergence(self.p_z) self.kl = tf.reduce_mean(kl) elif self.latent_dist == 'wn': tmp = self._polar_project(tf.zeros_like(self.z_sigma_square)) self.p_z = HyperbolicWrappedNorm(tmp, tf.ones_like(self.z_sigma_square)) kl = self.q_z.log_prob(self.z) - self.p_z.log_prob(self.z) self.kl = tf.reduce_mean(kl) else: raise NotImplemented self.ELBO = self.log_likelihood - self.kl self.sp_gene, _ = tf.split(self.x, [11, self.x.shape.as_list()[-1] - 11], axis=-1) self.sp_gene = tf.nn.relu(self.sp_gene - 3) self.sp_gene = tf.math.log1p(self.sp_gene) / tf.math.log(tf.constant(10.0)) # spatial coordinates self.sp = sp self.sp = tf.Variable(self.sp, trainable=False, dtype='float32') M1 = 128 M2 = self.sp.shape[0] p1 = tf.matmul( tf.expand_dims(tf.reduce_sum(tf.square(self.z_mu), 1), 1), tf.ones(shape=(1, M2)) ) p2 = tf.transpose(tf.matmul( tf.reshape(tf.reduce_sum(tf.square(self.sp), 1), shape=[-1, 1]), tf.ones(shape=(M1, 1)), transpose_b=True )) self.grid_dist = tf.sqrt(tf.add(p1, p2) - 2 * tf.matmul(self.z_mu, self.sp, transpose_b=True)) self.sp_mask = sp_mask self.sp_mask_marginal = tf.Variable(self.sp_mask, trainable=False, dtype='float32') B = tf.expand_dims(tf.transpose(self.sp_mask_marginal), 0) self.cell_sp_mask_marginal = tf.transpose(B, perm=[0, 2, 1]) # weighting the distances by gene expression AA = tf.broadcast_to(tf.expand_dims(self.grid_dist, 2), shape=[128, 64, 11]) self.cell_sp_mask_marginal = AA * self.cell_sp_mask_marginal ## self.sp_mask_neg = 100000.0 * (1.0 - self.sp_mask) self.sp_mask_neg = tf.Variable(self.sp_mask_neg, trainable=False, dtype='float32') self.sp_mask_neg = tf.broadcast_to(tf.expand_dims(self.sp_mask_neg, 0), shape=[128, 64, 11]) ## self.cell_sp_mask_marginal = self.cell_sp_mask_marginal + self.sp_mask_neg tmp = tf.nn.relu(tf.reduce_min(self.cell_sp_mask_marginal, 1) - 0.15) tmp *= self.sp_gene self.dist_loss = tf.reduce_mean(tf.reduce_sum(tmp, 1)) # ===== # Normlize by elements aa = np.zeros(11) aa[:6] = 1 aa[9] = 1 self.sp_mask_weight = self.sp_mask / np.sum(self.sp_mask, axis=0) self.sp_mask_weight = tf.Variable(self.sp_mask_weight * aa, trainable=False, dtype='float32') A = tf.expand_dims(self.sp_gene, 2) B = tf.expand_dims(tf.transpose(self.sp_mask_weight), 0) C = A * B self.cell_sp_mask = tf.transpose(C, perm=[0, 2, 1]) AA = tf.broadcast_to(tf.expand_dims(tf.nn.relu(self.grid_dist - 0.15), 2), shape=[128, 64, 11]) self.cell_dist_mask = AA * self.cell_sp_mask self.dist_loss += tf.reduce_sum(tf.reduce_sum(tf.reduce_sum(self.cell_dist_mask, 1), 1)) ## self.session = tf.compat.v1.Session() self.saver = tf.compat.v1.train.Saver()