def grad_cdf_value_scale(value, scale, c, dim): device = value.device dim = torch.tensor(int(dim)).to(device).double() signs = torch.tensor([1., -1.]).double().to(device).repeat( ((int(dim) + 1) // 2) * 2)[:int(dim)] signs = rexpand(signs, *value.size()) k_float = rexpand(torch.arange(dim), *value.size()).double().to(device) log_arg1 = (dim - 1 - 2 * k_float).pow(2) * c * scale * \ (\ torch.erf((value - (dim - 1 - 2 * k_float) * c.sqrt() * scale.pow(2)) / scale / math.sqrt(2)) \ + torch.erf((dim - 1 - 2 * k_float) * c.sqrt() * scale / math.sqrt(2)) \ ) log_arg2 = math.sqrt(2 / math.pi) * ( \ (dim - 1 - 2 * k_float) * c.sqrt() * torch.exp(-(dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2) \ - ((value / scale.pow(2) + (dim - 1 - 2 * k_float) * c.sqrt()) * torch.exp(-(value - (dim - 1 - 2 * k_float) * c.sqrt() * scale.pow(2)).pow(2) / (2 * scale.pow(2)))) \ ) log_arg = log_arg1 + log_arg2 sign_log_arg = torch.sign(log_arg) s = torch.lgamma(dim) - torch.lgamma(k_float + 1) - torch.lgamma(dim - k_float) \ + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 \ + torch.log(sign_log_arg * log_arg) log_grad_sum_sigma = log_sum_exp_signs(s, signs * sign_log_arg, dim=0) s1 = torch.lgamma(dim) - torch.lgamma(k_float + 1) - torch.lgamma(dim - k_float) \ + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 \ + torch.log( \ torch.erf((value - (dim - 1 - 2 * k_float) * c.sqrt() * scale.pow(2)) / scale / math.sqrt(2)) \ + torch.erf((dim - 1 - 2 * k_float) * c.sqrt() * scale / math.sqrt(2)) \ ) S1 = log_sum_exp_signs(s1, signs, dim=0) grad_sum_sigma = torch.sum(signs * sign_log_arg * torch.exp(s - S1), dim=0) grad_log_cdf_scale = grad_sum_sigma log_unormalised_prob = -value.pow(2) / (2 * scale.pow(2)) + ( dim - 1) * logsinh(c.sqrt() * value) - (dim - 1) / 2 * c.log() with torch.autograd.enable_grad(): scale = scale.float() logZ = _log_normalizer_closed_grad.apply(scale, c, dim) grad_logZ_scale = grad(logZ, scale, grad_outputs=torch.ones_like(scale)) grad_log_cdf_scale = -grad_logZ_scale[ 0] + 1 / scale + grad_log_cdf_scale.float() cdf = cdf_r(value.double(), scale.double(), c.double(), int(dim)).float().squeeze(0) grad_scale = cdf * grad_log_cdf_scale grad_value = (log_unormalised_prob.float() - logZ).exp() return grad_value, grad_scale
def log_prob(self, value): res = - value.pow(2) / (2 * self.scale.pow(2)) + (self.dim - 1) * logsinh(self.c.sqrt() * value) \ - (self.dim - 1) / 2 * self.c.log() - self.log_normalizer#.expand(value.shape) assert not torch.isnan(res).any() return res
def log_x_div_sinh(x, c): """ Stable function for torch.sinh(c.sqrt() * x).log() """ res = c.sqrt().log() + x.log() - logsinh(c.sqrt() * x) zero_value_idx = x == 0. res[zero_value_idx] = 0. return res