예제 #1
0
def test_categorical_gradient_with_logits(init_tensor_type):
    p = Variable(init_tensor_type([-float('inf'), 0]), requires_grad=True)
    categorical = OneHotCategorical(logits=p)
    log_pdf = categorical.batch_log_pdf(Variable(init_tensor_type([0, 1])))
    log_pdf.sum().backward()
    assert_equal(log_pdf.data[0], 0)
    assert_equal(p.grad.data[0], 0)
예제 #2
0
def main(xs, ys=None):
    """
    The model corresponds to the following generative process:
    p(z) = normal(0,I)              # handwriting style (latent)
    p(y|x) = categorical(I/10.)     # which digit (semi-supervised)
    p(x|y,z) = bernoulli(loc(y,z))   # an image
    loc is given by a neural network  `decoder`

    :param xs: a batch of scaled vectors of pixels from an image
    :param ys: (optional) a batch of the class labels i.e.
               the digit corresponding to the image(s)
    :return: None
    """
    xs = torch.reshape(xs, [200,784])
    # WL: ok for analyser? =====
    # ys = ...
    # ==========================
    
    # register this pytorch module and all of its sub-modules with pyro
    pyro.module("decoder_fst", decoder_fst)
    pyro.module("decoder_snd", decoder_snd)

    # batch_size = xs.size(0)
    # batch_size = 200
    # z_dim = 50
    # output_size = 10
    with pyro.plate("data"):
        # sample the handwriting style from the constant prior distribution
        prior_loc = torch.zeros([200, 50])
        prior_scale = torch.ones([200, 50])
        zs = pyro.sample("z", Normal(prior_loc, prior_scale).to_event(1))

        # if the label y (which digit to write) is supervised, sample from the
        # constant prior, otherwise, observe the value (i.e. score it against the constant prior)
        alpha_prior = torch.ones([200, 10]) / (1.0 * 10)
        # WL: editd. =====
        # ys = pyro.sample("y", OneHotCategorical(alpha_prior), obs=ys)
        if ys is None:
            ys = pyro.sample("y", OneHotCategorical(alpha_prior))
        else:
            ys = pyro.sample("y", OneHotCategorical(alpha_prior), obs=ys)
        # ================

        # finally, score the image (x) using the handwriting style (z) and
        # the class label y (which digit to write) against the
        # parametrized distribution p(x|y,z) = bernoulli(decoder(y,z))
        # where `decoder` is a neural network
        hidden = softplus(decoder_fst(torch.cat([zs, ys], -1)))
        loc = sigmoid(decoder_snd(hidden))
        pyro.sample("x", Bernoulli(loc).to_event(1), obs=xs)
예제 #3
0
 def label_variable(self, label):
     new_label = []
     options = {'device': label.device, 'dtype': label.dtype}
     for i, length in enumerate(self.latents_sizes):
         prior = torch.ones(label.shape[0], length, **
                            options) / (1.0 * length)
         new_label.append(
             pyro.sample("label_" + str(self.latents_names[i]),
                         OneHotCategorical(prior),
                         obs=one_hot(tensor(label[:, i], dtype=torch.int64),
                                     int(length))))
     new_label = torch.cat(new_label, -1)
     return new_label.to(torch.float32).to(label.device)
def main(xs, ys=None):
    """
    The guide corresponds to the following:
    q(y|x) = categorical(alpha(x))              # infer digit from an image
    q(z|x,y) = normal(loc(x,y),scale(x,y))       # infer handwriting style from an image and the digit
    loc, scale are given by a neural network `encoder_z`
    alpha is given by a neural network `encoder_y`

    :param xs: a batch of scaled vectors of pixels from an image
    :param ys: (optional) a batch of the class labels i.e.
               the digit corresponding to the image(s)
    :return: None
    """
    xs = torch.reshape(xs, [200, 784])
    # WL: ok for analyser? =====
    # ys = ...
    # ==========================

    pyro.module("encoder_y_fst", encoder_y_fst)
    pyro.module("encoder_y_snd", encoder_y_snd)
    pyro.module("encoder_z_fst", encoder_z_fst)
    pyro.module("encoder_z_out1", encoder_z_out1)
    pyro.module("encoder_z_out2", encoder_z_out2)

    # inform Pyro that the variables in the batch of xs, ys are conditionally independent
    with pyro.plate("data"):
        # if the class label (the digit) is not supervised, sample
        # (and score) the digit with the variational distribution
        # q(y|x) = categorical(alpha(x))
        if ys is None:
            hidden = softplus(encoder_y_fst(xs))
            alpha = softmax(encoder_y_snd(hidden))
            ys = pyro.sample("y", OneHotCategorical(alpha))

        # sample (and score) the latent handwriting-style with the variational
        # distribution q(z|x,y) = normal(loc(x,y),scale(x,y))
        # shape = broadcast_shape(torch.Size([200]), ys.shape[:-1]) + (-1,)
        # WL: ok for analyser? =====
        shape = ys.shape[:-1] + (-1, )
        hidden_z = softplus(
            encoder_z_fst(
                torch.cat([
                    torch.Tensor.expand(xs, shape),
                    torch.Tensor.expand(ys, shape)
                ], -1)))
        # ==========================
        loc = encoder_z_out1(hidden_z)
        scale = torch.exp(encoder_z_out2(hidden_z))
        pyro.sample("z", Normal(loc, scale).to_event(1))
예제 #5
0
 def model():
     p = torch.tensor([0.25] * 4)
     pyro.sample("z", OneHotCategorical(probs=p))
예제 #6
0
파일: st.py 프로젝트: ludvb/xfuse
    def model(self, x, zs):
        # pylint: disable=too-many-locals, too-many-statements

        dataset = require("dataloader").dataset

        def _compute_rim(decoded):
            shared_representation = get_module(
                "metagene_shared",
                lambda: torch.nn.Sequential(
                    torch.nn.Conv2d(
                        decoded.shape[1], decoded.shape[1], kernel_size=1
                    ),
                    torch.nn.BatchNorm2d(decoded.shape[1], momentum=0.05),
                    torch.nn.LeakyReLU(0.2, inplace=True),
                ),
            )(decoded)
            rim = torch.cat(
                [
                    get_module(
                        f"decoder_{_encode_metagene_name(n)}",
                        partial(
                            self._create_metagene_decoder, decoded.shape[1], n
                        ),
                    )(shared_representation)
                    for n in self.metagenes
                ],
                dim=1,
            )
            rim = torch.nn.functional.softmax(rim, dim=1)
            return rim

        decoded = self._decode(zs)
        label = center_crop(x["label"], [None, *decoded.shape[-2:]])

        rim = checkpoint(_compute_rim, decoded)
        rim = center_crop(rim, [None, None, *label.shape[-2:]])
        rim = pyro.sample("rim", Delta(rim))

        scale = pyro.sample(
            "scale",
            Delta(
                center_crop(
                    self._get_scale_decoder(decoded.shape[1])(decoded),
                    [None, None, *label.shape[-2:]],
                )
            ),
        )
        rim = scale * rim

        rate_mg_prior = Normal(
            0.0,
            1e-8
            + get_param(
                "rate_mg_prior_sd",
                lambda: torch.ones(len(self._allocated_genes)),
                constraint=constraints.positive,
            ),
        )
        with pyro.poutine.scale(scale=len(x["data"]) / dataset.size()):
            rate_mg = torch.stack(
                [
                    pyro.sample(
                        _encode_metagene_name(n),
                        rate_mg_prior,
                        infer={"is_global": True},
                    )
                    for n in self.metagenes
                ]
            )
        rate_mg = pyro.sample("rate_mg", Delta(rate_mg))

        rate_g_conditions_prior = Normal(
            0.0,
            1e-8
            + get_param(
                "rate_g_conditions_prior_sd",
                lambda: torch.ones(len(self._allocated_genes)),
                constraint=constraints.positive,
            ),
        )
        logits_g_conditions_prior = Normal(
            0.0,
            1e-8
            + get_param(
                "logits_g_conditions_prior_sd",
                lambda: torch.ones(len(self._allocated_genes)),
                constraint=constraints.positive,
            ),
        )

        rate_g, logits_g = [], []

        for batch_idx, (slide, covariates) in enumerate(
            zip(x["slide"], x["covariates"])
        ):
            rate_g_slide = get_param(
                "rate_g_condition_baseline",
                lambda: self.__init_rate_baseline().log(),
                lr_multiplier=5.0,
            )
            logits_g_slide = get_param(
                "logits_g_condition_baseline",
                self.__init_logits_baseline,
                lr_multiplier=5.0,
            )

            for covariate, condition in covariates.items():
                try:
                    conditions = get("covariates")[covariate]
                except KeyError:
                    continue

                if pd.isna(condition):
                    with pyro.poutine.scale(
                        scale=1.0 / dataset.size(slide=slide)
                    ):
                        pyro.sample(
                            f"condition-{covariate}-{batch_idx}",
                            OneHotCategorical(
                                to_device(torch.ones(len(conditions)))
                                / len(conditions)
                            ),
                            infer={"is_global": True},
                        )
                        # ^ NOTE 1: This statement affects the ELBO but not its
                        #           gradient. The pmf is non-differentiable but
                        #           it doesn't matter---our prior over the
                        #           conditions is uniform; even if a gradient
                        #           existed, it would always be zero.
                        # ^ NOTE 2: The result is used to index the effect of
                        #           the condition. However, this takes place in
                        #           the guide to avoid sampling effets that are
                        #           not used in the current minibatch,
                        #           potentially (?) reducing noise in the
                        #           learning signal. Therefore, the result here
                        #           is discarded.
                    condition_scale = 1e-99
                    # ^ HACK: Pyro requires scale > 0
                else:
                    condition_scale = 1.0 / dataset.size(
                        covariate=covariate, condition=condition
                    )

                with pyro.poutine.scale(scale=condition_scale):
                    rate_g_slide = rate_g_slide + pyro.sample(
                        f"rate_g_condition-{covariate}-{batch_idx}",
                        rate_g_conditions_prior,
                        infer={"is_global": True},
                    )
                    logits_g_slide = logits_g_slide + pyro.sample(
                        f"logits_g_condition-{covariate}-{batch_idx}",
                        logits_g_conditions_prior,
                        infer={"is_global": True},
                    )

            rate_g.append(rate_g_slide)
            logits_g.append(logits_g_slide)

        logits_g = torch.stack(logits_g)[:, self._gene_indices]
        rate_g = torch.stack(rate_g)[:, self._gene_indices]
        rate_mg = rate_g.unsqueeze(1) + rate_mg[:, self._gene_indices]

        with scope(prefix=self.tag):
            self._sample_image(x, decoded)

            for i, (data, label, rim, rate_mg, logits_g) in enumerate(
                zip(x["data"], label, rim, rate_mg, logits_g)
            ):
                zero_count_idxs = 1 + torch.where(data.sum(1) == 0)[0]
                partial_idxs = np.unique(
                    torch.cat([label[0], label[-1], label[:, 0], label[:, -1]])
                    .cpu()
                    .numpy()
                )
                partial_idxs = np.setdiff1d(
                    partial_idxs, zero_count_idxs.cpu().numpy()
                )
                mask = np.invert(
                    np.isin(label.cpu().numpy(), [0, *partial_idxs])
                )
                mask = torch.as_tensor(mask, device=label.device)

                if not mask.any():
                    continue

                label = label[mask]
                idxs, label = torch.unique(label, return_inverse=True)
                data = data[idxs - 1]
                pyro.sample(f"idx-{i}", Delta(idxs.float()))

                rim = rim[:, mask]
                labelonehot = sparseonehot(label)
                rim = torch.sparse.mm(labelonehot.t().float(), rim.t())
                rsg = rim @ rate_mg.exp()

                expression_distr = NegativeBinomial(
                    total_count=1e-8 + rsg, logits=logits_g
                )
                pyro.sample(f"xsg-{i}", expression_distr, obs=data)
예제 #7
0
파일: st.py 프로젝트: l1uw3n/xfuse
    def model(self, x, zs):
        # pylint: disable=too-many-locals
        def _compute_rim(decoded):
            shared_representation = get_module(
                "metagene_shared",
                lambda: torch.nn.Sequential(
                    torch.nn.Conv2d(
                        decoded.shape[1], decoded.shape[1], kernel_size=1),
                    torch.nn.BatchNorm2d(decoded.shape[1], momentum=0.05),
                    torch.nn.LeakyReLU(0.2, inplace=True),
                ),
            )(decoded)
            rim = torch.cat(
                [
                    get_module(
                        f"decoder_{_encode_metagene_name(n)}",
                        partial(self._create_metagene_decoder,
                                decoded.shape[1], n),
                    )(shared_representation) for n in self.metagenes
                ],
                dim=1,
            )
            rim = torch.nn.functional.softmax(rim, dim=1)
            return rim

        num_genes = x["data"][0].shape[1]
        decoded = self._decode(zs)
        label = center_crop(x["label"], [None, *decoded.shape[-2:]])

        rim = checkpoint(_compute_rim, decoded)
        rim = center_crop(rim, [None, None, *label.shape[-2:]])
        rim = p.sample("rim", Delta(rim))

        scale = p.sample(
            "scale",
            Delta(
                center_crop(
                    self._get_scale_decoder(decoded.shape[1])(decoded),
                    [None, None, *label.shape[-2:]],
                )),
        )
        rim = scale * rim

        with p.poutine.scale(scale=len(x["data"]) / self.n):
            rate_mg_prior = Normal(
                0.0,
                1e-8 + get_param(
                    "rate_mg_prior_sd",
                    lambda: torch.ones(num_genes),
                    constraint=constraints.positive,
                ),
            )
            rate_mg = torch.stack([
                p.sample(_encode_metagene_name(n), rate_mg_prior)
                for n in self.metagenes
            ])
            rate_mg = p.sample("rate_mg", Delta(rate_mg))

            rate_g_effects_baseline = get_param(
                "rate_g_effects_baseline",
                lambda: self.__init_rate_baseline().log(),
                lr_multiplier=5.0,
            )
            logits_g_effects_baseline = get_param(
                "logits_g_effects_baseline",
                # pylint: disable=unnecessary-lambda
                self.__init_logits_baseline,
                lr_multiplier=5.0,
            )
            rate_g_effects_prior = Normal(
                0.0,
                1e-8 + get_param(
                    "rate_g_effects_prior_sd",
                    lambda: torch.ones(num_genes),
                    constraint=constraints.positive,
                ),
            )
            rate_g_effects = p.sample("rate_g_effects", rate_g_effects_prior)
            rate_g_effects = torch.cat(
                [rate_g_effects_baseline.unsqueeze(0), rate_g_effects])
            logits_g_effects_prior = Normal(
                0.0,
                1e-8 + get_param(
                    "logits_g_effects_prior_sd",
                    lambda: torch.ones(num_genes),
                    constraint=constraints.positive,
                ),
            )
            logits_g_effects = p.sample(
                "logits_g_effects",
                logits_g_effects_prior,
            )
            logits_g_effects = torch.cat(
                [logits_g_effects_baseline.unsqueeze(0), logits_g_effects])

        effects = []
        for covariate, vals in require("covariates"):
            effect = p.sample(
                f"effect-{covariate}",
                OneHotCategorical(
                    to_device(torch.ones(len(vals))) / len(vals)),
            )
            effects.append(effect)
        effects = torch.cat(
            [
                to_device(torch.ones(x["effects"].shape[0], 1)),
                *effects,
            ],
            1,
        ).float()

        logits_g = effects @ logits_g_effects
        rate_g = effects @ rate_g_effects
        rate_mg = rate_g[:, None] + rate_mg

        with scope(prefix=self.tag):
            image_distr = self._sample_image(x, decoded)

            def _compute_sample_params(data, label, rim, rate_mg, logits_g):
                zero_count_idxs = 1 + torch.where(data.sum(1) == 0)[0]
                partial_idxs = np.unique(
                    torch.cat([label[0], label[-1], label[:, 0],
                               label[:, -1]]).cpu().numpy())
                partial_idxs = np.setdiff1d(partial_idxs,
                                            zero_count_idxs.cpu().numpy())
                mask = np.invert(
                    np.isin(label.cpu().numpy(), [0, *partial_idxs]))
                mask = torch.as_tensor(mask, device=label.device)

                if not mask.any():
                    return (
                        data[[]],
                        torch.zeros(0, num_genes).to(rim),
                        logits_g.expand(0, -1),
                    )

                label = label[mask] - 1
                idxs, label = torch.unique(label, return_inverse=True)
                data = data[idxs]

                rim = rim[:, mask]
                labelonehot = sparseonehot(label)
                rim = torch.sparse.mm(labelonehot.t().float(), rim.t())

                rgs = rim @ rate_mg.exp()

                return data, rgs, logits_g.expand(len(rgs), -1)

            data, rgs, logits_g = zip(*it.starmap(
                _compute_sample_params,
                zip(x["data"], label, rim, rate_mg, logits_g),
            ))

            expression_distr = NegativeBinomial(
                total_count=1e-8 + torch.cat(rgs),
                logits=torch.cat(logits_g),
            )
            p.sample("xsg", expression_distr, obs=torch.cat(data))

        return image_distr, expression_distr