Example #1
0
    def get_log_likelihood(self, x, batch):
        import numpy as np

        if self.observation_dist == 'nb':
            log_likelihood = log_likelihood_nb(
                self.x,
                self.mu,
                self.sigma_square,
            )
        else:
            dof = 2.0
            log_likelihood = log_likelihood_student(self.x,
                                                    self.mu,
                                                    self.sigma_square,
                                                    df=dof)
        num_samples = 5

        feed_dict = {self.x: x, self.batch_id: batch}
        log_likelihood_value = 0

        for i in range(num_samples):
            log_likelihood_value += self.session.run(log_likelihood,
                                                     feed_dict=feed_dict)

        log_likelihood_value /= np.float32(num_samples)

        return log_likelihood_value
Example #2
0
    def __init__(self,
                 n_gene,
                 n_batch=None,
                 z_dim=2,
                 encoder_layer=None,
                 decoder_layer=None,
                 activation=tf.nn.elu,
                 latent_dist='vmf',
                 observation_dist='nb',
                 seed=0):
        # n_batch should be a integer specifying the number of batches

        tf.compat.v1.set_random_seed(seed)

        if encoder_layer is None:
            encoder_layer = [128, 64, 32]
        if decoder_layer is None:
            decoder_layer = [32, 128]

        self.n_input_feature = n_gene
        # placeholder for gene expression data
        self.x = tf.compat.v1.placeholder(tf.float32,
                                          shape=[None, n_gene],
                                          name='x')

        self.z_dim, self.encoder_layer, self.decoder_layer, self.activation, \
            self.latent_dist, self.observation_dist = \
            z_dim, encoder_layer, decoder_layer, activation, \
            latent_dist, observation_dist

        if self.latent_dist is 'vmf':
            self.z_dim += 1

        if type(n_batch) is not list:
            n_batch = [n_batch]

        # placeholder for batch id of x
        self.n_batch = n_batch
        if len(self.n_batch) > 1:
            self.batch_id = tf.compat.v1.placeholder(tf.int32,
                                                     shape=[None, None],
                                                     name='batch')
            self.batch = self.multi_one_hot(self.batch_id, self.n_batch)
        else:
            self.batch_id = tf.compat.v1.placeholder(tf.int32,
                                                     shape=[None],
                                                     name='batch')
            self.batch = tf.one_hot(self.batch_id, self.n_batch[0])

        self.library_size = tf.reduce_sum(self.x,
                                          axis=1,
                                          keepdims=True,
                                          name='library-size')

        self.z_mu, self.z_sigma_square = self._encoder(self.x, self.batch)
        with tf.name_scope('latent-variable'):
            if self.latent_dist == 'normal':
                self.q_z = tf.distributions.Normal(self.z_mu,
                                                   self.z_sigma_square)
            elif self.latent_dist == 'vmf':
                self.q_z = VonMisesFisher(self.z_mu, self.z_sigma_square)
            elif self.latent_dist == 'wn':
                self.q_z = HyperbolicWrappedNorm(self.z_mu,
                                                 self.z_sigma_square)
            else:
                raise NotImplemented
            self.z = self.q_z.sample()

        self.mu, self.sigma_square = self._decoder(self.z, self.batch)
        self.depth_loss = self._depth_regularizer(self.batch)

        with tf.name_scope('ELBO'):
            if self.observation_dist == 'student':
                self.log_likelihood = tf.reduce_mean(log_likelihood_student(
                    self.x, self.mu, self.sigma_square, df=5.0),
                                                     name="log_likelihood")
            elif self.observation_dist == 'nb':
                self.log_likelihood = tf.reduce_mean(log_likelihood_nb(
                    self.x, self.mu, self.sigma_square, eps=1e-10),
                                                     name="log_likelihood")

            if self.latent_dist == 'normal':
                self.p_z = tf.distributions.Normal(tf.zeros_like(self.z),
                                                   tf.ones_like(self.z))
                kl = self.q_z.kl_divergence(self.p_z)
                self.kl = tf.reduce_mean(tf.reduce_sum(kl, axis=1))
            elif self.latent_dist == 'vmf':
                self.p_z = HypersphericalUniform(self.z_dim - 1,
                                                 dtype=self.x.dtype)
                kl = self.q_z.kl_divergence(self.p_z)
                self.kl = tf.reduce_mean(kl)
            elif self.latent_dist == 'wn':
                tmp = self._polar_project(tf.zeros_like(self.z_sigma_square))
                self.p_z = HyperbolicWrappedNorm(
                    tmp, tf.ones_like(self.z_sigma_square))

                kl = self.q_z.log_prob(self.z)
                kl = tf.clip_by_value(kl, -1000, 80)

                kl = tf.exp(kl) * kl - self.p_z.log_prob(self.z)

                self.kl = tf.reduce_mean(kl)
            else:
                raise NotImplemented

            self.ELBO = self.log_likelihood - self.kl

        self.session = tf.compat.v1.Session()
        self.saver = tf.compat.v1.train.Saver()
    def __init__(self, n_gene, sp, sp_mask, n_batch=None, z_dim=2,
                 encoder_layer=None, decoder_layer=None, activation=tf.nn.elu,
                 latent_dist='vmf', observation_dist='nb',
                 batch_invariant=False, seed=0):
        # n_batch should be a integer specifying the number of batches

        tf.compat.v1.set_random_seed(seed)

        if encoder_layer is None:
            encoder_layer = [128, 64, 32]
        if decoder_layer is None:
            decoder_layer = [32, 128]

        self.batch_invariant = batch_invariant

        self.n_input_feature = n_gene
        # placeholder for gene expression data
        self.x = tf.compat.v1.placeholder(tf.float32,
                                          shape=[None, n_gene], name='x')

        self.z_dim, self.encoder_layer, self.decoder_layer, self.activation, \
            self.latent_dist, self.observation_dist = \
            z_dim, encoder_layer, decoder_layer, activation, \
            latent_dist, observation_dist

        if self.latent_dist is 'vmf':
            self.z_dim += 1

        if type(n_batch) is not list:
            n_batch = [n_batch]

        # placeholder for batch id of x
        self.n_batch = n_batch
        if len(self.n_batch) > 1:
            self.batch_id = tf.compat.v1.placeholder(tf.int32,
                                                     shape=[None, None],
                                                     name='batch')
            self.batch = self.multi_one_hot(self.batch_id, self.n_batch)
        else:
            self.batch_id = tf.compat.v1.placeholder(tf.int32,
                                                     shape=[None],
                                                     name='batch')
            self.batch = tf.one_hot(self.batch_id, self.n_batch[0])

        self.library_size = tf.reduce_sum(self.x, axis=1, keepdims=True,
                                          name='library-size')

        self.z_mu, self.z_sigma_square = self._encoder(self.x, self.batch)
        self.make_encoder = tf.compat.v1.make_template('encoder', self._encoder)

        with tf.name_scope('latent-variable'):
            if self.latent_dist == 'normal':
                self.q_z = tf.distributions.Normal(self.z_mu, self.z_sigma_square)
            elif self.latent_dist == 'vmf':
                self.q_z = VonMisesFisher(self.z_mu, self.z_sigma_square)
            elif self.latent_dist == 'wn':
                self.q_z = HyperbolicWrappedNorm(self.z_mu, self.z_sigma_square)
            else:
                raise NotImplemented
            self.z = self.q_z.sample()

        self.mu, self.sigma_square = self._decoder(self.z, self.batch)
        self.depth_loss = self._depth_regularizer(self.batch)

        with tf.name_scope('ELBO'):
            if self.observation_dist == 'student':
                self.log_likelihood = tf.reduce_mean(
                    log_likelihood_student(self.x,
                                           self.mu,
                                           self.sigma_square,
                                           df=5.0), name="log_likelihood")
            elif self.observation_dist == 'nb':
                self.log_likelihood = tf.reduce_mean(
                    log_likelihood_nb(self.x,
                                      self.mu,
                                      self.sigma_square,
                                      eps=1e-10), name="log_likelihood")

            if self.latent_dist == 'normal':
                self.p_z = tf.distributions.Normal(tf.zeros_like(self.z),
                                                   tf.ones_like(self.z))
                kl = self.q_z.kl_divergence(self.p_z)
                self.kl = tf.reduce_mean(tf.reduce_sum(kl, axis=1))
            elif self.latent_dist == 'vmf':
                self.p_z = HypersphericalUniform(self.z_dim - 1,
                                                 dtype=self.x.dtype)
                kl = self.q_z.kl_divergence(self.p_z)
                self.kl = tf.reduce_mean(kl)
            elif self.latent_dist == 'wn':
                tmp = self._polar_project(tf.zeros_like(self.z_sigma_square))
                self.p_z = HyperbolicWrappedNorm(tmp,
                                                 tf.ones_like(self.z_sigma_square))

                kl = self.q_z.log_prob(self.z) - self.p_z.log_prob(self.z)

                self.kl = tf.reduce_mean(kl)
            else:
                raise NotImplemented

            self.ELBO = self.log_likelihood - self.kl

        self.sp_gene, _ = tf.split(self.x, [11, self.x.shape.as_list()[-1] - 11], axis=-1)

        self.sp_gene = tf.nn.relu(self.sp_gene - 3)
        self.sp_gene = tf.math.log1p(self.sp_gene) / tf.math.log(tf.constant(10.0))

        # spatial coordinates
        self.sp = sp
        self.sp = tf.Variable(self.sp, trainable=False, dtype='float32')

        M1 = 128
        M2 = self.sp.shape[0]

        p1 = tf.matmul(
            tf.expand_dims(tf.reduce_sum(tf.square(self.z_mu), 1), 1),
            tf.ones(shape=(1, M2))
        )
        p2 = tf.transpose(tf.matmul(
            tf.reshape(tf.reduce_sum(tf.square(self.sp), 1), shape=[-1, 1]),
            tf.ones(shape=(M1, 1)),
            transpose_b=True
        ))

        self.grid_dist = tf.sqrt(tf.add(p1, p2) -
                                 2 * tf.matmul(self.z_mu, self.sp, transpose_b=True))

        self.sp_mask = sp_mask

        self.sp_mask_marginal = tf.Variable(self.sp_mask, trainable=False, dtype='float32')
        B = tf.expand_dims(tf.transpose(self.sp_mask_marginal), 0)
        self.cell_sp_mask_marginal = tf.transpose(B, perm=[0, 2, 1])

        # weighting the distances by gene expression
        AA = tf.broadcast_to(tf.expand_dims(self.grid_dist, 2),
                             shape=[128, 64, 11])
        self.cell_sp_mask_marginal = AA * self.cell_sp_mask_marginal

        ##
        self.sp_mask_neg = 100000.0 * (1.0 - self.sp_mask)
        self.sp_mask_neg = tf.Variable(self.sp_mask_neg, trainable=False, dtype='float32')
        self.sp_mask_neg = tf.broadcast_to(tf.expand_dims(self.sp_mask_neg, 0),
                                           shape=[128, 64, 11])

        ##
        self.cell_sp_mask_marginal = self.cell_sp_mask_marginal + self.sp_mask_neg

        tmp = tf.nn.relu(tf.reduce_min(self.cell_sp_mask_marginal, 1) - 0.15)
        tmp *= self.sp_gene
        self.dist_loss = tf.reduce_mean(tf.reduce_sum(tmp, 1))

        # =====
        # Normlize by elements
        aa = np.zeros(11)
        aa[:6] = 1
        aa[9] = 1

        self.sp_mask_weight = self.sp_mask / np.sum(self.sp_mask, axis=0)
        self.sp_mask_weight = tf.Variable(self.sp_mask_weight * aa, trainable=False, dtype='float32')

        A = tf.expand_dims(self.sp_gene, 2)
        B = tf.expand_dims(tf.transpose(self.sp_mask_weight), 0)
        C = A * B
        self.cell_sp_mask = tf.transpose(C, perm=[0, 2, 1])

        AA = tf.broadcast_to(tf.expand_dims(tf.nn.relu(self.grid_dist - 0.15), 2),
                             shape=[128, 64, 11])
        self.cell_dist_mask = AA * self.cell_sp_mask

        self.dist_loss += tf.reduce_sum(tf.reduce_sum(tf.reduce_sum(self.cell_dist_mask, 1), 1))

        ##
        self.session = tf.compat.v1.Session()
        self.saver = tf.compat.v1.train.Saver()