def main(config_path): """ Train and evaluate a 2D UNet on the public ICH dataset using the parameters sepcified on the JSON at the config_path. The evaluation is performed by k-fold cross-validation. """ # load config file cfg = Config(settings=None) cfg.load_config(config_path) # Make Output directories out_path = os.path.join(cfg.settings['path']['OUTPUT'], cfg.settings['exp_name']) # + '/' os.makedirs(out_path, exist_ok=True) for k in range(cfg.settings['split']['n_fold']): os.makedirs(os.path.join(out_path, f'Fold_{k+1}/pred/'), exist_ok=True) # Initialize random seed to given seed seed = cfg.settings['seed'] if seed != -1: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True # Load data csv data_info_df = pd.read_csv( os.path.join(cfg.settings['path']['DATA'], 'ct_info.csv')) data_info_df = data_info_df.drop(data_info_df.columns[0], axis=1) patient_df = pd.read_csv( os.path.join(cfg.settings['path']['DATA'], 'patient_info.csv')) patient_df = patient_df.drop(patient_df.columns[0], axis=1) # Generate Cross-Val indices at the patient level skf = StratifiedKFold(n_splits=cfg.settings['split']['n_fold'], shuffle=cfg.settings['split']['shuffle'], random_state=seed) # iterate over folds and ensure that there are the same amount of ICH positive patient per fold --> Stratiffied CrossVal for k, (train_idx, test_idx) in enumerate( skf.split(patient_df.PatientNumber, patient_df.Hemorrhage)): # if fold results not already there if not os.path.exists( os.path.join(out_path, f'Fold_{k+1}/outputs.json')): # initialize logger logging.basicConfig(level=logging.INFO) logger = logging.getLogger() try: logger.handlers[1].stream.close() logger.removeHandler(logger.handlers[1]) except IndexError: pass logger.setLevel(logging.INFO) file_handler = logging.FileHandler( os.path.join(out_path, f'Fold_{k+1}/log.txt')) file_handler.setLevel(logging.INFO) file_handler.setFormatter( logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')) logger.addHandler(file_handler) if os.path.exists( os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt')): logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' + '#' * 30) logger.info(f"Experiment : {cfg.settings['exp_name']}") logger.info( f"Cross-Validation fold {k+1:02}/{cfg.settings['split']['n_fold']:02}" ) # initialize nbr of thread if cfg.settings['n_thread'] > 0: torch.set_num_threads(cfg.settings['n_thread']) logger.info(f"Number of thread : {cfg.settings['n_thread']}") # check if GPU available #cfg.settings['device'] = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') if cfg.settings['device'] is not None: cfg.settings['device'] = torch.device(cfg.settings['device']) else: if torch.cuda.is_available(): free_mem, device_idx = 0.0, 0 for d in range(torch.cuda.device_count()): mem = torch.cuda.get_device_properties( d).total_memory - torch.cuda.memory_allocated(d) if mem > free_mem: device_idx = d free_mem = mem cfg.settings['device'] = torch.device(f'cuda:{device_idx}') else: cfg.settings['device'] = torch.device('cpu') logger.info(f"Device : {cfg.settings['device']}") # extract train and test DataFrames + print summary (n samples positive and negatives) train_df = data_info_df[data_info_df.PatientNumber.isin( patient_df.loc[train_idx, 'PatientNumber'].values)] test_df = data_info_df[data_info_df.PatientNumber.isin( patient_df.loc[test_idx, 'PatientNumber'].values)] # sample the dataframe to have more or less normal slices n_remove = int( max( 0, len(train_df[train_df.Hemorrhage == 0]) - cfg.settings['dataset']['frac_negative'] * len(train_df[train_df.Hemorrhage == 1]))) df_remove = train_df[train_df.Hemorrhage == 0].sample( n=n_remove, random_state=seed) train_df = train_df[~train_df.index.isin(df_remove.index)] logger.info( '\n' + str(get_split_summary_table(data_info_df, train_df, test_df))) # Make Dataset + print online augmentation summary train_dataset = public_SegICH_Dataset2D( train_df, cfg.settings['path']['DATA'], augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.settings['data']['augmentation']['train'].items() ], window=(cfg.settings['data']['win_center'], cfg.settings['data']['win_width']), output_size=cfg.settings['data']['size']) test_dataset = public_SegICH_Dataset2D( test_df, cfg.settings['path']['DATA'], augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.settings['data']['augmentation']['eval'].items() ], window=(cfg.settings['data']['win_center'], cfg.settings['data']['win_width']), output_size=cfg.settings['data']['size']) logger.info( f"Data will be loaded from {cfg.settings['path']['DATA']}.") logger.info( f"CT scans will be windowed on [{cfg.settings['data']['win_center']-cfg.settings['data']['win_width']/2} ; {cfg.settings['data']['win_center'] + cfg.settings['data']['win_width']/2}]" ) logger.info( f"Training online data transformation: \n\n {str(train_dataset.transform)}\n" ) logger.info( f"Evaluation online data transformation: \n\n {str(test_dataset.transform)}\n" ) # Make architecture (and print summmary ??) unet_arch = UNet( depth=cfg.settings['net']['depth'], top_filter=cfg.settings['net']['top_filter'], use_3D=cfg.settings['net']['3D'], in_channels=cfg.settings['net']['in_channels'], out_channels=cfg.settings['net']['out_channels'], bilinear=cfg.settings['net']['bilinear'], midchannels_factor=cfg.settings['net']['midchannels_factor'], p_dropout=cfg.settings['net']['p_dropout']) unet_arch.to(cfg.settings['device']) logger.info( f"U-Net2D initialized with a depth of {cfg.settings['net']['depth']}" f" and a number of initial filter of {cfg.settings['net']['top_filter']}," ) logger.info( f"Reconstruction performed with {'Upsample + Conv' if cfg.settings['net']['bilinear'] else 'ConvTranspose'}." ) logger.info( f"U-Net2D takes {cfg.settings['net']['in_channels']} as input channels and {cfg.settings['net']['out_channels']} as output channels." ) logger.info( f"The U-Net2D has {sum(p.numel() for p in unet_arch.parameters())} parameters." ) # Make model unet2D = UNet2D( unet_arch, n_epoch=cfg.settings['train']['n_epoch'], batch_size=cfg.settings['train']['batch_size'], lr=cfg.settings['train']['lr'], lr_scheduler=getattr(torch.optim.lr_scheduler, cfg.settings['train']['lr_scheduler']), lr_scheduler_kwargs=cfg.settings['train'] ['lr_scheduler_kwargs'], loss_fn=getattr(src.models.optim.LossFunctions, cfg.settings['train']['loss_fn']), loss_fn_kwargs=cfg.settings['train']['loss_fn_kwargs'], weight_decay=cfg.settings['train']['weight_decay'], num_workers=cfg.settings['train']['num_workers'], device=cfg.settings['device'], print_progress=cfg.settings['print_progress']) # Load model if required if cfg.settings['train']['model_path_to_load']: if isinstance(cfg.settings['train']['model_path_to_load'], str): model_path = cfg.settings['train']['model_path_to_load'] unet2D.load_model(model_path, map_location=cfg.settings['device']) elif isinstance(cfg.settings['train']['model_path_to_load'], list): model_path = cfg.settings['train']['model_path_to_load'][k] unet2D.load_model(model_path, map_location=cfg.settings['device']) else: raise ValueError( f'Model path to load type not understood.') logger.info(f"2D U-Net model loaded from {model_path}") # print Training hyper-parameters train_params = [] for key, value in cfg.settings['train'].items(): train_params.append(f"--> {key} : {value}") logger.info('Training settings:\n\t' + '\n\t'.join(train_params)) # Train model eval_dataset = test_dataset if cfg.settings['train'][ 'validate_epoch'] else None unet2D.train(train_dataset, valid_dataset=eval_dataset, checkpoint_path=os.path.join( out_path, f'Fold_{k+1}/checkpoint.pt')) # Evaluate model unet2D.evaluate(test_dataset, save_path=os.path.join(out_path, f'Fold_{k+1}/pred/')) # Save models & outputs unet2D.save_model( os.path.join(out_path, f'Fold_{k+1}/trained_unet.pt')) logger.info("Trained U-Net saved at " + os.path.join(out_path, f'Fold_{k+1}/trained_unet.pt')) unet2D.save_outputs( os.path.join(out_path, f'Fold_{k+1}/outputs.json')) logger.info("Trained statistics saved at " + os.path.join(out_path, f'Fold_{k+1}/outputs.json')) # delete checkpoint if exists if os.path.exists( os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt')): os.remove(os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt')) logger.info('Checkpoint deleted.') # save mean +/- 1.96 std Dice in .txt file scores_list = [] for k in range(cfg.settings['split']['n_fold']): with open(os.path.join(out_path, f'Fold_{k+1}/outputs.json'), 'r') as f: out = json.load(f) scores_list.append( [out['eval']['dice']['all'], out['eval']['dice']['positive']]) means = np.array(scores_list).mean(axis=0) CI95 = 1.96 * np.array(scores_list).std(axis=0) with open(os.path.join(out_path, 'average_scores.txt'), 'w') as f: f.write(f'Dice = {means[0]} +/- {CI95[0]}\n') f.write(f'Dice (Positive) = {means[1]} +/- {CI95[1]}\n') logger.info('Average Scores saved at ' + os.path.join(out_path, 'average_scores.txt')) # generate dataframe of all prediction df_list = [ pd.read_csv( os.path.join(out_path, f'Fold_{i+1}/pred/volume_prediction_scores.csv')) for i in range(cfg.settings['split']['n_fold']) ] all_df = pd.concat(df_list, axis=0).reset_index(drop=True) all_df.to_csv(os.path.join(out_path, 'all_volume_prediction.csv')) logger.info('CSV of all volumes prediction saved at ' + os.path.join(out_path, 'all_volume_prediction.csv')) # Save config file cfg.settings['device'] = str(cfg.settings['device']) cfg.save_config(os.path.join(out_path, 'config.json')) logger.info("Config file saved at " + os.path.join(out_path, 'config.json')) # Analyse results analyse_supervised_exp(out_path, cfg.settings['path']['DATA'], cfg.settings['split']['n_fold'], save_fn=os.path.join(out_path, 'results_overview.pdf')) logger.info('Results overview figure saved at ' + os.path.join(out_path, 'results_overview.pdf'))
def main(config_path): """ """ # load config cfg = AttrDict.from_json_path(config_path) # make outputs dir out_path = os.path.join(cfg.path.output, cfg.exp_name) os.makedirs(out_path, exist_ok=True) # initialize seed if cfg.seed != -1: random.seed(cfg.seed) np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) torch.cuda.manual_seed(cfg.seed) torch.cuda.manual_seed_all(cfg.seed) torch.backends.cudnn.deterministic = True # initialize logger logger = initialize_logger(os.path.join(out_path, 'log.txt')) logger.info(f"Experiment : {cfg.exp_name}") # set device if cfg.device: cfg.device = torch.device(cfg.device) else: cfg.device = torch.device( f'cuda:0') if torch.cuda.is_available() else torch.device('cpu') logger.info(f"Device set to {cfg.device}.") # get Dataset data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'ct_info.csv'), index_col=0) dataset = public_SegICH_Dataset2D( data_info_df, cfg.path.data, augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.data.augmentation.items() ], output_size=cfg.data.size, window=(cfg.data.win_center, cfg.data.win_width)) # load inpainting model cfg_ae = AttrDict.from_json_path(cfg.ae_cfg_path) ae_net = AE_net(**cfg_ae.net) loaded_state_dict = torch.load(cfg.ae_model_path, map_location=cfg.device) ae_net.load_state_dict(loaded_state_dict) ae_net = ae_net.to(cfg.device).eval() logger.info(f"AE model succesfully loaded from {cfg.ae_model_path}") # Load Classifier if cfg.classifier_model_path is not None: cfg_classifier = AttrDict.from_json_path( os.path.join(cfg.classifier_model_path, 'config.json')) classifier = getattr(rn, cfg_classifier.net.resnet)( num_classes=cfg_classifier.net.num_classes, input_channels=cfg_classifier.net.input_channels) classifier_state_dict = torch.load(os.path.join( cfg.classifier_model_path, 'resnet_state_dict.pt'), map_location=cfg.device) classifier.load_state_dict(classifier_state_dict) classifier = classifier.to(cfg.device) classifier.eval() logger.info( f"ResNet classifier model succesfully loaded from {os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt')}" ) # iterate over dataset all_pred = [] for i, sample in enumerate(dataset): # unpack data image, target, id, slice = sample logger.info( "=" * 25 + f" SAMPLE {i+1:04}/{len(dataset):04} - Volume {id:03} Slice {slice:03} " + "=" * 25) # Classify sample if cfg.classifier_model_path is not None: with torch.no_grad(): input_clss = image.unsqueeze(0).to(cfg.device).float() pred_score = nn.functional.softmax( classifier(input_clss), dim=1 )[:, 1] # take columns of softmax of positive class as score pred = 1 if pred_score >= cfg.classification_threshold else 0 else: pred = 1 # if not classifier given, all slices are processed # process slice if classifier has detected Hemorrhage if pred == 1: logger.info( f"ICH detected. Compute anomaly mask through AE reconstruction." ) # Detect anomalies using the robuste approach ad_map, ad_mask = compute_anomaly(ae_net, image, alpha_low=cfg.alpha_low, alpha_high=cfg.alpha_high, device=cfg.device) logger.info(f"{ad_mask.sum()} anomalous pixels detected.") # save ad_mask ad_mask_fn = f"{id}/{slice}_anomalies.bmp" save_path = os.path.join(out_path, 'pred/', ad_mask_fn) if not os.path.isdir(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path), exist_ok=True) io.imsave(save_path, img_as_ubyte(ad_mask), check_contrast=False) # save anomaly map ad_map_fn = f"{id}/{slice}_map_anomalies.png" save_path_map = os.path.join(out_path, 'pred/', ad_map_fn) io.imsave(save_path_map, img_as_ubyte( rescale_intensity(ad_map, out_range=(0.0, 1.0))), check_contrast=False) else: logger.info(f"No ICH detected. Set the anomaly mask to zeros.") ad_mask = np.zeros_like(target[0].numpy()) ad_mask_fn, ad_map_fn = 'None', 'None' # compute confusion matrix with target ICH mask tn, fp, fn, tp = confusion_matrix(target[0].numpy().ravel(), ad_mask.ravel(), labels=[0, 1]).ravel() auc = roc_auc_score(target[0].numpy().ravel(), ad_map.ravel()) if torch.any(target[0]) else 'None' # append to all_pred list all_pred.append({ 'id': id.item(), 'slice': slice.item(), 'label': target.max().item(), 'TP': tp, 'TN': tn, 'FP': fp, 'FN': fn, 'AUC': auc, 'ad_mask_fn': ad_mask_fn, 'ad_map_fn': ad_map_fn }) # make a dataframe of all predictions slice_df = pd.DataFrame(all_pred) volume_df = slice_df[['id', 'label', 'TP', 'TN', 'FP', 'FN']].groupby('id').agg({ 'label': 'max', 'TP': 'sum', 'TN': 'sum', 'FP': 'sum', 'FN': 'sum' }) # Compute Dice and Volume Dice slice_df['Dice'] = (2 * slice_df.TP + 1) / (2 * slice_df.TP + slice_df.FP + slice_df.FN + 1) volume_df['Dice'] = (2 * volume_df.TP + 1) / ( 2 * volume_df.TP + volume_df.FP + volume_df.FN + 1) logger.info(f"Mean slice dice : {slice_df.Dice.mean(axis=0):.3f}") logger.info(f"Mean volume dice : {volume_df.Dice.mean(axis=0):.3f}") logger.info( f"Mean posiitve slice AUC {slice_df[slice_df.label == 1].AUC.mean(axis=0):.3f}" ) # Save Scores and Config slice_df.to_csv(os.path.join(out_path, 'slice_predictions.csv')) logger.info( f"Slice prediction csv saved at {os.path.join(out_path, 'slice_predictions.csv')}" ) volume_df.to_csv(os.path.join(out_path, 'volume_predictions.csv')) logger.info( f"Volume prediction csv saved at {os.path.join(out_path, 'volume_predictions.csv')}" ) cfg.device = str(cfg.device) with open(os.path.join(out_path, 'config.json'), 'w') as f: json.dump(cfg, f) logger.info( f"Config file saved at {os.path.join(out_path, 'config.json')}")
def main(config_path): """ UNet2D pretrained on binary classification with the RSNA dataset and finetuned with the public data. """ # load the config file cfg = AttrDict.from_json_path(config_path) # Make Outputs directories out_path = os.path.join(cfg.path.output, cfg.exp_name) out_path_selfsup = os.path.join(out_path, 'classification_pretrain/') out_path_sup = os.path.join(out_path, 'supervised_train/') os.makedirs(out_path_selfsup, exist_ok=True) for k in range(cfg.Sup.split.n_fold): os.makedirs(os.path.join(out_path_sup, f'Fold_{k+1}/pred/'), exist_ok=True) # Initialize random seed if cfg.seed != -1: random.seed(cfg.seed) np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) torch.cuda.manual_seed(cfg.seed) torch.cuda.manual_seed_all(cfg.seed) torch.backends.cudnn.deterministic = True # Set number of thread if cfg.n_thread > 0: torch.set_num_threads(cfg.n_thread) # set device if cfg.device: cfg.device = torch.device(cfg.device) else: cfg.device = get_available_device() #################################################### # Self-supervised training on Multi Classification # #################################################### # Initialize Logger logger = initialize_logger(os.path.join(out_path_selfsup, 'log.txt')) if os.path.exists(os.path.join(out_path_selfsup, f'checkpoint.pt')): logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' + '#' * 30) logger.info(f"Experiment : {cfg.exp_name}") # Load RSNA data csv df_rsna = pd.read_csv(os.path.join(cfg.path.data.SSL, 'slice_info.csv'), index_col=0) # Keep only fractions sample if cfg.SSL.dataset.n_data_1 >= 0: df_rsna_ICH = df_rsna[df_rsna.Hemorrhage == 1].sample( n=cfg.SSL.dataset.n_data_1, random_state=cfg.seed) else: df_rsna_ICH = df_rsna[df_rsna.Hemorrhage == 1] if cfg.SSL.dataset.f_data_0 >= 0: # nbr of normal as fraction of nbr of ICH samples df_rsna_noICH = df_rsna[df_rsna.Hemorrhage == 0].sample( n=cfg.SSL.dataset.f_data_0 * len(df_rsna_ICH), random_state=cfg.seed) else: df_rsna_noICH = df_rsna[df_rsna.Hemorrhage == 0] df_rsna = pd.concat([df_rsna_ICH, df_rsna_noICH], axis=0) # Split data to keep few for evaluation in a strafied way train_df, test_df = train_test_split(df_rsna, test_size=cfg.SSL.dataset.frac_eval, stratify=df_rsna.Hemorrhage, random_state=cfg.seed) logger.info('\n' + str(get_split_summary_table(df_rsna, train_df, test_df))) # Make dataset : Train --> BinaryClassification, Test --> BinaryClassification train_RSNA_dataset = RSNA_dataset( train_df, cfg.path.data.SSL, augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.SSL.dataset.augmentation.train.items() ], window=(cfg.data.win_center, cfg.data.win_width), output_size=cfg.data.size, mode='multi_classification') test_RSNA_dataset = RSNA_dataset( test_df, cfg.path.data.SSL, augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.SSL.dataset.augmentation.eval.items() ], window=(cfg.data.win_center, cfg.data.win_width), output_size=cfg.data.size, mode='multi_classification') logger.info(f"Data will be loaded from {cfg.path.data.SSL}.") logger.info( f"CT scans will be windowed on [{cfg.data.win_center-cfg.data.win_width/2} ; {cfg.data.win_center + cfg.data.win_width/2}]" ) logger.info(f"CT scans will be resized to {cfg.data.size}x{cfg.data.size}") logger.info( f"Training online data transformation: \n\n {str(train_RSNA_dataset.transform)}\n" ) logger.info( f"Evaluation online data transformation: \n\n {str(test_RSNA_dataset.transform)}\n" ) # Make U-Net-Encoder architecture net_ssl = UNet_Encoder(**cfg.SSL.net).to(cfg.device) net_params = [f"--> {k} : {v}" for k, v in cfg.SSL.net.items()] logger.info("UNet like Multi Classifier \n\t" + "\n\t".join(net_params)) logger.info( f"The Multi Classifier has {sum(p.numel() for p in net_ssl.parameters())} parameters." ) # Make Model cfg.SSL.train.model_param.lr_scheduler = getattr( torch.optim.lr_scheduler, cfg.SSL.train.model_param.lr_scheduler ) # convert scheduler name to scheduler class object if cfg.SSL.train.model_param.loss_fn == 'BCEWithLogitsLoss': df_rsna['no_Hemorrhage'] = 1 - df_rsna.Hemorrhage class_weight_list = ( (len(df_rsna) - df_rsna[train_RSNA_dataset.class_name].sum()) / df_rsna[train_RSNA_dataset.class_name].sum() ).values # define CE weighting from train dataset cfg.SSL.train.model_param.loss_fn_kwargs['pos_weight'] = torch.tensor( class_weight_list, device=cfg.device) # add weighting to CE kwargs try: cfg.SSL.train.model_param.loss_fn = getattr( torch.nn, cfg.SSL.train.model_param.loss_fn ) # convert loss_fn name to nn.Module class object except AttributeError: cfg.SSL.train.model_param.loss_fn = getattr( src.models.optim.LossFunctions, cfg.SSL.train.model_param.loss_fn) #torch.tensor([1 - w_ICH, w_ICH], device=cfg.device).float() classifier = MultiClassifier(net_ssl, device=cfg.device, print_progress=cfg.print_progress, **cfg.SSL.train.model_param) train_params = [ f"--> {k} : {v}" for k, v in cfg.SSL.train.model_param.items() ] logger.info("Classifer Training Parameters \n\t" + "\n\t".join(train_params)) # Load weights if specified if cfg.SSL.train.model_path_to_load: model_path = cfg.SSL.train.model_path_to_load classifier.load_model(model_path, map_location=cfg.device) logger.info( f"Classifer Model succesfully loaded from {cfg.SSL.train.model_path_to_load}" ) # train if needed if cfg.SSL.train.model_param.n_epoch > 0: classifier.train(train_RSNA_dataset, valid_dataset=test_RSNA_dataset, checkpoint_path=os.path.join(out_path_selfsup, f'checkpoint.pt')) # evaluate auc, acc, sub_acc, recall, precision, f1 = classifier.evaluate( test_RSNA_dataset, save_tsne=True, return_scores=True) logger.info(f"Classifier Test AUC : {auc:.2%}") logger.info(f"Classifier Test Accuracy : {acc:.2%}") logger.info(f"Classifier Test Subset Accuracy : {sub_acc:.2%}") logger.info(f"Classifier Test Recall : {recall:.2%}") logger.info(f"Classifier Test Precision : {precision:.2%}") logger.info(f"Classifier Test F1-score : {f1:.2%}") # save model, outputs classifier.save_model_state_dict( os.path.join(out_path_selfsup, 'pretrained_unet_enc.pt')) logger.info( "Pre-trained U-Net encoder on binary classification saved at " + os.path.join(out_path_selfsup, 'pretrained_unet_enc.pt')) classifier.save_outputs(os.path.join(out_path_selfsup, 'outputs.json')) logger.info("Classifier outputs saved at " + os.path.join(out_path_selfsup, 'outputs.json')) test_df.reset_index(drop=True).to_csv( os.path.join(out_path_selfsup, 'eval_data_info.csv')) logger.info("Evaluation data info saved at " + os.path.join(out_path_selfsup, 'eval_data_info.csv')) # delete any checkpoints if os.path.exists(os.path.join(out_path_selfsup, f'checkpoint.pt')): os.remove(os.path.join(out_path_selfsup, f'checkpoint.pt')) logger.info('Checkpoint deleted.') # get weights state dictionnary pretrained_unet_weights = classifier.get_state_dict() ################################################################### # Supervised fine-training of U-Net with K-Fold Cross-Validation # ################################################################### # load annotated data csv data_info_df = pd.read_csv(os.path.join(cfg.path.data.Sup, 'ct_info.csv'), index_col=0) patient_df = pd.read_csv(os.path.join(cfg.path.data.Sup, 'patient_info.csv'), index_col=0) # Make K-Fold spolit at patient level skf = StratifiedKFold(n_splits=cfg.Sup.split.n_fold, shuffle=cfg.Sup.split.shuffle, random_state=cfg.seed) # define scheduler and loss_fn as object cfg.Sup.train.model_param.lr_scheduler = getattr( torch.optim.lr_scheduler, cfg.Sup.train.model_param.lr_scheduler ) # convert scheduler name to scheduler class object cfg.Sup.train.model_param.loss_fn = getattr( src.models.optim.LossFunctions, cfg.Sup.train.model_param.loss_fn ) # convert loss_fn name to nn.Module class object # iterate over folds for k, (train_idx, test_idx) in enumerate( skf.split(patient_df.PatientNumber, patient_df.Hemorrhage)): # check if fold's results already exists if not os.path.exists( os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json')): # initialize logger logger = initialize_logger( os.path.join(out_path_sup, f'Fold_{k+1}/log.txt')) if os.path.exists( os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')): logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' + '#' * 30) logger.info(f"Experiment : {cfg['exp_name']}") logger.info( f"Cross-Validation fold {k+1:02}/{cfg['Sup']['split']['n_fold']:02}" ) # extract train/test slice dataframe train_df = data_info_df[data_info_df.PatientNumber.isin( patient_df.loc[train_idx, 'PatientNumber'].values)] test_df = data_info_df[data_info_df.PatientNumber.isin( patient_df.loc[test_idx, 'PatientNumber'].values)] # samples train dataframe to adjuste negative/positive fractions n_remove = int( max( 0, len(train_df[train_df.Hemorrhage == 0]) - cfg.Sup.dataset.frac_negative * len(train_df[train_df.Hemorrhage == 1]))) df_remove = train_df[train_df.Hemorrhage == 0].sample( n=n_remove, random_state=cfg.seed) train_df = train_df[~train_df.index.isin(df_remove.index)] logger.info( '\n' + str(get_split_summary_table(data_info_df, train_df, test_df))) # Make datasets train_dataset = public_SegICH_Dataset2D( train_df, cfg.path.data.Sup, augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.Sup.dataset.augmentation.train.items() ], window=(cfg.data.win_center, cfg.data.win_width), output_size=cfg.data.size) test_dataset = public_SegICH_Dataset2D( test_df, cfg.path.data.Sup, augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.Sup.dataset.augmentation.eval.items() ], window=(cfg.data.win_center, cfg.data.win_width), output_size=cfg.data.size) logger.info(f"Data will be loaded from {cfg.path.data.Sup}.") logger.info( f"CT scans will be windowed on [{cfg.data.win_center-cfg.data.win_width/2} ; {cfg.data.win_center + cfg.data.win_width/2}]" ) logger.info( f"CT scans will be resized to {cfg.data.size}x{cfg.data.size}") logger.info( f"Training online data transformation: \n\n {str(train_dataset.transform)}\n" ) logger.info( f"Evaluation online data transformation: \n\n {str(test_dataset.transform)}\n" ) # Make U-Net architecture unet_sup = UNet(**cfg.Sup.net).to(cfg.device) net_params = [f"--> {k} : {v}" for k, v in cfg.Sup.net.items()] logger.info("UNet-2D params \n\t" + "\n\t".join(net_params)) logger.info( f"The U-Net2D has {sum(p.numel() for p in unet_sup.parameters())} parameters." ) # Make Model unet2D = UNet2D(unet_sup, device=cfg.device, print_progress=cfg.print_progress, **cfg.Sup.train.model_param) train_params = [ f"--> {k} : {v}" for k, v in cfg.Sup.train.model_param.items() ] logger.info("UNet-2D Training Parameters \n\t" + "\n\t".join(train_params)) # ????? load model if specified ????? # transfer weights learn with context restoration logger.info( 'Initialize U-Net2D with weights learned with context_restoration on RSNA.' ) unet2D.transfer_weights(pretrained_unet_weights, verbose=True) # Train U-net eval_dataset = test_dataset if cfg.Sup.train.validate_epoch else None unet2D.train(train_dataset, valid_dataset=eval_dataset, checkpoint_path=os.path.join( out_path_sup, f'Fold_{k+1}/checkpoint.pt')) # Evaluate U-Net unet2D.evaluate(test_dataset, save_path=os.path.join(out_path_sup, f'Fold_{k+1}/pred/')) # Save models and outputs unet2D.save_model( os.path.join(out_path_sup, f'Fold_{k+1}/trained_unet.pt')) logger.info( "Trained U-Net saved at " + os.path.join(out_path_sup, f'Fold_{k+1}/trained_unet.pt')) unet2D.save_outputs( os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json')) logger.info("Trained statistics saved at " + os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json')) # delete checkpoint if exists if os.path.exists( os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')): os.remove( os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')) logger.info('Checkpoint deleted.') # save mean +/- 1.96 std Dice over Folds save_mean_fold_dice(out_path_sup, cfg.Sup.split.n_fold) logger.info('Average Scores saved at ' + os.path.join(out_path_sup, 'average_scores.txt')) # Save all volumes prediction csv df_list = [ pd.read_csv( os.path.join(out_path_sup, f'Fold_{i+1}/pred/volume_prediction_scores.csv')) for i in range(cfg.Sup.split.n_fold) ] all_df = pd.concat(df_list, axis=0).reset_index(drop=True) all_df.to_csv(os.path.join(out_path_sup, 'all_volume_prediction.csv')) logger.info('CSV of all volumes prediction saved at ' + os.path.join(out_path_sup, 'all_volume_prediction.csv')) # Save config file cfg.device = str(cfg.device) cfg.SSL.train.model_param.lr_scheduler = str( cfg.SSL.train.model_param.lr_scheduler) cfg.Sup.train.model_param.lr_scheduler = str( cfg.Sup.train.model_param.lr_scheduler) cfg.SSL.train.model_param.loss_fn = str(cfg.SSL.train.model_param.loss_fn) cfg.Sup.train.model_param.loss_fn = str(cfg.Sup.train.model_param.loss_fn) cfg.SSL.train.model_param.loss_fn_kwargs.pos_weight = cfg.SSL.train.model_param.loss_fn_kwargs.pos_weight.cpu( ).data.tolist() with open(os.path.join(out_path, 'config.json'), 'w') as fp: json.dump(cfg, fp) logger.info('Config file saved at ' + os.path.join(out_path, 'config.json')) # Analyse results analyse_supervised_exp(out_path_sup, cfg.path.data.Sup, n_fold=cfg.Sup.split.n_fold, config_folder=out_path, save_fn=os.path.join( out_path, 'results_supervised_overview.pdf')) logger.info('Results overview figure saved at ' + os.path.join(out_path, 'results_supervised_overview.pdf')) analyse_representation_exp(out_path_selfsup, save_fn=os.path.join( out_path, 'results_self-supervised_overview.pdf')) logger.info('Results overview figure saved at ' + os.path.join(out_path, 'results_self-supervised_overview.pdf'))
def main(config_path): """ UNet2D pretrained on context retoration with the RSNA dataset and finetuned with the public data. """ # load the config file with open(config_path, 'r') as fp: cfg = json.load(fp) # Make Outputs directories out_path = os.path.join( cfg['path']['output'], cfg['exp_name']) # + datetime.now().strftime('_%Y-%m-%d')) out_path_selfsup = os.path.join(out_path, 'context_restoration_pretrain/') out_path_sup = os.path.join(out_path, 'supervised_train/') os.makedirs(out_path_selfsup, exist_ok=True) for k in range(cfg['Sup']['split']['n_fold']): os.makedirs(os.path.join(out_path_sup, f'Fold_{k+1}/pred/'), exist_ok=True) # Initialize random seed seed = cfg['seed'] if seed != -1: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True # Set number of thread if cfg['n_thread'] > 0: torch.set_num_threads(cfg['n_thread']) # check if GPU available #cfg['device'] = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') if cfg['device'] is not None: cfg['device'] = torch.device(cfg['device']) else: if torch.cuda.is_available(): free_mem, device_idx = 0.0, 0 for d in range(torch.cuda.device_count()): mem = torch.cuda.get_device_properties( d).total_memory - torch.cuda.memory_allocated(d) if mem > free_mem: device_idx = d free_mem = mem cfg['device'] = torch.device(f'cuda:{device_idx}') else: cfg['device'] = torch.device('cpu') ################################################### # Self-supervised training on Context Restoration # ################################################### # Initialize Logger logger = initialize_logger(os.path.join(out_path_selfsup, 'log.txt')) if os.path.exists(os.path.join(out_path_selfsup, f'checkpoint.pt')): logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' + '#' * 30) logger.info(f"Experiment : {cfg['exp_name']}") # Load RSNA data csv df_rsna = pd.read_csv(os.path.join(cfg['path']['data']['SSL'], 'slice_info.csv'), index_col=0) # Keep only fractions of negative/positive df_rsna_pos = df_rsna[df_rsna.Hemorrhage == 1] if cfg['SSL']['dataset']['n_positive'] >= 0: df_rsna_pos = df_rsna_pos.sample(n=cfg['SSL']['dataset']['n_positive'], random_state=seed) n_neg = int(cfg['SSL']['dataset']['frac_negative'] * len(df_rsna_pos)) df_rsna_neg = df_rsna[df_rsna.Hemorrhage == 0].sample(n=n_neg, random_state=seed) df_rsna_samp = pd.concat([df_rsna_pos, df_rsna_neg], axis=0) # Split data to keep few for evaluation test_df = df_rsna_samp.sample(n=cfg['SSL']['dataset']['n_eval'], random_state=seed) train_df = df_rsna_samp.drop(test_df.index) df_rsna_neg = df_rsna[(df_rsna.Hemorrhage == 0)] test_df = test_df.append( df_rsna_neg[~df_rsna_neg.isin(train_df.index)].sample( n=cfg['SSL']['dataset']['n_eval_neg'], random_state=seed)) test_df, train_df = test_df.reset_index(), train_df.reset_index() logger.info('\n' + str(get_split_summary_table(df_rsna, train_df, test_df))) # Make dataset : Train --> ContextRestoration, Test --> Standard train_RSNA_dataset = RSNA_dataset( train_df, cfg['path']['data']['SSL'], augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg['SSL']['dataset']['augmentation']['train'].items() ], window=(cfg['data']['win_center'], cfg['data']['win_width']), output_size=cfg['data']['size'], mode='context_restoration', n_swap=cfg['SSL']['dataset']['n_swap'], swap_w=cfg['SSL']['dataset']['swap_w'], swap_h=cfg['SSL']['dataset']['swap_h'], swap_rot=cfg['SSL']['dataset']['swap_rotate']) test_RSNA_dataset = RSNA_dataset( test_df, cfg['path']['data']['SSL'], augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg['SSL']['dataset']['augmentation']['eval'].items() ], window=(cfg['data']['win_center'], cfg['data']['win_width']), output_size=cfg['data']['size'], mode='standard') logger.info(f"Data will be loaded from {cfg['path']['data']['SSL']}.") logger.info( f"CT scans will be windowed on [{cfg['data']['win_center']-cfg['data']['win_width']/2} ; {cfg['data']['win_center'] + cfg['data']['win_width']/2}]" ) logger.info( f"CT scans will be resized to {cfg['data']['size']}x{cfg['data']['size']}" ) logger.info( f"Training online data transformation: \n\n {str(train_RSNA_dataset.transform)}\n" ) logger.info( f"Evaluation online data transformation: \n\n {str(test_RSNA_dataset.transform)}\n" ) logger.info( f"Input images will be corrupted by {cfg['SSL']['dataset']['n_swap']} swap of dimension {cfg['SSL']['dataset']['swap_h']}x{cfg['SSL']['dataset']['swap_w']} (h x w)" ) # Make U-Net architecture net_ssl = UNet(depth=cfg['SSL']['net']['depth'], top_filter=cfg['SSL']['net']['top_filter'], use_3D=cfg['SSL']['net']['3D'], in_channels=cfg['SSL']['net']['in_channels'], out_channels=cfg['SSL']['net']['out_channels'], bilinear=cfg['SSL']['net']['bilinear'], use_final_activation=cfg['SSL']['net']['final_activation'], midchannels_factor=cfg['SSL']['net']['midchannels_factor'], p_dropout=cfg['SSL']['net']['p_dropout']) net_ssl.to(cfg['device']) logger.info( f"U-Net2D initialized with a depth of {cfg['SSL']['net']['depth']}" f" and a number of initial filter of {cfg['SSL']['net']['top_filter']}," ) logger.info( f"Reconstruction performed with {'Upsample + Conv' if cfg['SSL']['net']['bilinear'] else 'ConvTranspose'}." ) logger.info( f"U-Net2D takes {cfg['SSL']['net']['in_channels']} as input channels and {cfg['SSL']['net']['out_channels']} as output channels." ) logger.info( f"The U-Net2D has {sum(p.numel() for p in net_ssl.parameters())} parameters." ) # Make Model ctx_restor = ContextRestoration( net_ssl, n_epoch=cfg['SSL']['train']['n_epoch'], batch_size=cfg['SSL']['train']['batch_size'], lr=cfg['SSL']['train']['lr'], lr_scheduler=getattr(torch.optim.lr_scheduler, cfg['SSL']['train']['lr_scheduler']), lr_scheduler_kwargs=cfg['SSL']['train']['lr_scheduler_kwargs'], loss_fn=getattr(torch.nn, cfg['SSL']['train']['loss_fn']), loss_fn_kwargs=cfg['SSL']['train']['loss_fn_kwargs'], weight_decay=cfg['SSL']['train']['weight_decay'], num_workers=cfg['SSL']['train']['num_workers'], device=cfg['device'], print_progress=cfg['print_progress']) # Load weights if specified if cfg['SSL']['train']['model_path_to_load']: model_path = cfg['SSL']['train']['model_path_to_load'] ctx_restor.load_model(model_path, map_location=cfg['device']) # train if needed if cfg['SSL']['train']['n_epoch'] > 0: train_params = [] for key, value in cfg['SSL']['train'].items(): train_params.append(f"--> {key} : {value}") logger.info('Training settings:\n\t' + '\n\t'.join(train_params)) ctx_restor.train(train_RSNA_dataset, checkpoint_path=os.path.join(out_path_selfsup, f'checkpoint.pt')) # evaluate ctx_restor.evaluate(test_RSNA_dataset) # save model, outputs and evaluation data info (test_df) ctx_restor.save_model(os.path.join(out_path_selfsup, 'pretrained_unet.pt')) logger.info("Pre-trained U-Net on context restoration saved at " + os.path.join(out_path_selfsup, 'pretrained_unet.pt')) ctx_restor.save_outputs(os.path.join(out_path_selfsup, 'outputs.json')) logger.info("Context restoration outputs saved at " + os.path.join(out_path_selfsup, 'outputs.json')) test_df.to_csv(os.path.join(out_path_selfsup, 'eval_data_info.csv')) logger.info("Evaluation data info saved at " + os.path.join(out_path_selfsup, 'eval_data_info.csv')) # delete any checkpoints if os.path.exists(os.path.join(out_path_selfsup, f'checkpoint.pt')): os.remove(os.path.join(out_path_selfsup, f'checkpoint.pt')) logger.info('Checkpoint deleted.') # get weights state dictionnary pretrained_unet_weights = ctx_restor.get_state_dict() ################################################################### # Supervised fine-training of U-Net with K-Fold Cross-Validation # ################################################################### # load annotated data csv data_info_df = pd.read_csv(os.path.join(cfg['path']['data']['Sup'], 'ct_info.csv'), index_col=0) patient_df = pd.read_csv(os.path.join(cfg['path']['data']['Sup'], 'patient_info.csv'), index_col=0) # Make K-Fold spolit at patient level skf = StratifiedKFold(n_splits=cfg['Sup']['split']['n_fold'], shuffle=cfg['Sup']['split']['shuffle'], random_state=seed) # iterate over folds for k, (train_idx, test_idx) in enumerate( skf.split(patient_df.PatientNumber, patient_df.Hemorrhage)): # check if fold's results already exists if not os.path.exists( os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json')): # initialize logger logger = initialize_logger( os.path.join(out_path_sup, f'Fold_{k+1}/log.txt')) if os.path.exists( os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')): logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' + '#' * 30) logger.info(f"Experiment : {cfg['exp_name']}") logger.info( f"Cross-Validation fold {k+1:02}/{cfg['Sup']['split']['n_fold']:02}" ) # extract train/test slice dataframe train_df = data_info_df[data_info_df.PatientNumber.isin( patient_df.loc[train_idx, 'PatientNumber'].values)] test_df = data_info_df[data_info_df.PatientNumber.isin( patient_df.loc[test_idx, 'PatientNumber'].values)] # samples train dataframe to adjuste negative/positive fractions n_remove = int( max( 0, len(train_df[train_df.Hemorrhage == 0]) - cfg['Sup']['dataset']['frac_negative'] * len(train_df[train_df.Hemorrhage == 1]))) df_remove = train_df[train_df.Hemorrhage == 0].sample( n=n_remove, random_state=seed) train_df = train_df[~train_df.index.isin(df_remove.index)] logger.info( '\n' + str(get_split_summary_table(data_info_df, train_df, test_df))) # Make datasets train_dataset = public_SegICH_Dataset2D( train_df, cfg['path']['data']['Sup'], augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg['Sup']['dataset']['augmentation']['train'].items() ], window=(cfg['data']['win_center'], cfg['data']['win_width']), output_size=cfg['data']['size']) test_dataset = public_SegICH_Dataset2D( test_df, cfg['path']['data']['Sup'], augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg['Sup']['dataset']['augmentation']['eval'].items() ], window=(cfg['data']['win_center'], cfg['data']['win_width']), output_size=cfg['data']['size']) logger.info( f"Data will be loaded from {cfg['path']['data']['Sup']}.") logger.info( f"CT scans will be windowed on [{cfg['data']['win_center']-cfg['data']['win_width']/2} ; {cfg['data']['win_center'] + cfg['data']['win_width']/2}]" ) logger.info( f"CT scans will be resized to {cfg['data']['size']}x{cfg['data']['size']}" ) logger.info( f"Training online data transformation: \n\n {str(train_dataset.transform)}\n" ) logger.info( f"Evaluation online data transformation: \n\n {str(test_dataset.transform)}\n" ) # Make U-Net architecture unet_sup = UNet( depth=cfg['Sup']['net']['depth'], top_filter=cfg['Sup']['net']['top_filter'], use_3D=cfg['Sup']['net']['3D'], in_channels=cfg['Sup']['net']['in_channels'], out_channels=cfg['Sup']['net']['out_channels'], bilinear=cfg['Sup']['net']['bilinear'], midchannels_factor=cfg['Sup']['net']['midchannels_factor'], p_dropout=cfg['Sup']['net']['p_dropout']) unet_sup.to(cfg['device']) logger.info( f"U-Net2D initialized with a depth of {cfg['Sup']['net']['depth']}" f" and a number of initial filter of {cfg['Sup']['net']['top_filter']}," ) logger.info( f"Reconstruction performed with {'Upsample + Conv' if cfg['Sup']['net']['bilinear'] else 'ConvTranspose'}." ) logger.info( f"U-Net2D takes {cfg['Sup']['net']['in_channels']} as input channels and {cfg['Sup']['net']['out_channels']} as output channels." ) logger.info( f"The U-Net2D has {sum(p.numel() for p in unet_sup.parameters())} parameters." ) # Make Model unet2D = UNet2D( unet_sup, n_epoch=cfg['Sup']['train']['n_epoch'], batch_size=cfg['Sup']['train']['batch_size'], lr=cfg['Sup']['train']['lr'], lr_scheduler=getattr(torch.optim.lr_scheduler, cfg['Sup']['train']['lr_scheduler']), lr_scheduler_kwargs=cfg['Sup']['train']['lr_scheduler_kwargs'], loss_fn=getattr(src.models.optim.LossFunctions, cfg['Sup']['train']['loss_fn']), loss_fn_kwargs=cfg['Sup']['train']['loss_fn_kwargs'], weight_decay=cfg['Sup']['train']['weight_decay'], num_workers=cfg['Sup']['train']['num_workers'], device=cfg['device'], print_progress=cfg['print_progress']) # ????? load model if specified ????? # transfer weights learn with context restoration logger.info( 'Initialize U-Net2D with weights learned with context_restoration on RSNA.' ) unet2D.transfer_weights(pretrained_unet_weights, verbose=True) # Print training parameters train_params = [] for key, value in cfg['Sup']['train'].items(): train_params.append(f"--> {key} : {value}") logger.info('Training settings:\n\t' + '\n\t'.join(train_params)) # Train U-net eval_dataset = test_dataset if cfg['Sup']['train'][ 'validate_epoch'] else None unet2D.train(train_dataset, valid_dataset=eval_dataset, checkpoint_path=os.path.join( out_path_sup, f'Fold_{k+1}/checkpoint.pt')) # Evaluate U-Net unet2D.evaluate(test_dataset, save_path=os.path.join(out_path_sup, f'Fold_{k+1}/pred/')) # Save models and outputs unet2D.save_model( os.path.join(out_path_sup, f'Fold_{k+1}/trained_unet.pt')) logger.info( "Trained U-Net saved at " + os.path.join(out_path_sup, f'Fold_{k+1}/trained_unet.pt')) unet2D.save_outputs( os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json')) logger.info("Trained statistics saved at " + os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json')) # delete checkpoint if exists if os.path.exists( os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')): os.remove( os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')) logger.info('Checkpoint deleted.') # save mean +/- 1.96 std Dice over Folds save_mean_fold_dice(out_path_sup, cfg['Sup']['split']['n_fold']) logger.info('Average Scores saved at ' + os.path.join(out_path_sup, 'average_scores.txt')) # Save all volumes prediction csv df_list = [ pd.read_csv( os.path.join(out_path_sup, f'Fold_{i+1}/pred/volume_prediction_scores.csv')) for i in range(cfg['Sup']['split']['n_fold']) ] all_df = pd.concat(df_list, axis=0).reset_index(drop=True) all_df.to_csv(os.path.join(out_path_sup, 'all_volume_prediction.csv')) logger.info('CSV of all volumes prediction saved at ' + os.path.join(out_path_sup, 'all_volume_prediction.csv')) # Save config file cfg['device'] = str(cfg['device']) with open(os.path.join(out_path, 'config.json'), 'w') as fp: json.dump(cfg, fp) logger.info('Config file saved at ' + os.path.join(out_path, 'config.json')) # Analyse results analyse_supervised_exp(out_path_sup, cfg['path']['data']['Sup'], n_fold=cfg['Sup']['split']['n_fold'], config_folder=out_path, save_fn=os.path.join( out_path, 'results_supervised_overview.pdf')) logger.info('Results overview figure saved at ' + os.path.join(out_path, 'results_supervised_overview.pdf')) analyse_representation_exp(out_path_selfsup, save_fn=os.path.join( out_path, 'results_self-supervised_overview.pdf')) logger.info('Results overview figure saved at ' + os.path.join(out_path, 'results_self-supervised_overview.pdf'))
def main(config_path): """ """ # load config cfg = AttrDict.from_json_path(config_path) # make outputs dir out_path = os.path.join(cfg.path.output, cfg.exp_name) os.makedirs(out_path, exist_ok=True) # initialize seed if cfg.seed != -1: random.seed(cfg.seed) np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) torch.cuda.manual_seed(cfg.seed) torch.cuda.manual_seed_all(cfg.seed) torch.backends.cudnn.deterministic = True # initialize logger logger = initialize_logger(os.path.join(out_path, 'log.txt')) logger.info(f"Experiment : {cfg.exp_name}") # set device if cfg.device: cfg.device = torch.device(cfg.device) else: cfg.device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu') logger.info(f"Device set to {cfg.device}.") #------------------------------------------- # Make Dataset #------------------------------------------- data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'ct_info.csv'), index_col=0) dataset = public_SegICH_Dataset2D(data_info_df, cfg.path.data, augmentation_transform=[getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.data.augmentation.items()], output_size=cfg.data.size, window=(cfg.data.win_center, cfg.data.win_width)) #------------------------------------------- # Load FCDD Model #------------------------------------------- cfg_fcdd = AttrDict.from_json_path(cfg.fcdd_cfg_path) fcdd_net = FCDD_CNN_VGG(in_shape=(cfg_fcdd.net.in_channels, 256, 256), bias=cfg_fcdd.net.bias) loaded_state_dict = torch.load(cfg.fcdd_model_path, map_location=cfg.device) fcdd_net.load_state_dict(loaded_state_dict) fcdd_net = fcdd_net.to(cfg.device).eval() logger.info(f"FCDD model succesfully loaded from {cfg.fcdd_model_path}") # make FCDD object fcdd = FCDD(fcdd_net, batch_size=cfg.batch_size, num_workers=cfg.num_workers, device=cfg.device, print_progress=cfg.print_progress) #------------------------------------------- # Load Classifier Model #------------------------------------------- # Load Classifier if cfg.classifier_model_path is not None: cfg_classifier = AttrDict.from_json_path(os.path.join(cfg.classifier_model_path, 'config.json')) classifier = getattr(rn, cfg_classifier.net.resnet)(num_classes=cfg_classifier.net.num_classes, input_channels=cfg_classifier.net.input_channels) classifier_state_dict = torch.load(os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt'), map_location=cfg.device) classifier.load_state_dict(classifier_state_dict) classifier = classifier.to(cfg.device) classifier.eval() logger.info(f"ResNet classifier model succesfully loaded from {os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt')}") #------------------------------------------- # Generate Heat-Map for each slice #------------------------------------------- with torch.no_grad(): # make loader loader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=False, worker_init_fn=lambda _: np.random.seed()) fcdd_net.eval() min_val, max_val = fcdd.get_min_max(loader, **cfg.heatmap_param) # computing and saving heatmaps out = dict(id=[], slice=[], label=[], ad_map_fn=[], ad_mask_fn=[], TP=[], TN=[], FP=[], FN=[], AUC=[], classifier_pred=[]) for b, data in enumerate(loader): im, mask, id, slice = data im = im.to(cfg.device).float() mask = mask.to(cfg.device).float() # get heatmap heatmap = fcdd.generate_heatmap(im, reception=cfg.heatmap_param.reception, std=cfg.heatmap_param.std, cpu=cfg.heatmap_param.cpu) # scaling heatmap = ((heatmap - min_val) / (max_val - min_val)).clamp(0,1) # Threshold ad_mask = torch.where(heatmap >= cfg.heatmap_threshold, torch.ones_like(heatmap, device=heatmap.device), torch.zeros_like(heatmap, device=heatmap.device)) # Compute CM tn, fp, fn, tp = batch_binary_confusion_matrix(ad_mask, mask.to(heatmap.device)) # Save heatmaps/mask map_fn, mask_fn = [], [] for i in range(im.shape[0]): # Save AD Map ad_map_fn = f"{id[i]}/{slice[i]}_map_anomalies.png" save_path = os.path.join(out_path, 'pred/', ad_map_fn) if not os.path.isdir(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path), exist_ok=True) ad_map = heatmap[i].squeeze().cpu().numpy() io.imsave(save_path, img_as_ubyte(ad_map), check_contrast=False) # save ad_mask ad_mask_fn = f"{id[i]}/{slice[i]}_anomalies.bmp" save_path = os.path.join(out_path, 'pred/', ad_mask_fn) io.imsave(save_path, img_as_ubyte(ad_mask[i].squeeze().cpu().numpy()), check_contrast=False) map_fn.append(ad_map_fn) mask_fn.append(ad_mask_fn) # apply classifier ResNet-18 if cfg.classifier_model_path is not None: pred_score = nn.functional.softmax(classifier(im), dim=1)[:,1] # take columns of softmax of positive class as score clss_pred = torch.where(pred_score >= cfg.classification_threshold, torch.ones_like(pred_score, device=pred_score.device), torch.zeros_like(pred_score, device=pred_score.device)) else: clss_pred = [None]*im.shape[0] # Save Values out['id'] += id.cpu().tolist() out['slice'] += slice.cpu().tolist() out['label'] += mask.reshape(mask.shape[0], -1).max(dim=1)[0].cpu().tolist() out['ad_map_fn'] += map_fn out['ad_mask_fn'] += mask_fn out['TN'] += tn.cpu().tolist() out['FP'] += fp.cpu().tolist() out['FN'] += fn.cpu().tolist() out['TP'] += tp.cpu().tolist() out['AUC'] += [roc_auc_score(mask[i].cpu().numpy().ravel(), heatmap[i].cpu().numpy().ravel()) if torch.any(mask[i]>0) else 'None' for i in range(im.shape[0])] out['classifier_pred'] += clss_pred.cpu().tolist() if cfg.print_progress: print_progessbar(b, len(loader), Name='Heatmap Generation Batch', Size=100, erase=True) # make df and save as csv slice_df = pd.DataFrame(out) volume_df = slice_df[['id', 'label', 'TP', 'TN', 'FP', 'FN']].groupby('id').agg({'label':'max', 'TP':'sum', 'TN':'sum', 'FP':'sum', 'FN':'sum'}) slice_df['Dice'] = (2*slice_df.TP + 1) / (2*slice_df.TP + slice_df.FP + slice_df.FN + 1) volume_df['Dice'] = (2*volume_df.TP + 1) / (2*volume_df.TP + volume_df.FP + volume_df.FN + 1) logger.info(f"Mean slice dice : {slice_df.Dice.mean(axis=0):.3f}") logger.info(f"Mean volume dice : {volume_df.Dice.mean(axis=0):.3f}") logger.info(f"Mean posiitve slice AUC {slice_df[slice_df.label == 1].AUC.mean(axis=0):.3f}") # Save Scores and Config slice_df.to_csv(os.path.join(out_path, 'slice_predictions.csv')) logger.info(f"Slice prediction csv saved at {os.path.join(out_path, 'slice_predictions.csv')}") volume_df.to_csv(os.path.join(out_path, 'volume_predictions.csv')) logger.info(f"Volume prediction csv saved at {os.path.join(out_path, 'volume_predictions.csv')}") cfg.device = str(cfg.device) with open(os.path.join(out_path, 'config.json'), 'w') as f: json.dump(cfg, f) logger.info(f"Config file saved at {os.path.join(out_path, 'config.json')}")
def main(config_path): """ Segmente ICH using the anomaly inpainting approach on a whole dataset and compute slice/volume dice. """ # load config cfg = AttrDict.from_json_path(config_path) # make outputs dir out_path = os.path.join(cfg.path.output, cfg.exp_name) os.makedirs(out_path, exist_ok=True) # initialize seed if cfg.seed != -1: random.seed(cfg.seed) np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) torch.cuda.manual_seed(cfg.seed) torch.cuda.manual_seed_all(cfg.seed) torch.backends.cudnn.deterministic = True # set device cfg.device = torch.device( cfg.device) if cfg.device else get_available_device() # initialize logger logger = initialize_logger(os.path.join(out_path, 'log.txt')) logger.info(f"Experiment : {cfg.exp_name}") # get Dataset data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'ct_info.csv'), index_col=0) #data_info_df = data_info_df[sum([data_info_df.CT_fn.str.contains(s) for s in ['49/16', '51/39', '71/15', '71/22', '75/22']]) > 0] dataset = public_SegICH_Dataset2D( data_info_df, cfg.path.data, augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.data.augmentation.items() ], output_size=cfg.data.size, window=(cfg.data.win_center, cfg.data.win_width)) # load inpainting model cfg_inpaint = AttrDict.from_json_path(cfg.inpainter_cfg_path) cfg_inpaint.net.gen.return_coarse = False inpaint_net = SAGatedGenerator(**cfg_inpaint.net.gen) loaded_state_dict = torch.load(cfg.inpainter_model_path, map_location=cfg.device) inpaint_net.load_state_dict(loaded_state_dict) inpaint_net = inpaint_net.to( cfg.device ) # inpainter not in eval mode beacuse batch norm layers are not stabilized (because GAN optimization) logger.info( f"Inpainter model succesfully loaded from {cfg.inpainter_model_path}") # make AD inpainter Module ad_inpainter = InpaintAnomalyDetector(inpaint_net, device=cfg.device, **cfg.model_param) # Load Classifier if cfg.classifier_model_path is not None: cfg_classifier = AttrDict.from_json_path( os.path.join(cfg.classifier_model_path, 'config.json')) classifier = getattr(rn, cfg_classifier.net.resnet)( num_classes=cfg_classifier.net.num_classes, input_channels=cfg_classifier.net.input_channels) classifier_state_dict = torch.load(os.path.join( cfg.classifier_model_path, 'resnet_state_dict.pt'), map_location=cfg.device) classifier.load_state_dict(classifier_state_dict) classifier = classifier.to(cfg.device) classifier.eval() logger.info( f"ResNet classifier model succesfully loaded from {os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt')}" ) # iterate over dataset all_pred = [] for i, sample in enumerate(dataset): # unpack data image, target, id, slice = sample logger.info( "=" * 25 + f" SAMPLE {i+1:04}/{len(dataset):04} - Volume {id:03} Slice {slice:03} " + "=" * 25) # Classify sample if cfg.classifier_model_path is not None: with torch.no_grad(): input_clss = image.unsqueeze(0).to(cfg.device).float() pred_score = nn.functional.softmax( classifier(input_clss), dim=1 )[:, 1] # take columns of softmax of positive class as score pred = 1 if pred_score >= cfg.classification_threshold else 0 else: pred = 1 # if not classifier given, all slices are processed # process slice if classifier has detected Hemorrhage if pred == 1: logger.info( f"ICH detected. Compute anomaly mask through inpainting.") # Detect anomalies using the robuste approach ad_mask, ano_map, intermediate_masks = robust_anomaly_detect( image, ad_inpainter, save_dir=None, verbose=True, return_intermediate=True, **cfg.robust_param) # save ad_mask ad_mask_fn = f"{id}/{slice}_anomalies.bmp" save_path = os.path.join(out_path, 'pred/', ad_mask_fn) if not os.path.isdir(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path), exist_ok=True) io.imsave(save_path, img_as_ubyte(ad_mask), check_contrast=False) # save anomaly map ad_map_fn = f"{id}/{slice}_map_anomalies.png" save_path_map = os.path.join(out_path, 'pred/', ad_map_fn) io.imsave(save_path_map, img_as_ubyte(ano_map), check_contrast=False) # save intermediate mask for j, m in enumerate(intermediate_masks): if not os.path.isdir( os.path.join(out_path, f"pred/{id}/intermediate_masks/")): os.makedirs(os.path.join(out_path, f"pred/{id}/intermediate_masks/"), exist_ok=True) io.imsave(os.path.join( out_path, f"pred/{id}/intermediate_masks/{slice}_anomalies_{j+1}.bmp" ), img_as_ubyte(m), check_contrast=False) else: logger.info(f"No ICH detected. Set the anomaly mask to zeros.") ad_mask = np.zeros_like(target[0].numpy()) ad_mask_fn, ad_map_fn = 'None', 'None' # compute confusion matrix with target ICH mask tn, fp, fn, tp = confusion_matrix(target[0].numpy().ravel(), ad_mask.ravel(), labels=[0, 1]).ravel() # append to all_pred list all_pred.append({ 'id': id.item(), 'slice': slice.item(), 'label': target.max().item(), 'TP': tp, 'TN': tn, 'FP': fp, 'FN': fn, 'ad_mask_fn': ad_mask_fn, 'ad_map_fn': ad_map_fn }) # make a dataframe of all predictions slice_df = pd.DataFrame(all_pred) volume_df = slice_df[['id', 'label', 'TP', 'TN', 'FP', 'FN']].groupby('id').agg({ 'label': 'max', 'TP': 'sum', 'TN': 'sum', 'FP': 'sum', 'FN': 'sum' }) # Compute Dice and Volume Dice slice_df['Dice'] = (2 * slice_df.TP + 1) / (2 * slice_df.TP + slice_df.FP + slice_df.FN + 1) volume_df['Dice'] = (2 * volume_df.TP + 1) / ( 2 * volume_df.TP + volume_df.FP + volume_df.FN + 1) logger.info(f"Mean slice dice : {slice_df.Dice.mean(axis=0):.3f}") logger.info(f"Mean volume dice : {volume_df.Dice.mean(axis=0):.3f}") # Save Scores and Config slice_df.to_csv(os.path.join(out_path, 'slice_predictions.csv')) logger.info( f"Slice prediction csv saved at {os.path.join(out_path, 'slice_predictions.csv')}" ) volume_df.to_csv(os.path.join(out_path, 'volume_predictions.csv')) logger.info( f"Volume prediction csv saved at {os.path.join(out_path, 'volume_predictions.csv')}" ) cfg.device = str(cfg.device) with open(os.path.join(out_path, 'config.json'), 'w') as f: json.dump(cfg, f) logger.info( f"Config file saved at {os.path.join(out_path, 'config.json')}")