def lrp_UNet3D_overfit_visualizer(config_data): print("lrp_UNet3D_overfit_visualizer()") normalization_mode = str(config_data['LRP']['relprop_config']['normalization']) case_type = 'training' case_number = config_data['misc']['case_number'] data_modalities = config_data['data_modalities'] model_dir = os.path.join(config_data['working_dir'], config_data['relative_checkpoint_dir'],config_data['model_label_name']) lrp_dir = os.path.join(model_dir,'output_lrp_' + normalization_mode) lrp_fullpath = os.path.join(lrp_dir,'yR_'+case_type+'_'+str(case_number)+'.lrp') print(" model_dir:%s"%(str(model_dir))) print(" data_modalities:%s"%(str(data_modalities))) print(" case_number:%s"%(str(case_number))) print(" normalization_mode:%s"%(str(normalization_mode))) pkl_file = open(lrp_fullpath, 'rb') y, Rc = pickle.load(pkl_file) pkl_file.close() print(" Rc.shape = %s [%s]\n y.shape = %s [%s]"%(str(Rc.shape),str(type(Rc)),str(y.shape),str(type(y)))) ISLESDATA = ISLES2017mass() ISLESDATA.directory_path = config_data['data_directory']['dir_ISLES2017'] one_case = ISLESDATA.load_one_case(case_type, str(case_number), data_modalities) vis = vi.SlidingVisualizer() vis.do_show = True # vis.vis2_lrp_training(one_case,y,Rc,data_modalities) args = (one_case,y,Rc,data_modalities) p = Pool(2) p.map(vis.vis_lrp_training, [args+(1,),args+(2,)] )
def lrp_UNet3D_filter_sweep_visualizer(config_data): print('lrp_UNet3D_filter_sweep_visualizer()') filter_sub_mode = config_data['LRP']['filter_sweeper']['submode'] case_type = 'training' case_number = config_data['misc']['case_number'] data_modalities = config_data['data_modalities'] model_dir = os.path.join(config_data['working_dir'], config_data['relative_checkpoint_dir'],config_data['model_label_name']) lrp_dir = os.path.join(model_dir,'lrp_filter_sweep_mode_' + filter_sub_mode ) print(" lrp_dir:%s"%(str(lrp_dir))) from pipeline.lrp import get_modalities_0001 canonical_modalities_label = ['ADC','MTT','rCBF','rCBV' ,'Tmax','TTP','OT'] modalities_dict, no_of_input_channels = get_modalities_0001(config_data) dict_name = 'lrp_dict_' + case_type + '_' + str(case_number) + '.lrpd' lrp_fullpath = os.path.join(lrp_dir,dict_name) output_dictionary = vh.load_lrp_sweep(lrp_fullpath) from dataio.dataISLES2017 import ISLES2017mass ISLESDATA = ISLES2017mass() ISLESDATA.directory_path = config_data['data_directory']['dir_ISLES2017'] one_case = ISLESDATA.load_one_case(case_type, str(case_number), canonical_modalities_label) x1, _ = ISLESDATA.load_type0001_get_raw_input(one_case, modalities_dict) from utils.vis0001 import SlidingVisualizerMultiPanelLRP vis = SlidingVisualizerMultiPanelLRP() vis.vis_filter_sweeper(config_data, output_dictionary, x1, case_number)
def training_UNet3D_load_isles2017(case_type, case_numbers, config_data): TERMINATE_SIGNAL = False ISLESDATA = ISLES2017mass() ISLESDATA.directory_path = config_data['data_directory']['dir_ISLES2017'] ISLESDATA.load_many_cases(case_type, case_numbers, config_data) if DEBUG_TRAINING_DATA_LOADING: TERMINATE_SIGNAL = True trainloader = None else: trainloader = DataLoader(dataset=ISLESDATA, num_workers=0, batch_size=config_data['basic']['batch_size'], shuffle=True) print(" trainloader loaded") return ISLESDATA, trainloader, TERMINATE_SIGNAL
def test_data_augmentation(config_data): print("test_data_augmentation") case_number = 1 case_type = 'training' canonical_modalities_label = [ 'ADC', 'MTT', 'rCBF', 'rCBV', 'Tmax', 'TTP', 'OT' ] ISLESDATA = ISLES2017mass() ISLESDATA.directory_path = config_data['data_directory']['dir_ISLES2017'] one_case = ISLESDATA.load_one_case(case_type, str(case_number), canonical_modalities_label) ''' one_case['imgobj']['ADC'].shape: (192, 192, 19) one_case['imgobj']['MTT'].shape: (192, 192, 19) one_case['imgobj']['rCBF'].shape: (192, 192, 19) one_case['imgobj']['rCBV'].shape: (192, 192, 19) one_case['imgobj']['Tmax'].shape: (192, 192, 19) one_case['imgobj']['TTP'].shape: (192, 192, 19) one_case['imgobj']['OT'].shape: (192, 192, 19) one_case['case_number']: 4 type(one_case['header']): <class 'nibabel.nifti1.Nifti1Header'> one_case['affine'].shape: (4, 4) one_case['MRS']: 4 one_case['ttMRS']: 90 one_case['TICI']: 3 one_case['TSS']: 116 one_case['TTT']: 93 one_case['smir_id']: 127206 ''' one_case_modified = {} one_case_modified['imgobj'] = {} for x_modality in canonical_modalities_label: if x_modality == 'OT': continue x1_component = one_case['imgobj'][x_modality] x1_component = normalize_numpy_array( x1_component, target_min=0, target_max=1, source_min=np.min(x1_component), source_max=np.max(x1_component), ) x1_component = np.clip(x1_component, 0.0, 0.5) one_case_modified['imgobj'][x_modality] = x1_component one_case_modified['imgobj']['OT'] = one_case['imgobj']['OT'] vis = vi.SlidingVisualizer() vis.vis2(one_case, one_case_modified) plt.show()
def test_load_many_ISLES2017(config_data): print("test_load_many_ISLES2017") config_data['batch_size'] = 4 case_numbers = range(1, 49) case_type = 'training' ISLESDATA = ISLES2017mass() ISLESDATA.directory_path = config_data['data_directory']['dir_ISLES2017'] ISLESDATA.load_many_cases_type0003(case_type, case_numbers, config_data, normalize=True) trainloader = DataLoader(dataset=ISLESDATA, num_workers=1, batch_size=config_data['batch_size'], shuffle=True) print(" trainloader loaded!")
def visual_diff_gen_0001(config_data, visual_config=None): print('visual_diff_gen_0001(). HARDCODED VARIABLES? YES') case_type = 'training' canonical_modalities_dict = { 0: 'ADC', 1: 'MTT', 2: 'rCBF', 3: 'rCBV', 4: 'Tmax', 5: 'TTP', 6: 'OT' } if visual_config is None: visual_config = { 'case_numbers': [16], # [1,4,8,], 'DEPTH_INDEX': 10, 'NOW_SHOW': 0, 'CHANNEL_SELECTION': 3, 'defect_fraction': 1.2 } case_numbers = visual_config['case_numbers'] DEPTH_INDEX = visual_config['DEPTH_INDEX'] NOW_SHOW = visual_config['NOW_SHOW'] CHANNEL_SELECTION = visual_config['CHANNEL_SELECTION'] defect_fraction = visual_config['defect_fraction'] config_data['dataloader']['resize'] = [192, 192, 19] ISLESDATA = ISLES2017mass() ISLESDATA.directory_path = config_data['data_directory']['dir_ISLES2017'] ISLESDATA.load_many_cases(case_type, case_numbers, config_data) isles_mri = torch.tensor([ISLESDATA.__getitem__(NOW_SHOW)[0] ]).permute(0, 1, 4, 3, 2) isles_lesion = torch.tensor([ISLESDATA.__getitem__(NOW_SHOW)[1] ]).permute(0, 3, 2, 1).to(torch.float) print('isles_mri.shape:', isles_mri.shape) print('isles_lesion.shape:', isles_lesion.shape) from dataio.data_diffgen import DG3D dg = DG3D(unit_size=(192, 192), depth=19) # this size will be interpolated to 3D (19,192,192) _, x_unhealthy, y_lesion, _ = dg.generate_data_batches_in_torch( channel_size=6, batch_size=1, resize=config_data['dataloader']['resize']) print('x_unhealthy.shape:', x_unhealthy.shape) print('y_lesion.shape:', y_lesion.shape) isles_mri = isles_mri.cpu() isles_lesion = isles_lesion.cpu() x_unhealthy = x_unhealthy.cpu() y_lesion = y_lesion.cpu() mixed = isles_mri + defect_fraction * x_unhealthy mixed = torch.clamp(mixed, 0, 1) mixed_y = (y_lesion.to(torch.float) + isles_lesion.to(torch.float)) > 0 print('mixed.shape:', mixed.shape) print('mixed_y.shape:', mixed_y.shape) isles_mri = isles_mri.permute(0, 1, 4, 3, 2).squeeze() isles_lesion = isles_lesion.permute(0, 3, 2, 1).squeeze() x_unhealthy = x_unhealthy.permute(0, 1, 4, 3, 2).squeeze() y_lesion = y_lesion.permute(0, 3, 2, 1).squeeze() mixed = mixed.permute(0, 1, 4, 3, 2).squeeze() mixed_y = mixed_y.permute(0, 3, 2, 1).squeeze() cmap = 'inferno' fig = plt.figure() ax1 = fig.add_subplot(321) im1 = ax1.imshow(isles_mri[CHANNEL_SELECTION][:, :, DEPTH_INDEX], cmap=cmap) plt.colorbar(im1) ax2 = fig.add_subplot(322) ax2.imshow(isles_lesion[:, :, DEPTH_INDEX], cmap='Greys') ax3 = fig.add_subplot(323) im3 = ax3.imshow(x_unhealthy[CHANNEL_SELECTION][:, :, DEPTH_INDEX], cmap=cmap) plt.colorbar(im3) ax4 = fig.add_subplot(324) ax4.imshow(y_lesion[:, :, DEPTH_INDEX], cmap='Greys') ax5 = fig.add_subplot(325) im5 = ax5.imshow(mixed[CHANNEL_SELECTION][:, :, DEPTH_INDEX], cmap=cmap) plt.colorbar(im5) ax6 = fig.add_subplot(326) ax6.imshow(mixed_y[:, :, DEPTH_INDEX], cmap='Greys') this_title = '%s:%s' % (str(visual_config['case_numbers'][NOW_SHOW]), str(canonical_modalities_dict[CHANNEL_SELECTION])) ax1.set_title(this_title) plt.tight_layout() plt.show()
def generic_data_loading(config_data, case_numbers_manual=None, case_type='training', verbose=0): ''' - Assume OT is included - Assume PWI is not included ''' print(" evaluation_header.py.generic_data_loading()") canonical_modalities_label = ['ADC','MTT','rCBF','rCBV' ,'Tmax','TTP','OT'] modalities_dict = {} for i, mod in enumerate(config_data['data_modalities']): if mod != 'OT': modalities_dict[i] = mod no_of_input_channels = len(modalities_dict) if case_type=='training': case_numbers = range(1,49) if DEBUG_EVAL_TRAINING_CASE_NUMBERS is not None: case_numbers = DEBUG_EVAL_TRAINING_CASE_NUMBERS elif case_type == 'test': case_numbers = range(1,41) if DEBUG_EVAL_TEST_CASE_NUMBERS is not None: case_numbers = DEBUG_EVAL_TEST_CASE_NUMBERS if case_numbers_manual is not None: case_numbers = case_numbers_manual ISLESDATA = ISLES2017mass() ISLESDATA.directory_path = config_data['data_directory']['dir_ISLES2017'] resize_shape = tuple(config_data['dataloader']['resize']) if DEBUG_dataISLES2017_RESIZE_SHAPE is not None: resize_shape = DEBUG_dataISLES2017_RESIZE_SHAPE for_evaluation = {} case_processed = [] for case_number in case_numbers: one_case = ISLESDATA.load_one_case(case_type, str(case_number), canonical_modalities_label) if one_case is None: continue if case_type == 'training': s = one_case['imgobj']['OT'].shape labels = one_case['imgobj']['OT'] elif case_type == 'test': for xkey in one_case['imgobj']: s = one_case['imgobj'][xkey].shape; break x1 = np.zeros(shape=(no_of_input_channels,)+resize_shape) if verbose>=50: print(" case_number:%4s | original shape:%s"%(str(case_number),str(s))) for modality_key in modalities_dict: # print(" modalities_dict[%s]:%s"%(str(modality_key),str(modalities_dict[modality_key]))) if modalities_dict[modality_key] == 'OT': continue # print(" modalities_dict[modality_key]:%s"%(modalities_dict[modality_key])) x1_component = one_case['imgobj'][modalities_dict[modality_key]] x1_component = normalize_numpy_array(x1_component, target_min=config_data['normalization'][modalities_dict[modality_key]+"_target_min_max"][0], target_max=config_data['normalization'][modalities_dict[modality_key]+"_target_min_max"][1], source_min=np.min(x1_component),#config_data['normalization'][modalities_dict[modality_key]+"_source_min_max"][0], source_max=np.max(x1_component),#config_data['normalization'][modalities_dict[modality_key]+"_source_min_max"][1], verbose = 0) x1_component = torch.tensor(x1_component) x1[modality_key,:,:,:] = interp3d(x1_component,resize_shape) #x1 is now C,W,H,D x1s = x1.shape x = torch.tensor([x1]).to(torch.float).to(device=this_device) # print("x.shape:%s"%(str(x.shape))) x = x.permute(0,1,4,3,2) # print("x.shape after permute :%s"%(str(x.shape))) if case_type == 'training': labels = torch.tensor(labels).to(torch.int64).to(device=this_device) for_evaluation[case_number] = [x,labels] elif case_type == 'test': for_evaluation[case_number] = [x] case_processed.append(case_number) print(" data loaded for evaluation!%s"%(str(case_processed))) return for_evaluation
def test_save_one_for_submission(config_data): print("test_save_one_for_submission()") ev = EvalObj() case_type = 'training' canonical_modalities_label = [ 'ADC', 'MTT', 'rCBF', 'rCBV', 'Tmax', 'TTP', 'OT' ] ISLESDATA = ISLES2017mass() ISLESDATA.verbose = 20 ISLESDATA.directory_path = config_data['data_directory']['dir_ISLES2017'] # Note: at the point one_case is loaded, the shape is still (h,w,d) # # 1. case_number = 1 one_case = ISLESDATA.load_one_case(case_type, str(case_number), canonical_modalities_label) print("case_number:%s" % (str(case_number))) print(" one_case['imgobj']['OT'].shape:%s" % (str(one_case['imgobj']['OT'].shape))) for xkey in one_case: print(xkey) # just for observation original_shape = one_case['imgobj']['OT'].shape dummy_output_shape = (1, ) + one_case['imgobj'][ 'OT'].shape # assume at this point proper permutation has been done dummy_output = one_case['imgobj']['OT'].reshape( dummy_output_shape) + np.array( np.random.randint(0, 1000, size=dummy_output_shape) > 995).astype( np.float) dummy_output = torch.tensor(dummy_output) dummy_output_numpy = dummy_output.detach().cpu().numpy() y_ot_tensor = torch.tensor( one_case['imgobj']['OT'].reshape(dummy_output_shape)) ISLESDATA.save_one_case(dummy_output_numpy.reshape(original_shape), one_case, case_type, case_number, config_data, desc='etjoa001_' + str(case_number)) ev.save_one_case_evaluation(case_number, dummy_output, y_ot_tensor, config_data, dice=True) # # 2. case_number2 = 10 one_case = ISLESDATA.load_one_case(case_type, str(case_number2), canonical_modalities_label) print("case_number2:%s" % (str(case_number2))) print(" one_case['imgobj']['OT'].shape:%s" % (str(one_case['imgobj']['OT'].shape))) original_shape = one_case['imgobj']['OT'].shape dummy_output2_shape = (1, ) + one_case['imgobj']['OT'].shape dummy_output2 = one_case['imgobj']['OT'] dummy_output2 = torch.tensor(dummy_output2.reshape(dummy_output2_shape)) dummy_output2_numpy = dummy_output2.detach().cpu().numpy() ISLESDATA.save_one_case(dummy_output2_numpy.reshape(original_shape), one_case, case_type, case_number2, config_data, desc='etjoa001_' + str(case_number2)) y_ot2_tensor = torch.tensor( one_case['imgobj']['OT'].reshape(dummy_output2_shape)) ev.save_one_case_evaluation(case_number2, dummy_output2, y_ot2_tensor, config_data, dice=True)
def test_load_ISLES2017(config_data): print("test_load_ISLES2017()") case_numbers = range(1, 49) ISLESDATA = ISLES2017mass() ISLESDATA.verbose = 20 ISLESDATA.directory_path = config_data['data_directory']['dir_ISLES2017'] case_type = 'training' canonical_modalities_label = [ 'ADC', 'MTT', 'rCBF', 'rCBV', 'Tmax', 'TTP', 'OT' ] for case_number in case_numbers: one_case = ISLESDATA.load_one_case(case_type, str(case_number), canonical_modalities_label) ''' c_no|smir_id| ADC | MTT | rCBF | rCBV | Tmax | TTP | OT | 1 |127014 | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | 2 |127094 | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (!) D:/Desktop@D/meim2venv/meim3/data/isles2017\training_3 does not exist. Ignoring. 4 |127206 | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | 5 |127214 | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | 6 |127222 | (256, 256, 24) | (256, 256, 24) | (256, 256, 24) | (256, 256, 24) | (256, 256, 24) | (256, 256, 24) | (256, 256, 24) | 7 |127230 | (128, 128, 25) | (128, 128, 25) | (128, 128, 25) | (128, 128, 25) | (128, 128, 25) | (128, 128, 25) | (128, 128, 25) | ... 47 |188994 | (128, 128, 25) | (128, 128, 25) | (128, 128, 25) | (128, 128, 25) | (128, 128, 25) | (128, 128, 25) | (128, 128, 25) | 48 |189002 | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | (192, 192, 19) | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- min | | 0 | -2.288327 | -455.3503 | -58.81981 | -1.993951 | -1.2876743 | 0 | max | | 4095 | 89.48501 | 207.20389 | 40.0 | 123.06382 | 121.76149 | 1 | mean| | 195.55737 | 1.07344 | 3.85694 | 0.37295 | 1.37612 | 6.11271 | 0.00549 | (!!!) Impt the following normalization appears to be suitable after observing the data. The intensities are not very compatible with each other From the result, let us make the following naive normalization specifications. We convert to [0,1] from the following interval. | ADC 0, 5000 | MTT -5, 100 | rCBF -500, 300 | rCBV -60, 40 | Tmax -5, 150 | TTP -2, 150 | OT NONE ''' print("\nPrinting data shapes.") print( " %-4s|%-7s| %-16s | %-16s | %-16s | %-16s | %-16s | %-16s | %-16s |" % ('c_no', 'smir_id', 'ADC', 'MTT', 'rCBF', 'rCBV', 'Tmax', 'TTP', 'OT')) this_stats = {} canonical_modalities_label = [ 'ADC', 'MTT', 'rCBF', 'rCBV', 'Tmax', 'TTP', 'OT' ] for modality_label in canonical_modalities_label: this_stats[modality_label] = { 'shape': 0, 'min': np.inf, 'max': -np.inf, 'mean': [] } for case_number in case_numbers: one_case = ISLESDATA.load_one_case(case_type, str(case_number), canonical_modalities_label) if one_case is None: continue ADC, MTT, rCBF, rCBV, Tmax, TTP, OT = 0, 0, 0, 0, 0, 0, 0 for modality_label in canonical_modalities_label: if modality_label in one_case['imgobj']: this_stats[modality_label]['shape'] = one_case['imgobj'][ modality_label].shape if this_stats[modality_label]['min'] > np.min( one_case['imgobj'][modality_label]): this_stats[modality_label]['min'] = np.min( one_case['imgobj'][modality_label]) if this_stats[modality_label]['max'] < np.max( one_case['imgobj'][modality_label]): this_stats[modality_label]['max'] = np.max( one_case['imgobj'][modality_label]) this_stats[modality_label]['mean'].append( np.mean(one_case['imgobj'][modality_label])) print( " %-4s|%-7s| %-16s | %-16s | %-16s | %-16s | %-16s | %-16s | %-16s |" % (str(one_case['case_number']), str( one_case['smir_id']), this_stats['ADC']['shape'], this_stats['MTT']['shape'], this_stats['rCBF']['shape'], this_stats['rCBV']['shape'], this_stats['Tmax']['shape'], this_stats['TTP']['shape'], this_stats['OT']['shape'])) for modality_label in canonical_modalities_label: if len(this_stats[modality_label]['mean']) > 0: this_stats[modality_label]['mean'] = round( np.mean(this_stats[modality_label]['mean']), 5) else: this_stats[modality_label]['mean'] = "NA" print("-" * 160) print( " %-4s|%-7s| %-16s | %-16s | %-16s | %-16s | %-16s | %-16s | %-16s |" % ('min', '', this_stats['ADC']['min'], this_stats['MTT']['min'], this_stats['rCBF']['min'], this_stats['rCBV']['min'], this_stats['Tmax']['min'], this_stats['TTP']['min'], this_stats['OT']['min'])) print( " %-4s|%-7s| %-16s | %-16s | %-16s | %-16s | %-16s | %-16s | %-16s |" % ('max', '', this_stats['ADC']['max'], this_stats['MTT']['max'], this_stats['rCBF']['max'], this_stats['rCBV']['max'], this_stats['Tmax']['max'], this_stats['TTP']['max'], this_stats['OT']['max'])) print( " %-4s|%-7s| %-16s | %-16s | %-16s | %-16s | %-16s | %-16s | %-16s |" % ('mean', '', this_stats['ADC']['mean'], this_stats['MTT']['mean'], this_stats['rCBF']['mean'], this_stats['rCBV']['mean'], this_stats['Tmax']['mean'], this_stats['TTP']['mean'], this_stats['OT']['mean']))