Пример #1
0
    def get_means_variances(self):
        """ Return means and variances.
        Returns:
            (torch.tensor): Means (1, K).
            (torch.tensor): Variances (1, 1, K).
        """
        K = self.K
        dtype = torch.float64
        pi = torch.tensor(math.pi, dtype=dtype, device=self.dev)

        # Laguerre polymonial for n=1/2
        Laguerre = lambda x: torch.exp(x/2) * \
            ((1 - x) * besseli(0, -x/2) - x * besseli(1, -x/2))

        # Compute means and variances
        mean = torch.zeros((1, K), dtype=dtype, device=self.dev)
        var = torch.zeros((1, 1, K), dtype=dtype, device=self.dev)
        for k in range(K):
            nu_k = self.nu[k]
            sig_k = self.sig[k]

            x = -nu_k**2 / (2 * sig_k**2)
            x = x.flatten()
            if x > -20:
                mean[:, k] = torch.sqrt(pi * sig_k**2 / 2) * Laguerre(x)
                var[:, :, k] = 2 * sig_k**2 + nu_k**2 - (pi * sig_k**2 /
                                                         2) * Laguerre(x)**2
            else:
                mean[:, k] = nu_k
                var[:, :, k] = sig_k

        return mean, var
Пример #2
0
    def _update(self, ss0, ss1, ss2):
        """ Update RMM parameters.
        Args:
            ss0 (torch.tensor): 0th moment (K).
            ss1 (torch.tensor): 1st moment (C, K).
            ss2 (torch.tensor): 2nd moment (C, C, K).
        See also
            Koay, C.G. and Basser, P. J., Analytically exact correction scheme
            for signal extraction from noisy magnitude MR signals,
            Journal of Magnetic Resonance, Volume 179, Issue = 2, p. 317–322, (2006)
        """
        K = ss1.shape[1]
        dtype = torch.float64

        # Compute means and variances
        mu1 = torch.zeros(K, dtype=dtype, device=self.dev)
        mu2 = torch.zeros(K, dtype=dtype, device=self.dev)
        for k in range(K):
            # Update mean
            mu1[k] = 1 / ss0[k] * ss1[:, k]

            # Update covariance
            mu2[k] = (ss2[:, :, k] - ss1[:, k] * ss1[:, k] / ss0[k] +
                      self.lam * 1e-3) / (ss0[k] + 1e-3)

        # Update parameters (using means and variances)
        for k in range(K):
            r = mu1[k] / mu2[k].sqrt()
            theta = math.sqrt(math.pi / (4 - math.pi))
            theta = torch.as_tensor(theta, dtype=dtype,
                                    device=self.dev).flatten()
            if r > theta:
                theta2 = theta * theta
                for i in range(256):
                    xi = besseli(0, theta2/4) * (2 + theta2) + \
                         besseli(1, theta2/4) * theta2
                    xi = xi.square() * (math.pi / 8 * math.exp(-theta2 / 2))
                    xi = (2 + theta2) - xi
                    g = (xi * (1 + r**2) - 2).sqrt()
                    if torch.abs(theta - g) < 1e-6:
                        break
                    theta = g
                if not torch.isfinite(xi):
                    xi.fill_(1)
                self.sig[k] = mu2[k].sqrt() / xi.sqrt()
                self.nu[k] = (mu1[k].square() + mu2[k] * (xi - 2) / xi).sqrt()
            else:
                self.nu[k] = 0
                self.sig[k] = 0.5 * math.sqrt(2) * (mu1[k].square() +
                                                    mu2[k]).sqrt()
Пример #3
0
def nll_chi(dat, fit, msk, lam, df, return_grad=True, out=None):
    """Negative log-likelihood of the noncentral Chi distribution

    Parameters
    ----------
    dat : tensor
        Observed data
    fit : tensor
        Signal fit
    msk : tensor
        Mask of observed values
    lam : float
        Noise precision
    df : float
        Degrees of freedom
    return_grad : bool
        Return gradient on top of nll

    Returns
    -------
    nll : () tensor
        Negative log-likelihood
    grad : tensor, if `return_grad`
        Gradient

    """
    fitm = fit[msk]
    datm = dat[msk]

    # components of the log-likelihood
    sumlogfit = fitm.clamp_min(1e-32).log_().sum(dtype=torch.double)
    sumfit2 = ssq(fitm)
    sumlogdat = datm.clamp_min(1e-32).log_().sum(dtype=torch.double)
    sumdat2 = ssq(datm)

    # reweighting
    z = (fitm * datm).mul_(lam).clamp_min_(1e-32)
    xi = math.besseli_ratio(df / 2 - 1, z)
    logbes = math.besseli(df / 2 - 1, z, 'log')
    logbes = logbes.sum(dtype=torch.double)

    # sum parts
    crit = (df / 2 - 1) * sumlogfit - (df / 2) * sumlogdat - logbes
    crit += 0.5 * lam * (sumfit2 + sumdat2)
    if not return_grad:
        return crit

    # compute residuals
    grad = out.zero_() if out is not None else torch.zeros_like(dat)
    grad[msk] = datm.mul_(xi).neg_().add_(fitm).mul_(lam)
    return crit, grad
Пример #4
0
def nll_chi(dat, fit, msk, lam, df, return_residuals=True):
    """Negative log-likelihood of the noncentral Chi distribution

    Parameters
    ----------
    dat : tensor
        Observed data (should be zero where not observed)
    fit : tensor
        Signal fit (should be zero where not observed)
    msk : tensor
        Mask of observed values
    lam : float
        Noise precision
    df : float
        Degrees of freedom
    return_residuals : bool
        Return residuals (gradient) on top of nll

    Returns
    -------
    nll : () tensor
        Negative log-likelihood
    res : tensor, if `return_residuals`
        Residuals

    """
    z = (dat * fit * lam).clamp_min_(1e-32)
    xi = besseli_ratio(df / 2 - 1, z)
    logbes = besseli(df / 2 - 1, z, 'log')
    logbes = logbes[msk].sum(dtype=torch.double)

    # chi log-likelihood
    fitm = fit[msk]
    sumlogfit = fitm.clamp_min(1e-32).log_().sum(dtype=torch.double)
    sumfit2 = fitm.flatten().dot(fitm.flatten())
    del fitm
    datm = dat[msk]
    sumlogdat = datm.clamp_min(1e-32).log_().sum(dtype=torch.double)
    sumdat2 = datm.flatten().dot(datm.flatten())
    del datm

    crit = (df / 2 - 1) * sumlogfit - (df / 2) * sumlogdat - logbes
    crit += 0.5 * lam * (sumfit2 + sumdat2)
    if not return_residuals:
        return crit
    res = dat.mul_(xi).neg_().add_(fit)
    return crit, res
Пример #5
0
    def _log_likelihood(self, X, k=0, c=None):
        """
        Log-probability density function (pdf) of the Rician
        distribution, evaluated at the values in X.
        Args:
            X (torch.tensor): Observed data (N, C).
            k (int, optional): Index of mixture component. Defaults to 0.
        Returns:
            log_pdf (torch.tensor): (N, 1).
        See also:
            https://en.wikipedia.org/wiki/Rice_distribution#Characterization
        """
        backend = dict(dtype=X.dtype, device=X.device)
        tiny = 1e-32

        # Get Rice parameters
        nu = self.nu[k].to(**backend)
        sig2 = self.sig[k].to(**backend).square()

        log_pdf = (X + tiny).log() - sig2.log() - (X.square() +
                                                   nu.square()) / (2 * sig2)
        log_pdf = log_pdf + besseli(0, X * (nu / sig2), 'log')
        return log_pdf.flatten()