Exemplo n.º 1
0
    def reduce(self, tensor: storch.Tensor, detach_weights=True):
        plate_weighting = self.weight
        if detach_weights:
            plate_weighting = self.weight.detach()
        if self.n == 1:
            return storch.reduce(lambda x: x * plate_weighting, self.name)(tensor)
        # Case: The weight is a single number. First sum, then multiply with the weight (usually taking the mean)
        elif plate_weighting.ndim == 0:
            return storch.sum(tensor, self) * plate_weighting

        # Case: There is a weight for each plate which is not dependent on the other batch dimensions
        elif plate_weighting.ndim == 1:
            index = tensor.get_plate_dim_index(self.name)
            plate_weighting = plate_weighting[
                (...,) + (None,) * (tensor.ndim - index - 1)
            ]
            weighted_tensor = tensor * plate_weighting
            return storch.sum(weighted_tensor, self)

        # Case: The weight is a vector of numbers equal to batch dimension. Assumes it is a storch.Tensor
        else:
            for parent_plate in self.parents:
                if parent_plate not in tensor.plates:
                    raise ValueError(
                        "Plate missing when reducing tensor: " + parent_plate.name
                    )
            weighted_tensor = tensor * plate_weighting
            return storch.sum(weighted_tensor, self)
Exemplo n.º 2
0
 def multiplicative_estimator(
         self, tensor: storch.StochasticTensor,
         cost_node: storch.CostTensor) -> Optional[storch.Tensor]:
     cost_plate = None
     for _p in cost_node.plates:
         if _p.name == self.plate_name:
             cost_plate = _p
             break
     if self.use_baseline:
         iw = self.sampling_method.compute_iw(cost_plate, biased=False)
         BS = storch.sum(iw * cost_node, cost_plate)
         probs = cost_plate.log_probs.exp()
         if self.biased:
             # Equation 11
             WS = storch.sum(iw, cost_plate)
             WiS = (WS - iw + probs).detach()
             diff_cost = cost_node - BS / WS
             return storch.sum(iw / WiS * diff_cost.detach(), cost_plate)
         else:
             # Equation 10
             weighted_cost = cost_node * (1 - probs + iw)
             diff_cost = weighted_cost - BS
             return storch.sum(iw * diff_cost.detach(), cost_plate)
     else:
         # Equation 9
         # TODO: This seems inefficient... The plate should already contain the IW, right? Same for above if not self.biased
         iw = self.sampling_method.compute_iw(cost_plate, self.biased)
         return storch.sum(cost_node.detach() * iw, self.plate_name)
Exemplo n.º 3
0
 def weighting_function(self, tensor: storch.StochasticTensor,
                        plate: storch.Plate) -> Optional[storch.Tensor]:
     if self.eos:
         active = 1 - self.finished_samples
         amt_active: storch.Tensor = storch.sum(active, plate)
         return active / amt_active
     return super().weighting_function(tensor, plate)
Exemplo n.º 4
0
 def compute_iw(self, plate: AncestralPlate, biased: bool):
     # Compute importance weights. The kth sample has 0 weight, and is only used to compute the importance weights
     q = (1 - (-(plate.log_probs - plate.perturb_log_probs._tensor[
         ..., self.k - 1].unsqueeze(-1)).exp()).exp()).detach()
     iw = plate.log_probs.exp() / (q + self.EPS)
     # Set the weight of the kth sample (kappa) to 0.
     iw[..., self.k - 1] = 0.0
     if biased:
         WS = storch.sum(iw, plate).detach()
         return iw / WS
     return iw
Exemplo n.º 5
0
 def compute_baseline(self, tensor: StochasticTensor,
                      costs: CostTensor) -> torch.Tensor:
     if tensor.n == 1:
         raise ValueError(
             "Can only use the batch average baseline if multiple samples are used."
         )
     costs = costs.detach()
     sum_costs = storch.sum(costs, tensor.name)
     # TODO: Should reduce correctly
     baseline = (sum_costs - costs) / (tensor.n - 1)
     return baseline
Exemplo n.º 6
0
def b_binary_cross_entropy(
    input: storch.Tensor,
    target: storch.Tensor,
    dims: Union[str, List[str]] = None,
    weight=None,
    reduction: str = "mean",
):
    r"""Function that measures the Binary Cross Entropy in a batched way
    between the target and the output.

    See :class:`~torch.nn.BCELoss` for details.

    Args:
        input: Tensor of arbitrary shape
        target: Tensor of the same shape as input
        weight (Tensor, optional): a manual rescaling weight
                if provided it's repeated to match input tensor shape
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the sum of the output will be divided by the number of
            elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
            specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``

    Examples::

        >>> input = torch.randn((3, 2), requires_grad=True)
        >>> target = torch.rand((3, 2), requires_grad=False)
        >>> loss = b_binary_cross_entropy(F.sigmoid(input), target)
        >>> loss.backward()
    """

    if not dims:
        dims = []
    if isinstance(dims, str):
        dims = [dims]

    target = target.expand_as(input)
    unreduced = deterministic(F.binary_cross_entropy)(input,
                                                      target,
                                                      weight,
                                                      reduction="none")

    # unreduced = _loss(input, target, weight)
    indices = list(unreduced.event_dim_indices) + dims

    if reduction == "mean":
        return storch.mean(unreduced, indices)
    elif reduction == "sum":
        return storch.sum(unreduced, indices)
    elif reduction == "none":
        return unreduced
Exemplo n.º 7
0
    def estimator(
        self, tensor: storch.StochasticTensor, cost_node: storch.CostTensor
    ) -> Tuple[Optional[storch.Tensor], Optional[storch.Tensor]]:
        # Note: We automatically multiply with leave-one-out ratio in the plate reduction
        plate = None
        for _p in cost_node.plates:
            if _p.name == self.plate_name:
                plate = _p
                break
        if not self.use_baseline:
            return plate.log_probs * cost_node.detach()

        # Subtract the 'average' cost of other samples, keeping in mind that samples are not independent.
        # plates x k
        baseline = storch.sum(
            (plate.log_probs + plate.log_snd_leave_one_out).exp() * cost_node,
            plate)
        # Make sure the k dimension is recognized as batch dimension
        baseline = storch.Tensor(baseline._tensor, [baseline],
                                 baseline.plates + [plate])
        return plate.log_probs, (
            1 - magic_box(plate.log_probs)) * baseline.detach()
Exemplo n.º 8
0
    def plate_weighting(self, tensor: storch.StochasticTensor,
                        plate: storch.Plate) -> Optional[storch.Tensor]:
        # plates_w_k is a sequence of plates of which one is the input ancestral plate

        # Computes p(s) * R(S^k, s), or the probability of the sample times the leave-one-out ratio.
        # For details, see https://openreview.net/pdf?id=rklEj2EFvB
        # Code based on https://github.com/wouterkool/estimating-gradients-without-replacement/blob/master/bernoulli/gumbel.py
        log_probs = plate.log_probs.detach()
        # print("--------------------")

        # Compute integration points for the trapezoid rule: v should range from 0 to 1, where both v=0 and v=1 give a value of 0.
        # As the computation happens in log-space, take the logarithm of the result.
        # N
        v = (
            torch.arange(1, self.num_int_points, out=log_probs._tensor.new()) /
            self.num_int_points)
        log_v = v.log()

        # Compute log(1-v^{exp(log_probs+a)}) in a numerically stable way in log-space
        # Uses the gumbel_log_survival function from
        # https://github.com/wouterkool/estimating-gradients-without-replacement/blob/master/bernoulli/gumbel.py
        # plates_w_k x N
        g_bound = (log_probs[..., None] + self.a +
                   torch.log(-log_v)[log_probs.plate_dims * (None, ) +
                                     (slice(None), )])

        # Gumbel log survival: log P(g > g_bound) = log(1 - exp(-exp(-g_bound))) for standard gumbel g
        # If g_bound >= 10, use the series expansion for stability with error O((e^-10)^6) (=8.7E-27)
        # See https://www.wolframalpha.com/input/?i=log%281+-+exp%28-y%29%29
        y = torch.exp(g_bound)
        # plates_w_k x N
        terms = torch.where(g_bound >= 10,
                            -g_bound - y / 2 + y**2 / 24 - y**4 / 2880,
                            log1mexp(y))
        # print("terms", terms._tensor[0])

        # Compute integrands (without subtracting the special value s)
        # plates x N
        sum_of_terms = storch.sum(terms, plate)
        phi_S = storch.logsumexp(log_probs, plate)
        phi_D_min_S = log1mexp(phi_S)

        # plates x N
        integrand = (sum_of_terms +
                     torch.expm1(self.a + phi_D_min_S)[..., None] *
                     log_v[phi_D_min_S.plate_dims * (None, ) +
                           (slice(None), )])

        # Subtract one term the for element that is left out in R
        # Automatically unsqueezes correctly using plate dimensions
        # plates_w_k x N
        integrand_without_s = integrand - terms

        # plates
        log_p_S = integrand.logsumexp(dim=-1)
        # plates_w_k
        log_p_S_without_s = integrand_without_s.logsumexp(dim=-1)

        # plates_w_k
        log_leave_one_out = log_p_S_without_s - log_p_S

        if self.comp_leave_two_out:
            # Compute the integrands for the 2nd order leave one out ratio.
            # Make sure to properly choose the indices: We shouldn't subtract the same term twice on the diagonals.
            # k x k
            skip_diag = storch.Tensor(
                1 - torch.eye(plate.n, out=log_probs._tensor.new()), [],
                [plate])
            # plates_w_k x k x N
            integrand_without_ss = (integrand_without_s[..., None, :] -
                                    terms[..., None, :] * skip_diag[..., None])
            # plates_w_k x k
            log_p_S_without_ss = integrand_without_ss.logsumexp(dim=-1)

            plate.log_snd_leave_one_out = log_p_S_without_ss - log_p_S_without_s

        # print("lloo", log_leave_one_out._tensor[0])
        # print("log_probs", log_probs._tensor[0])
        # print("weighting", (log_leave_one_out + log_probs).exp()._tensor[0])

        # Return the unordered set estimator weighting
        return (log_leave_one_out + log_probs).exp().detach()