def _sample(self, distribution: dist.Distribution, sample_shape: Union[torch.Size, tuple] = torch.Size()): if self.training: return distribution.rsample(sample_shape=sample_shape) else: return distribution.sample(sample_shape=sample_shape)
def reparam_sample( self, distr: Distribution, parents: [storch.Tensor], plates: [Plate], amt_samples: int, ): if not distr.has_rsample: raise ValueError( "The input distribution has not implemented rsample. If you use a discrete " "distribution, make sure to use eg GumbelSoftmax.") return distr.rsample((amt_samples, ))