def backward(ctx, grad_output): grad_input = grad_output.clone() device = grad_input.device scale = ctx.scale c = ctx.c dim = torch.tensor(int(ctx.dim)).to(device).double() k_float = rexpand(torch.arange(int(dim)), *scale.size()).double().to(device) signs = torch.tensor([1., -1.]).double().to(device).repeat( ((int(dim) + 1) // 2) * 2)[:int(dim)] signs = rexpand(signs, *scale.size()) log_arg = (dim - 1 - 2 * k_float).pow(2) * c * scale * (1+torch.erf((dim - 1 - 2 * k_float) * c.sqrt() * scale / math.sqrt(2))) + \ torch.exp(-(dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2) * 2 / math.sqrt(math.pi) * (dim - 1 - 2 * k_float) * c.sqrt() / math.sqrt(2) log_arg_signs = torch.sign(log_arg) s = torch.lgamma(dim) - torch.lgamma(k_float + 1) - torch.lgamma(dim - k_float) \ + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 \ + torch.log(log_arg_signs * log_arg) log_grad_sum_sigma = log_sum_exp_signs(s, log_arg_signs * signs, dim=0) grad_scale = torch.exp(log_grad_sum_sigma - ctx.log_sum_term) grad_scale = 1 / ctx.scale + grad_scale grad_scale = (grad_input * grad_scale.float()).view( -1, *grad_input.shape).sum(0) return (grad_scale, None, None)
def grad_cdf_value_scale(value, scale, c, dim): device = value.device dim = torch.tensor(int(dim)).to(device).double() signs = torch.tensor([1., -1.]).double().to(device).repeat( ((int(dim) + 1) // 2) * 2)[:int(dim)] signs = rexpand(signs, *value.size()) k_float = rexpand(torch.arange(dim), *value.size()).double().to(device) log_arg1 = (dim - 1 - 2 * k_float).pow(2) * c * scale * \ (\ torch.erf((value - (dim - 1 - 2 * k_float) * c.sqrt() * scale.pow(2)) / scale / math.sqrt(2)) \ + torch.erf((dim - 1 - 2 * k_float) * c.sqrt() * scale / math.sqrt(2)) \ ) log_arg2 = math.sqrt(2 / math.pi) * ( \ (dim - 1 - 2 * k_float) * c.sqrt() * torch.exp(-(dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2) \ - ((value / scale.pow(2) + (dim - 1 - 2 * k_float) * c.sqrt()) * torch.exp(-(value - (dim - 1 - 2 * k_float) * c.sqrt() * scale.pow(2)).pow(2) / (2 * scale.pow(2)))) \ ) log_arg = log_arg1 + log_arg2 sign_log_arg = torch.sign(log_arg) s = torch.lgamma(dim) - torch.lgamma(k_float + 1) - torch.lgamma(dim - k_float) \ + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 \ + torch.log(sign_log_arg * log_arg) log_grad_sum_sigma = log_sum_exp_signs(s, signs * sign_log_arg, dim=0) s1 = torch.lgamma(dim) - torch.lgamma(k_float + 1) - torch.lgamma(dim - k_float) \ + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 \ + torch.log( \ torch.erf((value - (dim - 1 - 2 * k_float) * c.sqrt() * scale.pow(2)) / scale / math.sqrt(2)) \ + torch.erf((dim - 1 - 2 * k_float) * c.sqrt() * scale / math.sqrt(2)) \ ) S1 = log_sum_exp_signs(s1, signs, dim=0) grad_sum_sigma = torch.sum(signs * sign_log_arg * torch.exp(s - S1), dim=0) grad_log_cdf_scale = grad_sum_sigma log_unormalised_prob = -value.pow(2) / (2 * scale.pow(2)) + ( dim - 1) * logsinh(c.sqrt() * value) - (dim - 1) / 2 * c.log() with torch.autograd.enable_grad(): scale = scale.float() logZ = _log_normalizer_closed_grad.apply(scale, c, dim) grad_logZ_scale = grad(logZ, scale, grad_outputs=torch.ones_like(scale)) grad_log_cdf_scale = -grad_logZ_scale[ 0] + 1 / scale + grad_log_cdf_scale.float() cdf = cdf_r(value.double(), scale.double(), c.double(), int(dim)).float().squeeze(0) grad_scale = cdf * grad_log_cdf_scale grad_value = (log_unormalised_prob.float() - logZ).exp() return grad_value, grad_scale
def variance(self): c = self.c.double() scale = self.scale.double() dim = torch.tensor(int(self.dim)).double().to(self.device) signs = torch.tensor([1., -1.]).double().to(self.device).repeat( ((int(dim) + 1) // 2) * 2)[:int(dim)].unsqueeze(-1).unsqueeze(-1).expand( int(dim), *self.scale.size()) k_float = rexpand(torch.arange(self.dim), *self.scale.size()).double().to(self.device) s2 = torch.lgamma(dim) - torch.lgamma(k_float + 1) - torch.lgamma(dim - k_float) \ + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 \ + torch.log1p(torch.erf((dim - 1 - 2 * k_float) * c.sqrt() * scale / math.sqrt(2))) S2 = log_sum_exp_signs(s2, signs, dim=0) log_arg = (1 + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2)) * (1 + torch.erf((dim - 1 - 2 * k_float) * c.sqrt() * scale / math.sqrt(2))) + \ (dim - 1 - 2 * k_float) * c.sqrt() * torch.exp(-(dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2) * scale * math.sqrt(2 / math.pi) log_arg_signs = torch.sign(log_arg) s1 = torch.lgamma(dim) - torch.lgamma(k_float + 1) - torch.lgamma(dim - k_float) \ + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 \ + 2 * scale.log() \ + torch.log(log_arg_signs * log_arg) S1 = log_sum_exp_signs(s1, signs * log_arg_signs, dim=0) output = torch.exp(S1 - S2) output = output.float() - self.mean.pow(2) return output
def test_sample(self): N = 100000 self.d = HyperbolicRadius(self.dim, self.c, torch.tensor([.5, 1.]).unsqueeze(-1)) x = self.d.sample(torch.Size([N])) logp = self.d.log_prob(x) # Kolmogorov–Smirnov statistic grid = torch.linspace(0, 3, steps=100) ecdf = self.ecdf(x, grid) cdf = self.d.cdf(rexpand(grid, *self.d.scale.size())).squeeze(-1).t() diff = (ecdf - cdf).abs().max() assert diff < 5e-3
def cdf_r(value, scale, c, dim): value = value.double() scale = scale.double() c = c.double() if dim == 2: return 1 / torch.erf(c.sqrt() * scale / math.sqrt(2)) * .5 * \ (2 * torch.erf(c.sqrt() * scale / math.sqrt(2)) + torch.erf((value - c.sqrt() * scale.pow(2)) / math.sqrt(2) / scale) - \ torch.erf((c.sqrt() * scale.pow(2) + value) / math.sqrt(2) / scale)) else: device = value.device k_float = rexpand(torch.arange(dim), *value.size()).double().to(device) dim = torch.tensor(dim).to(device).double() s1 = torch.lgamma(dim) - torch.lgamma(k_float + 1) - torch.lgamma(dim - k_float) \ + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 \ + torch.log( \ torch.erf((value - (dim - 1 - 2 * k_float) * c.sqrt() * scale.pow(2)) / scale / math.sqrt(2)) \ + torch.erf((dim - 1 - 2 * k_float) * c.sqrt() * scale / math.sqrt(2)) \ ) s2 = torch.lgamma(dim) - torch.lgamma(k_float + 1) - torch.lgamma(dim - k_float) \ + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 \ + torch.log1p(torch.erf((dim - 1 - 2 * k_float) * c.sqrt() * scale / math.sqrt(2))) signs = torch.tensor([1., -1.]).double().to(device).repeat( ((int(dim) + 1) // 2) * 2)[:int(dim)] signs = rexpand(signs, *value.size()) S1 = log_sum_exp_signs(s1, signs, dim=0) S2 = log_sum_exp_signs(s2, signs, dim=0) output = torch.exp(S1 - S2) zero_value_idx = value == 0. output[zero_value_idx] = 0. return output.float()
def forward(ctx, scale, c, dim): scale = scale.double() c = c.double() ctx.scale = scale.clone().detach() ctx.c = c.clone().detach() ctx.dim = dim device = scale.device output = .5 * (Constants.logpi - Constants.log2) + scale.log() - ( int(dim) - 1) * (c.log() / 2 + Constants.log2) dim = torch.tensor(int(dim)).to(device).double() k_float = rexpand(torch.arange(int(dim)), *scale.size()).double().to(device) s = torch.lgamma(dim) - torch.lgamma(k_float + 1) - torch.lgamma(dim - k_float) \ + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 \ + torch.log1p(torch.erf((dim - 1 - 2 * k_float) * c.sqrt() * scale / math.sqrt(2))) signs = torch.tensor([1., -1.]).double().to(device).repeat( ((int(dim) + 1) // 2) * 2)[:int(dim)] signs = rexpand(signs, *scale.size()) ctx.log_sum_term = log_sum_exp_signs(s, signs, dim=0) output = output + ctx.log_sum_term return output.float()