Esempio n. 1
0
def test_onehot_shapes(probs):
    temperature = torch.tensor(0.5)
    probs = torch.tensor(probs, requires_grad=True)
    d = RelaxedOneHotCategoricalStraightThrough(temperature, probs=probs)
    sample = d.rsample()
    log_prob = d.log_prob(sample)
    grad_probs = grad(log_prob.sum(), [probs])[0]
    assert grad_probs.shape == probs.shape
Esempio n. 2
0
def rsample_gumbel_softmax(
    distr: Distribution,
    n: int,
    temperature: torch.Tensor,
    straight_through: bool = False,
) -> torch.Tensor:
    if isinstance(distr, (Categorical, OneHotCategorical)):
        if straight_through:
            gumbel_distr = RelaxedOneHotCategoricalStraightThrough(
                temperature, probs=distr.probs)
        else:
            gumbel_distr = RelaxedOneHotCategorical(temperature,
                                                    probs=distr.probs)
    elif isinstance(distr, Bernoulli):
        if straight_through:
            gumbel_distr = RelaxedBernoulliStraightThrough(temperature,
                                                           probs=distr.probs)
        else:
            gumbel_distr = RelaxedBernoulli(temperature, probs=distr.probs)
    else:
        raise ValueError("Using Gumbel Softmax with non-discrete distribution")
    return gumbel_distr.rsample((n, ))