Exemplo n.º 1
0
        def hook(*args):
            # For some reason, this args unpacking is required for compatbility with registring on a .grad_fn...?
            # TODO: I'm sure there could be something wrong here
            grad = args[-1]
            if isinstance(grad, tuple):
                grad = grad[0]

            # print(grad)
            if name in accum_grads:
                accum_grads[name] = storch.Tensor(
                    accum_grads[name]._tensor + grad, [], plates,
                    name + "_grad")
            else:
                accum_grads[name] = storch.Tensor(grad, [], plates,
                                                  name + "_grad")
Exemplo n.º 2
0
def _handle_inputs(
    tensor: AnyTensor, plates: Optional[_plates],
) -> (storch.Tensor, List[storch.Plate]):
    if isinstance(plates, storch.Plate):
        plates = [plates]
    if not isinstance(tensor, storch.Tensor):
        if not plates:
            raise ValueError("Make sure to pass plates when passing a torch.Tensor.")
        index_tensor = 0

        for plate in plates:
            if not isinstance(plate, storch.Plate):
                raise ValueError(
                    "Cannot handle plate names when passing a torch.Tensor"
                )
            if plate.n > 1:
                if tensor.shape[index_tensor] != plate.n:
                    raise ValueError(
                        "Received a tensor that does not align with the given plates."
                    )
                index_tensor += 1
        return storch.Tensor(tensor, [], plates), plates

    if not plates:
        return tensor, tensor.plates
    if isinstance(plates, str):
        return tensor, [tensor.get_plate(plates)]
    r_plates = []
    for plate in plates:
        if isinstance(plate, storch.Plate):
            r_plates.append(plate)
        else:
            r_plates.append(tensor.get_plate(plate))
    return tensor, plates
Exemplo n.º 3
0
def unique(tensor: storch.Tensor, event_dim: Optional[int] = 0) -> storch.Tensor:
    with storch.ignore_wrapping():
        fl_tensor = torch.flatten(tensor, tensor.plate_dims)
        uniq, inverse_indexing = torch.unique(
            fl_tensor, return_inverse=True, dim=event_dim
        )
    inverse_indexing = storch.Tensor(
        inverse_indexing, [tensor], tensor.plates, "inv_index_" + tensor.name
    )
    uq_plate = UniquePlate(
        "uq_plate_" + tensor.name,
        uniq.shape[0],
        tensor.multi_dim_plates(),
        inverse_indexing,
    )
    return storch.Tensor(uniq, [tensor], [uq_plate], "unique_" + tensor.name)
Exemplo n.º 4
0
 def undo_unique(self, unique_tensor: storch.Tensor) -> torch.Tensor:
     """
     Convert the unique tensor back to the non-unique format, then add the old plates back in
     # TODO: Make sure self.shrunken_plates is added
     # TODO: What if unique_tensor contains new plates after the unique?
     :param unique_tensor:
     :return:
     """
     plate_idx = unique_tensor.get_plate_dim_index(self.name)
     with storch.ignore_wrapping():
         dim_swapped = unique_tensor.transpose(
             plate_idx, unique_tensor.plate_dims - 1
         )
         fl_selected = torch.index_select(
             dim_swapped, dim=0, index=self.inv_indexing
         )
         selected = fl_selected.reshape(
             tuple(map(lambda p: p.n, self.shrunken_plates)) + fl_selected.shape[1:]
         )
         return storch.Tensor(
             selected,
             [unique_tensor],
             self.shrunken_plates + unique_tensor.plates,
             "undo_unique_" + unique_tensor.name,
         )
Exemplo n.º 5
0
 def create_plate(self, plate_size: int,
                  plates: [storch.Plate]) -> AncestralPlate:
     plate = super().create_plate(plate_size, plates)
     plate.perturb_log_probs = storch.Tensor(
         self.perturbed_log_probs._tensor,
         [self.perturbed_log_probs],
         self.perturbed_log_probs.plates + [plate],
     )
     return plate
Exemplo n.º 6
0
def grad(
    outputs,
    inputs,
    grad_outputs=None,
    retain_graph: Optional[bool] = None,
    create_graph: bool = False,
    only_inputs: bool = True,
    allow_unused: bool = False,
) -> Tuple[storch.Tensor, ...]:
    """
    Helper method for computing torch.autograd.grad on storch tensors. Returns storch Tensors as well.
    """
    args, _, _, _ = storch.wrappers._prepare_args(
        [outputs, inputs, grad_outputs], {}, unwrap=True, align_tensors=False,
    )
    _outputs, _inputs, _grad_outputs = tuple(args)
    grads = torch.autograd.grad(
        _outputs,
        _inputs,
        grad_outputs=_grad_outputs,
        retain_graph=retain_graph,
        create_graph=create_graph,
        only_inputs=only_inputs,
        allow_unused=allow_unused,
    )
    storch_grad = []
    for i, grad in enumerate(grads):
        input = inputs[i]
        if isinstance(input, storch.Tensor):
            storch_grad.append(
                storch.Tensor(
                    grad, outputs + [input], input.plates, input.name + "_grad"
                )
            )
        else:
            storch_grad.append(storch.Tensor(grad, outputs, [], "grad"))
    return tuple(storch_grad)
Exemplo n.º 7
0
def _process_stochastic(output: torch.Tensor, parents: [storch.Tensor],
                        plates: [storch.Plate]):
    if isinstance(output, storch.Tensor):
        if not output.stochastic:
            # TODO: Calls _add_parents so something is going wrong here
            # The Tensor was created by calling @deterministic within a stochastic context.
            # This means that we have to conservatively assume it is dependent on the parents
            output._add_parents(storch.wrappers._stochastic_parents)
        return output
    if isinstance(output, torch.Tensor):
        t = storch.Tensor(output, parents, plates)
        return t
    else:
        raise TypeError(
            "All outputs of functions wrapped in @storch.stochastic "
            "should be Tensors. At " + str(output))
Exemplo n.º 8
0
def _prepare_outputs_det(
    o: Any,
    parents: [storch.Tensor],
    plates: [storch.Plate],
    name: str,
    index: int,
    unflatten_plates,
):
    if o is None:
        return None, index
    if isinstance(o, storch.Tensor):
        if o.stochastic:
            raise RuntimeError(
                "Creation of stochastic storch Tensor within deterministic context"
            )
        # TODO: Does this require shape checking? Parent/Plate checking?
        return o, index + 1
    if isinstance(o, torch.Tensor):  # Explicitly _not_ a storch.Tensor
        if unflatten_plates:
            plate_dims = tuple([plate.n for plate in plates if plate.n > 1])
            o = o.reshape(plate_dims + o.shape[1:])
        t = storch.Tensor(o, parents, plates, name=name + str(index))
        return t, index + 1
    if is_iterable(o):
        outputs = []
        for _o in o:
            t, index = _prepare_outputs_det(_o,
                                            parents,
                                            plates,
                                            name,
                                            index,
                                            unflatten_plates=unflatten_plates)
            outputs.append(t)
        if isinstance(o, tuple):
            return tuple(outputs), index
        return outputs, index
    raise NotImplementedError(
        "Handling of other types of return values is currently not implemented: ",
        o)
Exemplo n.º 9
0
    def estimator(
        self, tensor: storch.StochasticTensor, cost_node: storch.CostTensor
    ) -> Tuple[Optional[storch.Tensor], Optional[storch.Tensor]]:
        # Note: We automatically multiply with leave-one-out ratio in the plate reduction
        plate = None
        for _p in cost_node.plates:
            if _p.name == self.plate_name:
                plate = _p
                break
        if not self.use_baseline:
            return plate.log_probs * cost_node.detach()

        # Subtract the 'average' cost of other samples, keeping in mind that samples are not independent.
        # plates x k
        baseline = storch.sum(
            (plate.log_probs + plate.log_snd_leave_one_out).exp() * cost_node,
            plate)
        # Make sure the k dimension is recognized as batch dimension
        baseline = storch.Tensor(baseline._tensor, [baseline],
                                 baseline.plates + [plate])
        return plate.log_probs, (
            1 - magic_box(plate.log_probs)) * baseline.detach()
Exemplo n.º 10
0
 def __init__(
     self,
     name: str,
     n: int,
     parents: List[storch.Plate],
     variable_index: int,
     parent_plate: AncestralPlate,
     selected_samples: storch.Tensor,
     log_probs: storch.Tensor,
     weight: Optional[storch.Tensor] = None,
 ):
     super().__init__(name, n, parents, weight)
     assert (not parent_plate and variable_index
             == 0) or (parent_plate.n <= self.n
                       and parent_plate.variable_index < variable_index)
     self.parent_plate = parent_plate
     self.selected_samples = selected_samples
     self.log_probs = storch.Tensor(log_probs._tensor, [log_probs],
                                    log_probs.plates + [self])
     self.variable_index = variable_index
     self._in_recursion = False
     self._override_equality = False
Exemplo n.º 11
0
    def decode(
        self,
        distr: Distribution,
        joint_log_probs: Optional[storch.Tensor],
        parents: [storch.Tensor],
        orig_distr_plates: [storch.Plate],
    ) -> (storch.Tensor, storch.Tensor, storch.Tensor):
        """
        Decode given the input arguments
        :param distribution: The distribution to decode
        :param joint_log_probs: The log probabilities of the samples so far. prev_plates x amt_samples
        :param parents: List of parents of this tensor
        :param orig_distr_plates: List of plates from the distribution. Can include the self plate k.
        :return: 3-tuple of `storch.Tensor`. 1: The sampled value. 2: The new joint log probabilities of the samples.
        3: How the samples index the parent samples. Can just be a range if there is no choosing happening.
        """
        ancestral_distrplate_index = -1
        is_conditional_sample = False

        multi_dim_distr_plates = []
        multi_dim_index = 0
        for plate in orig_distr_plates:
            if plate.n > 1:
                if plate.name == self.plate_name:
                    ancestral_distrplate_index = multi_dim_index
                    is_conditional_sample = True
                else:
                    multi_dim_distr_plates.append(plate)
                multi_dim_index += 1
        # plates? x k x events x

        # TODO: This doesn't properly combine two ancestral plates with the same name but different variable index
        #  (they should merge).
        all_multi_dim_plates = multi_dim_distr_plates.copy()
        if self.variable_index > 0:
            # Previous variables have been sampled. add the prev_plates to all_plates
            for plate in self.joint_log_probs.multi_dim_plates():
                if plate not in multi_dim_distr_plates:
                    all_multi_dim_plates.append(plate)

        amt_multi_dim_plates = len(all_multi_dim_plates)
        amt_multi_dim_distr_plates = len(multi_dim_distr_plates)
        amt_multi_dim_orig_distr_plates = amt_multi_dim_distr_plates + (
            1 if is_conditional_sample else 0
        )
        amt_multi_dim_prev_plates = amt_multi_dim_plates - amt_multi_dim_distr_plates
        if not distr.has_enumerate_support:
            raise ValueError("Can only decode distributions with enumerable support.")

        with storch.ignore_wrapping():
            # |D_yv| x (|distr_plates| + |k?| + |event_dims|) * (1,) x |D_yv|
            support_non_expanded: torch.Tensor = distr.enumerate_support(expand=False)
            # Compute the log-probability of the different events
            # |D_yv| x distr_plate[0] x ... k? ... x distr_plate[n-1] x events
            d_log_probs = distr.log_prob(support_non_expanded)

            # Note: Use amt_orig_distr_plates here because it might include k? dimension. amt_distr_plates filters this one.
            # distr_plate[0] x ... k? ... x distr_plate[n-1] x |D_yv| x events
            d_log_probs = storch.Tensor(
                d_log_probs.permute(
                    tuple(range(1, amt_multi_dim_orig_distr_plates + 1))
                    + (0,)
                    + tuple(
                        range(
                            amt_multi_dim_orig_distr_plates + 1, len(d_log_probs.shape)
                        )
                    )
                ),
                [],
                orig_distr_plates,
            )

        # |D_yv| x distr_plate[0] x ... x k? x ... x distr_plate[n-1] x events x event_shape
        support = distr.enumerate_support(expand=True)

        if is_conditional_sample:
            # Reduce ancestral dimension in the support. As the dimension is just an expanded version, this should
            # not change the underlying data.
            # |D_yv| x distr_plates x events x event_shape
            support = support[(slice(None),) * (ancestral_distrplate_index + 1) + (0,)]

            # Gather the correct log probabilities
            # distr_plate[0] x ... k ... x distr_plate[n-1] x |D_yv| x events
            # TODO: Move this down below to the other scary TODO
            d_log_probs = self.new_plate.on_unwrap_tensor(d_log_probs)
            # Permute the dimensions of d_log_probs st the k dimension is after the plates.
            for i, plate in enumerate(d_log_probs.multi_dim_plates()):
                if plate.name == self.plate_name:
                    d_log_probs.plates.remove(plate)
                    # distr_plates x k x |D_yv| x events
                    d_log_probs._tensor = d_log_probs._tensor.permute(
                        tuple(range(0, i))
                        + tuple(range(i + 1, amt_multi_dim_orig_distr_plates))
                        + (i,)
                        + tuple(
                            range(
                                amt_multi_dim_orig_distr_plates, len(d_log_probs.shape)
                            )
                        )
                    )
                    break

        # Equal to event_shape
        element_shape = distr.event_shape
        support_permutation = (
            tuple(range(1, amt_multi_dim_distr_plates + 1))
            + (0,)
            + tuple(range(amt_multi_dim_distr_plates + 1, len(support.shape)))
        )
        # distr_plates x |D_yv| x events x event_shape
        support = support.permute(support_permutation)

        if amt_multi_dim_plates != amt_multi_dim_distr_plates:
            # If previous samples had plate dimensions that are not in the distribution plates, add these to the support.
            support = support[
                (slice(None),) * amt_multi_dim_distr_plates
                + (None,) * amt_multi_dim_prev_plates
            ]
            all_plate_dims = tuple(map(lambda _p: _p.n, all_multi_dim_plates))
            # plates x |D_yv| x events x event_shape (where plates = distr_plates x prev_plates)
            support = support.expand(
                all_plate_dims + (-1,) * (len(support.shape) - amt_multi_dim_plates)
            )
        # plates x |D_yv| x events x event_shape
        support = storch.Tensor(support, [], all_multi_dim_plates)

        # Equal to events: Shape for the different conditional independent dimensions
        event_shape = support.shape[
            amt_multi_dim_plates + 1 : -len(element_shape)
            if len(element_shape) > 0
            else None
        ]

        ranges = []
        for size in event_shape:
            ranges.append(list(range(size)))

        amt_samples = 0
        parent_indexing = None
        if joint_log_probs is not None:
            # Initialize a tensor (self.parent_indexing) that keeps track of what samples link to previous choices of samples
            # Note that joint_log_probs.shape[-1] is amt_samples, not k. It's possible that amt_samples < k!
            amt_samples = joint_log_probs.shape[-1]
            # plates x k
            parent_indexing = support.new_zeros(
                size=support.shape[:amt_multi_dim_plates] + (self.k,), dtype=torch.long
            )

            # probably can go wrong if plates are missing.
            parent_indexing[..., :amt_samples] = left_expand_as(
                torch.arange(amt_samples), parent_indexing
            )
        # plates x k x events
        sampled_support_indices = support.new_zeros(
            size=support.shape[:amt_multi_dim_plates]  # plates
            + (self.k,)
            + support.shape[
                amt_multi_dim_plates + 1 : -len(element_shape)
                if len(element_shape) > 0
                else None
            ],  # events
            dtype=torch.long,
        )
        # Sample independent tensors in sequence
        # Iterate over the different (conditionally) independent samples being taken (the events)
        for indices in itertools.product(*ranges):
            # Log probabilities of the different options for this sample step (event)
            # distr_plates x k? x |D_yv|
            yv_log_probs = d_log_probs[(...,) + indices]
            (
                sampled_support_indices,
                joint_log_probs,
                parent_indexing,
                amt_samples,
            ) = self.decode_step(
                indices,
                yv_log_probs,
                joint_log_probs,
                sampled_support_indices,
                parent_indexing,
                is_conditional_sample,
                amt_multi_dim_plates,
                amt_samples,
            )
        # Finally, index the support using the sampled indices to get the sample!
        if amt_samples < self.k:
            # plates x amt_samples x events
            sampled_support_indices = sampled_support_indices[
                (...,) + (slice(amt_samples),) + (slice(None),) * len(ranges)
            ]
        expanded_indices = right_expand_as(sampled_support_indices, support)
        sample = support.gather(dim=amt_multi_dim_plates, index=expanded_indices)
        return sample, joint_log_probs, parent_indexing
Exemplo n.º 12
0
    def plate_weighting(self, tensor: storch.StochasticTensor,
                        plate: storch.Plate) -> Optional[storch.Tensor]:
        # plates_w_k is a sequence of plates of which one is the input ancestral plate

        # Computes p(s) * R(S^k, s), or the probability of the sample times the leave-one-out ratio.
        # For details, see https://openreview.net/pdf?id=rklEj2EFvB
        # Code based on https://github.com/wouterkool/estimating-gradients-without-replacement/blob/master/bernoulli/gumbel.py
        log_probs = plate.log_probs.detach()
        # print("--------------------")

        # Compute integration points for the trapezoid rule: v should range from 0 to 1, where both v=0 and v=1 give a value of 0.
        # As the computation happens in log-space, take the logarithm of the result.
        # N
        v = (
            torch.arange(1, self.num_int_points, out=log_probs._tensor.new()) /
            self.num_int_points)
        log_v = v.log()

        # Compute log(1-v^{exp(log_probs+a)}) in a numerically stable way in log-space
        # Uses the gumbel_log_survival function from
        # https://github.com/wouterkool/estimating-gradients-without-replacement/blob/master/bernoulli/gumbel.py
        # plates_w_k x N
        g_bound = (log_probs[..., None] + self.a +
                   torch.log(-log_v)[log_probs.plate_dims * (None, ) +
                                     (slice(None), )])

        # Gumbel log survival: log P(g > g_bound) = log(1 - exp(-exp(-g_bound))) for standard gumbel g
        # If g_bound >= 10, use the series expansion for stability with error O((e^-10)^6) (=8.7E-27)
        # See https://www.wolframalpha.com/input/?i=log%281+-+exp%28-y%29%29
        y = torch.exp(g_bound)
        # plates_w_k x N
        terms = torch.where(g_bound >= 10,
                            -g_bound - y / 2 + y**2 / 24 - y**4 / 2880,
                            log1mexp(y))
        # print("terms", terms._tensor[0])

        # Compute integrands (without subtracting the special value s)
        # plates x N
        sum_of_terms = storch.sum(terms, plate)
        phi_S = storch.logsumexp(log_probs, plate)
        phi_D_min_S = log1mexp(phi_S)

        # plates x N
        integrand = (sum_of_terms +
                     torch.expm1(self.a + phi_D_min_S)[..., None] *
                     log_v[phi_D_min_S.plate_dims * (None, ) +
                           (slice(None), )])

        # Subtract one term the for element that is left out in R
        # Automatically unsqueezes correctly using plate dimensions
        # plates_w_k x N
        integrand_without_s = integrand - terms

        # plates
        log_p_S = integrand.logsumexp(dim=-1)
        # plates_w_k
        log_p_S_without_s = integrand_without_s.logsumexp(dim=-1)

        # plates_w_k
        log_leave_one_out = log_p_S_without_s - log_p_S

        if self.comp_leave_two_out:
            # Compute the integrands for the 2nd order leave one out ratio.
            # Make sure to properly choose the indices: We shouldn't subtract the same term twice on the diagonals.
            # k x k
            skip_diag = storch.Tensor(
                1 - torch.eye(plate.n, out=log_probs._tensor.new()), [],
                [plate])
            # plates_w_k x k x N
            integrand_without_ss = (integrand_without_s[..., None, :] -
                                    terms[..., None, :] * skip_diag[..., None])
            # plates_w_k x k
            log_p_S_without_ss = integrand_without_ss.logsumexp(dim=-1)

            plate.log_snd_leave_one_out = log_p_S_without_ss - log_p_S_without_s

        # print("lloo", log_leave_one_out._tensor[0])
        # print("log_probs", log_probs._tensor[0])
        # print("weighting", (log_leave_one_out + log_probs).exp()._tensor[0])

        # Return the unordered set estimator weighting
        return (log_leave_one_out + log_probs).exp().detach()
Exemplo n.º 13
0
def to_storch(tensor: torch.Tensor) -> storch.Tensor:
    return storch.Tensor(tensor, [], [], "test")