def main(config_path):
    """
    Train and evaluate a 2D UNet on the public ICH dataset using the parameters sepcified on the JSON at the
    config_path. The evaluation is performed by k-fold cross-validation.
    """
    # load config file
    cfg = Config(settings=None)
    cfg.load_config(config_path)

    # Make Output directories
    out_path = os.path.join(cfg.settings['path']['OUTPUT'],
                            cfg.settings['exp_name'])  # + '/'
    os.makedirs(out_path, exist_ok=True)
    for k in range(cfg.settings['split']['n_fold']):
        os.makedirs(os.path.join(out_path, f'Fold_{k+1}/pred/'), exist_ok=True)

    # Initialize random seed to given seed
    seed = cfg.settings['seed']
    if seed != -1:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

    # Load data csv
    data_info_df = pd.read_csv(
        os.path.join(cfg.settings['path']['DATA'], 'ct_info.csv'))
    data_info_df = data_info_df.drop(data_info_df.columns[0], axis=1)
    patient_df = pd.read_csv(
        os.path.join(cfg.settings['path']['DATA'], 'patient_info.csv'))
    patient_df = patient_df.drop(patient_df.columns[0], axis=1)

    # Generate Cross-Val indices at the patient level
    skf = StratifiedKFold(n_splits=cfg.settings['split']['n_fold'],
                          shuffle=cfg.settings['split']['shuffle'],
                          random_state=seed)
    # iterate over folds and ensure that there are the same amount of ICH positive patient per fold --> Stratiffied CrossVal
    for k, (train_idx, test_idx) in enumerate(
            skf.split(patient_df.PatientNumber, patient_df.Hemorrhage)):
        # if fold results not already there
        if not os.path.exists(
                os.path.join(out_path, f'Fold_{k+1}/outputs.json')):
            # initialize logger
            logging.basicConfig(level=logging.INFO)
            logger = logging.getLogger()
            try:
                logger.handlers[1].stream.close()
                logger.removeHandler(logger.handlers[1])
            except IndexError:
                pass
            logger.setLevel(logging.INFO)
            file_handler = logging.FileHandler(
                os.path.join(out_path, f'Fold_{k+1}/log.txt'))
            file_handler.setLevel(logging.INFO)
            file_handler.setFormatter(
                logging.Formatter('%(asctime)s | %(levelname)s | %(message)s'))
            logger.addHandler(file_handler)

            if os.path.exists(
                    os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt')):
                logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' +
                            '#' * 30)

            logger.info(f"Experiment : {cfg.settings['exp_name']}")
            logger.info(
                f"Cross-Validation fold {k+1:02}/{cfg.settings['split']['n_fold']:02}"
            )

            # initialize nbr of thread
            if cfg.settings['n_thread'] > 0:
                torch.set_num_threads(cfg.settings['n_thread'])
            logger.info(f"Number of thread : {cfg.settings['n_thread']}")
            # check if GPU available
            #cfg.settings['device'] = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
            if cfg.settings['device'] is not None:
                cfg.settings['device'] = torch.device(cfg.settings['device'])
            else:
                if torch.cuda.is_available():
                    free_mem, device_idx = 0.0, 0
                    for d in range(torch.cuda.device_count()):
                        mem = torch.cuda.get_device_properties(
                            d).total_memory - torch.cuda.memory_allocated(d)
                        if mem > free_mem:
                            device_idx = d
                            free_mem = mem
                    cfg.settings['device'] = torch.device(f'cuda:{device_idx}')
                else:
                    cfg.settings['device'] = torch.device('cpu')
            logger.info(f"Device : {cfg.settings['device']}")

            # extract train and test DataFrames + print summary (n samples positive and negatives)
            train_df = data_info_df[data_info_df.PatientNumber.isin(
                patient_df.loc[train_idx, 'PatientNumber'].values)]
            test_df = data_info_df[data_info_df.PatientNumber.isin(
                patient_df.loc[test_idx, 'PatientNumber'].values)]
            # sample the dataframe to have more or less normal slices
            n_remove = int(
                max(
                    0,
                    len(train_df[train_df.Hemorrhage == 0]) -
                    cfg.settings['dataset']['frac_negative'] *
                    len(train_df[train_df.Hemorrhage == 1])))
            df_remove = train_df[train_df.Hemorrhage == 0].sample(
                n=n_remove, random_state=seed)
            train_df = train_df[~train_df.index.isin(df_remove.index)]
            logger.info(
                '\n' +
                str(get_split_summary_table(data_info_df, train_df, test_df)))

            # Make Dataset + print online augmentation summary
            train_dataset = public_SegICH_Dataset2D(
                train_df,
                cfg.settings['path']['DATA'],
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg.settings['data']['augmentation']['train'].items()
                ],
                window=(cfg.settings['data']['win_center'],
                        cfg.settings['data']['win_width']),
                output_size=cfg.settings['data']['size'])
            test_dataset = public_SegICH_Dataset2D(
                test_df,
                cfg.settings['path']['DATA'],
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg.settings['data']['augmentation']['eval'].items()
                ],
                window=(cfg.settings['data']['win_center'],
                        cfg.settings['data']['win_width']),
                output_size=cfg.settings['data']['size'])
            logger.info(
                f"Data will be loaded from {cfg.settings['path']['DATA']}.")
            logger.info(
                f"CT scans will be windowed on [{cfg.settings['data']['win_center']-cfg.settings['data']['win_width']/2} ; {cfg.settings['data']['win_center'] + cfg.settings['data']['win_width']/2}]"
            )
            logger.info(
                f"Training online data transformation: \n\n {str(train_dataset.transform)}\n"
            )
            logger.info(
                f"Evaluation online data transformation: \n\n {str(test_dataset.transform)}\n"
            )

            # Make architecture (and print summmary ??)
            unet_arch = UNet(
                depth=cfg.settings['net']['depth'],
                top_filter=cfg.settings['net']['top_filter'],
                use_3D=cfg.settings['net']['3D'],
                in_channels=cfg.settings['net']['in_channels'],
                out_channels=cfg.settings['net']['out_channels'],
                bilinear=cfg.settings['net']['bilinear'],
                midchannels_factor=cfg.settings['net']['midchannels_factor'],
                p_dropout=cfg.settings['net']['p_dropout'])
            unet_arch.to(cfg.settings['device'])
            logger.info(
                f"U-Net2D initialized with a depth of {cfg.settings['net']['depth']}"
                f" and a number of initial filter of {cfg.settings['net']['top_filter']},"
            )
            logger.info(
                f"Reconstruction performed with {'Upsample + Conv' if cfg.settings['net']['bilinear'] else 'ConvTranspose'}."
            )
            logger.info(
                f"U-Net2D takes {cfg.settings['net']['in_channels']} as input channels and {cfg.settings['net']['out_channels']} as output channels."
            )
            logger.info(
                f"The U-Net2D has {sum(p.numel() for p in unet_arch.parameters())} parameters."
            )

            # Make model
            unet2D = UNet2D(
                unet_arch,
                n_epoch=cfg.settings['train']['n_epoch'],
                batch_size=cfg.settings['train']['batch_size'],
                lr=cfg.settings['train']['lr'],
                lr_scheduler=getattr(torch.optim.lr_scheduler,
                                     cfg.settings['train']['lr_scheduler']),
                lr_scheduler_kwargs=cfg.settings['train']
                ['lr_scheduler_kwargs'],
                loss_fn=getattr(src.models.optim.LossFunctions,
                                cfg.settings['train']['loss_fn']),
                loss_fn_kwargs=cfg.settings['train']['loss_fn_kwargs'],
                weight_decay=cfg.settings['train']['weight_decay'],
                num_workers=cfg.settings['train']['num_workers'],
                device=cfg.settings['device'],
                print_progress=cfg.settings['print_progress'])

            # Load model if required
            if cfg.settings['train']['model_path_to_load']:
                if isinstance(cfg.settings['train']['model_path_to_load'],
                              str):
                    model_path = cfg.settings['train']['model_path_to_load']
                    unet2D.load_model(model_path,
                                      map_location=cfg.settings['device'])
                elif isinstance(cfg.settings['train']['model_path_to_load'],
                                list):
                    model_path = cfg.settings['train']['model_path_to_load'][k]
                    unet2D.load_model(model_path,
                                      map_location=cfg.settings['device'])
                else:
                    raise ValueError(
                        f'Model path to load type not understood.')
                logger.info(f"2D U-Net model loaded from {model_path}")

            # print Training hyper-parameters
            train_params = []
            for key, value in cfg.settings['train'].items():
                train_params.append(f"--> {key} : {value}")
            logger.info('Training settings:\n\t' + '\n\t'.join(train_params))

            # Train model
            eval_dataset = test_dataset if cfg.settings['train'][
                'validate_epoch'] else None
            unet2D.train(train_dataset,
                         valid_dataset=eval_dataset,
                         checkpoint_path=os.path.join(
                             out_path, f'Fold_{k+1}/checkpoint.pt'))

            # Evaluate model
            unet2D.evaluate(test_dataset,
                            save_path=os.path.join(out_path,
                                                   f'Fold_{k+1}/pred/'))

            # Save models & outputs
            unet2D.save_model(
                os.path.join(out_path, f'Fold_{k+1}/trained_unet.pt'))
            logger.info("Trained U-Net saved at " +
                        os.path.join(out_path, f'Fold_{k+1}/trained_unet.pt'))
            unet2D.save_outputs(
                os.path.join(out_path, f'Fold_{k+1}/outputs.json'))
            logger.info("Trained statistics saved at " +
                        os.path.join(out_path, f'Fold_{k+1}/outputs.json'))

            # delete checkpoint if exists
            if os.path.exists(
                    os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt')):
                os.remove(os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt'))
                logger.info('Checkpoint deleted.')

    # save mean +/- 1.96 std Dice in .txt file
    scores_list = []
    for k in range(cfg.settings['split']['n_fold']):
        with open(os.path.join(out_path, f'Fold_{k+1}/outputs.json'),
                  'r') as f:
            out = json.load(f)
        scores_list.append(
            [out['eval']['dice']['all'], out['eval']['dice']['positive']])
    means = np.array(scores_list).mean(axis=0)
    CI95 = 1.96 * np.array(scores_list).std(axis=0)
    with open(os.path.join(out_path, 'average_scores.txt'), 'w') as f:
        f.write(f'Dice = {means[0]} +/- {CI95[0]}\n')
        f.write(f'Dice (Positive) = {means[1]} +/- {CI95[1]}\n')
    logger.info('Average Scores saved at ' +
                os.path.join(out_path, 'average_scores.txt'))

    # generate dataframe of all prediction
    df_list = [
        pd.read_csv(
            os.path.join(out_path,
                         f'Fold_{i+1}/pred/volume_prediction_scores.csv'))
        for i in range(cfg.settings['split']['n_fold'])
    ]
    all_df = pd.concat(df_list, axis=0).reset_index(drop=True)
    all_df.to_csv(os.path.join(out_path, 'all_volume_prediction.csv'))
    logger.info('CSV of all volumes prediction saved at ' +
                os.path.join(out_path, 'all_volume_prediction.csv'))

    # Save config file
    cfg.settings['device'] = str(cfg.settings['device'])
    cfg.save_config(os.path.join(out_path, 'config.json'))
    logger.info("Config file saved at " +
                os.path.join(out_path, 'config.json'))

    # Analyse results
    analyse_supervised_exp(out_path,
                           cfg.settings['path']['DATA'],
                           cfg.settings['split']['n_fold'],
                           save_fn=os.path.join(out_path,
                                                'results_overview.pdf'))
    logger.info('Results overview figure saved at ' +
                os.path.join(out_path, 'results_overview.pdf'))
Example #2
0
def main(config_path):
    """ """
    # load config
    cfg = AttrDict.from_json_path(config_path)

    # make outputs dir
    out_path = os.path.join(cfg.path.output, cfg.exp_name)
    os.makedirs(out_path, exist_ok=True)

    # initialize seed
    if cfg.seed != -1:
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
        torch.backends.cudnn.deterministic = True

    # initialize logger
    logger = initialize_logger(os.path.join(out_path, 'log.txt'))
    logger.info(f"Experiment : {cfg.exp_name}")

    # set device
    if cfg.device:
        cfg.device = torch.device(cfg.device)
    else:
        cfg.device = torch.device(
            f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    logger.info(f"Device set to {cfg.device}.")

    # get Dataset
    data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'ct_info.csv'),
                               index_col=0)
    dataset = public_SegICH_Dataset2D(
        data_info_df,
        cfg.path.data,
        augmentation_transform=[
            getattr(tf, tf_name)(**tf_kwargs)
            for tf_name, tf_kwargs in cfg.data.augmentation.items()
        ],
        output_size=cfg.data.size,
        window=(cfg.data.win_center, cfg.data.win_width))

    # load inpainting model
    cfg_ae = AttrDict.from_json_path(cfg.ae_cfg_path)
    ae_net = AE_net(**cfg_ae.net)
    loaded_state_dict = torch.load(cfg.ae_model_path, map_location=cfg.device)
    ae_net.load_state_dict(loaded_state_dict)
    ae_net = ae_net.to(cfg.device).eval()
    logger.info(f"AE model succesfully loaded from {cfg.ae_model_path}")

    # Load Classifier
    if cfg.classifier_model_path is not None:
        cfg_classifier = AttrDict.from_json_path(
            os.path.join(cfg.classifier_model_path, 'config.json'))
        classifier = getattr(rn, cfg_classifier.net.resnet)(
            num_classes=cfg_classifier.net.num_classes,
            input_channels=cfg_classifier.net.input_channels)
        classifier_state_dict = torch.load(os.path.join(
            cfg.classifier_model_path, 'resnet_state_dict.pt'),
                                           map_location=cfg.device)
        classifier.load_state_dict(classifier_state_dict)
        classifier = classifier.to(cfg.device)
        classifier.eval()
        logger.info(
            f"ResNet classifier model succesfully loaded from {os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt')}"
        )

    # iterate over dataset
    all_pred = []
    for i, sample in enumerate(dataset):
        # unpack data
        image, target, id, slice = sample
        logger.info(
            "=" * 25 +
            f" SAMPLE {i+1:04}/{len(dataset):04} - Volume {id:03} Slice {slice:03} "
            + "=" * 25)

        # Classify sample
        if cfg.classifier_model_path is not None:
            with torch.no_grad():
                input_clss = image.unsqueeze(0).to(cfg.device).float()
                pred_score = nn.functional.softmax(
                    classifier(input_clss), dim=1
                )[:, 1]  # take columns of softmax of positive class as score
                pred = 1 if pred_score >= cfg.classification_threshold else 0
        else:
            pred = 1  # if not classifier given, all slices are processed

        # process slice if classifier has detected Hemorrhage
        if pred == 1:
            logger.info(
                f"ICH detected. Compute anomaly mask through AE reconstruction."
            )
            # Detect anomalies using the robuste approach
            ad_map, ad_mask = compute_anomaly(ae_net,
                                              image,
                                              alpha_low=cfg.alpha_low,
                                              alpha_high=cfg.alpha_high,
                                              device=cfg.device)
            logger.info(f"{ad_mask.sum()} anomalous pixels detected.")
            # save ad_mask
            ad_mask_fn = f"{id}/{slice}_anomalies.bmp"
            save_path = os.path.join(out_path, 'pred/', ad_mask_fn)
            if not os.path.isdir(os.path.dirname(save_path)):
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
            io.imsave(save_path, img_as_ubyte(ad_mask), check_contrast=False)
            # save anomaly map
            ad_map_fn = f"{id}/{slice}_map_anomalies.png"
            save_path_map = os.path.join(out_path, 'pred/', ad_map_fn)
            io.imsave(save_path_map,
                      img_as_ubyte(
                          rescale_intensity(ad_map, out_range=(0.0, 1.0))),
                      check_contrast=False)
        else:
            logger.info(f"No ICH detected. Set the anomaly mask to zeros.")
            ad_mask = np.zeros_like(target[0].numpy())
            ad_mask_fn, ad_map_fn = 'None', 'None'

        # compute confusion matrix with target ICH mask
        tn, fp, fn, tp = confusion_matrix(target[0].numpy().ravel(),
                                          ad_mask.ravel(),
                                          labels=[0, 1]).ravel()
        auc = roc_auc_score(target[0].numpy().ravel(),
                            ad_map.ravel()) if torch.any(target[0]) else 'None'
        # append to all_pred list
        all_pred.append({
            'id': id.item(),
            'slice': slice.item(),
            'label': target.max().item(),
            'TP': tp,
            'TN': tn,
            'FP': fp,
            'FN': fn,
            'AUC': auc,
            'ad_mask_fn': ad_mask_fn,
            'ad_map_fn': ad_map_fn
        })

    # make a dataframe of all predictions
    slice_df = pd.DataFrame(all_pred)
    volume_df = slice_df[['id', 'label', 'TP', 'TN', 'FP',
                          'FN']].groupby('id').agg({
                              'label': 'max',
                              'TP': 'sum',
                              'TN': 'sum',
                              'FP': 'sum',
                              'FN': 'sum'
                          })

    # Compute Dice and Volume Dice
    slice_df['Dice'] = (2 * slice_df.TP + 1) / (2 * slice_df.TP + slice_df.FP +
                                                slice_df.FN + 1)
    volume_df['Dice'] = (2 * volume_df.TP + 1) / (
        2 * volume_df.TP + volume_df.FP + volume_df.FN + 1)
    logger.info(f"Mean slice dice : {slice_df.Dice.mean(axis=0):.3f}")
    logger.info(f"Mean volume dice : {volume_df.Dice.mean(axis=0):.3f}")
    logger.info(
        f"Mean posiitve slice AUC {slice_df[slice_df.label == 1].AUC.mean(axis=0):.3f}"
    )

    # Save Scores and Config
    slice_df.to_csv(os.path.join(out_path, 'slice_predictions.csv'))
    logger.info(
        f"Slice prediction csv saved at {os.path.join(out_path, 'slice_predictions.csv')}"
    )
    volume_df.to_csv(os.path.join(out_path, 'volume_predictions.csv'))
    logger.info(
        f"Volume prediction csv saved at {os.path.join(out_path, 'volume_predictions.csv')}"
    )
    cfg.device = str(cfg.device)
    with open(os.path.join(out_path, 'config.json'), 'w') as f:
        json.dump(cfg, f)
    logger.info(
        f"Config file saved at {os.path.join(out_path, 'config.json')}")
def main(config_path):
    """
    UNet2D pretrained on binary classification with the RSNA dataset and finetuned with the public data.
    """
    # load the config file
    cfg = AttrDict.from_json_path(config_path)

    # Make Outputs directories
    out_path = os.path.join(cfg.path.output, cfg.exp_name)
    out_path_selfsup = os.path.join(out_path, 'classification_pretrain/')
    out_path_sup = os.path.join(out_path, 'supervised_train/')
    os.makedirs(out_path_selfsup, exist_ok=True)
    for k in range(cfg.Sup.split.n_fold):
        os.makedirs(os.path.join(out_path_sup, f'Fold_{k+1}/pred/'),
                    exist_ok=True)

    # Initialize random seed
    if cfg.seed != -1:
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
        torch.backends.cudnn.deterministic = True

    # Set number of thread
    if cfg.n_thread > 0: torch.set_num_threads(cfg.n_thread)
    # set device
    if cfg.device:
        cfg.device = torch.device(cfg.device)
    else:
        cfg.device = get_available_device()

    ####################################################
    # Self-supervised training on Multi Classification #
    ####################################################
    # Initialize Logger
    logger = initialize_logger(os.path.join(out_path_selfsup, 'log.txt'))
    if os.path.exists(os.path.join(out_path_selfsup, f'checkpoint.pt')):
        logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' + '#' * 30)
    logger.info(f"Experiment : {cfg.exp_name}")

    # Load RSNA data csv
    df_rsna = pd.read_csv(os.path.join(cfg.path.data.SSL, 'slice_info.csv'),
                          index_col=0)

    # Keep only fractions sample
    if cfg.SSL.dataset.n_data_1 >= 0:
        df_rsna_ICH = df_rsna[df_rsna.Hemorrhage == 1].sample(
            n=cfg.SSL.dataset.n_data_1, random_state=cfg.seed)
    else:
        df_rsna_ICH = df_rsna[df_rsna.Hemorrhage == 1]
    if cfg.SSL.dataset.f_data_0 >= 0:  # nbr of normal as fraction of nbr of ICH samples
        df_rsna_noICH = df_rsna[df_rsna.Hemorrhage == 0].sample(
            n=cfg.SSL.dataset.f_data_0 * len(df_rsna_ICH),
            random_state=cfg.seed)
    else:
        df_rsna_noICH = df_rsna[df_rsna.Hemorrhage == 0]
    df_rsna = pd.concat([df_rsna_ICH, df_rsna_noICH], axis=0)

    # Split data to keep few for evaluation in a strafied way
    train_df, test_df = train_test_split(df_rsna,
                                         test_size=cfg.SSL.dataset.frac_eval,
                                         stratify=df_rsna.Hemorrhage,
                                         random_state=cfg.seed)
    logger.info('\n' +
                str(get_split_summary_table(df_rsna, train_df, test_df)))

    # Make dataset : Train --> BinaryClassification, Test --> BinaryClassification
    train_RSNA_dataset = RSNA_dataset(
        train_df,
        cfg.path.data.SSL,
        augmentation_transform=[
            getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
            cfg.SSL.dataset.augmentation.train.items()
        ],
        window=(cfg.data.win_center, cfg.data.win_width),
        output_size=cfg.data.size,
        mode='multi_classification')
    test_RSNA_dataset = RSNA_dataset(
        test_df,
        cfg.path.data.SSL,
        augmentation_transform=[
            getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
            cfg.SSL.dataset.augmentation.eval.items()
        ],
        window=(cfg.data.win_center, cfg.data.win_width),
        output_size=cfg.data.size,
        mode='multi_classification')

    logger.info(f"Data will be loaded from {cfg.path.data.SSL}.")
    logger.info(
        f"CT scans will be windowed on [{cfg.data.win_center-cfg.data.win_width/2} ; {cfg.data.win_center + cfg.data.win_width/2}]"
    )
    logger.info(f"CT scans will be resized to {cfg.data.size}x{cfg.data.size}")
    logger.info(
        f"Training online data transformation: \n\n {str(train_RSNA_dataset.transform)}\n"
    )
    logger.info(
        f"Evaluation online data transformation: \n\n {str(test_RSNA_dataset.transform)}\n"
    )

    # Make U-Net-Encoder architecture
    net_ssl = UNet_Encoder(**cfg.SSL.net).to(cfg.device)
    net_params = [f"--> {k} : {v}" for k, v in cfg.SSL.net.items()]
    logger.info("UNet like Multi Classifier \n\t" + "\n\t".join(net_params))
    logger.info(
        f"The Multi Classifier has {sum(p.numel() for p in net_ssl.parameters())} parameters."
    )

    # Make Model
    cfg.SSL.train.model_param.lr_scheduler = getattr(
        torch.optim.lr_scheduler, cfg.SSL.train.model_param.lr_scheduler
    )  # convert scheduler name to scheduler class object
    if cfg.SSL.train.model_param.loss_fn == 'BCEWithLogitsLoss':
        df_rsna['no_Hemorrhage'] = 1 - df_rsna.Hemorrhage
        class_weight_list = (
            (len(df_rsna) - df_rsna[train_RSNA_dataset.class_name].sum()) /
            df_rsna[train_RSNA_dataset.class_name].sum()
        ).values  # define CE weighting from train dataset
        cfg.SSL.train.model_param.loss_fn_kwargs['pos_weight'] = torch.tensor(
            class_weight_list, device=cfg.device)  # add weighting to CE kwargs
    try:
        cfg.SSL.train.model_param.loss_fn = getattr(
            torch.nn, cfg.SSL.train.model_param.loss_fn
        )  # convert loss_fn name to nn.Module class object
    except AttributeError:
        cfg.SSL.train.model_param.loss_fn = getattr(
            src.models.optim.LossFunctions, cfg.SSL.train.model_param.loss_fn)

    #torch.tensor([1 - w_ICH, w_ICH], device=cfg.device).float()

    classifier = MultiClassifier(net_ssl,
                                 device=cfg.device,
                                 print_progress=cfg.print_progress,
                                 **cfg.SSL.train.model_param)

    train_params = [
        f"--> {k} : {v}" for k, v in cfg.SSL.train.model_param.items()
    ]
    logger.info("Classifer Training Parameters \n\t" +
                "\n\t".join(train_params))

    # Load weights if specified
    if cfg.SSL.train.model_path_to_load:
        model_path = cfg.SSL.train.model_path_to_load
        classifier.load_model(model_path, map_location=cfg.device)
        logger.info(
            f"Classifer Model succesfully loaded from {cfg.SSL.train.model_path_to_load}"
        )

    # train if needed
    if cfg.SSL.train.model_param.n_epoch > 0:
        classifier.train(train_RSNA_dataset,
                         valid_dataset=test_RSNA_dataset,
                         checkpoint_path=os.path.join(out_path_selfsup,
                                                      f'checkpoint.pt'))

    # evaluate
    auc, acc, sub_acc, recall, precision, f1 = classifier.evaluate(
        test_RSNA_dataset, save_tsne=True, return_scores=True)
    logger.info(f"Classifier Test AUC : {auc:.2%}")
    logger.info(f"Classifier Test Accuracy : {acc:.2%}")
    logger.info(f"Classifier Test Subset Accuracy : {sub_acc:.2%}")
    logger.info(f"Classifier Test Recall : {recall:.2%}")
    logger.info(f"Classifier Test Precision : {precision:.2%}")
    logger.info(f"Classifier Test F1-score : {f1:.2%}")

    # save model, outputs
    classifier.save_model_state_dict(
        os.path.join(out_path_selfsup, 'pretrained_unet_enc.pt'))
    logger.info(
        "Pre-trained U-Net encoder on binary classification saved at " +
        os.path.join(out_path_selfsup, 'pretrained_unet_enc.pt'))
    classifier.save_outputs(os.path.join(out_path_selfsup, 'outputs.json'))
    logger.info("Classifier outputs saved at " +
                os.path.join(out_path_selfsup, 'outputs.json'))
    test_df.reset_index(drop=True).to_csv(
        os.path.join(out_path_selfsup, 'eval_data_info.csv'))
    logger.info("Evaluation data info saved at " +
                os.path.join(out_path_selfsup, 'eval_data_info.csv'))

    # delete any checkpoints
    if os.path.exists(os.path.join(out_path_selfsup, f'checkpoint.pt')):
        os.remove(os.path.join(out_path_selfsup, f'checkpoint.pt'))
        logger.info('Checkpoint deleted.')

    # get weights state dictionnary
    pretrained_unet_weights = classifier.get_state_dict()

    ###################################################################
    # Supervised fine-training of U-Net  with K-Fold Cross-Validation #
    ###################################################################
    # load annotated data csv
    data_info_df = pd.read_csv(os.path.join(cfg.path.data.Sup, 'ct_info.csv'),
                               index_col=0)
    patient_df = pd.read_csv(os.path.join(cfg.path.data.Sup,
                                          'patient_info.csv'),
                             index_col=0)

    # Make K-Fold spolit at patient level
    skf = StratifiedKFold(n_splits=cfg.Sup.split.n_fold,
                          shuffle=cfg.Sup.split.shuffle,
                          random_state=cfg.seed)

    # define scheduler and loss_fn as object
    cfg.Sup.train.model_param.lr_scheduler = getattr(
        torch.optim.lr_scheduler, cfg.Sup.train.model_param.lr_scheduler
    )  # convert scheduler name to scheduler class object
    cfg.Sup.train.model_param.loss_fn = getattr(
        src.models.optim.LossFunctions, cfg.Sup.train.model_param.loss_fn
    )  # convert loss_fn name to nn.Module class object

    # iterate over folds
    for k, (train_idx, test_idx) in enumerate(
            skf.split(patient_df.PatientNumber, patient_df.Hemorrhage)):
        # check if fold's results already exists
        if not os.path.exists(
                os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json')):
            # initialize logger
            logger = initialize_logger(
                os.path.join(out_path_sup, f'Fold_{k+1}/log.txt'))
            if os.path.exists(
                    os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')):
                logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' +
                            '#' * 30)
            logger.info(f"Experiment : {cfg['exp_name']}")
            logger.info(
                f"Cross-Validation fold {k+1:02}/{cfg['Sup']['split']['n_fold']:02}"
            )

            # extract train/test slice dataframe
            train_df = data_info_df[data_info_df.PatientNumber.isin(
                patient_df.loc[train_idx, 'PatientNumber'].values)]
            test_df = data_info_df[data_info_df.PatientNumber.isin(
                patient_df.loc[test_idx, 'PatientNumber'].values)]
            # samples train dataframe to adjuste negative/positive fractions
            n_remove = int(
                max(
                    0,
                    len(train_df[train_df.Hemorrhage == 0]) -
                    cfg.Sup.dataset.frac_negative *
                    len(train_df[train_df.Hemorrhage == 1])))
            df_remove = train_df[train_df.Hemorrhage == 0].sample(
                n=n_remove, random_state=cfg.seed)
            train_df = train_df[~train_df.index.isin(df_remove.index)]
            logger.info(
                '\n' +
                str(get_split_summary_table(data_info_df, train_df, test_df)))

            # Make datasets
            train_dataset = public_SegICH_Dataset2D(
                train_df,
                cfg.path.data.Sup,
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg.Sup.dataset.augmentation.train.items()
                ],
                window=(cfg.data.win_center, cfg.data.win_width),
                output_size=cfg.data.size)
            test_dataset = public_SegICH_Dataset2D(
                test_df,
                cfg.path.data.Sup,
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg.Sup.dataset.augmentation.eval.items()
                ],
                window=(cfg.data.win_center, cfg.data.win_width),
                output_size=cfg.data.size)
            logger.info(f"Data will be loaded from {cfg.path.data.Sup}.")
            logger.info(
                f"CT scans will be windowed on [{cfg.data.win_center-cfg.data.win_width/2} ; {cfg.data.win_center + cfg.data.win_width/2}]"
            )
            logger.info(
                f"CT scans will be resized to {cfg.data.size}x{cfg.data.size}")
            logger.info(
                f"Training online data transformation: \n\n {str(train_dataset.transform)}\n"
            )
            logger.info(
                f"Evaluation online data transformation: \n\n {str(test_dataset.transform)}\n"
            )

            # Make U-Net architecture
            unet_sup = UNet(**cfg.Sup.net).to(cfg.device)
            net_params = [f"--> {k} : {v}" for k, v in cfg.Sup.net.items()]
            logger.info("UNet-2D params \n\t" + "\n\t".join(net_params))
            logger.info(
                f"The U-Net2D has {sum(p.numel() for p in unet_sup.parameters())} parameters."
            )

            # Make Model
            unet2D = UNet2D(unet_sup,
                            device=cfg.device,
                            print_progress=cfg.print_progress,
                            **cfg.Sup.train.model_param)

            train_params = [
                f"--> {k} : {v}" for k, v in cfg.Sup.train.model_param.items()
            ]
            logger.info("UNet-2D Training Parameters \n\t" +
                        "\n\t".join(train_params))

            # ????? load model if specified ?????

            # transfer weights learn with context restoration
            logger.info(
                'Initialize U-Net2D with weights learned with context_restoration on RSNA.'
            )
            unet2D.transfer_weights(pretrained_unet_weights, verbose=True)

            # Train U-net
            eval_dataset = test_dataset if cfg.Sup.train.validate_epoch else None
            unet2D.train(train_dataset,
                         valid_dataset=eval_dataset,
                         checkpoint_path=os.path.join(
                             out_path_sup, f'Fold_{k+1}/checkpoint.pt'))

            # Evaluate U-Net
            unet2D.evaluate(test_dataset,
                            save_path=os.path.join(out_path_sup,
                                                   f'Fold_{k+1}/pred/'))

            # Save models and outputs
            unet2D.save_model(
                os.path.join(out_path_sup, f'Fold_{k+1}/trained_unet.pt'))
            logger.info(
                "Trained U-Net saved at " +
                os.path.join(out_path_sup, f'Fold_{k+1}/trained_unet.pt'))
            unet2D.save_outputs(
                os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json'))
            logger.info("Trained statistics saved at " +
                        os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json'))

            # delete checkpoint if exists
            if os.path.exists(
                    os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')):
                os.remove(
                    os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt'))
                logger.info('Checkpoint deleted.')

    # save mean +/- 1.96 std Dice over Folds
    save_mean_fold_dice(out_path_sup, cfg.Sup.split.n_fold)
    logger.info('Average Scores saved at ' +
                os.path.join(out_path_sup, 'average_scores.txt'))

    # Save all volumes prediction csv
    df_list = [
        pd.read_csv(
            os.path.join(out_path_sup,
                         f'Fold_{i+1}/pred/volume_prediction_scores.csv'))
        for i in range(cfg.Sup.split.n_fold)
    ]
    all_df = pd.concat(df_list, axis=0).reset_index(drop=True)
    all_df.to_csv(os.path.join(out_path_sup, 'all_volume_prediction.csv'))
    logger.info('CSV of all volumes prediction saved at ' +
                os.path.join(out_path_sup, 'all_volume_prediction.csv'))

    # Save config file
    cfg.device = str(cfg.device)
    cfg.SSL.train.model_param.lr_scheduler = str(
        cfg.SSL.train.model_param.lr_scheduler)
    cfg.Sup.train.model_param.lr_scheduler = str(
        cfg.Sup.train.model_param.lr_scheduler)
    cfg.SSL.train.model_param.loss_fn = str(cfg.SSL.train.model_param.loss_fn)
    cfg.Sup.train.model_param.loss_fn = str(cfg.Sup.train.model_param.loss_fn)
    cfg.SSL.train.model_param.loss_fn_kwargs.pos_weight = cfg.SSL.train.model_param.loss_fn_kwargs.pos_weight.cpu(
    ).data.tolist()
    with open(os.path.join(out_path, 'config.json'), 'w') as fp:
        json.dump(cfg, fp)
    logger.info('Config file saved at ' +
                os.path.join(out_path, 'config.json'))

    # Analyse results
    analyse_supervised_exp(out_path_sup,
                           cfg.path.data.Sup,
                           n_fold=cfg.Sup.split.n_fold,
                           config_folder=out_path,
                           save_fn=os.path.join(
                               out_path, 'results_supervised_overview.pdf'))
    logger.info('Results overview figure saved at ' +
                os.path.join(out_path, 'results_supervised_overview.pdf'))
    analyse_representation_exp(out_path_selfsup,
                               save_fn=os.path.join(
                                   out_path,
                                   'results_self-supervised_overview.pdf'))
    logger.info('Results overview figure saved at ' +
                os.path.join(out_path, 'results_self-supervised_overview.pdf'))
def main(config_path):
    """
    UNet2D pretrained on context retoration with the RSNA dataset and finetuned with the public data.
    """
    # load the config file
    with open(config_path, 'r') as fp:
        cfg = json.load(fp)

    # Make Outputs directories
    out_path = os.path.join(
        cfg['path']['output'],
        cfg['exp_name'])  # + datetime.now().strftime('_%Y-%m-%d'))
    out_path_selfsup = os.path.join(out_path, 'context_restoration_pretrain/')
    out_path_sup = os.path.join(out_path, 'supervised_train/')
    os.makedirs(out_path_selfsup, exist_ok=True)
    for k in range(cfg['Sup']['split']['n_fold']):
        os.makedirs(os.path.join(out_path_sup, f'Fold_{k+1}/pred/'),
                    exist_ok=True)

    # Initialize random seed
    seed = cfg['seed']
    if seed != -1:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

    # Set number of thread
    if cfg['n_thread'] > 0: torch.set_num_threads(cfg['n_thread'])
    # check if GPU available
    #cfg['device'] = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    if cfg['device'] is not None:
        cfg['device'] = torch.device(cfg['device'])
    else:
        if torch.cuda.is_available():
            free_mem, device_idx = 0.0, 0
            for d in range(torch.cuda.device_count()):
                mem = torch.cuda.get_device_properties(
                    d).total_memory - torch.cuda.memory_allocated(d)
                if mem > free_mem:
                    device_idx = d
                    free_mem = mem
            cfg['device'] = torch.device(f'cuda:{device_idx}')
        else:
            cfg['device'] = torch.device('cpu')

    ###################################################
    # Self-supervised training on Context Restoration #
    ###################################################
    # Initialize Logger
    logger = initialize_logger(os.path.join(out_path_selfsup, 'log.txt'))
    if os.path.exists(os.path.join(out_path_selfsup, f'checkpoint.pt')):
        logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' + '#' * 30)
    logger.info(f"Experiment : {cfg['exp_name']}")

    # Load RSNA data csv
    df_rsna = pd.read_csv(os.path.join(cfg['path']['data']['SSL'],
                                       'slice_info.csv'),
                          index_col=0)

    # Keep only fractions of negative/positive
    df_rsna_pos = df_rsna[df_rsna.Hemorrhage == 1]
    if cfg['SSL']['dataset']['n_positive'] >= 0:
        df_rsna_pos = df_rsna_pos.sample(n=cfg['SSL']['dataset']['n_positive'],
                                         random_state=seed)
    n_neg = int(cfg['SSL']['dataset']['frac_negative'] * len(df_rsna_pos))
    df_rsna_neg = df_rsna[df_rsna.Hemorrhage == 0].sample(n=n_neg,
                                                          random_state=seed)
    df_rsna_samp = pd.concat([df_rsna_pos, df_rsna_neg], axis=0)

    # Split data to keep few for evaluation
    test_df = df_rsna_samp.sample(n=cfg['SSL']['dataset']['n_eval'],
                                  random_state=seed)
    train_df = df_rsna_samp.drop(test_df.index)
    df_rsna_neg = df_rsna[(df_rsna.Hemorrhage == 0)]
    test_df = test_df.append(
        df_rsna_neg[~df_rsna_neg.isin(train_df.index)].sample(
            n=cfg['SSL']['dataset']['n_eval_neg'], random_state=seed))
    test_df, train_df = test_df.reset_index(), train_df.reset_index()
    logger.info('\n' +
                str(get_split_summary_table(df_rsna, train_df, test_df)))

    # Make dataset : Train --> ContextRestoration, Test --> Standard
    train_RSNA_dataset = RSNA_dataset(
        train_df,
        cfg['path']['data']['SSL'],
        augmentation_transform=[
            getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
            cfg['SSL']['dataset']['augmentation']['train'].items()
        ],
        window=(cfg['data']['win_center'], cfg['data']['win_width']),
        output_size=cfg['data']['size'],
        mode='context_restoration',
        n_swap=cfg['SSL']['dataset']['n_swap'],
        swap_w=cfg['SSL']['dataset']['swap_w'],
        swap_h=cfg['SSL']['dataset']['swap_h'],
        swap_rot=cfg['SSL']['dataset']['swap_rotate'])
    test_RSNA_dataset = RSNA_dataset(
        test_df,
        cfg['path']['data']['SSL'],
        augmentation_transform=[
            getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
            cfg['SSL']['dataset']['augmentation']['eval'].items()
        ],
        window=(cfg['data']['win_center'], cfg['data']['win_width']),
        output_size=cfg['data']['size'],
        mode='standard')
    logger.info(f"Data will be loaded from {cfg['path']['data']['SSL']}.")
    logger.info(
        f"CT scans will be windowed on [{cfg['data']['win_center']-cfg['data']['win_width']/2} ; {cfg['data']['win_center'] + cfg['data']['win_width']/2}]"
    )
    logger.info(
        f"CT scans will be resized to {cfg['data']['size']}x{cfg['data']['size']}"
    )
    logger.info(
        f"Training online data transformation: \n\n {str(train_RSNA_dataset.transform)}\n"
    )
    logger.info(
        f"Evaluation online data transformation: \n\n {str(test_RSNA_dataset.transform)}\n"
    )
    logger.info(
        f"Input images will be corrupted by {cfg['SSL']['dataset']['n_swap']} swap of dimension {cfg['SSL']['dataset']['swap_h']}x{cfg['SSL']['dataset']['swap_w']} (h x w)"
    )

    # Make U-Net architecture
    net_ssl = UNet(depth=cfg['SSL']['net']['depth'],
                   top_filter=cfg['SSL']['net']['top_filter'],
                   use_3D=cfg['SSL']['net']['3D'],
                   in_channels=cfg['SSL']['net']['in_channels'],
                   out_channels=cfg['SSL']['net']['out_channels'],
                   bilinear=cfg['SSL']['net']['bilinear'],
                   use_final_activation=cfg['SSL']['net']['final_activation'],
                   midchannels_factor=cfg['SSL']['net']['midchannels_factor'],
                   p_dropout=cfg['SSL']['net']['p_dropout'])
    net_ssl.to(cfg['device'])
    logger.info(
        f"U-Net2D initialized with a depth of {cfg['SSL']['net']['depth']}"
        f" and a number of initial filter of {cfg['SSL']['net']['top_filter']},"
    )
    logger.info(
        f"Reconstruction performed with {'Upsample + Conv' if cfg['SSL']['net']['bilinear'] else 'ConvTranspose'}."
    )
    logger.info(
        f"U-Net2D takes {cfg['SSL']['net']['in_channels']} as input channels and {cfg['SSL']['net']['out_channels']} as output channels."
    )
    logger.info(
        f"The U-Net2D has {sum(p.numel() for p in net_ssl.parameters())} parameters."
    )

    # Make Model
    ctx_restor = ContextRestoration(
        net_ssl,
        n_epoch=cfg['SSL']['train']['n_epoch'],
        batch_size=cfg['SSL']['train']['batch_size'],
        lr=cfg['SSL']['train']['lr'],
        lr_scheduler=getattr(torch.optim.lr_scheduler,
                             cfg['SSL']['train']['lr_scheduler']),
        lr_scheduler_kwargs=cfg['SSL']['train']['lr_scheduler_kwargs'],
        loss_fn=getattr(torch.nn, cfg['SSL']['train']['loss_fn']),
        loss_fn_kwargs=cfg['SSL']['train']['loss_fn_kwargs'],
        weight_decay=cfg['SSL']['train']['weight_decay'],
        num_workers=cfg['SSL']['train']['num_workers'],
        device=cfg['device'],
        print_progress=cfg['print_progress'])

    # Load weights if specified
    if cfg['SSL']['train']['model_path_to_load']:
        model_path = cfg['SSL']['train']['model_path_to_load']
        ctx_restor.load_model(model_path, map_location=cfg['device'])

    # train if needed
    if cfg['SSL']['train']['n_epoch'] > 0:
        train_params = []
        for key, value in cfg['SSL']['train'].items():
            train_params.append(f"--> {key} : {value}")
        logger.info('Training settings:\n\t' + '\n\t'.join(train_params))

        ctx_restor.train(train_RSNA_dataset,
                         checkpoint_path=os.path.join(out_path_selfsup,
                                                      f'checkpoint.pt'))

    # evaluate
    ctx_restor.evaluate(test_RSNA_dataset)

    # save model, outputs and evaluation data info (test_df)
    ctx_restor.save_model(os.path.join(out_path_selfsup, 'pretrained_unet.pt'))
    logger.info("Pre-trained U-Net on context restoration saved at " +
                os.path.join(out_path_selfsup, 'pretrained_unet.pt'))
    ctx_restor.save_outputs(os.path.join(out_path_selfsup, 'outputs.json'))
    logger.info("Context restoration outputs saved at " +
                os.path.join(out_path_selfsup, 'outputs.json'))
    test_df.to_csv(os.path.join(out_path_selfsup, 'eval_data_info.csv'))
    logger.info("Evaluation data info saved at " +
                os.path.join(out_path_selfsup, 'eval_data_info.csv'))

    # delete any checkpoints
    if os.path.exists(os.path.join(out_path_selfsup, f'checkpoint.pt')):
        os.remove(os.path.join(out_path_selfsup, f'checkpoint.pt'))
        logger.info('Checkpoint deleted.')

    # get weights state dictionnary
    pretrained_unet_weights = ctx_restor.get_state_dict()

    ###################################################################
    # Supervised fine-training of U-Net  with K-Fold Cross-Validation #
    ###################################################################
    # load annotated data csv
    data_info_df = pd.read_csv(os.path.join(cfg['path']['data']['Sup'],
                                            'ct_info.csv'),
                               index_col=0)
    patient_df = pd.read_csv(os.path.join(cfg['path']['data']['Sup'],
                                          'patient_info.csv'),
                             index_col=0)

    # Make K-Fold spolit at patient level
    skf = StratifiedKFold(n_splits=cfg['Sup']['split']['n_fold'],
                          shuffle=cfg['Sup']['split']['shuffle'],
                          random_state=seed)

    # iterate over folds
    for k, (train_idx, test_idx) in enumerate(
            skf.split(patient_df.PatientNumber, patient_df.Hemorrhage)):
        # check if fold's results already exists
        if not os.path.exists(
                os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json')):
            # initialize logger
            logger = initialize_logger(
                os.path.join(out_path_sup, f'Fold_{k+1}/log.txt'))
            if os.path.exists(
                    os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')):
                logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' +
                            '#' * 30)
            logger.info(f"Experiment : {cfg['exp_name']}")
            logger.info(
                f"Cross-Validation fold {k+1:02}/{cfg['Sup']['split']['n_fold']:02}"
            )

            # extract train/test slice dataframe
            train_df = data_info_df[data_info_df.PatientNumber.isin(
                patient_df.loc[train_idx, 'PatientNumber'].values)]
            test_df = data_info_df[data_info_df.PatientNumber.isin(
                patient_df.loc[test_idx, 'PatientNumber'].values)]
            # samples train dataframe to adjuste negative/positive fractions
            n_remove = int(
                max(
                    0,
                    len(train_df[train_df.Hemorrhage == 0]) -
                    cfg['Sup']['dataset']['frac_negative'] *
                    len(train_df[train_df.Hemorrhage == 1])))
            df_remove = train_df[train_df.Hemorrhage == 0].sample(
                n=n_remove, random_state=seed)
            train_df = train_df[~train_df.index.isin(df_remove.index)]
            logger.info(
                '\n' +
                str(get_split_summary_table(data_info_df, train_df, test_df)))

            # Make datasets
            train_dataset = public_SegICH_Dataset2D(
                train_df,
                cfg['path']['data']['Sup'],
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg['Sup']['dataset']['augmentation']['train'].items()
                ],
                window=(cfg['data']['win_center'], cfg['data']['win_width']),
                output_size=cfg['data']['size'])
            test_dataset = public_SegICH_Dataset2D(
                test_df,
                cfg['path']['data']['Sup'],
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg['Sup']['dataset']['augmentation']['eval'].items()
                ],
                window=(cfg['data']['win_center'], cfg['data']['win_width']),
                output_size=cfg['data']['size'])
            logger.info(
                f"Data will be loaded from {cfg['path']['data']['Sup']}.")
            logger.info(
                f"CT scans will be windowed on [{cfg['data']['win_center']-cfg['data']['win_width']/2} ; {cfg['data']['win_center'] + cfg['data']['win_width']/2}]"
            )
            logger.info(
                f"CT scans will be resized to {cfg['data']['size']}x{cfg['data']['size']}"
            )
            logger.info(
                f"Training online data transformation: \n\n {str(train_dataset.transform)}\n"
            )
            logger.info(
                f"Evaluation online data transformation: \n\n {str(test_dataset.transform)}\n"
            )

            # Make U-Net architecture
            unet_sup = UNet(
                depth=cfg['Sup']['net']['depth'],
                top_filter=cfg['Sup']['net']['top_filter'],
                use_3D=cfg['Sup']['net']['3D'],
                in_channels=cfg['Sup']['net']['in_channels'],
                out_channels=cfg['Sup']['net']['out_channels'],
                bilinear=cfg['Sup']['net']['bilinear'],
                midchannels_factor=cfg['Sup']['net']['midchannels_factor'],
                p_dropout=cfg['Sup']['net']['p_dropout'])
            unet_sup.to(cfg['device'])
            logger.info(
                f"U-Net2D initialized with a depth of {cfg['Sup']['net']['depth']}"
                f" and a number of initial filter of {cfg['Sup']['net']['top_filter']},"
            )
            logger.info(
                f"Reconstruction performed with {'Upsample + Conv' if cfg['Sup']['net']['bilinear'] else 'ConvTranspose'}."
            )
            logger.info(
                f"U-Net2D takes {cfg['Sup']['net']['in_channels']} as input channels and {cfg['Sup']['net']['out_channels']} as output channels."
            )
            logger.info(
                f"The U-Net2D has {sum(p.numel() for p in unet_sup.parameters())} parameters."
            )

            # Make Model
            unet2D = UNet2D(
                unet_sup,
                n_epoch=cfg['Sup']['train']['n_epoch'],
                batch_size=cfg['Sup']['train']['batch_size'],
                lr=cfg['Sup']['train']['lr'],
                lr_scheduler=getattr(torch.optim.lr_scheduler,
                                     cfg['Sup']['train']['lr_scheduler']),
                lr_scheduler_kwargs=cfg['Sup']['train']['lr_scheduler_kwargs'],
                loss_fn=getattr(src.models.optim.LossFunctions,
                                cfg['Sup']['train']['loss_fn']),
                loss_fn_kwargs=cfg['Sup']['train']['loss_fn_kwargs'],
                weight_decay=cfg['Sup']['train']['weight_decay'],
                num_workers=cfg['Sup']['train']['num_workers'],
                device=cfg['device'],
                print_progress=cfg['print_progress'])

            # ????? load model if specified ?????

            # transfer weights learn with context restoration
            logger.info(
                'Initialize U-Net2D with weights learned with context_restoration on RSNA.'
            )
            unet2D.transfer_weights(pretrained_unet_weights, verbose=True)

            # Print training parameters
            train_params = []
            for key, value in cfg['Sup']['train'].items():
                train_params.append(f"--> {key} : {value}")
            logger.info('Training settings:\n\t' + '\n\t'.join(train_params))

            # Train U-net
            eval_dataset = test_dataset if cfg['Sup']['train'][
                'validate_epoch'] else None
            unet2D.train(train_dataset,
                         valid_dataset=eval_dataset,
                         checkpoint_path=os.path.join(
                             out_path_sup, f'Fold_{k+1}/checkpoint.pt'))

            # Evaluate U-Net
            unet2D.evaluate(test_dataset,
                            save_path=os.path.join(out_path_sup,
                                                   f'Fold_{k+1}/pred/'))

            # Save models and outputs
            unet2D.save_model(
                os.path.join(out_path_sup, f'Fold_{k+1}/trained_unet.pt'))
            logger.info(
                "Trained U-Net saved at " +
                os.path.join(out_path_sup, f'Fold_{k+1}/trained_unet.pt'))
            unet2D.save_outputs(
                os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json'))
            logger.info("Trained statistics saved at " +
                        os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json'))

            # delete checkpoint if exists
            if os.path.exists(
                    os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')):
                os.remove(
                    os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt'))
                logger.info('Checkpoint deleted.')

    # save mean +/- 1.96 std Dice over Folds
    save_mean_fold_dice(out_path_sup, cfg['Sup']['split']['n_fold'])
    logger.info('Average Scores saved at ' +
                os.path.join(out_path_sup, 'average_scores.txt'))

    # Save all volumes prediction csv
    df_list = [
        pd.read_csv(
            os.path.join(out_path_sup,
                         f'Fold_{i+1}/pred/volume_prediction_scores.csv'))
        for i in range(cfg['Sup']['split']['n_fold'])
    ]
    all_df = pd.concat(df_list, axis=0).reset_index(drop=True)
    all_df.to_csv(os.path.join(out_path_sup, 'all_volume_prediction.csv'))
    logger.info('CSV of all volumes prediction saved at ' +
                os.path.join(out_path_sup, 'all_volume_prediction.csv'))

    # Save config file
    cfg['device'] = str(cfg['device'])
    with open(os.path.join(out_path, 'config.json'), 'w') as fp:
        json.dump(cfg, fp)
    logger.info('Config file saved at ' +
                os.path.join(out_path, 'config.json'))

    # Analyse results
    analyse_supervised_exp(out_path_sup,
                           cfg['path']['data']['Sup'],
                           n_fold=cfg['Sup']['split']['n_fold'],
                           config_folder=out_path,
                           save_fn=os.path.join(
                               out_path, 'results_supervised_overview.pdf'))
    logger.info('Results overview figure saved at ' +
                os.path.join(out_path, 'results_supervised_overview.pdf'))
    analyse_representation_exp(out_path_selfsup,
                               save_fn=os.path.join(
                                   out_path,
                                   'results_self-supervised_overview.pdf'))
    logger.info('Results overview figure saved at ' +
                os.path.join(out_path, 'results_self-supervised_overview.pdf'))
Example #5
0
def main(config_path):
    """  """
    # load config
    cfg = AttrDict.from_json_path(config_path)

    # make outputs dir
    out_path = os.path.join(cfg.path.output, cfg.exp_name)
    os.makedirs(out_path, exist_ok=True)

    # initialize seed
    if cfg.seed != -1:
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
        torch.backends.cudnn.deterministic = True

    # initialize logger
    logger = initialize_logger(os.path.join(out_path, 'log.txt'))
    logger.info(f"Experiment : {cfg.exp_name}")

    # set device
    if cfg.device:
        cfg.device = torch.device(cfg.device)
    else:
        cfg.device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    logger.info(f"Device set to {cfg.device}.")

    #-------------------------------------------
    #       Make Dataset
    #-------------------------------------------

    data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'ct_info.csv'), index_col=0)
    dataset = public_SegICH_Dataset2D(data_info_df, cfg.path.data,
                    augmentation_transform=[getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.data.augmentation.items()],
                    output_size=cfg.data.size, window=(cfg.data.win_center, cfg.data.win_width))

    #-------------------------------------------
    #       Load FCDD Model
    #-------------------------------------------

    cfg_fcdd = AttrDict.from_json_path(cfg.fcdd_cfg_path)
    fcdd_net = FCDD_CNN_VGG(in_shape=(cfg_fcdd.net.in_channels, 256, 256), bias=cfg_fcdd.net.bias)
    loaded_state_dict = torch.load(cfg.fcdd_model_path, map_location=cfg.device)
    fcdd_net.load_state_dict(loaded_state_dict)
    fcdd_net = fcdd_net.to(cfg.device).eval()
    logger.info(f"FCDD model succesfully loaded from {cfg.fcdd_model_path}")

    # make FCDD object
    fcdd = FCDD(fcdd_net, batch_size=cfg.batch_size, num_workers=cfg.num_workers,
                device=cfg.device, print_progress=cfg.print_progress)

    #-------------------------------------------
    #       Load Classifier Model
    #-------------------------------------------

    # Load Classifier
    if cfg.classifier_model_path is not None:
        cfg_classifier = AttrDict.from_json_path(os.path.join(cfg.classifier_model_path, 'config.json'))
        classifier = getattr(rn, cfg_classifier.net.resnet)(num_classes=cfg_classifier.net.num_classes, input_channels=cfg_classifier.net.input_channels)
        classifier_state_dict = torch.load(os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt'), map_location=cfg.device)
        classifier.load_state_dict(classifier_state_dict)
        classifier = classifier.to(cfg.device)
        classifier.eval()
        logger.info(f"ResNet classifier model succesfully loaded from {os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt')}")

    #-------------------------------------------
    #       Generate Heat-Map for each slice
    #-------------------------------------------

    with torch.no_grad():
        # make loader
        loader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers,
                                             shuffle=False, worker_init_fn=lambda _: np.random.seed())
        fcdd_net.eval()

        min_val, max_val = fcdd.get_min_max(loader, **cfg.heatmap_param)

        # computing and saving heatmaps
        out = dict(id=[], slice=[], label=[], ad_map_fn=[], ad_mask_fn=[],
                   TP=[], TN=[], FP=[], FN=[], AUC=[], classifier_pred=[])
        for b, data in enumerate(loader):
            im, mask, id, slice = data
            im = im.to(cfg.device).float()
            mask = mask.to(cfg.device).float()

            # get heatmap
            heatmap = fcdd.generate_heatmap(im, reception=cfg.heatmap_param.reception, std=cfg.heatmap_param.std,
                                            cpu=cfg.heatmap_param.cpu)
            # scaling
            heatmap = ((heatmap - min_val) / (max_val - min_val)).clamp(0,1)

            # Threshold
            ad_mask = torch.where(heatmap >= cfg.heatmap_threshold, torch.ones_like(heatmap, device=heatmap.device),
                                                                    torch.zeros_like(heatmap, device=heatmap.device))

            # Compute CM
            tn, fp, fn, tp  = batch_binary_confusion_matrix(ad_mask, mask.to(heatmap.device))

            # Save heatmaps/mask
            map_fn, mask_fn = [], []
            for i in range(im.shape[0]):
                # Save AD Map
                ad_map_fn = f"{id[i]}/{slice[i]}_map_anomalies.png"
                save_path = os.path.join(out_path, 'pred/', ad_map_fn)
                if not os.path.isdir(os.path.dirname(save_path)):
                    os.makedirs(os.path.dirname(save_path), exist_ok=True)

                ad_map = heatmap[i].squeeze().cpu().numpy()
                io.imsave(save_path, img_as_ubyte(ad_map), check_contrast=False)
                # save ad_mask
                ad_mask_fn = f"{id[i]}/{slice[i]}_anomalies.bmp"
                save_path = os.path.join(out_path, 'pred/', ad_mask_fn)
                io.imsave(save_path, img_as_ubyte(ad_mask[i].squeeze().cpu().numpy()), check_contrast=False)

                map_fn.append(ad_map_fn)
                mask_fn.append(ad_mask_fn)

            # apply classifier ResNet-18
            if cfg.classifier_model_path is not None:
                pred_score = nn.functional.softmax(classifier(im), dim=1)[:,1] # take columns of softmax of positive class as score
                clss_pred = torch.where(pred_score >= cfg.classification_threshold, torch.ones_like(pred_score, device=pred_score.device),
                                                                                    torch.zeros_like(pred_score, device=pred_score.device))
            else:
                clss_pred = [None]*im.shape[0]

            # Save Values
            out['id'] += id.cpu().tolist()
            out['slice'] += slice.cpu().tolist()
            out['label'] += mask.reshape(mask.shape[0], -1).max(dim=1)[0].cpu().tolist()
            out['ad_map_fn'] += map_fn
            out['ad_mask_fn'] += mask_fn
            out['TN'] += tn.cpu().tolist()
            out['FP'] += fp.cpu().tolist()
            out['FN'] += fn.cpu().tolist()
            out['TP'] += tp.cpu().tolist()
            out['AUC'] += [roc_auc_score(mask[i].cpu().numpy().ravel(), heatmap[i].cpu().numpy().ravel()) if torch.any(mask[i]>0) else 'None' for i in range(im.shape[0])]
            out['classifier_pred'] += clss_pred.cpu().tolist()

            if cfg.print_progress:
                print_progessbar(b, len(loader), Name='Heatmap Generation Batch', Size=100, erase=True)

    # make df and save as csv
    slice_df = pd.DataFrame(out)
    volume_df = slice_df[['id', 'label', 'TP', 'TN', 'FP', 'FN']].groupby('id').agg({'label':'max', 'TP':'sum', 'TN':'sum', 'FP':'sum', 'FN':'sum'})

    slice_df['Dice'] = (2*slice_df.TP + 1) / (2*slice_df.TP + slice_df.FP + slice_df.FN + 1)
    volume_df['Dice'] = (2*volume_df.TP + 1) / (2*volume_df.TP + volume_df.FP + volume_df.FN + 1)
    logger.info(f"Mean slice dice : {slice_df.Dice.mean(axis=0):.3f}")
    logger.info(f"Mean volume dice : {volume_df.Dice.mean(axis=0):.3f}")
    logger.info(f"Mean posiitve slice AUC {slice_df[slice_df.label == 1].AUC.mean(axis=0):.3f}")

    # Save Scores and Config
    slice_df.to_csv(os.path.join(out_path, 'slice_predictions.csv'))
    logger.info(f"Slice prediction csv saved at {os.path.join(out_path, 'slice_predictions.csv')}")
    volume_df.to_csv(os.path.join(out_path, 'volume_predictions.csv'))
    logger.info(f"Volume prediction csv saved at {os.path.join(out_path, 'volume_predictions.csv')}")
    cfg.device = str(cfg.device)
    with open(os.path.join(out_path, 'config.json'), 'w') as f:
        json.dump(cfg, f)
    logger.info(f"Config file saved at {os.path.join(out_path, 'config.json')}")
Example #6
0
def main(config_path):
    """
    Segmente ICH using the anomaly inpainting approach on a whole dataset and compute slice/volume dice.
    """
    # load config
    cfg = AttrDict.from_json_path(config_path)

    # make outputs dir
    out_path = os.path.join(cfg.path.output, cfg.exp_name)
    os.makedirs(out_path, exist_ok=True)

    # initialize seed
    if cfg.seed != -1:
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
        torch.backends.cudnn.deterministic = True

    # set device
    cfg.device = torch.device(
        cfg.device) if cfg.device else get_available_device()

    # initialize logger
    logger = initialize_logger(os.path.join(out_path, 'log.txt'))
    logger.info(f"Experiment : {cfg.exp_name}")

    # get Dataset
    data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'ct_info.csv'),
                               index_col=0)
    #data_info_df = data_info_df[sum([data_info_df.CT_fn.str.contains(s) for s in ['49/16', '51/39', '71/15', '71/22', '75/22']]) > 0]
    dataset = public_SegICH_Dataset2D(
        data_info_df,
        cfg.path.data,
        augmentation_transform=[
            getattr(tf, tf_name)(**tf_kwargs)
            for tf_name, tf_kwargs in cfg.data.augmentation.items()
        ],
        output_size=cfg.data.size,
        window=(cfg.data.win_center, cfg.data.win_width))

    # load inpainting model
    cfg_inpaint = AttrDict.from_json_path(cfg.inpainter_cfg_path)
    cfg_inpaint.net.gen.return_coarse = False
    inpaint_net = SAGatedGenerator(**cfg_inpaint.net.gen)
    loaded_state_dict = torch.load(cfg.inpainter_model_path,
                                   map_location=cfg.device)
    inpaint_net.load_state_dict(loaded_state_dict)
    inpaint_net = inpaint_net.to(
        cfg.device
    )  # inpainter not in eval mode beacuse batch norm layers are not stabilized (because GAN optimization)
    logger.info(
        f"Inpainter model succesfully loaded from {cfg.inpainter_model_path}")

    # make AD inpainter Module
    ad_inpainter = InpaintAnomalyDetector(inpaint_net,
                                          device=cfg.device,
                                          **cfg.model_param)

    # Load Classifier
    if cfg.classifier_model_path is not None:
        cfg_classifier = AttrDict.from_json_path(
            os.path.join(cfg.classifier_model_path, 'config.json'))
        classifier = getattr(rn, cfg_classifier.net.resnet)(
            num_classes=cfg_classifier.net.num_classes,
            input_channels=cfg_classifier.net.input_channels)
        classifier_state_dict = torch.load(os.path.join(
            cfg.classifier_model_path, 'resnet_state_dict.pt'),
                                           map_location=cfg.device)
        classifier.load_state_dict(classifier_state_dict)
        classifier = classifier.to(cfg.device)
        classifier.eval()
        logger.info(
            f"ResNet classifier model succesfully loaded from {os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt')}"
        )

    # iterate over dataset
    all_pred = []
    for i, sample in enumerate(dataset):
        # unpack data
        image, target, id, slice = sample
        logger.info(
            "=" * 25 +
            f" SAMPLE {i+1:04}/{len(dataset):04} - Volume {id:03} Slice {slice:03} "
            + "=" * 25)

        # Classify sample
        if cfg.classifier_model_path is not None:
            with torch.no_grad():
                input_clss = image.unsqueeze(0).to(cfg.device).float()
                pred_score = nn.functional.softmax(
                    classifier(input_clss), dim=1
                )[:, 1]  # take columns of softmax of positive class as score
                pred = 1 if pred_score >= cfg.classification_threshold else 0
        else:
            pred = 1  # if not classifier given, all slices are processed

        # process slice if classifier has detected Hemorrhage
        if pred == 1:
            logger.info(
                f"ICH detected. Compute anomaly mask through inpainting.")
            # Detect anomalies using the robuste approach
            ad_mask, ano_map, intermediate_masks = robust_anomaly_detect(
                image,
                ad_inpainter,
                save_dir=None,
                verbose=True,
                return_intermediate=True,
                **cfg.robust_param)
            # save ad_mask
            ad_mask_fn = f"{id}/{slice}_anomalies.bmp"
            save_path = os.path.join(out_path, 'pred/', ad_mask_fn)
            if not os.path.isdir(os.path.dirname(save_path)):
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
            io.imsave(save_path, img_as_ubyte(ad_mask), check_contrast=False)
            # save anomaly map
            ad_map_fn = f"{id}/{slice}_map_anomalies.png"
            save_path_map = os.path.join(out_path, 'pred/', ad_map_fn)
            io.imsave(save_path_map,
                      img_as_ubyte(ano_map),
                      check_contrast=False)
            # save intermediate mask
            for j, m in enumerate(intermediate_masks):
                if not os.path.isdir(
                        os.path.join(out_path,
                                     f"pred/{id}/intermediate_masks/")):
                    os.makedirs(os.path.join(out_path,
                                             f"pred/{id}/intermediate_masks/"),
                                exist_ok=True)
                io.imsave(os.path.join(
                    out_path,
                    f"pred/{id}/intermediate_masks/{slice}_anomalies_{j+1}.bmp"
                ),
                          img_as_ubyte(m),
                          check_contrast=False)
        else:
            logger.info(f"No ICH detected. Set the anomaly mask to zeros.")
            ad_mask = np.zeros_like(target[0].numpy())
            ad_mask_fn, ad_map_fn = 'None', 'None'

        # compute confusion matrix with target ICH mask
        tn, fp, fn, tp = confusion_matrix(target[0].numpy().ravel(),
                                          ad_mask.ravel(),
                                          labels=[0, 1]).ravel()
        # append to all_pred list
        all_pred.append({
            'id': id.item(),
            'slice': slice.item(),
            'label': target.max().item(),
            'TP': tp,
            'TN': tn,
            'FP': fp,
            'FN': fn,
            'ad_mask_fn': ad_mask_fn,
            'ad_map_fn': ad_map_fn
        })

    # make a dataframe of all predictions
    slice_df = pd.DataFrame(all_pred)
    volume_df = slice_df[['id', 'label', 'TP', 'TN', 'FP',
                          'FN']].groupby('id').agg({
                              'label': 'max',
                              'TP': 'sum',
                              'TN': 'sum',
                              'FP': 'sum',
                              'FN': 'sum'
                          })

    # Compute Dice and Volume Dice
    slice_df['Dice'] = (2 * slice_df.TP + 1) / (2 * slice_df.TP + slice_df.FP +
                                                slice_df.FN + 1)
    volume_df['Dice'] = (2 * volume_df.TP + 1) / (
        2 * volume_df.TP + volume_df.FP + volume_df.FN + 1)
    logger.info(f"Mean slice dice : {slice_df.Dice.mean(axis=0):.3f}")
    logger.info(f"Mean volume dice : {volume_df.Dice.mean(axis=0):.3f}")

    # Save Scores and Config
    slice_df.to_csv(os.path.join(out_path, 'slice_predictions.csv'))
    logger.info(
        f"Slice prediction csv saved at {os.path.join(out_path, 'slice_predictions.csv')}"
    )
    volume_df.to_csv(os.path.join(out_path, 'volume_predictions.csv'))
    logger.info(
        f"Volume prediction csv saved at {os.path.join(out_path, 'volume_predictions.csv')}"
    )
    cfg.device = str(cfg.device)
    with open(os.path.join(out_path, 'config.json'), 'w') as f:
        json.dump(cfg, f)
    logger.info(
        f"Config file saved at {os.path.join(out_path, 'config.json')}")