Exemplo n.º 1
0
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()

        device = grad_input.device
        scale = ctx.scale
        c = ctx.c
        dim = torch.tensor(int(ctx.dim)).to(device).double()

        k_float = rexpand(torch.arange(int(dim)),
                          *scale.size()).double().to(device)
        signs = torch.tensor([1., -1.]).double().to(device).repeat(
            ((int(dim) + 1) // 2) * 2)[:int(dim)]
        signs = rexpand(signs, *scale.size())

        log_arg = (dim - 1 - 2 * k_float).pow(2) * c * scale * (1+torch.erf((dim - 1 - 2 * k_float) * c.sqrt() * scale / math.sqrt(2))) + \
            torch.exp(-(dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2) * 2 / math.sqrt(math.pi) * (dim - 1 - 2 * k_float) * c.sqrt() / math.sqrt(2)
        log_arg_signs = torch.sign(log_arg)
        s = torch.lgamma(dim) - torch.lgamma(k_float + 1) - torch.lgamma(dim - k_float) \
            + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 \
            + torch.log(log_arg_signs * log_arg)
        log_grad_sum_sigma = log_sum_exp_signs(s, log_arg_signs * signs, dim=0)

        grad_scale = torch.exp(log_grad_sum_sigma - ctx.log_sum_term)
        grad_scale = 1 / ctx.scale + grad_scale

        grad_scale = (grad_input * grad_scale.float()).view(
            -1, *grad_input.shape).sum(0)
        return (grad_scale, None, None)
Exemplo n.º 2
0
def grad_cdf_value_scale(value, scale, c, dim):
    device = value.device

    dim = torch.tensor(int(dim)).to(device).double()

    signs = torch.tensor([1., -1.]).double().to(device).repeat(
        ((int(dim) + 1) // 2) * 2)[:int(dim)]
    signs = rexpand(signs, *value.size())
    k_float = rexpand(torch.arange(dim), *value.size()).double().to(device)

    log_arg1 = (dim - 1 - 2 * k_float).pow(2) * c * scale * \
    (\
        torch.erf((value - (dim - 1 - 2 * k_float) * c.sqrt() * scale.pow(2)) / scale / math.sqrt(2)) \
        + torch.erf((dim - 1 - 2 * k_float) * c.sqrt() * scale / math.sqrt(2)) \
    )

    log_arg2 = math.sqrt(2 / math.pi) * ( \
        (dim - 1 - 2 * k_float) * c.sqrt() * torch.exp(-(dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2) \
        - ((value / scale.pow(2) + (dim - 1 - 2 * k_float) * c.sqrt()) * torch.exp(-(value - (dim - 1 - 2 * k_float) * c.sqrt() * scale.pow(2)).pow(2) / (2 * scale.pow(2)))) \
        )

    log_arg = log_arg1 + log_arg2
    sign_log_arg = torch.sign(log_arg)

    s = torch.lgamma(dim) - torch.lgamma(k_float + 1) - torch.lgamma(dim - k_float) \
            + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 \
            + torch.log(sign_log_arg * log_arg)

    log_grad_sum_sigma = log_sum_exp_signs(s, signs * sign_log_arg, dim=0)

    s1 = torch.lgamma(dim) - torch.lgamma(k_float + 1) - torch.lgamma(dim - k_float) \
        + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 \
        + torch.log( \
            torch.erf((value - (dim - 1 - 2 * k_float) * c.sqrt() * scale.pow(2)) / scale / math.sqrt(2)) \
            + torch.erf((dim - 1 - 2 * k_float) * c.sqrt() * scale / math.sqrt(2)) \
        )

    S1 = log_sum_exp_signs(s1, signs, dim=0)
    grad_sum_sigma = torch.sum(signs * sign_log_arg * torch.exp(s - S1), dim=0)
    grad_log_cdf_scale = grad_sum_sigma
    log_unormalised_prob = -value.pow(2) / (2 * scale.pow(2)) + (
        dim - 1) * logsinh(c.sqrt() * value) - (dim - 1) / 2 * c.log()

    with torch.autograd.enable_grad():
        scale = scale.float()
        logZ = _log_normalizer_closed_grad.apply(scale, c, dim)
        grad_logZ_scale = grad(logZ,
                               scale,
                               grad_outputs=torch.ones_like(scale))

    grad_log_cdf_scale = -grad_logZ_scale[
        0] + 1 / scale + grad_log_cdf_scale.float()
    cdf = cdf_r(value.double(), scale.double(), c.double(),
                int(dim)).float().squeeze(0)
    grad_scale = cdf * grad_log_cdf_scale

    grad_value = (log_unormalised_prob.float() - logZ).exp()
    return grad_value, grad_scale
Exemplo n.º 3
0
    def variance(self):
        c = self.c.double()
        scale = self.scale.double()
        dim = torch.tensor(int(self.dim)).double().to(self.device)
        signs = torch.tensor([1., -1.]).double().to(self.device).repeat(
            ((int(dim) + 1) // 2) *
            2)[:int(dim)].unsqueeze(-1).unsqueeze(-1).expand(
                int(dim), *self.scale.size())

        k_float = rexpand(torch.arange(self.dim),
                          *self.scale.size()).double().to(self.device)
        s2 = torch.lgamma(dim) - torch.lgamma(k_float + 1) - torch.lgamma(dim - k_float) \
                + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 \
                + torch.log1p(torch.erf((dim - 1 - 2 * k_float) * c.sqrt() * scale / math.sqrt(2)))
        S2 = log_sum_exp_signs(s2, signs, dim=0)

        log_arg = (1 + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2)) * (1 + torch.erf((dim - 1 - 2 * k_float) * c.sqrt() * scale / math.sqrt(2))) + \
               (dim - 1 - 2 * k_float) * c.sqrt() * torch.exp(-(dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2) * scale * math.sqrt(2 / math.pi)
        log_arg_signs = torch.sign(log_arg)
        s1 = torch.lgamma(dim) - torch.lgamma(k_float + 1) - torch.lgamma(dim - k_float) \
                + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 \
                + 2 * scale.log() \
                + torch.log(log_arg_signs * log_arg)
        S1 = log_sum_exp_signs(s1, signs * log_arg_signs, dim=0)

        output = torch.exp(S1 - S2)
        output = output.float() - self.mean.pow(2)
        return output
Exemplo n.º 4
0
 def test_sample(self):
     N = 100000
     self.d = HyperbolicRadius(self.dim, self.c, torch.tensor([.5, 1.]).unsqueeze(-1))
     x = self.d.sample(torch.Size([N]))
     logp = self.d.log_prob(x)
     
     # Kolmogorov–Smirnov statistic
     grid = torch.linspace(0, 3, steps=100)
     ecdf = self.ecdf(x, grid)
     cdf = self.d.cdf(rexpand(grid, *self.d.scale.size())).squeeze(-1).t()
     diff = (ecdf - cdf).abs().max()
     assert diff < 5e-3
Exemplo n.º 5
0
def cdf_r(value, scale, c, dim):
    value = value.double()
    scale = scale.double()
    c = c.double()

    if dim == 2:
        return 1 / torch.erf(c.sqrt() * scale / math.sqrt(2)) * .5 * \
    (2 * torch.erf(c.sqrt() * scale / math.sqrt(2)) + torch.erf((value - c.sqrt() * scale.pow(2)) / math.sqrt(2) / scale) - \
        torch.erf((c.sqrt() * scale.pow(2) + value) / math.sqrt(2) / scale))
    else:
        device = value.device

        k_float = rexpand(torch.arange(dim), *value.size()).double().to(device)
        dim = torch.tensor(dim).to(device).double()

        s1 = torch.lgamma(dim) - torch.lgamma(k_float + 1) - torch.lgamma(dim - k_float) \
            + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 \
            + torch.log( \
                torch.erf((value - (dim - 1 - 2 * k_float) * c.sqrt() * scale.pow(2)) / scale / math.sqrt(2)) \
                + torch.erf((dim - 1 - 2 * k_float) * c.sqrt() * scale / math.sqrt(2)) \
                )
        s2 = torch.lgamma(dim) - torch.lgamma(k_float + 1) - torch.lgamma(dim - k_float) \
            + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 \
            + torch.log1p(torch.erf((dim - 1 - 2 * k_float) * c.sqrt() * scale / math.sqrt(2)))

        signs = torch.tensor([1., -1.]).double().to(device).repeat(
            ((int(dim) + 1) // 2) * 2)[:int(dim)]
        signs = rexpand(signs, *value.size())

        S1 = log_sum_exp_signs(s1, signs, dim=0)
        S2 = log_sum_exp_signs(s2, signs, dim=0)

        output = torch.exp(S1 - S2)
        zero_value_idx = value == 0.
        output[zero_value_idx] = 0.
        return output.float()
Exemplo n.º 6
0
    def forward(ctx, scale, c, dim):
        scale = scale.double()
        c = c.double()
        ctx.scale = scale.clone().detach()
        ctx.c = c.clone().detach()
        ctx.dim = dim

        device = scale.device
        output = .5 * (Constants.logpi - Constants.log2) + scale.log() - (
            int(dim) - 1) * (c.log() / 2 + Constants.log2)
        dim = torch.tensor(int(dim)).to(device).double()

        k_float = rexpand(torch.arange(int(dim)),
                          *scale.size()).double().to(device)
        s = torch.lgamma(dim) - torch.lgamma(k_float + 1) - torch.lgamma(dim - k_float) \
            + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 \
            + torch.log1p(torch.erf((dim - 1 - 2 * k_float) * c.sqrt() * scale / math.sqrt(2)))
        signs = torch.tensor([1., -1.]).double().to(device).repeat(
            ((int(dim) + 1) // 2) * 2)[:int(dim)]
        signs = rexpand(signs, *scale.size())
        ctx.log_sum_term = log_sum_exp_signs(s, signs, dim=0)
        output = output + ctx.log_sum_term

        return output.float()