示例#1
0
def jackknife2(filein, fileout, angles, low, recon, recon_args, sdn=None):
    """
    plots twice the sum of the leave-one-out errors of reconstructions

    Plots and saves twice the sum of the leave-one-out differences between
    the original image in filein and the reconstruction via recon applied to
    the k-space subsampling specified by angles fed into radialines, while
    including all frequencies between -low to low in both directions,
    corrupting the k-space values with independent and identically distributed
    centered complex Gaussian noise whose standard deviation is sdn*sqrt(2)
    (sdn=0 if not provided explicitly). The "one" left out in the leave-one-out
    is a radial "line" at the angles specified by angles.

    The calling sequence of recon must be  (m, n, f, mask, **recon_args),
    where filein contains an m x n image, f is the image in k-space subsampled
    to the mask, mask is the return from calls to radialines (with angles),
    supplemented by all frequencies between -low to low in both directions, and
    **recon_args is the unpacking of recon_args. The function recon must return
    a torch.Tensor (the reconstruction) and a float (the corresponding loss).

    Parameters
    ----------
    filein : str
        path to the file containing the image to be processed (the path may be
        relative or absolute)
    fileout : str
        path to the file to which the plots will be saved (the path may be
        relative or absolute)
    angles : list of float
        angles of the radial "lines" in the mask that radialines will construct
    low : int
        bandwidth of low frequencies to include in the mask
        (between -low to low in both the horizontal and vertical directions)
    recon : function
        returns the reconstructed image
    recon_args : dict
        keyword arguments for recon
    sdn : float, optional
        standard deviation of the noise to add (defaults to 0)

    Returns
    -------
    float
        loss for the reconstruction using all angles
    list of float
        losses for the reconstructions using all angles except for one
    """
    # Set default parameters.
    if sdn is None:
        sdn = 0
    # Read the image from disk.
    with Image.open(filein) as img:
        f_orig = np.array(img).astype(np.float64) / 255.
    m = f_orig.shape[0]
    n = f_orig.shape[1]
    # Fourier transform the image.
    ff_orig = np.fft.fft2(f_orig) / np.sqrt(m * n)
    # Add noise.
    ff_noisy = ff_orig.copy()
    ff_noisy += sdn * (np.random.randn(m, n) + 1j * np.random.randn(m, n))
    # Select which frequencies to retain.
    mask = radialines.randradialineset(m, n, angles)
    # Include all low frequencies.
    for km in range(low):
        for kn in range(low):
            mask[km, kn] = True
            mask[m - 1 - km, kn] = True
            mask[km, n - 1 - kn] = True
            mask[m - 1 - km, n - 1 - kn] = True
    # Subsample the noisy Fourier transform of the original image.
    f = ctorch.from_numpy(ff_noisy[mask]).cuda()
    logging.info('computing jackknife2 differences -- all {}'.format(
        len(angles)))
    # Perform the reconstruction using the entire mask.
    reconf, lossf = recon(m, n, f, mask, **recon_args)
    reconf = reconf.cpu().numpy()
    # Perform the reconstruction omitting different samples in k-space.
    recons = np.ndarray((angles.size, m, n))
    loss = []
    for k in range(angles.size):
        # Drop an angle.
        langles = list(angles)
        del langles[k]
        mask1 = radialines.randradialineset(m, n, langles)
        # Include all low frequencies.
        for km in range(low):
            for kn in range(low):
                mask1[km, kn] = True
                mask1[m - 1 - km, kn] = True
                mask1[km, n - 1 - kn] = True
                mask1[m - 1 - km, n - 1 - kn] = True
        # Subsample the noisy Fourier transform of the original image.
        f1 = ctorch.from_numpy(ff_noisy[mask1]).cuda()
        # Reconstruct the image from the subsampled data.
        recon1, loss1 = recon(m, n, f1, mask1, **recon_args)
        recon1 = recon1.cpu().numpy()
        # Record the results.
        recons[k, :, :] = recon1
        loss.append(loss1)
    # Calculate the sum of the leave-one-out differences.
    sumloo = np.sum(recons - reconf, axis=0)
    scaled = sumloo * 2

    # Plot errors.
    # Remove the ticks and spines on the axes.
    matplotlib.rcParams['xtick.top'] = False
    matplotlib.rcParams['xtick.bottom'] = False
    matplotlib.rcParams['ytick.left'] = False
    matplotlib.rcParams['ytick.right'] = False
    matplotlib.rcParams['xtick.labeltop'] = False
    matplotlib.rcParams['xtick.labelbottom'] = False
    matplotlib.rcParams['ytick.labelleft'] = False
    matplotlib.rcParams['ytick.labelright'] = False
    matplotlib.rcParams['axes.spines.top'] = False
    matplotlib.rcParams['axes.spines.bottom'] = False
    matplotlib.rcParams['axes.spines.left'] = False
    matplotlib.rcParams['axes.spines.right'] = False
    # Configure the colormaps.
    kwargs01 = dict(cmap='gray',
                    norm=matplotlib.colors.Normalize(vmin=0, vmax=1))
    kwargs11 = dict(cmap='gray',
                    norm=matplotlib.colors.Normalize(vmin=-1, vmax=1))
    # Separate the suffix (filetype) from the rest of the filename.
    suffix = '.' + fileout.split('.')[-1]
    rest = fileout[:-len(suffix)]
    assert fileout == rest + suffix
    # Plot the original.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Original')
    plt.imshow(np.clip(f_orig, 0, 1), **kwargs01)
    plt.savefig(rest + '_original' + suffix, bbox_inches='tight')
    # Plot the reconstruction from the original mask provided.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Reconstruction')
    plt.imshow(np.clip(reconf, 0, 1), **kwargs01)
    plt.savefig(rest + '_recon' + suffix, bbox_inches='tight')
    # Plot the difference from the original.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Error of Reconstruction')
    plt.imshow(np.clip(reconf - f_orig, -1, 1), **kwargs11)
    plt.savefig(rest + '_error' + suffix, bbox_inches='tight')
    # Plot twice the sum of the leave-one-out differences.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Jackknife')
    plt.imshow(np.clip(scaled, -1, 1), **kwargs11)
    plt.savefig(rest + '_jackknife' + suffix, bbox_inches='tight')

    return lossf, loss
示例#2
0
def jackknife(filein, fileout, mask, low, recon, recon_args, sdn=None):
    """
    plots twice the sum of the leave-one-out errors of reconstructions

    Plots and saves twice the sum of the leave-one-out differences between
    the original image in filein and the reconstruction via recon applied to
    the k-space subsampling specified by mask (well, assuming mask includes
    all frequencies between -low and low), corrupting the k-space values with
    independent and identically distributed centered complex Gaussian noise
    whose standard deviation is sdn*sqrt(2) (sdn=0 if not provided explicitly).
    The "one" left out in the leave-one-out is a full row of k-space.

    The calling sequence of recon must be  (m, n, f, mask_th, **recon_args),
    where filein contains an m x n image, f is the image in k-space subsampled
    to the mask, mask_th = torch.from_numpy(mask.astype(np.unit8)).cuda(), and
    **recon_args is the unpacking of recon_args. The function recon must return
    a torch.Tensor (the reconstruction) and a float (the corresponding loss).

    _N.B._: mask[-low+1], mask[-low+2], ..., mask[low-1] must be True.

    Parameters
    ----------
    filein : str
        path to the file containing the image to be processed (the path may be
        relative or absolute)
    fileout : str
        path to the file to which the plots will be saved (the path may be
        relative or absolute)
    mask : ndarray of bool
        indicators of whether to include (True) or exclude (False)
        the corresponding rows in k-space of the image from filein
    low : int
        bandwidth of low frequencies included in mask (between -low to low)
    recon : function
        returns the reconstructed image
    recon_args : dict
        keyword arguments for recon
    sdn : float, optional
        standard deviation of the noise to add (defaults to 0)

    Returns
    -------
    float
        loss for the reconstruction using all mask
    list of float
        losses for the reconstructions using all mask except for one row
    """
    # Set default parameters.
    if sdn is None:
        sdn = 0
    # Check that the mask includes all low frequencies.
    for k in range(low):
        assert mask[k]
        assert mask[-k]
    # Read the image from disk.
    with Image.open(filein) as img:
        f_orig = np.array(img).astype(np.float64) / 255.
    m = f_orig.shape[0]
    n = f_orig.shape[1]
    # Fourier transform the image.
    ff_orig = np.fft.fft2(f_orig) / np.sqrt(m * n)
    # Add noise.
    ff_noisy = ff_orig.copy()
    ff_noisy += sdn * (np.random.randn(m, n) + 1j * np.random.randn(m, n))
    # Subsample the noisy Fourier transform of the original image.
    f = ctorch.from_numpy(ff_noisy[mask]).cuda()
    # Index the True values in mask (aside from the low frequencies);
    # make the inequality strict to allow for low = 0.
    trues = []
    for k in range(mask.size):
        if k > low and k < m - low and mask[k]:
            trues.append(k)
    logging.info('computing jackknife differences -- all {}'.format(
        len(trues)))
    # Perform the reconstruction using the entire mask.
    mask_th = torch.from_numpy(mask.astype(np.uint8)).cuda()
    reconf, lossf = recon(m, n, f, mask_th, **recon_args)
    reconf = reconf.cpu().numpy()
    # Perform the reconstruction omitting different samples in k-space.
    recons = np.ndarray((len(trues), m, n))
    loss = []
    for k in range(len(trues)):
        # Drop a row.
        mask1 = mask.copy()
        mask1[trues[k]] = False
        f1 = ctorch.from_numpy(ff_noisy[mask1]).cuda()
        # Reconstruct the image from the subsampled data.
        mask1_th = torch.from_numpy(mask1.astype(np.uint8)).cuda()
        recon1, loss1 = recon(m, n, f1, mask1_th, **recon_args)
        recon1 = recon1.cpu().numpy()
        # Record the results.
        recons[k, :, :] = recon1
        loss.append(loss1)
    # Calculate the sum of the leave-one-out differences.
    sumloo = np.sum(recons - reconf, axis=0)
    scaled = sumloo * 2

    # Plot errors.
    # Remove the ticks and spines on the axes.
    matplotlib.rcParams['xtick.top'] = False
    matplotlib.rcParams['xtick.bottom'] = False
    matplotlib.rcParams['ytick.left'] = False
    matplotlib.rcParams['ytick.right'] = False
    matplotlib.rcParams['xtick.labeltop'] = False
    matplotlib.rcParams['xtick.labelbottom'] = False
    matplotlib.rcParams['ytick.labelleft'] = False
    matplotlib.rcParams['ytick.labelright'] = False
    matplotlib.rcParams['axes.spines.top'] = False
    matplotlib.rcParams['axes.spines.bottom'] = False
    matplotlib.rcParams['axes.spines.left'] = False
    matplotlib.rcParams['axes.spines.right'] = False
    # Configure the colormaps.
    kwargs01 = dict(cmap='gray',
                    norm=matplotlib.colors.Normalize(vmin=0, vmax=1))
    kwargs11 = dict(cmap='gray',
                    norm=matplotlib.colors.Normalize(vmin=-1, vmax=1))
    # Separate the suffix (filetype) from the rest of the filename.
    suffix = '.' + fileout.split('.')[-1]
    rest = fileout[:-len(suffix)]
    assert fileout == rest + suffix
    # Plot the original.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Original')
    plt.imshow(f_orig, **kwargs01)
    plt.savefig(rest + '_original' + suffix, bbox_inches='tight')
    # Plot the reconstruction from the original mask provided.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Reconstruction')
    plt.imshow(reconf, **kwargs01)
    plt.savefig(rest + '_recon' + suffix, bbox_inches='tight')
    # Plot the difference from the original.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Error of Reconstruction')
    plt.imshow(reconf - f_orig, **kwargs11)
    plt.savefig(rest + '_error' + suffix, bbox_inches='tight')
    # Plot twice the sum of the leave-one-out differences.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Jackknife')
    plt.imshow(scaled, **kwargs11)
    plt.savefig(rest + '_jackknife' + suffix, bbox_inches='tight')

    return lossf, loss
示例#3
0
def cs_fft(m, n, f, mask, mu, beta, n_iter):
    """
    Recovers an image from a subset of its frequencies using FFTs.

    Reconstructs an m x n image from the subset f of its frequencies specified
    by mask, using ADMM with regularization parameter mu, coupling parameter
    beta, and number of iterations n_iter. Unlike function cs_baseline,
    this cs_fft uses FFTs. The computations take place on the CPU(s) in numpy
    when f is a numpy.ndarray and take place on the GPU(s) in ctorch when f is
    a ctorch.ComplexTensor.

    _N.B._: mask[0, 0] must be True to make the optimization well-posed.

    Parameters
    ----------
    m : int
        number of rows in the image being reconstructed
    n : int
        number of columns in the image being reconstructed
    f : numpy.ndarray or ctorch.ComplexTensor
        potentially nonzero entries (prior to the inverse Fourier transform)
    mask : numpy.ndarray
        boolean indicators of the potential nonzeros in the full m x n array
        -- note that the zero frequency entry must be True in order to make the
        optimization well-posed
    mu : float
        regularization parameter
    beta : float
        coupling parameter for the ADMM iterations
    n_iter : int
        number of ADMM iterations to conduct

    Returns
    -------
    numpy.ndarray or ctorch.ComplexTensor
        reconstructed m x n image
    float
        objective value at the end of the ADMM iterations (see function adm)
    """

    def image_gradient(x):
        """
        First-order finite-differencing both horizontally and vertically.

        Computes a first-order finite-difference approximation to the gradient.

        Parameters
        ----------
        x : numpy.ndarray or ctorch.ComplexTensor
            image (that is, two-dimensional array)

        Returns
        -------
        numpy.ndarray or ctorch.ComplexTensor
            horizontal finite differences of x stacked on top of the vertical
            finite differences (separating horizontal from vertical via the
            initial dimension)
        """
        if isinstance(x, np.ndarray):
            # Wrap the last column of x around to the beginning.
            x_h = np.hstack((x[:, -1:], x))
            # Wrap the last row of x around to the beginning.
            x_v = np.vstack((x[-1:], x))
            # Apply forward differences to the columns of x.
            d_x = (x_h[:, 1:] - x_h[:, :-1])
            # Apply forward differences to the rows of x.
            d_y = (x_v[1:] - x_v[:-1])
            return np.vstack((d_x.ravel(), d_y.ravel()))
        elif isinstance(x, ctorch.ComplexTensor):
            # Wrap the last column of x around to the beginning.
            x_h = ctorch.cat((x[:, -1:], x), dim=1)
            # Wrap the last row of x around to the beginning.
            x_v = ctorch.cat((x[-1:], x), dim=0)
            # Apply forward differences to the columns of x.
            d_x = (x_h[:, 1:] - x_h[:, :-1])
            # Apply forward differences to the rows of x.
            d_y = (x_v[1:] - x_v[:-1])
            return ctorch.cat((d_x, d_y)).view(2, -1)
        else:
            raise TypeError('Input must be a numpy.ndarray ' +
                            'or a ctorch.ComplexTensor.')

    def image_gradient_T(x):
        """
        Transpose of the operator that function image_gradient implements.

        Computes the transpose of the matrix given by function image_gradient.

        Parameters
        ----------
        x : numpy.ndarray or ctorch.ComplexTensor
            stack of two identically shaped arrays

        Returns
        -------
        numpy.ndarray or ctorch.ComplexTensor
            result of applying to x the transpose of function image_gradient
        """
        if isinstance(x, np.ndarray):
            x_h = x[0]
            x_v = x[1]
            # Wrap the first column of x_h around to the end.
            x_h_ext = np.hstack((x_h, x_h[:, :1]))
            # Wrap the first row of x_v around to the end.
            x_v_ext = np.vstack((x_v, x_v[:1]))
            # Apply forward differences to the columns of x.
            d_x = x_h_ext[:, :-1] - x_h_ext[:, 1:]
            # Apply forward differences to the rows of x.
            d_y = x_v_ext[:-1] - x_v_ext[1:]
            return d_x + d_y
        elif isinstance(x, ctorch.ComplexTensor):
            x_h = x[0]
            x_v = x[1]
            # Wrap the first column of x_h around to the end.
            x_h_ext = ctorch.cat((x_h, x_h[:, :1]), dim=1)
            # Wrap the first row of x_v around to the end.
            x_v_ext = ctorch.cat((x_v, x_v[:1]), dim=0)
            # Apply forward differences to the columns of x.
            d_x = x_h_ext[:, :-1] - x_h_ext[:, 1:]
            # Apply forward differences to the rows of x.
            d_y = x_v_ext[:-1] - x_v_ext[1:]
            return d_x + d_y
        else:
            raise TypeError('Input must be a numpy.ndarray ' +
                            'or a ctorch.ComplexTensor.')

    if isinstance(f, np.ndarray):
        assert mask[0, 0]
        # Rescale f and pad with zeros between the mask samples.
        Ktf = (mu / beta) * zero_padded(m, n, f, mask)
        # Calculate the Fourier transform of the convolutional kernels
        # for finite differences.
        tx = np.abs(np.fft.fft([1, -1] + [0] * (m - 2)))**2
        ty = np.abs(np.fft.fft([1, -1] + [0] * (n - 2)))**2
        # Compute the multipliers required to solve formula (2.8) from Tao-Yang
        # in the Fourier domain. The calculation involves broadcasting the
        # Fourier transform of the convolutional kernel for horizontal finite
        # differences over the vertical directions, and broadcasting the
        # Fourier transform of the convolutional kernel for vertical finite
        # differences over the horizontal directions.
        multipliers = 1. / (ty + tx[:, None] + (mu / beta) * mask)
        # Initialize the primal (x) and dual (la) solutions to zeros.
        x = np.zeros((m, n))
        la = np.zeros((2, m * n))
        # Calculate iterations of alternating minimization.
        for i in range(n_iter):
            # Apply shrinkage via formula (2.7) from Tao-Yang, dividing both
            # arguments of the "max" operator in formula (2.7) by the
            # denominator of the rightmost factor in formula (2.7).
            a = image_gradient(x) + la / beta
            b = scipy.linalg.norm(a, axis=0, keepdims=True)
            if i > 0:
                y = a * np.maximum(1 - 1 / (beta * b), 0)
            else:
                y = np.zeros((2, m * n))
            # Solve formula (2.8) from Tao-Yang in the Fourier domain.
            c = image_gradient_T((y - la / beta).reshape((2, m, n))) + Ktf
            x = np.fft.ifft2(np.fft.fft2(c) * multipliers)
            # Update the Lagrange multipliers via formula (2.9) from Tao-Yang.
            la = la - beta * (y - image_gradient(x))
        # Calculate the loss in formula (1.4) from Tao-Yang...
        loss = np.linalg.norm(image_gradient(x), axis=0).sum()
        # ... adding in the term for the fidelity of the reconstruction.
        loss += np.linalg.norm(
            np.fft.fft2(x)[mask] / np.sqrt(m * n) - f)**2 * (mu / 2)
        # Discard the imaginary part of the primal solution,
        # returning only the real part and the loss.
        return x.real, loss
    elif isinstance(f, ctorch.ComplexTensor):
        assert mask[0, 0]
        # Convert the mask from boolean indicators to long integer indices.
        mask_nnz = torch.nonzero(torch.from_numpy(mask.astype(np.uint8)))
        mask_nnz = mask_nnz.cuda()
        # Rescale f and pad with zeros between the mask samples.
        Ktf = zero_padded(m, n, f, mask_nnz) * (mu / beta)
        # Calculate the Fourier transform of the convolutional kernels
        # for finite differences.
        tx = np.abs(np.fft.fft([1, -1] + [0] * (m - 2)))**2
        ty = np.abs(np.fft.fft([1, -1] + [0] * (n - 2)))**2
        # Compute the multipliers required to solve formula (2.8) from Tao-Yang
        # in the Fourier domain. The calculation involves broadcasting the
        # Fourier transform of the convolutional kernel for horizontal finite
        # differences over the vertical directions, and broadcasting the
        # Fourier transform of the convolutional kernel for vertical finite
        # differences over the horizontal directions.
        multipliers = 1. / (ty + tx[:, None] + (mu / beta) * mask)
        multipliers = ctorch.from_numpy(multipliers).cuda()
        # Initialize the primal (x) and dual (la) solutions to zeros,
        # creating new ctorch tensors of the same type as f.
        x = f.new(m, n).zero_()
        la = f.new(2, m * n).zero_()
        # Calculate iterations of alternating minimization.
        for i in range(n_iter):
            # Apply shrinkage via formula (2.7) from Tao-Yang, dividing both
            # arguments of the "max" operator in formula (2.7) by the
            # denominator of the rightmost factor in formula (2.7).
            a = image_gradient(x) + la / beta
            b = ctorch.norm(a, p=2, dim=0, keepdim=True)
            if i > 0:
                y = a * torch.clamp(1 - 1 / (beta * b), min=0)
            else:
                y = f.new(2, m * n).zero_()
            # Solve formula (2.8) from Tao-Yang in the Fourier domain.
            c = image_gradient_T((y - la / beta).view(2, m, n)) + Ktf
            x = ctorch.ifft2(ctorch.fft2(c) * multipliers)
            # Update the Lagrange multipliers via formula (2.9) from Tao-Yang.
            la = la - (y - image_gradient(x)) * beta
        # Calculate the loss in formula (1.4) from Tao-Yang...
        loss = ctorch.norm(image_gradient(x), p=2, dim=0).sum()
        # ... adding in the term for the fidelity of the reconstruction.
        ftx = ctorch.fft2(x) / math.sqrt(m * n)
        mask_flat = mask_nnz[:, 0] * n + mask_nnz[:, 1]
        loss += ctorch.norm(ftx.view(-1)[mask_flat] - f)**2 * (mu / 2)
        # Discard the imaginary part of the primal solution,
        # returning only the real part and the loss.
        return x.real, loss.cpu().item()
    else:
        raise TypeError('Input must be a numpy.ndarray ' +
                        'or a ctorch.ComplexTensor.')
示例#4
0
def runtestadmm(method, cpu, filename, mu=1e12, beta=1, subsampling_factor=0.7,
                n_iter=100, seed=None):
    """Run tests as specified.

    Use the specified method (on CPUs if cpu is True), reading the image from
    filename, with the lasso regularization parameter mu and the ADMM coupling
    parameter beta, subsampling by subsampling_factor, for n_iter iterations
    of ADMM, seeding the random number generator with the provided seed.

    Parameters
    ----------
    method : str
        which algorithm to use ('cs_baseline' or 'cs_fft')
    cpu : boolean
        set to true to perform all computations on the CPU(s)
    filename : str
        name of the file containing the image to be processed; prepend a path
        if the file resides outside the working directory
    mu : float
        regularization parameter
    beta : float
        coupling parameter for the ADMM iterations
    subsampling_factor : float
        probability of retaining an entry in k-space
    n_iter : int
        number of ADMM iterations to conduct
    seed : int
        seed value for numpy's random number generator

    Returns
    -------
    float
        objective value at the end of the ADMM iterations (see function adm)
    """

    def tic():
        """
        Timing starting.

        Records the current time.

        Returns
        -------
        float
            present time in fractional seconds
        """
        torch.cuda.synchronize()
        return time.perf_counter()

    def toc(t):
        """
        Timing stopping.

        Reports the difference of the current time from the reference provided.

        Parameters
        ----------
        t : float
            reference time in fractional seconds

        Returns
        -------
        float
            difference of the present time from the reference t
        """
        torch.cuda.synchronize()
        return time.perf_counter() - t

    # Fix the random seed if appropriate.
    np.random.seed(seed=seed)
    # Read the image from disk.
    f_orig = np.array(Image.open(filename)).astype(np.float64) / 255.
    m = f_orig.shape[0]
    n = f_orig.shape[1]
    # Select which k-space frequencies to retain.
    mask = np.random.uniform(size=(m, n)) < subsampling_factor
    # Make the optimization well-posed by including the zero frequency.
    mask[0, 0] = True
    # Subsample the Fourier transform of the original image.
    f = np.fft.fft2(f_orig)[mask] / np.sqrt(m * n)
    # Start timing.
    t = tic()
    # Reconstruct the image from the undersampled Fourier data.
    print('Running {}(cpu={}, mu={}, beta={}, n_iter={})'.format(
        method, cpu, mu, beta, n_iter))
    if method == 'cs_baseline':
        if cpu:
            x, loss = cs_baseline(m, n, f, mask, mu=mu, beta=beta,
                                  n_iter=n_iter)
        else:
            raise NotImplementedError('A baseline on GPUs is not implemented' +
                                      '; use \'--cpu\'')
    elif method == 'cs_fft':
        if cpu:
            x, loss = cs_fft(m, n, f, mask, mu=mu, beta=beta, n_iter=n_iter)
        else:
            # Move the Fourier data to the GPUs.
            f_th = ctorch.from_numpy(f).cuda()
            # The first call to `ctorch.fft2` is slow;
            # run a dummy fft2 and restart the timer to get accurate timings.
            ctorch.fft2(ctorch.from_numpy(np.fft.fft2(f_orig)).cuda())
            t = tic()
            x, loss = cs_fft(m, n, f_th, mask, mu=mu, beta=beta, n_iter=n_iter)
            x = x.cpu().numpy()
    else:
        raise NotImplementedError('method must be either \'cs_baseline\' ' +
                                  'or \'cs_fft\'')
    # Stop timing.
    tt = toc(t)
    # Print the time taken and final loss.
    print('time={}s'.format(tt))
    print('loss={}'.format(loss))
    # Plot the original image, its reconstruction, and the sampling pattern.
    plt.figure(figsize=(12, 12))
    plt.subplot(221)
    plt.title('Original')
    plt.imshow(f_orig, cmap='gray')
    plt.subplot(222)
    plt.title('Compressed sensing reconstruction')
    plt.imshow(x.reshape(m, n), cmap='gray')
    plt.subplot(223)
    plt.title('Naive (zero-padded ifft2) reconstruction')
    plt.imshow(np.abs(zero_padded(m, n, f, mask)), cmap='gray')
    plt.subplot(224)
    plt.title('Sampling mask')
    plt.imshow(mask, cmap='gray')
    plt.savefig('recon2.png', bbox_inches='tight')
    return loss
示例#5
0
def bootstrap(filein, fileout, subsampling_factor, mask, low, recon,
              recon_args, n_resamps=None, sdn=None, viz=None):
    """
    plots thrice the average of bootstrapped errors in reconstruction

    Plots and saves thrice the average of the differences between
    a reconstruction of the original image in filein and n_resamps bootstrap
    reconstructions via recon applied to the k-space subsamplings specified
    by mask and by other masks generated similarly (retaining each row of
    the image with probability given by subsampling_factor, then adding all
    frequencies between -low and low), corrupting the k-space values with
    independent and identically distributed centered complex Gaussian noise
    whose standard deviation is sdn*sqrt(2) (sdn=0 if not provided explicitly).
    Setting viz to be True yields colorized visualizations, too, including
    the error estimates overlaid over the reconstruction, the error estimates
    blurred overlaid over the reconstruction, the error estimates blurred,
    the error estimates subtracted from the reconstruction, the error estimates
    saturating the reconstruction in hue-saturation-value (HSV) color space,
    and the error estimates interpolating the reconstruction in HSV space.

    The calling sequence of recon must be  (m, n, f, mask_th, **recon_args),
    where filein contains an m x n image, f is the image in k-space subsampled
    to the mask, mask_th = torch.from_numpy(mask.astype(np.unit8)).cuda(), and
    **recon_args is the unpacking of recon_args. The function recon must return
    a torch.Tensor (the reconstruction) and a float (the corresponding loss).

    _N.B._: mask[-low+1], mask[-low+2], ..., mask[low-1] must be True.

    Parameters
    ----------
    filein : str
        path to the file containing the image to be processed (the path may be
        relative or absolute)
    fileout : str
        path to the file to which the plots will be saved (the path may be
        relative or absolute)
    subsampling_factor : float
        probability of retaining a row in the subsampling masks
    mask : ndarray of bool
        indicators of whether to include (True) or exclude (False)
        the corresponding rows in k-space of the image from filein
    low : int
        bandwidth of low frequencies included in mask (between -low to low)
    recon : function
        returns the reconstructed image
    recon_args : dict
        keyword arguments for recon
    n_resamps : int, optional
        number of bootstrap resampled reconstructions (defaults to 100)
    sdn : float, optional
        standard deviation of the noise to add (defaults to 0)
    viz : bool, optional
        indicator of whether to generate colorized visualizations
        (defaults to False)

    Returns
    -------
    float
        loss for the reconstruction using the original mask
    list of float
        losses for the reconstructions using other, randomly generated masks
    float
        square root of the sum of the square of the estimated errors
    float
        square root of the sum of the square of the estimated errors blurred
    """
    # Set default parameters.
    if n_resamps is None:
        n_resamps = 100
    if sdn is None:
        sdn = 0
    if viz is None:
        viz = False
    # Check that the mask includes all low frequencies.
    for k in range(low):
        assert mask[k]
        assert mask[-k]
    # Read the image from disk.
    with Image.open(filein) as img:
        f_orig = np.array(img).astype(np.float64) / 255.
    m = f_orig.shape[0]
    n = f_orig.shape[1]
    # Fourier transform the image.
    ff_orig = np.fft.fft2(f_orig) / np.sqrt(m * n)
    # Add noise.
    ff_noisy = ff_orig.copy()
    ff_noisy += sdn * (np.random.randn(m, n) + 1j * np.random.randn(m, n))
    # Subsample the noisy Fourier transform of the original image.
    f = ctorch.from_numpy(ff_noisy[mask]).cuda()
    logging.info('computing bootstrap resamplings -- all {}'.format(n_resamps))
    # Perform the reconstruction using the mask.
    mask_th = torch.from_numpy(mask.astype(np.uint8)).cuda()
    reconf, lossf = recon(m, n, f, mask_th, **recon_args)
    reconf = reconf.cpu().numpy()
    # Fourier transform the reconstruction.
    freconf = np.fft.fft2(reconf) / np.sqrt(m * n)
    # Perform the reconstruction resampling new masks and samples in k-space.
    recons = np.ndarray((n_resamps, m, n))
    loss = []
    for k in range(n_resamps):
        # Select which frequencies to retain.
        maski = set(np.floor(
            m * np.random.uniform(size=round(m * subsampling_factor))))
        mask1 = np.asarray([False] * m, dtype=bool)
        for i in maski:
            mask1[int(i)] = True
        # Make the optimization well-posed by including the zero frequency.
        mask1[0] = True
        # Include all low frequencies.
        for j in range(low):
            mask1[j] = True
            mask1[-j] = True
        # Subsample the Fourier transform of the reconstruction.
        f1 = ctorch.from_numpy(freconf[mask1]).cuda()
        # Reconstruct the image from the subsampled data.
        mask1_th = torch.from_numpy(mask1.astype(np.uint8)).cuda()
        recon1, loss1 = recon(m, n, f1, mask1_th, **recon_args)
        recon1 = recon1.cpu().numpy()
        # Record the results.
        recons[k, :, :] = recon1
        loss.append(loss1)
    # Calculate the sum of the bootstrap differences.
    sumboo = np.sum(recons - reconf, axis=0)
    scaled = sumboo * 3 / n_resamps
    # Blur the error estimates.
    sigma = 1
    blurred = skimage.filters.gaussian(scaled, sigma=sigma)
    rsse_estimated = np.linalg.norm(scaled, ord='fro')
    rsse_blurred = np.linalg.norm(blurred, ord='fro')

    # Plot errors.
    # Remove the ticks and spines on the axes.
    matplotlib.rcParams['xtick.top'] = False
    matplotlib.rcParams['xtick.bottom'] = False
    matplotlib.rcParams['ytick.left'] = False
    matplotlib.rcParams['ytick.right'] = False
    matplotlib.rcParams['xtick.labeltop'] = False
    matplotlib.rcParams['xtick.labelbottom'] = False
    matplotlib.rcParams['ytick.labelleft'] = False
    matplotlib.rcParams['ytick.labelright'] = False
    matplotlib.rcParams['axes.spines.top'] = False
    matplotlib.rcParams['axes.spines.bottom'] = False
    matplotlib.rcParams['axes.spines.left'] = False
    matplotlib.rcParams['axes.spines.right'] = False
    # Configure the colormaps.
    kwargs01 = dict(cmap='gray',
                    norm=matplotlib.colors.Normalize(vmin=0, vmax=1))
    kwargs11 = dict(cmap='gray',
                    norm=matplotlib.colors.Normalize(vmin=-1, vmax=1))
    # Separate the suffix (filetype) from the rest of the filename.
    suffix = '.' + fileout.split('.')[-1]
    rest = fileout[:-len(suffix)]
    assert fileout == rest + suffix
    # Plot the original.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Original')
    plt.imshow(np.clip(f_orig, 0, 1), **kwargs01)
    plt.savefig(rest + '_original' + suffix, bbox_inches='tight')
    # Plot the reconstruction from the original mask provided.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Reconstruction')
    plt.imshow(np.clip(reconf, 0, 1), **kwargs01)
    plt.savefig(rest + '_recon' + suffix, bbox_inches='tight')
    # Plot the difference from the original.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Error of Reconstruction')
    plt.imshow(np.clip(reconf - f_orig, -1, 1), **kwargs11)
    plt.savefig(rest + '_error' + suffix, bbox_inches='tight')
    # Plot thrice the average of the bootstrap differences.
    plt.figure(figsize=(5.5, 5.5))
    plt.title('Bootstrap')
    plt.imshow(np.clip(scaled, -1, 1), **kwargs11)
    plt.savefig(rest + '_bootstrap' + suffix, bbox_inches='tight')

    if viz:
        # Plot the reconstruction minus the bootstrap difference.
        plt.figure(figsize=(5.5, 5.5))
        plt.title('Reconstruction \u2013 Bootstrap')
        plt.imshow(np.clip(reconf - scaled, 0, 1), **kwargs01)
        plt.savefig(rest + '_corrected' + suffix, bbox_inches='tight')
        # Overlay the error estimates on the reconstruction.
        plt.figure(figsize=(5.5, 5.5))
        threshold = np.abs(scaled).flatten()
        threshold = np.sort(threshold)
        maxthresh = threshold[-1]
        threshold = threshold[round(0.98 * threshold.size)]
        hue = 2. / 3 + (scaled / maxthresh) / 4 * 2 / 3
        saturation = np.abs(scaled) > threshold
        value = reconf * (1 - saturation) + saturation
        hsv = np.dstack((hue, saturation, value))
        rgb = hsv_to_rgb(hsv)
        plt.title('Errors Over a Threshold Overlaid')
        plt.imshow(np.clip(rgb, 0, 1))
        plt.savefig(rest + '_overlaid' + suffix, bbox_inches='tight')
        # Overlay the blurred error estimates on the reconstruction.
        plt.figure(figsize=(5.5, 5.5))
        threshold = np.abs(blurred).flatten()
        threshold = np.sort(threshold)
        maxthresh = threshold[-1]
        threshold = threshold[round(0.98 * threshold.size)]
        hue = 2. / 3 + (blurred / maxthresh) / 4 * 2 / 3
        saturation = np.abs(blurred) > threshold
        value = reconf * (1 - saturation) + saturation
        hsv = np.dstack((hue, saturation, value))
        rgb = hsv_to_rgb(hsv)
        plt.title('Blurred Errors Over a Threshold Overlaid')
        plt.imshow(np.clip(rgb, 0, 1))
        plt.savefig(rest + '_blurred_overlaid' + suffix, bbox_inches='tight')
        # Plot a bootstrap-saturated reconstruction.
        plt.figure(figsize=(5.5, 5.5))
        hue = (1 - np.sign(scaled)) / 4 * 2 / 3
        saturation = np.abs(scaled)
        saturation = saturation / np.max(saturation)
        value = np.clip(reconf, 0, 1)
        hsv = np.dstack((hue, saturation, value))
        rgb = hsv_to_rgb(hsv)
        plt.title('Bootstrap-Saturated Reconstruction')
        plt.imshow(np.clip(rgb, 0, 1))
        plt.savefig(rest + '_saturated' + suffix, bbox_inches='tight')
        # Plot a bootstrap-interpolated reconstruction.
        plt.figure(figsize=(5.5, 5.5))
        hue = 7. / 12 + np.sign(scaled) * 3 / 12
        saturation = np.abs(scaled)
        saturation = saturation / np.max(saturation)
        value = np.clip(reconf, 0, 1)
        hsv = np.dstack((hue, saturation, value))
        rgb = hsv_to_rgb(hsv)
        plt.title('Bootstrap-Interpolated Reconstruction')
        plt.imshow(np.clip(rgb, 0, 1))
        plt.savefig(rest + '_interpolated' + suffix, bbox_inches='tight')
        # Plot the blurred bootstrap.
        plt.figure(figsize=(5.5, 5.5))
        plt.title('Blurred Bootstrap')
        plt.imshow(np.clip(blurred, -1, 1), **kwargs11)
        plt.savefig(rest + '_blurred' + suffix, bbox_inches='tight')

    return lossf, loss, rsse_estimated, rsse_blurred