def read_image_and_map_and_apply_map(image_filename,map_filename): """ Reads an image and a map and applies the map to an image :param image_filename: input image filename :param map_filename: input map filename :return: the warped image and its image header as a tupe (im,hdr) """ im_warped = None map,map_hdr = fileio.MapIO().read(map_filename) im,hdr,_,_ = fileio.ImageIO().read_to_map_compatible_format(image_filename,map) spacing = hdr['spacing'] #TODO: check that the spacing is compatible with the map if (im is not None) and (map is not None): # make pytorch arrays for subsequent processing im_t = AdaptVal(torch.from_numpy(im)) map_t = AdaptVal(torch.from_numpy(map)) im_warped = utils.t2np( utils.compute_warped_image_multiNC(im_t,map_t,spacing) ) return im_warped,hdr else: print('Could not read map or image') return None,None
def read_images(source_image_name,target_image_name, normalize_spacing=True, normalize_intensities=True, squeeze_image=True): I0,hdr0,spacing0,normalized_spacing0 = fileio.ImageIO().read_to_nc_format(source_image_name, intensity_normalize=normalize_intensities, squeeze_image=squeeze_image) I1,hdr1,spacing1,normalized_spacing1 = fileio.ImageIO().read_to_nc_format(target_image_name, intensity_normalize=normalize_intensities, squeeze_image=squeeze_image) assert (np.all( spacing0 == spacing1) ) # TODO: do a better test for equality for the images here if normalize_spacing: spacing = normalized_spacing0 else: spacing = spacing0 print('Spacing = ' + str(spacing)) return I0, I1, spacing, hdr0, hdr1
def file_io_read_img(path, is_label, normalize_spacing=True, normalize_intensities=True, squeeze_image=True, adaptive_padding=4): normalize_intensities = False if is_label else normalize_intensities im, hdr, spacing, normalized_spacing = fileio.ImageIO().read(path, normalize_intensities, squeeze_image,adaptive_padding) if normalize_spacing: spacing = normalized_spacing else: spacing = spacing info = { 'spacing':spacing, 'img_size': im.shape} return im, info
def compute_average_image(images): im_io = FIO.ImageIO() Iavg = None for nr, im_name in enumerate(images): Ic, hdrc, spacing, _ = im_io.read_to_nc_format(filename=im_name) if nr == 0: Iavg = AdaptVal(torch.from_numpy(Ic)) else: Iavg += AdaptVal(torch.from_numpy(Ic)) Iavg = Iavg / len(images) return Iavg, spacing
def normalize_image_intensity(source_filename,target_filename,target_cdf,target_cdf_bins, remove_background=True, background_value=0): if os.path.exists(target_filename): if os.path.samefile(source_filename, target_filename): raise ValueError('ERROR: Source file {} is the same as target file {}. Refusing conversion.'.format(source_filename,target_filename)) im_io = fio.ImageIO() im_orig, hdr, spacing, squeezed_spacing = im_io.read(native_str(source_filename), intensity_normalize=False, normalize_spacing=False) print('Histogram matching: {}'.format(source_filename)) print('Image: {:s}: min={:4.0f}; max={:4.0f}'.format(source_filename, im_orig.min(), im_orig.max())) hist_matched_im = histogram_match(im_orig,target_cdf=target_cdf,target_bins=target_cdf_bins,remove_background=remove_background,background_value=background_value) # and now write it out print('Writing results to {}'.format(target_filename)) im_io = fio.ImageIO() im_io.write(native_str(target_filename), hist_matched_im, hdr) return hist_matched_im,hdr
def __init__(self, reg_models): """ initialize with a sequence of registration mode :param reg_models: a list of tuples (model_name:string, model_setting: json_file_name(string) or an ParameterDict object) """ self.si_ = SI.RegisterImagePair() self.model_0_name, self.model_0_setting = reg_models[0] self.model_1_name, self.model_1_setting = reg_models[1] self.im_io = FIO.ImageIO() self.target_image_np = None self.moving_image_np = None self.target_mask = None self.moving_mask = None self.Ab = None self.map = None self.inverse_map = None
def compute_average_quantile_function(filenames,nr_of_bins,remove_background=True,background_value=0,save_results_to_pdf=False): im_io = fio.ImageIO() perc = None all_quants = [] print('Computing the average quantile function (from the following files):') for f in filenames: im_orig, hdr, spacing, squeezed_spacing = im_io.read(native_str(f), intensity_normalize=False, normalize_spacing=False) print('Image: {:s}: min={:4.0f}; max={:4.0f}'.format(f, im_orig.min(), im_orig.max())) if remove_background: indx_keep = (im_orig > background_value) im = im_orig[indx_keep] else: im = im_orig imquant, perc = compute_quantile_function(im.flatten(), nr_of_quantiles=nr_of_bins) all_quants.append(imquant) avg_quant = np.zeros_like(all_quants[0]) for cquant in all_quants: avg_quant += 1.0 / len(all_quants) * cquant if save_results_to_pdf: plt.clf() for cquant in all_quants: plt.plot(perc, cquant) plt.plot(perc, avg_quant, color='k', linewidth=3.0) plt.xlabel('P') plt.ylabel('I') print('Saving: quantile_averaging.pdf') plt.savefig('quantile_averaging.pdf') #plt.show() return avg_quant,perc
similarityE=optimized_energy[1], regE=optimized_energy[2])) if write_map is not None: if optimized_map is not None: #om_data = optimized_map.data.numpy() #nrrd.write( write_map, om_data, md_I ) fileio.MapIO().write(write_map,optimized_map,md_I) else: print('Warning: Map cannot be written as it was not computed -- maybe you are using an image-based algorithm?') if write_warped_image is not None: if warped_image is not None: #wi_data = warped_image.data.numpy() #nrrd.write(write_warped_image, wi_data, md_I) fileio.ImageIO().write(write_warped_image,warped_image,md_I) else: print('Warning: Warped image cannot be written as it was not computed -- maybe you are using a map-based algorithm?') if write_reg_params is not None: if optimized_reg_parameters is not None: #rp_data = optimized_reg_parameters.data.numpy() #nrrd.write(write_reg_params, rp_data, md_I) fileio.GenericIO().write(write_reg_params,optimized_reg_parameters,md_I) else: print('Warning: optimized parameters were not computed and hence cannot be saved.') if used_config is not None: print('Writing the used configuration to file.') params.write_JSON( used_config + '_settings_clean.json') params.write_JSON_comments( used_config + '_settings_comments.json')
return im_warped,hdr else: print('Could not read map or image') return None,None if __name__ == "__main__": # execute this as a script import argparse print('WARNING: TODO: need to add support for different spline orders for image warping!! (I.e., support for params for compute_warped_image_multiNC') parser = argparse.ArgumentParser(description='Apply map to warp image') required = parser.add_argument_group('required arguments') required.add_argument('--map', required=True, help='Map that should be applied [need to be in [-1,1]^d format currently]') required.add_argument('--image', required=True, help='Image to which the map should be applied') required.add_argument('--warped_image', required=True, help='Warped image after applying the map') args = parser.parse_args() image_filename = args.image map_filename = args.map im_warped_filename = args.warped_image im_warped,hdr = read_image_and_map_and_apply_map(image_filename, map_filename) # now write it out print( 'Writing warped image to file: ' + im_warped_filename ) fileio.ImageIO().write(im_warped_filename, im_warped, hdr)
def calculate_image_overlap(dataset_info, phi_path, source_labelmap_path, target_labelmap_path, warped_labelmap_path, moving_id, target_id, use_sym_links=True): """ Calculate the overlapping rate of a specified case :param dataset_info: dictionary containing all the validation dataset information :param dataset_dir: path to the label datasets :param phi_path: deformation field path :param moving_id: moving image id :param target_id: target image id :return: """ Labels = None nr_of_labels = -1 if dataset_info['label_name'] is not None: Labels = sio.loadmat(dataset_info['label_name']) nr_of_labels = (len(Labels['Labels'])) result = np.zeros(nr_of_labels) else: if 'nr_of_labels' in dataset_info: nr_of_labels = dataset_info['nr_of_labels'] result = np.zeros(nr_of_labels) else: raise ValueError( 'If matlab label file not given, nr_of_labels needs to be specified' ) # todo: not sure why these are floats, but okay for now im_io = fio.ImageIO() label_from_id = moving_id - dataset_info[ 'start_id'] # typicall starts at 1 label_to_id = target_id - dataset_info['start_id'] label_from_filename = dataset_info['label_files_dir'] + dataset_info[ 'label_prefix'] + '{:d}.nii'.format(label_from_id + dataset_info['start_id']) label_from, hdr, _, _ = im_io.read(label_from_filename, silent_mode=True, squeeze_image=True) label_to_filename = dataset_info['label_files_dir'] + dataset_info[ 'label_prefix'] + '{:d}.nii'.format(label_to_id + dataset_info['start_id']) label_to, hdr, _, _ = im_io.read(label_to_filename, silent_mode=True, squeeze_image=True) map_io = fio.MapIO() phi, _, _, _ = map_io.read_from_validation_map_format(phi_path) warp_result = warp_image_nn(label_from, phi) im_io.write(warped_labelmap_path, warp_result, hdr) if source_labelmap_path is not None: if use_sym_links: utils.create_symlink_with_correct_ext(label_from_filename, source_labelmap_path) else: im_io.write(source_labelmap_path, label_from, hdr) if target_labelmap_path is not None: if use_sym_links: utils.create_symlink_with_correct_ext(label_to_filename, target_labelmap_path) else: im_io.write(target_labelmap_path, label_to, hdr) for label_idx in range(nr_of_labels): if Labels is not None: current_id = Labels['Labels'][label_idx][0] else: current_id = label_idx target_vol = float((label_to == current_id).sum()) if target_vol == 0 or current_id == 0: result[label_idx] = np.nan else: intersection = float( (np.logical_and(label_to == current_id, warp_result == current_id)).sum()) result[label_idx] = intersection / target_vol single_result = result single_result = single_result[~np.isnan(single_result)] if len(single_result) == 0: result_mean = np.nan else: result_mean = np.mean(single_result) return result_mean, single_result
def build_atlas(images, nr_of_cycles, warped_images, temp_folder, visualize): si = SI.RegisterImagePair() im_io = FIO.ImageIO() # compute first average image Iavg, sp = compute_average_image(images) Iavg = Iavg.data if visualize: plt.imshow(AdaptVal(Iavg[0, 0, ...]).detach().cpu().numpy(), cmap='gray') plt.title('Initial average based on ' + str(len(images)) + ' images') plt.colorbar() plt.show() # initialize list to save model parameters in between cycles mp = [] # register all images to the average image and while doing so compute a new average image for c in range(nr_of_cycles): print('Starting cycle ' + str(c + 1) + '/' + str(nr_of_cycles)) for i, im_name in enumerate(images): print('Registering image ' + str(i) + '/' + str(len(images))) Ic, hdrc, spacing, _ = im_io.read_to_nc_format(filename=im_name) # set former model parameters if available if c != 0: si.set_model_parameters(mp[i]) # register current image to average image si.register_images(Ic, AdaptVal(Iavg).detach().cpu().numpy(), spacing, model_name='svf_scalar_momentum_map', map_low_res_factor=0.5, nr_of_iterations=5, visualize_step=None, similarity_measure_sigma=0.5) wi = si.get_warped_image() # save current model parametrs for the next circle if c == 0: mp.append(si.get_model_parameters()) elif c != nr_of_cycles - 1: mp[i] = si.get_model_parameters() if c == nr_of_cycles - 1: # last time this is run, so let's save the image current_filename = warped_images + '/atlas_reg_Image' + str( i + 1).zfill(4) + '.nrrd' print("writing image " + str(i + 1)) im_io.write(current_filename, wi, hdrc) if i == 0: newAvg = wi.data else: newAvg += wi.data Iavg = newAvg / len(images) if visualize: plt.imshow(AdaptVal(Iavg[0, 0, ...]).detach().cpu().numpy(), cmap='gray') plt.title('Average ' + str(c + 1) + '/' + str(nr_of_cycles)) plt.colorbar() plt.show() return Iavg
parser.add_argument( '--visualize', required=True, help='Visualize the first, intermediate and final atlas images', type=bool) parser.add_argument('--final_atlas', required=True, help='Path to where to save final atlas image') args = parser.parse_args() input_image_folder = args.input_image_folder input_image_pattern = args.input_image_pattern number_input_images = args.number_input_images number_of_cycles = args.number_of_cycles warped_images = args.warped_images temp_folder = args.temp_folder visualize = args.visualize final_atlas = args.final_atlas # get files to build the atlas from images_list = find(input_image_pattern, input_image_folder) images = images_list[0:number_input_images] # build the atlas Iatlas = build_atlas(images, number_of_cycles, None, None, visualize) # save final atlas image atlas_filename = final_atlas + '/atlas_reg_finalAtlasImage.nrrd' print("Writing final average image") FIO.ImageIO().write(atlas_filename, Iatlas)
def compute_and_visualize_validation_result(multi_gaussian_stds_synth, multi_gaussian_stds, compare_global_weights, image_and_map_output_directory, misc_output_directory, label_output_directory, print_output_directory, pair_nr, current_source_id, visualize=False, print_images=False, clean_publication_print=False, printing_single_pair=False): # load the computed results for map, weights, and momentum map_output_filename_pt = os.path.join(image_and_map_output_directory, 'map_{:05d}.pt'.format(pair_nr)) weights_output_filename_pt = os.path.join( image_and_map_output_directory, 'weights_{:05d}.pt'.format(pair_nr)) momentum_output_filename_pt = os.path.join( image_and_map_output_directory, 'momentum_{:05d}.pt'.format(pair_nr)) map = torch.load(map_output_filename_pt).detach().cpu().numpy() momentum = torch.load(momentum_output_filename_pt).detach().cpu().numpy() if print_images: visualize = True if clean_publication_print: clean_publication_dir = os.path.join(print_output_directory, 'clean_publication_prints') if not os.path.exists(clean_publication_dir): print( 'INFO: creating directory {:s}'.format(clean_publication_dir)) os.mkdir(clean_publication_dir) else: clean_publication_dir = None weights_dict = torch.load(weights_output_filename_pt) if not compare_global_weights: if 'local_weights' in weights_dict: weights = weights_dict['local_weights'].detach().cpu().numpy() else: raise ValueError( 'requested comparison of local weights, but local weights are not available' ) else: # there are only global weights # let's make them "local" so that we can use the same code for comparison everywhere global_weights = weights_dict['default_multi_gaussian_weights'].detach( ).cpu().numpy() nr_of_weights = len(global_weights) sz_m = list(momentum.shape) desired_sz = [sz_m[0]] + [nr_of_weights] + sz_m[2:] weights = np.zeros(desired_sz, dtype='float32') for n in range(nr_of_weights): weights[:, n, ...] = global_weights[n] # now get the corresponding ground truth values # (these are based on the source image and hence based on the source id) # load the ground truth results for map, weights, and momentum gt_map_filename = os.path.join( misc_output_directory, 'gt_map_{:05d}.pt'.format(current_source_id)) gt_weights_filename = os.path.join( misc_output_directory, 'gt_weights_{:05d}.pt'.format(current_source_id)) gt_momentum_filename = os.path.join( misc_output_directory, 'gt_momentum_{:05d}.pt'.format(current_source_id)) label_output_filename = os.path.join(label_output_directory, 'm{:d}.nii'.format(current_source_id)) gt_map = torch.load(gt_map_filename) gt_weights = torch.load(gt_weights_filename) gt_momentum = torch.load(gt_momentum_filename) im_io = FIO.ImageIO() label_image, _, _, _ = im_io.read_to_nc_format(label_output_filename, silent_mode=True) label_image = label_image[0, 0, :, :, 0] if print_images: print_output_directory_eff = print_output_directory else: print_output_directory_eff = None if printing_single_pair or pair_nr == 0: clean_publication_dir_eff = clean_publication_dir visualize = True else: # we don't want to print them for all clean_publication_dir_eff = None # now we can compare them d = dict() d['map_stats'] = compare_map( map, gt_map, label_image, visualize=visualize, print_output_directory=print_output_directory_eff, clean_publication_directory=clean_publication_dir_eff, pair_nr=pair_nr) d['det_jac_stats'] = compare_det_of_jac_from_map( map, gt_map, label_image, visualize=visualize, print_output_directory=print_output_directory_eff, clean_publication_directory=clean_publication_dir_eff, pair_nr=pair_nr) d['weight_stats'] = compare_weights( weights, gt_weights, multi_gaussian_stds_synth, multi_gaussian_stds, label_image=label_image, visualize=visualize, print_output_directory=print_output_directory_eff, clean_publication_directory=clean_publication_dir_eff, pair_nr=pair_nr) if momentum.shape == gt_momentum.shape: d['momentum_stats'] = compare_momentum( momentum, gt_momentum, label_image, visualize=visualize, print_output_directory=print_output_directory_eff, clean_publication_directory=clean_publication_dir_eff, pair_nr=pair_nr) return d