Ejemplo n.º 1
0
 def undo_unique(self, unique_tensor: storch.Tensor) -> torch.Tensor:
     """
     Convert the unique tensor back to the non-unique format, then add the old plates back in
     # TODO: Make sure self.shrunken_plates is added
     # TODO: What if unique_tensor contains new plates after the unique?
     :param unique_tensor:
     :return:
     """
     plate_idx = unique_tensor.get_plate_dim_index(self.name)
     with storch.ignore_wrapping():
         dim_swapped = unique_tensor.transpose(
             plate_idx, unique_tensor.plate_dims - 1
         )
         fl_selected = torch.index_select(
             dim_swapped, dim=0, index=self.inv_indexing
         )
         selected = fl_selected.reshape(
             tuple(map(lambda p: p.n, self.shrunken_plates)) + fl_selected.shape[1:]
         )
         return storch.Tensor(
             selected,
             [unique_tensor],
             self.shrunken_plates + unique_tensor.plates,
             "undo_unique_" + unique_tensor.name,
         )
Ejemplo n.º 2
0
    def reduce(self, tensor: storch.Tensor, detach_weights=True):
        plate_weighting = self.weight
        if detach_weights:
            plate_weighting = self.weight.detach()
        if self.n == 1:
            return storch.reduce(lambda x: x * plate_weighting, self.name)(tensor)
        # Case: The weight is a single number. First sum, then multiply with the weight (usually taking the mean)
        elif plate_weighting.ndim == 0:
            return storch.sum(tensor, self) * plate_weighting

        # Case: There is a weight for each plate which is not dependent on the other batch dimensions
        elif plate_weighting.ndim == 1:
            index = tensor.get_plate_dim_index(self.name)
            plate_weighting = plate_weighting[
                (...,) + (None,) * (tensor.ndim - index - 1)
            ]
            weighted_tensor = tensor * plate_weighting
            return storch.sum(weighted_tensor, self)

        # Case: The weight is a vector of numbers equal to batch dimension. Assumes it is a storch.Tensor
        else:
            for parent_plate in self.parents:
                if parent_plate not in tensor.plates:
                    raise ValueError(
                        "Plate missing when reducing tensor: " + parent_plate.name
                    )
            weighted_tensor = tensor * plate_weighting
            return storch.sum(weighted_tensor, self)
Ejemplo n.º 3
0
def b_binary_cross_entropy(
    input: storch.Tensor,
    target: storch.Tensor,
    dims: Union[str, List[str]] = None,
    weight=None,
    reduction: str = "mean",
):
    r"""Function that measures the Binary Cross Entropy in a batched way
    between the target and the output.

    See :class:`~torch.nn.BCELoss` for details.

    Args:
        input: Tensor of arbitrary shape
        target: Tensor of the same shape as input
        weight (Tensor, optional): a manual rescaling weight
                if provided it's repeated to match input tensor shape
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``

    Examples::

        >>> input = torch.randn((3, 2), requires_grad=True)
        >>> target = torch.rand((3, 2), requires_grad=False)
        >>> loss = b_binary_cross_entropy(F.sigmoid(input), target)
        >>> loss.backward()
    """

    if not dims:
        dims = []
    if isinstance(dims, str):
        dims = [dims]

    target = target.expand_as(input)
    unreduced = deterministic(F.binary_cross_entropy)(input,
                                                      target,
                                                      weight,
                                                      reduction="none")

    # unreduced = _loss(input, target, weight)
    indices = list(unreduced.event_dim_indices) + dims

    if reduction == "mean":
        return storch.mean(unreduced, indices)
    elif reduction == "sum":
        return storch.sum(unreduced, indices)
    elif reduction == "none":
        return unreduced
Ejemplo n.º 4
0
def unique(tensor: storch.Tensor, event_dim: Optional[int] = 0) -> storch.Tensor:
    with storch.ignore_wrapping():
        fl_tensor = torch.flatten(tensor, tensor.plate_dims)
        uniq, inverse_indexing = torch.unique(
            fl_tensor, return_inverse=True, dim=event_dim
        )
    inverse_indexing = storch.Tensor(
        inverse_indexing, [tensor], tensor.plates, "inv_index_" + tensor.name
    )
    uq_plate = UniquePlate(
        "uq_plate_" + tensor.name,
        uniq.shape[0],
        tensor.multi_dim_plates(),
        inverse_indexing,
    )
    return storch.Tensor(uniq, [tensor], [uq_plate], "unique_" + tensor.name)
Ejemplo n.º 5
0
    def on_unwrap_tensor(self, tensor: storch.Tensor) -> storch.Tensor:
        """
        Gets called whenever the given tensor is being unwrapped and unsqueezed for batch use.
        This method should not be called on tensors whose variable index is higher than this plates.

        :param tensor: The input tensor that is being unwrapped
        :return: The tensor that will be unwrapped and unsqueezed in the future. Can be a modification of the input tensor.
        """
        if self._in_recursion:
            # Required when calling storch.gather in this method. It will call on_unwrap_tensor again.
            return tensor
        for i, plate in enumerate(tensor.multi_dim_plates()):
            if plate.name != self.name:
                continue
            assert isinstance(plate, AncestralPlate)
            if plate.variable_index == self.variable_index:
                return tensor
            # This is true by the filtering at on_collecting_args
            assert plate.variable_index < self.variable_index

            parent_plates = []
            current_plate = self

            # Collect the list of plates from the tensors variable index to this plates variable index
            while current_plate.variable_index != plate.variable_index:
                parent_plates.append(current_plate)
                current_plate = current_plate.parent_plate
            assert current_plate == plate

            # Go over all parent plates and gather their respective choices.
            for parent_plate in reversed(parent_plates):
                self._in_recursion = True
                expanded_selected_samples = expand_with_ignore_as(
                    parent_plate.selected_samples, tensor, self.name
                )
                self._override_equality = False
                # Gather what samples of the tensor are chosen by this plate (parent_plate)
                tensor = storch.gather(
                    tensor, parent_plate.name, expanded_selected_samples
                )
                self._in_recursion = False
                self._override_equality = False
            break
        return tensor
Ejemplo n.º 6
0
def _convert_indices(tensor: storch.Tensor, dims: _indices) -> (List[int], List[str]):
    conv_indices = []
    red_batches = []
    if not isinstance(dims, List):
        dims = [dims]
    for index in dims:
        if isinstance(index, int):
            if index >= tensor.plate_dims or index < 0 and index >= -tensor.event_dims:
                conv_indices.append(index)
            else:
                print(tensor.shape, index)
                raise IndexError(
                    "Can only pass indexes for event dimensions."
                    + str(tensor)
                    + ". Index: "
                    + str(index)
                )
        else:
            if isinstance(index, storch.Plate):
                index = index.name
            conv_indices.append(tensor.get_plate_dim_index(index))
            red_batches.append(index)
    return tuple(conv_indices), red_batches
Ejemplo n.º 7
0
    def decode_step(
        self,
        indices: Tuple[int],
        yv_log_probs: storch.Tensor,
        joint_log_probs: Optional[storch.Tensor],
        sampled_support_indices: Optional[storch.Tensor],
        parent_indexing: Optional[storch.Tensor],
        is_conditional_sample: bool,
        amt_plates: int,
        amt_samples: int,
    ) -> (storch.Tensor, storch.Tensor, storch.Tensor):
        """
        Decode given the input arguments for a specific event using stochastic beam search.
        :param indices: Tuple of integers indexing the current event to sample.
        :param yv_log_probs:  Log probabilities of the different options for this event. distr_plates x k? x |D_yv|
        :param joint_log_probs: The log probabilities of the samples so far. None if `not is_conditional_sample`. prev_plates x amt_samples
        :param sampled_support_indices: Tensor of samples so far. None if this is the first set of indices. plates x k x events
        :param parent_indexing: Tensor indexing the parent sample. None if `not is_conditional_sample`.
        :param is_conditional_sample: True if a parent has already been sampled. This means the plates are more complex!
        :param amt_plates: The total amount of plates in both the distribution and the previously sampled variables
        :param amt_samples: The amount of active samples.
        :return: 3-tuple of `storch.Tensor`. 1: sampled_support_indices, with `:, indices` referring to the indices for the support.
        2: The updated `joint_log_probs` of the samples.
        3: The updated `parent_indexing`. How the samples index the parent samples. Can just return parent_indexing if nothing happens.
        4: The amount of active samples after this step.
        """

        first_sample = False
        if joint_log_probs is None:
            # We also know that k? is not present, so distr_plates x |D_yv|
            all_joint_log_probs = yv_log_probs
            # First condition on max being 0:
            self.perturbed_log_probs = 0.0
            first_sample = True
        elif is_conditional_sample > 0:
            # Make sure we are selecting the correct log-probabilities. As parents have been selected, this might change!
            # plates x amt_samples x |D_yv|
            yv_log_probs = yv_log_probs.gather(
                dim=-2,
                index=right_expand_as(
                    # Use the parent_indexing to select the correct plate samples. Make sure to limit to amt_samples!
                    parent_indexing[..., :amt_samples],
                    yv_log_probs,
                ),
            )
            # self.joint_log_probs: prev_plates x amt_samples
            # plates x amt_samples x |D_yv|
            all_joint_log_probs = joint_log_probs.unsqueeze(-1) + yv_log_probs
        else:
            # self.joint_log_probs: plates x amt_samples
            # plates x amt_samples x |D_yv|
            all_joint_log_probs = joint_log_probs.unsqueeze(
                -1) + yv_log_probs.unsqueeze(-2)

        # Sample plates x k? x |D_yv| conditional Gumbel variables
        cond_G_yv = cond_gumbel_sample(all_joint_log_probs,
                                       self.perturbed_log_probs)

        # If there are finished samples, ensure eos is always sampled.
        if self.finished_samples is not None:
            # TODO: Is this the correct way of ensuring self.eos is always sampled for finished sequences?
            #  Coudl it bias things in any way?
            # Set the probability of continuing on finished sequences to -infinity so that they are filtered out during topk.
            # amt_finished
            finished_perturb_log_probs = self.perturbed_log_probs._tensor[
                self.finished_samples._tensor]
            # amt_finished x |D_yv|
            finished_vec = finished_perturb_log_probs.new_full(
                (
                    finished_perturb_log_probs.shape[0],
                    cond_G_yv.shape[-1],
                ),
                -float("inf"),
            )
            # Then make sure the log probability of the eos token is equal to the last perturbed log prob.
            finished_vec[:, self.eos] = finished_perturb_log_probs
            cond_G_yv[self.finished_samples] = finished_vec

        if not first_sample:
            # plates x (k * |D_yv|) (k == prev_amt_samples, in this case)
            cond_G_yv = cond_G_yv.reshape(cond_G_yv.shape[:-2] + (-1, ))

        # Select the samples given the perturbed log probabilities
        self.perturbed_log_probs, arg_top = self.select_samples(
            cond_G_yv, all_joint_log_probs)
        amt_samples = arg_top.shape[-1]

        if first_sample:
            # plates x amt_samples
            joint_log_probs = all_joint_log_probs.gather(dim=-1, index=arg_top)
            # Index for the selected samples. Uses slice(amt_samples) for the first index in case k > |D_yv|
            # (:) * amt_plates + (indices for events) + amt_samples
            indexing = (slice(None), ) * amt_plates + (slice(
                0, amt_samples), ) + indices
            sampled_support_indices[indexing] = arg_top
        else:
            # Gather corresponding joint log probabilities. First reshape like previous to plates x (k * |D_yv|).
            joint_log_probs = all_joint_log_probs.reshape(
                cond_G_yv.shape).gather(dim=-1, index=arg_top)

            # |D_yv|
            size_domain = yv_log_probs.shape[-1]

            # Keep track of what parents were sampled for the arg top
            # plates x amt_samples
            chosen_parents = arg_top // size_domain
            sampled_support_indices = sampled_support_indices.gather(
                dim=amt_plates,
                index=right_expand_as(chosen_parents, sampled_support_indices),
            )
            if parent_indexing is not None:
                parent_indexing = parent_indexing.gather(dim=-1,
                                                         index=chosen_parents)
            # Index for the selected samples. Uses slice(amt_samples) for the first index in case k > |D_yv|
            # plates x amt_samples
            chosen_samples = arg_top.remainder(size_domain)
            indexing = (slice(None), ) * amt_plates + (slice(
                0, amt_samples), ) + indices
            sampled_support_indices[indexing] = chosen_samples
        return sampled_support_indices, joint_log_probs, parent_indexing, amt_samples
Ejemplo n.º 8
0
 def forward(self, x: storch.Tensor):
     if self.reshape:
         x = x.reshape(x.shape[: x.plate_dims] + (-1,))
     return self.fc2(F.relu(self.fc1(x))).squeeze(-1)