def main(config_path):
    """
    Train and evaluate a 2D UNet on the public ICH dataset using the parameters sepcified on the JSON at the
    config_path. The evaluation is performed by k-fold cross-validation.
    """
    # load config file
    cfg = Config(settings=None)
    cfg.load_config(config_path)

    # Make Output directories
    out_path = os.path.join(cfg.settings['path']['OUTPUT'],
                            cfg.settings['exp_name'])  # + '/'
    os.makedirs(out_path, exist_ok=True)
    for k in range(cfg.settings['split']['n_fold']):
        os.makedirs(os.path.join(out_path, f'Fold_{k+1}/pred/'), exist_ok=True)

    # Initialize random seed to given seed
    seed = cfg.settings['seed']
    if seed != -1:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

    # Load data csv
    data_info_df = pd.read_csv(
        os.path.join(cfg.settings['path']['DATA'], 'ct_info.csv'))
    data_info_df = data_info_df.drop(data_info_df.columns[0], axis=1)
    patient_df = pd.read_csv(
        os.path.join(cfg.settings['path']['DATA'], 'patient_info.csv'))
    patient_df = patient_df.drop(patient_df.columns[0], axis=1)

    # Generate Cross-Val indices at the patient level
    skf = StratifiedKFold(n_splits=cfg.settings['split']['n_fold'],
                          shuffle=cfg.settings['split']['shuffle'],
                          random_state=seed)
    # iterate over folds and ensure that there are the same amount of ICH positive patient per fold --> Stratiffied CrossVal
    for k, (train_idx, test_idx) in enumerate(
            skf.split(patient_df.PatientNumber, patient_df.Hemorrhage)):
        # if fold results not already there
        if not os.path.exists(
                os.path.join(out_path, f'Fold_{k+1}/outputs.json')):
            # initialize logger
            logging.basicConfig(level=logging.INFO)
            logger = logging.getLogger()
            try:
                logger.handlers[1].stream.close()
                logger.removeHandler(logger.handlers[1])
            except IndexError:
                pass
            logger.setLevel(logging.INFO)
            file_handler = logging.FileHandler(
                os.path.join(out_path, f'Fold_{k+1}/log.txt'))
            file_handler.setLevel(logging.INFO)
            file_handler.setFormatter(
                logging.Formatter('%(asctime)s | %(levelname)s | %(message)s'))
            logger.addHandler(file_handler)

            if os.path.exists(
                    os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt')):
                logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' +
                            '#' * 30)

            logger.info(f"Experiment : {cfg.settings['exp_name']}")
            logger.info(
                f"Cross-Validation fold {k+1:02}/{cfg.settings['split']['n_fold']:02}"
            )

            # initialize nbr of thread
            if cfg.settings['n_thread'] > 0:
                torch.set_num_threads(cfg.settings['n_thread'])
            logger.info(f"Number of thread : {cfg.settings['n_thread']}")
            # check if GPU available
            #cfg.settings['device'] = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
            if cfg.settings['device'] is not None:
                cfg.settings['device'] = torch.device(cfg.settings['device'])
            else:
                if torch.cuda.is_available():
                    free_mem, device_idx = 0.0, 0
                    for d in range(torch.cuda.device_count()):
                        mem = torch.cuda.get_device_properties(
                            d).total_memory - torch.cuda.memory_allocated(d)
                        if mem > free_mem:
                            device_idx = d
                            free_mem = mem
                    cfg.settings['device'] = torch.device(f'cuda:{device_idx}')
                else:
                    cfg.settings['device'] = torch.device('cpu')
            logger.info(f"Device : {cfg.settings['device']}")

            # extract train and test DataFrames + print summary (n samples positive and negatives)
            train_df = data_info_df[data_info_df.PatientNumber.isin(
                patient_df.loc[train_idx, 'PatientNumber'].values)]
            test_df = data_info_df[data_info_df.PatientNumber.isin(
                patient_df.loc[test_idx, 'PatientNumber'].values)]
            # sample the dataframe to have more or less normal slices
            n_remove = int(
                max(
                    0,
                    len(train_df[train_df.Hemorrhage == 0]) -
                    cfg.settings['dataset']['frac_negative'] *
                    len(train_df[train_df.Hemorrhage == 1])))
            df_remove = train_df[train_df.Hemorrhage == 0].sample(
                n=n_remove, random_state=seed)
            train_df = train_df[~train_df.index.isin(df_remove.index)]
            logger.info(
                '\n' +
                str(get_split_summary_table(data_info_df, train_df, test_df)))

            # Make Dataset + print online augmentation summary
            train_dataset = public_SegICH_Dataset2D(
                train_df,
                cfg.settings['path']['DATA'],
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg.settings['data']['augmentation']['train'].items()
                ],
                window=(cfg.settings['data']['win_center'],
                        cfg.settings['data']['win_width']),
                output_size=cfg.settings['data']['size'])
            test_dataset = public_SegICH_Dataset2D(
                test_df,
                cfg.settings['path']['DATA'],
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg.settings['data']['augmentation']['eval'].items()
                ],
                window=(cfg.settings['data']['win_center'],
                        cfg.settings['data']['win_width']),
                output_size=cfg.settings['data']['size'])
            logger.info(
                f"Data will be loaded from {cfg.settings['path']['DATA']}.")
            logger.info(
                f"CT scans will be windowed on [{cfg.settings['data']['win_center']-cfg.settings['data']['win_width']/2} ; {cfg.settings['data']['win_center'] + cfg.settings['data']['win_width']/2}]"
            )
            logger.info(
                f"Training online data transformation: \n\n {str(train_dataset.transform)}\n"
            )
            logger.info(
                f"Evaluation online data transformation: \n\n {str(test_dataset.transform)}\n"
            )

            # Make architecture (and print summmary ??)
            unet_arch = UNet(
                depth=cfg.settings['net']['depth'],
                top_filter=cfg.settings['net']['top_filter'],
                use_3D=cfg.settings['net']['3D'],
                in_channels=cfg.settings['net']['in_channels'],
                out_channels=cfg.settings['net']['out_channels'],
                bilinear=cfg.settings['net']['bilinear'],
                midchannels_factor=cfg.settings['net']['midchannels_factor'],
                p_dropout=cfg.settings['net']['p_dropout'])
            unet_arch.to(cfg.settings['device'])
            logger.info(
                f"U-Net2D initialized with a depth of {cfg.settings['net']['depth']}"
                f" and a number of initial filter of {cfg.settings['net']['top_filter']},"
            )
            logger.info(
                f"Reconstruction performed with {'Upsample + Conv' if cfg.settings['net']['bilinear'] else 'ConvTranspose'}."
            )
            logger.info(
                f"U-Net2D takes {cfg.settings['net']['in_channels']} as input channels and {cfg.settings['net']['out_channels']} as output channels."
            )
            logger.info(
                f"The U-Net2D has {sum(p.numel() for p in unet_arch.parameters())} parameters."
            )

            # Make model
            unet2D = UNet2D(
                unet_arch,
                n_epoch=cfg.settings['train']['n_epoch'],
                batch_size=cfg.settings['train']['batch_size'],
                lr=cfg.settings['train']['lr'],
                lr_scheduler=getattr(torch.optim.lr_scheduler,
                                     cfg.settings['train']['lr_scheduler']),
                lr_scheduler_kwargs=cfg.settings['train']
                ['lr_scheduler_kwargs'],
                loss_fn=getattr(src.models.optim.LossFunctions,
                                cfg.settings['train']['loss_fn']),
                loss_fn_kwargs=cfg.settings['train']['loss_fn_kwargs'],
                weight_decay=cfg.settings['train']['weight_decay'],
                num_workers=cfg.settings['train']['num_workers'],
                device=cfg.settings['device'],
                print_progress=cfg.settings['print_progress'])

            # Load model if required
            if cfg.settings['train']['model_path_to_load']:
                if isinstance(cfg.settings['train']['model_path_to_load'],
                              str):
                    model_path = cfg.settings['train']['model_path_to_load']
                    unet2D.load_model(model_path,
                                      map_location=cfg.settings['device'])
                elif isinstance(cfg.settings['train']['model_path_to_load'],
                                list):
                    model_path = cfg.settings['train']['model_path_to_load'][k]
                    unet2D.load_model(model_path,
                                      map_location=cfg.settings['device'])
                else:
                    raise ValueError(
                        f'Model path to load type not understood.')
                logger.info(f"2D U-Net model loaded from {model_path}")

            # print Training hyper-parameters
            train_params = []
            for key, value in cfg.settings['train'].items():
                train_params.append(f"--> {key} : {value}")
            logger.info('Training settings:\n\t' + '\n\t'.join(train_params))

            # Train model
            eval_dataset = test_dataset if cfg.settings['train'][
                'validate_epoch'] else None
            unet2D.train(train_dataset,
                         valid_dataset=eval_dataset,
                         checkpoint_path=os.path.join(
                             out_path, f'Fold_{k+1}/checkpoint.pt'))

            # Evaluate model
            unet2D.evaluate(test_dataset,
                            save_path=os.path.join(out_path,
                                                   f'Fold_{k+1}/pred/'))

            # Save models & outputs
            unet2D.save_model(
                os.path.join(out_path, f'Fold_{k+1}/trained_unet.pt'))
            logger.info("Trained U-Net saved at " +
                        os.path.join(out_path, f'Fold_{k+1}/trained_unet.pt'))
            unet2D.save_outputs(
                os.path.join(out_path, f'Fold_{k+1}/outputs.json'))
            logger.info("Trained statistics saved at " +
                        os.path.join(out_path, f'Fold_{k+1}/outputs.json'))

            # delete checkpoint if exists
            if os.path.exists(
                    os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt')):
                os.remove(os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt'))
                logger.info('Checkpoint deleted.')

    # save mean +/- 1.96 std Dice in .txt file
    scores_list = []
    for k in range(cfg.settings['split']['n_fold']):
        with open(os.path.join(out_path, f'Fold_{k+1}/outputs.json'),
                  'r') as f:
            out = json.load(f)
        scores_list.append(
            [out['eval']['dice']['all'], out['eval']['dice']['positive']])
    means = np.array(scores_list).mean(axis=0)
    CI95 = 1.96 * np.array(scores_list).std(axis=0)
    with open(os.path.join(out_path, 'average_scores.txt'), 'w') as f:
        f.write(f'Dice = {means[0]} +/- {CI95[0]}\n')
        f.write(f'Dice (Positive) = {means[1]} +/- {CI95[1]}\n')
    logger.info('Average Scores saved at ' +
                os.path.join(out_path, 'average_scores.txt'))

    # generate dataframe of all prediction
    df_list = [
        pd.read_csv(
            os.path.join(out_path,
                         f'Fold_{i+1}/pred/volume_prediction_scores.csv'))
        for i in range(cfg.settings['split']['n_fold'])
    ]
    all_df = pd.concat(df_list, axis=0).reset_index(drop=True)
    all_df.to_csv(os.path.join(out_path, 'all_volume_prediction.csv'))
    logger.info('CSV of all volumes prediction saved at ' +
                os.path.join(out_path, 'all_volume_prediction.csv'))

    # Save config file
    cfg.settings['device'] = str(cfg.settings['device'])
    cfg.save_config(os.path.join(out_path, 'config.json'))
    logger.info("Config file saved at " +
                os.path.join(out_path, 'config.json'))

    # Analyse results
    analyse_supervised_exp(out_path,
                           cfg.settings['path']['DATA'],
                           cfg.settings['split']['n_fold'],
                           save_fn=os.path.join(out_path,
                                                'results_overview.pdf'))
    logger.info('Results overview figure saved at ' +
                os.path.join(out_path, 'results_overview.pdf'))
def main(config_path):
    """
    UNet2D pretrained on binary classification with the RSNA dataset and finetuned with the public data.
    """
    # load the config file
    cfg = AttrDict.from_json_path(config_path)

    # Make Outputs directories
    out_path = os.path.join(cfg.path.output, cfg.exp_name)
    out_path_selfsup = os.path.join(out_path, 'classification_pretrain/')
    out_path_sup = os.path.join(out_path, 'supervised_train/')
    os.makedirs(out_path_selfsup, exist_ok=True)
    for k in range(cfg.Sup.split.n_fold):
        os.makedirs(os.path.join(out_path_sup, f'Fold_{k+1}/pred/'),
                    exist_ok=True)

    # Initialize random seed
    if cfg.seed != -1:
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
        torch.backends.cudnn.deterministic = True

    # Set number of thread
    if cfg.n_thread > 0: torch.set_num_threads(cfg.n_thread)
    # set device
    if cfg.device:
        cfg.device = torch.device(cfg.device)
    else:
        cfg.device = get_available_device()

    ####################################################
    # Self-supervised training on Multi Classification #
    ####################################################
    # Initialize Logger
    logger = initialize_logger(os.path.join(out_path_selfsup, 'log.txt'))
    if os.path.exists(os.path.join(out_path_selfsup, f'checkpoint.pt')):
        logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' + '#' * 30)
    logger.info(f"Experiment : {cfg.exp_name}")

    # Load RSNA data csv
    df_rsna = pd.read_csv(os.path.join(cfg.path.data.SSL, 'slice_info.csv'),
                          index_col=0)

    # Keep only fractions sample
    if cfg.SSL.dataset.n_data_1 >= 0:
        df_rsna_ICH = df_rsna[df_rsna.Hemorrhage == 1].sample(
            n=cfg.SSL.dataset.n_data_1, random_state=cfg.seed)
    else:
        df_rsna_ICH = df_rsna[df_rsna.Hemorrhage == 1]
    if cfg.SSL.dataset.f_data_0 >= 0:  # nbr of normal as fraction of nbr of ICH samples
        df_rsna_noICH = df_rsna[df_rsna.Hemorrhage == 0].sample(
            n=cfg.SSL.dataset.f_data_0 * len(df_rsna_ICH),
            random_state=cfg.seed)
    else:
        df_rsna_noICH = df_rsna[df_rsna.Hemorrhage == 0]
    df_rsna = pd.concat([df_rsna_ICH, df_rsna_noICH], axis=0)

    # Split data to keep few for evaluation in a strafied way
    train_df, test_df = train_test_split(df_rsna,
                                         test_size=cfg.SSL.dataset.frac_eval,
                                         stratify=df_rsna.Hemorrhage,
                                         random_state=cfg.seed)
    logger.info('\n' +
                str(get_split_summary_table(df_rsna, train_df, test_df)))

    # Make dataset : Train --> BinaryClassification, Test --> BinaryClassification
    train_RSNA_dataset = RSNA_dataset(
        train_df,
        cfg.path.data.SSL,
        augmentation_transform=[
            getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
            cfg.SSL.dataset.augmentation.train.items()
        ],
        window=(cfg.data.win_center, cfg.data.win_width),
        output_size=cfg.data.size,
        mode='multi_classification')
    test_RSNA_dataset = RSNA_dataset(
        test_df,
        cfg.path.data.SSL,
        augmentation_transform=[
            getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
            cfg.SSL.dataset.augmentation.eval.items()
        ],
        window=(cfg.data.win_center, cfg.data.win_width),
        output_size=cfg.data.size,
        mode='multi_classification')

    logger.info(f"Data will be loaded from {cfg.path.data.SSL}.")
    logger.info(
        f"CT scans will be windowed on [{cfg.data.win_center-cfg.data.win_width/2} ; {cfg.data.win_center + cfg.data.win_width/2}]"
    )
    logger.info(f"CT scans will be resized to {cfg.data.size}x{cfg.data.size}")
    logger.info(
        f"Training online data transformation: \n\n {str(train_RSNA_dataset.transform)}\n"
    )
    logger.info(
        f"Evaluation online data transformation: \n\n {str(test_RSNA_dataset.transform)}\n"
    )

    # Make U-Net-Encoder architecture
    net_ssl = UNet_Encoder(**cfg.SSL.net).to(cfg.device)
    net_params = [f"--> {k} : {v}" for k, v in cfg.SSL.net.items()]
    logger.info("UNet like Multi Classifier \n\t" + "\n\t".join(net_params))
    logger.info(
        f"The Multi Classifier has {sum(p.numel() for p in net_ssl.parameters())} parameters."
    )

    # Make Model
    cfg.SSL.train.model_param.lr_scheduler = getattr(
        torch.optim.lr_scheduler, cfg.SSL.train.model_param.lr_scheduler
    )  # convert scheduler name to scheduler class object
    if cfg.SSL.train.model_param.loss_fn == 'BCEWithLogitsLoss':
        df_rsna['no_Hemorrhage'] = 1 - df_rsna.Hemorrhage
        class_weight_list = (
            (len(df_rsna) - df_rsna[train_RSNA_dataset.class_name].sum()) /
            df_rsna[train_RSNA_dataset.class_name].sum()
        ).values  # define CE weighting from train dataset
        cfg.SSL.train.model_param.loss_fn_kwargs['pos_weight'] = torch.tensor(
            class_weight_list, device=cfg.device)  # add weighting to CE kwargs
    try:
        cfg.SSL.train.model_param.loss_fn = getattr(
            torch.nn, cfg.SSL.train.model_param.loss_fn
        )  # convert loss_fn name to nn.Module class object
    except AttributeError:
        cfg.SSL.train.model_param.loss_fn = getattr(
            src.models.optim.LossFunctions, cfg.SSL.train.model_param.loss_fn)

    #torch.tensor([1 - w_ICH, w_ICH], device=cfg.device).float()

    classifier = MultiClassifier(net_ssl,
                                 device=cfg.device,
                                 print_progress=cfg.print_progress,
                                 **cfg.SSL.train.model_param)

    train_params = [
        f"--> {k} : {v}" for k, v in cfg.SSL.train.model_param.items()
    ]
    logger.info("Classifer Training Parameters \n\t" +
                "\n\t".join(train_params))

    # Load weights if specified
    if cfg.SSL.train.model_path_to_load:
        model_path = cfg.SSL.train.model_path_to_load
        classifier.load_model(model_path, map_location=cfg.device)
        logger.info(
            f"Classifer Model succesfully loaded from {cfg.SSL.train.model_path_to_load}"
        )

    # train if needed
    if cfg.SSL.train.model_param.n_epoch > 0:
        classifier.train(train_RSNA_dataset,
                         valid_dataset=test_RSNA_dataset,
                         checkpoint_path=os.path.join(out_path_selfsup,
                                                      f'checkpoint.pt'))

    # evaluate
    auc, acc, sub_acc, recall, precision, f1 = classifier.evaluate(
        test_RSNA_dataset, save_tsne=True, return_scores=True)
    logger.info(f"Classifier Test AUC : {auc:.2%}")
    logger.info(f"Classifier Test Accuracy : {acc:.2%}")
    logger.info(f"Classifier Test Subset Accuracy : {sub_acc:.2%}")
    logger.info(f"Classifier Test Recall : {recall:.2%}")
    logger.info(f"Classifier Test Precision : {precision:.2%}")
    logger.info(f"Classifier Test F1-score : {f1:.2%}")

    # save model, outputs
    classifier.save_model_state_dict(
        os.path.join(out_path_selfsup, 'pretrained_unet_enc.pt'))
    logger.info(
        "Pre-trained U-Net encoder on binary classification saved at " +
        os.path.join(out_path_selfsup, 'pretrained_unet_enc.pt'))
    classifier.save_outputs(os.path.join(out_path_selfsup, 'outputs.json'))
    logger.info("Classifier outputs saved at " +
                os.path.join(out_path_selfsup, 'outputs.json'))
    test_df.reset_index(drop=True).to_csv(
        os.path.join(out_path_selfsup, 'eval_data_info.csv'))
    logger.info("Evaluation data info saved at " +
                os.path.join(out_path_selfsup, 'eval_data_info.csv'))

    # delete any checkpoints
    if os.path.exists(os.path.join(out_path_selfsup, f'checkpoint.pt')):
        os.remove(os.path.join(out_path_selfsup, f'checkpoint.pt'))
        logger.info('Checkpoint deleted.')

    # get weights state dictionnary
    pretrained_unet_weights = classifier.get_state_dict()

    ###################################################################
    # Supervised fine-training of U-Net  with K-Fold Cross-Validation #
    ###################################################################
    # load annotated data csv
    data_info_df = pd.read_csv(os.path.join(cfg.path.data.Sup, 'ct_info.csv'),
                               index_col=0)
    patient_df = pd.read_csv(os.path.join(cfg.path.data.Sup,
                                          'patient_info.csv'),
                             index_col=0)

    # Make K-Fold spolit at patient level
    skf = StratifiedKFold(n_splits=cfg.Sup.split.n_fold,
                          shuffle=cfg.Sup.split.shuffle,
                          random_state=cfg.seed)

    # define scheduler and loss_fn as object
    cfg.Sup.train.model_param.lr_scheduler = getattr(
        torch.optim.lr_scheduler, cfg.Sup.train.model_param.lr_scheduler
    )  # convert scheduler name to scheduler class object
    cfg.Sup.train.model_param.loss_fn = getattr(
        src.models.optim.LossFunctions, cfg.Sup.train.model_param.loss_fn
    )  # convert loss_fn name to nn.Module class object

    # iterate over folds
    for k, (train_idx, test_idx) in enumerate(
            skf.split(patient_df.PatientNumber, patient_df.Hemorrhage)):
        # check if fold's results already exists
        if not os.path.exists(
                os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json')):
            # initialize logger
            logger = initialize_logger(
                os.path.join(out_path_sup, f'Fold_{k+1}/log.txt'))
            if os.path.exists(
                    os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')):
                logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' +
                            '#' * 30)
            logger.info(f"Experiment : {cfg['exp_name']}")
            logger.info(
                f"Cross-Validation fold {k+1:02}/{cfg['Sup']['split']['n_fold']:02}"
            )

            # extract train/test slice dataframe
            train_df = data_info_df[data_info_df.PatientNumber.isin(
                patient_df.loc[train_idx, 'PatientNumber'].values)]
            test_df = data_info_df[data_info_df.PatientNumber.isin(
                patient_df.loc[test_idx, 'PatientNumber'].values)]
            # samples train dataframe to adjuste negative/positive fractions
            n_remove = int(
                max(
                    0,
                    len(train_df[train_df.Hemorrhage == 0]) -
                    cfg.Sup.dataset.frac_negative *
                    len(train_df[train_df.Hemorrhage == 1])))
            df_remove = train_df[train_df.Hemorrhage == 0].sample(
                n=n_remove, random_state=cfg.seed)
            train_df = train_df[~train_df.index.isin(df_remove.index)]
            logger.info(
                '\n' +
                str(get_split_summary_table(data_info_df, train_df, test_df)))

            # Make datasets
            train_dataset = public_SegICH_Dataset2D(
                train_df,
                cfg.path.data.Sup,
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg.Sup.dataset.augmentation.train.items()
                ],
                window=(cfg.data.win_center, cfg.data.win_width),
                output_size=cfg.data.size)
            test_dataset = public_SegICH_Dataset2D(
                test_df,
                cfg.path.data.Sup,
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg.Sup.dataset.augmentation.eval.items()
                ],
                window=(cfg.data.win_center, cfg.data.win_width),
                output_size=cfg.data.size)
            logger.info(f"Data will be loaded from {cfg.path.data.Sup}.")
            logger.info(
                f"CT scans will be windowed on [{cfg.data.win_center-cfg.data.win_width/2} ; {cfg.data.win_center + cfg.data.win_width/2}]"
            )
            logger.info(
                f"CT scans will be resized to {cfg.data.size}x{cfg.data.size}")
            logger.info(
                f"Training online data transformation: \n\n {str(train_dataset.transform)}\n"
            )
            logger.info(
                f"Evaluation online data transformation: \n\n {str(test_dataset.transform)}\n"
            )

            # Make U-Net architecture
            unet_sup = UNet(**cfg.Sup.net).to(cfg.device)
            net_params = [f"--> {k} : {v}" for k, v in cfg.Sup.net.items()]
            logger.info("UNet-2D params \n\t" + "\n\t".join(net_params))
            logger.info(
                f"The U-Net2D has {sum(p.numel() for p in unet_sup.parameters())} parameters."
            )

            # Make Model
            unet2D = UNet2D(unet_sup,
                            device=cfg.device,
                            print_progress=cfg.print_progress,
                            **cfg.Sup.train.model_param)

            train_params = [
                f"--> {k} : {v}" for k, v in cfg.Sup.train.model_param.items()
            ]
            logger.info("UNet-2D Training Parameters \n\t" +
                        "\n\t".join(train_params))

            # ????? load model if specified ?????

            # transfer weights learn with context restoration
            logger.info(
                'Initialize U-Net2D with weights learned with context_restoration on RSNA.'
            )
            unet2D.transfer_weights(pretrained_unet_weights, verbose=True)

            # Train U-net
            eval_dataset = test_dataset if cfg.Sup.train.validate_epoch else None
            unet2D.train(train_dataset,
                         valid_dataset=eval_dataset,
                         checkpoint_path=os.path.join(
                             out_path_sup, f'Fold_{k+1}/checkpoint.pt'))

            # Evaluate U-Net
            unet2D.evaluate(test_dataset,
                            save_path=os.path.join(out_path_sup,
                                                   f'Fold_{k+1}/pred/'))

            # Save models and outputs
            unet2D.save_model(
                os.path.join(out_path_sup, f'Fold_{k+1}/trained_unet.pt'))
            logger.info(
                "Trained U-Net saved at " +
                os.path.join(out_path_sup, f'Fold_{k+1}/trained_unet.pt'))
            unet2D.save_outputs(
                os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json'))
            logger.info("Trained statistics saved at " +
                        os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json'))

            # delete checkpoint if exists
            if os.path.exists(
                    os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')):
                os.remove(
                    os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt'))
                logger.info('Checkpoint deleted.')

    # save mean +/- 1.96 std Dice over Folds
    save_mean_fold_dice(out_path_sup, cfg.Sup.split.n_fold)
    logger.info('Average Scores saved at ' +
                os.path.join(out_path_sup, 'average_scores.txt'))

    # Save all volumes prediction csv
    df_list = [
        pd.read_csv(
            os.path.join(out_path_sup,
                         f'Fold_{i+1}/pred/volume_prediction_scores.csv'))
        for i in range(cfg.Sup.split.n_fold)
    ]
    all_df = pd.concat(df_list, axis=0).reset_index(drop=True)
    all_df.to_csv(os.path.join(out_path_sup, 'all_volume_prediction.csv'))
    logger.info('CSV of all volumes prediction saved at ' +
                os.path.join(out_path_sup, 'all_volume_prediction.csv'))

    # Save config file
    cfg.device = str(cfg.device)
    cfg.SSL.train.model_param.lr_scheduler = str(
        cfg.SSL.train.model_param.lr_scheduler)
    cfg.Sup.train.model_param.lr_scheduler = str(
        cfg.Sup.train.model_param.lr_scheduler)
    cfg.SSL.train.model_param.loss_fn = str(cfg.SSL.train.model_param.loss_fn)
    cfg.Sup.train.model_param.loss_fn = str(cfg.Sup.train.model_param.loss_fn)
    cfg.SSL.train.model_param.loss_fn_kwargs.pos_weight = cfg.SSL.train.model_param.loss_fn_kwargs.pos_weight.cpu(
    ).data.tolist()
    with open(os.path.join(out_path, 'config.json'), 'w') as fp:
        json.dump(cfg, fp)
    logger.info('Config file saved at ' +
                os.path.join(out_path, 'config.json'))

    # Analyse results
    analyse_supervised_exp(out_path_sup,
                           cfg.path.data.Sup,
                           n_fold=cfg.Sup.split.n_fold,
                           config_folder=out_path,
                           save_fn=os.path.join(
                               out_path, 'results_supervised_overview.pdf'))
    logger.info('Results overview figure saved at ' +
                os.path.join(out_path, 'results_supervised_overview.pdf'))
    analyse_representation_exp(out_path_selfsup,
                               save_fn=os.path.join(
                                   out_path,
                                   'results_self-supervised_overview.pdf'))
    logger.info('Results overview figure saved at ' +
                os.path.join(out_path, 'results_self-supervised_overview.pdf'))
Exemple #3
0
def main(config_path):
    """
    Train and evaluate a 2D UNet on the public ICH dataset with the anomaly attention map using the parameters on the
    JSON at the config_path. The evaluation is performed by k-fold cross-validation.
    """
    # load config file
    cfg = AttrDict.from_json_path(config_path)

    # Make Output directories
    out_path = os.path.join(cfg.path.output, cfg.exp_name)
    os.makedirs(out_path, exist_ok=True)
    for k in range(cfg.split.n_fold):
        os.makedirs(os.path.join(out_path, f'Fold_{k+1}/pred/'), exist_ok=True)

    # Initialize random seed to given seed
    if cfg.seed != -1:
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
        torch.backends.cudnn.deterministic = True

    # Load data csv
    data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'info.csv'),
                               index_col=0)
    patient_df = pd.read_csv(os.path.join(cfg.path.data, 'patient_info.csv'),
                             index_col=0)

    # Generate Cross-Val indices at the patient level
    skf = StratifiedKFold(n_splits=cfg.split.n_fold,
                          shuffle=cfg.split.shuffle,
                          random_state=cfg.seed)
    # iterate over folds and ensure that there are the same amount of ICH positive patient per fold --> Stratiffied CrossVal
    for k, (train_idx, test_idx) in enumerate(
            skf.split(patient_df.PatientNumber, patient_df.Hemorrhage)):
        # if fold results not already there
        if not os.path.exists(
                os.path.join(out_path, f'Fold_{k+1}/outputs.json')):
            # initialize logger
            logger = initialize_logger(os.path.join(out_path, 'log.txt'))
            if os.path.exists(
                    os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt')):
                logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' +
                            '#' * 30)
            logger.info(f"Experiment : {cfg.exp_name}")
            logger.info(
                f"Cross-Validation fold {k+1:02}/{cfg.split.n_fold:02}")

            # check if GPU available
            if cfg.device is not None:
                cfg.device = torch.device(cfg.device)
            else:
                cfg.device = torch.device('cuda') if torch.cuda.is_available(
                ) else torch.device('cpu')
            logger.info(f"Device : {cfg.device}")

            # extract train and test DataFrames + print summary (n samples positive and negatives)
            train_df = data_info_df[data_info_df.id.isin(
                patient_df.loc[train_idx, 'PatientNumber'].values)]
            test_df = data_info_df[data_info_df.id.isin(
                patient_df.loc[test_idx, 'PatientNumber'].values)]
            # sample the dataframe to have more or less normal slices
            n_remove = int(
                max(
                    0,
                    len(train_df[train_df.Hemorrhage == 0]) -
                    cfg.dataset.frac_negative *
                    len(train_df[train_df.Hemorrhage == 1])))
            df_remove = train_df[train_df.Hemorrhage == 0].sample(
                n=n_remove, random_state=cfg.seed)
            train_df = train_df[~train_df.index.isin(df_remove.index)]
            logger.info(
                '\n' +
                str(get_split_summary_table(data_info_df, train_df, test_df)))

            # Make Dataset + print online augmentation summary
            train_dataset = public_SegICH_AttentionDataset2D(
                train_df,
                cfg.path.data,
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg.data.augmentation.train.items()
                ],
                window=(cfg.data.win_center, cfg.data.win_width),
                output_size=cfg.data.size)
            test_dataset = public_SegICH_AttentionDataset2D(
                test_df,
                cfg.path.data,
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg.data.augmentation.eval.items()
                ],
                window=(cfg.data.win_center, cfg.data.win_width),
                output_size=cfg.data.size)
            logger.info(f"Data will be loaded from {cfg.path.data}.")
            logger.info(
                f"CT scans will be windowed on [{cfg.data.win_center-cfg.data.win_width/2} ; {cfg.data.win_center + cfg.data.win_width/2}]"
            )
            logger.info(
                f"Training online data transformation: \n\n {str(train_dataset.transform)}\n"
            )
            logger.info(
                f"Evaluation online data transformation: \n\n {str(test_dataset.transform)}\n"
            )

            # Make architecture (and print summmary ??)
            unet_arch = UNet(**cfg.net)
            unet_arch.to(cfg.device)
            net_params = [f"--> {k} : {v}" for k, v in cfg.net.items()]
            logger.info("U-Net2D architecture \n\t" + "\n\t".join(net_params))
            logger.info(
                f"The U-Net2D has {sum(p.numel() for p in unet_arch.parameters())} parameters."
            )

            # Make model
            cfg_train = AttrDict(cfg.train.params)
            cfg_train.lr_scheduler = getattr(torch.optim.lr_scheduler,
                                             cfg_train.lr_scheduler)
            cfg_train.loss_fn = getattr(src.models.optim.LossFunctions,
                                        cfg_train.loss_fn)
            unet2D = UNet2D(unet_arch,
                            device=cfg.device,
                            print_progress=cfg.print_progress,
                            **cfg_train)
            # print Training hyper-parameters
            train_params = [f"--> {k} : {v}" for k, v in cfg_train.items()]
            logger.info("U-Net2D Training Parameters \n\t" +
                        "\n\t".join(train_params))

            # Load model if required
            if cfg.train.model_path_to_load:
                if isinstance(cfg.train.model_path_to_load, str):
                    model_path = cfg.train.model_path_to_load
                    unet2D.load_model(model_path, map_location=cfg.device)
                elif isinstance(cfg.train.model_path_to_load, list):
                    model_path = cfg.train.model_path_to_load[k]
                    unet2D.load_model(model_path, map_location=cfg.device)
                else:
                    raise ValueError(
                        f'Model path to load type not understood.')
                logger.info(f"2D U-Net model loaded from {model_path}")

            # Train model
            eval_dataset = test_dataset if cfg.train.validate_epoch else None
            unet2D.train(train_dataset,
                         valid_dataset=eval_dataset,
                         checkpoint_path=os.path.join(
                             out_path, f'Fold_{k+1}/checkpoint.pt'))

            # Evaluate model
            unet2D.evaluate(test_dataset,
                            save_path=os.path.join(out_path,
                                                   f'Fold_{k+1}/pred/'))

            # Save models & outputs
            unet2D.save_model(
                os.path.join(out_path, f'Fold_{k+1}/trained_unet.pt'))
            logger.info("Trained U-Net saved at " +
                        os.path.join(out_path, f'Fold_{k+1}/trained_unet.pt'))
            unet2D.save_outputs(
                os.path.join(out_path, f'Fold_{k+1}/outputs.json'))
            logger.info("Trained statistics saved at " +
                        os.path.join(out_path, f'Fold_{k+1}/outputs.json'))

            # delete checkpoint if exists
            if os.path.exists(
                    os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt')):
                os.remove(os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt'))
                logger.info('Checkpoint deleted.')

    # save mean +/- 1.96 std Dice in .txt file
    scores_list = []
    for k in range(cfg.split.n_fold):
        with open(os.path.join(out_path, f'Fold_{k+1}/outputs.json'),
                  'r') as f:
            out = json.load(f)
        scores_list.append(
            [out['eval']['dice']['all'], out['eval']['dice']['positive']])
    means = np.array(scores_list).mean(axis=0)
    CI95 = 1.96 * np.array(scores_list).std(axis=0)
    with open(os.path.join(out_path, 'average_scores.txt'), 'w') as f:
        f.write(f'Dice = {means[0]} +/- {CI95[0]}\n')
        f.write(f'Dice (Positive) = {means[1]} +/- {CI95[1]}\n')
    logger.info('Average Scores saved at ' +
                os.path.join(out_path, 'average_scores.txt'))

    # generate dataframe of all prediction
    df_list = [
        pd.read_csv(
            os.path.join(out_path,
                         f'Fold_{i+1}/pred/volume_prediction_scores.csv'))
        for i in range(cfg.split.n_fold)
    ]
    all_df = pd.concat(df_list, axis=0).reset_index(drop=True)
    all_df.to_csv(os.path.join(out_path, 'all_volume_prediction.csv'))
    logger.info('CSV of all volumes prediction saved at ' +
                os.path.join(out_path, 'all_volume_prediction.csv'))

    # Save config file
    cfg.device = str(cfg.device)
    #cfg.train.params.lr_scheduler = str(cfg.train.params.lr_scheduler)
    #cfg.train.params.loss_fn = str(cfg.train.params.loss_fn)
    with open(os.path.join(out_path, 'config.json'), 'w') as fp:
        json.dump(cfg, fp)
    logger.info("Config file saved at " +
                os.path.join(out_path, 'config.json'))

    # Analyse results
    analyse_supervised_exp(out_path,
                           cfg.path.data,
                           cfg.split.n_fold,
                           save_fn=os.path.join(out_path,
                                                'results_overview.pdf'))
    logger.info('Results overview figure saved at ' +
                os.path.join(out_path, 'results_overview.pdf'))
def main(config_path):
    """
    UNet2D pretrained on context retoration with the RSNA dataset and finetuned with the public data.
    """
    # load the config file
    with open(config_path, 'r') as fp:
        cfg = json.load(fp)

    # Make Outputs directories
    out_path = os.path.join(
        cfg['path']['output'],
        cfg['exp_name'])  # + datetime.now().strftime('_%Y-%m-%d'))
    out_path_selfsup = os.path.join(out_path, 'context_restoration_pretrain/')
    out_path_sup = os.path.join(out_path, 'supervised_train/')
    os.makedirs(out_path_selfsup, exist_ok=True)
    for k in range(cfg['Sup']['split']['n_fold']):
        os.makedirs(os.path.join(out_path_sup, f'Fold_{k+1}/pred/'),
                    exist_ok=True)

    # Initialize random seed
    seed = cfg['seed']
    if seed != -1:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

    # Set number of thread
    if cfg['n_thread'] > 0: torch.set_num_threads(cfg['n_thread'])
    # check if GPU available
    #cfg['device'] = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    if cfg['device'] is not None:
        cfg['device'] = torch.device(cfg['device'])
    else:
        if torch.cuda.is_available():
            free_mem, device_idx = 0.0, 0
            for d in range(torch.cuda.device_count()):
                mem = torch.cuda.get_device_properties(
                    d).total_memory - torch.cuda.memory_allocated(d)
                if mem > free_mem:
                    device_idx = d
                    free_mem = mem
            cfg['device'] = torch.device(f'cuda:{device_idx}')
        else:
            cfg['device'] = torch.device('cpu')

    ###################################################
    # Self-supervised training on Context Restoration #
    ###################################################
    # Initialize Logger
    logger = initialize_logger(os.path.join(out_path_selfsup, 'log.txt'))
    if os.path.exists(os.path.join(out_path_selfsup, f'checkpoint.pt')):
        logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' + '#' * 30)
    logger.info(f"Experiment : {cfg['exp_name']}")

    # Load RSNA data csv
    df_rsna = pd.read_csv(os.path.join(cfg['path']['data']['SSL'],
                                       'slice_info.csv'),
                          index_col=0)

    # Keep only fractions of negative/positive
    df_rsna_pos = df_rsna[df_rsna.Hemorrhage == 1]
    if cfg['SSL']['dataset']['n_positive'] >= 0:
        df_rsna_pos = df_rsna_pos.sample(n=cfg['SSL']['dataset']['n_positive'],
                                         random_state=seed)
    n_neg = int(cfg['SSL']['dataset']['frac_negative'] * len(df_rsna_pos))
    df_rsna_neg = df_rsna[df_rsna.Hemorrhage == 0].sample(n=n_neg,
                                                          random_state=seed)
    df_rsna_samp = pd.concat([df_rsna_pos, df_rsna_neg], axis=0)

    # Split data to keep few for evaluation
    test_df = df_rsna_samp.sample(n=cfg['SSL']['dataset']['n_eval'],
                                  random_state=seed)
    train_df = df_rsna_samp.drop(test_df.index)
    df_rsna_neg = df_rsna[(df_rsna.Hemorrhage == 0)]
    test_df = test_df.append(
        df_rsna_neg[~df_rsna_neg.isin(train_df.index)].sample(
            n=cfg['SSL']['dataset']['n_eval_neg'], random_state=seed))
    test_df, train_df = test_df.reset_index(), train_df.reset_index()
    logger.info('\n' +
                str(get_split_summary_table(df_rsna, train_df, test_df)))

    # Make dataset : Train --> ContextRestoration, Test --> Standard
    train_RSNA_dataset = RSNA_dataset(
        train_df,
        cfg['path']['data']['SSL'],
        augmentation_transform=[
            getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
            cfg['SSL']['dataset']['augmentation']['train'].items()
        ],
        window=(cfg['data']['win_center'], cfg['data']['win_width']),
        output_size=cfg['data']['size'],
        mode='context_restoration',
        n_swap=cfg['SSL']['dataset']['n_swap'],
        swap_w=cfg['SSL']['dataset']['swap_w'],
        swap_h=cfg['SSL']['dataset']['swap_h'],
        swap_rot=cfg['SSL']['dataset']['swap_rotate'])
    test_RSNA_dataset = RSNA_dataset(
        test_df,
        cfg['path']['data']['SSL'],
        augmentation_transform=[
            getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
            cfg['SSL']['dataset']['augmentation']['eval'].items()
        ],
        window=(cfg['data']['win_center'], cfg['data']['win_width']),
        output_size=cfg['data']['size'],
        mode='standard')
    logger.info(f"Data will be loaded from {cfg['path']['data']['SSL']}.")
    logger.info(
        f"CT scans will be windowed on [{cfg['data']['win_center']-cfg['data']['win_width']/2} ; {cfg['data']['win_center'] + cfg['data']['win_width']/2}]"
    )
    logger.info(
        f"CT scans will be resized to {cfg['data']['size']}x{cfg['data']['size']}"
    )
    logger.info(
        f"Training online data transformation: \n\n {str(train_RSNA_dataset.transform)}\n"
    )
    logger.info(
        f"Evaluation online data transformation: \n\n {str(test_RSNA_dataset.transform)}\n"
    )
    logger.info(
        f"Input images will be corrupted by {cfg['SSL']['dataset']['n_swap']} swap of dimension {cfg['SSL']['dataset']['swap_h']}x{cfg['SSL']['dataset']['swap_w']} (h x w)"
    )

    # Make U-Net architecture
    net_ssl = UNet(depth=cfg['SSL']['net']['depth'],
                   top_filter=cfg['SSL']['net']['top_filter'],
                   use_3D=cfg['SSL']['net']['3D'],
                   in_channels=cfg['SSL']['net']['in_channels'],
                   out_channels=cfg['SSL']['net']['out_channels'],
                   bilinear=cfg['SSL']['net']['bilinear'],
                   use_final_activation=cfg['SSL']['net']['final_activation'],
                   midchannels_factor=cfg['SSL']['net']['midchannels_factor'],
                   p_dropout=cfg['SSL']['net']['p_dropout'])
    net_ssl.to(cfg['device'])
    logger.info(
        f"U-Net2D initialized with a depth of {cfg['SSL']['net']['depth']}"
        f" and a number of initial filter of {cfg['SSL']['net']['top_filter']},"
    )
    logger.info(
        f"Reconstruction performed with {'Upsample + Conv' if cfg['SSL']['net']['bilinear'] else 'ConvTranspose'}."
    )
    logger.info(
        f"U-Net2D takes {cfg['SSL']['net']['in_channels']} as input channels and {cfg['SSL']['net']['out_channels']} as output channels."
    )
    logger.info(
        f"The U-Net2D has {sum(p.numel() for p in net_ssl.parameters())} parameters."
    )

    # Make Model
    ctx_restor = ContextRestoration(
        net_ssl,
        n_epoch=cfg['SSL']['train']['n_epoch'],
        batch_size=cfg['SSL']['train']['batch_size'],
        lr=cfg['SSL']['train']['lr'],
        lr_scheduler=getattr(torch.optim.lr_scheduler,
                             cfg['SSL']['train']['lr_scheduler']),
        lr_scheduler_kwargs=cfg['SSL']['train']['lr_scheduler_kwargs'],
        loss_fn=getattr(torch.nn, cfg['SSL']['train']['loss_fn']),
        loss_fn_kwargs=cfg['SSL']['train']['loss_fn_kwargs'],
        weight_decay=cfg['SSL']['train']['weight_decay'],
        num_workers=cfg['SSL']['train']['num_workers'],
        device=cfg['device'],
        print_progress=cfg['print_progress'])

    # Load weights if specified
    if cfg['SSL']['train']['model_path_to_load']:
        model_path = cfg['SSL']['train']['model_path_to_load']
        ctx_restor.load_model(model_path, map_location=cfg['device'])

    # train if needed
    if cfg['SSL']['train']['n_epoch'] > 0:
        train_params = []
        for key, value in cfg['SSL']['train'].items():
            train_params.append(f"--> {key} : {value}")
        logger.info('Training settings:\n\t' + '\n\t'.join(train_params))

        ctx_restor.train(train_RSNA_dataset,
                         checkpoint_path=os.path.join(out_path_selfsup,
                                                      f'checkpoint.pt'))

    # evaluate
    ctx_restor.evaluate(test_RSNA_dataset)

    # save model, outputs and evaluation data info (test_df)
    ctx_restor.save_model(os.path.join(out_path_selfsup, 'pretrained_unet.pt'))
    logger.info("Pre-trained U-Net on context restoration saved at " +
                os.path.join(out_path_selfsup, 'pretrained_unet.pt'))
    ctx_restor.save_outputs(os.path.join(out_path_selfsup, 'outputs.json'))
    logger.info("Context restoration outputs saved at " +
                os.path.join(out_path_selfsup, 'outputs.json'))
    test_df.to_csv(os.path.join(out_path_selfsup, 'eval_data_info.csv'))
    logger.info("Evaluation data info saved at " +
                os.path.join(out_path_selfsup, 'eval_data_info.csv'))

    # delete any checkpoints
    if os.path.exists(os.path.join(out_path_selfsup, f'checkpoint.pt')):
        os.remove(os.path.join(out_path_selfsup, f'checkpoint.pt'))
        logger.info('Checkpoint deleted.')

    # get weights state dictionnary
    pretrained_unet_weights = ctx_restor.get_state_dict()

    ###################################################################
    # Supervised fine-training of U-Net  with K-Fold Cross-Validation #
    ###################################################################
    # load annotated data csv
    data_info_df = pd.read_csv(os.path.join(cfg['path']['data']['Sup'],
                                            'ct_info.csv'),
                               index_col=0)
    patient_df = pd.read_csv(os.path.join(cfg['path']['data']['Sup'],
                                          'patient_info.csv'),
                             index_col=0)

    # Make K-Fold spolit at patient level
    skf = StratifiedKFold(n_splits=cfg['Sup']['split']['n_fold'],
                          shuffle=cfg['Sup']['split']['shuffle'],
                          random_state=seed)

    # iterate over folds
    for k, (train_idx, test_idx) in enumerate(
            skf.split(patient_df.PatientNumber, patient_df.Hemorrhage)):
        # check if fold's results already exists
        if not os.path.exists(
                os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json')):
            # initialize logger
            logger = initialize_logger(
                os.path.join(out_path_sup, f'Fold_{k+1}/log.txt'))
            if os.path.exists(
                    os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')):
                logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' +
                            '#' * 30)
            logger.info(f"Experiment : {cfg['exp_name']}")
            logger.info(
                f"Cross-Validation fold {k+1:02}/{cfg['Sup']['split']['n_fold']:02}"
            )

            # extract train/test slice dataframe
            train_df = data_info_df[data_info_df.PatientNumber.isin(
                patient_df.loc[train_idx, 'PatientNumber'].values)]
            test_df = data_info_df[data_info_df.PatientNumber.isin(
                patient_df.loc[test_idx, 'PatientNumber'].values)]
            # samples train dataframe to adjuste negative/positive fractions
            n_remove = int(
                max(
                    0,
                    len(train_df[train_df.Hemorrhage == 0]) -
                    cfg['Sup']['dataset']['frac_negative'] *
                    len(train_df[train_df.Hemorrhage == 1])))
            df_remove = train_df[train_df.Hemorrhage == 0].sample(
                n=n_remove, random_state=seed)
            train_df = train_df[~train_df.index.isin(df_remove.index)]
            logger.info(
                '\n' +
                str(get_split_summary_table(data_info_df, train_df, test_df)))

            # Make datasets
            train_dataset = public_SegICH_Dataset2D(
                train_df,
                cfg['path']['data']['Sup'],
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg['Sup']['dataset']['augmentation']['train'].items()
                ],
                window=(cfg['data']['win_center'], cfg['data']['win_width']),
                output_size=cfg['data']['size'])
            test_dataset = public_SegICH_Dataset2D(
                test_df,
                cfg['path']['data']['Sup'],
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg['Sup']['dataset']['augmentation']['eval'].items()
                ],
                window=(cfg['data']['win_center'], cfg['data']['win_width']),
                output_size=cfg['data']['size'])
            logger.info(
                f"Data will be loaded from {cfg['path']['data']['Sup']}.")
            logger.info(
                f"CT scans will be windowed on [{cfg['data']['win_center']-cfg['data']['win_width']/2} ; {cfg['data']['win_center'] + cfg['data']['win_width']/2}]"
            )
            logger.info(
                f"CT scans will be resized to {cfg['data']['size']}x{cfg['data']['size']}"
            )
            logger.info(
                f"Training online data transformation: \n\n {str(train_dataset.transform)}\n"
            )
            logger.info(
                f"Evaluation online data transformation: \n\n {str(test_dataset.transform)}\n"
            )

            # Make U-Net architecture
            unet_sup = UNet(
                depth=cfg['Sup']['net']['depth'],
                top_filter=cfg['Sup']['net']['top_filter'],
                use_3D=cfg['Sup']['net']['3D'],
                in_channels=cfg['Sup']['net']['in_channels'],
                out_channels=cfg['Sup']['net']['out_channels'],
                bilinear=cfg['Sup']['net']['bilinear'],
                midchannels_factor=cfg['Sup']['net']['midchannels_factor'],
                p_dropout=cfg['Sup']['net']['p_dropout'])
            unet_sup.to(cfg['device'])
            logger.info(
                f"U-Net2D initialized with a depth of {cfg['Sup']['net']['depth']}"
                f" and a number of initial filter of {cfg['Sup']['net']['top_filter']},"
            )
            logger.info(
                f"Reconstruction performed with {'Upsample + Conv' if cfg['Sup']['net']['bilinear'] else 'ConvTranspose'}."
            )
            logger.info(
                f"U-Net2D takes {cfg['Sup']['net']['in_channels']} as input channels and {cfg['Sup']['net']['out_channels']} as output channels."
            )
            logger.info(
                f"The U-Net2D has {sum(p.numel() for p in unet_sup.parameters())} parameters."
            )

            # Make Model
            unet2D = UNet2D(
                unet_sup,
                n_epoch=cfg['Sup']['train']['n_epoch'],
                batch_size=cfg['Sup']['train']['batch_size'],
                lr=cfg['Sup']['train']['lr'],
                lr_scheduler=getattr(torch.optim.lr_scheduler,
                                     cfg['Sup']['train']['lr_scheduler']),
                lr_scheduler_kwargs=cfg['Sup']['train']['lr_scheduler_kwargs'],
                loss_fn=getattr(src.models.optim.LossFunctions,
                                cfg['Sup']['train']['loss_fn']),
                loss_fn_kwargs=cfg['Sup']['train']['loss_fn_kwargs'],
                weight_decay=cfg['Sup']['train']['weight_decay'],
                num_workers=cfg['Sup']['train']['num_workers'],
                device=cfg['device'],
                print_progress=cfg['print_progress'])

            # ????? load model if specified ?????

            # transfer weights learn with context restoration
            logger.info(
                'Initialize U-Net2D with weights learned with context_restoration on RSNA.'
            )
            unet2D.transfer_weights(pretrained_unet_weights, verbose=True)

            # Print training parameters
            train_params = []
            for key, value in cfg['Sup']['train'].items():
                train_params.append(f"--> {key} : {value}")
            logger.info('Training settings:\n\t' + '\n\t'.join(train_params))

            # Train U-net
            eval_dataset = test_dataset if cfg['Sup']['train'][
                'validate_epoch'] else None
            unet2D.train(train_dataset,
                         valid_dataset=eval_dataset,
                         checkpoint_path=os.path.join(
                             out_path_sup, f'Fold_{k+1}/checkpoint.pt'))

            # Evaluate U-Net
            unet2D.evaluate(test_dataset,
                            save_path=os.path.join(out_path_sup,
                                                   f'Fold_{k+1}/pred/'))

            # Save models and outputs
            unet2D.save_model(
                os.path.join(out_path_sup, f'Fold_{k+1}/trained_unet.pt'))
            logger.info(
                "Trained U-Net saved at " +
                os.path.join(out_path_sup, f'Fold_{k+1}/trained_unet.pt'))
            unet2D.save_outputs(
                os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json'))
            logger.info("Trained statistics saved at " +
                        os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json'))

            # delete checkpoint if exists
            if os.path.exists(
                    os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')):
                os.remove(
                    os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt'))
                logger.info('Checkpoint deleted.')

    # save mean +/- 1.96 std Dice over Folds
    save_mean_fold_dice(out_path_sup, cfg['Sup']['split']['n_fold'])
    logger.info('Average Scores saved at ' +
                os.path.join(out_path_sup, 'average_scores.txt'))

    # Save all volumes prediction csv
    df_list = [
        pd.read_csv(
            os.path.join(out_path_sup,
                         f'Fold_{i+1}/pred/volume_prediction_scores.csv'))
        for i in range(cfg['Sup']['split']['n_fold'])
    ]
    all_df = pd.concat(df_list, axis=0).reset_index(drop=True)
    all_df.to_csv(os.path.join(out_path_sup, 'all_volume_prediction.csv'))
    logger.info('CSV of all volumes prediction saved at ' +
                os.path.join(out_path_sup, 'all_volume_prediction.csv'))

    # Save config file
    cfg['device'] = str(cfg['device'])
    with open(os.path.join(out_path, 'config.json'), 'w') as fp:
        json.dump(cfg, fp)
    logger.info('Config file saved at ' +
                os.path.join(out_path, 'config.json'))

    # Analyse results
    analyse_supervised_exp(out_path_sup,
                           cfg['path']['data']['Sup'],
                           n_fold=cfg['Sup']['split']['n_fold'],
                           config_folder=out_path,
                           save_fn=os.path.join(
                               out_path, 'results_supervised_overview.pdf'))
    logger.info('Results overview figure saved at ' +
                os.path.join(out_path, 'results_supervised_overview.pdf'))
    analyse_representation_exp(out_path_selfsup,
                               save_fn=os.path.join(
                                   out_path,
                                   'results_self-supervised_overview.pdf'))
    logger.info('Results overview figure saved at ' +
                os.path.join(out_path, 'results_self-supervised_overview.pdf'))