Пример #1
0
def test_zinb_0_gate(total_count, probs):
    # if gate is 0 ZINB is NegativeBinomial
    zinb_ = ZeroInflatedNegativeBinomial(
        torch.zeros(1), total_count=torch.tensor(total_count), probs=torch.tensor(probs)
    )
    neg_bin = NegativeBinomial(torch.tensor(total_count), probs=torch.tensor(probs))
    s = neg_bin.sample((20,))
    zinb_prob = zinb_.log_prob(s)
    neg_bin_prob = neg_bin.log_prob(s)
    assert_close(zinb_prob, neg_bin_prob)
Пример #2
0
    def __init__(self, total_count, *, probs=None, logits=None, gate=None, gate_logits=None, validate_args=None):
        base_dist = NegativeBinomial(
            total_count=total_count,
            probs=probs,
            logits=logits,
            validate_args=False,
        )
        base_dist._validate_args = validate_args

        super().__init__(
            base_dist, gate=gate, gate_logits=gate_logits, validate_args=validate_args
        )
Пример #3
0
def test_zinb_0_gate(total_count, probs):
    # if gate is 0 ZINB is NegativeBinomial
    zinb1 = ZeroInflatedNegativeBinomial(
        total_count=torch.tensor(total_count),
        gate=torch.zeros(1),
        probs=torch.tensor(probs),
    )
    zinb2 = ZeroInflatedNegativeBinomial(
        total_count=torch.tensor(total_count),
        gate_logits=torch.tensor(-99.9),
        probs=torch.tensor(probs),
    )
    neg_bin = NegativeBinomial(torch.tensor(total_count),
                               probs=torch.tensor(probs))
    s = neg_bin.sample((20, ))
    zinb1_prob = zinb1.log_prob(s)
    zinb2_prob = zinb2.log_prob(s)
    neg_bin_prob = neg_bin.log_prob(s)
    assert_close(zinb1_prob, neg_bin_prob)
    assert_close(zinb2_prob, neg_bin_prob)
Пример #4
0
    def __init__(self, gate, total_count, probs=None, logits=None, validate_args=None):
        base_dist = NegativeBinomial(
            total_count=total_count,
            probs=probs,
            logits=logits,
            validate_args=validate_args,
        )

        super(ZeroInflatedNegativeBinomial, self).__init__(
            gate, base_dist, validate_args=validate_args
        )
Пример #5
0
    def model(self, x, zs):
        # pylint: disable=too-many-locals
        def _compute_rim(decoded):
            shared_representation = get_module(
                "metagene_shared",
                lambda: torch.nn.Sequential(
                    torch.nn.Conv2d(
                        decoded.shape[1], decoded.shape[1], kernel_size=1),
                    torch.nn.BatchNorm2d(decoded.shape[1], momentum=0.05),
                    torch.nn.LeakyReLU(0.2, inplace=True),
                ),
            )(decoded)
            rim = torch.cat(
                [
                    get_module(
                        f"decoder_{_encode_metagene_name(n)}",
                        partial(self._create_metagene_decoder,
                                decoded.shape[1], n),
                    )(shared_representation) for n in self.metagenes
                ],
                dim=1,
            )
            rim = torch.nn.functional.softmax(rim, dim=1)
            return rim

        num_genes = x["data"][0].shape[1]
        decoded = self._decode(zs)
        label = center_crop(x["label"], [None, *decoded.shape[-2:]])

        rim = checkpoint(_compute_rim, decoded)
        rim = center_crop(rim, [None, None, *label.shape[-2:]])
        rim = p.sample("rim", Delta(rim))

        scale = p.sample(
            "scale",
            Delta(
                center_crop(
                    self._get_scale_decoder(decoded.shape[1])(decoded),
                    [None, None, *label.shape[-2:]],
                )),
        )
        rim = scale * rim

        with p.poutine.scale(scale=len(x["data"]) / self.n):
            rate_mg_prior = Normal(
                0.0,
                1e-8 + get_param(
                    "rate_mg_sd",
                    lambda: torch.ones(num_genes),
                    constraint=constraints.positive,
                ),
            )
            rate_mg = torch.stack([
                p.sample(_encode_metagene_name(n), rate_mg_prior)
                for n in self.metagenes
            ])
            rate_mg = p.sample("rate_mg", Delta(rate_mg))

            rate_g_effects_baseline = get_param(
                "rate_g_effects_baseline",
                lambda: self.__init_rate_baseline().log(),
                lr_multiplier=5.0,
            )
            logits_g_effects_baseline = get_param(
                "logits_g_effects_baseline",
                # pylint: disable=unnecessary-lambda
                self.__init_logits_baseline,
                lr_multiplier=5.0,
            )
            rate_g_effects_prior = Normal(
                0.0,
                1e-8 + get_param(
                    "rate_g_effects_sd",
                    lambda: torch.ones(num_genes),
                    constraint=constraints.positive,
                ),
            )
            rate_g_effects = p.sample("rate_g_effects", rate_g_effects_prior)
            rate_g_effects = torch.cat(
                [rate_g_effects_baseline.unsqueeze(0), rate_g_effects])
            logits_g_effects_prior = Normal(
                0.0,
                1e-8 + get_param(
                    "logits_g_effects_sd",
                    lambda: torch.ones(num_genes),
                    constraint=constraints.positive,
                ),
            )
            logits_g_effects = p.sample(
                "logits_g_effects",
                logits_g_effects_prior,
            )
            logits_g_effects = torch.cat(
                [logits_g_effects_baseline.unsqueeze(0), logits_g_effects])

            effects = torch.cat(
                [
                    torch.ones(x["effects"].shape[0], 1).to(x["effects"]),
                    x["effects"],
                ],
                1,
            ).float()

            logits_g = effects @ logits_g_effects
            rate_g = effects @ rate_g_effects
            rate_mg = rate_g[:, None] + rate_mg

        with scope(prefix=self.tag):
            image_distr = self._sample_image(x, decoded)

            def _compute_sample_params(data, label, rim, rate_mg, logits_g):
                nonmissing = label != 0
                zero_count_spots = 1 + torch.where(data.sum(1) == 0)[0]
                nonpartial = binary_fill_holes(
                    np.isin(label.cpu(), [0, *zero_count_spots.cpu()]))
                nonpartial = torch.as_tensor(nonpartial).to(nonmissing)
                mask = nonpartial & nonmissing

                if not mask.any():
                    return (
                        data[[]],
                        torch.zeros(0, num_genes).to(rim),
                        logits_g.expand(0, -1),
                    )

                label = label[mask] - 1
                idxs, label = torch.unique(label, return_inverse=True)
                data = data[idxs]

                rim = rim[:, mask]
                labelonehot = sparseonehot(label)
                rim = torch.sparse.mm(labelonehot.t().float(), rim.t())

                rgs = rim @ rate_mg.exp()

                return data, rgs, logits_g.expand(len(rgs), -1)

            data, rgs, logits_g = zip(*it.starmap(
                _compute_sample_params,
                zip(x["data"], label, rim, rate_mg, logits_g),
            ))

            expression_distr = NegativeBinomial(
                total_count=1e-8 + torch.cat(rgs),
                logits=torch.cat(logits_g),
            )
            p.sample("xsg", expression_distr, obs=torch.cat(data))

        return image_distr, expression_distr
Пример #6
0
    def __init_globals(self):
        dataloader = require("dataloader")
        device = get("default_device")

        dataloader = make_dataloader(
            Dataset(
                Data(
                    slides={
                        k: Slide(
                            data=v.data,
                            # pylint: disable=unnecessary-lambda
                            # ^ Necessary for type checking to pass
                            iterator=lambda x: DataSlide(x),
                        )
                        for k, v in dataloader.dataset.data.slides.items()
                        if v.data.type == "ST"
                    },
                    design=dataloader.dataset.data.design,
                )),
            num_workers=0,
            batch_size=100,
        )

        r2rp = transform_to(constraints.positive)

        scale = torch.zeros(1, requires_grad=True, device=device)
        rate = torch.zeros(len(dataloader.dataset.genes),
                           requires_grad=True,
                           device=device)
        logits = torch.zeros(len(dataloader.dataset.genes),
                             requires_grad=True,
                             device=device)

        optim = torch.optim.Adam((scale, rate, logits), lr=0.01)

        with Progressbar(it.count(1), leave=False, position=0) as iterator:
            running_rmse = None
            for epoch in iterator:
                previous_rmse = running_rmse
                for x in (torch.cat(x["ST"]["data"]).to(device)
                          for x in dataloader):
                    distr = NegativeBinomial(r2rp(scale) * r2rp(rate),
                                             logits=logits)
                    rmse = (((distr.mean -
                              x)**2).mean(1).sqrt().mean().detach().cpu())
                    try:
                        running_rmse = running_rmse + 1e-2 * (rmse -
                                                              running_rmse)
                    except TypeError:
                        running_rmse = rmse
                    iterator.set_description(
                        "Initializing global coefficients, please wait..." +
                        f" (RMSE: {running_rmse:.3f})")
                    optim.zero_grad()
                    nll = -distr.log_prob(x).sum()
                    nll.backward()
                    optim.step()
                if (epoch > 100) and (previous_rmse - running_rmse < 1e-4):
                    break

        self.__init_scale = r2rp(scale).detach().cpu()
        self.__init_rate = r2rp(rate).detach().cpu()
        self.__init_logits = logits.detach().cpu()
Пример #7
0
Файл: st.py Проект: ludvb/xfuse
    def model(self, x, zs):
        # pylint: disable=too-many-locals, too-many-statements

        dataset = require("dataloader").dataset

        def _compute_rim(decoded):
            shared_representation = get_module(
                "metagene_shared",
                lambda: torch.nn.Sequential(
                    torch.nn.Conv2d(
                        decoded.shape[1], decoded.shape[1], kernel_size=1
                    ),
                    torch.nn.BatchNorm2d(decoded.shape[1], momentum=0.05),
                    torch.nn.LeakyReLU(0.2, inplace=True),
                ),
            )(decoded)
            rim = torch.cat(
                [
                    get_module(
                        f"decoder_{_encode_metagene_name(n)}",
                        partial(
                            self._create_metagene_decoder, decoded.shape[1], n
                        ),
                    )(shared_representation)
                    for n in self.metagenes
                ],
                dim=1,
            )
            rim = torch.nn.functional.softmax(rim, dim=1)
            return rim

        decoded = self._decode(zs)
        label = center_crop(x["label"], [None, *decoded.shape[-2:]])

        rim = checkpoint(_compute_rim, decoded)
        rim = center_crop(rim, [None, None, *label.shape[-2:]])
        rim = pyro.sample("rim", Delta(rim))

        scale = pyro.sample(
            "scale",
            Delta(
                center_crop(
                    self._get_scale_decoder(decoded.shape[1])(decoded),
                    [None, None, *label.shape[-2:]],
                )
            ),
        )
        rim = scale * rim

        rate_mg_prior = Normal(
            0.0,
            1e-8
            + get_param(
                "rate_mg_prior_sd",
                lambda: torch.ones(len(self._allocated_genes)),
                constraint=constraints.positive,
            ),
        )
        with pyro.poutine.scale(scale=len(x["data"]) / dataset.size()):
            rate_mg = torch.stack(
                [
                    pyro.sample(
                        _encode_metagene_name(n),
                        rate_mg_prior,
                        infer={"is_global": True},
                    )
                    for n in self.metagenes
                ]
            )
        rate_mg = pyro.sample("rate_mg", Delta(rate_mg))

        rate_g_conditions_prior = Normal(
            0.0,
            1e-8
            + get_param(
                "rate_g_conditions_prior_sd",
                lambda: torch.ones(len(self._allocated_genes)),
                constraint=constraints.positive,
            ),
        )
        logits_g_conditions_prior = Normal(
            0.0,
            1e-8
            + get_param(
                "logits_g_conditions_prior_sd",
                lambda: torch.ones(len(self._allocated_genes)),
                constraint=constraints.positive,
            ),
        )

        rate_g, logits_g = [], []

        for batch_idx, (slide, covariates) in enumerate(
            zip(x["slide"], x["covariates"])
        ):
            rate_g_slide = get_param(
                "rate_g_condition_baseline",
                lambda: self.__init_rate_baseline().log(),
                lr_multiplier=5.0,
            )
            logits_g_slide = get_param(
                "logits_g_condition_baseline",
                self.__init_logits_baseline,
                lr_multiplier=5.0,
            )

            for covariate, condition in covariates.items():
                try:
                    conditions = get("covariates")[covariate]
                except KeyError:
                    continue

                if pd.isna(condition):
                    with pyro.poutine.scale(
                        scale=1.0 / dataset.size(slide=slide)
                    ):
                        pyro.sample(
                            f"condition-{covariate}-{batch_idx}",
                            OneHotCategorical(
                                to_device(torch.ones(len(conditions)))
                                / len(conditions)
                            ),
                            infer={"is_global": True},
                        )
                        # ^ NOTE 1: This statement affects the ELBO but not its
                        #           gradient. The pmf is non-differentiable but
                        #           it doesn't matter---our prior over the
                        #           conditions is uniform; even if a gradient
                        #           existed, it would always be zero.
                        # ^ NOTE 2: The result is used to index the effect of
                        #           the condition. However, this takes place in
                        #           the guide to avoid sampling effets that are
                        #           not used in the current minibatch,
                        #           potentially (?) reducing noise in the
                        #           learning signal. Therefore, the result here
                        #           is discarded.
                    condition_scale = 1e-99
                    # ^ HACK: Pyro requires scale > 0
                else:
                    condition_scale = 1.0 / dataset.size(
                        covariate=covariate, condition=condition
                    )

                with pyro.poutine.scale(scale=condition_scale):
                    rate_g_slide = rate_g_slide + pyro.sample(
                        f"rate_g_condition-{covariate}-{batch_idx}",
                        rate_g_conditions_prior,
                        infer={"is_global": True},
                    )
                    logits_g_slide = logits_g_slide + pyro.sample(
                        f"logits_g_condition-{covariate}-{batch_idx}",
                        logits_g_conditions_prior,
                        infer={"is_global": True},
                    )

            rate_g.append(rate_g_slide)
            logits_g.append(logits_g_slide)

        logits_g = torch.stack(logits_g)[:, self._gene_indices]
        rate_g = torch.stack(rate_g)[:, self._gene_indices]
        rate_mg = rate_g.unsqueeze(1) + rate_mg[:, self._gene_indices]

        with scope(prefix=self.tag):
            self._sample_image(x, decoded)

            for i, (data, label, rim, rate_mg, logits_g) in enumerate(
                zip(x["data"], label, rim, rate_mg, logits_g)
            ):
                zero_count_idxs = 1 + torch.where(data.sum(1) == 0)[0]
                partial_idxs = np.unique(
                    torch.cat([label[0], label[-1], label[:, 0], label[:, -1]])
                    .cpu()
                    .numpy()
                )
                partial_idxs = np.setdiff1d(
                    partial_idxs, zero_count_idxs.cpu().numpy()
                )
                mask = np.invert(
                    np.isin(label.cpu().numpy(), [0, *partial_idxs])
                )
                mask = torch.as_tensor(mask, device=label.device)

                if not mask.any():
                    continue

                label = label[mask]
                idxs, label = torch.unique(label, return_inverse=True)
                data = data[idxs - 1]
                pyro.sample(f"idx-{i}", Delta(idxs.float()))

                rim = rim[:, mask]
                labelonehot = sparseonehot(label)
                rim = torch.sparse.mm(labelonehot.t().float(), rim.t())
                rsg = rim @ rate_mg.exp()

                expression_distr = NegativeBinomial(
                    total_count=1e-8 + rsg, logits=logits_g
                )
                pyro.sample(f"xsg-{i}", expression_distr, obs=data)