Esempio n. 1
0
def test_onehot_shapes(probs):
    temperature = torch.tensor(0.5)
    probs = torch.tensor(probs, requires_grad=True)
    d = RelaxedOneHotCategoricalStraightThrough(temperature, probs=probs)
    sample = d.rsample()
    log_prob = d.log_prob(sample)
    grad_probs = grad(log_prob.sum(), [probs])[0]
    assert grad_probs.shape == probs.shape
Esempio n. 2
0
    def rsample(self, n_samples, ret_z, temperature=0.1):
        oh = RelaxedOneHotCategoricalStraightThrough(
            logits=self.logits, temperature=temperature).rsample((n_samples, ))
        mus = (oh.unsqueeze(2) * self.mus).sum(dim=1)
        sigmas = (oh.unsqueeze(2) * self.sigmas).sum(dim=1)**2

        if ret_z:
            return Normal(mus, sigmas).rsample(
                (1, )).squeeze(0), oh.argmax(dim=1)
        else:
            return Normal(mus, sigmas).rsample((1, )).squeeze(0)
Esempio n. 3
0
 def guide():
     q = pyro.param("q",
                    torch.tensor([0.1, 0.2, 0.3, 0.4]),
                    constraint=constraints.simplex)
     temp = torch.tensor(0.10)
     pyro.sample(
         "z",
         RelaxedOneHotCategoricalStraightThrough(temperature=temp, probs=q))
Esempio n. 4
0
def test_onehot_entropy_grad(temp):
    num_samples = 2000000
    q = torch.tensor([0.1, 0.2, 0.3, 0.4], requires_grad=True)
    temp = torch.tensor(temp)

    dist_q = RelaxedOneHotCategorical(temperature=temp, probs=q)
    z = dist_q.rsample(sample_shape=(num_samples, ))
    expected = grad(dist_q.log_prob(z).sum(), [q])[0] / num_samples

    dist_q = RelaxedOneHotCategoricalStraightThrough(temperature=temp, probs=q)
    z = dist_q.rsample(sample_shape=(num_samples, ))
    actual = grad(dist_q.log_prob(z).sum(), [q])[0] / num_samples

    assert_equal(
        expected,
        actual,
        prec=0.08,
        msg=
        'bad grad for RelaxedOneHotCategoricalStraightThrough (expected {}, got {})'
        .format(expected, actual))
Esempio n. 5
0
File: st.py Progetto: l1uw3n/xfuse
 def guide(self, x):
     with p.poutine.scale(scale=len(x["data"]) / self.n):
         self._sample_globals()
     for covariate, _ in require("covariates"):
         is_observed = x["effects"][covariate].values.any(1)
         effect_distr = RelaxedOneHotCategoricalStraightThrough(
             temperature=to_device(torch.as_tensor(0.1)),
             logits=torch.stack([
                 get_param(
                     f"effect-{covariate}-{sample}-logits",
                     torch.zeros(len(vals)),
                 ) for sample, vals in x["effects"][covariate].iterrows()
             ]),
         )
         with p.poutine.mask(mask=~to_device(torch.as_tensor(is_observed))):
             effect = p.sample(f"effect-{covariate}-all", effect_distr)
         effect[is_observed] = torch.as_tensor(
             x["effects"][covariate].values[is_observed]).to(effect)
         p.sample(f"effect-{covariate}", Delta(effect))
     return super().guide(x)
Esempio n. 6
0
def rsample_gumbel_softmax(
    distr: Distribution,
    n: int,
    temperature: torch.Tensor,
    straight_through: bool = False,
) -> torch.Tensor:
    if isinstance(distr, (Categorical, OneHotCategorical)):
        if straight_through:
            gumbel_distr = RelaxedOneHotCategoricalStraightThrough(
                temperature, probs=distr.probs)
        else:
            gumbel_distr = RelaxedOneHotCategorical(temperature,
                                                    probs=distr.probs)
    elif isinstance(distr, Bernoulli):
        if straight_through:
            gumbel_distr = RelaxedBernoulliStraightThrough(temperature,
                                                           probs=distr.probs)
        else:
            gumbel_distr = RelaxedBernoulli(temperature, probs=distr.probs)
    else:
        raise ValueError("Using Gumbel Softmax with non-discrete distribution")
    return gumbel_distr.rsample((n, ))
Esempio n. 7
0
File: st.py Progetto: ludvb/xfuse
        def _sample_condition(batch_idx, slide, covariate, condition):
            try:
                conditions = get("covariates")[covariate]
            except KeyError:
                return

            if pd.isna(condition):
                condition_distr = RelaxedOneHotCategoricalStraightThrough(
                    temperature=to_device(torch.as_tensor(0.1)),
                    logits=get_param(
                        f"logits-{slide}-{covariate}",
                        lambda: torch.zeros(len(conditions)),
                        lr_multiplier=2.0,
                    ),
                )
                with pyro.poutine.scale(scale=1.0 / dataset.size(slide=slide)):
                    condition_onehot = pyro.sample(
                        f"condition-{covariate}-{batch_idx}",
                        condition_distr,
                        infer={"is_global": True},
                    )
                condition_scale = 1e-99
                # ^ HACK: Pyro requires scale > 0
            else:
                condition_onehot = to_device(torch.eye(len(conditions)))[
                    np.isin(conditions, condition)
                ][0]
                condition_scale = 1.0 / dataset.size(
                    covariate=covariate, condition=condition
                )

            mu_rate_g_condition = condition_onehot @ get_param(
                f"mu_rate_g_condition-{covariate}",
                lambda: torch.zeros(
                    len(conditions), len(self._allocated_genes)
                ),
                lr_multiplier=2.0,
            )
            sd_rate_g_condition = condition_onehot @ get_param(
                f"sd_rate_g_condition-{covariate}",
                lambda: 1e-2
                * torch.ones(len(conditions), len(self._allocated_genes)),
                constraint=constraints.positive,
                lr_multiplier=2.0,
            )
            with pyro.poutine.scale(scale=condition_scale):
                pyro.sample(
                    f"rate_g_condition-{covariate}-{batch_idx}",
                    Normal(mu_rate_g_condition, 1e-8 + sd_rate_g_condition),
                    infer={"is_global": True},
                )

            mu_logits_g_condition = condition_onehot @ get_param(
                f"mu_logits_g_condition-{covariate}",
                lambda: torch.zeros(
                    len(conditions), len(self._allocated_genes)
                ),
                lr_multiplier=2.0,
            )
            sd_logits_g_condition = condition_onehot @ get_param(
                f"sd_logits_g_condition-{covariate}",
                lambda: 1e-2
                * torch.ones(len(conditions), len(self._allocated_genes)),
                constraint=constraints.positive,
                lr_multiplier=2.0,
            )
            with pyro.poutine.scale(scale=condition_scale):
                pyro.sample(
                    f"logits_g_condition-{covariate}-{batch_idx}",
                    Normal(
                        mu_logits_g_condition, 1e-8 + sd_logits_g_condition
                    ),
                    infer={"is_global": True},
                )
Esempio n. 8
0
    def forward(self, s_t):
        hidden = self.s_to_hidden(s_t)
        probs = F.softmax(self.hidden_to_scores(hidden), dim=-1)
        a_t = RelaxedOneHotCategoricalStraightThrough(0.5, probs=probs).rsample()

        return a_t