def sample( self, distr: Distribution, parents: [storch.Tensor], plates: [Plate], requires_grad: bool, ) -> (storch.StochasticTensor, Plate): plate = None for _plate in plates: if _plate.name == self.plate_name: plate = _plate break n_samples = 1 if plate else self.n_samples with storch.ignore_wrapping(): tensor = self.mc_sample(distr, parents, plates, n_samples) plate_size = tensor.shape[0] if tensor.shape[0] == 1: tensor = tensor.squeeze(0) if not plate: plate = Plate(self.plate_name, plate_size, plates.copy()) plates.insert(0, plate) if isinstance(tensor, storch.Tensor): tensor = tensor._tensor s_tensor = storch.StochasticTensor( tensor, parents, plates, self.plate_name, plate_size, distr, requires_grad, ) return s_tensor, plate
def sample( self, distr: Distribution, parents: [storch.Tensor], plates: [Plate], requires_grad: bool, ) -> (storch.StochasticTensor, Plate): # TODO: Currently very inefficient as it isn't batched # TODO: What if the expectation has a parent if not distr.has_enumerate_support: raise ValueError( "Can only calculate the expected value for distributions with enumerable support." ) support: torch.Tensor = distr.enumerate_support(expand=True) support_non_expanded: torch.Tensor = distr.enumerate_support( expand=False) expect_size = support.shape[0] batch_len = len(plates) sizes = support.shape[batch_len + 1:len(support.shape) - len(distr.event_shape)] amt_samples_used = expect_size cross_products = 1 if not sizes else None for dim in sizes: amt_samples_used = amt_samples_used**dim if not cross_products: cross_products = dim else: cross_products = cross_products**dim if amt_samples_used > self.budget: raise ValueError( "Computing the expectation on this distribution would exceed the computation budget." ) enumerate_tensor = support.new_zeros([amt_samples_used] + list(support.shape[1:])) support_non_expanded = support_non_expanded.squeeze().unsqueeze(1) for i, t in enumerate( itertools.product(support_non_expanded, repeat=cross_products)): enumerate_tensor[i] = torch.cat(t, dim=0) enumerate_tensor = enumerate_tensor.detach() plate_size = enumerate_tensor.shape[0] plate = Plate(self.plate_name, plate_size, plates.copy()) plates.insert(0, plate) s_tensor = storch.StochasticTensor( enumerate_tensor, parents, plates, self.plate_name, plate_size, distr, requires_grad, ) return s_tensor, plate
def sample( self, distr: Distribution, parents: [storch.Tensor], orig_distr_plates: [storch.Plate], requires_grad: bool, ) -> (torch.Tensor, storch.Plate): hard_sample, plate = super().sample( distr, parents, orig_distr_plates, requires_grad ) from storch import conditional_gumbel_rsample gumbel_wor = conditional_gumbel_rsample(hard_sample, distr.probs, isinstance(distr, torch.distributions.Bernoulli), self.temperature) gumbel_wor = storch.StochasticTensor( gumbel_wor._tensor, list(zip(*hard_sample._parents))[0], hard_sample.plates, hard_sample.name, self.k, distr, requires_grad, ) return gumbel_wor, plate
def sample( self, distr: Distribution, parents: [storch.Tensor], orig_distr_plates: [storch.Plate], requires_grad: bool, ) -> (torch.Tensor, storch.Plate): """ Sample from the distribution given the sequence so far. :param distribution: The distribution to sample from :return: """ # This code has three parts. # The first prepares all necessary tensors to make sure they can be easily indexed. # This part is quite long as there are many cases. # 1) There have not been any variables sampled so far. # 2) There have been variables sampled, but their results are NOT used to compute the input distribution. # in other words, the variable to sample is independent of the other sampled variables. However, # we should still keep track of the other sampled variables to make sure that it still samples without # replacement properly. In this case, the ancestral plate is not in the plates attribute. # We also have to make sure that we know in the future what samples are chosen for the _other_ samples. # 3) There have been parents sampled, and this variable is dependent on at least some of them. # The plates list then contains the ancestral plate. We need to make sure we compute the joint log probs # for the conditional samples (ie, based on the different sampled variables in the ancestral dimension). # The second part is a loop over all options for the event dimensions. This samples these conditionally # independent samples in sequence. It samples indexes, not events. # The third part after the loop uses the sampled indexes and matches it to the events to be used. # LEGEND FOR SHAPE COMMENTS # ========================= # To make this code generalize to every bayesian network, complicated shape management is necessary. # The following are references to the shapes that are used within the method # # distr_plates: refers to the plates on the parameters of the distribution. Does *not* include # the k? ancestral plate (possibly empty) # orig_distr_plates: refers to the plates on the parameters of the distribution, and *can* include # the k? ancestral plate (possibly empty) # prev_plates: refers to the plates of the previous sampled variable in this swr sample (possibly empty) # plates: refers to all plates except this ancestral plate, of which there are amt_plates. The first plates are # the distr_plates, after that the prev_plates that are _not_ in distr_plates. # It is composed of distr_plate x (ancstr_plates - distr_plates) # events: refers to the conditionally independent dimensions of the distribution (the distributions batch shape minus the plates) # k: refers to self.k # k?: refers to an optional plate dimension of this ancestral plate. It either doesn't exist, or is the sample # dimension. If it exists, this means this sample is conditionally dependent on ancestors. # |D_yv|: refers to the *size* of the domain # amt_samples: refers to the current amount of sampled sequences. amt_samples <= k, but it can be lower if there # are not enough events to sample from (eg |D_yv| < k) # event_shape: refers to the *shape* of the domain elements # (can be 0, eg Categorical, or equal to |D_yv| for OneHotCategorical) # Do the decoding step given the prepared tensors samples, self.joint_log_probs, self.parent_indexing = self.decode( distr, self.joint_log_probs, parents, orig_distr_plates ) # Find out what sequences have reached the EOS token, and make sure to always sample EOS after that. # Does not contain the ancestral plate as this uses samples instead of s_tensor. if self.eos: self.finished_samples = samples.eq(self.eos) k_index = 0 plates = orig_distr_plates if isinstance(samples, storch.Tensor): k_index = samples.plate_dims plates = samples.plates samples = samples._tensor plate_size = samples.shape[k_index] # Remove the ancestral plate, if it already happens to be in to_remove = None for plate in plates: if plate.name == self.plate_name: to_remove = plate break if to_remove: plates.remove(to_remove) # Create the newly updated plate self.new_plate = self.create_plate(plate_size, plates.copy()) plates.append(self.new_plate) if self.parent_indexing is not None: self.parent_indexing.plates.append(self.new_plate) self.seq = list(map(lambda t: self.new_plate.on_unwrap_tensor(t), self.seq)) # Construct the stochastic tensor s_tensor = storch.StochasticTensor( samples, parents, plates, self.plate_name, plate_size, distr, requires_grad or self.joint_log_probs.requires_grad, ) self.seq.append(s_tensor) # Increase variable index self.variable_index += 1 return s_tensor, self.new_plate