Esempio n. 1
0
    def model(self, x, y=None):
        # Register various nn.Modules with Pyro
        pyro.module("scanvi", self)

        # This gene-level parameter modulates the variance of the observation distribution
        theta = pyro.param("inverse_dispersion", 10.0 * x.new_ones(self.num_genes),
                           constraint=constraints.positive)

        # We scale all sample statements by scale_factor so that the ELBO is normalized
        # wrt the number of datapoints and genes
        with pyro.plate("batch", len(x)), poutine.scale(scale=self.scale_factor):
            z1 = pyro.sample("z1", dist.Normal(0, x.new_ones(self.latent_dim)).to_event(1))
            # Note that if y is None (i.e. y is unobserved) then y will be sampled;
            # otherwise y will be treated as observed.
            y = pyro.sample("y", dist.OneHotCategorical(logits=x.new_zeros(self.num_labels)),
                            obs=y)

            z2_loc, z2_scale = self.z2_decoder(z1, y)
            z2 = pyro.sample("z2", dist.Normal(z2_loc, z2_scale).to_event(1))

            l_scale = self.l_scale * x.new_ones(1)
            l = pyro.sample("l", dist.LogNormal(self.l_loc, l_scale).to_event(1))

            # Note that by construction mu is normalized (i.e. mu.sum(-1) == 1) and the
            # total scale of counts for each cell is determined by `l`
            gate_logits, mu = self.x_decoder(z2)
            # TODO revisit this parameterization if torch.distributions.NegativeBinomial changes
            # from failure to success parametrization;
            # see https://github.com/pytorch/pytorch/issues/42449
            nb_logits = (l * mu + self.epsilon).log() - (theta + self.epsilon).log()
            x_dist = dist.ZeroInflatedNegativeBinomial(gate_logits=gate_logits, total_count=theta,
                                                       logits=nb_logits)
            # Observe the datapoint x using the observation distribution x_dist
            pyro.sample("x", x_dist.to_event(1), obs=x)
Esempio n. 2
0
    def forward(self):

        def RP(weights, distances, d):
            return 1e4 * (weights * torch.pow(0.5, distances/(1e3 * d))).sum(-1)

        with pyro.plate(self.name +"_regions", 3):
            a = pyro.sample(self.name +"_a", dist.HalfNormal(12.))

        with pyro.plate(self.name +"_upstream-downstream", 2):

            d = torch.exp(pyro.sample(self.name +'_logdistance', dist.Normal(np.e, 2.)))

        b = pyro.sample(self.name +"_b", dist.Normal(-10.,3.))
        theta = pyro.sample(self.name +"_theta", dist.Gamma(2., 0.5))
        psi = pyro.sample(self.name +"_dropout", dist.Beta(1., 10.))

        with pyro.plate(self.name +"_data", self.N, subsample_size=64) as ind:

            expr_rate = a[0] * RP(self.upstream_weights.index_select(0, ind), self.upstream_distances, d[0])\
                + a[1] * RP(self.downstream_weights.index_select(0, ind), self.downstream_distances, d[1]) \
                + a[2] * 1e4 * self.promoter_weights.index_select(0, ind).sum(-1) \
                + b
            
            mu = torch.multiply(self.read_depth.index_select(0, ind), torch.exp(expr_rate))
            p = torch.minimum(mu / (mu + theta), torch.tensor([0.99999]))

            pyro.sample(self.name +'_obs', 
                        dist.ZeroInflatedNegativeBinomial(total_count=theta, probs=p, gate = psi),
                        obs= self.gene_expr.index_select(0, ind))
Esempio n. 3
0
    def model(self, raw_expr, encoded_expr, read_depth):

        pyro.module("decoder", self.decoder)

        with pyro.plate("genes", self.num_genes):

            dispersion = pyro.sample(
                "dispersion",
                dist.Gamma(
                    torch.tensor(2.).to(self.device),
                    torch.tensor(0.5).to(self.device)))
            psi = pyro.sample(
                "dropout",
                dist.Beta(
                    torch.tensor(1.).to(self.device),
                    torch.tensor(10.).to(self.device)))

        #pyro.module("decoder", self.decoder)
        with pyro.plate("cells", encoded_expr.shape[0]):
            # Dirichlet prior  𝑝(πœƒ|𝛼) is replaced by a log-normal distribution

            theta_loc = self.prior_mu * encoded_expr.new_ones(
                (encoded_expr.shape[0], self.num_topics))
            theta_scale = self.prior_std * encoded_expr.new_ones(
                (encoded_expr.shape[0], self.num_topics))
            theta = pyro.sample(
                "theta",
                dist.LogNormal(theta_loc, theta_scale).to_event(1))
            theta = theta / theta.sum(-1, keepdim=True)
            # conditional distribution of 𝑀𝑛 is defined as
            # 𝑀𝑛|𝛽,πœƒ ~ Categorical(𝜎(π›½πœƒ))
            expr_rate = pyro.deterministic("expr_rate", self.decoder(theta))

            mu = torch.multiply(read_depth, expr_rate)
            p = torch.minimum(mu / (mu + dispersion), self.max_prob)

            pyro.sample(
                'obs',
                dist.ZeroInflatedNegativeBinomial(total_count=dispersion,
                                                  probs=p,
                                                  gate=psi).to_event(1),
                obs=raw_expr)
Esempio n. 4
0
 def model(self, x, log_library):
     # register PyTorch module `decoder` with Pyro
     pyro.module("scvi", self)
     with pyro.plate("data", x.shape[0]):
         # setup hyperparameters for prior p(z)
         z_loc = x.new_zeros(torch.Size((x.shape[0], self.n_latent)))
         z_scale = x.new_ones(torch.Size((x.shape[0], self.n_latent)))
         # sample from prior (value will be sampled by guide when computing the ELBO)
         z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
         # decode the latent code z
         px_scale, _, px_rate, px_dropout = self.decoder("gene", z, log_library)
         # build count distribution
         nb_logits = (px_rate + self.epsilon).log() - (
             self.px_r.exp() + self.epsilon
         ).log()
         x_dist = dist.ZeroInflatedNegativeBinomial(
             gate_logits=px_dropout, total_count=self.px_r.exp(), logits=nb_logits
         )
         # score against actual counts
         pyro.sample("obs", x_dist.to_event(1), obs=x)
Esempio n. 5
0
    def model(self, raw_expr, encoded_expr, read_depth):

        pyro.module("decoder", self.decoder)

        dispersion = pyro.param("dispersion",
                                torch.tensor(5.).to(self.device) *
                                torch.ones(self.num_genes).to(self.device),
                                constraint=constraints.positive)

        with pyro.plate("cells", encoded_expr.shape[0]):

            # Dirichlet prior  𝑝(πœƒ|𝛼) is replaced by a log-normal distribution
            theta_loc = self.prior_mu * encoded_expr.new_ones(
                (encoded_expr.shape[0], self.num_topics))
            theta_scale = self.prior_std * encoded_expr.new_ones(
                (encoded_expr.shape[0], self.num_topics))
            theta = pyro.sample(
                "theta",
                dist.LogNormal(theta_loc, theta_scale).to_event(1))
            theta = theta / theta.sum(-1, keepdim=True)

            read_scale = pyro.sample(
                'read_depth',
                dist.LogNormal(torch.log(read_depth), 1.).to_event(1))

            #read_scale = torch.minimum(read_scale, self.max_scale)
            # conditional distribution of 𝑀𝑛 is defined as
            # 𝑀𝑛|𝛽,πœƒ ~ Categorical(𝜎(π›½πœƒ))
            expr_rate, dropout = self.decoder(theta)

            mu = torch.multiply(read_scale, expr_rate)
            p = torch.minimum(mu / (mu + dispersion), self.max_prob)

            pyro.sample('obs',
                        dist.ZeroInflatedNegativeBinomial(
                            total_count=dispersion,
                            probs=p,
                            gate_logits=dropout).to_event(1),
                        obs=raw_expr)
Esempio n. 6
0
    def forward(self):

        with pyro.plate("gene_weights", self.G):

            b = pyro.sample("b", dist.Normal(-10.,3.))
            theta = pyro.sample("theta", dist.Gamma(2., 0.5))
            psi = pyro.sample("dropout", dist.Beta(1., 10.))

            with pyro.plate("topic-gene_weights", self.K):
                beta = pyro.sample("beta", dist.Gamma(1., 5.))        
        
        with pyro.plate("gene", self.G) as gene:
            with pyro.plate("data", self.N, subsample_size=64) as ind:

                expr_rate = pyro.deterministic("rate", torch.matmul(self.cell_topics.index_select(0, ind), beta) + b)

                mu = torch.reshape(self.read_depth, (-1,1)).index_select(0, ind) * torch.exp(expr_rate)
                p = torch.minimum(mu / (mu + theta), torch.tensor([0.99999]))

                pyro.sample("obs",
                            dist.ZeroInflatedNegativeBinomial(total_count=theta, 
                                                              probs=p, gate = psi),
                            obs= self.gene_expr.index_select(0, ind))