コード例 #1
0
ファイル: test_special.py プロジェクト: zeta1999/pyro
def test_safe_log():
    # Test values.
    x = torch.randn(1000).exp().requires_grad_()
    expected = x.log()
    actual = safe_log(x)
    assert_equal(actual, expected)
    assert_equal(grad(actual.sum(), [x])[0], grad(expected.sum(), [x])[0])

    # Test gradients.
    x = torch.tensor(0., requires_grad=True)
    assert not torch.isfinite(grad(x.log(), [x])[0])
    assert torch.isfinite(grad(safe_log(x), [x])[0])
コード例 #2
0
def quantize_enumerate(x_real, min, max):
    """
    Randomly quantize in a way that preserves probability mass.
    We use a piecewise polynomial spline of order 3.
    """
    assert min < max
    lb = x_real.detach().floor()

    # This cubic spline interpolates over the nearest four integers, ensuring
    # piecewise quadratic gradients.
    s = x_real - lb
    ss = s * s
    t = 1 - s
    tt = t * t
    probs = torch.stack(
        [
            t * tt,
            4 + ss * (3 * s - 6),
            4 + tt * (3 * t - 6),
            s * ss,
        ],
        dim=-1,
    ) * (1 / 6)
    logits = safe_log(probs)
    q = torch.arange(-1.0, 3.0)

    x = lb.unsqueeze(-1) + q
    x = torch.max(x, 2 * min - 1 - x)
    x = torch.min(x, 2 * max + 1 - x)
    return x, logits
コード例 #3
0
def quantize_enumerate(x_real, min, max, num_quant_bins=4):
    """Quantize, then manually enumerate."""
    assert _all(min < max)
    lb = x_real.detach().floor()

    probs = compute_bin_probs(x_real - lb, num_quant_bins=num_quant_bins)
    logits = safe_log(probs)

    arange_min = 1 - num_quant_bins // 2
    arange_max = 1 + num_quant_bins // 2
    q = torch.arange(arange_min, arange_max)

    x = lb.unsqueeze(-1) + q
    x = torch.max(x, 2 * _unsqueeze(min) - 1 - x)
    x = torch.min(x, 2 * _unsqueeze(max) + 1 - x)

    return x, logits
コード例 #4
0
ファイル: coalescent.py プロジェクト: zeta1999/pyro
    def log_prob(self, value):
        """
        Computes likelihood as in equations 7-8 of [3].

        This has time complexity ``O(T + S N log(N))`` where ``T`` is the
        number of time steps, ``N`` is the number of leaves, and ``S =
        sample_shape.numel()`` is the number of samples of ``value``.

        This is differentiable wrt ``rate_grid`` but neither ``leaf_times`` nor
        ``value = coal_times``.

        :param torch.Tensor value: A tensor of coalescent times. These denote
            sets of size ``leaf_times.size(-1) - 1`` along the trailing
            dimension and should be sorted along that dimension.
        :returns: Likelihood ``p(coal_times | leaf_times, rate_grid)``
        :rtype: torch.Tensor
        """
        if self._validate_args:
            self._validate_sample(value)
        coal_times = value
        phylogeny = _make_phylogeny(self.leaf_times, coal_times)

        # Compute survival factors for closed intervals.
        cumsum = self.rate_grid.cumsum(-1)
        cumsum = torch.nn.functional.pad(cumsum, (1, 0), value=0)
        integral = _interpolate_gather(
            cumsum, phylogeny.times[..., 1:])  # ignore the final lonely leaf
        integral = integral[..., :-1] - integral[..., 1:]
        integral = integral.clamp(min=torch.finfo(
            integral.dtype).tiny)  # avoid nan
        log_prob = -(phylogeny.binomial[..., 1:-1] * integral).sum(-1)

        # Compute density of coalescent events.
        i = coal_times.floor().clamp(min=0, max=self.duration - 1).long()
        rates = phylogeny.coal_binomial * _gather(self.rate_grid, -1, i)
        log_prob = log_prob + safe_log(rates).sum(-1)

        batch_shape = broadcast_shape(self.batch_shape, value.shape[:-1])
        log_prob = log_prob.expand(batch_shape)
        return log_prob
コード例 #5
0
ファイル: torch_log.py プロジェクト: pyro-ppl/pyro
def einsum(equation, *operands):
    """
    Log-sum-exp implementation of einsum.
    """
    # rename symbols to support PyTorch 0.4.1 and earlier,
    # which allow only symbols a-z.
    symbols = sorted(set(equation) - set(",->"))
    rename = dict(zip(symbols, "abcdefghijklmnopqrstuvwxyz"))
    equation = "".join(rename.get(s, s) for s in equation)

    inputs, output = equation.split("->")
    if inputs == output:
        return operands[0][...]  # create a new object
    inputs = inputs.split(",")

    shifts = []
    exp_operands = []
    for dims, operand in zip(inputs, operands):
        shift = operand.detach()
        for i, dim in enumerate(dims):
            if dim not in output:
                shift = shift.max(i, keepdim=True)[0]
        # avoid nan due to -inf - -inf
        shift = shift.clamp(min=torch.finfo(shift.dtype).min)
        exp_operands.append((operand - shift).exp())

        # permute shift to match output
        shift = shift.reshape(
            torch.Size(size for size, dim in zip(operand.shape, dims) if dim in output)
        )
        if shift.dim():
            shift = shift.reshape((1,) * (len(output) - shift.dim()) + shift.shape)
            dims = [dim for dim in dims if dim in output]
            dims = [dim for dim in output if dim not in dims] + dims
            shift = shift.permute(*(dims.index(dim) for dim in output))
        shifts.append(shift)

    result = safe_log(torch.einsum(equation, exp_operands))
    return sum(shifts + [result])