def test_onehot_shapes(probs): temperature = torch.tensor(0.5) probs = torch.tensor(probs, requires_grad=True) d = RelaxedOneHotCategoricalStraightThrough(temperature, probs=probs) sample = d.rsample() log_prob = d.log_prob(sample) grad_probs = grad(log_prob.sum(), [probs])[0] assert grad_probs.shape == probs.shape
def rsample_gumbel_softmax( distr: Distribution, n: int, temperature: torch.Tensor, straight_through: bool = False, ) -> torch.Tensor: if isinstance(distr, (Categorical, OneHotCategorical)): if straight_through: gumbel_distr = RelaxedOneHotCategoricalStraightThrough( temperature, probs=distr.probs) else: gumbel_distr = RelaxedOneHotCategorical(temperature, probs=distr.probs) elif isinstance(distr, Bernoulli): if straight_through: gumbel_distr = RelaxedBernoulliStraightThrough(temperature, probs=distr.probs) else: gumbel_distr = RelaxedBernoulli(temperature, probs=distr.probs) else: raise ValueError("Using Gumbel Softmax with non-discrete distribution") return gumbel_distr.rsample((n, ))