def sample( self, distr: Distribution, parents: [storch.Tensor], plates: [Plate], requires_grad: bool, ) -> (storch.StochasticTensor, Plate): # TODO: Currently very inefficient as it isn't batched # TODO: What if the expectation has a parent if not distr.has_enumerate_support: raise ValueError( "Can only calculate the expected value for distributions with enumerable support." ) support: torch.Tensor = distr.enumerate_support(expand=True) support_non_expanded: torch.Tensor = distr.enumerate_support( expand=False) expect_size = support.shape[0] batch_len = len(plates) sizes = support.shape[batch_len + 1:len(support.shape) - len(distr.event_shape)] amt_samples_used = expect_size cross_products = 1 if not sizes else None for dim in sizes: amt_samples_used = amt_samples_used**dim if not cross_products: cross_products = dim else: cross_products = cross_products**dim if amt_samples_used > self.budget: raise ValueError( "Computing the expectation on this distribution would exceed the computation budget." ) enumerate_tensor = support.new_zeros([amt_samples_used] + list(support.shape[1:])) support_non_expanded = support_non_expanded.squeeze().unsqueeze(1) for i, t in enumerate( itertools.product(support_non_expanded, repeat=cross_products)): enumerate_tensor[i] = torch.cat(t, dim=0) enumerate_tensor = enumerate_tensor.detach() plate_size = enumerate_tensor.shape[0] plate = Plate(self.plate_name, plate_size, plates.copy()) plates.insert(0, plate) s_tensor = storch.StochasticTensor( enumerate_tensor, parents, plates, self.plate_name, plate_size, distr, requires_grad, ) return s_tensor, plate
def decode( self, distr: Distribution, joint_log_probs: Optional[storch.Tensor], parents: [storch.Tensor], orig_distr_plates: [storch.Plate], ) -> (storch.Tensor, storch.Tensor, storch.Tensor): """ Decode given the input arguments :param distribution: The distribution to decode :param joint_log_probs: The log probabilities of the samples so far. prev_plates x amt_samples :param parents: List of parents of this tensor :param orig_distr_plates: List of plates from the distribution. Can include the self plate k. :return: 3-tuple of `storch.Tensor`. 1: The sampled value. 2: The new joint log probabilities of the samples. 3: How the samples index the parent samples. Can just be a range if there is no choosing happening. """ ancestral_distrplate_index = -1 is_conditional_sample = False multi_dim_distr_plates = [] multi_dim_index = 0 for plate in orig_distr_plates: if plate.n > 1: if plate.name == self.plate_name: ancestral_distrplate_index = multi_dim_index is_conditional_sample = True else: multi_dim_distr_plates.append(plate) multi_dim_index += 1 # plates? x k x events x # TODO: This doesn't properly combine two ancestral plates with the same name but different variable index # (they should merge). all_multi_dim_plates = multi_dim_distr_plates.copy() if self.variable_index > 0: # Previous variables have been sampled. add the prev_plates to all_plates for plate in self.joint_log_probs.multi_dim_plates(): if plate not in multi_dim_distr_plates: all_multi_dim_plates.append(plate) amt_multi_dim_plates = len(all_multi_dim_plates) amt_multi_dim_distr_plates = len(multi_dim_distr_plates) amt_multi_dim_orig_distr_plates = amt_multi_dim_distr_plates + ( 1 if is_conditional_sample else 0 ) amt_multi_dim_prev_plates = amt_multi_dim_plates - amt_multi_dim_distr_plates if not distr.has_enumerate_support: raise ValueError("Can only decode distributions with enumerable support.") with storch.ignore_wrapping(): # |D_yv| x (|distr_plates| + |k?| + |event_dims|) * (1,) x |D_yv| support_non_expanded: torch.Tensor = distr.enumerate_support(expand=False) # Compute the log-probability of the different events # |D_yv| x distr_plate[0] x ... k? ... x distr_plate[n-1] x events d_log_probs = distr.log_prob(support_non_expanded) # Note: Use amt_orig_distr_plates here because it might include k? dimension. amt_distr_plates filters this one. # distr_plate[0] x ... k? ... x distr_plate[n-1] x |D_yv| x events d_log_probs = storch.Tensor( d_log_probs.permute( tuple(range(1, amt_multi_dim_orig_distr_plates + 1)) + (0,) + tuple( range( amt_multi_dim_orig_distr_plates + 1, len(d_log_probs.shape) ) ) ), [], orig_distr_plates, ) # |D_yv| x distr_plate[0] x ... x k? x ... x distr_plate[n-1] x events x event_shape support = distr.enumerate_support(expand=True) if is_conditional_sample: # Reduce ancestral dimension in the support. As the dimension is just an expanded version, this should # not change the underlying data. # |D_yv| x distr_plates x events x event_shape support = support[(slice(None),) * (ancestral_distrplate_index + 1) + (0,)] # Gather the correct log probabilities # distr_plate[0] x ... k ... x distr_plate[n-1] x |D_yv| x events # TODO: Move this down below to the other scary TODO d_log_probs = self.new_plate.on_unwrap_tensor(d_log_probs) # Permute the dimensions of d_log_probs st the k dimension is after the plates. for i, plate in enumerate(d_log_probs.multi_dim_plates()): if plate.name == self.plate_name: d_log_probs.plates.remove(plate) # distr_plates x k x |D_yv| x events d_log_probs._tensor = d_log_probs._tensor.permute( tuple(range(0, i)) + tuple(range(i + 1, amt_multi_dim_orig_distr_plates)) + (i,) + tuple( range( amt_multi_dim_orig_distr_plates, len(d_log_probs.shape) ) ) ) break # Equal to event_shape element_shape = distr.event_shape support_permutation = ( tuple(range(1, amt_multi_dim_distr_plates + 1)) + (0,) + tuple(range(amt_multi_dim_distr_plates + 1, len(support.shape))) ) # distr_plates x |D_yv| x events x event_shape support = support.permute(support_permutation) if amt_multi_dim_plates != amt_multi_dim_distr_plates: # If previous samples had plate dimensions that are not in the distribution plates, add these to the support. support = support[ (slice(None),) * amt_multi_dim_distr_plates + (None,) * amt_multi_dim_prev_plates ] all_plate_dims = tuple(map(lambda _p: _p.n, all_multi_dim_plates)) # plates x |D_yv| x events x event_shape (where plates = distr_plates x prev_plates) support = support.expand( all_plate_dims + (-1,) * (len(support.shape) - amt_multi_dim_plates) ) # plates x |D_yv| x events x event_shape support = storch.Tensor(support, [], all_multi_dim_plates) # Equal to events: Shape for the different conditional independent dimensions event_shape = support.shape[ amt_multi_dim_plates + 1 : -len(element_shape) if len(element_shape) > 0 else None ] ranges = [] for size in event_shape: ranges.append(list(range(size))) amt_samples = 0 parent_indexing = None if joint_log_probs is not None: # Initialize a tensor (self.parent_indexing) that keeps track of what samples link to previous choices of samples # Note that joint_log_probs.shape[-1] is amt_samples, not k. It's possible that amt_samples < k! amt_samples = joint_log_probs.shape[-1] # plates x k parent_indexing = support.new_zeros( size=support.shape[:amt_multi_dim_plates] + (self.k,), dtype=torch.long ) # probably can go wrong if plates are missing. parent_indexing[..., :amt_samples] = left_expand_as( torch.arange(amt_samples), parent_indexing ) # plates x k x events sampled_support_indices = support.new_zeros( size=support.shape[:amt_multi_dim_plates] # plates + (self.k,) + support.shape[ amt_multi_dim_plates + 1 : -len(element_shape) if len(element_shape) > 0 else None ], # events dtype=torch.long, ) # Sample independent tensors in sequence # Iterate over the different (conditionally) independent samples being taken (the events) for indices in itertools.product(*ranges): # Log probabilities of the different options for this sample step (event) # distr_plates x k? x |D_yv| yv_log_probs = d_log_probs[(...,) + indices] ( sampled_support_indices, joint_log_probs, parent_indexing, amt_samples, ) = self.decode_step( indices, yv_log_probs, joint_log_probs, sampled_support_indices, parent_indexing, is_conditional_sample, amt_multi_dim_plates, amt_samples, ) # Finally, index the support using the sampled indices to get the sample! if amt_samples < self.k: # plates x amt_samples x events sampled_support_indices = sampled_support_indices[ (...,) + (slice(amt_samples),) + (slice(None),) * len(ranges) ] expanded_indices = right_expand_as(sampled_support_indices, support) sample = support.gather(dim=amt_multi_dim_plates, index=expanded_indices) return sample, joint_log_probs, parent_indexing