def get_means_variances(self): """ Return means and variances. Returns: (torch.tensor): Means (1, K). (torch.tensor): Variances (1, 1, K). """ K = self.K dtype = torch.float64 pi = torch.tensor(math.pi, dtype=dtype, device=self.dev) # Laguerre polymonial for n=1/2 Laguerre = lambda x: torch.exp(x/2) * \ ((1 - x) * besseli(0, -x/2) - x * besseli(1, -x/2)) # Compute means and variances mean = torch.zeros((1, K), dtype=dtype, device=self.dev) var = torch.zeros((1, 1, K), dtype=dtype, device=self.dev) for k in range(K): nu_k = self.nu[k] sig_k = self.sig[k] x = -nu_k**2 / (2 * sig_k**2) x = x.flatten() if x > -20: mean[:, k] = torch.sqrt(pi * sig_k**2 / 2) * Laguerre(x) var[:, :, k] = 2 * sig_k**2 + nu_k**2 - (pi * sig_k**2 / 2) * Laguerre(x)**2 else: mean[:, k] = nu_k var[:, :, k] = sig_k return mean, var
def _update(self, ss0, ss1, ss2): """ Update RMM parameters. Args: ss0 (torch.tensor): 0th moment (K). ss1 (torch.tensor): 1st moment (C, K). ss2 (torch.tensor): 2nd moment (C, C, K). See also Koay, C.G. and Basser, P. J., Analytically exact correction scheme for signal extraction from noisy magnitude MR signals, Journal of Magnetic Resonance, Volume 179, Issue = 2, p. 317–322, (2006) """ K = ss1.shape[1] dtype = torch.float64 # Compute means and variances mu1 = torch.zeros(K, dtype=dtype, device=self.dev) mu2 = torch.zeros(K, dtype=dtype, device=self.dev) for k in range(K): # Update mean mu1[k] = 1 / ss0[k] * ss1[:, k] # Update covariance mu2[k] = (ss2[:, :, k] - ss1[:, k] * ss1[:, k] / ss0[k] + self.lam * 1e-3) / (ss0[k] + 1e-3) # Update parameters (using means and variances) for k in range(K): r = mu1[k] / mu2[k].sqrt() theta = math.sqrt(math.pi / (4 - math.pi)) theta = torch.as_tensor(theta, dtype=dtype, device=self.dev).flatten() if r > theta: theta2 = theta * theta for i in range(256): xi = besseli(0, theta2/4) * (2 + theta2) + \ besseli(1, theta2/4) * theta2 xi = xi.square() * (math.pi / 8 * math.exp(-theta2 / 2)) xi = (2 + theta2) - xi g = (xi * (1 + r**2) - 2).sqrt() if torch.abs(theta - g) < 1e-6: break theta = g if not torch.isfinite(xi): xi.fill_(1) self.sig[k] = mu2[k].sqrt() / xi.sqrt() self.nu[k] = (mu1[k].square() + mu2[k] * (xi - 2) / xi).sqrt() else: self.nu[k] = 0 self.sig[k] = 0.5 * math.sqrt(2) * (mu1[k].square() + mu2[k]).sqrt()
def nll_chi(dat, fit, msk, lam, df, return_grad=True, out=None): """Negative log-likelihood of the noncentral Chi distribution Parameters ---------- dat : tensor Observed data fit : tensor Signal fit msk : tensor Mask of observed values lam : float Noise precision df : float Degrees of freedom return_grad : bool Return gradient on top of nll Returns ------- nll : () tensor Negative log-likelihood grad : tensor, if `return_grad` Gradient """ fitm = fit[msk] datm = dat[msk] # components of the log-likelihood sumlogfit = fitm.clamp_min(1e-32).log_().sum(dtype=torch.double) sumfit2 = ssq(fitm) sumlogdat = datm.clamp_min(1e-32).log_().sum(dtype=torch.double) sumdat2 = ssq(datm) # reweighting z = (fitm * datm).mul_(lam).clamp_min_(1e-32) xi = math.besseli_ratio(df / 2 - 1, z) logbes = math.besseli(df / 2 - 1, z, 'log') logbes = logbes.sum(dtype=torch.double) # sum parts crit = (df / 2 - 1) * sumlogfit - (df / 2) * sumlogdat - logbes crit += 0.5 * lam * (sumfit2 + sumdat2) if not return_grad: return crit # compute residuals grad = out.zero_() if out is not None else torch.zeros_like(dat) grad[msk] = datm.mul_(xi).neg_().add_(fitm).mul_(lam) return crit, grad
def nll_chi(dat, fit, msk, lam, df, return_residuals=True): """Negative log-likelihood of the noncentral Chi distribution Parameters ---------- dat : tensor Observed data (should be zero where not observed) fit : tensor Signal fit (should be zero where not observed) msk : tensor Mask of observed values lam : float Noise precision df : float Degrees of freedom return_residuals : bool Return residuals (gradient) on top of nll Returns ------- nll : () tensor Negative log-likelihood res : tensor, if `return_residuals` Residuals """ z = (dat * fit * lam).clamp_min_(1e-32) xi = besseli_ratio(df / 2 - 1, z) logbes = besseli(df / 2 - 1, z, 'log') logbes = logbes[msk].sum(dtype=torch.double) # chi log-likelihood fitm = fit[msk] sumlogfit = fitm.clamp_min(1e-32).log_().sum(dtype=torch.double) sumfit2 = fitm.flatten().dot(fitm.flatten()) del fitm datm = dat[msk] sumlogdat = datm.clamp_min(1e-32).log_().sum(dtype=torch.double) sumdat2 = datm.flatten().dot(datm.flatten()) del datm crit = (df / 2 - 1) * sumlogfit - (df / 2) * sumlogdat - logbes crit += 0.5 * lam * (sumfit2 + sumdat2) if not return_residuals: return crit res = dat.mul_(xi).neg_().add_(fit) return crit, res
def _log_likelihood(self, X, k=0, c=None): """ Log-probability density function (pdf) of the Rician distribution, evaluated at the values in X. Args: X (torch.tensor): Observed data (N, C). k (int, optional): Index of mixture component. Defaults to 0. Returns: log_pdf (torch.tensor): (N, 1). See also: https://en.wikipedia.org/wiki/Rice_distribution#Characterization """ backend = dict(dtype=X.dtype, device=X.device) tiny = 1e-32 # Get Rice parameters nu = self.nu[k].to(**backend) sig2 = self.sig[k].to(**backend).square() log_pdf = (X + tiny).log() - sig2.log() - (X.square() + nu.square()) / (2 * sig2) log_pdf = log_pdf + besseli(0, X * (nu / sig2), 'log') return log_pdf.flatten()