def sample(self): MVN_samples = ( self.mus + self.L1 * torch.unsqueeze(torch.randn_like(self.corrs, device=self.device), dim=-1) # [..., GMM_c, 2] + self.L2 * torch.unsqueeze( torch.randn_like(self.corrs, device=self.device), dim=-1) ) # (manual 2x2 matmul) cat_samples = self.cat.sample() # [...] selector = torch.unsqueeze(to_one_hot(cat_samples, self.GMM_c, self.device), dim=-1) return torch.sum(MVN_samples * selector, dim=-2)
def rsample(self, sample_shape=torch.Size()): """ Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched. :param sample_shape: Shape of the samples :return: Samples from the GMM. """ mvn_samples = (self.mus + torch.squeeze(torch.matmul( self.L, torch.unsqueeze(torch.randn(size=sample_shape + self.mus.shape, device=self.device), dim=-1)), dim=-1)) component_cat_samples = self.pis_cat_dist.sample(sample_shape) selector = torch.unsqueeze(to_one_hot(component_cat_samples, self.components), dim=-1) return torch.sum(mvn_samples * selector, dim=-2)