def reduce(self, tensor: storch.Tensor, detach_weights=True): plate_weighting = self.weight if detach_weights: plate_weighting = self.weight.detach() if self.n == 1: return storch.reduce(lambda x: x * plate_weighting, self.name)(tensor) # Case: The weight is a single number. First sum, then multiply with the weight (usually taking the mean) elif plate_weighting.ndim == 0: return storch.sum(tensor, self) * plate_weighting # Case: There is a weight for each plate which is not dependent on the other batch dimensions elif plate_weighting.ndim == 1: index = tensor.get_plate_dim_index(self.name) plate_weighting = plate_weighting[ (...,) + (None,) * (tensor.ndim - index - 1) ] weighted_tensor = tensor * plate_weighting return storch.sum(weighted_tensor, self) # Case: The weight is a vector of numbers equal to batch dimension. Assumes it is a storch.Tensor else: for parent_plate in self.parents: if parent_plate not in tensor.plates: raise ValueError( "Plate missing when reducing tensor: " + parent_plate.name ) weighted_tensor = tensor * plate_weighting return storch.sum(weighted_tensor, self)
def multiplicative_estimator( self, tensor: storch.StochasticTensor, cost_node: storch.CostTensor) -> Optional[storch.Tensor]: cost_plate = None for _p in cost_node.plates: if _p.name == self.plate_name: cost_plate = _p break if self.use_baseline: iw = self.sampling_method.compute_iw(cost_plate, biased=False) BS = storch.sum(iw * cost_node, cost_plate) probs = cost_plate.log_probs.exp() if self.biased: # Equation 11 WS = storch.sum(iw, cost_plate) WiS = (WS - iw + probs).detach() diff_cost = cost_node - BS / WS return storch.sum(iw / WiS * diff_cost.detach(), cost_plate) else: # Equation 10 weighted_cost = cost_node * (1 - probs + iw) diff_cost = weighted_cost - BS return storch.sum(iw * diff_cost.detach(), cost_plate) else: # Equation 9 # TODO: This seems inefficient... The plate should already contain the IW, right? Same for above if not self.biased iw = self.sampling_method.compute_iw(cost_plate, self.biased) return storch.sum(cost_node.detach() * iw, self.plate_name)
def weighting_function(self, tensor: storch.StochasticTensor, plate: storch.Plate) -> Optional[storch.Tensor]: if self.eos: active = 1 - self.finished_samples amt_active: storch.Tensor = storch.sum(active, plate) return active / amt_active return super().weighting_function(tensor, plate)
def compute_iw(self, plate: AncestralPlate, biased: bool): # Compute importance weights. The kth sample has 0 weight, and is only used to compute the importance weights q = (1 - (-(plate.log_probs - plate.perturb_log_probs._tensor[ ..., self.k - 1].unsqueeze(-1)).exp()).exp()).detach() iw = plate.log_probs.exp() / (q + self.EPS) # Set the weight of the kth sample (kappa) to 0. iw[..., self.k - 1] = 0.0 if biased: WS = storch.sum(iw, plate).detach() return iw / WS return iw
def compute_baseline(self, tensor: StochasticTensor, costs: CostTensor) -> torch.Tensor: if tensor.n == 1: raise ValueError( "Can only use the batch average baseline if multiple samples are used." ) costs = costs.detach() sum_costs = storch.sum(costs, tensor.name) # TODO: Should reduce correctly baseline = (sum_costs - costs) / (tensor.n - 1) return baseline
def b_binary_cross_entropy( input: storch.Tensor, target: storch.Tensor, dims: Union[str, List[str]] = None, weight=None, reduction: str = "mean", ): r"""Function that measures the Binary Cross Entropy in a batched way between the target and the output. See :class:`~torch.nn.BCELoss` for details. Args: input: Tensor of arbitrary shape target: Tensor of the same shape as input weight (Tensor, optional): a manual rescaling weight if provided it's repeated to match input tensor shape reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` Examples:: >>> input = torch.randn((3, 2), requires_grad=True) >>> target = torch.rand((3, 2), requires_grad=False) >>> loss = b_binary_cross_entropy(F.sigmoid(input), target) >>> loss.backward() """ if not dims: dims = [] if isinstance(dims, str): dims = [dims] target = target.expand_as(input) unreduced = deterministic(F.binary_cross_entropy)(input, target, weight, reduction="none") # unreduced = _loss(input, target, weight) indices = list(unreduced.event_dim_indices) + dims if reduction == "mean": return storch.mean(unreduced, indices) elif reduction == "sum": return storch.sum(unreduced, indices) elif reduction == "none": return unreduced
def estimator( self, tensor: storch.StochasticTensor, cost_node: storch.CostTensor ) -> Tuple[Optional[storch.Tensor], Optional[storch.Tensor]]: # Note: We automatically multiply with leave-one-out ratio in the plate reduction plate = None for _p in cost_node.plates: if _p.name == self.plate_name: plate = _p break if not self.use_baseline: return plate.log_probs * cost_node.detach() # Subtract the 'average' cost of other samples, keeping in mind that samples are not independent. # plates x k baseline = storch.sum( (plate.log_probs + plate.log_snd_leave_one_out).exp() * cost_node, plate) # Make sure the k dimension is recognized as batch dimension baseline = storch.Tensor(baseline._tensor, [baseline], baseline.plates + [plate]) return plate.log_probs, ( 1 - magic_box(plate.log_probs)) * baseline.detach()
def plate_weighting(self, tensor: storch.StochasticTensor, plate: storch.Plate) -> Optional[storch.Tensor]: # plates_w_k is a sequence of plates of which one is the input ancestral plate # Computes p(s) * R(S^k, s), or the probability of the sample times the leave-one-out ratio. # For details, see https://openreview.net/pdf?id=rklEj2EFvB # Code based on https://github.com/wouterkool/estimating-gradients-without-replacement/blob/master/bernoulli/gumbel.py log_probs = plate.log_probs.detach() # print("--------------------") # Compute integration points for the trapezoid rule: v should range from 0 to 1, where both v=0 and v=1 give a value of 0. # As the computation happens in log-space, take the logarithm of the result. # N v = ( torch.arange(1, self.num_int_points, out=log_probs._tensor.new()) / self.num_int_points) log_v = v.log() # Compute log(1-v^{exp(log_probs+a)}) in a numerically stable way in log-space # Uses the gumbel_log_survival function from # https://github.com/wouterkool/estimating-gradients-without-replacement/blob/master/bernoulli/gumbel.py # plates_w_k x N g_bound = (log_probs[..., None] + self.a + torch.log(-log_v)[log_probs.plate_dims * (None, ) + (slice(None), )]) # Gumbel log survival: log P(g > g_bound) = log(1 - exp(-exp(-g_bound))) for standard gumbel g # If g_bound >= 10, use the series expansion for stability with error O((e^-10)^6) (=8.7E-27) # See https://www.wolframalpha.com/input/?i=log%281+-+exp%28-y%29%29 y = torch.exp(g_bound) # plates_w_k x N terms = torch.where(g_bound >= 10, -g_bound - y / 2 + y**2 / 24 - y**4 / 2880, log1mexp(y)) # print("terms", terms._tensor[0]) # Compute integrands (without subtracting the special value s) # plates x N sum_of_terms = storch.sum(terms, plate) phi_S = storch.logsumexp(log_probs, plate) phi_D_min_S = log1mexp(phi_S) # plates x N integrand = (sum_of_terms + torch.expm1(self.a + phi_D_min_S)[..., None] * log_v[phi_D_min_S.plate_dims * (None, ) + (slice(None), )]) # Subtract one term the for element that is left out in R # Automatically unsqueezes correctly using plate dimensions # plates_w_k x N integrand_without_s = integrand - terms # plates log_p_S = integrand.logsumexp(dim=-1) # plates_w_k log_p_S_without_s = integrand_without_s.logsumexp(dim=-1) # plates_w_k log_leave_one_out = log_p_S_without_s - log_p_S if self.comp_leave_two_out: # Compute the integrands for the 2nd order leave one out ratio. # Make sure to properly choose the indices: We shouldn't subtract the same term twice on the diagonals. # k x k skip_diag = storch.Tensor( 1 - torch.eye(plate.n, out=log_probs._tensor.new()), [], [plate]) # plates_w_k x k x N integrand_without_ss = (integrand_without_s[..., None, :] - terms[..., None, :] * skip_diag[..., None]) # plates_w_k x k log_p_S_without_ss = integrand_without_ss.logsumexp(dim=-1) plate.log_snd_leave_one_out = log_p_S_without_ss - log_p_S_without_s # print("lloo", log_leave_one_out._tensor[0]) # print("log_probs", log_probs._tensor[0]) # print("weighting", (log_leave_one_out + log_probs).exp()._tensor[0]) # Return the unordered set estimator weighting return (log_leave_one_out + log_probs).exp().detach()