def undo_unique(self, unique_tensor: storch.Tensor) -> torch.Tensor: """ Convert the unique tensor back to the non-unique format, then add the old plates back in # TODO: Make sure self.shrunken_plates is added # TODO: What if unique_tensor contains new plates after the unique? :param unique_tensor: :return: """ plate_idx = unique_tensor.get_plate_dim_index(self.name) with storch.ignore_wrapping(): dim_swapped = unique_tensor.transpose( plate_idx, unique_tensor.plate_dims - 1 ) fl_selected = torch.index_select( dim_swapped, dim=0, index=self.inv_indexing ) selected = fl_selected.reshape( tuple(map(lambda p: p.n, self.shrunken_plates)) + fl_selected.shape[1:] ) return storch.Tensor( selected, [unique_tensor], self.shrunken_plates + unique_tensor.plates, "undo_unique_" + unique_tensor.name, )
def reduce(self, tensor: storch.Tensor, detach_weights=True): plate_weighting = self.weight if detach_weights: plate_weighting = self.weight.detach() if self.n == 1: return storch.reduce(lambda x: x * plate_weighting, self.name)(tensor) # Case: The weight is a single number. First sum, then multiply with the weight (usually taking the mean) elif plate_weighting.ndim == 0: return storch.sum(tensor, self) * plate_weighting # Case: There is a weight for each plate which is not dependent on the other batch dimensions elif plate_weighting.ndim == 1: index = tensor.get_plate_dim_index(self.name) plate_weighting = plate_weighting[ (...,) + (None,) * (tensor.ndim - index - 1) ] weighted_tensor = tensor * plate_weighting return storch.sum(weighted_tensor, self) # Case: The weight is a vector of numbers equal to batch dimension. Assumes it is a storch.Tensor else: for parent_plate in self.parents: if parent_plate not in tensor.plates: raise ValueError( "Plate missing when reducing tensor: " + parent_plate.name ) weighted_tensor = tensor * plate_weighting return storch.sum(weighted_tensor, self)
def b_binary_cross_entropy( input: storch.Tensor, target: storch.Tensor, dims: Union[str, List[str]] = None, weight=None, reduction: str = "mean", ): r"""Function that measures the Binary Cross Entropy in a batched way between the target and the output. See :class:`~torch.nn.BCELoss` for details. Args: input: Tensor of arbitrary shape target: Tensor of the same shape as input weight (Tensor, optional): a manual rescaling weight if provided it's repeated to match input tensor shape reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` Examples:: >>> input = torch.randn((3, 2), requires_grad=True) >>> target = torch.rand((3, 2), requires_grad=False) >>> loss = b_binary_cross_entropy(F.sigmoid(input), target) >>> loss.backward() """ if not dims: dims = [] if isinstance(dims, str): dims = [dims] target = target.expand_as(input) unreduced = deterministic(F.binary_cross_entropy)(input, target, weight, reduction="none") # unreduced = _loss(input, target, weight) indices = list(unreduced.event_dim_indices) + dims if reduction == "mean": return storch.mean(unreduced, indices) elif reduction == "sum": return storch.sum(unreduced, indices) elif reduction == "none": return unreduced
def unique(tensor: storch.Tensor, event_dim: Optional[int] = 0) -> storch.Tensor: with storch.ignore_wrapping(): fl_tensor = torch.flatten(tensor, tensor.plate_dims) uniq, inverse_indexing = torch.unique( fl_tensor, return_inverse=True, dim=event_dim ) inverse_indexing = storch.Tensor( inverse_indexing, [tensor], tensor.plates, "inv_index_" + tensor.name ) uq_plate = UniquePlate( "uq_plate_" + tensor.name, uniq.shape[0], tensor.multi_dim_plates(), inverse_indexing, ) return storch.Tensor(uniq, [tensor], [uq_plate], "unique_" + tensor.name)
def on_unwrap_tensor(self, tensor: storch.Tensor) -> storch.Tensor: """ Gets called whenever the given tensor is being unwrapped and unsqueezed for batch use. This method should not be called on tensors whose variable index is higher than this plates. :param tensor: The input tensor that is being unwrapped :return: The tensor that will be unwrapped and unsqueezed in the future. Can be a modification of the input tensor. """ if self._in_recursion: # Required when calling storch.gather in this method. It will call on_unwrap_tensor again. return tensor for i, plate in enumerate(tensor.multi_dim_plates()): if plate.name != self.name: continue assert isinstance(plate, AncestralPlate) if plate.variable_index == self.variable_index: return tensor # This is true by the filtering at on_collecting_args assert plate.variable_index < self.variable_index parent_plates = [] current_plate = self # Collect the list of plates from the tensors variable index to this plates variable index while current_plate.variable_index != plate.variable_index: parent_plates.append(current_plate) current_plate = current_plate.parent_plate assert current_plate == plate # Go over all parent plates and gather their respective choices. for parent_plate in reversed(parent_plates): self._in_recursion = True expanded_selected_samples = expand_with_ignore_as( parent_plate.selected_samples, tensor, self.name ) self._override_equality = False # Gather what samples of the tensor are chosen by this plate (parent_plate) tensor = storch.gather( tensor, parent_plate.name, expanded_selected_samples ) self._in_recursion = False self._override_equality = False break return tensor
def _convert_indices(tensor: storch.Tensor, dims: _indices) -> (List[int], List[str]): conv_indices = [] red_batches = [] if not isinstance(dims, List): dims = [dims] for index in dims: if isinstance(index, int): if index >= tensor.plate_dims or index < 0 and index >= -tensor.event_dims: conv_indices.append(index) else: print(tensor.shape, index) raise IndexError( "Can only pass indexes for event dimensions." + str(tensor) + ". Index: " + str(index) ) else: if isinstance(index, storch.Plate): index = index.name conv_indices.append(tensor.get_plate_dim_index(index)) red_batches.append(index) return tuple(conv_indices), red_batches
def decode_step( self, indices: Tuple[int], yv_log_probs: storch.Tensor, joint_log_probs: Optional[storch.Tensor], sampled_support_indices: Optional[storch.Tensor], parent_indexing: Optional[storch.Tensor], is_conditional_sample: bool, amt_plates: int, amt_samples: int, ) -> (storch.Tensor, storch.Tensor, storch.Tensor): """ Decode given the input arguments for a specific event using stochastic beam search. :param indices: Tuple of integers indexing the current event to sample. :param yv_log_probs: Log probabilities of the different options for this event. distr_plates x k? x |D_yv| :param joint_log_probs: The log probabilities of the samples so far. None if `not is_conditional_sample`. prev_plates x amt_samples :param sampled_support_indices: Tensor of samples so far. None if this is the first set of indices. plates x k x events :param parent_indexing: Tensor indexing the parent sample. None if `not is_conditional_sample`. :param is_conditional_sample: True if a parent has already been sampled. This means the plates are more complex! :param amt_plates: The total amount of plates in both the distribution and the previously sampled variables :param amt_samples: The amount of active samples. :return: 3-tuple of `storch.Tensor`. 1: sampled_support_indices, with `:, indices` referring to the indices for the support. 2: The updated `joint_log_probs` of the samples. 3: The updated `parent_indexing`. How the samples index the parent samples. Can just return parent_indexing if nothing happens. 4: The amount of active samples after this step. """ first_sample = False if joint_log_probs is None: # We also know that k? is not present, so distr_plates x |D_yv| all_joint_log_probs = yv_log_probs # First condition on max being 0: self.perturbed_log_probs = 0.0 first_sample = True elif is_conditional_sample > 0: # Make sure we are selecting the correct log-probabilities. As parents have been selected, this might change! # plates x amt_samples x |D_yv| yv_log_probs = yv_log_probs.gather( dim=-2, index=right_expand_as( # Use the parent_indexing to select the correct plate samples. Make sure to limit to amt_samples! parent_indexing[..., :amt_samples], yv_log_probs, ), ) # self.joint_log_probs: prev_plates x amt_samples # plates x amt_samples x |D_yv| all_joint_log_probs = joint_log_probs.unsqueeze(-1) + yv_log_probs else: # self.joint_log_probs: plates x amt_samples # plates x amt_samples x |D_yv| all_joint_log_probs = joint_log_probs.unsqueeze( -1) + yv_log_probs.unsqueeze(-2) # Sample plates x k? x |D_yv| conditional Gumbel variables cond_G_yv = cond_gumbel_sample(all_joint_log_probs, self.perturbed_log_probs) # If there are finished samples, ensure eos is always sampled. if self.finished_samples is not None: # TODO: Is this the correct way of ensuring self.eos is always sampled for finished sequences? # Coudl it bias things in any way? # Set the probability of continuing on finished sequences to -infinity so that they are filtered out during topk. # amt_finished finished_perturb_log_probs = self.perturbed_log_probs._tensor[ self.finished_samples._tensor] # amt_finished x |D_yv| finished_vec = finished_perturb_log_probs.new_full( ( finished_perturb_log_probs.shape[0], cond_G_yv.shape[-1], ), -float("inf"), ) # Then make sure the log probability of the eos token is equal to the last perturbed log prob. finished_vec[:, self.eos] = finished_perturb_log_probs cond_G_yv[self.finished_samples] = finished_vec if not first_sample: # plates x (k * |D_yv|) (k == prev_amt_samples, in this case) cond_G_yv = cond_G_yv.reshape(cond_G_yv.shape[:-2] + (-1, )) # Select the samples given the perturbed log probabilities self.perturbed_log_probs, arg_top = self.select_samples( cond_G_yv, all_joint_log_probs) amt_samples = arg_top.shape[-1] if first_sample: # plates x amt_samples joint_log_probs = all_joint_log_probs.gather(dim=-1, index=arg_top) # Index for the selected samples. Uses slice(amt_samples) for the first index in case k > |D_yv| # (:) * amt_plates + (indices for events) + amt_samples indexing = (slice(None), ) * amt_plates + (slice( 0, amt_samples), ) + indices sampled_support_indices[indexing] = arg_top else: # Gather corresponding joint log probabilities. First reshape like previous to plates x (k * |D_yv|). joint_log_probs = all_joint_log_probs.reshape( cond_G_yv.shape).gather(dim=-1, index=arg_top) # |D_yv| size_domain = yv_log_probs.shape[-1] # Keep track of what parents were sampled for the arg top # plates x amt_samples chosen_parents = arg_top // size_domain sampled_support_indices = sampled_support_indices.gather( dim=amt_plates, index=right_expand_as(chosen_parents, sampled_support_indices), ) if parent_indexing is not None: parent_indexing = parent_indexing.gather(dim=-1, index=chosen_parents) # Index for the selected samples. Uses slice(amt_samples) for the first index in case k > |D_yv| # plates x amt_samples chosen_samples = arg_top.remainder(size_domain) indexing = (slice(None), ) * amt_plates + (slice( 0, amt_samples), ) + indices sampled_support_indices[indexing] = chosen_samples return sampled_support_indices, joint_log_probs, parent_indexing, amt_samples
def forward(self, x: storch.Tensor): if self.reshape: x = x.reshape(x.shape[: x.plate_dims] + (-1,)) return self.fc2(F.relu(self.fc1(x))).squeeze(-1)