Пример #1
0
    def model(y=None):
        with pyro.util.optional(pyro.plate("plate", 3), with_plate):
            x = pyro.sample("x", dist.Normal(0, 1).expand(event_shape).to_event())
            x2 = pyro.deterministic("x2", x**2, event_dim=len(event_shape))

        pyro.deterministic("x3", x2)
        return pyro.sample("obs", dist.Normal(x2, 0.1).to_event(), obs=y)
Пример #2
0
def backward_model(data,
                   *args,
                   encoder=None,
                   use_deltas=False,
                   kl_beta=1.0,
                   **kwargs):
    encoder = pyro.module("encoder", encoder)
    N = data.shape[0]
    with poutine.scale_messenger.ScaleMessenger(1 / N):
        with pyro.plate("batch", N):
            encoder_out = encoder(data)
            delta_sample_transformer_params(encoder.transformer.transformers,
                                            encoder_out["transform_params"])

            pyro.deterministic("attention_input", encoder_out["view"])

            def make_dist():
                if use_deltas:
                    return D.Delta(encoder_out["z_mu"])
                else:
                    return D.Normal(encoder_out["z_mu"],
                                    torch.exp(encoder_out["z_std"]) + 1e-6)

            with poutine.scale_messenger.ScaleMessenger(kl_beta):
                z = pyro.sample("z", make_dist().to_event(1))
Пример #3
0
def forward_model(
    data,
    transforms=None,
    cond=True,
    decoder=None,
    output_size=40,
    device=torch.device("cpu"),
    **kwargs
):
    decoder = pyro.module("view_decoder", decoder)
    with pyro.plate(data.shape[0]):

        z = pyro.sample(
            "z",
            D.Normal(
                torch.zeros(decoder.latent_dim, device=device),
                torch.ones(decoder.latent_dim, device=device),
            ).to_event(1),
        )

        view = decoder(z)

        pyro.deterministic("canonical_view", view)

        grid = coordinates.identity_grid([output_size, output_size], device=device)
        grid = grid.expand(data.shape[0], *grid.shape)

        transform = random_pose_transform(transforms)

        transform_grid = transform(grid)

        transformed_view = T.broadcasting_grid_sample(view, transform_grid)
        obs = data if cond else None
        pyro.sample("pixels", D.Bernoulli(transformed_view).to_event(3), obs=obs)
Пример #4
0
    def model(self, xs, ys=None):
        # register this pytorch module and all of its sub-modules with pyro
        pyro.module("generation_net", self)
        batch_size = xs.shape[0]
        with pyro.plate("data"):

            # Prior network uses the baseline predictions as initial guess.
            # This is the generative process with recurrent connection
            with torch.no_grad():
                # this ensures the training process does not change the
                # baseline network
                y_hat = self.baseline_net(xs).view(xs.shape)

            # sample the handwriting style from the prior distribution, which is
            # modulated by the input xs.
            prior_loc, prior_scale = self.prior_net(xs, y_hat)
            zs = pyro.sample('z', dist.Normal(prior_loc, prior_scale).to_event(1))

            # the output y is generated from the distribution pθ(y|x, z)
            loc = self.generation_net(zs)

            if ys is not None:
                # In training, we will only sample in the masked image
                mask_loc = loc[(xs == -1).view(-1, 784)].view(batch_size, -1)
                mask_ys = ys[xs == -1].view(batch_size, -1)
                pyro.sample('y', dist.Bernoulli(mask_loc).to_event(1), obs=mask_ys)
            else:
                # In testing, no need to sample: the output is already a
                # probability in [0, 1] range, which better represent pixel
                # values considering grayscale. If we sample, we will force
                # each pixel to be  either 0 or 1, killing the grayscale
                pyro.deterministic('y', loc.detach())

            # return the loc so we can visualize it later
            return loc
Пример #5
0
    def model(self,
              data,
              transforms=None):

        output_size = self.encoder.insize
        decoder = pyro.module("decoder", self.decoder)
        # decoder takes z and std in the transformed coordinate frame
        # and the theta
        # and outputs an upright image
        with pyro.plate(data.shape[0]):
            # prior for z
            z = pyro.sample(
                "z",
                D.Normal(
                    torch.zeros(decoder.z_dim, device=data.device),
                    torch.ones(decoder.z_dim, device=data.device),
                ).to_event(1),
            )

            # given a z, the decoder produces an "image"
            # this image must be transformed from the self consistent basis
            # to real world basis
            # first, z and std for the self consistent basis is outputted
            # then it is transfomed
            view = decoder(z)

            pyro.deterministic("canonical_view", view)
            # pyro.deterministic
            # is like pyro.sample but it is deterministic...?

            # all of this is completely independent of the input
            # maybe this is the "prior for the transformation"
            # and hence it looks completely independent of the input
            # but when the model is run again, these variables are replayed
            # with the theta generated by the guide
            # makes sense
            # so the model replays with theta and mu and sigma generated by
            # the guide,
            # taking theta and mu sigma and applying the inverse transform
            # to get the output image.
            grid = coordinates.identity_grid(
                [output_size, output_size], device=data.device)
            grid = grid.expand(data.shape[0], *grid.shape)
            transforms = T.TransformSequence(T.Translation(), T.Rotation())
            transform = random_pose_transform(transforms)

            transform_grid = transform(grid)

            # output from decoder is transormed in do a different coordinate system

            transformed_view = T.broadcasting_grid_sample(view, transform_grid)

            # view from decoder outputs an image
            pyro.sample(
                "pixels", D.Bernoulli(transformed_view).to_event(3), obs=data)
Пример #6
0
def backward_model(data, *args, encoder=None, **kwargs):
    encoder = pyro.module("encoder", encoder)
    N = data.shape[0]
    with poutine.scale_messenger.ScaleMessenger(1 / N):
        with pyro.plate("batch", N):
            encoder_out = encoder(data)
            delta_sample_transformer_params(encoder.transformer.transformers,
                                            encoder_out["transform_params"])

            pyro.deterministic("attention_input", encoder_out["view"])
            z = pyro.sample(
                "z",
                D.Normal(encoder_out["z_mu"],
                         torch.exp(encoder_out["z_std"]) + 1e-3).to_event(1),
            )
def state_space_model(data, N=1, T=2, prior_drift=0., verbose=False):
    # global rvs
    drift = pyro.sample('drift', dist.Normal(prior_drift, 1))
    vol = pyro.sample('vol', dist.LogNormal(0, 1))
    uncert = pyro.sample('uncert', dist.LogNormal(-5, 1))

    if verbose:
        print(f"Using drift = {drift}, vol = {vol}, uncert = {uncert}")

    # the latent time series you want to infer
    # since you want to output this, we initialize a vector where you'll
    # save the inferred values
    latent = torch.empty((T, N))  # +1 comes from hidden initial condition

    # I think you want to plate out the same state space model for N different obs
    with pyro.plate('data_plate', N) as n:
        x0 = pyro.sample('x0',
                         dist.Normal(drift,
                                     vol))  # or whatever your IC might be
        latent[0, n] = x0

        # now comes the markov part, as you correctly noted
        for t in pyro.markov(range(1, T)):
            x_t = pyro.sample(f"x_{t}",
                              dist.Normal(latent[t - 1, n] + drift, vol))
            y_t = pyro.sample(f"y_{t}",
                              dist.Normal(x_t, uncert),
                              obs=data[t - 1, n] if data is not None else None)
            latent[t, n] = x_t

    return pyro.deterministic('latent', latent)
Пример #8
0
    def model(self, home_team, away_team):

        sigma_a = pyro.sample("sigma_a", dist.HalfNormal(1.0))
        sigma_b = pyro.sample("sigma_b", dist.HalfNormal(1.0))
        mu_b = pyro.sample("mu_b", dist.Normal(0.0, 1.0))
        rho_raw = pyro.sample("rho_raw", dist.Beta(2, 2))
        rho = pyro.deterministic("rho", 2.0 * rho_raw - 1.0)

        log_gamma = pyro.sample("log_gamma", dist.Normal(0, 1))

        with pyro.plate("teams", self.n_teams):
            abilities = pyro.sample(
                "abilities",
                dist.MultivariateNormal(
                    torch.tensor([0.0, mu_b]),
                    covariance_matrix=torch.tensor(
                        [
                            [sigma_a ** 2.0, rho * sigma_a * sigma_b],
                            [rho * sigma_a * sigma_b, sigma_b ** 2.0],
                        ]
                    ),
                ),
            )

        log_a = abilities[:, 0]
        log_b = abilities[:, 1]
        home_inds = torch.tensor([self.team_to_index[team] for team in home_team])
        away_inds = torch.tensor([self.team_to_index[team] for team in away_team])
        home_rate = torch.exp(log_a[home_inds] + log_b[away_inds] + log_gamma)
        away_rate = torch.exp(log_a[away_inds] + log_b[home_inds])

        pyro.sample("home_goals", dist.Poisson(home_rate))
        pyro.sample("away_goals", dist.Poisson(away_rate))
Пример #9
0
def quantize(name, x_real, min, max):
    """
    Randomly quantize in a way that preserves probability mass.
    We use a piecewise polynomial spline of order 3.
    """
    assert min < max
    lb = x_real.detach().floor()

    # This cubic spline interpolates over the nearest four integers, ensuring
    # piecewise quadratic gradients.
    s = x_real - lb
    ss = s * s
    t = 1 - s
    tt = t * t
    probs = torch.stack(
        [
            t * tt,
            4 + ss * (3 * s - 6),
            4 + tt * (3 * t - 6),
            s * ss,
        ],
        dim=-1,
    ) * (1 / 6)
    q = pyro.sample("Q_" + name, dist.Categorical(probs)).type_as(x_real)

    x = lb + q - 1
    x = torch.max(x, 2 * min - 1 - x)
    x = torch.min(x, 2 * max + 1 - x)

    return pyro.deterministic(name, x)
Пример #10
0
def discrete_model(args, data):
    # Sample global parameters.
    rate_s, prob_i, rho = global_model(args.population)

    # Sequentially sample time-local variables.
    S = torch.tensor(args.population - 1.0)
    I = torch.tensor(1.0)
    for t, datum in enumerate(data):
        S2I = pyro.sample("S2I_{}".format(t),
                          dist.Binomial(S, -(rate_s * I).expm1()))
        I2R = pyro.sample("I2R_{}".format(t), dist.Binomial(I, prob_i))
        S = pyro.deterministic("S_{}".format(t), S - S2I)
        I = pyro.deterministic("I_{}".format(t), I + S2I - I2R)
        pyro.sample("obs_{}".format(t),
                    dist.ExtendedBinomial(S2I, rho),
                    obs=datum)
Пример #11
0
    def forward(self,
                diameters: Tensor,
                resp_model=configs.respiratorA,
                debit=_DEBIT):
        """
        Parameters
        ----------
        diameters (Tensor): particle diameters
        resp_model (function):
            Tuple of surface area and list of `~MaskLayers`.
        """
        surface_area_, layers_ = resp_model()

        face_vel = debit / surface_area_

        with pyro.plate("diameters"):
            phi = pyro.deterministic(
                "phi",
                penetration.compute_penetration_profile(diameters,
                                                        layers_,
                                                        face_vel,
                                                        configs.temperature,
                                                        configs.viscosity,
                                                        return_log=True))

            obs = pyro.sample('obs_log',
                              dist.Normal(loc=phi, scale=self.obs_scale))
        return obs
Пример #12
0
def forward_model(data,
                  label,
                  N=-1,
                  transforms=None,
                  instantiate_label=False,
                  cond_label=True,
                  cond=True,
                  decoder=None,
                  latent_decoder=None,
                  output_size=40,
                  device=torch.device("cpu"),
                  **kwargs):
    decoder = pyro.module("view_decoder", decoder)
    with pyro.plate("batch", N):

        z = pyro.sample(
            "z",
            D.Normal(
                torch.zeros(N, decoder.latent_dim, device=device),
                torch.ones(N, decoder.latent_dim, device=device),
            ).to_event(1),
        )

        # use supervision
        if instantiate_label:
            latent_decoder = pyro.module("latent_decoder", latent_decoder)
            label_logits = latent_decoder(z)
            obs_label = label if cond_label else None
            pyro.sample("y", D.Categorical(logits=label_logits), obs=obs_label)

        view = decoder(z)

        pyro.deterministic("canonical_view", view)

        grid = coordinates.identity_grid([output_size, output_size],
                                         device=device)
        grid = grid.expand(N, *grid.shape)

        transform = random_pose_transform(transforms, device=device)

        transform_grid = transform(grid)

        transformed_view = T.broadcasting_grid_sample(view, transform_grid)
        obs = data if cond else None
        pyro.sample("pixels",
                    D.Bernoulli(transformed_view).to_event(3),
                    obs=obs)
Пример #13
0
def forward_model(
        data,
        transforms=None,
        instantiate_label=False,
        cond=True,
        decoder=None,
        output_size=128,
        device=torch.device("cpu"),
        kl_beta=1.0,
        **kwargs,
):
    decoder = pyro.module("view_decoder", decoder)
    N = data.shape[0]
    with poutine.scale_messenger.ScaleMessenger(1 / N):
        with pyro.plate("batch", N):
            with poutine.scale_messenger.ScaleMessenger(kl_beta):
                z = pyro.sample(
                    "z",
                    D.Normal(
                        torch.zeros(N, decoder.latent_dim, device=device),
                        torch.ones(N, decoder.latent_dim, device=device),
                    ).to_event(1),
                )

            # use supervision

            view = decoder(z)

            pyro.deterministic("canonical_view", view)

            grid = coordinates.identity_grid([output_size, output_size],
                                             device=device)
            grid = grid.expand(N, *grid.shape)
            scale = view.shape[-1] / output_size
            grid = grid * (
                1 / scale
            )  # rescales the image co-ordinates so one pixel of the recon corresponds to 1 pixel of the view.

            transform = random_pose_transform(transforms, device=device)

            transform_grid = transform(grid)

            transformed_view = T.broadcasting_grid_sample(view, transform_grid)
            obs = data if cond else None
            pyro.sample("pixels",
                        D.Laplace(transformed_view, 0.5).to_event(3),
                        obs=obs)
Пример #14
0
def pyro_sample_saas_lengthscales(dim: int,
                                  alpha: float = 0.1,
                                  pyro: Any = None,
                                  **tkwargs: Any) -> Tensor:
    tausq = pyro.sample(
        "kernel_tausq",
        pyro.distributions.HalfCauchy(torch.tensor(alpha, **tkwargs)),
    )
    inv_length_sq = pyro.sample(
        "_kernel_inv_length_sq",
        pyro.distributions.HalfCauchy(torch.ones(dim, **tkwargs)),
    )
    inv_length_sq = pyro.deterministic("kernel_inv_length_sq",
                                       tausq * inv_length_sq)
    lengthscale = pyro.deterministic(
        "lengthscale",
        (1.0 / inv_length_sq).sqrt(),  # pyre-ignore [16]
    )
    return lengthscale
Пример #15
0
 def sample_lengthscale(
     self, dim: int, alpha: float = 0.1, **tkwargs: Any
 ) -> Tensor:
     r"""Sample the lengthscale."""
     tausq = pyro.sample(
         "kernel_tausq",
         pyro.distributions.HalfCauchy(torch.tensor(alpha, **tkwargs)),
     )
     inv_length_sq = pyro.sample(
         "_kernel_inv_length_sq",
         pyro.distributions.HalfCauchy(torch.ones(dim, **tkwargs)),
     )
     inv_length_sq = pyro.deterministic(
         "kernel_inv_length_sq", tausq * inv_length_sq
     )
     lengthscale = pyro.deterministic(
         "lengthscale",
         (1.0 / inv_length_sq).sqrt(),
     )
     return lengthscale
Пример #16
0
def spectral_matrix_gp(n_points, plate, suffix=''):
    if suffix != '':
        suffix = f'_{suffix}'

    with plate:
        l = 150
        pts = torch.arange(n_points, dtype=torch.float64).unsqueeze(-1)
        distance_squared = torch.pow(pts - pts.T, 2).unsqueeze(-1)
        cov = pyro.deterministic('cov',
                                 torch.exp(-0.5 * distance_squared / l).T,
                                 event_dim=2).contiguous()
        diag_idx = np.diag_indices(n_points, ndim=1)
        cov[:, diag_idx, diag_idx] += torch.rand(1, 1, n_points) / 1000
        gp = pyro.sample(
            'gp',
            dist.MultivariateNormal(loc=torch.tensor([0.] * n_points),
                                    covariance_matrix=cov).to_event(0))
        S = pyro.deterministic('S', F.softplus(gp * 50) * 20, event_dim=1)

    return S
Пример #17
0
def quantize(name, x_real, min, max, num_quant_bins=4):
    """Randomly quantize in a way that preserves probability mass."""
    assert _all(min < max)
    if num_quant_bins == 1:
        x = x_real.detach().round()
        return pyro.deterministic(name, x, event_dim=0)

    lb = x_real.detach().floor()

    probs = compute_bin_probs(x_real - lb, num_quant_bins=num_quant_bins)

    q = pyro.sample("Q_" + name,
                    dist.Categorical(probs),
                    infer={"enumerate": "parallel"})
    q = q.type_as(x_real) - (num_quant_bins // 2 - 1)

    x = lb + q
    x = torch.max(x, 2 * min - 1 - x)
    x = torch.min(x, 2 * max + 1 - x)

    return pyro.deterministic(name, x, event_dim=0)
Пример #18
0
def time_matrix(time, irf, t0, plate, suffix='', scattering=False):
    if suffix != '':
        suffix = f'_{suffix}'

    if scattering:
        lol = pyro.deterministic(f'T_sc', irf / irf.max())
        return lol

    with plate:
        if not scattering:
            tau = pyro.sample(f'tau{suffix}', dist.Gamma(5, 10))[:, np.newaxis]
        else:
            tau = pyro.sample(f'tau{suffix}',
                              dist.Uniform(0.00001, 0.005))[:, np.newaxis]

    T_unbound = pyro.deterministic(f'T_unbound{suffix}',
                                   torch.exp(-time / tau))
    T = pyro.deterministic(f'T{suffix}', T_unbound)

    T_convolved = conv1d(T, irf, mode='fft_circular', cut=False)

    if not scattering:
        circular_multiplier = 1 / (1 - torch.exp(-time[-1] / tau))
        circular_multiplier = pyro.deterministic('circular_multiplier',
                                                 circular_multiplier)
        T_convolved = T_convolved * circular_multiplier
    T_convolved = pyro.deterministic(f'T_convolved{suffix}', T_convolved)

    T_scaled = T_convolved / (T_convolved.max(dim=1)[0][:, np.newaxis])
    T_scaled = pyro.deterministic(f'T_scaled{suffix}', T_scaled)

    return T_scaled
Пример #19
0
    def model_full(self, data=None, time=None, fix_time=False):
        components_plate = pyro.plate('components', self.n_components)

        if fix_time:
            t0 = pyro.deterministic('t0', self.t0)
        else:
            t0 = pyro.sample('t0', dist.Normal(loc=self.t0, scale=0.01))
        xi = pyro.sample(
            'xi',
            dist.Exponential(torch.rand(self.n_wavelenghs) / 10).to_event(1))

        irf = Interp1d()(self.time_padded, self.irf_padded, time - t0, None)

        T_scaled = time_matrix(time, irf, t0, components_plate, suffix='')
        S = spectral_matrix_gp(self.n_wavelenghs, components_plate, suffix='')
        ST = pyro.deterministic('ST', S.T @ T_scaled, event_dim=2)

        if self.scattering:
            scattering_plate = pyro.plate('scattering', 1)
            T_sc = time_matrix(time,
                               irf,
                               t0,
                               scattering_plate,
                               suffix='sc',
                               scattering=True)
            S_sc = spectral_matrix_unscaled(
                self.n_wavelenghs, scattering_plate, suffix='sc') * 20
            ST_sc = pyro.deterministic('ST_sc', S_sc.T @ T_sc, event_dim=2)
            ST = ST + ST_sc

        if data is not None:
            data = data[self.time_slice][self.wavelength_slice]
        pyro.sample(
            'I_obs',
            dist.Poisson(
                (ST.T +
                 xi)[self.time_slice][self.wavelength_slice]).to_event(2),
            obs=data)
        pyro.deterministic('I', ST.T + xi, event_dim=2)
Пример #20
0
def spectral_matrix(n_points, plate, suffix=''):
    if suffix != '':
        suffix = f'_{suffix}'

    with plate:
        S_unscaled = pyro.sample(
            f'S_unscaled{suffix}',
            dist.Exponential(torch.rand(n_points)).to_event(1))
        S = pyro.deterministic(f'S{suffix}',
                               S_unscaled /
                               S_unscaled.max(dim=1)[0][:, np.newaxis],
                               event_dim=1)
    return S
Пример #21
0
    def forward(self, data):
        N_pop = 300

        p1 = self.ode_params1.view((-1, ))
        p2 = self.ode_params2.view((-1, ))
        p3 = self.ode_params3.view((-1, ))
        R0 = pyro.deterministic('R0', torch.zeros_like(p1))
        ode_params = torch.stack([p1, p2, p3, 1 - p3, R0], dim=1)
        SIR_sim = self._ode_op.apply(ode_params, (self._ode_model, ))

        for i in range(len(data)):
            pyro.sample("obs_{}".format(i),
                        dist.Poisson(SIR_sim[..., i, 1] * N_pop),
                        obs=data[i])
        return SIR_sim
Пример #22
0
    def model(self, raw_expr, encoded_expr, read_depth):

        pyro.module("decoder", self.decoder)

        with pyro.plate("genes", self.num_genes):

            dispersion = pyro.sample(
                "dispersion",
                dist.Gamma(
                    torch.tensor(2.).to(self.device),
                    torch.tensor(0.5).to(self.device)))
            psi = pyro.sample(
                "dropout",
                dist.Beta(
                    torch.tensor(1.).to(self.device),
                    torch.tensor(10.).to(self.device)))

        #pyro.module("decoder", self.decoder)
        with pyro.plate("cells", encoded_expr.shape[0]):
            # Dirichlet prior  𝑝(𝜃|𝛼) is replaced by a log-normal distribution

            theta_loc = self.prior_mu * encoded_expr.new_ones(
                (encoded_expr.shape[0], self.num_topics))
            theta_scale = self.prior_std * encoded_expr.new_ones(
                (encoded_expr.shape[0], self.num_topics))
            theta = pyro.sample(
                "theta",
                dist.LogNormal(theta_loc, theta_scale).to_event(1))
            theta = theta / theta.sum(-1, keepdim=True)
            # conditional distribution of 𝑤𝑛 is defined as
            # 𝑤𝑛|𝛽,𝜃 ~ Categorical(𝜎(𝛽𝜃))
            expr_rate = pyro.deterministic("expr_rate", self.decoder(theta))

            mu = torch.multiply(read_depth, expr_rate)
            p = torch.minimum(mu / (mu + dispersion), self.max_prob)

            pyro.sample(
                'obs',
                dist.ZeroInflatedNegativeBinomial(total_count=dispersion,
                                                  probs=p,
                                                  gate=psi).to_event(1),
                obs=raw_expr)
Пример #23
0
def pyrocov_model(dataset):
    # Tensor shapes are commented at the end of some lines.
    features = dataset["features"]
    local_time = dataset["local_time"][..., None]  # [T, P, 1]
    T, P, _ = local_time.shape
    S, F = features.shape
    weekly_strains = dataset["weekly_strains"]
    assert weekly_strains.shape == (T, P, S)

    # Sample global random variables.
    coef_scale = pyro.sample("coef_scale", dist.InverseGamma(5e3, 1e2))[..., None]
    rate_scale = pyro.sample("rate_scale", dist.LogNormal(-4, 2))[..., None]
    init_loc_scale = pyro.sample("init_loc_scale", dist.LogNormal(0, 2))[..., None]
    init_scale = pyro.sample("init_scale", dist.LogNormal(0, 2))[..., None]

    # Assume relative growth rate depends strongly on mutations and weakly on place.
    coef_loc = torch.zeros(F)
    coef = pyro.sample("coef", dist.Logistic(coef_loc, coef_scale).to_event(1))  # [F]
    rate_loc = pyro.deterministic(
        "rate_loc", 0.01 * coef @ features.T, event_dim=1
    )  # [S]

    # Assume initial infections depend strongly on strain and place.
    init_loc = pyro.sample(
        "init_loc", dist.Normal(torch.zeros(S), init_loc_scale).to_event(1)
    )  # [S]
    with pyro.plate("place", P, dim=-1):
        rate = pyro.sample(
            "rate", dist.Normal(rate_loc, rate_scale).to_event(1)
        )  # [P, S]
        init = pyro.sample(
            "init", dist.Normal(init_loc, init_scale).to_event(1)
        )  # [P, S]

        # Finally observe counts.
        with pyro.plate("time", T, dim=-2):
            logits = init + rate * local_time  # [T, P, S]
            pyro.sample(
                "obs",
                dist.Multinomial(logits=logits, validate_args=False),
                obs=weekly_strains,
            )
Пример #24
0
    def model(data):
        a = pyro.sample("a", dist.Normal(0, 1))
        b = pyro.sample("b", NonreparameterizedNormal(a, 0))
        c = pyro.sample("c", dist.Normal(b, 1))
        d = pyro.sample("d", dist.Normal(a, c.exp()))

        e = pyro.sample("e", dist.Normal(0, 1))
        f = pyro.sample("f", dist.Normal(0, 1))
        g = pyro.sample("g",
                        dist.Bernoulli(logits=e + f),
                        obs=torch.tensor(0.0))

        with pyro.plate("p", len(data)):
            d_ = d.detach()  # this results in a known failure
            h = pyro.sample("h", dist.Normal(c, d_.exp()))
            i = pyro.deterministic("i", h + 1)
            j = pyro.sample("j", dist.Delta(h + 1), obs=h + 1)
            k = pyro.sample("k", dist.Normal(a, j.exp()), obs=data)

        return [a, b, c, d, e, f, g, h, i, j, k]
Пример #25
0
    def forward(self):

        with pyro.plate("gene_weights", self.G):

            b = pyro.sample("b", dist.Normal(-10.,3.))
            theta = pyro.sample("theta", dist.Gamma(2., 0.5))
            psi = pyro.sample("dropout", dist.Beta(1., 10.))

            with pyro.plate("topic-gene_weights", self.K):
                beta = pyro.sample("beta", dist.Gamma(1., 5.))        
        
        with pyro.plate("gene", self.G) as gene:
            with pyro.plate("data", self.N, subsample_size=64) as ind:

                expr_rate = pyro.deterministic("rate", torch.matmul(self.cell_topics.index_select(0, ind), beta) + b)

                mu = torch.reshape(self.read_depth, (-1,1)).index_select(0, ind) * torch.exp(expr_rate)
                p = torch.minimum(mu / (mu + theta), torch.tensor([0.99999]))

                pyro.sample("obs",
                            dist.ZeroInflatedNegativeBinomial(total_count=theta, 
                                                              probs=p, gate = psi),
                            obs= self.gene_expr.index_select(0, ind))
Пример #26
0
    def model(self, x_data, idx=None):

        # =====================Gene expression level scaling======================= #
        # Explains difference in expression between genes and
        # how it differs in single cell and spatial technology
        # compute hyperparameters from mean and sd
        gl_alpha_shape = self.gl_shape**2 / self.gl_shape_var
        gl_alpha_rate = self.gl_shape / self.gl_shape_var
        gl_beta_shape = self.gl_rate**2 / self.gl_rate_var
        gl_beta_rate = self.gl_rate / self.gl_rate_var

        self.gene_level_alpha_hyp = pyro.sample(
            'gene_level_alpha_hyp',
            dist.Gamma(
                torch.ones([1, 1]) * torch.tensor(gl_alpha_shape),
                torch.ones([1, 1]) * torch.tensor(gl_alpha_rate)))
        self.gene_level_beta_hyp = pyro.sample(
            'gene_level_beta_hyp',
            dist.Gamma(
                torch.ones([1, 1]) * torch.tensor(gl_beta_shape),
                torch.ones([1, 1]) * torch.tensor(gl_beta_rate)))

        self.gene_level = pyro.sample(
            'gene_level',
            dist.Gamma(
                torch.ones([self.n_var, 1]) * self.gene_level_alpha_hyp,
                torch.ones([self.n_var, 1]) * self.gene_level_beta_hyp))

        # scale cell state factors by gene_level
        self.gene_factors = pyro.deterministic('gene_factors', self.cell_state)

        # =====================Spot factors======================= #
        # prior on spot factors reflects the number of cells, fraction of their cytoplasm captured,
        # times heterogeniety in the total number of mRNA between individual cells with each cell type
        cps_shape = self.cell_number_prior['cells_per_spot'] ** 2 \
                    / (self.cell_number_prior['cells_per_spot'] / self.cell_number_prior['cells_mean_var_ratio'])
        cps_rate = self.cell_number_prior['cells_per_spot'] \
                   / (self.cell_number_prior['cells_per_spot'] / self.cell_number_prior['cells_mean_var_ratio'])
        self.cells_per_spot = pyro.sample(
            'cells_per_spot',
            dist.Gamma(
                torch.ones([self.n_obs, 1]) * torch.tensor(cps_shape),
                torch.ones([self.n_obs, 1]) * torch.tensor(cps_rate)))

        fps_shape = self.cell_number_prior['factors_per_spot'] ** 2 \
                    / (self.cell_number_prior['factors_per_spot'] / self.cell_number_prior['factors_mean_var_ratio'])
        fps_rate = self.cell_number_prior['factors_per_spot'] \
                   / (self.cell_number_prior['factors_per_spot'] / self.cell_number_prior['factors_mean_var_ratio'])
        self.factors_per_spot = pyro.sample(
            'factors_per_spot',
            dist.Gamma(
                torch.ones([self.n_obs, 1]) * torch.tensor(fps_shape),
                torch.ones([self.n_obs, 1]) * torch.tensor(fps_rate)))

        shape = self.factors_per_spot / torch.tensor(
            np.array(self.n_fact).reshape((1, 1)))
        rate = torch.ones([1, 1]) / self.cells_per_spot * self.factors_per_spot
        self.spot_factors = pyro.sample(
            'spot_factors',
            dist.Gamma(torch.matmul(shape, torch.ones([1, self.n_fact])),
                       torch.matmul(rate, torch.ones([1, self.n_fact]))))

        # =====================Spot-specific additive component======================= #
        # molecule contribution that cannot be explained by cell state signatures
        # these counts are distributed between all genes not just expressed genes
        self.spot_add_hyp = pyro.sample(
            'spot_add_hyp',
            dist.Gamma(
                torch.ones([2, 1]) * torch.tensor(1.),
                torch.ones([2, 1]) * torch.tensor(0.1)))
        self.spot_add = pyro.sample(
            'spot_add',
            dist.Gamma(
                torch.ones([self.n_obs, 1]) * self.spot_add_hyp[0, 0],
                torch.ones([self.n_obs, 1]) * self.spot_add_hyp[1, 0]))

        # =====================Gene-specific additive component ======================= #
        # per gene molecule contribution that cannot be explained by cell state signatures
        # these counts are distributed equally between all spots (e.g. background, free-floating RNA)
        self.gene_add_hyp = pyro.sample(
            'gene_add_hyp',
            dist.Gamma(
                torch.ones([2, 1]) * torch.tensor(1.),
                torch.ones([2, 1]) * torch.tensor(1.)))
        self.gene_add = pyro.sample(
            'gene_add',
            dist.Gamma(
                torch.ones([self.n_var, 1]) * self.gene_add_hyp[0, 0],
                torch.ones([self.n_var, 1]) * self.gene_add_hyp[1, 0]))

        # =====================Gene-specific overdispersion ======================= #
        self.phi_hyp = pyro.sample(
            'phi_hyp',
            dist.Gamma(
                torch.ones([1, 1]) * torch.tensor(self.phi_hyp_prior['mean']),
                torch.ones([1, 1]) * torch.tensor(self.phi_hyp_prior['sd'])))
        self.gene_E = pyro.sample(
            'gene_E',
            dist.Exponential(torch.ones([self.n_var, 1]) * self.phi_hyp[0, 0]))

        # =====================Expected expression ======================= #
        # expected expression
        self.mu_biol = torch.matmul(self.spot_factors[idx], self.gene_factors.T) * self.gene_level.T \
                       + self.gene_add.T + self.spot_add[idx]

        # =====================DATA likelihood ======================= #
        # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial
        self.data_target = pyro.sample(
            'data_target',
            NegativeBinomial(mu=self.mu_biol,
                             theta=torch.ones([1, 1]) /
                             (self.gene_E.T * self.gene_E.T)),
            obs=x_data)

        # =====================Compute nUMI from each factor in spots  ======================= #
        nUMI = (self.spot_factors *
                (self.gene_factors * self.gene_level).sum(0))
        self.nUMI_factors = pyro.deterministic('nUMI_factors', nUMI)
Пример #27
0
def model(
    s,
    m,
    y=None,
    gamma_hyper=1.0,
    pi0=1.0,
    rho0=1.0,
    epsilon0=0.01,
    alpha0=1000.0,
    dtype=torch.float32,
    device="cpu",
):

    # Cast inputs and set device
    m, gamma_hyper, pi0, rho0, epsilon0, alpha0 = [
        torch.tensor(v, dtype=dtype, device=device)
        for v in [m, gamma_hyper, pi0, rho0, epsilon0, alpha0]
    ]
    if y is not None:
        y = torch.tensor(y)

    n, g = m.shape

    with pyro.plate("position", g, dim=-1):
        with pyro.plate("strain", s, dim=-2):
            gamma = pyro.sample(
                "gamma",
                dist.RelaxedBernoulli(temperature=gamma_hyper, logits=0.0),
            )
    # gamma.shape == (s, g)

    rho_hyper = pyro.sample("rho_hyper", dist.Gamma(rho0, 1.0))
    rho = pyro.sample(
        "rho",
        dist.RelaxedOneHotCategorical(
            temperature=rho_hyper,
            logits=torch.zeros(s, dtype=dtype, device=device),
        ),
    )

    epsilon_hyper = pyro.sample("epsilon_hyper", dist.Beta(1.0, 1 / epsilon0))
    alpha_hyper = pyro.sample("alpha_hyper", dist.Gamma(alpha0, 1.0))
    pi_hyper = pyro.sample("pi_hyper", dist.Gamma(pi0, 1.0))

    with pyro.plate("sample", n, dim=-1):
        pi = pyro.sample(
            "pi",
            dist.RelaxedOneHotCategorical(temperature=pi_hyper, probs=rho),
        )
        alpha = pyro.sample("alpha", dist.Gamma(alpha_hyper, 1.0)).unsqueeze(
            -1
        )
        epsilon = pyro.sample(
            "epsilon", dist.Beta(1.0, 1 / epsilon_hyper)
        ).unsqueeze(-1)
    # pi.shape == (n, s)
    # alpha.shape == epsilon.shape == (n,)

    p_noerr = pyro.deterministic("p_noerr", pi @ gamma)
    p = pyro.deterministic(
        "p", (1 - epsilon / 2) * (p_noerr) + (epsilon / 2) * (1 - p_noerr)
    )
    # p.shape == (n, g)

    y = pyro.sample(
        "y",
        dist.BetaBinomial(
            concentration1=alpha * p,
            concentration0=alpha * (1 - p),
            total_count=m,
        ),
        obs=y,
    )
    # y.shape == (n, g)
    return y
Пример #28
0
    def __call__(self):
        """
        Notes
        -----
        Labeling system:
        1. for kernel level of parameters such as rho, span, nkots, kerenel etc.,
        use suffix _lev and _coef for levels and regression to partition
        2. for knots level of parameters such as coef, loc and scale priors,
        use prefix _lev and _rr _pr for levels, regular and positive regressors to partition
        3. reduce ambigious by replacing all greeks by labels more intuitive
        use _coef, _weight etc. instead of _beta, use _scale instead of _sigma
        """

        response = self.response
        which_valid = self.which_valid_res

        n_obs = self.n_obs
        # n_valid = self.n_valid_res
        sdy = self.sdy
        meany = self.mean_y
        dof = self.dof
        lev_knot_loc = self.lev_knot_loc
        seas_term = self.seas_term

        pr = self.pr
        rr = self.rr
        n_pr = self.n_pr
        n_rr = self.n_rr

        k_lev = self.k_lev
        k_coef = self.k_coef
        n_knots_lev = self.n_knots_lev
        n_knots_coef = self.n_knots_coef

        lev_knot_scale = self.lev_knot_scale
        # mult var norm stuff
        mvn = self.mvn
        geometric_walk = self.geometric_walk
        min_residuals_sd = self.min_residuals_sd
        if min_residuals_sd > 1.0:
            min_residuals_sd = torch.tensor(1.0)
        if min_residuals_sd < 0:
            min_residuals_sd = torch.tensor(0.0)
        # expand dim to n_rr x n_knots_coef
        rr_init_knot_loc = self.rr_init_knot_loc
        rr_init_knot_scale = self.rr_init_knot_scale
        rr_knot_scale = self.rr_knot_scale

        # this does not need to expand dim since it is used as latent grand mean
        pr_init_knot_loc = self.pr_init_knot_loc
        pr_init_knot_scale = self.pr_init_knot_scale
        pr_knot_scale = self.pr_knot_scale

        # transformation of data
        regressors = torch.zeros(n_obs)
        if n_pr > 0 and n_rr > 0:
            regressors = torch.cat([rr, pr], dim=-1)
        elif n_pr > 0:
            regressors = pr
        elif n_rr > 0:
            regressors = rr

        response_tran = response - meany - seas_term

        # sampling begins here
        extra_out = {}

        # levels sampling
        lev_knot_tran = pyro.sample(
            "lev_knot_tran",
            dist.Normal(lev_knot_loc - meany,
                        lev_knot_scale).expand([n_knots_lev]).to_event(1))
        lev = (lev_knot_tran @ k_lev.transpose(-2, -1))

        # using hierarchical priors vs. multivariate priors
        if mvn == 0:
            # regular regressor sampling
            if n_rr > 0:
                # pooling latent variables
                rr_init_knot = pyro.sample(
                    "rr_init_knot",
                    dist.Normal(rr_init_knot_loc,
                                rr_init_knot_scale).to_event(1))
                rr_knot = pyro.sample(
                    "rr_knot",
                    dist.Normal(
                        rr_init_knot.unsqueeze(-1) *
                        torch.ones(n_rr, n_knots_coef),
                        rr_knot_scale).to_event(2))
                rr_coef = (rr_knot @ k_coef.transpose(-2, -1)).transpose(
                    -2, -1)

            # positive regressor sampling
            if n_pr > 0:
                if geometric_walk:
                    # TODO: development method
                    pr_init_knot = pyro.sample(
                        "pr_init_knot",
                        dist.FoldedDistribution(
                            dist.Normal(pr_init_knot_loc,
                                        pr_init_knot_scale)).to_event(1))
                    pr_knot_step = pyro.sample(
                        "pr_knot_step",
                        # note that unlike rr_knot, the first one is ignored as we use the initial scale
                        # to sample the first knot
                        dist.Normal(torch.zeros(n_pr, n_knots_coef),
                                    pr_knot_scale).to_event(2))
                    pr_knot = pr_init_knot.unsqueeze(-1) * pr_knot_step.cumsum(
                        -1).exp()
                    pr_coef = (pr_knot @ k_coef.transpose(-2, -1)).transpose(
                        -2, -1)
                else:
                    # TODO: original method
                    # pooling latent variables
                    pr_init_knot = pyro.sample(
                        "pr_knot_loc",
                        dist.FoldedDistribution(
                            dist.Normal(pr_init_knot_loc,
                                        pr_init_knot_scale)).to_event(1))

                    pr_knot = pyro.sample(
                        "pr_knot",
                        dist.FoldedDistribution(
                            dist.Normal(
                                pr_init_knot.unsqueeze(-1) *
                                torch.ones(n_pr, n_knots_coef),
                                pr_knot_scale)).to_event(2))
                    pr_coef = (pr_knot @ k_coef.transpose(-2, -1)).transpose(
                        -2, -1)
        else:
            # regular regressor sampling
            if n_rr > 0:
                rr_init_knot = pyro.deterministic(
                    "rr_init_knot", torch.zeros(rr_init_knot_loc.shape))

                # updated mod
                loc_temp = rr_init_knot_loc.unsqueeze(-1) * torch.ones(
                    n_rr, n_knots_coef)
                scale_temp = torch.diag_embed(
                    rr_init_knot_scale.unsqueeze(-1) *
                    torch.ones(n_rr, n_knots_coef))

                # the sampling
                rr_knot = pyro.sample(
                    "rr_knot",
                    dist.MultivariateNormal(
                        loc=loc_temp,
                        covariance_matrix=scale_temp).to_event(1))
                rr_coef = (rr_knot @ k_coef.transpose(-2, -1)).transpose(
                    -2, -1)

            # positive regressor sampling
            if n_pr > 0:
                # this part is junk just so that the pr_init_knot has a prior; but it does not connect to anything else
                # pooling latent variables
                pr_init_knot = pyro.sample(
                    "pr_init_knot",
                    dist.FoldedDistribution(
                        dist.Normal(pr_init_knot_loc,
                                    pr_init_knot_scale)).to_event(1))
                # updated mod
                loc_temp = pr_init_knot_loc.unsqueeze(-1) * torch.ones(
                    n_pr, n_knots_coef)
                scale_temp = torch.diag_embed(
                    pr_init_knot_scale.unsqueeze(-1) *
                    torch.ones(n_pr, n_knots_coef))

                pr_knot = pyro.sample(
                    "pr_knot",
                    dist.MultivariateNormal(
                        loc=loc_temp,
                        covariance_matrix=scale_temp).to_event(1))
                pr_knot = torch.exp(pr_knot)
                pr_coef = (pr_knot @ k_coef.transpose(-2, -1)).transpose(
                    -2, -1)

        # concatenating all latent variables
        coef_init_knot = torch.zeros(n_rr + n_pr)
        coef_knot = torch.zeros((n_rr + n_pr, n_knots_coef))

        coef = torch.zeros(n_obs)
        if n_pr > 0 and n_rr > 0:
            coef_knot = torch.cat([rr_knot, pr_knot], dim=-2)
            coef_init_knot = torch.cat([rr_init_knot, pr_init_knot], dim=-1)
            coef = torch.cat([rr_coef, pr_coef], dim=-1)
        elif n_pr > 0:
            coef_knot = pr_knot
            coef_init_knot = pr_init_knot
            coef = pr_coef
        elif n_rr > 0:
            coef_knot = rr_knot
            coef_init_knot = rr_init_knot
            coef = rr_coef

        # coefficients likelihood/priors
        coef_prior_list = self.coef_prior_list
        if coef_prior_list:
            for x in coef_prior_list:
                name = x['name']
                # TODO: we can move torch conversion to init to enhance speed
                m = torch.tensor(x['prior_mean'])
                sd = torch.tensor(x['prior_sd'])
                # tp = torch.tensor(x['prior_tp_idx'])
                # idx = torch.tensor(x['prior_regressor_col_idx'])
                start_tp_idx = x['prior_start_tp_idx']
                end_tp_idx = x['prior_end_tp_idx']
                idx = x['prior_regressor_col_idx']
                pyro.sample("prior_{}".format(name),
                            dist.Normal(m, sd).to_event(2),
                            obs=coef[..., start_tp_idx:end_tp_idx, idx])

        # observation likelihood
        yhat = lev + (regressors * coef).sum(-1)
        obs_scale_base = pyro.sample("obs_scale_base",
                                     dist.Beta(2, 2)).unsqueeze(-1)
        # from 0.5 * sdy to sdy
        obs_scale = ((obs_scale_base *
                      (1.0 - min_residuals_sd)) + min_residuals_sd) * sdy

        # with pyro.plate("response_plate", n_valid):
        #     pyro.sample("response",
        #                 dist.StudentT(dof, yhat[..., which_valid], obs_scale),
        #                 obs=response_tran[which_valid])

        pyro.sample("response",
                    dist.StudentT(dof, yhat[..., which_valid],
                                  obs_scale).to_event(1),
                    obs=response_tran[which_valid])

        lev_knot = lev_knot_tran + meany

        extra_out.update({
            'yhat': yhat + seas_term + meany,
            'lev': lev + meany,
            'lev_knot': lev_knot,
            'coef': coef,
            'coef_knot': coef_knot,
            'coef_init_knot': coef_init_knot,
            'obs_scale': obs_scale,
        })
        return extra_out
    def forward(self, x_data, idx, batch_index):

        obs2sample = one_hot(batch_index, self.n_batch)

        obs_plate = self.create_plates(x_data, idx, batch_index)

        # =====================Cell abundances w_sf======================= #
        # factorisation prior on w_sf models similarity in locations
        # across cell types f and reflects the absolute scale of w_sf
        with obs_plate:
            n_s_cells_per_location = pyro.sample(
                "n_s_cells_per_location",
                dist.Gamma(
                    self.N_cells_per_location * self.N_cells_mean_var_ratio,
                    self.N_cells_mean_var_ratio,
                ),
            )

            y_s_groups_per_location = pyro.sample(
                "y_s_groups_per_location",
                dist.Gamma(self.Y_groups_per_location, self.ones),
            )

        # cell group loadings
        shape = self.ones_1_n_groups * y_s_groups_per_location / self.n_groups_tensor
        rate = self.ones_1_n_groups / (n_s_cells_per_location /
                                       y_s_groups_per_location)
        with obs_plate:
            z_sr_groups_factors = pyro.sample(
                "z_sr_groups_factors",
                dist.Gamma(
                    shape,
                    rate),  # .to_event(1)#.expand([self.n_groups]).to_event(1)
            )  # (n_obs, n_groups)

        k_r_factors_per_groups = pyro.sample(
            "k_r_factors_per_groups",
            dist.Gamma(self.factors_per_groups,
                       self.ones).expand([self.n_groups, 1]).to_event(2),
        )  # (self.n_groups, 1)

        c2f_shape = k_r_factors_per_groups / self.n_factors_tensor

        x_fr_group2fact = pyro.sample(
            "x_fr_group2fact",
            dist.Gamma(c2f_shape, k_r_factors_per_groups).expand(
                [self.n_groups, self.n_factors]).to_event(2),
        )  # (self.n_groups, self.n_factors)

        with obs_plate:
            w_sf_mu = z_sr_groups_factors @ x_fr_group2fact
            w_sf = pyro.sample(
                "w_sf",
                dist.Gamma(
                    w_sf_mu * self.w_sf_mean_var_ratio_tensor,
                    self.w_sf_mean_var_ratio_tensor,
                ),
            )  # (self.n_obs, self.n_factors)

        # =====================Location-specific detection efficiency ======================= #
        # y_s with hierarchical mean prior
        detection_mean_y_e = pyro.sample(
            "detection_mean_y_e",
            dist.Gamma(
                self.ones * self.detection_mean_hyp_prior_alpha,
                self.ones * self.detection_mean_hyp_prior_beta,
            ).expand([self.n_batch, 1]).to_event(2),
        )
        detection_hyp_prior_alpha = pyro.deterministic(
            "detection_hyp_prior_alpha",
            self.ones_n_batch_1 * self.detection_hyp_prior_alpha,
        )

        beta = (obs2sample @ detection_hyp_prior_alpha) / (
            obs2sample @ detection_mean_y_e)
        with obs_plate:
            detection_y_s = pyro.sample(
                "detection_y_s",
                dist.Gamma(obs2sample @ detection_hyp_prior_alpha, beta),
            )  # (self.n_obs, 1)

        # =====================Gene-specific additive component ======================= #
        # per gene molecule contribution that cannot be explained by
        # cell state signatures (e.g. background, free-floating RNA)
        s_g_gene_add_alpha_hyp = pyro.sample(
            "s_g_gene_add_alpha_hyp",
            dist.Gamma(self.gene_add_alpha_hyp_prior_alpha,
                       self.gene_add_alpha_hyp_prior_beta),
        )
        s_g_gene_add_mean = pyro.sample(
            "s_g_gene_add_mean",
            dist.Gamma(
                self.gene_add_mean_hyp_prior_alpha,
                self.gene_add_mean_hyp_prior_beta,
            ).expand([self.n_batch, 1]).to_event(2),
        )  # (self.n_batch)
        s_g_gene_add_alpha_e_inv = pyro.sample(
            "s_g_gene_add_alpha_e_inv",
            dist.Exponential(s_g_gene_add_alpha_hyp).expand([self.n_batch,
                                                             1]).to_event(2),
        )  # (self.n_batch)
        s_g_gene_add_alpha_e = self.ones / s_g_gene_add_alpha_e_inv.pow(2)

        s_g_gene_add = pyro.sample(
            "s_g_gene_add",
            dist.Gamma(s_g_gene_add_alpha_e, s_g_gene_add_alpha_e /
                       s_g_gene_add_mean).expand([self.n_batch,
                                                  self.n_vars]).to_event(2),
        )  # (self.n_batch, n_vars)

        # =====================Gene-specific overdispersion ======================= #
        alpha_g_phi_hyp = pyro.sample(
            "alpha_g_phi_hyp",
            dist.Gamma(self.alpha_g_phi_hyp_prior_alpha,
                       self.alpha_g_phi_hyp_prior_beta),
        )
        alpha_g_inverse = pyro.sample(
            "alpha_g_inverse",
            dist.Exponential(alpha_g_phi_hyp).expand(
                [self.n_batch, self.n_vars]).to_event(2),
        )  # (self.n_batch, self.n_vars)

        # =====================Expected expression ======================= #
        # expected expression
        mu = ((w_sf @ self.cell_state) +
              (obs2sample @ s_g_gene_add)) * detection_y_s
        alpha = obs2sample @ (self.ones / alpha_g_inverse.pow(2))
        # convert mean and overdispersion to total count and logits
        # total_count, logits = _convert_mean_disp_to_counts_logits(
        #    mu, alpha, eps=self.eps
        # )

        # =====================DATA likelihood ======================= #
        # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial
        with obs_plate:
            pyro.sample(
                "data_target",
                dist.GammaPoisson(concentration=alpha, rate=alpha / mu),
                # dist.NegativeBinomial(total_count=total_count, logits=logits),
                obs=x_data,
            )

        # =====================Compute mRNA count from each factor in locations  ======================= #
        with obs_plate:
            mRNA = w_sf * (self.cell_state).sum(-1)
            pyro.deterministic("u_sf_mRNA_factors", mRNA)
Пример #30
0
    def __call__(self):
        response = self.response
        num_of_obs = self.num_of_obs
        extra_out = {}

        # smoothing params
        if self.lev_sm_input < 0:
            lev_sm = pyro.sample("lev_sm", dist.Uniform(0, 1))
        else:
            lev_sm = torch.tensor(self.lev_sm_input, dtype=torch.double)
            extra_out['lev_sm'] = lev_sm
        if self.slp_sm_input < 0:
            slp_sm = pyro.sample("slp_sm", dist.Uniform(0, 1))
        else:
            slp_sm = torch.tensor(self.slp_sm_input, dtype=torch.double)
            extra_out['slp_sm'] = slp_sm

        # residual tuning parameters
        nu = pyro.sample("nu", dist.Uniform(self.min_nu, self.max_nu))

        # prior for residuals
        obs_sigma = pyro.sample("obs_sigma", dist.HalfCauchy(self.cauchy_sd))

        # regression parameters
        if self.num_of_pr == 0:
            pr = torch.zeros(num_of_obs)
            pr_beta = pyro.deterministic("pr_beta", torch.zeros(0))
        else:
            with pyro.plate("pr", self.num_of_pr):
                # fixed scale ridge
                if self.reg_penalty_type == 0:
                    pr_sigma = self.pr_sigma_prior
                # auto scale ridge
                elif self.reg_penalty_type == 2:
                    # weak prior for sigma
                    pr_sigma = pyro.sample(
                        "pr_sigma", dist.HalfCauchy(self.auto_ridge_scale))
                # case when it is not lasso
                if self.reg_penalty_type != 1:
                    # weak prior for betas
                    pr_beta = pyro.sample(
                        "pr_beta",
                        dist.FoldedDistribution(
                            dist.Normal(self.pr_beta_prior, pr_sigma)))
                else:
                    pr_beta = pyro.sample(
                        "pr_beta",
                        dist.FoldedDistribution(
                            dist.Laplace(self.pr_beta_prior,
                                         self.lasso_scale)))
            pr = pr_beta @ self.pr_mat.transpose(-1, -2)

        if self.num_of_nr == 0:
            nr = torch.zeros(num_of_obs)
            nr_beta = pyro.deterministic("nr_beta", torch.zeros(0))
        else:
            with pyro.plate("nr", self.num_of_nr):
                # fixed scale ridge
                if self.reg_penalty_type == 0:
                    nr_sigma = self.nr_sigma_prior
                # auto scale ridge
                elif self.reg_penalty_type == 2:
                    # weak prior for sigma
                    nr_sigma = pyro.sample(
                        "nr_sigma", dist.HalfCauchy(self.auto_ridge_scale))
                # case when it is not lasso
                if self.reg_penalty_type != 1:
                    # weak prior for betas
                    nr_beta = pyro.sample(
                        "nr_beta",
                        dist.FoldedDistribution(
                            dist.Normal(self.nr_beta_prior, nr_sigma)))
                else:
                    nr_beta = pyro.sample(
                        "nr_beta",
                        dist.FoldedDistribution(
                            dist.Laplace(self.nr_beta_prior,
                                         self.lasso_scale)))
            nr = nr_beta @ self.nr_mat.transpose(-1, -2)

        if self.num_of_rr == 0:
            rr = torch.zeros(num_of_obs)
            rr_beta = pyro.deterministic("rr_beta", torch.zeros(0))
        else:
            with pyro.plate("rr", self.num_of_rr):
                # fixed scale ridge
                if self.reg_penalty_type == 0:
                    rr_sigma = self.rr_sigma_prior
                # auto scale ridge
                elif self.reg_penalty_type == 2:
                    # weak prior for sigma
                    rr_sigma = pyro.sample(
                        "rr_sigma", dist.HalfCauchy(self.auto_ridge_scale))
                # case when it is not lasso
                if self.reg_penalty_type != 1:
                    # weak prior for betas
                    rr_beta = pyro.sample(
                        "rr_beta", dist.Normal(self.rr_beta_prior, rr_sigma))
                else:
                    rr_beta = pyro.sample(
                        "rr_beta",
                        dist.Laplace(self.rr_beta_prior, self.lasso_scale))
            rr = rr_beta @ self.rr_mat.transpose(-1, -2)

        # a hack to make sure we don't use a dimension "1" due to rr_beta and pr_beta sampling
        r = pr + nr + rr
        if r.dim() > 1:
            r = r.unsqueeze(-2)

        # trend parameters
        # local trend proportion
        lt_coef = pyro.sample("lt_coef", dist.Uniform(0, 1))
        # global trend proportion
        gt_coef = pyro.sample("gt_coef", dist.Uniform(-0.5, 0.5))
        # global trend parameter
        gt_pow = pyro.sample("gt_pow", dist.Uniform(0, 1))

        # seasonal parameters
        if self.is_seasonal:
            # seasonality smoothing parameter
            if self.sea_sm_input < 0:
                sea_sm = pyro.sample("sea_sm", dist.Uniform(0, 1))
            else:
                sea_sm = torch.tensor(self.sea_sm_input, dtype=torch.double)
                extra_out['sea_sm'] = sea_sm

            # initial seasonality
            # 33% lift is with 1 sd prob.
            init_sea = pyro.sample(
                "init_sea",
                dist.Normal(0, 0.33).expand([self.seasonality]).to_event(1))
            init_sea = init_sea - init_sea.mean(-1, keepdim=True)

        b = [None] * num_of_obs  # slope
        l = [None] * num_of_obs  # level
        if self.is_seasonal:
            s = [None] * (self.num_of_obs + self.seasonality)
            for t in range(self.seasonality):
                s[t] = init_sea[..., t]
            s[self.seasonality] = init_sea[..., 0]
        else:
            s = [torch.tensor(0.)] * num_of_obs

        # states initial condition
        b[0] = torch.zeros_like(slp_sm)
        if self.is_seasonal:
            l[0] = response[0] - r[..., 0] - s[0]
        else:
            l[0] = response[0] - r[..., 0]

        # update process
        for t in range(1, num_of_obs):
            # this update equation with l[t-1] ONLY.
            # intentionally different from the Holt-Winter form
            # this change is suggested from Slawek's original SLGT model
            l[t] = lev_sm * (response[t] - s[t] -
                             r[..., t]) + (1 - lev_sm) * l[t - 1]
            b[t] = slp_sm * (l[t] - l[t - 1]) + (1 - slp_sm) * b[t - 1]
            if self.is_seasonal:
                s[t + self.seasonality] = \
                    sea_sm * (response[t] - l[t] - r[..., t]) + (1 - sea_sm) * s[t]

        # evaluation process
        # vectorize as much math as possible
        for lst in [b, l, s]:
            # torch.stack requires all items to have the same shape, but the
            # initial items of our lists may not have batch_shape, so we expand.
            lst[0] = lst[0].expand_as(lst[-1])
        b = torch.stack(b, dim=-1).reshape(b[0].shape[:-1] + (-1, ))
        l = torch.stack(l, dim=-1).reshape(l[0].shape[:-1] + (-1, ))
        s = torch.stack(s, dim=-1).reshape(s[0].shape[:-1] + (-1, ))

        lgt_sum = l + gt_coef * l.abs()**gt_pow + lt_coef * b
        lgt_sum = torch.cat([l[..., :1], lgt_sum[..., :-1]],
                            dim=-1)  # shift by 1
        # a hack here as well to get rid of the extra "1" in r.shape
        if r.dim() >= 2:
            r = r.squeeze(-2)
        yhat = lgt_sum + s[..., :num_of_obs] + r

        with pyro.plate("response_plate", num_of_obs - 1):
            pyro.sample("response",
                        dist.StudentT(nu, yhat[..., 1:], obs_sigma),
                        obs=response[1:])

        # we care beta not the pr_beta, nr_beta, ...
        extra_out['beta'] = torch.cat([pr_beta, nr_beta, rr_beta], dim=-1)

        extra_out.update({'b': b, 'l': l, 's': s, 'lgt_sum': lgt_sum})
        return extra_out