def read_image_and_map_and_apply_map(image_filename,map_filename): """ Reads an image and a map and applies the map to an image :param image_filename: input image filename :param map_filename: input map filename :return: the warped image and its image header as a tupe (im,hdr) """ im_warped = None map,map_hdr = fileio.MapIO().read(map_filename) im,hdr,_,_ = fileio.ImageIO().read_to_map_compatible_format(image_filename,map) spacing = hdr['spacing'] #TODO: check that the spacing is compatible with the map if (im is not None) and (map is not None): # make pytorch arrays for subsequent processing im_t = AdaptVal(torch.from_numpy(im)) map_t = AdaptVal(torch.from_numpy(map)) im_warped = utils.t2np( utils.compute_warped_image_multiNC(im_t,map_t,spacing) ) return im_warped,hdr else: print('Could not read map or image') return None,None
def createImage(self, ex_len=64): example_img_len = ex_len dim = 2 szEx = np.tile(example_img_len, dim) # size of the desired images: (sz)^dim I0, I1, self.spacing = eg.CreateSquares(dim).create_image_pair( szEx, self.params) # create a default image size with two sample squares self.sz = np.array(I0.shape) # create the source and target image as pyTorch variables self.ISource = AdaptVal(torch.from_numpy(I0.copy())) self.ITarget = AdaptVal(torch.from_numpy(I1)) # smooth both a little bit self.params[('image_smoothing', {}, 'image smoothing settings')] self.params['image_smoothing'][( 'smooth_images', True, '[True|False]; smoothes the images before registration')] self.params['image_smoothing'][('smoother', {}, 'settings for the image smoothing')] self.params['image_smoothing']['smoother'][( 'gaussian_std', 0.05, 'how much smoothing is done')] self.params['image_smoothing']['smoother'][( 'type', 'gaussian', "['gaussianSpatial'|'gaussian'|'diffusion']")] cparams = self.params['image_smoothing'] s = SF.SmootherFactory(self.sz[2::], self.spacing).create_smoother(cparams) self.ISource = s.smooth(self.ISource) self.ITarget = s.smooth(self.ITarget)
def invert_map(map, spacing): """ Inverts the map and returns its inverse. Assumes standard map parameterization [-1,1]^d :param map: Input map to be inverted :return: inverted map """ # make pytorch arrays for subsequent processing map_t = AdaptVal(torch.from_numpy(map)) # identity map id = utils.identity_map_multiN(map_t.data.shape, spacing) id_t = AdaptVal(torch.from_numpy(id)) # parameter to store the inverse map invmap_t = AdaptVal(Parameter(torch.from_numpy(id.copy()))) # some optimizer settings, probably too strict nr_of_iterations = 1000 rel_ftol = 1e-6 optimizer = CO.LBFGS_LS([invmap_t], lr=1.0, max_iter=1, tolerance_grad=rel_ftol * 10, tolerance_change=rel_ftol, max_eval=5, history_size=5, line_search_fn='backtracking') #optimizer = torch.optim.SGD([invmap_t], lr=0.001, momentum=0.9, dampening=0, weight_decay=0,nesterov=True) def compute_loss(): # warps map_t with inv_map, if it is the inverse should result in the identity map wmap = utils.compute_warped_image_multiNC(map_t, invmap_t, spacing) current_loss = ((wmap - id_t)**2).sum() return current_loss def _closure(): optimizer.zero_grad() loss = compute_loss() loss.backward() return loss last_loss = utils.t2np(compute_loss()) for iter in range(nr_of_iterations): optimizer.step(_closure) current_loss = utils.t2np(compute_loss()) print('Iter = ' + str(iter) + '; E = ' + str(current_loss)) if (current_loss >= last_loss): break else: last_loss = current_loss return utils.t2np(invmap_t)
def compute_average_image(images): im_io = FIO.ImageIO() Iavg = None for nr, im_name in enumerate(images): Ic, hdrc, spacing, _ = im_io.read_to_nc_format(filename=im_name) if nr == 0: Iavg = AdaptVal(torch.from_numpy(Ic)) else: Iavg += AdaptVal(torch.from_numpy(Ic)) Iavg = Iavg / len(images) return Iavg, spacing
def add_texture_on_img(im_orig, texture_gaussian_smoothness=0.1, texture_magnitude=0.3): # do this separately for each integer intensity level levels = np.unique((np.floor(im_orig)).astype('int')) im = np.zeros_like(im_orig) for current_level in levels: sz = im_orig.shape rand_noise = np.random.random(sz[2:]).astype('float32') - 0.5 rand_noise = rand_noise.view().reshape(sz) r_params = pars.ParameterDict() r_params['smoother']['type'] = 'gaussian' r_params['smoother']['gaussian_std'] = texture_gaussian_smoothness spacing = 1.0 / (np.array(sz[2:]).astype('float32') - 1) s_r = sf.SmootherFactory(sz[2::], spacing).create_smoother(r_params) rand_noise_smoothed = s_r.smooth(AdaptVal( torch.from_numpy(rand_noise))).detach().cpu().numpy() rand_noise_smoothed /= rand_noise_smoothed.max() rand_noise_smoothed *= texture_magnitude c_indx = (im_orig >= current_level - 0.5) im[c_indx] = im_orig[c_indx] + rand_noise_smoothed[c_indx] return torch.Tensor(im)
def get_labeled_image(img_pair): s_path, t_path = img_pair moving = torch.load(s_path) target = torch.load(t_path) moving_np = moving.cpu().numpy() target_np = target.cpu().numpy() ind_value_list = np.unique(moving_np) ind_value_list_target = np.unique(target_np) assert len(set(ind_value_list) - set(ind_value_list_target)) == 0 lmoving = torch.zeros_like(moving) ltarget = torch.zeros_like(target) ind_value_list.sort() for i, value in enumerate(ind_value_list): lmoving[moving == value] = i ltarget[target == value] = i return AdaptVal(lmoving), AdaptVal(ltarget)
def __init__(self): dx = AdaptVal( torch.Tensor([[[-1., -3., -1.], [-3., -6., -3.], [-1., -3., -1.]], [[0., 0., 0.], [0., 0, 0.], [0., 0., 0.]], [[1., 3., 1.], [3., 6., 3.], [1., 3., 1.]]])).view(1, 1, 3, 3, 3) dy = AdaptVal( torch.Tensor([[[1., 3., 1.], [0., 0., 0.], [-1., -3., -1.]], [[3., 6., 3.], [0., 0, 0.], [-3., -6., -3.]], [[1., 3., 1.], [0., 0., 0.], [-1., -3., -1.]]])).view(1, 1, 3, 3, 3) dz = AdaptVal( torch.Tensor([[[-1., 0., 1.], [-3., 0., 3.], [-1., 0., 1.]], [[-3., 0., 3.], [-6., 0, 6.], [-3., 0., 3.]], [[-1., 0., 1.], [-3., 0., 3.], [-1., 0., 1.]]])).view(1, 1, 3, 3, 3) self.spatial_filter = torch.cat((dx, dy, dz), 0) self.spatial_filter = self.spatial_filter.repeat(1, 1, 1, 1, 1)
def __get_smoothed_target(self, I0): ITarget = AdaptVal(torch.from_numpy(I0.copy())) # cparams = pars.ParameterDict() # cparams[('smoother', {})] # cparams['smoother']['type'] = 'gaussianSpatial' # cparams['smoother']['gaussianStd'] = 0.005 # s = SF.SmootherFactory(sz[2::], spacing).create_smoother(cparams) # ITarget = s.smooth(ITarget).detach() ITarget = self.fourier_smoother(ITarget).detach() return ITarget
def upsample_to_compatible_size_single_image(gt_weight, weight, interpolation_order=1): # upsample the weights if needed if gt_weight.shape == weight.shape: return weight else: sampler = IS.ResampleImage() weight_sz = weight.shape weight_reshaped = AdaptVal( torch.from_numpy(weight.view().reshape([1, 1] + list(weight_sz))).float()) spacing = np.array([1., 1.]) desired_size = gt_weight.shape weight_upsampled_t, _ = sampler.upsample_image_to_size( weight_reshaped, spacing, desired_size, interpolation_order) weight_upsampled = weight_upsampled_t.detach().cpu().numpy() return weight_upsampled
def downsample_to_compatible_size_single_image(gt_weight, weight, interpolation_order=3): # downsample the ground truth weights if needed if gt_weight.shape == weight.shape: return gt_weight else: sampler = IS.ResampleImage() gt_weight_sz = gt_weight.shape gt_weight_reshaped = AdaptVal( torch.from_numpy( gt_weight.view().reshape([1, 1] + list(gt_weight_sz))).float()) spacing = np.array([1., 1.]) desired_size = weight.shape gt_weight_downsampled_t, _ = sampler.downsample_image_to_size( gt_weight_reshaped, spacing, desired_size, interpolation_order) gt_weight_downsampled = gt_weight_downsampled_t.detach().cpu().numpy() return gt_weight_downsampled
def resample_image(I, spacing, desiredSize, spline_order=1, zero_boundary=False, identity_map=None): """ Resample an image to a given desired size :param I: Input image (expected to be of BxCxXxYxZ format) :param spacing: array describing the spatial spacing :param desiredSize: array for the desired size (excluding B and C, i.e, 1 entry for 1D, 2 for 2D, and 3 for 3D) :return: returns a tuple: the downsampled image, the new spacing after downsampling """ desiredSize = desiredSize[2:] sz = np.array(list(I.size())) # check that the batch size and the number of channels is the same nrOfI = sz[0] nrOfC = sz[1] desiredSizeNC = np.array([nrOfI, nrOfC] + list(desiredSize)) newspacing = spacing * ((sz[2::].astype('float') - 1.) / (desiredSizeNC[2::].astype('float') - 1.) ) ########################################### if identity_map is not None: idDes = identity_map else: idDes = AdaptVal( torch.from_numpy( py_utils.identity_map_multiN(desiredSizeNC, newspacing))) # now use this map for resampling ID = py_utils.compute_warped_image_multiNC(I, idDes, newspacing, spline_order, zero_boundary) return ID, newspacing
def compare_det_of_jac_from_map(map, gt_map, label_image, visualize=False, print_output_directory=None, clean_publication_directory=None, pair_nr=None): sz = np.array(map.shape[2:]) # synthetic spacing spacing = np.array(1. / (sz - 1)) map_torch = AdaptVal(torch.from_numpy(map).float()) gt_map_torch = AdaptVal(torch.from_numpy(gt_map).float()) det_est = eu.compute_determinant_of_jacobian(map_torch, spacing) det_gt = eu.compute_determinant_of_jacobian(gt_map_torch, spacing) n = det_est - det_gt if visualize: if clean_publication_directory is None: plt.clf() plt.subplot(131) plt.imshow(det_gt) plt.colorbar() plt.title('det_gt') plt.subplot(132) plt.imshow(det_est) plt.colorbar() plt.title('det_est') plt.subplot(133) plt.imshow(n) plt.colorbar() plt.title('det_est - det_gt') if print_output_directory is None: plt.show() else: plt.savefig( os.path.join( print_output_directory, '{:0>3d}'.format(pair_nr) + '_det_jac_validation.pdf')) if clean_publication_directory is not None: plt.clf() plt.imshow(det_gt) plt.colorbar() plt.axis('image') plt.axis('off') plt.savefig(os.path.join( clean_publication_directory, 'det_gt_{:0>3d}'.format(pair_nr) + '_det_jac_validation.pdf'), bbox_inches='tight', pad_inches=0) plt.clf() plt.imshow(det_est) plt.colorbar() plt.axis('image') plt.axis('off') plt.savefig(os.path.join( clean_publication_directory, 'det_est_{:0>3d}'.format(pair_nr) + '_det_jac_validation.pdf'), bbox_inches='tight', pad_inches=0) plt.clf() plt.imshow(n) plt.colorbar() plt.axis('image') plt.axis('off') plt.savefig(os.path.join( clean_publication_directory, 'det_est_m_det_gt_{:0>3d}'.format(pair_nr) + '_det_jac_validation.pdf'), bbox_inches='tight', pad_inches=0) ds = compute_image_stats(n, label_image) return ds
def build_atlas(images, nr_of_cycles, warped_images, temp_folder, visualize): si = SI.RegisterImagePair() im_io = FIO.ImageIO() # compute first average image Iavg, sp = compute_average_image(images) Iavg = Iavg.data if visualize: plt.imshow(AdaptVal(Iavg[0, 0, ...]).detach().cpu().numpy(), cmap='gray') plt.title('Initial average based on ' + str(len(images)) + ' images') plt.colorbar() plt.show() # initialize list to save model parameters in between cycles mp = [] # register all images to the average image and while doing so compute a new average image for c in range(nr_of_cycles): print('Starting cycle ' + str(c + 1) + '/' + str(nr_of_cycles)) for i, im_name in enumerate(images): print('Registering image ' + str(i) + '/' + str(len(images))) Ic, hdrc, spacing, _ = im_io.read_to_nc_format(filename=im_name) # set former model parameters if available if c != 0: si.set_model_parameters(mp[i]) # register current image to average image si.register_images(Ic, AdaptVal(Iavg).detach().cpu().numpy(), spacing, model_name='svf_scalar_momentum_map', map_low_res_factor=0.5, nr_of_iterations=5, visualize_step=None, similarity_measure_sigma=0.5) wi = si.get_warped_image() # save current model parametrs for the next circle if c == 0: mp.append(si.get_model_parameters()) elif c != nr_of_cycles - 1: mp[i] = si.get_model_parameters() if c == nr_of_cycles - 1: # last time this is run, so let's save the image current_filename = warped_images + '/atlas_reg_Image' + str( i + 1).zfill(4) + '.nrrd' print("writing image " + str(i + 1)) im_io.write(current_filename, wi, hdrc) if i == 0: newAvg = wi.data else: newAvg += wi.data Iavg = newAvg / len(images) if visualize: plt.imshow(AdaptVal(Iavg[0, 0, ...]).detach().cpu().numpy(), cmap='gray') plt.title('Average ' + str(c + 1) + '/' + str(nr_of_cycles)) plt.colorbar() plt.show() return Iavg
def warp_data(data_list): return [AdaptVal(data) for data in data_list]
import mermaid.smoother_factory as SF import matplotlib.pyplot as plt params = MP.ParameterDict() dim = 2 szEx = np.tile(64, dim) I0, I1, spacing = eg.CreateSquares(dim).create_image_pair( szEx, params) # create a default image size with two sample squares sz = np.array(I0.shape) # create the source and target image as pyTorch variables ISource = AdaptVal(torch.from_numpy(I0.copy())) smoother = SF.MySingleGaussianFourierSmoother(sz[2:], spacing, params) g_std = smoother.get_gaussian_std() ISmooth = smoother.smooth_scalar_field(ISource) ISmooth.backward(torch.ones_like(ISmooth)) #ISmooth.backward(torch.zeros_like(ISmooth)) print('g_std.grad') print(g_std.grad) plt.subplot(121) plt.imshow(utils.t2np(ISource[0, 0, ...]))
params['square_example_images'] = ({}, 'Settings for example image generation') params['square_example_images']['len_s'] = int(szEx.min() // 6) params['square_example_images']['len_l'] = int(szEx.max() // 4) # create a default image size with two sample squares I0, I1, spacing = eg.CreateSquares(ds.dim).create_image_pair(szEx, params) sz = np.array(I0.shape) assert (len(sz) == ds.dim + 2) print('Spacing = ' + str(spacing)) # create the source and target image as pyTorch variables ISource = AdaptVal(torch.from_numpy(I0.copy())) ITarget = AdaptVal(torch.from_numpy(I1)) # if desired we smooth them a little bit if ds.smooth_images: # smooth both a little bit params['image_smoothing'] = ds.par_algconf['image_smoothing'] cparams = params['image_smoothing'] s = SF.SmootherFactory(sz[2::], spacing).create_smoother(cparams) ISource = s.smooth(ISource) ITarget = s.smooth(ITarget) ##############################3 # Setting up the optimizer # ^^^^^^^^^^^^^^^^^^^^^^^^ #
def do_registration( I0_name, I1_name, visualize, visualize_step, use_multi_scale, normalize_spacing, normalize_intensities, squeeze_image, par_algconf ): from mermaid.data_wrapper import AdaptVal import mermaid.smoother_factory as SF import mermaid.multiscale_optimizer as MO from mermaid.config_parser import nr_of_threads params = pars.ParameterDict() par_image_smoothing = par_algconf['algconf']['image_smoothing'] par_model = par_algconf['algconf']['model'] par_optimizer = par_algconf['algconf']['optimizer'] use_map = par_model['deformation']['use_map'] map_low_res_factor = par_model['deformation']['map_low_res_factor'] model_name = par_model['deformation']['name'] if use_map: model_name = model_name + '_map' else: model_name = model_name + '_image' # general parameters params['model']['registration_model'] = par_algconf['algconf']['model']['registration_model'] torch.set_num_threads( nr_of_threads ) print('Number of pytorch threads set to: ' + str(torch.get_num_threads())) I0, I1, spacing, md_I0, md_I1 = read_images( I0_name, I1_name, normalize_spacing, normalize_intensities,squeeze_image ) sz = I0.shape # create the source and target image as pyTorch variables ISource = AdaptVal(torch.from_numpy(I0.copy())) ITarget = AdaptVal(torch.from_numpy(I1)) smooth_images = par_image_smoothing['smooth_images'] if smooth_images: # smooth both a little bit params['image_smoothing'] = par_algconf['algconf']['image_smoothing'] cparams = params['image_smoothing'] s = SF.SmootherFactory(sz[2::], spacing).create_smoother(cparams) ISource = s.smooth_scalar_field(ISource) ITarget = s.smooth_scalar_field(ITarget) if not use_multi_scale: # create multi-scale settings for single-scale solution multi_scale_scale_factors = [1.0] multi_scale_iterations_per_scale = [par_optimizer['single_scale']['nr_of_iterations']] else: multi_scale_scale_factors = par_optimizer['multi_scale']['scale_factors'] multi_scale_iterations_per_scale = par_optimizer['multi_scale']['scale_iterations'] mo = MO.MultiScaleRegistrationOptimizer(sz, spacing, use_map, map_low_res_factor, params) optimizer_name = par_optimizer['name'] mo.set_optimizer_by_name(optimizer_name) mo.set_visualization(visualize) mo.set_visualize_step(visualize_step) mo.set_model(model_name) mo.set_source_image(ISource) mo.set_target_image(ITarget) mo.set_scale_factors(multi_scale_scale_factors) mo.set_number_of_iterations_per_scale(multi_scale_iterations_per_scale) # and now do the optimization mo.optimize() optimized_energy = mo.get_energy() warped_image = mo.get_warped_image() optimized_map = mo.get_map() optimized_reg_parameters = mo.get_model_parameters() return warped_image, optimized_map, optimized_reg_parameters, optimized_energy, params, md_I0