def main(): # =================================== # read the test image # =================================== data_brain_test = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=50, idx_end=52, protocol='T1', size=exp_config.image_size, target_resolution=exp_config.target_resolution_brain) imts = data_brain_test['images'] image = imts[1, :, :, 100] # =================================== # predict segmentation at the pre-processed resolution # =================================== predicted_label = predict_segmentation(image) # =================================== # save sample results # =================================== utils_vis.save_single_image_and_label(image, predicted_label, savepath=sys_config.log_root + exp_config.expname_i2l + '/test_result.png')
def main(): # ============================ # Load SD data # ============================ logging.info( '============================================================') logging.info('Loading SD data...') if exp_config.train_dataset is 'HCPT1': logging.info('Reading HCPT1 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) data_brain_train_sd = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=0, idx_end=20, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_hcp, target_resolution=exp_config.target_resolution_brain) imtr_sd, gttr_sd = [ data_brain_train_sd['images'], data_brain_train_sd['labels'] ] data_brain_val_sd = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=20, idx_end=25, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_hcp, target_resolution=exp_config.target_resolution_brain) imvl_sd, gtvl_sd = [ data_brain_val_sd['images'], data_brain_val_sd['labels'] ] # PROSTATE elif exp_config.train_dataset is 'NCI': logging.info('Reading NCI images...') logging.info('Data root directory: ' + sys_config.orig_data_root_nci) data_pros = data_nci.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_nci, preprocessing_folder=sys_config.preproc_folder_nci, size=exp_config.image_size, target_resolution=exp_config.target_resolution_prostate, force_overwrite=False, cv_fold_num=1) imtr_sd, gttr_sd = [ data_pros['images_train'], data_pros['masks_train'] ] imvl_sd, gtvl_sd = [ data_pros['images_validation'], data_pros['masks_validation'] ] # ============================ # Load TD unlabelled images # ============================ logging.info( '============================================================') logging.info('Loading TD unlabelled images...') if exp_config.test_dataset is 'HCPT2': logging.info('Reading HCPT2 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) image_depth = exp_config.image_depth_hcp data_brain_train_td = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=0, idx_end=20, protocol='T2', size=exp_config.image_size, depth=image_depth, target_resolution=exp_config.target_resolution_brain) imtr_td = data_brain_train_td['images'] data_brain_val_td = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=20, idx_end=25, protocol='T2', size=exp_config.image_size, depth=image_depth, target_resolution=exp_config.target_resolution_brain) imvl_td = data_brain_val_td['images'] elif exp_config.test_dataset is 'CALTECH': logging.info('Reading CALTECH images...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide + 'CALTECH/') image_depth = exp_config.image_depth_caltech data_brain_train_td = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='CALTECH', idx_start=0, idx_end=10, protocol='T1', size=exp_config.image_size, depth=image_depth, target_resolution=exp_config.target_resolution_brain) imtr_td = data_brain_train_td['images'] data_brain_val_td = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='CALTECH', idx_start=10, idx_end=15, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_caltech, target_resolution=exp_config.target_resolution_brain) imvl_td = data_brain_val_td['images'] elif exp_config.test_dataset is 'PIRAD_ERC': logging.info('Reading PIRAD_ERC images...') logging.info('Data root directory: ' + sys_config.orig_data_root_pirad_erc) data_pros_train = data_pirad_erc.load_data( input_folder=sys_config.orig_data_root_pirad_erc, preproc_folder=sys_config.preproc_folder_pirad_erc, idx_start=40, idx_end=68, size=exp_config.image_size, target_resolution=exp_config.target_resolution_prostate, labeller='ek', force_overwrite=False) data_pros_val = data_pirad_erc.load_data( input_folder=sys_config.orig_data_root_pirad_erc, preproc_folder=sys_config.preproc_folder_pirad_erc, idx_start=20, idx_end=40, size=exp_config.image_size, target_resolution=exp_config.target_resolution_prostate, labeller='ek', force_overwrite=False) imtr_td = data_pros_train['images'] imvl_td = data_pros_val['images'] elif exp_config.test_dataset is 'PROMISE': logging.info('Reading PROMISE images...') logging.info('Data root directory: ' + sys_config.orig_data_root_promise) data_pros = data_promise.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_promise, preprocessing_folder=sys_config.preproc_folder_promise, size=exp_config.image_size, target_resolution=exp_config.target_resolution_prostate, force_overwrite=False, cv_fold_num=2) imtr_td = data_pros['images_train'] imvl_td = data_pros['images_validation'] # ================================================================ # create a text file for writing results # results of individual subjects will be appended to this file # ================================================================ log_dir_uda = os.path.join(sys_config.log_root, exp_config.expname_uda) if not tf.gfile.Exists(log_dir_uda): tf.gfile.MakeDirs(log_dir_uda) tf.gfile.MakeDirs(log_dir_uda + '/models') # =========================== # Copy experiment config file # =========================== shutil.copy(exp_config.__file__, log_dir_uda) # ================================================================ # run uda training # ================================================================ run_uda_training(log_dir_uda, imtr_sd, gttr_sd, imvl_sd, gtvl_sd, imtr_td, imvl_td)
def main(argv): # ============================ # Load test image # ============================ logging.info( '============================================================') logging.info('Loading data...') if exp_config.test_dataset is 'HCPT2': logging.info('Reading HCPT2 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) image_depth = exp_config.image_depth_hcp data_brain_test = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=50, idx_end=70, protocol='T2', size=exp_config.image_size, depth=image_depth, target_resolution=exp_config.target_resolution_brain) imts, gtts = [data_brain_test['images'], data_brain_test['labels']] num_test_subjects = imts.shape[0] // image_depth name_test_subjects = data_brain_test['patnames'] slice_thickness_in_test_subjects = data_brain_test['pz'][:] elif exp_config.test_dataset is 'CALTECH': logging.info('Reading CALTECH images...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide + 'CALTECH/') image_depth = exp_config.image_depth_caltech data_brain_test = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='CALTECH', idx_start=16, idx_end=36, protocol='T1', size=exp_config.image_size, depth=image_depth, target_resolution=exp_config.target_resolution_brain) imts, gtts = [data_brain_test['images'], data_brain_test['labels']] num_test_subjects = imts.shape[0] // image_depth name_test_subjects = data_brain_test['patnames'] slice_thickness_in_test_subjects = data_brain_test['pz'][:] elif exp_config.test_dataset is 'STANFORD': logging.info('Reading STANFORD images...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide + 'STANFORD/') image_depth = exp_config.image_depth_stanford data_brain_test = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='STANFORD', idx_start=16, idx_end=36, protocol='T1', size=exp_config.image_size, depth=image_depth, target_resolution=exp_config.target_resolution_brain) imts, gtts = [data_brain_test['images'], data_brain_test['labels']] num_test_subjects = imts.shape[0] // image_depth name_test_subjects = data_brain_test['patnames'] slice_thickness_in_test_subjects = data_brain_test['pz'][:] # ================================================================ # read the atlas # ================================================================ atlas = np.load(sys_config.preproc_folder_hcp + 'hcp_atlas.npy') # ================================================================ # create a text file for writing results # results of individual subjects will be appended to this file # ================================================================ log_dir_base = os.path.join(sys_config.log_root, exp_config.expname_normalizer) if not tf.gfile.Exists(log_dir_base): tf.gfile.MakeDirs(log_dir_base) # ================================================================ # run the training for each test image # ================================================================ subject_num = int(argv[0]) for subject_id in range(subject_num, subject_num + 1): subject_id_start_slice = subject_id * image_depth subject_id_end_slice = (subject_id + 1) * image_depth image = imts[subject_id_start_slice:subject_id_end_slice, :, :] label = gtts[subject_id_start_slice:subject_id_end_slice, :, :] slice_thickness_this_subject = slice_thickness_in_test_subjects[ subject_id] # ================================================================== # setup logging # ================================================================== logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') log_dir_base = os.path.join(sys_config.log_root, exp_config.expname_normalizer) subject_name = str(name_test_subjects[subject_id])[2:-1] # subject_name = str(subject_id) log_dir = log_dir_base + '/subject_' + subject_name logging.info( '============================================================') logging.info('Logging directory: %s' % log_dir) logging.info('Subject ID: %d' % subject_id) logging.info('Subject name: %s' % subject_name) # =========================== # create dir if it does not exist # =========================== if not tf.gfile.Exists(log_dir): tf.gfile.MakeDirs(log_dir) tf.gfile.MakeDirs(log_dir + '/models') tf.gfile.MakeDirs(log_dir + '/results') tf.gfile.MakeDirs(log_dir + '/results/visualize_images') # =========================== # Copy experiment config file # =========================== shutil.copy(exp_config.__file__, log_dir) # =========================== # Change the resolution of the current image so that it matches the atlas, and pad and crop. # =========================== image_rescaled_cropped, label_rescaled_cropped = modify_image_and_label( image, label, atlas, slice_thickness_this_subject) # visualize image and ground truth label utils_vis.save_samples_downsampled( utils.crop_or_pad_volume_to_size_along_x(image, 256)[::8, :, :], savepath=log_dir + '/orig_image.png', add_pixel_each_label=False, cmap='gray') utils_vis.save_samples_downsampled( utils.crop_or_pad_volume_to_size_along_x(label, 256)[::8, :, :], savepath=log_dir + '/gt_label.png', cmap='tab20') utils_vis.save_samples_downsampled(image_rescaled_cropped[::8, :, :], savepath=log_dir + '/orig_image_rescaled.png', add_pixel_each_label=False, cmap='gray') utils_vis.save_samples_downsampled(label_rescaled_cropped[::8, :, :], savepath=log_dir + '/gt_label_rescaled.png', cmap='tab20') # =========================== # run training. pass the log dir of the 1st TD subject, if this training has been completed successfully # =========================== first_subject = 0 log_dir_first_TD_subject = '' if subject_id != first_subject: log_dir_first_TD_subject = log_dir_base + '/subject_' + str( name_test_subjects[first_subject])[2:-1] run_training(log_dir, image_rescaled_cropped, label_rescaled_cropped, atlas, continue_run=exp_config.continue_run, log_dir_first_TD_subject=log_dir_first_TD_subject) # =========================== # =========================== gc.collect()
def main(): # =================================== # read the test images # =================================== test_dataset_name = exp_config.test_dataset if test_dataset_name is 'HCPT1': logging.info('Reading HCPT1 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) image_depth = exp_config.image_depth_hcp idx_start = 50 idx_end = 70 data_brain_test = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=idx_start, idx_end=idx_end, protocol='T1', size=exp_config.image_size, depth=image_depth, target_resolution=exp_config.target_resolution_brain) elif test_dataset_name is 'HCPT2': logging.info('Reading HCPT2 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) image_depth = exp_config.image_depth_hcp idx_start = 50 idx_end = 70 data_brain_test = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=idx_start, idx_end=idx_end, protocol='T2', size=exp_config.image_size, depth=image_depth, target_resolution=exp_config.target_resolution_brain) elif test_dataset_name is 'CALTECH': logging.info('Reading CALTECH images...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide + 'CALTECH/') image_depth = exp_config.image_depth_caltech idx_start = 16 idx_end = 36 data_brain_test = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='CALTECH', idx_start=idx_start, idx_end=idx_end, protocol='T1', size=exp_config.image_size, depth=image_depth, target_resolution=exp_config.target_resolution_brain) elif test_dataset_name is 'STANFORD': logging.info('Reading STANFORD images...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide + 'STANFORD/') image_depth = exp_config.image_depth_stanford idx_start = 16 idx_end = 36 data_brain_test = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='STANFORD', idx_start=idx_start, idx_end=idx_end, protocol='T1', size=exp_config.image_size, depth=image_depth, target_resolution=exp_config.target_resolution_brain) imts = data_brain_test['images'] name_test_subjects = data_brain_test['patnames'] num_test_subjects = imts.shape[0] // image_depth ids = np.arange(idx_start, idx_end) orig_data_res_x = data_brain_test['px'][:] orig_data_res_y = data_brain_test['py'][:] orig_data_res_z = data_brain_test['pz'][:] orig_data_siz_x = data_brain_test['nx'][:] orig_data_siz_y = data_brain_test['ny'][:] orig_data_siz_z = data_brain_test['nz'][:] # ================================ # set the log directory # ================================ if exp_config.normalize is True: log_dir = os.path.join(sys_config.log_root, exp_config.expname_normalizer) else: log_dir = sys_config.log_root + 'i2l_mapper/' + exp_config.expname_i2l if exp_config.post_process is True: file_suffix = '_with_post_process_with_dae_runs' + str( exp_config.dae_post_process_runs) else: file_suffix = '' logging.info(log_dir) # ================================ # open a text file for writing the mean dice scores for each subject that is evaluated # ================================ results_file = open( log_dir + '/' + test_dataset_name + '_' + 'test' + file_suffix + '.txt', "w") results_file.write("================================== \n") results_file.write("Test results \n") # ================================================================ # For each test image, load the best model and compute the dice with this model # ================================================================ dice_per_label_per_subject = [] hsd_per_label_per_subject = [] for sub_num in range(num_test_subjects): subject_id_start_slice = np.sum(orig_data_siz_z[:sub_num]) subject_id_end_slice = np.sum(orig_data_siz_z[:sub_num + 1]) image = imts[subject_id_start_slice:subject_id_end_slice, :, :] # ================================================================== # setup logging # ================================================================== logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') subject_name = str(name_test_subjects[sub_num])[2:-1] logging.info( '============================================================') logging.info('Subject id: %s' % sub_num) # ================================================================== # predict segmentation at the pre-processed resolution # ================================================================== predicted_labels, normalized_image = predict_segmentation( subject_name, image, exp_config.normalize, exp_config.post_process) # ================================================================== # read the original segmentation mask # ================================================================== if test_dataset_name is 'HCPT1': # image will be normalized to [0,1] image_orig, labels_orig = data_hcp.load_without_size_preprocessing( input_folder=sys_config.orig_data_root_hcp, idx=ids[sub_num], protocol='T1', preprocessing_folder=sys_config.preproc_folder_hcp, depth=image_depth) num_rotations = 0 elif test_dataset_name is 'HCPT2': # image will be normalized to [0,1] image_orig, labels_orig = data_hcp.load_without_size_preprocessing( input_folder=sys_config.orig_data_root_hcp, idx=ids[sub_num], protocol='T2', preprocessing_folder=sys_config.preproc_folder_hcp, depth=image_depth) num_rotations = 0 elif test_dataset_name is 'CALTECH': # image will be normalized to [0,1] image_orig, labels_orig = data_abide.load_without_size_preprocessing( input_folder=sys_config.orig_data_root_abide, site_name='CALTECH', idx=ids[sub_num], depth=image_depth) num_rotations = 0 elif test_dataset_name is 'STANFORD': # image will be normalized to [0,1] image_orig, labels_orig = data_abide.load_without_size_preprocessing( input_folder=sys_config.orig_data_root_abide, site_name='STANFORD', idx=ids[sub_num], depth=image_depth) num_rotations = 0 # ================================================================== # convert the predicitons back to original resolution # ================================================================== predicted_labels_orig_res_and_size = rescale_and_crop( predicted_labels, orig_data_res_x[sub_num], orig_data_res_y[sub_num], orig_data_siz_x[sub_num], orig_data_siz_y[sub_num], order_interpolation=0, num_rotations=num_rotations) normalized_image_orig_res_and_size = rescale_and_crop( normalized_image, orig_data_res_x[sub_num], orig_data_res_y[sub_num], orig_data_siz_x[sub_num], orig_data_siz_y[sub_num], order_interpolation=1, num_rotations=num_rotations) # ================================================================== # compute dice at the original resolution # ================================================================== dice_per_label_this_subject = met.f1_score( labels_orig.flatten(), predicted_labels_orig_res_and_size.flatten(), average=None) # ================================================================== # compute Hausforff distance at the original resolution # ================================================================== hsd_per_label_this_subject = utils.compute_surface_distance( y1=labels_orig, y2=predicted_labels_orig_res_and_size, nlabels=exp_config.nlabels) # ================================================================ # save sample results # ================================================================ utils_vis.save_sample_prediction_results( x=utils.crop_or_pad_volume_to_size_along_z(image_orig, 256), x_norm=utils.crop_or_pad_volume_to_size_along_z( normalized_image_orig_res_and_size, 256), y_pred=utils.crop_or_pad_volume_to_size_along_z( predicted_labels_orig_res_and_size, 256), gt=utils.crop_or_pad_volume_to_size_along_z(labels_orig, 256), num_rotations= -num_rotations, # rotate for consistent visualization across datasets savepath=log_dir + '/' + test_dataset_name + '_' + 'test' + '_' + subject_name + file_suffix + '.png') # ================================ # write the mean fg dice of this subject to the text file # ================================ results_file.write(subject_name + ":: dice (mean, std over all FG labels): ") results_file.write( str(np.round(np.mean(dice_per_label_this_subject[1:]), 3)) + ", " + str(np.round(np.std(dice_per_label_this_subject[1:]), 3))) dice_per_label_per_subject.append(dice_per_label_this_subject) results_file.write( ", hausdorff distance (mean, std over all FG labels): ") results_file.write( str(np.round(np.mean(hsd_per_label_this_subject), 3)) + ", " + str(np.round(np.std(dice_per_label_this_subject[1:]), 3))) hsd_per_label_per_subject.append(hsd_per_label_this_subject) results_file.write("\n") # ================================================================ # write per label statistics over all subjects # ================================================================ dice_per_label_per_subject = np.array(dice_per_label_per_subject) hsd_per_label_per_subject = np.array(hsd_per_label_per_subject) # ================================ # In the array images_dice, in the rows, there are subjects # and in the columns, there are the dice scores for each label for a particular subject # ================================ results_file.write("================================== \n") results_file.write("Label: dice mean, std. deviation over all subjects\n") for i in range(dice_per_label_per_subject.shape[1]): results_file.write( str(i) + ": " + str(np.round(np.mean(dice_per_label_per_subject[:, i]), 3)) + ", " + str(np.round(np.std(dice_per_label_per_subject[:, i]), 3)) + "\n") results_file.write("================================== \n") results_file.write( "Label: hausdorff distance mean, std. deviation over all subjects\n") for i in range(hsd_per_label_per_subject.shape[1]): results_file.write( str(i + 1) + ": " + str(np.round(np.mean(hsd_per_label_per_subject[:, i]), 3)) + ", " + str(np.round(np.std(hsd_per_label_per_subject[:, i]), 3)) + "\n") # ================== # write the mean dice over all subjects and all labels # ================== results_file.write("================================== \n") results_file.write( "DICE Mean, std. deviation over foreground labels over all subjects: " + str(np.round(np.mean(dice_per_label_per_subject[:, 1:]), 3)) + ", " + str(np.round(np.std(dice_per_label_per_subject[:, 1:]), 3)) + "\n") results_file.write( "HSD Mean, std. deviation over labels over all subjects: " + str(np.round(np.mean(hsd_per_label_per_subject), 3)) + ", " + str(np.round(np.std(hsd_per_label_per_subject), 3)) + "\n") results_file.write("================================== \n") results_file.close()
def run_training(continue_run): # ============================ # log experiment details # ============================ logging.info( '============================================================') logging.info('EXPERIMENT NAME: %s' % exp_config.experiment_name_i2l) # ============================ # Initialize step number - this is number of mini-batch runs # ============================ init_step = 0 # ============================ # if continue_run is set to True, load the model parameters saved earlier # else start training from scratch # ============================ if continue_run: logging.info( '============================================================') logging.info('Continuing previous run') try: init_checkpoint_path = utils.get_latest_model_checkpoint_path( log_dir, 'models/model.ckpt') logging.info('Checkpoint path: %s' % init_checkpoint_path) init_step = int( init_checkpoint_path.split('/')[-1].split('-') [-1]) + 1 # plus 1 as otherwise starts with eval logging.info('Latest step was: %d' % init_step) except: logging.warning( 'Did not find init checkpoint. Maybe first run failed. Disabling continue mode...' ) continue_run = False init_step = 0 logging.info( '============================================================') # ============================ # Load data # ============================ logging.info( '============================================================') logging.info('Loading data...') if exp_config.train_dataset is 'HCPT1': logging.info('Reading HCPT1 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) data_brain_train = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=0, idx_end=1040, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_hcp, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=20, idx_end=25, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_hcp, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] if exp_config.train_dataset is 'HCPT2': logging.info('Reading HCPT2 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) data_brain_train = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=0, idx_end=20, protocol='T2', size=exp_config.image_size, depth=exp_config.image_depth_hcp, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=20, idx_end=25, protocol='T2', size=exp_config.image_size, depth=exp_config.image_depth_hcp, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] elif exp_config.train_dataset is 'CALTECH': logging.info('Reading CALTECH images...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide + 'CALTECH/') data_brain_train = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='CALTECH', idx_start=0, idx_end=10, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_caltech, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='CALTECH', idx_start=10, idx_end=15, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_caltech, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] elif exp_config.train_dataset is 'STANFORD': logging.info('Reading STANFORD images...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide + 'STANFORD/') data_brain_train = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='STANFORD', idx_start=0, idx_end=10, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_stanford, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='STANFORD', idx_start=10, idx_end=15, protocol='T1', size=exp_config.image_size, depth=exp_config.image_depth_stanford, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] elif exp_config.train_dataset is 'IXI': logging.info('Reading IXI images...') logging.info('Data root directory: ' + sys_config.orig_data_root_ixi) data_brain_train = data_ixi.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_ixi, preprocessing_folder=sys_config.preproc_folder_ixi, idx_start=0, idx_end=12, protocol='T2', size=exp_config.image_size, depth=exp_config.image_depth_ixi, target_resolution=exp_config.target_resolution_brain) imtr, gttr = [data_brain_train['images'], data_brain_train['labels']] data_brain_val = data_ixi.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_ixi, preprocessing_folder=sys_config.preproc_folder_ixi, idx_start=12, idx_end=17, protocol='T2', size=exp_config.image_size, depth=exp_config.image_depth_ixi, target_resolution=exp_config.target_resolution_brain) imvl, gtvl = [data_brain_val['images'], data_brain_val['labels']] logging.info( 'Training Images: %s' % str(imtr.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info( 'Training Labels: %s' % str(gttr.shape)) # expected: [num_slices, img_size_x, img_size_y] logging.info('Validation Images: %s' % str(imvl.shape)) logging.info('Validation Labels: %s' % str(gtvl.shape)) logging.info( '============================================================') # ================================================================ # build the TF graph # ================================================================ with tf.Graph().as_default(): # ============================ # set random seed for reproducibility # ============================ tf.random.set_random_seed(exp_config.run_number) np.random.seed(exp_config.run_number) # ================================================================ # create placeholders # ================================================================ logging.info('Creating placeholders...') image_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) + [1] mask_tensor_shape = [exp_config.batch_size] + list( exp_config.image_size) images_pl = tf.placeholder(tf.float32, shape=image_tensor_shape, name='images') labels_pl = tf.placeholder(tf.uint8, shape=mask_tensor_shape, name='labels') learning_rate_pl = tf.placeholder(tf.float32, shape=[], name='learning_rate') training_pl = tf.placeholder(tf.bool, shape=[], name='training_or_testing') # ================================================================ # insert a normalization module in front of the segmentation network # the normalization module will be adapted for each test image # ================================================================ images_normalized, _ = model.normalize(images_pl, exp_config, training_pl) # ================================================================ # build the graph that computes predictions from the inference model # ================================================================ logits, _, _ = model.predict_i2l(images_normalized, exp_config, training_pl=training_pl) print('shape of inputs: ', images_pl.shape) # (batch_size, 256, 256, 1) print('shape of logits: ', logits.shape) # (batch_size, 256, 256, nlabels) # ================================================================ # create a list of all vars that must be optimized wrt # ================================================================ i2l_vars = [] for v in tf.trainable_variables(): i2l_vars.append(v) # ================================================================ # add ops for calculation of the supervised training loss # ================================================================ loss_op = model.loss(logits, labels_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type_i2l) tf.summary.scalar('loss', loss_op) # ================================================================ # add optimization ops. # Create different ops according to the variables that must be trained # ================================================================ print('creating training op...') train_op = model.training_step(loss_op, i2l_vars, exp_config.optimizer_handle, learning_rate_pl, update_bn_nontrainable_vars=True) # ================================================================ # add ops for model evaluation # ================================================================ print('creating eval op...') eval_loss = model.evaluation_i2l(logits, labels_pl, images_pl, nlabels=exp_config.nlabels, loss_type=exp_config.loss_type_i2l) # ================================================================ # build the summary Tensor based on the TF collection of Summaries. # ================================================================ print('creating summary op...') summary = tf.summary.merge_all() # ================================================================ # add init ops # ================================================================ init_ops = tf.global_variables_initializer() # ================================================================ # find if any vars are uninitialized # ================================================================ logging.info('Adding the op to get a list of initialized variables...') uninit_vars = tf.report_uninitialized_variables() # ================================================================ # create saver # ================================================================ saver = tf.train.Saver(max_to_keep=10) saver_best_dice = tf.train.Saver(max_to_keep=3) # ================================================================ # create session # ================================================================ sess = tf.Session() # ================================================================ # create a summary writer # ================================================================ summary_writer = tf.summary.FileWriter(log_dir, sess.graph) # ================================================================ # summaries of the validation errors # ================================================================ vl_error = tf.placeholder(tf.float32, shape=[], name='vl_error') vl_error_summary = tf.summary.scalar('validation/loss', vl_error) vl_dice = tf.placeholder(tf.float32, shape=[], name='vl_dice') vl_dice_summary = tf.summary.scalar('validation/dice', vl_dice) vl_summary = tf.summary.merge([vl_error_summary, vl_dice_summary]) # ================================================================ # summaries of the training errors # ================================================================ tr_error = tf.placeholder(tf.float32, shape=[], name='tr_error') tr_error_summary = tf.summary.scalar('training/loss', tr_error) tr_dice = tf.placeholder(tf.float32, shape=[], name='tr_dice') tr_dice_summary = tf.summary.scalar('training/dice', tr_dice) tr_summary = tf.summary.merge([tr_error_summary, tr_dice_summary]) # ================================================================ # freeze the graph before execution # ================================================================ logging.info('Freezing the graph now!') tf.get_default_graph().finalize() # ================================================================ # Run the Op to initialize the variables. # ================================================================ logging.info( '============================================================') logging.info('initializing all variables...') sess.run(init_ops) # ================================================================ # print names of all variables # ================================================================ logging.info( '============================================================') logging.info('This is the list of all variables:') for v in tf.trainable_variables(): print(v.name) # ================================================================ # print names of uninitialized variables # ================================================================ logging.info( '============================================================') logging.info('This is the list of uninitialized variables:') uninit_variables = sess.run(uninit_vars) for v in uninit_variables: print(v) # ================================================================ # continue run from a saved checkpoint # ================================================================ if continue_run: # Restore session logging.info( '============================================================') logging.info('Restroring session from: %s' % init_checkpoint_path) saver.restore(sess, init_checkpoint_path) # ================================================================ # ================================================================ step = init_step curr_lr = exp_config.learning_rate best_dice = 0 # ================================================================ # run training epochs # ================================================================ while (step < exp_config.max_steps): if step % 1000 is 0: logging.info( '============================================================' ) logging.info('step %d' % step) # ================================================ # batches # ================================================ for batch in iterate_minibatches(imtr, gttr, batch_size=exp_config.batch_size, train_or_eval='train'): curr_lr = exp_config.learning_rate start_time = time.time() x, y = batch # =========================== # avoid incomplete batches # =========================== if y.shape[0] < exp_config.batch_size: step += 1 continue # =========================== # create the feed dict for this training iteration # =========================== feed_dict = { images_pl: x, labels_pl: y, learning_rate_pl: curr_lr, training_pl: True } # =========================== # opt step # =========================== _, loss = sess.run([train_op, loss_op], feed_dict=feed_dict) # =========================== # compute the time for this mini-batch computation # =========================== duration = time.time() - start_time # =========================== # write the summaries and print an overview fairly often # =========================== if (step + 1) % exp_config.summary_writing_frequency == 0: logging.info( 'Step %d: loss = %.3f (%.3f sec for the last step)' % (step + 1, loss, duration)) # =========================== # Update the events file # =========================== summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() # =========================== # Compute the loss on the entire training set # =========================== if step % exp_config.train_eval_frequency == 0: logging.info('Training Data Eval:') train_loss, train_dice = do_eval(sess, eval_loss, images_pl, labels_pl, training_pl, imtr, gttr, exp_config.batch_size) tr_summary_msg = sess.run(tr_summary, feed_dict={ tr_error: train_loss, tr_dice: train_dice }) summary_writer.add_summary(tr_summary_msg, step) # =========================== # Save a checkpoint periodically # =========================== if step % exp_config.save_frequency == 0: checkpoint_file = os.path.join(log_dir, 'models/model.ckpt') saver.save(sess, checkpoint_file, global_step=step) # =========================== # Evaluate the model periodically on a validation set # =========================== if step % exp_config.val_eval_frequency == 0: logging.info('Validation Data Eval:') val_loss, val_dice = do_eval(sess, eval_loss, images_pl, labels_pl, training_pl, imvl, gtvl, exp_config.batch_size) vl_summary_msg = sess.run(vl_summary, feed_dict={ vl_error: val_loss, vl_dice: val_dice }) summary_writer.add_summary(vl_summary_msg, step) # =========================== # save model if the val dice is the best yet # =========================== if val_dice > best_dice: best_dice = val_dice best_file = os.path.join(log_dir, 'models/best_dice.ckpt') saver_best_dice.save(sess, best_file, global_step=step) logging.info( 'Found new average best dice on validation sets! - %f - Saving model.' % val_dice) step += 1 sess.close()
def main(): # =================================== # read the test images # =================================== test_dataset_name = exp_config.test_dataset if test_dataset_name is 'HCPT1': logging.info('Reading HCPT1 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) image_depth = exp_config.image_depth_hcp idx_start = 50 idx_end = 70 data_brain_test = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=idx_start, idx_end=idx_end, protocol='T1', size=exp_config.image_size, depth=image_depth, target_resolution=exp_config.target_resolution_brain) elif test_dataset_name is 'HCPT2': logging.info('Reading HCPT2 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) image_depth = exp_config.image_depth_hcp idx_start = 50 idx_end = 70 data_brain_test = data_hcp.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_hcp, preprocessing_folder=sys_config.preproc_folder_hcp, idx_start=idx_start, idx_end=idx_end, protocol='T2', size=exp_config.image_size, depth=image_depth, target_resolution=exp_config.target_resolution_brain) elif test_dataset_name is 'CALTECH': logging.info('Reading CALTECH images...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide + 'CALTECH/') image_depth = exp_config.image_depth_caltech idx_start = 16 idx_end = 36 data_brain_test = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='CALTECH', idx_start=idx_start, idx_end=idx_end, protocol='T1', size=exp_config.image_size, depth=image_depth, target_resolution=exp_config.target_resolution_brain) elif test_dataset_name is 'STANFORD': logging.info('Reading STANFORD images...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide + 'STANFORD/') image_depth = exp_config.image_depth_stanford idx_start = 16 idx_end = 36 data_brain_test = data_abide.load_and_maybe_process_data( input_folder=sys_config.orig_data_root_abide, preprocessing_folder=sys_config.preproc_folder_abide, site_name='STANFORD', idx_start=idx_start, idx_end=idx_end, protocol='T1', size=exp_config.image_size, depth=image_depth, target_resolution=exp_config.target_resolution_brain) imts = data_brain_test['images'] name_test_subjects = data_brain_test['patnames'] num_test_subjects = imts.shape[0] // image_depth ids = np.arange(idx_start, idx_end) orig_data_res_x = data_brain_test['px'][:] orig_data_res_y = data_brain_test['py'][:] orig_data_res_z = data_brain_test['pz'][:] orig_data_siz_x = data_brain_test['nx'][:] orig_data_siz_y = data_brain_test['ny'][:] orig_data_siz_z = data_brain_test['nz'][:] # ================================ # set the log directory # ================================ if exp_config.normalize is True: log_dir = os.path.join(sys_config.log_root, exp_config.expname_normalizer) else: log_dir = sys_config.log_root + 'i2l_mapper/' + exp_config.expname_i2l # ================================================================ # For each test image, load the best model and compute the dice with this model # ================================================================ for sub_num in range(5): subject_id_start_slice = np.sum(orig_data_siz_z[:sub_num]) subject_id_end_slice = np.sum(orig_data_siz_z[:sub_num + 1]) image = imts[subject_id_start_slice:subject_id_end_slice, :, :] # ================================================================== # setup logging # ================================================================== logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') subject_name = str(name_test_subjects[sub_num])[2:-1] logging.info( '============================================================') logging.info('Subject id: %s' % sub_num) # ================================================================== # predict segmentation at the pre-processed resolution # ================================================================== predicted_labels, normalized_image = predict_segmentation( subject_name, image, exp_config.normalize) # ================================================================== # read the original segmentation mask # ================================================================== if test_dataset_name is 'HCPT1': # image will be normalized to [0,1] image_orig, labels_orig = data_hcp.load_without_size_preprocessing( input_folder=sys_config.orig_data_root_hcp, idx=ids[sub_num], protocol='T1', preprocessing_folder=sys_config.preproc_folder_hcp, depth=image_depth) num_rotations = 0 elif test_dataset_name is 'HCPT2': # image will be normalized to [0,1] image_orig, labels_orig = data_hcp.load_without_size_preprocessing( input_folder=sys_config.orig_data_root_hcp, idx=ids[sub_num], protocol='T2', preprocessing_folder=sys_config.preproc_folder_hcp, depth=image_depth) num_rotations = 0 elif test_dataset_name is 'CALTECH': # image will be normalized to [0,1] image_orig, labels_orig = data_abide.load_without_size_preprocessing( input_folder=sys_config.orig_data_root_abide, site_name='CALTECH', idx=ids[sub_num], depth=image_depth) num_rotations = 0 elif test_dataset_name is 'STANFORD': # image will be normalized to [0,1] image_orig, labels_orig = data_abide.load_without_size_preprocessing( input_folder=sys_config.orig_data_root_abide, site_name='STANFORD', idx=ids[sub_num], depth=image_depth) num_rotations = 0 # ================================================================== # convert the predicitons back to original resolution # ================================================================== predicted_labels_orig_res_and_size = rescale_and_crop( predicted_labels, orig_data_res_x[sub_num], orig_data_res_y[sub_num], orig_data_siz_x[sub_num], orig_data_siz_y[sub_num], order_interpolation=0, num_rotations=num_rotations) normalized_image_orig_res_and_size = rescale_and_crop( normalized_image, orig_data_res_x[sub_num], orig_data_res_y[sub_num], orig_data_siz_x[sub_num], orig_data_siz_y[sub_num], order_interpolation=1, num_rotations=num_rotations) # ================================================================ # save sample results # ================================================================ x_true = utils.crop_or_pad_volume_to_size_along_z(image_orig, 256) z_true = utils.crop_or_pad_volume_to_size_along_z(labels_orig, 256) x_norm = utils.crop_or_pad_volume_to_size_along_z( normalized_image_orig_res_and_size, 256) z_pred = utils.crop_or_pad_volume_to_size_along_z( predicted_labels_orig_res_and_size, 256) # basepath = os.path.join(sys_config.log_root, exp_config.expname_normalizer) + '/subject_' + subject_name + '/results/tta' + str(exp_config.normalize) basepath = log_dir + '/' + test_dataset_name + '_' + 'test' + '_' + subject_name for zz in np.arange(120, 130, 10): utils_vis.save_single_image( x_true[:, :, zz], basepath + 'slice' + str(zz) + '_x_true.png', 15, False, 'gray', False) utils_vis.save_single_image(x_norm[:, :, zz], basepath + 'slice' + str(zz) + '_x_norm.png', 15, False, 'gray', False, climits=[-1.0, 2.0]) utils_vis.save_single_image( z_true[:, :, zz], basepath + 'slice' + str(zz) + '_z_true.png', 15, True, 'tab20', False) utils_vis.save_single_image( z_pred[:, :, zz], basepath + 'slice' + str(zz) + '_z_pred.png', 15, True, 'tab20', False)
def main(): # =================================== # read the test images # =================================== if exp_config.evaluate_td is True: test_dataset_name = exp_config.test_dataset else: test_dataset_name = exp_config.train_dataset if test_dataset_name is 'HCPT1': logging.info('Reading HCPT1 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) image_depth = exp_config.image_depth_hcp idx_start = 50 idx_end = 70 data_test = data_hcp.load_and_maybe_process_data(input_folder = sys_config.orig_data_root_hcp, preprocessing_folder = sys_config.preproc_folder_hcp, idx_start = idx_start, idx_end = idx_end, protocol = 'T1', size = exp_config.image_size, depth = image_depth, target_resolution = exp_config.target_resolution) imts = data_test['images'] name_test_subjects = data_test['patnames'] num_test_subjects = imts.shape[0] // image_depth ids = np.arange(idx_start, idx_end) orig_data_res_x = data_test['px'][:] orig_data_res_y = data_test['py'][:] orig_data_res_z = data_test['pz'][:] orig_data_siz_x = data_test['nx'][:] orig_data_siz_y = data_test['ny'][:] orig_data_siz_z = data_test['nz'][:] elif test_dataset_name is 'HCPT2': logging.info('Reading HCPT2 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) image_depth = exp_config.image_depth_hcp idx_start = 50 idx_end = 70 data_test = data_hcp.load_and_maybe_process_data(input_folder = sys_config.orig_data_root_hcp, preprocessing_folder = sys_config.preproc_folder_hcp, idx_start = idx_start, idx_end = idx_end, protocol = 'T2', size = exp_config.image_size, depth = image_depth, target_resolution = exp_config.target_resolution) imts = data_test['images'] name_test_subjects = data_test['patnames'] num_test_subjects = imts.shape[0] // image_depth ids = np.arange(idx_start, idx_end) orig_data_res_x = data_test['px'][:] orig_data_res_y = data_test['py'][:] orig_data_res_z = data_test['pz'][:] orig_data_siz_x = data_test['nx'][:] orig_data_siz_y = data_test['ny'][:] orig_data_siz_z = data_test['nz'][:] elif test_dataset_name is 'CALTECH': logging.info('Reading CALTECH images...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide + 'CALTECH/') image_depth = exp_config.image_depth_caltech idx_start = 16 idx_end = 36 data_test = data_abide.load_and_maybe_process_data(input_folder = sys_config.orig_data_root_abide, preprocessing_folder = sys_config.preproc_folder_abide, site_name = 'CALTECH', idx_start = idx_start, idx_end = idx_end, protocol = 'T1', size = exp_config.image_size, depth = image_depth, target_resolution = exp_config.target_resolution) imts = data_test['images'] name_test_subjects = data_test['patnames'] num_test_subjects = imts.shape[0] // image_depth ids = np.arange(idx_start, idx_end) orig_data_res_x = data_test['px'][:] orig_data_res_y = data_test['py'][:] orig_data_res_z = data_test['pz'][:] orig_data_siz_x = data_test['nx'][:] orig_data_siz_y = data_test['ny'][:] orig_data_siz_z = data_test['nz'][:] elif test_dataset_name is 'NCI': data_test = data_nci.load_and_maybe_process_data(input_folder=sys_config.orig_data_root_nci, preprocessing_folder=sys_config.preproc_folder_nci, size=exp_config.image_size, target_resolution=exp_config.target_resolution, force_overwrite=False, cv_fold_num = 1) imts = data_test['images_test'] name_test_subjects = data_test['patnames_test'] orig_data_res_x = data_test['px_test'][:] orig_data_res_y = data_test['py_test'][:] orig_data_res_z = data_test['pz_test'][:] orig_data_siz_x = data_test['nx_test'][:] orig_data_siz_y = data_test['ny_test'][:] orig_data_siz_z = data_test['nz_test'][:] num_test_subjects = orig_data_siz_z.shape[0] ids = np.arange(num_test_subjects) elif test_dataset_name is 'PIRAD_ERC': idx_start = 0 idx_end = 20 ids = np.arange(idx_start, idx_end) data_test = data_pirad_erc.load_data(input_folder=sys_config.orig_data_root_pirad_erc, preproc_folder=sys_config.preproc_folder_pirad_erc, idx_start=idx_start, idx_end=idx_end, size=exp_config.image_size, target_resolution=exp_config.target_resolution, labeller='ek') imts = data_test['images'] name_test_subjects = data_test['patnames'] orig_data_res_x = data_test['px'][:] orig_data_res_y = data_test['py'][:] orig_data_res_z = data_test['pz'][:] orig_data_siz_x = data_test['nx'][:] orig_data_siz_y = data_test['ny'][:] orig_data_siz_z = data_test['nz'][:] num_test_subjects = orig_data_siz_z.shape[0] # ================================ # set the log directory # ================================ if exp_config.normalize is True: log_dir = os.path.join(sys_config.log_root, exp_config.expname_normalizer) else: if exp_config.uda is False: log_dir = sys_config.log_root + exp_config.expname_i2l else: log_dir = sys_config.log_root + exp_config.expname_uda # ================================ # open a text file for writing the mean dice scores for each subject that is evaluated # ================================ results_file = open(log_dir + '/' + test_dataset_name + '_' + 'test' + '.txt', "w") results_file.write("================================== \n") results_file.write("Test results \n") # ================================================================ # For each test image, load the best model and compute the dice with this model # ================================================================ dice_per_label_per_subject = [] hsd_per_label_per_subject = [] for sub_num in range(5): #(num_test_subjects): subject_id_start_slice = np.sum(orig_data_siz_z[:sub_num]) subject_id_end_slice = np.sum(orig_data_siz_z[:sub_num+1]) image = imts[subject_id_start_slice:subject_id_end_slice,:,:] # ================================================================== # setup logging # ================================================================== logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') subject_name = str(name_test_subjects[sub_num])[2:-1] logging.info('============================================================') logging.info('Subject id: %s' %sub_num) # ================================================================== # predict segmentation at the pre-processed resolution # ================================================================== predicted_labels, normalized_image = predict_segmentation(subject_name, image, exp_config.normalize) # ================================================================== # read the original segmentation mask # ================================================================== if test_dataset_name is 'HCPT1': # image will be normalized to [0,1] image_orig, labels_orig = data_hcp.load_without_size_preprocessing(input_folder = sys_config.orig_data_root_hcp, idx = ids[sub_num], protocol = 'T1', preprocessing_folder = sys_config.preproc_folder_hcp, depth = image_depth) num_rotations = 0 elif test_dataset_name is 'HCPT2': # image will be normalized to [0,1] image_orig, labels_orig = data_hcp.load_without_size_preprocessing(input_folder = sys_config.orig_data_root_hcp, idx = ids[sub_num], protocol = 'T2', preprocessing_folder = sys_config.preproc_folder_hcp, depth = image_depth) num_rotations = 0 elif test_dataset_name is 'CALTECH': # image will be normalized to [0,1] image_orig, labels_orig = data_abide.load_without_size_preprocessing(input_folder = sys_config.orig_data_root_abide, site_name = 'CALTECH', idx = ids[sub_num], depth = image_depth) num_rotations = 0 elif test_dataset_name is 'STANFORD': # image will be normalized to [0,1] image_orig, labels_orig = data_abide.load_without_size_preprocessing(input_folder = sys_config.orig_data_root_abide, site_name = 'STANFORD', idx = ids[sub_num], depth = image_depth) num_rotations = 0 elif test_dataset_name is 'NCI': # image will be normalized to [0,1] image_orig, labels_orig = data_nci.load_without_size_preprocessing(sys_config.orig_data_root_nci, cv_fold_num=1, train_test='test', idx=ids[sub_num]) num_rotations = 0 elif test_dataset_name is 'PIRAD_ERC': # image will be normalized to [0,1] image_orig, labels_orig = data_pirad_erc.load_without_size_preprocessing(sys_config.orig_data_root_pirad_erc, ids[sub_num], labeller='ek') num_rotations = -3 # ================================================================== # convert the predicitons back to original resolution # ================================================================== predicted_labels_orig_res_and_size = rescale_and_crop(predicted_labels, orig_data_res_x[sub_num], orig_data_res_y[sub_num], orig_data_siz_x[sub_num], orig_data_siz_y[sub_num], order_interpolation = 0, num_rotations = num_rotations) normalized_image_orig_res_and_size = rescale_and_crop(normalized_image, orig_data_res_x[sub_num], orig_data_res_y[sub_num], orig_data_siz_x[sub_num], orig_data_siz_y[sub_num], order_interpolation = 1, num_rotations = num_rotations) # ================================================================== # If only whole-gland comparisions are desired, merge the labels in both ground truth segmentations as well as the predictions # ================================================================== if exp_config.whole_gland_results is True: predicted_labels_orig_res_and_size[predicted_labels_orig_res_and_size!=0] = 1 labels_orig[labels_orig!=0] = 1 nl = 2 savepath = log_dir + '/' + test_dataset_name + '_test_' + subject_name + '_whole_gland.png' else: nl = exp_config.nlabels savepath = log_dir + '/' + test_dataset_name + '_test_' + subject_name + '.png' # ================================================================== # compute dice at the original resolution # ================================================================== dice_per_label_this_subject = met.f1_score(labels_orig.flatten(), predicted_labels_orig_res_and_size.flatten(), average=None) # ================================================================== # compute Hausforff distance at the original resolution # ================================================================== compute_hsd = False if compute_hsd is True: hsd_per_label_this_subject = utils.compute_surface_distance(y1 = labels_orig, y2 = predicted_labels_orig_res_and_size, nlabels = exp_config.nlabels) else: hsd_per_label_this_subject = np.zeros((exp_config.nlabels)) # ================================================================ # save sample results # ================================================================ d_vis = 32 # 256 ids_vis = np.arange(0, 32, 4) # ids = np.arange(48, 256-48, (256-96)//8) utils_vis.save_sample_prediction_results(x = utils.crop_or_pad_volume_to_size_along_z(image_orig, d_vis), x_norm = utils.crop_or_pad_volume_to_size_along_z(normalized_image_orig_res_and_size, d_vis), y_pred = utils.crop_or_pad_volume_to_size_along_z(predicted_labels_orig_res_and_size, d_vis), gt = utils.crop_or_pad_volume_to_size_along_z(labels_orig, d_vis), num_rotations = - num_rotations, # rotate for consistent visualization across datasets savepath = savepath, nlabels = nl, ids=ids_vis) # ================================ # write the mean fg dice of this subject to the text file # ================================ results_file.write(subject_name + ":: dice (mean, std over all FG labels): ") results_file.write(str(np.round(np.mean(dice_per_label_this_subject[1:]), 3)) + ", " + str(np.round(np.std(dice_per_label_this_subject[1:]), 3))) results_file.write(", hausdorff distance (mean, std over all FG labels): ") results_file.write(str(np.round(np.mean(hsd_per_label_this_subject), 3)) + ", " + str(np.round(np.std(dice_per_label_this_subject[1:]), 3)) + "\n") dice_per_label_per_subject.append(dice_per_label_this_subject) hsd_per_label_per_subject.append(hsd_per_label_this_subject) # ================================================================ # write per label statistics over all subjects # ================================================================ dice_per_label_per_subject = np.array(dice_per_label_per_subject) hsd_per_label_per_subject = np.array(hsd_per_label_per_subject) # ================================ # In the array images_dice, in the rows, there are subjects # and in the columns, there are the dice scores for each label for a particular subject # ================================ results_file.write("================================== \n") results_file.write("Label: dice mean, std. deviation over all subjects\n") for i in range(dice_per_label_per_subject.shape[1]): results_file.write(str(i) + ": " + str(np.round(np.mean(dice_per_label_per_subject[:,i]), 3)) + ", " + str(np.round(np.std(dice_per_label_per_subject[:,i]), 3)) + "\n") results_file.write("================================== \n") results_file.write("Label: hausdorff distance mean, std. deviation over all subjects\n") for i in range(hsd_per_label_per_subject.shape[1]): results_file.write(str(i+1) + ": " + str(np.round(np.mean(hsd_per_label_per_subject[:,i]), 3)) + ", " + str(np.round(np.std(hsd_per_label_per_subject[:,i]), 3)) + "\n") # ================== # write the mean dice over all subjects and all labels # ================== results_file.write("================================== \n") results_file.write("DICE Mean, std. deviation over foreground labels over all subjects: " + str(np.round(np.mean(dice_per_label_per_subject[:,1:]), 3)) + ", " + str(np.round(np.std(dice_per_label_per_subject[:,1:]), 3)) + "\n") results_file.write("HSD Mean, std. deviation over labels over all subjects: " + str(np.round(np.mean(hsd_per_label_per_subject), 3)) + ", " + str(np.round(np.std(hsd_per_label_per_subject), 3)) + "\n") results_file.write("================================== \n") results_file.close()
def main(): # =================================== # read the test images # =================================== test_dataset_name = exp_config.test_dataset if test_dataset_name is 'HCPT1': logging.info('Reading HCPT1 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) image_depth = exp_config.image_depth_hcp idx_start = 50 idx_end = 70 data_brain_test = data_hcp.load_and_maybe_process_data(input_folder = sys_config.orig_data_root_hcp, preprocessing_folder = sys_config.preproc_folder_hcp, idx_start = idx_start, idx_end = idx_end, protocol = 'T1', size = exp_config.image_size, depth = image_depth, target_resolution = exp_config.target_resolution_brain) elif test_dataset_name is 'HCPT2': logging.info('Reading HCPT2 images...') logging.info('Data root directory: ' + sys_config.orig_data_root_hcp) image_depth = exp_config.image_depth_hcp idx_start = 50 idx_end = 70 data_brain_test = data_hcp.load_and_maybe_process_data(input_folder = sys_config.orig_data_root_hcp, preprocessing_folder = sys_config.preproc_folder_hcp, idx_start = idx_start, idx_end = idx_end, protocol = 'T2', size = exp_config.image_size, depth = image_depth, target_resolution = exp_config.target_resolution_brain) elif test_dataset_name is 'CALTECH': logging.info('Reading CALTECH images...') logging.info('Data root directory: ' + sys_config.orig_data_root_abide + 'CALTECH/') image_depth = exp_config.image_depth_caltech idx_start = 16 idx_end = 36 data_brain_test = data_abide.load_and_maybe_process_data(input_folder = sys_config.orig_data_root_abide, preprocessing_folder = sys_config.preproc_folder_abide, site_name = 'CALTECH', idx_start = idx_start, idx_end = idx_end, protocol = 'T1', size = exp_config.image_size, depth = image_depth, target_resolution = exp_config.target_resolution_brain) imts = data_brain_test['images'] name_test_subjects = data_brain_test['patnames'] ids = np.arange(idx_start, idx_end) orig_data_res_x = data_brain_test['px'][:] orig_data_res_y = data_brain_test['py'][:] orig_data_siz_x = data_brain_test['nx'][:] orig_data_siz_y = data_brain_test['ny'][:] orig_data_siz_z = data_brain_test['nz'][:] # ================================================================ # Set subject number here # ================================================================ for sub_num in np.arange(20): subject_id_start_slice = np.sum(orig_data_siz_z[:sub_num]) subject_id_end_slice = np.sum(orig_data_siz_z[:sub_num+1]) image = imts[subject_id_start_slice:subject_id_end_slice,:,:] # ================================================================== # setup logging # ================================================================== logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') subject_name = str(name_test_subjects[sub_num])[2:-1] logging.info('============================================================') logging.info('Subject id: %s' %sub_num) # ================================================================== # predict segmentation at the pre-processed resolution # ================================================================== predicted_labels, normalized_image, denoised_labels, predicted_labels_tta, normalized_image_tta = predict_segmentation(subject_name, image, exp_config.normalize) # ================================================================== # read the original segmentation mask # ================================================================== if test_dataset_name is 'HCPT1': # image will be normalized to [0,1] image_orig, labels_orig = data_hcp.load_without_size_preprocessing(input_folder = sys_config.orig_data_root_hcp, idx = ids[sub_num], protocol = 'T1', preprocessing_folder = sys_config.preproc_folder_hcp, depth = image_depth) num_rotations = 0 elif test_dataset_name is 'HCPT2': # image will be normalized to [0,1] image_orig, labels_orig = data_hcp.load_without_size_preprocessing(input_folder = sys_config.orig_data_root_hcp, idx = ids[sub_num], protocol = 'T2', preprocessing_folder = sys_config.preproc_folder_hcp, depth = image_depth) num_rotations = 0 elif test_dataset_name is 'CALTECH': # image will be normalized to [0,1] image_orig, labels_orig = data_abide.load_without_size_preprocessing(input_folder = sys_config.orig_data_root_abide, site_name = 'CALTECH', idx = ids[sub_num], depth = image_depth) num_rotations = 0 # ================================================================== # convert the predicitons back to original resolution # ================================================================== predicted_labels_orig_res_and_size = rescale_and_crop(predicted_labels, orig_data_res_x[sub_num], orig_data_res_y[sub_num], orig_data_siz_x[sub_num], orig_data_siz_y[sub_num], order_interpolation = 0, num_rotations = num_rotations) normalized_image_orig_res_and_size = rescale_and_crop(normalized_image, orig_data_res_x[sub_num], orig_data_res_y[sub_num], orig_data_siz_x[sub_num], orig_data_siz_y[sub_num], order_interpolation = 1, num_rotations = num_rotations) denoised_labels_orig_res_and_size = rescale_and_crop(denoised_labels, orig_data_res_x[sub_num], orig_data_res_y[sub_num], orig_data_siz_x[sub_num], orig_data_siz_y[sub_num], order_interpolation = 0, num_rotations = num_rotations) predicted_labels_tta_orig_res_and_size = rescale_and_crop(predicted_labels_tta, orig_data_res_x[sub_num], orig_data_res_y[sub_num], orig_data_siz_x[sub_num], orig_data_siz_y[sub_num], order_interpolation = 0, num_rotations = num_rotations) normalized_image_tta_orig_res_and_size = rescale_and_crop(normalized_image_tta, orig_data_res_x[sub_num], orig_data_res_y[sub_num], orig_data_siz_x[sub_num], orig_data_siz_y[sub_num], order_interpolation = 1, num_rotations = num_rotations) # ================================================================ # save sample results # ================================================================ x_true = image_orig y_true = labels_orig x_norm = normalized_image_orig_res_and_size y_pred = predicted_labels_orig_res_and_size y_denoised = denoised_labels_orig_res_and_size x_norm_tta = normalized_image_tta_orig_res_and_size y_pred_tta = predicted_labels_tta_orig_res_and_size basepath = os.path.join(sys_config.log_root, exp_config.expname_normalizer) + '/subject_' + subject_name + '/results/' for zz in np.arange(80, 120, 10): utils_vis.save_single_image(x_true[:,:,zz], basepath + 'slice' + str(zz) + '_x_true.png', 15, False, 'gray', False) utils_vis.save_single_image(x_norm[:,:,zz], basepath + 'slice' + str(zz) + '_x_norm.png', 15, False, 'gray', False) utils_vis.save_single_image(x_norm_tta[:,:,zz], basepath + 'slice' + str(zz) + '_x_norm_tta.png', 15, False, 'gray', False) utils_vis.save_single_image(y_true[:,:,zz], basepath + 'slice' + str(zz) + '_y_true.png', 15, True, 'tab20', False) utils_vis.save_single_image(y_pred[:,:,zz], basepath + 'slice' + str(zz) + '_y_pred.png', 15, True, 'tab20', False) utils_vis.save_single_image(y_pred_tta[:,:,zz], basepath + 'slice' + str(zz) + '_y_pred_tta.png', 15, True, 'tab20', False) utils_vis.save_single_image(y_denoised[:,:,zz], basepath + 'slice' + str(zz) + '_y_pred_dae.png', 15, True, 'tab20', False)