def test_safe_log(): # Test values. x = torch.randn(1000).exp().requires_grad_() expected = x.log() actual = safe_log(x) assert_equal(actual, expected) assert_equal(grad(actual.sum(), [x])[0], grad(expected.sum(), [x])[0]) # Test gradients. x = torch.tensor(0., requires_grad=True) assert not torch.isfinite(grad(x.log(), [x])[0]) assert torch.isfinite(grad(safe_log(x), [x])[0])
def quantize_enumerate(x_real, min, max): """ Randomly quantize in a way that preserves probability mass. We use a piecewise polynomial spline of order 3. """ assert min < max lb = x_real.detach().floor() # This cubic spline interpolates over the nearest four integers, ensuring # piecewise quadratic gradients. s = x_real - lb ss = s * s t = 1 - s tt = t * t probs = torch.stack( [ t * tt, 4 + ss * (3 * s - 6), 4 + tt * (3 * t - 6), s * ss, ], dim=-1, ) * (1 / 6) logits = safe_log(probs) q = torch.arange(-1.0, 3.0) x = lb.unsqueeze(-1) + q x = torch.max(x, 2 * min - 1 - x) x = torch.min(x, 2 * max + 1 - x) return x, logits
def quantize_enumerate(x_real, min, max, num_quant_bins=4): """Quantize, then manually enumerate.""" assert _all(min < max) lb = x_real.detach().floor() probs = compute_bin_probs(x_real - lb, num_quant_bins=num_quant_bins) logits = safe_log(probs) arange_min = 1 - num_quant_bins // 2 arange_max = 1 + num_quant_bins // 2 q = torch.arange(arange_min, arange_max) x = lb.unsqueeze(-1) + q x = torch.max(x, 2 * _unsqueeze(min) - 1 - x) x = torch.min(x, 2 * _unsqueeze(max) + 1 - x) return x, logits
def log_prob(self, value): """ Computes likelihood as in equations 7-8 of [3]. This has time complexity ``O(T + S N log(N))`` where ``T`` is the number of time steps, ``N`` is the number of leaves, and ``S = sample_shape.numel()`` is the number of samples of ``value``. This is differentiable wrt ``rate_grid`` but neither ``leaf_times`` nor ``value = coal_times``. :param torch.Tensor value: A tensor of coalescent times. These denote sets of size ``leaf_times.size(-1) - 1`` along the trailing dimension and should be sorted along that dimension. :returns: Likelihood ``p(coal_times | leaf_times, rate_grid)`` :rtype: torch.Tensor """ if self._validate_args: self._validate_sample(value) coal_times = value phylogeny = _make_phylogeny(self.leaf_times, coal_times) # Compute survival factors for closed intervals. cumsum = self.rate_grid.cumsum(-1) cumsum = torch.nn.functional.pad(cumsum, (1, 0), value=0) integral = _interpolate_gather( cumsum, phylogeny.times[..., 1:]) # ignore the final lonely leaf integral = integral[..., :-1] - integral[..., 1:] integral = integral.clamp(min=torch.finfo( integral.dtype).tiny) # avoid nan log_prob = -(phylogeny.binomial[..., 1:-1] * integral).sum(-1) # Compute density of coalescent events. i = coal_times.floor().clamp(min=0, max=self.duration - 1).long() rates = phylogeny.coal_binomial * _gather(self.rate_grid, -1, i) log_prob = log_prob + safe_log(rates).sum(-1) batch_shape = broadcast_shape(self.batch_shape, value.shape[:-1]) log_prob = log_prob.expand(batch_shape) return log_prob
def einsum(equation, *operands): """ Log-sum-exp implementation of einsum. """ # rename symbols to support PyTorch 0.4.1 and earlier, # which allow only symbols a-z. symbols = sorted(set(equation) - set(",->")) rename = dict(zip(symbols, "abcdefghijklmnopqrstuvwxyz")) equation = "".join(rename.get(s, s) for s in equation) inputs, output = equation.split("->") if inputs == output: return operands[0][...] # create a new object inputs = inputs.split(",") shifts = [] exp_operands = [] for dims, operand in zip(inputs, operands): shift = operand.detach() for i, dim in enumerate(dims): if dim not in output: shift = shift.max(i, keepdim=True)[0] # avoid nan due to -inf - -inf shift = shift.clamp(min=torch.finfo(shift.dtype).min) exp_operands.append((operand - shift).exp()) # permute shift to match output shift = shift.reshape( torch.Size(size for size, dim in zip(operand.shape, dims) if dim in output) ) if shift.dim(): shift = shift.reshape((1,) * (len(output) - shift.dim()) + shift.shape) dims = [dim for dim in dims if dim in output] dims = [dim for dim in output if dim not in dims] + dims shift = shift.permute(*(dims.index(dim) for dim in output)) shifts.append(shift) result = safe_log(torch.einsum(equation, exp_operands)) return sum(shifts + [result])