Esempio n. 1
0
def read_image_and_map_and_apply_map(image_filename,map_filename):
    """
    Reads an image and a map and applies the map to an image
    :param image_filename: input image filename
    :param map_filename: input map filename
    :return: the warped image and its image header as a tupe (im,hdr)
    """

    im_warped = None
    map,map_hdr = fileio.MapIO().read(map_filename)
    im,hdr,_,_ = fileio.ImageIO().read_to_map_compatible_format(image_filename,map)

    spacing = hdr['spacing']
    #TODO: check that the spacing is compatible with the map

    if (im is not None) and (map is not None):
        # make pytorch arrays for subsequent processing
        im_t = AdaptVal(torch.from_numpy(im))
        map_t = AdaptVal(torch.from_numpy(map))
        im_warped = utils.t2np( utils.compute_warped_image_multiNC(im_t,map_t,spacing) )

        return im_warped,hdr
    else:
        print('Could not read map or image')
        return None,None
    def createImage(self, ex_len=64):

        example_img_len = ex_len
        dim = 2
        szEx = np.tile(example_img_len,
                       dim)  # size of the desired images: (sz)^dim
        I0, I1, self.spacing = eg.CreateSquares(dim).create_image_pair(
            szEx,
            self.params)  # create a default image size with two sample squares
        self.sz = np.array(I0.shape)

        # create the source and target image as pyTorch variables
        self.ISource = AdaptVal(torch.from_numpy(I0.copy()))
        self.ITarget = AdaptVal(torch.from_numpy(I1))

        # smooth both a little bit
        self.params[('image_smoothing', {}, 'image smoothing settings')]
        self.params['image_smoothing'][(
            'smooth_images', True,
            '[True|False]; smoothes the images before registration')]
        self.params['image_smoothing'][('smoother', {},
                                        'settings for the image smoothing')]
        self.params['image_smoothing']['smoother'][(
            'gaussian_std', 0.05, 'how much smoothing is done')]
        self.params['image_smoothing']['smoother'][(
            'type', 'gaussian', "['gaussianSpatial'|'gaussian'|'diffusion']")]

        cparams = self.params['image_smoothing']
        s = SF.SmootherFactory(self.sz[2::],
                               self.spacing).create_smoother(cparams)
        self.ISource = s.smooth(self.ISource)
        self.ITarget = s.smooth(self.ITarget)
Esempio n. 3
0
def invert_map(map, spacing):
    """
    Inverts the map and returns its inverse. Assumes standard map parameterization [-1,1]^d
    :param map: Input map to be inverted
    :return: inverted map
    """
    # make pytorch arrays for subsequent processing
    map_t = AdaptVal(torch.from_numpy(map))

    # identity map
    id = utils.identity_map_multiN(map_t.data.shape, spacing)
    id_t = AdaptVal(torch.from_numpy(id))

    # parameter to store the inverse map
    invmap_t = AdaptVal(Parameter(torch.from_numpy(id.copy())))

    # some optimizer settings, probably too strict
    nr_of_iterations = 1000
    rel_ftol = 1e-6
    optimizer = CO.LBFGS_LS([invmap_t],
                            lr=1.0,
                            max_iter=1,
                            tolerance_grad=rel_ftol * 10,
                            tolerance_change=rel_ftol,
                            max_eval=5,
                            history_size=5,
                            line_search_fn='backtracking')

    #optimizer = torch.optim.SGD([invmap_t], lr=0.001, momentum=0.9, dampening=0, weight_decay=0,nesterov=True)

    def compute_loss():
        # warps map_t with inv_map, if it is the inverse should result in the identity map
        wmap = utils.compute_warped_image_multiNC(map_t, invmap_t, spacing)
        current_loss = ((wmap - id_t)**2).sum()
        return current_loss

    def _closure():
        optimizer.zero_grad()
        loss = compute_loss()
        loss.backward()
        return loss

    last_loss = utils.t2np(compute_loss())

    for iter in range(nr_of_iterations):
        optimizer.step(_closure)
        current_loss = utils.t2np(compute_loss())
        print('Iter = ' + str(iter) + '; E = ' + str(current_loss))
        if (current_loss >= last_loss):
            break
        else:
            last_loss = current_loss

    return utils.t2np(invmap_t)
Esempio n. 4
0
def compute_average_image(images):
    im_io = FIO.ImageIO()
    Iavg = None
    for nr, im_name in enumerate(images):
        Ic, hdrc, spacing, _ = im_io.read_to_nc_format(filename=im_name)
        if nr == 0:
            Iavg = AdaptVal(torch.from_numpy(Ic))
        else:
            Iavg += AdaptVal(torch.from_numpy(Ic))
    Iavg = Iavg / len(images)
    return Iavg, spacing
Esempio n. 5
0
def add_texture_on_img(im_orig,
                       texture_gaussian_smoothness=0.1,
                       texture_magnitude=0.3):

    # do this separately for each integer intensity level
    levels = np.unique((np.floor(im_orig)).astype('int'))

    im = np.zeros_like(im_orig)

    for current_level in levels:

        sz = im_orig.shape
        rand_noise = np.random.random(sz[2:]).astype('float32') - 0.5
        rand_noise = rand_noise.view().reshape(sz)
        r_params = pars.ParameterDict()
        r_params['smoother']['type'] = 'gaussian'
        r_params['smoother']['gaussian_std'] = texture_gaussian_smoothness
        spacing = 1.0 / (np.array(sz[2:]).astype('float32') - 1)
        s_r = sf.SmootherFactory(sz[2::], spacing).create_smoother(r_params)

        rand_noise_smoothed = s_r.smooth(AdaptVal(
            torch.from_numpy(rand_noise))).detach().cpu().numpy()
        rand_noise_smoothed /= rand_noise_smoothed.max()
        rand_noise_smoothed *= texture_magnitude

        c_indx = (im_orig >= current_level - 0.5)
        im[c_indx] = im_orig[c_indx] + rand_noise_smoothed[c_indx]

    return torch.Tensor(im)
Esempio n. 6
0
def get_labeled_image(img_pair):
    s_path, t_path = img_pair
    moving = torch.load(s_path)
    target = torch.load(t_path)
    moving_np = moving.cpu().numpy()
    target_np = target.cpu().numpy()
    ind_value_list = np.unique(moving_np)
    ind_value_list_target = np.unique(target_np)
    assert len(set(ind_value_list) - set(ind_value_list_target)) == 0
    lmoving = torch.zeros_like(moving)
    ltarget = torch.zeros_like(target)
    ind_value_list.sort()
    for i, value in enumerate(ind_value_list):
        lmoving[moving == value] = i
        ltarget[target == value] = i

    return AdaptVal(lmoving), AdaptVal(ltarget)
    def __init__(self):

        dx = AdaptVal(
            torch.Tensor([[[-1., -3., -1.], [-3., -6., -3.], [-1., -3., -1.]],
                          [[0., 0., 0.], [0., 0, 0.], [0., 0., 0.]],
                          [[1., 3., 1.], [3., 6., 3.],
                           [1., 3., 1.]]])).view(1, 1, 3, 3, 3)
        dy = AdaptVal(
            torch.Tensor([[[1., 3., 1.], [0., 0., 0.], [-1., -3., -1.]],
                          [[3., 6., 3.], [0., 0, 0.], [-3., -6., -3.]],
                          [[1., 3., 1.], [0., 0., 0.],
                           [-1., -3., -1.]]])).view(1, 1, 3, 3, 3)
        dz = AdaptVal(
            torch.Tensor([[[-1., 0., 1.], [-3., 0., 3.], [-1., 0., 1.]],
                          [[-3., 0., 3.], [-6., 0, 6.], [-3., 0., 3.]],
                          [[-1., 0., 1.], [-3., 0., 3.],
                           [-1., 0., 1.]]])).view(1, 1, 3, 3, 3)
        self.spatial_filter = torch.cat((dx, dy, dz), 0)
        self.spatial_filter = self.spatial_filter.repeat(1, 1, 1, 1, 1)
 def __get_smoothed_target(self, I0):
     ITarget = AdaptVal(torch.from_numpy(I0.copy()))
     # cparams = pars.ParameterDict()
     # cparams[('smoother', {})]
     # cparams['smoother']['type'] = 'gaussianSpatial'
     # cparams['smoother']['gaussianStd'] = 0.005
     # s = SF.SmootherFactory(sz[2::], spacing).create_smoother(cparams)
     # ITarget = s.smooth(ITarget).detach()
     ITarget = self.fourier_smoother(ITarget).detach()
     return ITarget
Esempio n. 9
0
def upsample_to_compatible_size_single_image(gt_weight,
                                             weight,
                                             interpolation_order=1):
    # upsample the weights if needed
    if gt_weight.shape == weight.shape:
        return weight
    else:
        sampler = IS.ResampleImage()

        weight_sz = weight.shape
        weight_reshaped = AdaptVal(
            torch.from_numpy(weight.view().reshape([1, 1] +
                                                   list(weight_sz))).float())
        spacing = np.array([1., 1.])
        desired_size = gt_weight.shape

        weight_upsampled_t, _ = sampler.upsample_image_to_size(
            weight_reshaped, spacing, desired_size, interpolation_order)
        weight_upsampled = weight_upsampled_t.detach().cpu().numpy()

        return weight_upsampled
Esempio n. 10
0
def downsample_to_compatible_size_single_image(gt_weight,
                                               weight,
                                               interpolation_order=3):
    # downsample the ground truth weights if needed
    if gt_weight.shape == weight.shape:
        return gt_weight
    else:
        sampler = IS.ResampleImage()

        gt_weight_sz = gt_weight.shape
        gt_weight_reshaped = AdaptVal(
            torch.from_numpy(
                gt_weight.view().reshape([1, 1] + list(gt_weight_sz))).float())
        spacing = np.array([1., 1.])
        desired_size = weight.shape

        gt_weight_downsampled_t, _ = sampler.downsample_image_to_size(
            gt_weight_reshaped, spacing, desired_size, interpolation_order)
        gt_weight_downsampled = gt_weight_downsampled_t.detach().cpu().numpy()

        return gt_weight_downsampled
Esempio n. 11
0
def resample_image(I,
                   spacing,
                   desiredSize,
                   spline_order=1,
                   zero_boundary=False,
                   identity_map=None):
    """
    Resample an image to a given desired size

    :param I: Input image (expected to be of BxCxXxYxZ format)
    :param spacing: array describing the spatial spacing
    :param desiredSize: array for the desired size (excluding B and C, i.e, 1 entry for 1D, 2 for 2D, and 3 for 3D)
    :return: returns a tuple: the downsampled image, the new spacing after downsampling
    """
    desiredSize = desiredSize[2:]
    sz = np.array(list(I.size()))
    # check that the batch size and the number of channels is the same
    nrOfI = sz[0]
    nrOfC = sz[1]

    desiredSizeNC = np.array([nrOfI, nrOfC] + list(desiredSize))

    newspacing = spacing * ((sz[2::].astype('float') - 1.) /
                            (desiredSizeNC[2::].astype('float') - 1.)
                            )  ###########################################
    if identity_map is not None:
        idDes = identity_map
    else:
        idDes = AdaptVal(
            torch.from_numpy(
                py_utils.identity_map_multiN(desiredSizeNC, newspacing)))
    # now use this map for resampling
    ID = py_utils.compute_warped_image_multiNC(I, idDes, newspacing,
                                               spline_order, zero_boundary)

    return ID, newspacing
Esempio n. 12
0
def compare_det_of_jac_from_map(map,
                                gt_map,
                                label_image,
                                visualize=False,
                                print_output_directory=None,
                                clean_publication_directory=None,
                                pair_nr=None):

    sz = np.array(map.shape[2:])
    # synthetic spacing
    spacing = np.array(1. / (sz - 1))

    map_torch = AdaptVal(torch.from_numpy(map).float())
    gt_map_torch = AdaptVal(torch.from_numpy(gt_map).float())

    det_est = eu.compute_determinant_of_jacobian(map_torch, spacing)
    det_gt = eu.compute_determinant_of_jacobian(gt_map_torch, spacing)

    n = det_est - det_gt

    if visualize:

        if clean_publication_directory is None:
            plt.clf()

            plt.subplot(131)
            plt.imshow(det_gt)
            plt.colorbar()
            plt.title('det_gt')

            plt.subplot(132)
            plt.imshow(det_est)
            plt.colorbar()
            plt.title('det_est')

            plt.subplot(133)
            plt.imshow(n)
            plt.colorbar()
            plt.title('det_est - det_gt')

            if print_output_directory is None:
                plt.show()
            else:
                plt.savefig(
                    os.path.join(
                        print_output_directory,
                        '{:0>3d}'.format(pair_nr) + '_det_jac_validation.pdf'))

        if clean_publication_directory is not None:
            plt.clf()
            plt.imshow(det_gt)
            plt.colorbar()
            plt.axis('image')
            plt.axis('off')
            plt.savefig(os.path.join(
                clean_publication_directory,
                'det_gt_{:0>3d}'.format(pair_nr) + '_det_jac_validation.pdf'),
                        bbox_inches='tight',
                        pad_inches=0)

            plt.clf()
            plt.imshow(det_est)
            plt.colorbar()
            plt.axis('image')
            plt.axis('off')
            plt.savefig(os.path.join(
                clean_publication_directory,
                'det_est_{:0>3d}'.format(pair_nr) + '_det_jac_validation.pdf'),
                        bbox_inches='tight',
                        pad_inches=0)

            plt.clf()
            plt.imshow(n)
            plt.colorbar()
            plt.axis('image')
            plt.axis('off')
            plt.savefig(os.path.join(
                clean_publication_directory,
                'det_est_m_det_gt_{:0>3d}'.format(pair_nr) +
                '_det_jac_validation.pdf'),
                        bbox_inches='tight',
                        pad_inches=0)

    ds = compute_image_stats(n, label_image)
    return ds
Esempio n. 13
0
def build_atlas(images, nr_of_cycles, warped_images, temp_folder, visualize):
    si = SI.RegisterImagePair()
    im_io = FIO.ImageIO()

    # compute first average image
    Iavg, sp = compute_average_image(images)
    Iavg = Iavg.data

    if visualize:
        plt.imshow(AdaptVal(Iavg[0, 0, ...]).detach().cpu().numpy(),
                   cmap='gray')
        plt.title('Initial average based on ' + str(len(images)) + ' images')
        plt.colorbar()
        plt.show()

    # initialize list to save model parameters in between cycles
    mp = []

    # register all images to the average image and while doing so compute a new average image
    for c in range(nr_of_cycles):
        print('Starting cycle ' + str(c + 1) + '/' + str(nr_of_cycles))
        for i, im_name in enumerate(images):
            print('Registering image ' + str(i) + '/' + str(len(images)))
            Ic, hdrc, spacing, _ = im_io.read_to_nc_format(filename=im_name)

            # set former model parameters if available
            if c != 0:
                si.set_model_parameters(mp[i])

            # register current image to average image
            si.register_images(Ic,
                               AdaptVal(Iavg).detach().cpu().numpy(),
                               spacing,
                               model_name='svf_scalar_momentum_map',
                               map_low_res_factor=0.5,
                               nr_of_iterations=5,
                               visualize_step=None,
                               similarity_measure_sigma=0.5)
            wi = si.get_warped_image()

            # save current model parametrs for the next circle
            if c == 0:
                mp.append(si.get_model_parameters())
            elif c != nr_of_cycles - 1:
                mp[i] = si.get_model_parameters()

            if c == nr_of_cycles - 1:  # last time this is run, so let's save the image
                current_filename = warped_images + '/atlas_reg_Image' + str(
                    i + 1).zfill(4) + '.nrrd'
                print("writing image " + str(i + 1))
                im_io.write(current_filename, wi, hdrc)

            if i == 0:
                newAvg = wi.data
            else:
                newAvg += wi.data

        Iavg = newAvg / len(images)

        if visualize:
            plt.imshow(AdaptVal(Iavg[0, 0, ...]).detach().cpu().numpy(),
                       cmap='gray')
            plt.title('Average ' + str(c + 1) + '/' + str(nr_of_cycles))
            plt.colorbar()
            plt.show()
    return Iavg
Esempio n. 14
0
def warp_data(data_list):
    return [AdaptVal(data) for data in data_list]
Esempio n. 15
0
import mermaid.smoother_factory as SF

import matplotlib.pyplot as plt

params = MP.ParameterDict()

dim = 2

szEx = np.tile(64, dim)
I0, I1, spacing = eg.CreateSquares(dim).create_image_pair(
    szEx, params)  # create a default image size with two sample squares
sz = np.array(I0.shape)

# create the source and target image as pyTorch variables
ISource = AdaptVal(torch.from_numpy(I0.copy()))

smoother = SF.MySingleGaussianFourierSmoother(sz[2:], spacing, params)

g_std = smoother.get_gaussian_std()

ISmooth = smoother.smooth_scalar_field(ISource)
ISmooth.backward(torch.ones_like(ISmooth))
#ISmooth.backward(torch.zeros_like(ISmooth))

print('g_std.grad')
print(g_std.grad)

plt.subplot(121)
plt.imshow(utils.t2np(ISource[0, 0, ...]))
    params['square_example_images'] = ({},
                                       'Settings for example image generation')
    params['square_example_images']['len_s'] = int(szEx.min() // 6)
    params['square_example_images']['len_l'] = int(szEx.max() // 4)

    # create a default image size with two sample squares
    I0, I1, spacing = eg.CreateSquares(ds.dim).create_image_pair(szEx, params)

sz = np.array(I0.shape)

assert (len(sz) == ds.dim + 2)

print('Spacing = ' + str(spacing))

# create the source and target image as pyTorch variables
ISource = AdaptVal(torch.from_numpy(I0.copy()))
ITarget = AdaptVal(torch.from_numpy(I1))

# if desired we smooth them a little bit
if ds.smooth_images:
    # smooth both a little bit
    params['image_smoothing'] = ds.par_algconf['image_smoothing']
    cparams = params['image_smoothing']
    s = SF.SmootherFactory(sz[2::], spacing).create_smoother(cparams)
    ISource = s.smooth(ISource)
    ITarget = s.smooth(ITarget)

##############################3
# Setting up the optimizer
# ^^^^^^^^^^^^^^^^^^^^^^^^
#
Esempio n. 17
0
def do_registration( I0_name, I1_name, visualize, visualize_step, use_multi_scale, normalize_spacing, normalize_intensities, squeeze_image, par_algconf ):

    from mermaid.data_wrapper import AdaptVal
    import mermaid.smoother_factory as SF
    import mermaid.multiscale_optimizer as MO
    from mermaid.config_parser import nr_of_threads

    params = pars.ParameterDict()

    par_image_smoothing = par_algconf['algconf']['image_smoothing']
    par_model = par_algconf['algconf']['model']
    par_optimizer = par_algconf['algconf']['optimizer']

    use_map = par_model['deformation']['use_map']
    map_low_res_factor = par_model['deformation']['map_low_res_factor']
    model_name = par_model['deformation']['name']

    if use_map:
        model_name = model_name + '_map'
    else:
        model_name = model_name + '_image'

    # general parameters
    params['model']['registration_model'] = par_algconf['algconf']['model']['registration_model']

    torch.set_num_threads( nr_of_threads )
    print('Number of pytorch threads set to: ' + str(torch.get_num_threads()))

    I0, I1, spacing, md_I0, md_I1 = read_images( I0_name, I1_name, normalize_spacing, normalize_intensities,squeeze_image )
    sz = I0.shape

    # create the source and target image as pyTorch variables
    ISource = AdaptVal(torch.from_numpy(I0.copy()))
    ITarget = AdaptVal(torch.from_numpy(I1))

    smooth_images = par_image_smoothing['smooth_images']
    if smooth_images:
        # smooth both a little bit
        params['image_smoothing'] = par_algconf['algconf']['image_smoothing']
        cparams = params['image_smoothing']
        s = SF.SmootherFactory(sz[2::], spacing).create_smoother(cparams)
        ISource = s.smooth_scalar_field(ISource)
        ITarget = s.smooth_scalar_field(ITarget)

    if not use_multi_scale:
        # create multi-scale settings for single-scale solution
        multi_scale_scale_factors = [1.0]
        multi_scale_iterations_per_scale = [par_optimizer['single_scale']['nr_of_iterations']]
    else:
        multi_scale_scale_factors = par_optimizer['multi_scale']['scale_factors']
        multi_scale_iterations_per_scale = par_optimizer['multi_scale']['scale_iterations']

    mo = MO.MultiScaleRegistrationOptimizer(sz, spacing, use_map, map_low_res_factor, params)

    optimizer_name = par_optimizer['name']

    mo.set_optimizer_by_name(optimizer_name)
    mo.set_visualization(visualize)
    mo.set_visualize_step(visualize_step)

    mo.set_model(model_name)

    mo.set_source_image(ISource)
    mo.set_target_image(ITarget)

    mo.set_scale_factors(multi_scale_scale_factors)
    mo.set_number_of_iterations_per_scale(multi_scale_iterations_per_scale)

    # and now do the optimization
    mo.optimize()

    optimized_energy = mo.get_energy()
    warped_image = mo.get_warped_image()
    optimized_map = mo.get_map()
    optimized_reg_parameters = mo.get_model_parameters()

    return warped_image, optimized_map, optimized_reg_parameters, optimized_energy, params, md_I0