def hook(*args): # For some reason, this args unpacking is required for compatbility with registring on a .grad_fn...? # TODO: I'm sure there could be something wrong here grad = args[-1] if isinstance(grad, tuple): grad = grad[0] # print(grad) if name in accum_grads: accum_grads[name] = storch.Tensor( accum_grads[name]._tensor + grad, [], plates, name + "_grad") else: accum_grads[name] = storch.Tensor(grad, [], plates, name + "_grad")
def _handle_inputs( tensor: AnyTensor, plates: Optional[_plates], ) -> (storch.Tensor, List[storch.Plate]): if isinstance(plates, storch.Plate): plates = [plates] if not isinstance(tensor, storch.Tensor): if not plates: raise ValueError("Make sure to pass plates when passing a torch.Tensor.") index_tensor = 0 for plate in plates: if not isinstance(plate, storch.Plate): raise ValueError( "Cannot handle plate names when passing a torch.Tensor" ) if plate.n > 1: if tensor.shape[index_tensor] != plate.n: raise ValueError( "Received a tensor that does not align with the given plates." ) index_tensor += 1 return storch.Tensor(tensor, [], plates), plates if not plates: return tensor, tensor.plates if isinstance(plates, str): return tensor, [tensor.get_plate(plates)] r_plates = [] for plate in plates: if isinstance(plate, storch.Plate): r_plates.append(plate) else: r_plates.append(tensor.get_plate(plate)) return tensor, plates
def unique(tensor: storch.Tensor, event_dim: Optional[int] = 0) -> storch.Tensor: with storch.ignore_wrapping(): fl_tensor = torch.flatten(tensor, tensor.plate_dims) uniq, inverse_indexing = torch.unique( fl_tensor, return_inverse=True, dim=event_dim ) inverse_indexing = storch.Tensor( inverse_indexing, [tensor], tensor.plates, "inv_index_" + tensor.name ) uq_plate = UniquePlate( "uq_plate_" + tensor.name, uniq.shape[0], tensor.multi_dim_plates(), inverse_indexing, ) return storch.Tensor(uniq, [tensor], [uq_plate], "unique_" + tensor.name)
def undo_unique(self, unique_tensor: storch.Tensor) -> torch.Tensor: """ Convert the unique tensor back to the non-unique format, then add the old plates back in # TODO: Make sure self.shrunken_plates is added # TODO: What if unique_tensor contains new plates after the unique? :param unique_tensor: :return: """ plate_idx = unique_tensor.get_plate_dim_index(self.name) with storch.ignore_wrapping(): dim_swapped = unique_tensor.transpose( plate_idx, unique_tensor.plate_dims - 1 ) fl_selected = torch.index_select( dim_swapped, dim=0, index=self.inv_indexing ) selected = fl_selected.reshape( tuple(map(lambda p: p.n, self.shrunken_plates)) + fl_selected.shape[1:] ) return storch.Tensor( selected, [unique_tensor], self.shrunken_plates + unique_tensor.plates, "undo_unique_" + unique_tensor.name, )
def create_plate(self, plate_size: int, plates: [storch.Plate]) -> AncestralPlate: plate = super().create_plate(plate_size, plates) plate.perturb_log_probs = storch.Tensor( self.perturbed_log_probs._tensor, [self.perturbed_log_probs], self.perturbed_log_probs.plates + [plate], ) return plate
def grad( outputs, inputs, grad_outputs=None, retain_graph: Optional[bool] = None, create_graph: bool = False, only_inputs: bool = True, allow_unused: bool = False, ) -> Tuple[storch.Tensor, ...]: """ Helper method for computing torch.autograd.grad on storch tensors. Returns storch Tensors as well. """ args, _, _, _ = storch.wrappers._prepare_args( [outputs, inputs, grad_outputs], {}, unwrap=True, align_tensors=False, ) _outputs, _inputs, _grad_outputs = tuple(args) grads = torch.autograd.grad( _outputs, _inputs, grad_outputs=_grad_outputs, retain_graph=retain_graph, create_graph=create_graph, only_inputs=only_inputs, allow_unused=allow_unused, ) storch_grad = [] for i, grad in enumerate(grads): input = inputs[i] if isinstance(input, storch.Tensor): storch_grad.append( storch.Tensor( grad, outputs + [input], input.plates, input.name + "_grad" ) ) else: storch_grad.append(storch.Tensor(grad, outputs, [], "grad")) return tuple(storch_grad)
def _process_stochastic(output: torch.Tensor, parents: [storch.Tensor], plates: [storch.Plate]): if isinstance(output, storch.Tensor): if not output.stochastic: # TODO: Calls _add_parents so something is going wrong here # The Tensor was created by calling @deterministic within a stochastic context. # This means that we have to conservatively assume it is dependent on the parents output._add_parents(storch.wrappers._stochastic_parents) return output if isinstance(output, torch.Tensor): t = storch.Tensor(output, parents, plates) return t else: raise TypeError( "All outputs of functions wrapped in @storch.stochastic " "should be Tensors. At " + str(output))
def _prepare_outputs_det( o: Any, parents: [storch.Tensor], plates: [storch.Plate], name: str, index: int, unflatten_plates, ): if o is None: return None, index if isinstance(o, storch.Tensor): if o.stochastic: raise RuntimeError( "Creation of stochastic storch Tensor within deterministic context" ) # TODO: Does this require shape checking? Parent/Plate checking? return o, index + 1 if isinstance(o, torch.Tensor): # Explicitly _not_ a storch.Tensor if unflatten_plates: plate_dims = tuple([plate.n for plate in plates if plate.n > 1]) o = o.reshape(plate_dims + o.shape[1:]) t = storch.Tensor(o, parents, plates, name=name + str(index)) return t, index + 1 if is_iterable(o): outputs = [] for _o in o: t, index = _prepare_outputs_det(_o, parents, plates, name, index, unflatten_plates=unflatten_plates) outputs.append(t) if isinstance(o, tuple): return tuple(outputs), index return outputs, index raise NotImplementedError( "Handling of other types of return values is currently not implemented: ", o)
def estimator( self, tensor: storch.StochasticTensor, cost_node: storch.CostTensor ) -> Tuple[Optional[storch.Tensor], Optional[storch.Tensor]]: # Note: We automatically multiply with leave-one-out ratio in the plate reduction plate = None for _p in cost_node.plates: if _p.name == self.plate_name: plate = _p break if not self.use_baseline: return plate.log_probs * cost_node.detach() # Subtract the 'average' cost of other samples, keeping in mind that samples are not independent. # plates x k baseline = storch.sum( (plate.log_probs + plate.log_snd_leave_one_out).exp() * cost_node, plate) # Make sure the k dimension is recognized as batch dimension baseline = storch.Tensor(baseline._tensor, [baseline], baseline.plates + [plate]) return plate.log_probs, ( 1 - magic_box(plate.log_probs)) * baseline.detach()
def __init__( self, name: str, n: int, parents: List[storch.Plate], variable_index: int, parent_plate: AncestralPlate, selected_samples: storch.Tensor, log_probs: storch.Tensor, weight: Optional[storch.Tensor] = None, ): super().__init__(name, n, parents, weight) assert (not parent_plate and variable_index == 0) or (parent_plate.n <= self.n and parent_plate.variable_index < variable_index) self.parent_plate = parent_plate self.selected_samples = selected_samples self.log_probs = storch.Tensor(log_probs._tensor, [log_probs], log_probs.plates + [self]) self.variable_index = variable_index self._in_recursion = False self._override_equality = False
def decode( self, distr: Distribution, joint_log_probs: Optional[storch.Tensor], parents: [storch.Tensor], orig_distr_plates: [storch.Plate], ) -> (storch.Tensor, storch.Tensor, storch.Tensor): """ Decode given the input arguments :param distribution: The distribution to decode :param joint_log_probs: The log probabilities of the samples so far. prev_plates x amt_samples :param parents: List of parents of this tensor :param orig_distr_plates: List of plates from the distribution. Can include the self plate k. :return: 3-tuple of `storch.Tensor`. 1: The sampled value. 2: The new joint log probabilities of the samples. 3: How the samples index the parent samples. Can just be a range if there is no choosing happening. """ ancestral_distrplate_index = -1 is_conditional_sample = False multi_dim_distr_plates = [] multi_dim_index = 0 for plate in orig_distr_plates: if plate.n > 1: if plate.name == self.plate_name: ancestral_distrplate_index = multi_dim_index is_conditional_sample = True else: multi_dim_distr_plates.append(plate) multi_dim_index += 1 # plates? x k x events x # TODO: This doesn't properly combine two ancestral plates with the same name but different variable index # (they should merge). all_multi_dim_plates = multi_dim_distr_plates.copy() if self.variable_index > 0: # Previous variables have been sampled. add the prev_plates to all_plates for plate in self.joint_log_probs.multi_dim_plates(): if plate not in multi_dim_distr_plates: all_multi_dim_plates.append(plate) amt_multi_dim_plates = len(all_multi_dim_plates) amt_multi_dim_distr_plates = len(multi_dim_distr_plates) amt_multi_dim_orig_distr_plates = amt_multi_dim_distr_plates + ( 1 if is_conditional_sample else 0 ) amt_multi_dim_prev_plates = amt_multi_dim_plates - amt_multi_dim_distr_plates if not distr.has_enumerate_support: raise ValueError("Can only decode distributions with enumerable support.") with storch.ignore_wrapping(): # |D_yv| x (|distr_plates| + |k?| + |event_dims|) * (1,) x |D_yv| support_non_expanded: torch.Tensor = distr.enumerate_support(expand=False) # Compute the log-probability of the different events # |D_yv| x distr_plate[0] x ... k? ... x distr_plate[n-1] x events d_log_probs = distr.log_prob(support_non_expanded) # Note: Use amt_orig_distr_plates here because it might include k? dimension. amt_distr_plates filters this one. # distr_plate[0] x ... k? ... x distr_plate[n-1] x |D_yv| x events d_log_probs = storch.Tensor( d_log_probs.permute( tuple(range(1, amt_multi_dim_orig_distr_plates + 1)) + (0,) + tuple( range( amt_multi_dim_orig_distr_plates + 1, len(d_log_probs.shape) ) ) ), [], orig_distr_plates, ) # |D_yv| x distr_plate[0] x ... x k? x ... x distr_plate[n-1] x events x event_shape support = distr.enumerate_support(expand=True) if is_conditional_sample: # Reduce ancestral dimension in the support. As the dimension is just an expanded version, this should # not change the underlying data. # |D_yv| x distr_plates x events x event_shape support = support[(slice(None),) * (ancestral_distrplate_index + 1) + (0,)] # Gather the correct log probabilities # distr_plate[0] x ... k ... x distr_plate[n-1] x |D_yv| x events # TODO: Move this down below to the other scary TODO d_log_probs = self.new_plate.on_unwrap_tensor(d_log_probs) # Permute the dimensions of d_log_probs st the k dimension is after the plates. for i, plate in enumerate(d_log_probs.multi_dim_plates()): if plate.name == self.plate_name: d_log_probs.plates.remove(plate) # distr_plates x k x |D_yv| x events d_log_probs._tensor = d_log_probs._tensor.permute( tuple(range(0, i)) + tuple(range(i + 1, amt_multi_dim_orig_distr_plates)) + (i,) + tuple( range( amt_multi_dim_orig_distr_plates, len(d_log_probs.shape) ) ) ) break # Equal to event_shape element_shape = distr.event_shape support_permutation = ( tuple(range(1, amt_multi_dim_distr_plates + 1)) + (0,) + tuple(range(amt_multi_dim_distr_plates + 1, len(support.shape))) ) # distr_plates x |D_yv| x events x event_shape support = support.permute(support_permutation) if amt_multi_dim_plates != amt_multi_dim_distr_plates: # If previous samples had plate dimensions that are not in the distribution plates, add these to the support. support = support[ (slice(None),) * amt_multi_dim_distr_plates + (None,) * amt_multi_dim_prev_plates ] all_plate_dims = tuple(map(lambda _p: _p.n, all_multi_dim_plates)) # plates x |D_yv| x events x event_shape (where plates = distr_plates x prev_plates) support = support.expand( all_plate_dims + (-1,) * (len(support.shape) - amt_multi_dim_plates) ) # plates x |D_yv| x events x event_shape support = storch.Tensor(support, [], all_multi_dim_plates) # Equal to events: Shape for the different conditional independent dimensions event_shape = support.shape[ amt_multi_dim_plates + 1 : -len(element_shape) if len(element_shape) > 0 else None ] ranges = [] for size in event_shape: ranges.append(list(range(size))) amt_samples = 0 parent_indexing = None if joint_log_probs is not None: # Initialize a tensor (self.parent_indexing) that keeps track of what samples link to previous choices of samples # Note that joint_log_probs.shape[-1] is amt_samples, not k. It's possible that amt_samples < k! amt_samples = joint_log_probs.shape[-1] # plates x k parent_indexing = support.new_zeros( size=support.shape[:amt_multi_dim_plates] + (self.k,), dtype=torch.long ) # probably can go wrong if plates are missing. parent_indexing[..., :amt_samples] = left_expand_as( torch.arange(amt_samples), parent_indexing ) # plates x k x events sampled_support_indices = support.new_zeros( size=support.shape[:amt_multi_dim_plates] # plates + (self.k,) + support.shape[ amt_multi_dim_plates + 1 : -len(element_shape) if len(element_shape) > 0 else None ], # events dtype=torch.long, ) # Sample independent tensors in sequence # Iterate over the different (conditionally) independent samples being taken (the events) for indices in itertools.product(*ranges): # Log probabilities of the different options for this sample step (event) # distr_plates x k? x |D_yv| yv_log_probs = d_log_probs[(...,) + indices] ( sampled_support_indices, joint_log_probs, parent_indexing, amt_samples, ) = self.decode_step( indices, yv_log_probs, joint_log_probs, sampled_support_indices, parent_indexing, is_conditional_sample, amt_multi_dim_plates, amt_samples, ) # Finally, index the support using the sampled indices to get the sample! if amt_samples < self.k: # plates x amt_samples x events sampled_support_indices = sampled_support_indices[ (...,) + (slice(amt_samples),) + (slice(None),) * len(ranges) ] expanded_indices = right_expand_as(sampled_support_indices, support) sample = support.gather(dim=amt_multi_dim_plates, index=expanded_indices) return sample, joint_log_probs, parent_indexing
def plate_weighting(self, tensor: storch.StochasticTensor, plate: storch.Plate) -> Optional[storch.Tensor]: # plates_w_k is a sequence of plates of which one is the input ancestral plate # Computes p(s) * R(S^k, s), or the probability of the sample times the leave-one-out ratio. # For details, see https://openreview.net/pdf?id=rklEj2EFvB # Code based on https://github.com/wouterkool/estimating-gradients-without-replacement/blob/master/bernoulli/gumbel.py log_probs = plate.log_probs.detach() # print("--------------------") # Compute integration points for the trapezoid rule: v should range from 0 to 1, where both v=0 and v=1 give a value of 0. # As the computation happens in log-space, take the logarithm of the result. # N v = ( torch.arange(1, self.num_int_points, out=log_probs._tensor.new()) / self.num_int_points) log_v = v.log() # Compute log(1-v^{exp(log_probs+a)}) in a numerically stable way in log-space # Uses the gumbel_log_survival function from # https://github.com/wouterkool/estimating-gradients-without-replacement/blob/master/bernoulli/gumbel.py # plates_w_k x N g_bound = (log_probs[..., None] + self.a + torch.log(-log_v)[log_probs.plate_dims * (None, ) + (slice(None), )]) # Gumbel log survival: log P(g > g_bound) = log(1 - exp(-exp(-g_bound))) for standard gumbel g # If g_bound >= 10, use the series expansion for stability with error O((e^-10)^6) (=8.7E-27) # See https://www.wolframalpha.com/input/?i=log%281+-+exp%28-y%29%29 y = torch.exp(g_bound) # plates_w_k x N terms = torch.where(g_bound >= 10, -g_bound - y / 2 + y**2 / 24 - y**4 / 2880, log1mexp(y)) # print("terms", terms._tensor[0]) # Compute integrands (without subtracting the special value s) # plates x N sum_of_terms = storch.sum(terms, plate) phi_S = storch.logsumexp(log_probs, plate) phi_D_min_S = log1mexp(phi_S) # plates x N integrand = (sum_of_terms + torch.expm1(self.a + phi_D_min_S)[..., None] * log_v[phi_D_min_S.plate_dims * (None, ) + (slice(None), )]) # Subtract one term the for element that is left out in R # Automatically unsqueezes correctly using plate dimensions # plates_w_k x N integrand_without_s = integrand - terms # plates log_p_S = integrand.logsumexp(dim=-1) # plates_w_k log_p_S_without_s = integrand_without_s.logsumexp(dim=-1) # plates_w_k log_leave_one_out = log_p_S_without_s - log_p_S if self.comp_leave_two_out: # Compute the integrands for the 2nd order leave one out ratio. # Make sure to properly choose the indices: We shouldn't subtract the same term twice on the diagonals. # k x k skip_diag = storch.Tensor( 1 - torch.eye(plate.n, out=log_probs._tensor.new()), [], [plate]) # plates_w_k x k x N integrand_without_ss = (integrand_without_s[..., None, :] - terms[..., None, :] * skip_diag[..., None]) # plates_w_k x k log_p_S_without_ss = integrand_without_ss.logsumexp(dim=-1) plate.log_snd_leave_one_out = log_p_S_without_ss - log_p_S_without_s # print("lloo", log_leave_one_out._tensor[0]) # print("log_probs", log_probs._tensor[0]) # print("weighting", (log_leave_one_out + log_probs).exp()._tensor[0]) # Return the unordered set estimator weighting return (log_leave_one_out + log_probs).exp().detach()
def to_storch(tensor: torch.Tensor) -> storch.Tensor: return storch.Tensor(tensor, [], [], "test")