예제 #1
0
    def sample(
        self,
        distr: Distribution,
        parents: [storch.Tensor],
        plates: [Plate],
        requires_grad: bool,
    ) -> (storch.StochasticTensor, Plate):
        plate = None
        for _plate in plates:
            if _plate.name == self.plate_name:
                plate = _plate
                break
        n_samples = 1 if plate else self.n_samples
        with storch.ignore_wrapping():
            tensor = self.mc_sample(distr, parents, plates, n_samples)
        plate_size = tensor.shape[0]
        if tensor.shape[0] == 1:
            tensor = tensor.squeeze(0)

        if not plate:
            plate = Plate(self.plate_name, plate_size, plates.copy())
            plates.insert(0, plate)

        if isinstance(tensor, storch.Tensor):
            tensor = tensor._tensor

        s_tensor = storch.StochasticTensor(
            tensor, parents, plates, self.plate_name, plate_size, distr, requires_grad,
        )
        return s_tensor, plate
예제 #2
0
    def sample(
        self,
        distr: Distribution,
        parents: [storch.Tensor],
        plates: [Plate],
        requires_grad: bool,
    ) -> (storch.StochasticTensor, Plate):
        # TODO: Currently very inefficient as it isn't batched
        # TODO: What if the expectation has a parent
        if not distr.has_enumerate_support:
            raise ValueError(
                "Can only calculate the expected value for distributions with enumerable support."
            )
        support: torch.Tensor = distr.enumerate_support(expand=True)
        support_non_expanded: torch.Tensor = distr.enumerate_support(
            expand=False)
        expect_size = support.shape[0]

        batch_len = len(plates)
        sizes = support.shape[batch_len + 1:len(support.shape) -
                              len(distr.event_shape)]
        amt_samples_used = expect_size
        cross_products = 1 if not sizes else None
        for dim in sizes:
            amt_samples_used = amt_samples_used**dim
            if not cross_products:
                cross_products = dim
            else:
                cross_products = cross_products**dim

        if amt_samples_used > self.budget:
            raise ValueError(
                "Computing the expectation on this distribution would exceed the computation budget."
            )

        enumerate_tensor = support.new_zeros([amt_samples_used] +
                                             list(support.shape[1:]))
        support_non_expanded = support_non_expanded.squeeze().unsqueeze(1)
        for i, t in enumerate(
                itertools.product(support_non_expanded,
                                  repeat=cross_products)):
            enumerate_tensor[i] = torch.cat(t, dim=0)

        enumerate_tensor = enumerate_tensor.detach()

        plate_size = enumerate_tensor.shape[0]

        plate = Plate(self.plate_name, plate_size, plates.copy())
        plates.insert(0, plate)

        s_tensor = storch.StochasticTensor(
            enumerate_tensor,
            parents,
            plates,
            self.plate_name,
            plate_size,
            distr,
            requires_grad,
        )
        return s_tensor, plate
예제 #3
0
    def sample(
        self,
        distr: Distribution,
        parents: [storch.Tensor],
        orig_distr_plates: [storch.Plate],
        requires_grad: bool,
    ) -> (torch.Tensor, storch.Plate):
        hard_sample, plate = super().sample(
            distr, parents, orig_distr_plates, requires_grad
        )
        from storch import conditional_gumbel_rsample

        gumbel_wor = conditional_gumbel_rsample(hard_sample, distr.probs,
                                                isinstance(distr, torch.distributions.Bernoulli), self.temperature)
        gumbel_wor = storch.StochasticTensor(
            gumbel_wor._tensor,
            list(zip(*hard_sample._parents))[0],
            hard_sample.plates,
            hard_sample.name,
            self.k,
            distr,
            requires_grad,
        )
        return gumbel_wor, plate
예제 #4
0
파일: seq.py 프로젝트: dudugang/storchastic
    def sample(
        self,
        distr: Distribution,
        parents: [storch.Tensor],
        orig_distr_plates: [storch.Plate],
        requires_grad: bool,
    ) -> (torch.Tensor, storch.Plate):

        """
        Sample from the distribution given the sequence so far.
        :param distribution: The distribution to sample from
        :return:
        """

        # This code has three parts.
        # The first prepares all necessary tensors to make sure they can be easily indexed.
        # This part is quite long as there are many cases.
        # 1) There have not been any variables sampled so far.
        # 2) There have been variables sampled, but their results are NOT used to compute the input distribution.
        #    in other words, the variable to sample is independent of the other sampled variables. However,
        #    we should still keep track of the other sampled variables to make sure that it still samples without
        #    replacement properly. In this case, the ancestral plate is not in the plates attribute.
        #    We also have to make sure that we know in the future what samples are chosen for the _other_ samples.
        # 3) There have been parents sampled, and this variable is dependent on at least some of them.
        #    The plates list then contains the ancestral plate. We need to make sure we compute the joint log probs
        #    for the conditional samples (ie, based on the different sampled variables in the ancestral dimension).
        # The second part is a loop over all options for the event dimensions. This samples these conditionally
        # independent samples in sequence. It samples indexes, not events.
        # The third part after the loop uses the sampled indexes and matches it to the events to be used.

        # LEGEND FOR SHAPE COMMENTS
        # =========================
        # To make this code generalize to every bayesian network, complicated shape management is necessary.
        # The following are references to the shapes that are used within the method
        #
        # distr_plates: refers to the plates on the parameters of the distribution. Does *not* include
        #  the k? ancestral plate (possibly empty)
        # orig_distr_plates: refers to the plates on the parameters of the distribution, and *can* include
        #  the k? ancestral plate (possibly empty)
        # prev_plates: refers to the plates of the previous sampled variable in this swr sample (possibly empty)
        # plates: refers to all plates except this ancestral plate, of which there are amt_plates. The first plates are
        #  the distr_plates, after that the prev_plates that are _not_ in distr_plates.
        #  It is composed of distr_plate x (ancstr_plates - distr_plates)
        # events: refers to the conditionally independent dimensions of the distribution (the distributions batch shape minus the plates)
        # k: refers to self.k
        # k?: refers to an optional plate dimension of this ancestral plate. It either doesn't exist, or is the sample
        #  dimension. If it exists, this means this sample is conditionally dependent on ancestors.
        # |D_yv|: refers to the *size* of the domain
        # amt_samples: refers to the current amount of sampled sequences. amt_samples <= k, but it can be lower if there
        #  are not enough events to sample from (eg |D_yv| < k)
        # event_shape: refers to the *shape* of the domain elements
        #  (can be 0, eg Categorical, or equal to |D_yv| for OneHotCategorical)

        # Do the decoding step given the prepared tensors
        samples, self.joint_log_probs, self.parent_indexing = self.decode(
            distr, self.joint_log_probs, parents, orig_distr_plates
        )

        # Find out what sequences have reached the EOS token, and make sure to always sample EOS after that.
        # Does not contain the ancestral plate as this uses samples instead of s_tensor.
        if self.eos:
            self.finished_samples = samples.eq(self.eos)

        k_index = 0
        plates = orig_distr_plates
        if isinstance(samples, storch.Tensor):
            k_index = samples.plate_dims
            plates = samples.plates
            samples = samples._tensor

        plate_size = samples.shape[k_index]

        # Remove the ancestral plate, if it already happens to be in
        to_remove = None
        for plate in plates:
            if plate.name == self.plate_name:
                to_remove = plate
                break
        if to_remove:
            plates.remove(to_remove)

        # Create the newly updated plate
        self.new_plate = self.create_plate(plate_size, plates.copy())
        plates.append(self.new_plate)

        if self.parent_indexing is not None:
            self.parent_indexing.plates.append(self.new_plate)

        self.seq = list(map(lambda t: self.new_plate.on_unwrap_tensor(t), self.seq))

        # Construct the stochastic tensor
        s_tensor = storch.StochasticTensor(
            samples,
            parents,
            plates,
            self.plate_name,
            plate_size,
            distr,
            requires_grad or self.joint_log_probs.requires_grad,
        )

        self.seq.append(s_tensor)

        # Increase variable index
        self.variable_index += 1
        return s_tensor, self.new_plate