Exemplo n.º 1
0
    def sample(
        self,
        distr: Distribution,
        parents: [storch.Tensor],
        plates: [Plate],
        requires_grad: bool,
    ) -> (storch.StochasticTensor, Plate):
        # TODO: Currently very inefficient as it isn't batched
        # TODO: What if the expectation has a parent
        if not distr.has_enumerate_support:
            raise ValueError(
                "Can only calculate the expected value for distributions with enumerable support."
            )
        support: torch.Tensor = distr.enumerate_support(expand=True)
        support_non_expanded: torch.Tensor = distr.enumerate_support(
            expand=False)
        expect_size = support.shape[0]

        batch_len = len(plates)
        sizes = support.shape[batch_len + 1:len(support.shape) -
                              len(distr.event_shape)]
        amt_samples_used = expect_size
        cross_products = 1 if not sizes else None
        for dim in sizes:
            amt_samples_used = amt_samples_used**dim
            if not cross_products:
                cross_products = dim
            else:
                cross_products = cross_products**dim

        if amt_samples_used > self.budget:
            raise ValueError(
                "Computing the expectation on this distribution would exceed the computation budget."
            )

        enumerate_tensor = support.new_zeros([amt_samples_used] +
                                             list(support.shape[1:]))
        support_non_expanded = support_non_expanded.squeeze().unsqueeze(1)
        for i, t in enumerate(
                itertools.product(support_non_expanded,
                                  repeat=cross_products)):
            enumerate_tensor[i] = torch.cat(t, dim=0)

        enumerate_tensor = enumerate_tensor.detach()

        plate_size = enumerate_tensor.shape[0]

        plate = Plate(self.plate_name, plate_size, plates.copy())
        plates.insert(0, plate)

        s_tensor = storch.StochasticTensor(
            enumerate_tensor,
            parents,
            plates,
            self.plate_name,
            plate_size,
            distr,
            requires_grad,
        )
        return s_tensor, plate
Exemplo n.º 2
0
    def decode(
        self,
        distr: Distribution,
        joint_log_probs: Optional[storch.Tensor],
        parents: [storch.Tensor],
        orig_distr_plates: [storch.Plate],
    ) -> (storch.Tensor, storch.Tensor, storch.Tensor):
        """
        Decode given the input arguments
        :param distribution: The distribution to decode
        :param joint_log_probs: The log probabilities of the samples so far. prev_plates x amt_samples
        :param parents: List of parents of this tensor
        :param orig_distr_plates: List of plates from the distribution. Can include the self plate k.
        :return: 3-tuple of `storch.Tensor`. 1: The sampled value. 2: The new joint log probabilities of the samples.
        3: How the samples index the parent samples. Can just be a range if there is no choosing happening.
        """
        ancestral_distrplate_index = -1
        is_conditional_sample = False

        multi_dim_distr_plates = []
        multi_dim_index = 0
        for plate in orig_distr_plates:
            if plate.n > 1:
                if plate.name == self.plate_name:
                    ancestral_distrplate_index = multi_dim_index
                    is_conditional_sample = True
                else:
                    multi_dim_distr_plates.append(plate)
                multi_dim_index += 1
        # plates? x k x events x

        # TODO: This doesn't properly combine two ancestral plates with the same name but different variable index
        #  (they should merge).
        all_multi_dim_plates = multi_dim_distr_plates.copy()
        if self.variable_index > 0:
            # Previous variables have been sampled. add the prev_plates to all_plates
            for plate in self.joint_log_probs.multi_dim_plates():
                if plate not in multi_dim_distr_plates:
                    all_multi_dim_plates.append(plate)

        amt_multi_dim_plates = len(all_multi_dim_plates)
        amt_multi_dim_distr_plates = len(multi_dim_distr_plates)
        amt_multi_dim_orig_distr_plates = amt_multi_dim_distr_plates + (
            1 if is_conditional_sample else 0
        )
        amt_multi_dim_prev_plates = amt_multi_dim_plates - amt_multi_dim_distr_plates
        if not distr.has_enumerate_support:
            raise ValueError("Can only decode distributions with enumerable support.")

        with storch.ignore_wrapping():
            # |D_yv| x (|distr_plates| + |k?| + |event_dims|) * (1,) x |D_yv|
            support_non_expanded: torch.Tensor = distr.enumerate_support(expand=False)
            # Compute the log-probability of the different events
            # |D_yv| x distr_plate[0] x ... k? ... x distr_plate[n-1] x events
            d_log_probs = distr.log_prob(support_non_expanded)

            # Note: Use amt_orig_distr_plates here because it might include k? dimension. amt_distr_plates filters this one.
            # distr_plate[0] x ... k? ... x distr_plate[n-1] x |D_yv| x events
            d_log_probs = storch.Tensor(
                d_log_probs.permute(
                    tuple(range(1, amt_multi_dim_orig_distr_plates + 1))
                    + (0,)
                    + tuple(
                        range(
                            amt_multi_dim_orig_distr_plates + 1, len(d_log_probs.shape)
                        )
                    )
                ),
                [],
                orig_distr_plates,
            )

        # |D_yv| x distr_plate[0] x ... x k? x ... x distr_plate[n-1] x events x event_shape
        support = distr.enumerate_support(expand=True)

        if is_conditional_sample:
            # Reduce ancestral dimension in the support. As the dimension is just an expanded version, this should
            # not change the underlying data.
            # |D_yv| x distr_plates x events x event_shape
            support = support[(slice(None),) * (ancestral_distrplate_index + 1) + (0,)]

            # Gather the correct log probabilities
            # distr_plate[0] x ... k ... x distr_plate[n-1] x |D_yv| x events
            # TODO: Move this down below to the other scary TODO
            d_log_probs = self.new_plate.on_unwrap_tensor(d_log_probs)
            # Permute the dimensions of d_log_probs st the k dimension is after the plates.
            for i, plate in enumerate(d_log_probs.multi_dim_plates()):
                if plate.name == self.plate_name:
                    d_log_probs.plates.remove(plate)
                    # distr_plates x k x |D_yv| x events
                    d_log_probs._tensor = d_log_probs._tensor.permute(
                        tuple(range(0, i))
                        + tuple(range(i + 1, amt_multi_dim_orig_distr_plates))
                        + (i,)
                        + tuple(
                            range(
                                amt_multi_dim_orig_distr_plates, len(d_log_probs.shape)
                            )
                        )
                    )
                    break

        # Equal to event_shape
        element_shape = distr.event_shape
        support_permutation = (
            tuple(range(1, amt_multi_dim_distr_plates + 1))
            + (0,)
            + tuple(range(amt_multi_dim_distr_plates + 1, len(support.shape)))
        )
        # distr_plates x |D_yv| x events x event_shape
        support = support.permute(support_permutation)

        if amt_multi_dim_plates != amt_multi_dim_distr_plates:
            # If previous samples had plate dimensions that are not in the distribution plates, add these to the support.
            support = support[
                (slice(None),) * amt_multi_dim_distr_plates
                + (None,) * amt_multi_dim_prev_plates
            ]
            all_plate_dims = tuple(map(lambda _p: _p.n, all_multi_dim_plates))
            # plates x |D_yv| x events x event_shape (where plates = distr_plates x prev_plates)
            support = support.expand(
                all_plate_dims + (-1,) * (len(support.shape) - amt_multi_dim_plates)
            )
        # plates x |D_yv| x events x event_shape
        support = storch.Tensor(support, [], all_multi_dim_plates)

        # Equal to events: Shape for the different conditional independent dimensions
        event_shape = support.shape[
            amt_multi_dim_plates + 1 : -len(element_shape)
            if len(element_shape) > 0
            else None
        ]

        ranges = []
        for size in event_shape:
            ranges.append(list(range(size)))

        amt_samples = 0
        parent_indexing = None
        if joint_log_probs is not None:
            # Initialize a tensor (self.parent_indexing) that keeps track of what samples link to previous choices of samples
            # Note that joint_log_probs.shape[-1] is amt_samples, not k. It's possible that amt_samples < k!
            amt_samples = joint_log_probs.shape[-1]
            # plates x k
            parent_indexing = support.new_zeros(
                size=support.shape[:amt_multi_dim_plates] + (self.k,), dtype=torch.long
            )

            # probably can go wrong if plates are missing.
            parent_indexing[..., :amt_samples] = left_expand_as(
                torch.arange(amt_samples), parent_indexing
            )
        # plates x k x events
        sampled_support_indices = support.new_zeros(
            size=support.shape[:amt_multi_dim_plates]  # plates
            + (self.k,)
            + support.shape[
                amt_multi_dim_plates + 1 : -len(element_shape)
                if len(element_shape) > 0
                else None
            ],  # events
            dtype=torch.long,
        )
        # Sample independent tensors in sequence
        # Iterate over the different (conditionally) independent samples being taken (the events)
        for indices in itertools.product(*ranges):
            # Log probabilities of the different options for this sample step (event)
            # distr_plates x k? x |D_yv|
            yv_log_probs = d_log_probs[(...,) + indices]
            (
                sampled_support_indices,
                joint_log_probs,
                parent_indexing,
                amt_samples,
            ) = self.decode_step(
                indices,
                yv_log_probs,
                joint_log_probs,
                sampled_support_indices,
                parent_indexing,
                is_conditional_sample,
                amt_multi_dim_plates,
                amt_samples,
            )
        # Finally, index the support using the sampled indices to get the sample!
        if amt_samples < self.k:
            # plates x amt_samples x events
            sampled_support_indices = sampled_support_indices[
                (...,) + (slice(amt_samples),) + (slice(None),) * len(ranges)
            ]
        expanded_indices = right_expand_as(sampled_support_indices, support)
        sample = support.gather(dim=amt_multi_dim_plates, index=expanded_indices)
        return sample, joint_log_probs, parent_indexing