def main(config_path): """ Train an FCDD on the RSNA dataset. """ # Load config file cfg = AttrDict.from_json_path(config_path) # make outputs dir out_path = os.path.join(cfg.path.output, cfg.exp_name) os.makedirs(out_path, exist_ok=True) if cfg.train.validate_epoch: os.makedirs(os.path.join(out_path, 'valid_results/'), exist_ok=True) # initialize seed if cfg.seed != None: random.seed(cfg.seed) np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) torch.cuda.manual_seed(cfg.seed) torch.cuda.manual_seed_all(cfg.seed) torch.backends.cudnn.deterministic = True # initialize logger logger = initialize_logger(os.path.join(out_path, 'log.txt')) if os.path.exists(os.path.join(out_path, f'checkpoint.pt')): logger.info('\n' + '#'*30 + f'\n Recovering Session \n' + '#'*30) logger.info(f"Experiment : {cfg.exp_name}") # set device if cfg.device: cfg.device = torch.device(cfg.device) else: cfg.device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu') logger.info(f"Device set to {cfg.device}.") #-------------------------------------------------------------------- # Make Datasets #-------------------------------------------------------------------- # load RSNA data & keep normal only & and sample the required number df_rsna = pd.read_csv(os.path.join(cfg.path.data, 'slice_info.csv'), index_col=0) df_train = df_rsna[df_rsna.Hemorrhage == 0].sample(n=cfg.dataset.n_normal, random_state=cfg.seed) if cfg.dataset.n_abnormal > 0: df_rsna_neg = df_rsna[df_rsna.Hemorrhage == 1].sample(n=cfg.dataset.n_abnormal, random_state=cfg.seed) df_train = pd.concat([df_train, df_rsna_neg], axis=0) # df for validation if cfg.train.validate_epoch: df_rsna_remain = df_rsna[~df_rsna.index.isin(df_train.index)] df_valid = df_rsna_remain[df_rsna_remain.Hemorrhage == 0].sample(n=cfg.dataset.n_normal_valid, random_state=cfg.seed) if cfg.dataset.n_abnormal_valid > 0: df_rsna_neg = df_rsna_remain[df_rsna_remain.Hemorrhage == 1].sample(n=cfg.dataset.n_abnormal_valid, random_state=cfg.seed) df_valid = pd.concat([df_valid, df_rsna_neg], axis=0) # Make FCDD dataset train_dataset = RSNA_FCDD_dataset(df_train, cfg.path.data, artificial_anomaly=cfg.dataset.artificial_anomaly, anomaly_proba=cfg.dataset.anomaly_proba, augmentation_transform=[getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.dataset.augmentation.train.items()], window=(cfg.dataset.win_center, cfg.dataset.win_width), output_size=cfg.dataset.size, drawing_params=cfg.dataset.drawing_params) if cfg.train.validate_epoch: valid_dataset = RSNA_FCDD_dataset(df_valid, cfg.path.data, artificial_anomaly=cfg.dataset.artificial_anomaly_valid, anomaly_proba=cfg.dataset.anomaly_proba, augmentation_transform=[getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.dataset.augmentation.eval.items()], window=(cfg.dataset.win_center, cfg.dataset.win_width), output_size=cfg.dataset.size, drawing_params=cfg.dataset.drawing_params) else: valid_dataset = None logger.info(f"Data loaded from {cfg.path.data}.") logger.info(f"Train set contains {len(train_dataset)} samples.") if valid_dataset: logger.info(f"Valid set contains {len(valid_dataset)} samples.") logger.info(f"CT scans will be windowed on [{cfg.dataset.win_center-cfg.dataset.win_width/2} ; {cfg.dataset.win_center + cfg.dataset.win_width/2}]") logger.info(f"CT scans will be resized to {cfg.dataset.size}x{cfg.dataset.size}") logger.info(f"Training online data transformation: \n\n {str(train_dataset.transform)}\n") if valid_dataset: logger.info(f"Evaluation online data transformation: \n\n {str(valid_dataset.transform)}\n") if cfg.dataset.artificial_anomaly: draw_params = [f"--> {k} : {v}" for k, v in cfg.dataset.drawing_params.items()] logger.info("Artificial Anomaly drawing parameters \n\t" + "\n\t".join(draw_params)) #-------------------------------------------------------------------- # Make Networks #-------------------------------------------------------------------- net = FCDD_CNN_VGG(in_shape=[cfg.net.in_channels, cfg.dataset.size, cfg.dataset.size], bias=cfg.net.bias) #-------------------------------------------------------------------- # Make FCDD model #-------------------------------------------------------------------- cfg.train.model_param.lr_scheduler = getattr(torch.optim.lr_scheduler, cfg.train.model_param.lr_scheduler) # convert scheduler name to scheduler class object model = FCDD(net, print_progress=cfg.print_progress, device=cfg.device, **cfg.train.model_param) train_params = [f"--> {k} : {v}" for k, v in cfg.train.model_param.items()] logger.info("FCDD Training Parameters \n\t" + "\n\t".join(train_params)) # load models if provided if cfg.train.model_path_to_load: model.load_model(cfg.train.model_path_to_load, map_location=cfg.device) logger.info(f"FCDD Model loaded from {cfg.train.model_path_to_load}") #-------------------------------------------------------------------- # Train FCDD model #-------------------------------------------------------------------- if cfg.train.model_param.n_epoch > 0: model.train(train_dataset, checkpoint_path=os.path.join(out_path, 'Checkpoint.pt'), valid_dataset=valid_dataset) #-------------------------------------------------------------------- # Generate and save few Heatmap with FCDD model #-------------------------------------------------------------------- if cfg.train.validate_epoch: if len(valid_dataset) > 100: valid_subset = torch.utils.data.random_split(valid_dataset, [100, len(valid_dataset)-100], generator=torch.Generator().manual_seed(cfg.seed))[0] else: valid_subset = valid_dataset model.localize_anomalies(valid_subset, save_path=os.path.join(out_path, 'valid_results/'), **cfg.train.heatmap_param) #-------------------------------------------------------------------- # Save outputs, models and config #-------------------------------------------------------------------- # save models model.save_model(export_fn=os.path.join(out_path, 'FCDD.pt')) logger.info("FCDD model saved at " + os.path.join(out_path, 'FCDD.pt')) # save outputs model.save_outputs(export_fn=os.path.join(out_path, 'outputs.json')) logger.info("Outputs file saved at " + os.path.join(out_path, 'outputs.json')) # save config file cfg.device = str(cfg.device) # set device as string to be JSON serializable cfg.train.model_param.lr_scheduler = str(cfg.train.model_param.lr_scheduler) with open(os.path.join(out_path, 'config.json'), 'w') as fp: json.dump(cfg, fp) logger.info("Config file saved at " + os.path.join(out_path, 'config.json'))
def main(config_path): """ """ # load config cfg = AttrDict.from_json_path(config_path) # make outputs dir out_path = os.path.join(cfg.path.output, cfg.exp_name) os.makedirs(out_path, exist_ok=True) # initialize seed if cfg.seed != -1: random.seed(cfg.seed) np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) torch.cuda.manual_seed(cfg.seed) torch.cuda.manual_seed_all(cfg.seed) torch.backends.cudnn.deterministic = True # initialize logger logger = initialize_logger(os.path.join(out_path, 'log.txt')) logger.info(f"Experiment : {cfg.exp_name}") # set device if cfg.device: cfg.device = torch.device(cfg.device) else: cfg.device = torch.device( f'cuda:0') if torch.cuda.is_available() else torch.device('cpu') logger.info(f"Device set to {cfg.device}.") # get Dataset data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'ct_info.csv'), index_col=0) dataset = public_SegICH_Dataset2D( data_info_df, cfg.path.data, augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.data.augmentation.items() ], output_size=cfg.data.size, window=(cfg.data.win_center, cfg.data.win_width)) # load inpainting model cfg_ae = AttrDict.from_json_path(cfg.ae_cfg_path) ae_net = AE_net(**cfg_ae.net) loaded_state_dict = torch.load(cfg.ae_model_path, map_location=cfg.device) ae_net.load_state_dict(loaded_state_dict) ae_net = ae_net.to(cfg.device).eval() logger.info(f"AE model succesfully loaded from {cfg.ae_model_path}") # Load Classifier if cfg.classifier_model_path is not None: cfg_classifier = AttrDict.from_json_path( os.path.join(cfg.classifier_model_path, 'config.json')) classifier = getattr(rn, cfg_classifier.net.resnet)( num_classes=cfg_classifier.net.num_classes, input_channels=cfg_classifier.net.input_channels) classifier_state_dict = torch.load(os.path.join( cfg.classifier_model_path, 'resnet_state_dict.pt'), map_location=cfg.device) classifier.load_state_dict(classifier_state_dict) classifier = classifier.to(cfg.device) classifier.eval() logger.info( f"ResNet classifier model succesfully loaded from {os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt')}" ) # iterate over dataset all_pred = [] for i, sample in enumerate(dataset): # unpack data image, target, id, slice = sample logger.info( "=" * 25 + f" SAMPLE {i+1:04}/{len(dataset):04} - Volume {id:03} Slice {slice:03} " + "=" * 25) # Classify sample if cfg.classifier_model_path is not None: with torch.no_grad(): input_clss = image.unsqueeze(0).to(cfg.device).float() pred_score = nn.functional.softmax( classifier(input_clss), dim=1 )[:, 1] # take columns of softmax of positive class as score pred = 1 if pred_score >= cfg.classification_threshold else 0 else: pred = 1 # if not classifier given, all slices are processed # process slice if classifier has detected Hemorrhage if pred == 1: logger.info( f"ICH detected. Compute anomaly mask through AE reconstruction." ) # Detect anomalies using the robuste approach ad_map, ad_mask = compute_anomaly(ae_net, image, alpha_low=cfg.alpha_low, alpha_high=cfg.alpha_high, device=cfg.device) logger.info(f"{ad_mask.sum()} anomalous pixels detected.") # save ad_mask ad_mask_fn = f"{id}/{slice}_anomalies.bmp" save_path = os.path.join(out_path, 'pred/', ad_mask_fn) if not os.path.isdir(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path), exist_ok=True) io.imsave(save_path, img_as_ubyte(ad_mask), check_contrast=False) # save anomaly map ad_map_fn = f"{id}/{slice}_map_anomalies.png" save_path_map = os.path.join(out_path, 'pred/', ad_map_fn) io.imsave(save_path_map, img_as_ubyte( rescale_intensity(ad_map, out_range=(0.0, 1.0))), check_contrast=False) else: logger.info(f"No ICH detected. Set the anomaly mask to zeros.") ad_mask = np.zeros_like(target[0].numpy()) ad_mask_fn, ad_map_fn = 'None', 'None' # compute confusion matrix with target ICH mask tn, fp, fn, tp = confusion_matrix(target[0].numpy().ravel(), ad_mask.ravel(), labels=[0, 1]).ravel() auc = roc_auc_score(target[0].numpy().ravel(), ad_map.ravel()) if torch.any(target[0]) else 'None' # append to all_pred list all_pred.append({ 'id': id.item(), 'slice': slice.item(), 'label': target.max().item(), 'TP': tp, 'TN': tn, 'FP': fp, 'FN': fn, 'AUC': auc, 'ad_mask_fn': ad_mask_fn, 'ad_map_fn': ad_map_fn }) # make a dataframe of all predictions slice_df = pd.DataFrame(all_pred) volume_df = slice_df[['id', 'label', 'TP', 'TN', 'FP', 'FN']].groupby('id').agg({ 'label': 'max', 'TP': 'sum', 'TN': 'sum', 'FP': 'sum', 'FN': 'sum' }) # Compute Dice and Volume Dice slice_df['Dice'] = (2 * slice_df.TP + 1) / (2 * slice_df.TP + slice_df.FP + slice_df.FN + 1) volume_df['Dice'] = (2 * volume_df.TP + 1) / ( 2 * volume_df.TP + volume_df.FP + volume_df.FN + 1) logger.info(f"Mean slice dice : {slice_df.Dice.mean(axis=0):.3f}") logger.info(f"Mean volume dice : {volume_df.Dice.mean(axis=0):.3f}") logger.info( f"Mean posiitve slice AUC {slice_df[slice_df.label == 1].AUC.mean(axis=0):.3f}" ) # Save Scores and Config slice_df.to_csv(os.path.join(out_path, 'slice_predictions.csv')) logger.info( f"Slice prediction csv saved at {os.path.join(out_path, 'slice_predictions.csv')}" ) volume_df.to_csv(os.path.join(out_path, 'volume_predictions.csv')) logger.info( f"Volume prediction csv saved at {os.path.join(out_path, 'volume_predictions.csv')}" ) cfg.device = str(cfg.device) with open(os.path.join(out_path, 'config.json'), 'w') as f: json.dump(cfg, f) logger.info( f"Config file saved at {os.path.join(out_path, 'config.json')}")
def main(config_path): """ UNet2D pretrained on binary classification with the RSNA dataset and finetuned with the public data. """ # load the config file cfg = AttrDict.from_json_path(config_path) # Make Outputs directories out_path = os.path.join(cfg.path.output, cfg.exp_name) out_path_selfsup = os.path.join(out_path, 'classification_pretrain/') out_path_sup = os.path.join(out_path, 'supervised_train/') os.makedirs(out_path_selfsup, exist_ok=True) for k in range(cfg.Sup.split.n_fold): os.makedirs(os.path.join(out_path_sup, f'Fold_{k+1}/pred/'), exist_ok=True) # Initialize random seed if cfg.seed != -1: random.seed(cfg.seed) np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) torch.cuda.manual_seed(cfg.seed) torch.cuda.manual_seed_all(cfg.seed) torch.backends.cudnn.deterministic = True # Set number of thread if cfg.n_thread > 0: torch.set_num_threads(cfg.n_thread) # set device if cfg.device: cfg.device = torch.device(cfg.device) else: cfg.device = get_available_device() #################################################### # Self-supervised training on Multi Classification # #################################################### # Initialize Logger logger = initialize_logger(os.path.join(out_path_selfsup, 'log.txt')) if os.path.exists(os.path.join(out_path_selfsup, f'checkpoint.pt')): logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' + '#' * 30) logger.info(f"Experiment : {cfg.exp_name}") # Load RSNA data csv df_rsna = pd.read_csv(os.path.join(cfg.path.data.SSL, 'slice_info.csv'), index_col=0) # Keep only fractions sample if cfg.SSL.dataset.n_data_1 >= 0: df_rsna_ICH = df_rsna[df_rsna.Hemorrhage == 1].sample( n=cfg.SSL.dataset.n_data_1, random_state=cfg.seed) else: df_rsna_ICH = df_rsna[df_rsna.Hemorrhage == 1] if cfg.SSL.dataset.f_data_0 >= 0: # nbr of normal as fraction of nbr of ICH samples df_rsna_noICH = df_rsna[df_rsna.Hemorrhage == 0].sample( n=cfg.SSL.dataset.f_data_0 * len(df_rsna_ICH), random_state=cfg.seed) else: df_rsna_noICH = df_rsna[df_rsna.Hemorrhage == 0] df_rsna = pd.concat([df_rsna_ICH, df_rsna_noICH], axis=0) # Split data to keep few for evaluation in a strafied way train_df, test_df = train_test_split(df_rsna, test_size=cfg.SSL.dataset.frac_eval, stratify=df_rsna.Hemorrhage, random_state=cfg.seed) logger.info('\n' + str(get_split_summary_table(df_rsna, train_df, test_df))) # Make dataset : Train --> BinaryClassification, Test --> BinaryClassification train_RSNA_dataset = RSNA_dataset( train_df, cfg.path.data.SSL, augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.SSL.dataset.augmentation.train.items() ], window=(cfg.data.win_center, cfg.data.win_width), output_size=cfg.data.size, mode='multi_classification') test_RSNA_dataset = RSNA_dataset( test_df, cfg.path.data.SSL, augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.SSL.dataset.augmentation.eval.items() ], window=(cfg.data.win_center, cfg.data.win_width), output_size=cfg.data.size, mode='multi_classification') logger.info(f"Data will be loaded from {cfg.path.data.SSL}.") logger.info( f"CT scans will be windowed on [{cfg.data.win_center-cfg.data.win_width/2} ; {cfg.data.win_center + cfg.data.win_width/2}]" ) logger.info(f"CT scans will be resized to {cfg.data.size}x{cfg.data.size}") logger.info( f"Training online data transformation: \n\n {str(train_RSNA_dataset.transform)}\n" ) logger.info( f"Evaluation online data transformation: \n\n {str(test_RSNA_dataset.transform)}\n" ) # Make U-Net-Encoder architecture net_ssl = UNet_Encoder(**cfg.SSL.net).to(cfg.device) net_params = [f"--> {k} : {v}" for k, v in cfg.SSL.net.items()] logger.info("UNet like Multi Classifier \n\t" + "\n\t".join(net_params)) logger.info( f"The Multi Classifier has {sum(p.numel() for p in net_ssl.parameters())} parameters." ) # Make Model cfg.SSL.train.model_param.lr_scheduler = getattr( torch.optim.lr_scheduler, cfg.SSL.train.model_param.lr_scheduler ) # convert scheduler name to scheduler class object if cfg.SSL.train.model_param.loss_fn == 'BCEWithLogitsLoss': df_rsna['no_Hemorrhage'] = 1 - df_rsna.Hemorrhage class_weight_list = ( (len(df_rsna) - df_rsna[train_RSNA_dataset.class_name].sum()) / df_rsna[train_RSNA_dataset.class_name].sum() ).values # define CE weighting from train dataset cfg.SSL.train.model_param.loss_fn_kwargs['pos_weight'] = torch.tensor( class_weight_list, device=cfg.device) # add weighting to CE kwargs try: cfg.SSL.train.model_param.loss_fn = getattr( torch.nn, cfg.SSL.train.model_param.loss_fn ) # convert loss_fn name to nn.Module class object except AttributeError: cfg.SSL.train.model_param.loss_fn = getattr( src.models.optim.LossFunctions, cfg.SSL.train.model_param.loss_fn) #torch.tensor([1 - w_ICH, w_ICH], device=cfg.device).float() classifier = MultiClassifier(net_ssl, device=cfg.device, print_progress=cfg.print_progress, **cfg.SSL.train.model_param) train_params = [ f"--> {k} : {v}" for k, v in cfg.SSL.train.model_param.items() ] logger.info("Classifer Training Parameters \n\t" + "\n\t".join(train_params)) # Load weights if specified if cfg.SSL.train.model_path_to_load: model_path = cfg.SSL.train.model_path_to_load classifier.load_model(model_path, map_location=cfg.device) logger.info( f"Classifer Model succesfully loaded from {cfg.SSL.train.model_path_to_load}" ) # train if needed if cfg.SSL.train.model_param.n_epoch > 0: classifier.train(train_RSNA_dataset, valid_dataset=test_RSNA_dataset, checkpoint_path=os.path.join(out_path_selfsup, f'checkpoint.pt')) # evaluate auc, acc, sub_acc, recall, precision, f1 = classifier.evaluate( test_RSNA_dataset, save_tsne=True, return_scores=True) logger.info(f"Classifier Test AUC : {auc:.2%}") logger.info(f"Classifier Test Accuracy : {acc:.2%}") logger.info(f"Classifier Test Subset Accuracy : {sub_acc:.2%}") logger.info(f"Classifier Test Recall : {recall:.2%}") logger.info(f"Classifier Test Precision : {precision:.2%}") logger.info(f"Classifier Test F1-score : {f1:.2%}") # save model, outputs classifier.save_model_state_dict( os.path.join(out_path_selfsup, 'pretrained_unet_enc.pt')) logger.info( "Pre-trained U-Net encoder on binary classification saved at " + os.path.join(out_path_selfsup, 'pretrained_unet_enc.pt')) classifier.save_outputs(os.path.join(out_path_selfsup, 'outputs.json')) logger.info("Classifier outputs saved at " + os.path.join(out_path_selfsup, 'outputs.json')) test_df.reset_index(drop=True).to_csv( os.path.join(out_path_selfsup, 'eval_data_info.csv')) logger.info("Evaluation data info saved at " + os.path.join(out_path_selfsup, 'eval_data_info.csv')) # delete any checkpoints if os.path.exists(os.path.join(out_path_selfsup, f'checkpoint.pt')): os.remove(os.path.join(out_path_selfsup, f'checkpoint.pt')) logger.info('Checkpoint deleted.') # get weights state dictionnary pretrained_unet_weights = classifier.get_state_dict() ################################################################### # Supervised fine-training of U-Net with K-Fold Cross-Validation # ################################################################### # load annotated data csv data_info_df = pd.read_csv(os.path.join(cfg.path.data.Sup, 'ct_info.csv'), index_col=0) patient_df = pd.read_csv(os.path.join(cfg.path.data.Sup, 'patient_info.csv'), index_col=0) # Make K-Fold spolit at patient level skf = StratifiedKFold(n_splits=cfg.Sup.split.n_fold, shuffle=cfg.Sup.split.shuffle, random_state=cfg.seed) # define scheduler and loss_fn as object cfg.Sup.train.model_param.lr_scheduler = getattr( torch.optim.lr_scheduler, cfg.Sup.train.model_param.lr_scheduler ) # convert scheduler name to scheduler class object cfg.Sup.train.model_param.loss_fn = getattr( src.models.optim.LossFunctions, cfg.Sup.train.model_param.loss_fn ) # convert loss_fn name to nn.Module class object # iterate over folds for k, (train_idx, test_idx) in enumerate( skf.split(patient_df.PatientNumber, patient_df.Hemorrhage)): # check if fold's results already exists if not os.path.exists( os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json')): # initialize logger logger = initialize_logger( os.path.join(out_path_sup, f'Fold_{k+1}/log.txt')) if os.path.exists( os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')): logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' + '#' * 30) logger.info(f"Experiment : {cfg['exp_name']}") logger.info( f"Cross-Validation fold {k+1:02}/{cfg['Sup']['split']['n_fold']:02}" ) # extract train/test slice dataframe train_df = data_info_df[data_info_df.PatientNumber.isin( patient_df.loc[train_idx, 'PatientNumber'].values)] test_df = data_info_df[data_info_df.PatientNumber.isin( patient_df.loc[test_idx, 'PatientNumber'].values)] # samples train dataframe to adjuste negative/positive fractions n_remove = int( max( 0, len(train_df[train_df.Hemorrhage == 0]) - cfg.Sup.dataset.frac_negative * len(train_df[train_df.Hemorrhage == 1]))) df_remove = train_df[train_df.Hemorrhage == 0].sample( n=n_remove, random_state=cfg.seed) train_df = train_df[~train_df.index.isin(df_remove.index)] logger.info( '\n' + str(get_split_summary_table(data_info_df, train_df, test_df))) # Make datasets train_dataset = public_SegICH_Dataset2D( train_df, cfg.path.data.Sup, augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.Sup.dataset.augmentation.train.items() ], window=(cfg.data.win_center, cfg.data.win_width), output_size=cfg.data.size) test_dataset = public_SegICH_Dataset2D( test_df, cfg.path.data.Sup, augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.Sup.dataset.augmentation.eval.items() ], window=(cfg.data.win_center, cfg.data.win_width), output_size=cfg.data.size) logger.info(f"Data will be loaded from {cfg.path.data.Sup}.") logger.info( f"CT scans will be windowed on [{cfg.data.win_center-cfg.data.win_width/2} ; {cfg.data.win_center + cfg.data.win_width/2}]" ) logger.info( f"CT scans will be resized to {cfg.data.size}x{cfg.data.size}") logger.info( f"Training online data transformation: \n\n {str(train_dataset.transform)}\n" ) logger.info( f"Evaluation online data transformation: \n\n {str(test_dataset.transform)}\n" ) # Make U-Net architecture unet_sup = UNet(**cfg.Sup.net).to(cfg.device) net_params = [f"--> {k} : {v}" for k, v in cfg.Sup.net.items()] logger.info("UNet-2D params \n\t" + "\n\t".join(net_params)) logger.info( f"The U-Net2D has {sum(p.numel() for p in unet_sup.parameters())} parameters." ) # Make Model unet2D = UNet2D(unet_sup, device=cfg.device, print_progress=cfg.print_progress, **cfg.Sup.train.model_param) train_params = [ f"--> {k} : {v}" for k, v in cfg.Sup.train.model_param.items() ] logger.info("UNet-2D Training Parameters \n\t" + "\n\t".join(train_params)) # ????? load model if specified ????? # transfer weights learn with context restoration logger.info( 'Initialize U-Net2D with weights learned with context_restoration on RSNA.' ) unet2D.transfer_weights(pretrained_unet_weights, verbose=True) # Train U-net eval_dataset = test_dataset if cfg.Sup.train.validate_epoch else None unet2D.train(train_dataset, valid_dataset=eval_dataset, checkpoint_path=os.path.join( out_path_sup, f'Fold_{k+1}/checkpoint.pt')) # Evaluate U-Net unet2D.evaluate(test_dataset, save_path=os.path.join(out_path_sup, f'Fold_{k+1}/pred/')) # Save models and outputs unet2D.save_model( os.path.join(out_path_sup, f'Fold_{k+1}/trained_unet.pt')) logger.info( "Trained U-Net saved at " + os.path.join(out_path_sup, f'Fold_{k+1}/trained_unet.pt')) unet2D.save_outputs( os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json')) logger.info("Trained statistics saved at " + os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json')) # delete checkpoint if exists if os.path.exists( os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')): os.remove( os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')) logger.info('Checkpoint deleted.') # save mean +/- 1.96 std Dice over Folds save_mean_fold_dice(out_path_sup, cfg.Sup.split.n_fold) logger.info('Average Scores saved at ' + os.path.join(out_path_sup, 'average_scores.txt')) # Save all volumes prediction csv df_list = [ pd.read_csv( os.path.join(out_path_sup, f'Fold_{i+1}/pred/volume_prediction_scores.csv')) for i in range(cfg.Sup.split.n_fold) ] all_df = pd.concat(df_list, axis=0).reset_index(drop=True) all_df.to_csv(os.path.join(out_path_sup, 'all_volume_prediction.csv')) logger.info('CSV of all volumes prediction saved at ' + os.path.join(out_path_sup, 'all_volume_prediction.csv')) # Save config file cfg.device = str(cfg.device) cfg.SSL.train.model_param.lr_scheduler = str( cfg.SSL.train.model_param.lr_scheduler) cfg.Sup.train.model_param.lr_scheduler = str( cfg.Sup.train.model_param.lr_scheduler) cfg.SSL.train.model_param.loss_fn = str(cfg.SSL.train.model_param.loss_fn) cfg.Sup.train.model_param.loss_fn = str(cfg.Sup.train.model_param.loss_fn) cfg.SSL.train.model_param.loss_fn_kwargs.pos_weight = cfg.SSL.train.model_param.loss_fn_kwargs.pos_weight.cpu( ).data.tolist() with open(os.path.join(out_path, 'config.json'), 'w') as fp: json.dump(cfg, fp) logger.info('Config file saved at ' + os.path.join(out_path, 'config.json')) # Analyse results analyse_supervised_exp(out_path_sup, cfg.path.data.Sup, n_fold=cfg.Sup.split.n_fold, config_folder=out_path, save_fn=os.path.join( out_path, 'results_supervised_overview.pdf')) logger.info('Results overview figure saved at ' + os.path.join(out_path, 'results_supervised_overview.pdf')) analyse_representation_exp(out_path_selfsup, save_fn=os.path.join( out_path, 'results_self-supervised_overview.pdf')) logger.info('Results overview figure saved at ' + os.path.join(out_path, 'results_self-supervised_overview.pdf'))
def main(config_path): """ Train and evaluate a 2D UNet on the public ICH dataset with the anomaly attention map using the parameters on the JSON at the config_path. The evaluation is performed by k-fold cross-validation. """ # load config file cfg = AttrDict.from_json_path(config_path) # Make Output directories out_path = os.path.join(cfg.path.output, cfg.exp_name) os.makedirs(out_path, exist_ok=True) for k in range(cfg.split.n_fold): os.makedirs(os.path.join(out_path, f'Fold_{k+1}/pred/'), exist_ok=True) # Initialize random seed to given seed if cfg.seed != -1: random.seed(cfg.seed) np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) torch.cuda.manual_seed(cfg.seed) torch.cuda.manual_seed_all(cfg.seed) torch.backends.cudnn.deterministic = True # Load data csv data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'info.csv'), index_col=0) patient_df = pd.read_csv(os.path.join(cfg.path.data, 'patient_info.csv'), index_col=0) # Generate Cross-Val indices at the patient level skf = StratifiedKFold(n_splits=cfg.split.n_fold, shuffle=cfg.split.shuffle, random_state=cfg.seed) # iterate over folds and ensure that there are the same amount of ICH positive patient per fold --> Stratiffied CrossVal for k, (train_idx, test_idx) in enumerate( skf.split(patient_df.PatientNumber, patient_df.Hemorrhage)): # if fold results not already there if not os.path.exists( os.path.join(out_path, f'Fold_{k+1}/outputs.json')): # initialize logger logger = initialize_logger(os.path.join(out_path, 'log.txt')) if os.path.exists( os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt')): logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' + '#' * 30) logger.info(f"Experiment : {cfg.exp_name}") logger.info( f"Cross-Validation fold {k+1:02}/{cfg.split.n_fold:02}") # check if GPU available if cfg.device is not None: cfg.device = torch.device(cfg.device) else: cfg.device = torch.device('cuda') if torch.cuda.is_available( ) else torch.device('cpu') logger.info(f"Device : {cfg.device}") # extract train and test DataFrames + print summary (n samples positive and negatives) train_df = data_info_df[data_info_df.id.isin( patient_df.loc[train_idx, 'PatientNumber'].values)] test_df = data_info_df[data_info_df.id.isin( patient_df.loc[test_idx, 'PatientNumber'].values)] # sample the dataframe to have more or less normal slices n_remove = int( max( 0, len(train_df[train_df.Hemorrhage == 0]) - cfg.dataset.frac_negative * len(train_df[train_df.Hemorrhage == 1]))) df_remove = train_df[train_df.Hemorrhage == 0].sample( n=n_remove, random_state=cfg.seed) train_df = train_df[~train_df.index.isin(df_remove.index)] logger.info( '\n' + str(get_split_summary_table(data_info_df, train_df, test_df))) # Make Dataset + print online augmentation summary train_dataset = public_SegICH_AttentionDataset2D( train_df, cfg.path.data, augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.data.augmentation.train.items() ], window=(cfg.data.win_center, cfg.data.win_width), output_size=cfg.data.size) test_dataset = public_SegICH_AttentionDataset2D( test_df, cfg.path.data, augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.data.augmentation.eval.items() ], window=(cfg.data.win_center, cfg.data.win_width), output_size=cfg.data.size) logger.info(f"Data will be loaded from {cfg.path.data}.") logger.info( f"CT scans will be windowed on [{cfg.data.win_center-cfg.data.win_width/2} ; {cfg.data.win_center + cfg.data.win_width/2}]" ) logger.info( f"Training online data transformation: \n\n {str(train_dataset.transform)}\n" ) logger.info( f"Evaluation online data transformation: \n\n {str(test_dataset.transform)}\n" ) # Make architecture (and print summmary ??) unet_arch = UNet(**cfg.net) unet_arch.to(cfg.device) net_params = [f"--> {k} : {v}" for k, v in cfg.net.items()] logger.info("U-Net2D architecture \n\t" + "\n\t".join(net_params)) logger.info( f"The U-Net2D has {sum(p.numel() for p in unet_arch.parameters())} parameters." ) # Make model cfg_train = AttrDict(cfg.train.params) cfg_train.lr_scheduler = getattr(torch.optim.lr_scheduler, cfg_train.lr_scheduler) cfg_train.loss_fn = getattr(src.models.optim.LossFunctions, cfg_train.loss_fn) unet2D = UNet2D(unet_arch, device=cfg.device, print_progress=cfg.print_progress, **cfg_train) # print Training hyper-parameters train_params = [f"--> {k} : {v}" for k, v in cfg_train.items()] logger.info("U-Net2D Training Parameters \n\t" + "\n\t".join(train_params)) # Load model if required if cfg.train.model_path_to_load: if isinstance(cfg.train.model_path_to_load, str): model_path = cfg.train.model_path_to_load unet2D.load_model(model_path, map_location=cfg.device) elif isinstance(cfg.train.model_path_to_load, list): model_path = cfg.train.model_path_to_load[k] unet2D.load_model(model_path, map_location=cfg.device) else: raise ValueError( f'Model path to load type not understood.') logger.info(f"2D U-Net model loaded from {model_path}") # Train model eval_dataset = test_dataset if cfg.train.validate_epoch else None unet2D.train(train_dataset, valid_dataset=eval_dataset, checkpoint_path=os.path.join( out_path, f'Fold_{k+1}/checkpoint.pt')) # Evaluate model unet2D.evaluate(test_dataset, save_path=os.path.join(out_path, f'Fold_{k+1}/pred/')) # Save models & outputs unet2D.save_model( os.path.join(out_path, f'Fold_{k+1}/trained_unet.pt')) logger.info("Trained U-Net saved at " + os.path.join(out_path, f'Fold_{k+1}/trained_unet.pt')) unet2D.save_outputs( os.path.join(out_path, f'Fold_{k+1}/outputs.json')) logger.info("Trained statistics saved at " + os.path.join(out_path, f'Fold_{k+1}/outputs.json')) # delete checkpoint if exists if os.path.exists( os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt')): os.remove(os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt')) logger.info('Checkpoint deleted.') # save mean +/- 1.96 std Dice in .txt file scores_list = [] for k in range(cfg.split.n_fold): with open(os.path.join(out_path, f'Fold_{k+1}/outputs.json'), 'r') as f: out = json.load(f) scores_list.append( [out['eval']['dice']['all'], out['eval']['dice']['positive']]) means = np.array(scores_list).mean(axis=0) CI95 = 1.96 * np.array(scores_list).std(axis=0) with open(os.path.join(out_path, 'average_scores.txt'), 'w') as f: f.write(f'Dice = {means[0]} +/- {CI95[0]}\n') f.write(f'Dice (Positive) = {means[1]} +/- {CI95[1]}\n') logger.info('Average Scores saved at ' + os.path.join(out_path, 'average_scores.txt')) # generate dataframe of all prediction df_list = [ pd.read_csv( os.path.join(out_path, f'Fold_{i+1}/pred/volume_prediction_scores.csv')) for i in range(cfg.split.n_fold) ] all_df = pd.concat(df_list, axis=0).reset_index(drop=True) all_df.to_csv(os.path.join(out_path, 'all_volume_prediction.csv')) logger.info('CSV of all volumes prediction saved at ' + os.path.join(out_path, 'all_volume_prediction.csv')) # Save config file cfg.device = str(cfg.device) #cfg.train.params.lr_scheduler = str(cfg.train.params.lr_scheduler) #cfg.train.params.loss_fn = str(cfg.train.params.loss_fn) with open(os.path.join(out_path, 'config.json'), 'w') as fp: json.dump(cfg, fp) logger.info("Config file saved at " + os.path.join(out_path, 'config.json')) # Analyse results analyse_supervised_exp(out_path, cfg.path.data, cfg.split.n_fold, save_fn=os.path.join(out_path, 'results_overview.pdf')) logger.info('Results overview figure saved at ' + os.path.join(out_path, 'results_overview.pdf'))
def main(config_path): """ """ # load config cfg = AttrDict.from_json_path(config_path) # make outputs dir out_path = os.path.join(cfg.path.output, cfg.exp_name) os.makedirs(out_path, exist_ok=True) # initialize seed if cfg.seed != -1: random.seed(cfg.seed) np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) torch.cuda.manual_seed(cfg.seed) torch.cuda.manual_seed_all(cfg.seed) torch.backends.cudnn.deterministic = True # initialize logger logger = initialize_logger(os.path.join(out_path, 'log.txt')) logger.info(f"Experiment : {cfg.exp_name}") # set device if cfg.device: cfg.device = torch.device(cfg.device) else: cfg.device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu') logger.info(f"Device set to {cfg.device}.") #------------------------------------------- # Make Dataset #------------------------------------------- data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'ct_info.csv'), index_col=0) dataset = public_SegICH_Dataset2D(data_info_df, cfg.path.data, augmentation_transform=[getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.data.augmentation.items()], output_size=cfg.data.size, window=(cfg.data.win_center, cfg.data.win_width)) #------------------------------------------- # Load FCDD Model #------------------------------------------- cfg_fcdd = AttrDict.from_json_path(cfg.fcdd_cfg_path) fcdd_net = FCDD_CNN_VGG(in_shape=(cfg_fcdd.net.in_channels, 256, 256), bias=cfg_fcdd.net.bias) loaded_state_dict = torch.load(cfg.fcdd_model_path, map_location=cfg.device) fcdd_net.load_state_dict(loaded_state_dict) fcdd_net = fcdd_net.to(cfg.device).eval() logger.info(f"FCDD model succesfully loaded from {cfg.fcdd_model_path}") # make FCDD object fcdd = FCDD(fcdd_net, batch_size=cfg.batch_size, num_workers=cfg.num_workers, device=cfg.device, print_progress=cfg.print_progress) #------------------------------------------- # Load Classifier Model #------------------------------------------- # Load Classifier if cfg.classifier_model_path is not None: cfg_classifier = AttrDict.from_json_path(os.path.join(cfg.classifier_model_path, 'config.json')) classifier = getattr(rn, cfg_classifier.net.resnet)(num_classes=cfg_classifier.net.num_classes, input_channels=cfg_classifier.net.input_channels) classifier_state_dict = torch.load(os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt'), map_location=cfg.device) classifier.load_state_dict(classifier_state_dict) classifier = classifier.to(cfg.device) classifier.eval() logger.info(f"ResNet classifier model succesfully loaded from {os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt')}") #------------------------------------------- # Generate Heat-Map for each slice #------------------------------------------- with torch.no_grad(): # make loader loader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=False, worker_init_fn=lambda _: np.random.seed()) fcdd_net.eval() min_val, max_val = fcdd.get_min_max(loader, **cfg.heatmap_param) # computing and saving heatmaps out = dict(id=[], slice=[], label=[], ad_map_fn=[], ad_mask_fn=[], TP=[], TN=[], FP=[], FN=[], AUC=[], classifier_pred=[]) for b, data in enumerate(loader): im, mask, id, slice = data im = im.to(cfg.device).float() mask = mask.to(cfg.device).float() # get heatmap heatmap = fcdd.generate_heatmap(im, reception=cfg.heatmap_param.reception, std=cfg.heatmap_param.std, cpu=cfg.heatmap_param.cpu) # scaling heatmap = ((heatmap - min_val) / (max_val - min_val)).clamp(0,1) # Threshold ad_mask = torch.where(heatmap >= cfg.heatmap_threshold, torch.ones_like(heatmap, device=heatmap.device), torch.zeros_like(heatmap, device=heatmap.device)) # Compute CM tn, fp, fn, tp = batch_binary_confusion_matrix(ad_mask, mask.to(heatmap.device)) # Save heatmaps/mask map_fn, mask_fn = [], [] for i in range(im.shape[0]): # Save AD Map ad_map_fn = f"{id[i]}/{slice[i]}_map_anomalies.png" save_path = os.path.join(out_path, 'pred/', ad_map_fn) if not os.path.isdir(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path), exist_ok=True) ad_map = heatmap[i].squeeze().cpu().numpy() io.imsave(save_path, img_as_ubyte(ad_map), check_contrast=False) # save ad_mask ad_mask_fn = f"{id[i]}/{slice[i]}_anomalies.bmp" save_path = os.path.join(out_path, 'pred/', ad_mask_fn) io.imsave(save_path, img_as_ubyte(ad_mask[i].squeeze().cpu().numpy()), check_contrast=False) map_fn.append(ad_map_fn) mask_fn.append(ad_mask_fn) # apply classifier ResNet-18 if cfg.classifier_model_path is not None: pred_score = nn.functional.softmax(classifier(im), dim=1)[:,1] # take columns of softmax of positive class as score clss_pred = torch.where(pred_score >= cfg.classification_threshold, torch.ones_like(pred_score, device=pred_score.device), torch.zeros_like(pred_score, device=pred_score.device)) else: clss_pred = [None]*im.shape[0] # Save Values out['id'] += id.cpu().tolist() out['slice'] += slice.cpu().tolist() out['label'] += mask.reshape(mask.shape[0], -1).max(dim=1)[0].cpu().tolist() out['ad_map_fn'] += map_fn out['ad_mask_fn'] += mask_fn out['TN'] += tn.cpu().tolist() out['FP'] += fp.cpu().tolist() out['FN'] += fn.cpu().tolist() out['TP'] += tp.cpu().tolist() out['AUC'] += [roc_auc_score(mask[i].cpu().numpy().ravel(), heatmap[i].cpu().numpy().ravel()) if torch.any(mask[i]>0) else 'None' for i in range(im.shape[0])] out['classifier_pred'] += clss_pred.cpu().tolist() if cfg.print_progress: print_progessbar(b, len(loader), Name='Heatmap Generation Batch', Size=100, erase=True) # make df and save as csv slice_df = pd.DataFrame(out) volume_df = slice_df[['id', 'label', 'TP', 'TN', 'FP', 'FN']].groupby('id').agg({'label':'max', 'TP':'sum', 'TN':'sum', 'FP':'sum', 'FN':'sum'}) slice_df['Dice'] = (2*slice_df.TP + 1) / (2*slice_df.TP + slice_df.FP + slice_df.FN + 1) volume_df['Dice'] = (2*volume_df.TP + 1) / (2*volume_df.TP + volume_df.FP + volume_df.FN + 1) logger.info(f"Mean slice dice : {slice_df.Dice.mean(axis=0):.3f}") logger.info(f"Mean volume dice : {volume_df.Dice.mean(axis=0):.3f}") logger.info(f"Mean posiitve slice AUC {slice_df[slice_df.label == 1].AUC.mean(axis=0):.3f}") # Save Scores and Config slice_df.to_csv(os.path.join(out_path, 'slice_predictions.csv')) logger.info(f"Slice prediction csv saved at {os.path.join(out_path, 'slice_predictions.csv')}") volume_df.to_csv(os.path.join(out_path, 'volume_predictions.csv')) logger.info(f"Volume prediction csv saved at {os.path.join(out_path, 'volume_predictions.csv')}") cfg.device = str(cfg.device) with open(os.path.join(out_path, 'config.json'), 'w') as f: json.dump(cfg, f) logger.info(f"Config file saved at {os.path.join(out_path, 'config.json')}")
def main(config_path): """ ResNet binary classification with the RSNA dataset. """ # load the config file cfg = AttrDict.from_json_path(config_path) # Make Outputs directories out_path = os.path.join(cfg.path.output, cfg.exp_name) os.makedirs(out_path, exist_ok=True) # Initialize random seed if cfg.seed != -1: random.seed(cfg.seed) np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) torch.cuda.manual_seed(cfg.seed) torch.cuda.manual_seed_all(cfg.seed) torch.backends.cudnn.deterministic = True # initialize logger logger = initialize_logger(os.path.join(out_path, 'log.txt')) if os.path.exists(os.path.join(out_path, f'checkpoint.pt')): logger.info('\n' + '#'*30 + f'\n Recovering Session \n' + '#'*30) logger.info(f"Experiment : {cfg.exp_name}") # Set number of thread if cfg.n_thread > 0: torch.set_num_threads(cfg.n_thread) # set device, if None use the first one if cfg.device: cfg.device = torch.device(cfg.device) else: cfg.device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu') logger.info(f"Device set to {cfg.device}. {torch.cuda.device_count()} GPU available, " f"{len(cfg.multi_gpu_id) if cfg.multi_gpu_id else 1} used.") # Load RSNA data csv df_rsna = pd.read_csv(os.path.join(cfg.path.data, 'slice_info.csv'), index_col=0) # Keep only fractions sample if cfg.dataset.n_data_0 >= 0: df_rsna_noICH = df_rsna[df_rsna.Hemorrhage == 0].sample(n=cfg.dataset.n_data_0, random_state=cfg.seed) else: df_rsna_noICH = df_rsna[df_rsna.Hemorrhage == 0] if cfg.dataset.n_data_1 >= 0: df_rsna_ICH = df_rsna[df_rsna.Hemorrhage == 1].sample(n=cfg.dataset.n_data_1, random_state=cfg.seed) else: df_rsna_ICH = df_rsna[df_rsna.Hemorrhage == 1] df_rsna = pd.concat([df_rsna_ICH, df_rsna_noICH], axis=0) # Split data to keep few for evaluation in a strafied way train_df, test_df = train_test_split(df_rsna, test_size=cfg.dataset.frac_eval, stratify=df_rsna.Hemorrhage, random_state=cfg.seed) logger.info('\n' + str(get_split_summary_table(df_rsna, train_df, test_df))) # Make dataset : Train --> BinaryClassification, Test --> BinaryClassification train_RSNA_dataset = RSNA_dataset(train_df, cfg.path.data, augmentation_transform=[getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.dataset.augmentation.train.items()], window=(cfg.data.win_center, cfg.data.win_width), output_size=cfg.data.size, mode='binary_classification') test_RSNA_dataset = RSNA_dataset(test_df, cfg.path.data, augmentation_transform=[getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.dataset.augmentation.eval.items()], window=(cfg.data.win_center, cfg.data.win_width), output_size=cfg.data.size, mode='binary_classification') logger.info(f"Data will be loaded from {cfg.path.data}.") logger.info(f"CT scans will be windowed on [{cfg.data.win_center-cfg.data.win_width/2} ; {cfg.data.win_center + cfg.data.win_width/2}]") logger.info(f"CT scans will be resized to {cfg.data.size}x{cfg.data.size}") logger.info(f"Training online data transformation: \n\n {str(train_RSNA_dataset.transform)}\n") logger.info(f"Evaluation online data transformation: \n\n {str(test_RSNA_dataset.transform)}\n") # Make Resnet Architecture resnet_network = getattr(resnet, cfg.net.resnet)(num_classes=cfg.net.num_classes, input_channels=cfg.net.input_channels) logger.info(f"Using a {cfg.net.resnet} architecture.") if cfg.multi_gpu_id is not None and len(cfg.multi_gpu_id) > 1: # set network for multi-GPU resnet_network = torch.nn.DataParallel(resnet_network, device_ids=cfg.multi_gpu_id) logger.info("Enabling the resnet for multi-GPU computation.") resnet_network = resnet_network.to(cfg.device) logger.info(f"The {cfg.net.resnet} has {sum(p.numel() for p in resnet_network.parameters())} parameters.") # Make model cfg.train.model_param.lr_scheduler = getattr(torch.optim.lr_scheduler, cfg.train.model_param.lr_scheduler) # convert scheduler name to scheduler class object cfg.train.model_param.loss_fn = getattr(torch.nn, cfg.train.model_param.loss_fn) # convert loss_fn name to nn.Module class object w_ICH = train_df.Hemorrhage.sum() / len(train_df) # define CE weighting from train dataset cfg.train.model_param.loss_fn_kwargs['weight'] = torch.tensor([1 - w_ICH, w_ICH], device=cfg.device).float() # add weighting to CE kwargs classifier = BinaryClassifier(resnet_network, device=cfg.device, print_progress=cfg.print_progress, **cfg.train.model_param) train_params = [f"--> {k} : {v}" for k, v in cfg.train.model_param.items()] logger.info("Classifer Training Parameters \n\t" + "\n\t".join(train_params)) # Load weights if specified if cfg.train.model_path_to_load: model_path = cfg.train.model_path_to_load classifier.load_model(model_path, map_location=cfg.device) logger.info(f"Classifer Model succesfully loaded from {cfg.train.model_path_to_load}") # Train if cfg.train.model_param.n_epoch > 0: classifier.train(train_RSNA_dataset, valid_dataset=test_RSNA_dataset, checkpoint_path=os.path.join(out_path, f'checkpoint.pt')) # Evaluate auc, acc, recall, precision, f1 = classifier.evaluate(test_RSNA_dataset, save_tsne=False, return_scores=True) logger.info(f"Classifier Test AUC : {auc:.2%}") logger.info(f"Classifier Test Accuracy : {acc:.2%}") logger.info(f"Classifier Test Recall : {recall:.2%}") logger.info(f"Classifier Test Precision : {precision:.2%}") logger.info(f"Classifier Test F1-score : {f1:.2%}") # save model, outputs classifier.save_model(os.path.join(out_path, 'resnet.pt')) logger.info(f"{cfg.net.resnet} saved at " + os.path.join(out_path, 'resnet.pt')) classifier.save_model_state_dict(os.path.join(out_path, 'resnet_state_dict.pt')) logger.info(f"{cfg.net.resnet} saved at " + os.path.join(out_path, 'resnet_state_dict.pt')) classifier.save_outputs(os.path.join(out_path, 'outputs.json')) logger.info("Classifier outputs saved at " + os.path.join(out_path, 'outputs.json')) test_df.reset_index(drop=True).to_csv(os.path.join(out_path, 'eval_data_info.csv')) logger.info("Evaluation data info saved at " + os.path.join(out_path, 'eval_data_info.csv')) # delete any checkpoints if os.path.exists(os.path.join(out_path, f'checkpoint.pt')): os.remove(os.path.join(out_path, f'checkpoint.pt')) logger.info('Checkpoint deleted.') cfg.device = str(cfg.device) cfg.train.model_param.lr_scheduler = str(cfg.train.model_param.lr_scheduler) cfg.train.model_param.loss_fn = str(cfg.train.model_param.loss_fn) cfg.train.model_param.loss_fn_kwargs.weight = list(cfg.train.model_param.loss_fn_kwargs.weight) with open(os.path.join(out_path, 'config.json'), 'w') as f: json.dump(cfg, f) logger.info(f"Config file saved at {os.path.join(out_path, 'config.json')}")
def main(config_path): """ Train an Inpainting generator with gated convolution through a SN-PatchGAN training scheme. The generator is trained to inpaint non-ICH CT scans from the RSNA dataset. """ # Load config file cfg = AttrDict.from_json_path(config_path) # make outputs dir out_path = os.path.join(cfg.path.output, cfg.exp_name) os.makedirs(out_path, exist_ok=True) if cfg.train.validate_epoch: os.makedirs(os.path.join(out_path, 'valid_results/'), exist_ok=True) # initialize seed if cfg.seed != -1: random.seed(cfg.seed) np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) torch.cuda.manual_seed(cfg.seed) torch.cuda.manual_seed_all(cfg.seed) torch.backends.cudnn.deterministic = True # initialize logger logger = initialize_logger(os.path.join(out_path, 'log.txt')) if os.path.exists(os.path.join(out_path, f'checkpoint.pt')): logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' + '#' * 30) logger.info(f"Experiment : {cfg.exp_name}") # set device if cfg.device: cfg.device = torch.device(cfg.device) else: cfg.device = torch.device( f'cuda:0') if torch.cuda.is_available() else torch.device('cpu') logger.info( f"Device set to {cfg.device}. {torch.cuda.device_count()} GPU available, " f"{len(cfg.multi_gpu_id) if cfg.multi_gpu_id else 1} used.") # set n_thread if cfg.n_thread > 0: torch.set_num_threads(cfg.n_thread) #-------------------------------------------------------------------- # Make Datasets #-------------------------------------------------------------------- # load RSNA data & keep normal only & and sample the required number df_rsna = pd.read_csv(os.path.join(cfg.path.data, 'slice_info.csv'), index_col=0) df_rsna_pos = df_rsna[df_rsna.Hemorrhage == 0] if cfg.dataset.n_sample >= 0: df_rsna_pos = df_rsna_pos.sample(n=cfg.dataset.n_sample, random_state=cfg.seed) # make dataset train_dataset = RSNA_Inpaint_dataset( df_rsna_pos, cfg.path.data, augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.dataset.augmentation.train.items() ], window=(cfg.dataset.win_center, cfg.dataset.win_width), output_size=cfg.dataset.size, **cfg.dataset.mask) # load small valid subset and make dataset if cfg.train.validate_epoch: df_valid = pd.read_csv(os.path.join(cfg.path.data_valid, 'info.csv'), index_col=0) valid_dataset = ImgMaskDataset( df_valid, cfg.path.data_valid, augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.dataset.augmentation.eval.items() ], window=(cfg.dataset.win_center, cfg.dataset.win_width), output_size=cfg.dataset.size) else: valid_dataset = None logger.info(f"Train Data will be loaded from {cfg.path.data}.") logger.info(f"Train contains {len(train_dataset)} samples.") logger.info(f"Valid Data will be loaded from {cfg.path.data_valid}.") if valid_dataset: logger.info(f"Valid contains {len(valid_dataset)} samples.") logger.info( f"CT scans will be windowed on [{cfg.dataset.win_center-cfg.dataset.win_width/2} ; {cfg.dataset.win_center + cfg.dataset.win_width/2}]" ) logger.info( f"CT scans will be resized to {cfg.dataset.size}x{cfg.dataset.size}") logger.info( f"Training online data transformation: \n\n {str(train_dataset.transform)}\n" ) if valid_dataset: logger.info( f"Evaluation online data transformation: \n\n {str(valid_dataset.transform)}\n" ) mask_params = [f"--> {k} : {v}" for k, v in cfg.dataset.mask.items()] logger.info("Train inpainting masks generated with \n\t" + "\n\t".join(mask_params)) #-------------------------------------------------------------------- # Make Networks #-------------------------------------------------------------------- if 'context_attention' in cfg.net.gen: cfg.net.gen.context_attention_kwargs[ 'device'] = cfg.device # add device to kwargs of contextual attention module generator_net = GatedGenerator(**cfg.net.gen) elif 'self_attention' in cfg.net.gen: generator_net = SAGatedGenerator(**cfg.net.gen) discriminator_net = PatchDiscriminator(**cfg.net.dis) if cfg.multi_gpu_id is not None and len( cfg.multi_gpu_id) > 1: # set network for multi-GPU #generator_net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(generator_net) generator_net = torch.nn.DataParallel(generator_net, device_ids=cfg.multi_gpu_id) #discriminator_net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator_net) discriminator_net = torch.nn.DataParallel(discriminator_net, device_ids=cfg.multi_gpu_id) logger.info("Enabling multi-GPU computation.") gen_params = [f"--> {k} : {v}" for k, v in cfg.net.gen.items()] logger.info("Gated Generator Parameters \n\t" + "\n\t".join(gen_params)) dis_params = [f"--> {k} : {v}" for k, v in cfg.net.dis.items()] logger.info("Gated Discriminator Parameters \n\t" + "\n\t".join(dis_params)) #-------------------------------------------------------------------- # Make Inpainting GAN model #-------------------------------------------------------------------- cfg.train.model_param.lr_scheduler = getattr( torch.optim.lr_scheduler, cfg.train.model_param.lr_scheduler ) # convert scheduler name to scheduler class object gan_model = SNPatchGAN(generator_net, discriminator_net, print_progress=cfg.print_progress, device=cfg.device, **cfg.train.model_param) train_params = [f"--> {k} : {v}" for k, v in cfg.train.model_param.items()] logger.info("GAN Training Parameters \n\t" + "\n\t".join(train_params)) # load models if provided if cfg.train.model_path_to_load.gen: gan_model.load_Generator(cfg.train.model_path_to_load.gen, map_location=cfg.device) if cfg.train.model_path_to_load.dis: gan_model.load_Discriminator(cfg.train.model_path_to_load.dis, map_location=cfg.device) #-------------------------------------------------------------------- # Train SN-PatchGAN model #-------------------------------------------------------------------- if cfg.train.model_param.n_epoch > 0: gan_model.train(train_dataset, checkpoint_path=os.path.join(out_path, 'Checkpoint.pt'), valid_dataset=valid_dataset, valid_path=os.path.join(out_path, 'valid_results/'), save_freq=cfg.train.valid_save_freq) #-------------------------------------------------------------------- # Save outputs, models and config #-------------------------------------------------------------------- # save models gan_model.save_models(export_fn=(os.path.join(out_path, 'generator.pt'), os.path.join(out_path, 'discriminator.pt')), which='both') logger.info("Generator model saved at " + os.path.join(out_path, 'generator.pt')) logger.info("Discriminator model saved at " + os.path.join(out_path, 'discriminator.pt')) # save outputs gan_model.save_outputs(export_fn=os.path.join(out_path, 'outputs.json')) logger.info("Outputs file saved at " + os.path.join(out_path, 'outputs.json')) # save config file cfg.device = str( cfg.device) # set device as string to be JSON serializable if 'context_attention' in cfg.net.gen: cfg.net.gen.context_attention_kwargs.device = str( cfg.net.gen.context_attention_kwargs.device) cfg.train.model_param.lr_scheduler = str( cfg.train.model_param.lr_scheduler) with open(os.path.join(out_path, 'config.json'), 'w') as fp: json.dump(cfg, fp) logger.info("Config file saved at " + os.path.join(out_path, 'config.json')) # delete any checkpoints if os.path.exists(os.path.join(out_path, f'Checkpoint.pt')): os.remove(os.path.join(out_path, f'Checkpoint.pt')) logger.info('Checkpoint deleted.')
def main(config_path): """ Segmente ICH using the anomaly inpainting approach on a whole dataset and compute slice/volume dice. """ # load config cfg = AttrDict.from_json_path(config_path) # make outputs dir out_path = os.path.join(cfg.path.output, cfg.exp_name) os.makedirs(out_path, exist_ok=True) # initialize seed if cfg.seed != -1: random.seed(cfg.seed) np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) torch.cuda.manual_seed(cfg.seed) torch.cuda.manual_seed_all(cfg.seed) torch.backends.cudnn.deterministic = True # set device cfg.device = torch.device( cfg.device) if cfg.device else get_available_device() # initialize logger logger = initialize_logger(os.path.join(out_path, 'log.txt')) logger.info(f"Experiment : {cfg.exp_name}") # get Dataset data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'ct_info.csv'), index_col=0) #data_info_df = data_info_df[sum([data_info_df.CT_fn.str.contains(s) for s in ['49/16', '51/39', '71/15', '71/22', '75/22']]) > 0] dataset = public_SegICH_Dataset2D( data_info_df, cfg.path.data, augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.data.augmentation.items() ], output_size=cfg.data.size, window=(cfg.data.win_center, cfg.data.win_width)) # load inpainting model cfg_inpaint = AttrDict.from_json_path(cfg.inpainter_cfg_path) cfg_inpaint.net.gen.return_coarse = False inpaint_net = SAGatedGenerator(**cfg_inpaint.net.gen) loaded_state_dict = torch.load(cfg.inpainter_model_path, map_location=cfg.device) inpaint_net.load_state_dict(loaded_state_dict) inpaint_net = inpaint_net.to( cfg.device ) # inpainter not in eval mode beacuse batch norm layers are not stabilized (because GAN optimization) logger.info( f"Inpainter model succesfully loaded from {cfg.inpainter_model_path}") # make AD inpainter Module ad_inpainter = InpaintAnomalyDetector(inpaint_net, device=cfg.device, **cfg.model_param) # Load Classifier if cfg.classifier_model_path is not None: cfg_classifier = AttrDict.from_json_path( os.path.join(cfg.classifier_model_path, 'config.json')) classifier = getattr(rn, cfg_classifier.net.resnet)( num_classes=cfg_classifier.net.num_classes, input_channels=cfg_classifier.net.input_channels) classifier_state_dict = torch.load(os.path.join( cfg.classifier_model_path, 'resnet_state_dict.pt'), map_location=cfg.device) classifier.load_state_dict(classifier_state_dict) classifier = classifier.to(cfg.device) classifier.eval() logger.info( f"ResNet classifier model succesfully loaded from {os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt')}" ) # iterate over dataset all_pred = [] for i, sample in enumerate(dataset): # unpack data image, target, id, slice = sample logger.info( "=" * 25 + f" SAMPLE {i+1:04}/{len(dataset):04} - Volume {id:03} Slice {slice:03} " + "=" * 25) # Classify sample if cfg.classifier_model_path is not None: with torch.no_grad(): input_clss = image.unsqueeze(0).to(cfg.device).float() pred_score = nn.functional.softmax( classifier(input_clss), dim=1 )[:, 1] # take columns of softmax of positive class as score pred = 1 if pred_score >= cfg.classification_threshold else 0 else: pred = 1 # if not classifier given, all slices are processed # process slice if classifier has detected Hemorrhage if pred == 1: logger.info( f"ICH detected. Compute anomaly mask through inpainting.") # Detect anomalies using the robuste approach ad_mask, ano_map, intermediate_masks = robust_anomaly_detect( image, ad_inpainter, save_dir=None, verbose=True, return_intermediate=True, **cfg.robust_param) # save ad_mask ad_mask_fn = f"{id}/{slice}_anomalies.bmp" save_path = os.path.join(out_path, 'pred/', ad_mask_fn) if not os.path.isdir(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path), exist_ok=True) io.imsave(save_path, img_as_ubyte(ad_mask), check_contrast=False) # save anomaly map ad_map_fn = f"{id}/{slice}_map_anomalies.png" save_path_map = os.path.join(out_path, 'pred/', ad_map_fn) io.imsave(save_path_map, img_as_ubyte(ano_map), check_contrast=False) # save intermediate mask for j, m in enumerate(intermediate_masks): if not os.path.isdir( os.path.join(out_path, f"pred/{id}/intermediate_masks/")): os.makedirs(os.path.join(out_path, f"pred/{id}/intermediate_masks/"), exist_ok=True) io.imsave(os.path.join( out_path, f"pred/{id}/intermediate_masks/{slice}_anomalies_{j+1}.bmp" ), img_as_ubyte(m), check_contrast=False) else: logger.info(f"No ICH detected. Set the anomaly mask to zeros.") ad_mask = np.zeros_like(target[0].numpy()) ad_mask_fn, ad_map_fn = 'None', 'None' # compute confusion matrix with target ICH mask tn, fp, fn, tp = confusion_matrix(target[0].numpy().ravel(), ad_mask.ravel(), labels=[0, 1]).ravel() # append to all_pred list all_pred.append({ 'id': id.item(), 'slice': slice.item(), 'label': target.max().item(), 'TP': tp, 'TN': tn, 'FP': fp, 'FN': fn, 'ad_mask_fn': ad_mask_fn, 'ad_map_fn': ad_map_fn }) # make a dataframe of all predictions slice_df = pd.DataFrame(all_pred) volume_df = slice_df[['id', 'label', 'TP', 'TN', 'FP', 'FN']].groupby('id').agg({ 'label': 'max', 'TP': 'sum', 'TN': 'sum', 'FP': 'sum', 'FN': 'sum' }) # Compute Dice and Volume Dice slice_df['Dice'] = (2 * slice_df.TP + 1) / (2 * slice_df.TP + slice_df.FP + slice_df.FN + 1) volume_df['Dice'] = (2 * volume_df.TP + 1) / ( 2 * volume_df.TP + volume_df.FP + volume_df.FN + 1) logger.info(f"Mean slice dice : {slice_df.Dice.mean(axis=0):.3f}") logger.info(f"Mean volume dice : {volume_df.Dice.mean(axis=0):.3f}") # Save Scores and Config slice_df.to_csv(os.path.join(out_path, 'slice_predictions.csv')) logger.info( f"Slice prediction csv saved at {os.path.join(out_path, 'slice_predictions.csv')}" ) volume_df.to_csv(os.path.join(out_path, 'volume_predictions.csv')) logger.info( f"Volume prediction csv saved at {os.path.join(out_path, 'volume_predictions.csv')}" ) cfg.device = str(cfg.device) with open(os.path.join(out_path, 'config.json'), 'w') as f: json.dump(cfg, f) logger.info( f"Config file saved at {os.path.join(out_path, 'config.json')}")
def main(config_path): """ Train an Auto-Encoder to reconstruct CT-scans from the RSNA dataset. """ # Load config file cfg = AttrDict.from_json_path(config_path) # make outputs dir out_path = os.path.join(cfg.path.output, cfg.exp_name) os.makedirs(out_path, exist_ok=True) if cfg.train.validate_epoch: os.makedirs(os.path.join(out_path, 'valid_results/'), exist_ok=True) # initialize seed if cfg.seed != None: random.seed(cfg.seed) np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) torch.cuda.manual_seed(cfg.seed) torch.cuda.manual_seed_all(cfg.seed) torch.backends.cudnn.deterministic = True # initialize logger logger = initialize_logger(os.path.join(out_path, 'log.txt')) if os.path.exists(os.path.join(out_path, f'checkpoint.pt')): logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' + '#' * 30) logger.info(f"Experiment : {cfg.exp_name}") # set device if cfg.device: cfg.device = torch.device(cfg.device) else: cfg.device = torch.device( f'cuda:0') if torch.cuda.is_available() else torch.device('cpu') logger.info(f"Device set to {cfg.device}.") #-------------------------------------------------------------------- # Make Datasets #-------------------------------------------------------------------- # load RSNA data & keep normal only & and sample the required number df_rsna = pd.read_csv(os.path.join(cfg.path.data, 'slice_info.csv'), index_col=0) df_rsna_pos = df_rsna[df_rsna.Hemorrhage == 0] if cfg.dataset.n_sample > 0: df_rsna_pos = df_rsna_pos.sample(n=cfg.dataset.n_sample, random_state=cfg.seed) # split df to keep n_sample_valid for validation if (cfg.dataset.n_sample_valid > 0) & cfg.train.validate_epoch: df_train, df_valid = train_test_split( df_rsna_pos, test_size=cfg.dataset.n_sample_valid, random_state=cfg.seed) else: df_train = df_rsna_pos # make dataset train_dataset = RSNA_dataset( df_train, cfg.path.data, mode='standard', augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.dataset.augmentation.train.items() ], window=(cfg.dataset.win_center, cfg.dataset.win_width), output_size=cfg.dataset.size) # load small valid subset and make dataset if cfg.train.validate_epoch: valid_dataset = RSNA_dataset(df_valid, cfg.path.data, mode='standard', augmentation_transform=[ getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.dataset.augmentation.eval.items() ], window=(cfg.dataset.win_center, cfg.dataset.win_width), output_size=cfg.dataset.size) else: valid_dataset = None logger.info(f"Data will be loaded from {cfg.path.data}.") logger.info(f"Train contains {len(train_dataset)} samples.") if valid_dataset: logger.info(f"Valid contains {len(valid_dataset)} samples.") logger.info( f"CT scans will be windowed on [{cfg.dataset.win_center-cfg.dataset.win_width/2} ; {cfg.dataset.win_center + cfg.dataset.win_width/2}]" ) logger.info( f"CT scans will be resized to {cfg.dataset.size}x{cfg.dataset.size}") logger.info( f"Training online data transformation: \n\n {str(train_dataset.transform)}\n" ) if valid_dataset: logger.info( f"Evaluation online data transformation: \n\n {str(valid_dataset.transform)}\n" ) #-------------------------------------------------------------------- # Make Networks #-------------------------------------------------------------------- ae_net = AE_net(**cfg.net) ae_params = [f"--> {k} : {v}" for k, v in cfg.net.items()] logger.info("AE Parameters \n\t" + "\n\t".join(ae_params)) #-------------------------------------------------------------------- # Make AE model #-------------------------------------------------------------------- cfg.train.model_param.lr_scheduler = getattr( torch.optim.lr_scheduler, cfg.train.model_param.lr_scheduler ) # convert scheduler name to scheduler class object ae_model = AE(ae_net, print_progress=cfg.print_progress, device=cfg.device, **cfg.train.model_param) train_params = [f"--> {k} : {v}" for k, v in cfg.train.model_param.items()] logger.info("AE Training Parameters \n\t" + "\n\t".join(train_params)) # load models if provided if cfg.train.model_path_to_load: ae_model.load_model(cfg.train.model_path_to_load, map_location=cfg.device) #-------------------------------------------------------------------- # Train AE model #-------------------------------------------------------------------- if cfg.train.model_param.n_epoch > 0: ae_model.train(train_dataset, checkpoint_path=os.path.join(out_path, 'Checkpoint.pt'), valid_dataset=valid_dataset, valid_path=os.path.join(out_path, 'valid_results/'), valid_freq=cfg.train.valid_save_freq) #-------------------------------------------------------------------- # Save outputs, models and config #-------------------------------------------------------------------- # save models ae_model.save_model(export_fn=os.path.join(out_path, 'AE.pt')) logger.info("AE model saved at " + os.path.join(out_path, 'AE.pt')) # save outputs ae_model.save_outputs(export_fn=os.path.join(out_path, 'outputs.json')) logger.info("Outputs file saved at " + os.path.join(out_path, 'outputs.json')) # save config file cfg.device = str( cfg.device) # set device as string to be JSON serializable cfg.train.model_param.lr_scheduler = str( cfg.train.model_param.lr_scheduler) with open(os.path.join(out_path, 'config.json'), 'w') as fp: json.dump(cfg, fp) logger.info("Config file saved at " + os.path.join(out_path, 'config.json'))