def main(config_path):
    """
    Train an FCDD on the RSNA dataset.
    """
    # Load config file
    cfg = AttrDict.from_json_path(config_path)

    # make outputs dir
    out_path = os.path.join(cfg.path.output, cfg.exp_name)
    os.makedirs(out_path, exist_ok=True)
    if cfg.train.validate_epoch:
         os.makedirs(os.path.join(out_path, 'valid_results/'), exist_ok=True)

    # initialize seed
    if cfg.seed != None:
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
        torch.backends.cudnn.deterministic = True

    # initialize logger
    logger = initialize_logger(os.path.join(out_path, 'log.txt'))
    if os.path.exists(os.path.join(out_path, f'checkpoint.pt')):
        logger.info('\n' + '#'*30 + f'\n Recovering Session \n' + '#'*30)
    logger.info(f"Experiment : {cfg.exp_name}")

    # set device
    if cfg.device:
        cfg.device = torch.device(cfg.device)
    else:
        cfg.device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    logger.info(f"Device set to {cfg.device}.")

    #--------------------------------------------------------------------
    #                           Make Datasets
    #--------------------------------------------------------------------
    # load RSNA data & keep normal only & and sample the required number
    df_rsna = pd.read_csv(os.path.join(cfg.path.data, 'slice_info.csv'), index_col=0)

    df_train = df_rsna[df_rsna.Hemorrhage == 0].sample(n=cfg.dataset.n_normal, random_state=cfg.seed)
    if cfg.dataset.n_abnormal > 0:
        df_rsna_neg = df_rsna[df_rsna.Hemorrhage == 1].sample(n=cfg.dataset.n_abnormal, random_state=cfg.seed)
        df_train = pd.concat([df_train, df_rsna_neg], axis=0)

    # df for validation
    if cfg.train.validate_epoch:
        df_rsna_remain = df_rsna[~df_rsna.index.isin(df_train.index)]
        df_valid = df_rsna_remain[df_rsna_remain.Hemorrhage == 0].sample(n=cfg.dataset.n_normal_valid, random_state=cfg.seed)
        if cfg.dataset.n_abnormal_valid > 0:
            df_rsna_neg = df_rsna_remain[df_rsna_remain.Hemorrhage == 1].sample(n=cfg.dataset.n_abnormal_valid, random_state=cfg.seed)
            df_valid = pd.concat([df_valid, df_rsna_neg], axis=0)

    # Make FCDD dataset
    train_dataset = RSNA_FCDD_dataset(df_train, cfg.path.data, artificial_anomaly=cfg.dataset.artificial_anomaly,
                                      anomaly_proba=cfg.dataset.anomaly_proba,
                                      augmentation_transform=[getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.dataset.augmentation.train.items()],
                                      window=(cfg.dataset.win_center, cfg.dataset.win_width), output_size=cfg.dataset.size,
                                      drawing_params=cfg.dataset.drawing_params)
    if cfg.train.validate_epoch:
        valid_dataset = RSNA_FCDD_dataset(df_valid, cfg.path.data, artificial_anomaly=cfg.dataset.artificial_anomaly_valid,
                                          anomaly_proba=cfg.dataset.anomaly_proba,
                                          augmentation_transform=[getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.dataset.augmentation.eval.items()],
                                          window=(cfg.dataset.win_center, cfg.dataset.win_width), output_size=cfg.dataset.size,
                                          drawing_params=cfg.dataset.drawing_params)
    else:
        valid_dataset = None

    logger.info(f"Data loaded from {cfg.path.data}.")
    logger.info(f"Train set contains {len(train_dataset)} samples.")
    if valid_dataset: logger.info(f"Valid set contains {len(valid_dataset)} samples.")
    logger.info(f"CT scans will be windowed on [{cfg.dataset.win_center-cfg.dataset.win_width/2} ; {cfg.dataset.win_center + cfg.dataset.win_width/2}]")
    logger.info(f"CT scans will be resized to {cfg.dataset.size}x{cfg.dataset.size}")
    logger.info(f"Training online data transformation: \n\n {str(train_dataset.transform)}\n")
    if valid_dataset: logger.info(f"Evaluation online data transformation: \n\n {str(valid_dataset.transform)}\n")
    if cfg.dataset.artificial_anomaly:
        draw_params = [f"--> {k} : {v}" for k, v in cfg.dataset.drawing_params.items()]
        logger.info("Artificial Anomaly drawing parameters \n\t" + "\n\t".join(draw_params))

    #--------------------------------------------------------------------
    #                           Make Networks
    #--------------------------------------------------------------------
    net = FCDD_CNN_VGG(in_shape=[cfg.net.in_channels, cfg.dataset.size, cfg.dataset.size], bias=cfg.net.bias)

    #--------------------------------------------------------------------
    #                          Make FCDD model
    #--------------------------------------------------------------------
    cfg.train.model_param.lr_scheduler = getattr(torch.optim.lr_scheduler, cfg.train.model_param.lr_scheduler) # convert scheduler name to scheduler class object
    model = FCDD(net, print_progress=cfg.print_progress,
                 device=cfg.device,  **cfg.train.model_param)
    train_params = [f"--> {k} : {v}" for k, v in cfg.train.model_param.items()]
    logger.info("FCDD Training Parameters \n\t" + "\n\t".join(train_params))

    # load models if provided
    if cfg.train.model_path_to_load:
        model.load_model(cfg.train.model_path_to_load, map_location=cfg.device)
        logger.info(f"FCDD Model loaded from {cfg.train.model_path_to_load}")

    #--------------------------------------------------------------------
    #                          Train FCDD model
    #--------------------------------------------------------------------
    if cfg.train.model_param.n_epoch > 0:
        model.train(train_dataset, checkpoint_path=os.path.join(out_path, 'Checkpoint.pt'),
                    valid_dataset=valid_dataset)

    #--------------------------------------------------------------------
    #               Generate and save few Heatmap with FCDD model
    #--------------------------------------------------------------------
    if cfg.train.validate_epoch:
        if len(valid_dataset) > 100:
            valid_subset = torch.utils.data.random_split(valid_dataset, [100, len(valid_dataset)-100],
                                                        generator=torch.Generator().manual_seed(cfg.seed))[0]
        else:
            valid_subset = valid_dataset
        model.localize_anomalies(valid_subset, save_path=os.path.join(out_path, 'valid_results/'),
                                 **cfg.train.heatmap_param)

    #--------------------------------------------------------------------
    #                   Save outputs, models and config
    #--------------------------------------------------------------------
    # save models
    model.save_model(export_fn=os.path.join(out_path, 'FCDD.pt'))
    logger.info("FCDD model saved at " + os.path.join(out_path, 'FCDD.pt'))
    # save outputs
    model.save_outputs(export_fn=os.path.join(out_path, 'outputs.json'))
    logger.info("Outputs file saved at " + os.path.join(out_path, 'outputs.json'))
    # save config file
    cfg.device = str(cfg.device) # set device as string to be JSON serializable
    cfg.train.model_param.lr_scheduler = str(cfg.train.model_param.lr_scheduler)
    with open(os.path.join(out_path, 'config.json'), 'w') as fp:
        json.dump(cfg, fp)
    logger.info("Config file saved at " + os.path.join(out_path, 'config.json'))
コード例 #2
0
def main(config_path):
    """ """
    # load config
    cfg = AttrDict.from_json_path(config_path)

    # make outputs dir
    out_path = os.path.join(cfg.path.output, cfg.exp_name)
    os.makedirs(out_path, exist_ok=True)

    # initialize seed
    if cfg.seed != -1:
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
        torch.backends.cudnn.deterministic = True

    # initialize logger
    logger = initialize_logger(os.path.join(out_path, 'log.txt'))
    logger.info(f"Experiment : {cfg.exp_name}")

    # set device
    if cfg.device:
        cfg.device = torch.device(cfg.device)
    else:
        cfg.device = torch.device(
            f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    logger.info(f"Device set to {cfg.device}.")

    # get Dataset
    data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'ct_info.csv'),
                               index_col=0)
    dataset = public_SegICH_Dataset2D(
        data_info_df,
        cfg.path.data,
        augmentation_transform=[
            getattr(tf, tf_name)(**tf_kwargs)
            for tf_name, tf_kwargs in cfg.data.augmentation.items()
        ],
        output_size=cfg.data.size,
        window=(cfg.data.win_center, cfg.data.win_width))

    # load inpainting model
    cfg_ae = AttrDict.from_json_path(cfg.ae_cfg_path)
    ae_net = AE_net(**cfg_ae.net)
    loaded_state_dict = torch.load(cfg.ae_model_path, map_location=cfg.device)
    ae_net.load_state_dict(loaded_state_dict)
    ae_net = ae_net.to(cfg.device).eval()
    logger.info(f"AE model succesfully loaded from {cfg.ae_model_path}")

    # Load Classifier
    if cfg.classifier_model_path is not None:
        cfg_classifier = AttrDict.from_json_path(
            os.path.join(cfg.classifier_model_path, 'config.json'))
        classifier = getattr(rn, cfg_classifier.net.resnet)(
            num_classes=cfg_classifier.net.num_classes,
            input_channels=cfg_classifier.net.input_channels)
        classifier_state_dict = torch.load(os.path.join(
            cfg.classifier_model_path, 'resnet_state_dict.pt'),
                                           map_location=cfg.device)
        classifier.load_state_dict(classifier_state_dict)
        classifier = classifier.to(cfg.device)
        classifier.eval()
        logger.info(
            f"ResNet classifier model succesfully loaded from {os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt')}"
        )

    # iterate over dataset
    all_pred = []
    for i, sample in enumerate(dataset):
        # unpack data
        image, target, id, slice = sample
        logger.info(
            "=" * 25 +
            f" SAMPLE {i+1:04}/{len(dataset):04} - Volume {id:03} Slice {slice:03} "
            + "=" * 25)

        # Classify sample
        if cfg.classifier_model_path is not None:
            with torch.no_grad():
                input_clss = image.unsqueeze(0).to(cfg.device).float()
                pred_score = nn.functional.softmax(
                    classifier(input_clss), dim=1
                )[:, 1]  # take columns of softmax of positive class as score
                pred = 1 if pred_score >= cfg.classification_threshold else 0
        else:
            pred = 1  # if not classifier given, all slices are processed

        # process slice if classifier has detected Hemorrhage
        if pred == 1:
            logger.info(
                f"ICH detected. Compute anomaly mask through AE reconstruction."
            )
            # Detect anomalies using the robuste approach
            ad_map, ad_mask = compute_anomaly(ae_net,
                                              image,
                                              alpha_low=cfg.alpha_low,
                                              alpha_high=cfg.alpha_high,
                                              device=cfg.device)
            logger.info(f"{ad_mask.sum()} anomalous pixels detected.")
            # save ad_mask
            ad_mask_fn = f"{id}/{slice}_anomalies.bmp"
            save_path = os.path.join(out_path, 'pred/', ad_mask_fn)
            if not os.path.isdir(os.path.dirname(save_path)):
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
            io.imsave(save_path, img_as_ubyte(ad_mask), check_contrast=False)
            # save anomaly map
            ad_map_fn = f"{id}/{slice}_map_anomalies.png"
            save_path_map = os.path.join(out_path, 'pred/', ad_map_fn)
            io.imsave(save_path_map,
                      img_as_ubyte(
                          rescale_intensity(ad_map, out_range=(0.0, 1.0))),
                      check_contrast=False)
        else:
            logger.info(f"No ICH detected. Set the anomaly mask to zeros.")
            ad_mask = np.zeros_like(target[0].numpy())
            ad_mask_fn, ad_map_fn = 'None', 'None'

        # compute confusion matrix with target ICH mask
        tn, fp, fn, tp = confusion_matrix(target[0].numpy().ravel(),
                                          ad_mask.ravel(),
                                          labels=[0, 1]).ravel()
        auc = roc_auc_score(target[0].numpy().ravel(),
                            ad_map.ravel()) if torch.any(target[0]) else 'None'
        # append to all_pred list
        all_pred.append({
            'id': id.item(),
            'slice': slice.item(),
            'label': target.max().item(),
            'TP': tp,
            'TN': tn,
            'FP': fp,
            'FN': fn,
            'AUC': auc,
            'ad_mask_fn': ad_mask_fn,
            'ad_map_fn': ad_map_fn
        })

    # make a dataframe of all predictions
    slice_df = pd.DataFrame(all_pred)
    volume_df = slice_df[['id', 'label', 'TP', 'TN', 'FP',
                          'FN']].groupby('id').agg({
                              'label': 'max',
                              'TP': 'sum',
                              'TN': 'sum',
                              'FP': 'sum',
                              'FN': 'sum'
                          })

    # Compute Dice and Volume Dice
    slice_df['Dice'] = (2 * slice_df.TP + 1) / (2 * slice_df.TP + slice_df.FP +
                                                slice_df.FN + 1)
    volume_df['Dice'] = (2 * volume_df.TP + 1) / (
        2 * volume_df.TP + volume_df.FP + volume_df.FN + 1)
    logger.info(f"Mean slice dice : {slice_df.Dice.mean(axis=0):.3f}")
    logger.info(f"Mean volume dice : {volume_df.Dice.mean(axis=0):.3f}")
    logger.info(
        f"Mean posiitve slice AUC {slice_df[slice_df.label == 1].AUC.mean(axis=0):.3f}"
    )

    # Save Scores and Config
    slice_df.to_csv(os.path.join(out_path, 'slice_predictions.csv'))
    logger.info(
        f"Slice prediction csv saved at {os.path.join(out_path, 'slice_predictions.csv')}"
    )
    volume_df.to_csv(os.path.join(out_path, 'volume_predictions.csv'))
    logger.info(
        f"Volume prediction csv saved at {os.path.join(out_path, 'volume_predictions.csv')}"
    )
    cfg.device = str(cfg.device)
    with open(os.path.join(out_path, 'config.json'), 'w') as f:
        json.dump(cfg, f)
    logger.info(
        f"Config file saved at {os.path.join(out_path, 'config.json')}")
def main(config_path):
    """
    UNet2D pretrained on binary classification with the RSNA dataset and finetuned with the public data.
    """
    # load the config file
    cfg = AttrDict.from_json_path(config_path)

    # Make Outputs directories
    out_path = os.path.join(cfg.path.output, cfg.exp_name)
    out_path_selfsup = os.path.join(out_path, 'classification_pretrain/')
    out_path_sup = os.path.join(out_path, 'supervised_train/')
    os.makedirs(out_path_selfsup, exist_ok=True)
    for k in range(cfg.Sup.split.n_fold):
        os.makedirs(os.path.join(out_path_sup, f'Fold_{k+1}/pred/'),
                    exist_ok=True)

    # Initialize random seed
    if cfg.seed != -1:
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
        torch.backends.cudnn.deterministic = True

    # Set number of thread
    if cfg.n_thread > 0: torch.set_num_threads(cfg.n_thread)
    # set device
    if cfg.device:
        cfg.device = torch.device(cfg.device)
    else:
        cfg.device = get_available_device()

    ####################################################
    # Self-supervised training on Multi Classification #
    ####################################################
    # Initialize Logger
    logger = initialize_logger(os.path.join(out_path_selfsup, 'log.txt'))
    if os.path.exists(os.path.join(out_path_selfsup, f'checkpoint.pt')):
        logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' + '#' * 30)
    logger.info(f"Experiment : {cfg.exp_name}")

    # Load RSNA data csv
    df_rsna = pd.read_csv(os.path.join(cfg.path.data.SSL, 'slice_info.csv'),
                          index_col=0)

    # Keep only fractions sample
    if cfg.SSL.dataset.n_data_1 >= 0:
        df_rsna_ICH = df_rsna[df_rsna.Hemorrhage == 1].sample(
            n=cfg.SSL.dataset.n_data_1, random_state=cfg.seed)
    else:
        df_rsna_ICH = df_rsna[df_rsna.Hemorrhage == 1]
    if cfg.SSL.dataset.f_data_0 >= 0:  # nbr of normal as fraction of nbr of ICH samples
        df_rsna_noICH = df_rsna[df_rsna.Hemorrhage == 0].sample(
            n=cfg.SSL.dataset.f_data_0 * len(df_rsna_ICH),
            random_state=cfg.seed)
    else:
        df_rsna_noICH = df_rsna[df_rsna.Hemorrhage == 0]
    df_rsna = pd.concat([df_rsna_ICH, df_rsna_noICH], axis=0)

    # Split data to keep few for evaluation in a strafied way
    train_df, test_df = train_test_split(df_rsna,
                                         test_size=cfg.SSL.dataset.frac_eval,
                                         stratify=df_rsna.Hemorrhage,
                                         random_state=cfg.seed)
    logger.info('\n' +
                str(get_split_summary_table(df_rsna, train_df, test_df)))

    # Make dataset : Train --> BinaryClassification, Test --> BinaryClassification
    train_RSNA_dataset = RSNA_dataset(
        train_df,
        cfg.path.data.SSL,
        augmentation_transform=[
            getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
            cfg.SSL.dataset.augmentation.train.items()
        ],
        window=(cfg.data.win_center, cfg.data.win_width),
        output_size=cfg.data.size,
        mode='multi_classification')
    test_RSNA_dataset = RSNA_dataset(
        test_df,
        cfg.path.data.SSL,
        augmentation_transform=[
            getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
            cfg.SSL.dataset.augmentation.eval.items()
        ],
        window=(cfg.data.win_center, cfg.data.win_width),
        output_size=cfg.data.size,
        mode='multi_classification')

    logger.info(f"Data will be loaded from {cfg.path.data.SSL}.")
    logger.info(
        f"CT scans will be windowed on [{cfg.data.win_center-cfg.data.win_width/2} ; {cfg.data.win_center + cfg.data.win_width/2}]"
    )
    logger.info(f"CT scans will be resized to {cfg.data.size}x{cfg.data.size}")
    logger.info(
        f"Training online data transformation: \n\n {str(train_RSNA_dataset.transform)}\n"
    )
    logger.info(
        f"Evaluation online data transformation: \n\n {str(test_RSNA_dataset.transform)}\n"
    )

    # Make U-Net-Encoder architecture
    net_ssl = UNet_Encoder(**cfg.SSL.net).to(cfg.device)
    net_params = [f"--> {k} : {v}" for k, v in cfg.SSL.net.items()]
    logger.info("UNet like Multi Classifier \n\t" + "\n\t".join(net_params))
    logger.info(
        f"The Multi Classifier has {sum(p.numel() for p in net_ssl.parameters())} parameters."
    )

    # Make Model
    cfg.SSL.train.model_param.lr_scheduler = getattr(
        torch.optim.lr_scheduler, cfg.SSL.train.model_param.lr_scheduler
    )  # convert scheduler name to scheduler class object
    if cfg.SSL.train.model_param.loss_fn == 'BCEWithLogitsLoss':
        df_rsna['no_Hemorrhage'] = 1 - df_rsna.Hemorrhage
        class_weight_list = (
            (len(df_rsna) - df_rsna[train_RSNA_dataset.class_name].sum()) /
            df_rsna[train_RSNA_dataset.class_name].sum()
        ).values  # define CE weighting from train dataset
        cfg.SSL.train.model_param.loss_fn_kwargs['pos_weight'] = torch.tensor(
            class_weight_list, device=cfg.device)  # add weighting to CE kwargs
    try:
        cfg.SSL.train.model_param.loss_fn = getattr(
            torch.nn, cfg.SSL.train.model_param.loss_fn
        )  # convert loss_fn name to nn.Module class object
    except AttributeError:
        cfg.SSL.train.model_param.loss_fn = getattr(
            src.models.optim.LossFunctions, cfg.SSL.train.model_param.loss_fn)

    #torch.tensor([1 - w_ICH, w_ICH], device=cfg.device).float()

    classifier = MultiClassifier(net_ssl,
                                 device=cfg.device,
                                 print_progress=cfg.print_progress,
                                 **cfg.SSL.train.model_param)

    train_params = [
        f"--> {k} : {v}" for k, v in cfg.SSL.train.model_param.items()
    ]
    logger.info("Classifer Training Parameters \n\t" +
                "\n\t".join(train_params))

    # Load weights if specified
    if cfg.SSL.train.model_path_to_load:
        model_path = cfg.SSL.train.model_path_to_load
        classifier.load_model(model_path, map_location=cfg.device)
        logger.info(
            f"Classifer Model succesfully loaded from {cfg.SSL.train.model_path_to_load}"
        )

    # train if needed
    if cfg.SSL.train.model_param.n_epoch > 0:
        classifier.train(train_RSNA_dataset,
                         valid_dataset=test_RSNA_dataset,
                         checkpoint_path=os.path.join(out_path_selfsup,
                                                      f'checkpoint.pt'))

    # evaluate
    auc, acc, sub_acc, recall, precision, f1 = classifier.evaluate(
        test_RSNA_dataset, save_tsne=True, return_scores=True)
    logger.info(f"Classifier Test AUC : {auc:.2%}")
    logger.info(f"Classifier Test Accuracy : {acc:.2%}")
    logger.info(f"Classifier Test Subset Accuracy : {sub_acc:.2%}")
    logger.info(f"Classifier Test Recall : {recall:.2%}")
    logger.info(f"Classifier Test Precision : {precision:.2%}")
    logger.info(f"Classifier Test F1-score : {f1:.2%}")

    # save model, outputs
    classifier.save_model_state_dict(
        os.path.join(out_path_selfsup, 'pretrained_unet_enc.pt'))
    logger.info(
        "Pre-trained U-Net encoder on binary classification saved at " +
        os.path.join(out_path_selfsup, 'pretrained_unet_enc.pt'))
    classifier.save_outputs(os.path.join(out_path_selfsup, 'outputs.json'))
    logger.info("Classifier outputs saved at " +
                os.path.join(out_path_selfsup, 'outputs.json'))
    test_df.reset_index(drop=True).to_csv(
        os.path.join(out_path_selfsup, 'eval_data_info.csv'))
    logger.info("Evaluation data info saved at " +
                os.path.join(out_path_selfsup, 'eval_data_info.csv'))

    # delete any checkpoints
    if os.path.exists(os.path.join(out_path_selfsup, f'checkpoint.pt')):
        os.remove(os.path.join(out_path_selfsup, f'checkpoint.pt'))
        logger.info('Checkpoint deleted.')

    # get weights state dictionnary
    pretrained_unet_weights = classifier.get_state_dict()

    ###################################################################
    # Supervised fine-training of U-Net  with K-Fold Cross-Validation #
    ###################################################################
    # load annotated data csv
    data_info_df = pd.read_csv(os.path.join(cfg.path.data.Sup, 'ct_info.csv'),
                               index_col=0)
    patient_df = pd.read_csv(os.path.join(cfg.path.data.Sup,
                                          'patient_info.csv'),
                             index_col=0)

    # Make K-Fold spolit at patient level
    skf = StratifiedKFold(n_splits=cfg.Sup.split.n_fold,
                          shuffle=cfg.Sup.split.shuffle,
                          random_state=cfg.seed)

    # define scheduler and loss_fn as object
    cfg.Sup.train.model_param.lr_scheduler = getattr(
        torch.optim.lr_scheduler, cfg.Sup.train.model_param.lr_scheduler
    )  # convert scheduler name to scheduler class object
    cfg.Sup.train.model_param.loss_fn = getattr(
        src.models.optim.LossFunctions, cfg.Sup.train.model_param.loss_fn
    )  # convert loss_fn name to nn.Module class object

    # iterate over folds
    for k, (train_idx, test_idx) in enumerate(
            skf.split(patient_df.PatientNumber, patient_df.Hemorrhage)):
        # check if fold's results already exists
        if not os.path.exists(
                os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json')):
            # initialize logger
            logger = initialize_logger(
                os.path.join(out_path_sup, f'Fold_{k+1}/log.txt'))
            if os.path.exists(
                    os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')):
                logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' +
                            '#' * 30)
            logger.info(f"Experiment : {cfg['exp_name']}")
            logger.info(
                f"Cross-Validation fold {k+1:02}/{cfg['Sup']['split']['n_fold']:02}"
            )

            # extract train/test slice dataframe
            train_df = data_info_df[data_info_df.PatientNumber.isin(
                patient_df.loc[train_idx, 'PatientNumber'].values)]
            test_df = data_info_df[data_info_df.PatientNumber.isin(
                patient_df.loc[test_idx, 'PatientNumber'].values)]
            # samples train dataframe to adjuste negative/positive fractions
            n_remove = int(
                max(
                    0,
                    len(train_df[train_df.Hemorrhage == 0]) -
                    cfg.Sup.dataset.frac_negative *
                    len(train_df[train_df.Hemorrhage == 1])))
            df_remove = train_df[train_df.Hemorrhage == 0].sample(
                n=n_remove, random_state=cfg.seed)
            train_df = train_df[~train_df.index.isin(df_remove.index)]
            logger.info(
                '\n' +
                str(get_split_summary_table(data_info_df, train_df, test_df)))

            # Make datasets
            train_dataset = public_SegICH_Dataset2D(
                train_df,
                cfg.path.data.Sup,
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg.Sup.dataset.augmentation.train.items()
                ],
                window=(cfg.data.win_center, cfg.data.win_width),
                output_size=cfg.data.size)
            test_dataset = public_SegICH_Dataset2D(
                test_df,
                cfg.path.data.Sup,
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg.Sup.dataset.augmentation.eval.items()
                ],
                window=(cfg.data.win_center, cfg.data.win_width),
                output_size=cfg.data.size)
            logger.info(f"Data will be loaded from {cfg.path.data.Sup}.")
            logger.info(
                f"CT scans will be windowed on [{cfg.data.win_center-cfg.data.win_width/2} ; {cfg.data.win_center + cfg.data.win_width/2}]"
            )
            logger.info(
                f"CT scans will be resized to {cfg.data.size}x{cfg.data.size}")
            logger.info(
                f"Training online data transformation: \n\n {str(train_dataset.transform)}\n"
            )
            logger.info(
                f"Evaluation online data transformation: \n\n {str(test_dataset.transform)}\n"
            )

            # Make U-Net architecture
            unet_sup = UNet(**cfg.Sup.net).to(cfg.device)
            net_params = [f"--> {k} : {v}" for k, v in cfg.Sup.net.items()]
            logger.info("UNet-2D params \n\t" + "\n\t".join(net_params))
            logger.info(
                f"The U-Net2D has {sum(p.numel() for p in unet_sup.parameters())} parameters."
            )

            # Make Model
            unet2D = UNet2D(unet_sup,
                            device=cfg.device,
                            print_progress=cfg.print_progress,
                            **cfg.Sup.train.model_param)

            train_params = [
                f"--> {k} : {v}" for k, v in cfg.Sup.train.model_param.items()
            ]
            logger.info("UNet-2D Training Parameters \n\t" +
                        "\n\t".join(train_params))

            # ????? load model if specified ?????

            # transfer weights learn with context restoration
            logger.info(
                'Initialize U-Net2D with weights learned with context_restoration on RSNA.'
            )
            unet2D.transfer_weights(pretrained_unet_weights, verbose=True)

            # Train U-net
            eval_dataset = test_dataset if cfg.Sup.train.validate_epoch else None
            unet2D.train(train_dataset,
                         valid_dataset=eval_dataset,
                         checkpoint_path=os.path.join(
                             out_path_sup, f'Fold_{k+1}/checkpoint.pt'))

            # Evaluate U-Net
            unet2D.evaluate(test_dataset,
                            save_path=os.path.join(out_path_sup,
                                                   f'Fold_{k+1}/pred/'))

            # Save models and outputs
            unet2D.save_model(
                os.path.join(out_path_sup, f'Fold_{k+1}/trained_unet.pt'))
            logger.info(
                "Trained U-Net saved at " +
                os.path.join(out_path_sup, f'Fold_{k+1}/trained_unet.pt'))
            unet2D.save_outputs(
                os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json'))
            logger.info("Trained statistics saved at " +
                        os.path.join(out_path_sup, f'Fold_{k+1}/outputs.json'))

            # delete checkpoint if exists
            if os.path.exists(
                    os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt')):
                os.remove(
                    os.path.join(out_path_sup, f'Fold_{k+1}/checkpoint.pt'))
                logger.info('Checkpoint deleted.')

    # save mean +/- 1.96 std Dice over Folds
    save_mean_fold_dice(out_path_sup, cfg.Sup.split.n_fold)
    logger.info('Average Scores saved at ' +
                os.path.join(out_path_sup, 'average_scores.txt'))

    # Save all volumes prediction csv
    df_list = [
        pd.read_csv(
            os.path.join(out_path_sup,
                         f'Fold_{i+1}/pred/volume_prediction_scores.csv'))
        for i in range(cfg.Sup.split.n_fold)
    ]
    all_df = pd.concat(df_list, axis=0).reset_index(drop=True)
    all_df.to_csv(os.path.join(out_path_sup, 'all_volume_prediction.csv'))
    logger.info('CSV of all volumes prediction saved at ' +
                os.path.join(out_path_sup, 'all_volume_prediction.csv'))

    # Save config file
    cfg.device = str(cfg.device)
    cfg.SSL.train.model_param.lr_scheduler = str(
        cfg.SSL.train.model_param.lr_scheduler)
    cfg.Sup.train.model_param.lr_scheduler = str(
        cfg.Sup.train.model_param.lr_scheduler)
    cfg.SSL.train.model_param.loss_fn = str(cfg.SSL.train.model_param.loss_fn)
    cfg.Sup.train.model_param.loss_fn = str(cfg.Sup.train.model_param.loss_fn)
    cfg.SSL.train.model_param.loss_fn_kwargs.pos_weight = cfg.SSL.train.model_param.loss_fn_kwargs.pos_weight.cpu(
    ).data.tolist()
    with open(os.path.join(out_path, 'config.json'), 'w') as fp:
        json.dump(cfg, fp)
    logger.info('Config file saved at ' +
                os.path.join(out_path, 'config.json'))

    # Analyse results
    analyse_supervised_exp(out_path_sup,
                           cfg.path.data.Sup,
                           n_fold=cfg.Sup.split.n_fold,
                           config_folder=out_path,
                           save_fn=os.path.join(
                               out_path, 'results_supervised_overview.pdf'))
    logger.info('Results overview figure saved at ' +
                os.path.join(out_path, 'results_supervised_overview.pdf'))
    analyse_representation_exp(out_path_selfsup,
                               save_fn=os.path.join(
                                   out_path,
                                   'results_self-supervised_overview.pdf'))
    logger.info('Results overview figure saved at ' +
                os.path.join(out_path, 'results_self-supervised_overview.pdf'))
コード例 #4
0
def main(config_path):
    """
    Train and evaluate a 2D UNet on the public ICH dataset with the anomaly attention map using the parameters on the
    JSON at the config_path. The evaluation is performed by k-fold cross-validation.
    """
    # load config file
    cfg = AttrDict.from_json_path(config_path)

    # Make Output directories
    out_path = os.path.join(cfg.path.output, cfg.exp_name)
    os.makedirs(out_path, exist_ok=True)
    for k in range(cfg.split.n_fold):
        os.makedirs(os.path.join(out_path, f'Fold_{k+1}/pred/'), exist_ok=True)

    # Initialize random seed to given seed
    if cfg.seed != -1:
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
        torch.backends.cudnn.deterministic = True

    # Load data csv
    data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'info.csv'),
                               index_col=0)
    patient_df = pd.read_csv(os.path.join(cfg.path.data, 'patient_info.csv'),
                             index_col=0)

    # Generate Cross-Val indices at the patient level
    skf = StratifiedKFold(n_splits=cfg.split.n_fold,
                          shuffle=cfg.split.shuffle,
                          random_state=cfg.seed)
    # iterate over folds and ensure that there are the same amount of ICH positive patient per fold --> Stratiffied CrossVal
    for k, (train_idx, test_idx) in enumerate(
            skf.split(patient_df.PatientNumber, patient_df.Hemorrhage)):
        # if fold results not already there
        if not os.path.exists(
                os.path.join(out_path, f'Fold_{k+1}/outputs.json')):
            # initialize logger
            logger = initialize_logger(os.path.join(out_path, 'log.txt'))
            if os.path.exists(
                    os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt')):
                logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' +
                            '#' * 30)
            logger.info(f"Experiment : {cfg.exp_name}")
            logger.info(
                f"Cross-Validation fold {k+1:02}/{cfg.split.n_fold:02}")

            # check if GPU available
            if cfg.device is not None:
                cfg.device = torch.device(cfg.device)
            else:
                cfg.device = torch.device('cuda') if torch.cuda.is_available(
                ) else torch.device('cpu')
            logger.info(f"Device : {cfg.device}")

            # extract train and test DataFrames + print summary (n samples positive and negatives)
            train_df = data_info_df[data_info_df.id.isin(
                patient_df.loc[train_idx, 'PatientNumber'].values)]
            test_df = data_info_df[data_info_df.id.isin(
                patient_df.loc[test_idx, 'PatientNumber'].values)]
            # sample the dataframe to have more or less normal slices
            n_remove = int(
                max(
                    0,
                    len(train_df[train_df.Hemorrhage == 0]) -
                    cfg.dataset.frac_negative *
                    len(train_df[train_df.Hemorrhage == 1])))
            df_remove = train_df[train_df.Hemorrhage == 0].sample(
                n=n_remove, random_state=cfg.seed)
            train_df = train_df[~train_df.index.isin(df_remove.index)]
            logger.info(
                '\n' +
                str(get_split_summary_table(data_info_df, train_df, test_df)))

            # Make Dataset + print online augmentation summary
            train_dataset = public_SegICH_AttentionDataset2D(
                train_df,
                cfg.path.data,
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg.data.augmentation.train.items()
                ],
                window=(cfg.data.win_center, cfg.data.win_width),
                output_size=cfg.data.size)
            test_dataset = public_SegICH_AttentionDataset2D(
                test_df,
                cfg.path.data,
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg.data.augmentation.eval.items()
                ],
                window=(cfg.data.win_center, cfg.data.win_width),
                output_size=cfg.data.size)
            logger.info(f"Data will be loaded from {cfg.path.data}.")
            logger.info(
                f"CT scans will be windowed on [{cfg.data.win_center-cfg.data.win_width/2} ; {cfg.data.win_center + cfg.data.win_width/2}]"
            )
            logger.info(
                f"Training online data transformation: \n\n {str(train_dataset.transform)}\n"
            )
            logger.info(
                f"Evaluation online data transformation: \n\n {str(test_dataset.transform)}\n"
            )

            # Make architecture (and print summmary ??)
            unet_arch = UNet(**cfg.net)
            unet_arch.to(cfg.device)
            net_params = [f"--> {k} : {v}" for k, v in cfg.net.items()]
            logger.info("U-Net2D architecture \n\t" + "\n\t".join(net_params))
            logger.info(
                f"The U-Net2D has {sum(p.numel() for p in unet_arch.parameters())} parameters."
            )

            # Make model
            cfg_train = AttrDict(cfg.train.params)
            cfg_train.lr_scheduler = getattr(torch.optim.lr_scheduler,
                                             cfg_train.lr_scheduler)
            cfg_train.loss_fn = getattr(src.models.optim.LossFunctions,
                                        cfg_train.loss_fn)
            unet2D = UNet2D(unet_arch,
                            device=cfg.device,
                            print_progress=cfg.print_progress,
                            **cfg_train)
            # print Training hyper-parameters
            train_params = [f"--> {k} : {v}" for k, v in cfg_train.items()]
            logger.info("U-Net2D Training Parameters \n\t" +
                        "\n\t".join(train_params))

            # Load model if required
            if cfg.train.model_path_to_load:
                if isinstance(cfg.train.model_path_to_load, str):
                    model_path = cfg.train.model_path_to_load
                    unet2D.load_model(model_path, map_location=cfg.device)
                elif isinstance(cfg.train.model_path_to_load, list):
                    model_path = cfg.train.model_path_to_load[k]
                    unet2D.load_model(model_path, map_location=cfg.device)
                else:
                    raise ValueError(
                        f'Model path to load type not understood.')
                logger.info(f"2D U-Net model loaded from {model_path}")

            # Train model
            eval_dataset = test_dataset if cfg.train.validate_epoch else None
            unet2D.train(train_dataset,
                         valid_dataset=eval_dataset,
                         checkpoint_path=os.path.join(
                             out_path, f'Fold_{k+1}/checkpoint.pt'))

            # Evaluate model
            unet2D.evaluate(test_dataset,
                            save_path=os.path.join(out_path,
                                                   f'Fold_{k+1}/pred/'))

            # Save models & outputs
            unet2D.save_model(
                os.path.join(out_path, f'Fold_{k+1}/trained_unet.pt'))
            logger.info("Trained U-Net saved at " +
                        os.path.join(out_path, f'Fold_{k+1}/trained_unet.pt'))
            unet2D.save_outputs(
                os.path.join(out_path, f'Fold_{k+1}/outputs.json'))
            logger.info("Trained statistics saved at " +
                        os.path.join(out_path, f'Fold_{k+1}/outputs.json'))

            # delete checkpoint if exists
            if os.path.exists(
                    os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt')):
                os.remove(os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt'))
                logger.info('Checkpoint deleted.')

    # save mean +/- 1.96 std Dice in .txt file
    scores_list = []
    for k in range(cfg.split.n_fold):
        with open(os.path.join(out_path, f'Fold_{k+1}/outputs.json'),
                  'r') as f:
            out = json.load(f)
        scores_list.append(
            [out['eval']['dice']['all'], out['eval']['dice']['positive']])
    means = np.array(scores_list).mean(axis=0)
    CI95 = 1.96 * np.array(scores_list).std(axis=0)
    with open(os.path.join(out_path, 'average_scores.txt'), 'w') as f:
        f.write(f'Dice = {means[0]} +/- {CI95[0]}\n')
        f.write(f'Dice (Positive) = {means[1]} +/- {CI95[1]}\n')
    logger.info('Average Scores saved at ' +
                os.path.join(out_path, 'average_scores.txt'))

    # generate dataframe of all prediction
    df_list = [
        pd.read_csv(
            os.path.join(out_path,
                         f'Fold_{i+1}/pred/volume_prediction_scores.csv'))
        for i in range(cfg.split.n_fold)
    ]
    all_df = pd.concat(df_list, axis=0).reset_index(drop=True)
    all_df.to_csv(os.path.join(out_path, 'all_volume_prediction.csv'))
    logger.info('CSV of all volumes prediction saved at ' +
                os.path.join(out_path, 'all_volume_prediction.csv'))

    # Save config file
    cfg.device = str(cfg.device)
    #cfg.train.params.lr_scheduler = str(cfg.train.params.lr_scheduler)
    #cfg.train.params.loss_fn = str(cfg.train.params.loss_fn)
    with open(os.path.join(out_path, 'config.json'), 'w') as fp:
        json.dump(cfg, fp)
    logger.info("Config file saved at " +
                os.path.join(out_path, 'config.json'))

    # Analyse results
    analyse_supervised_exp(out_path,
                           cfg.path.data,
                           cfg.split.n_fold,
                           save_fn=os.path.join(out_path,
                                                'results_overview.pdf'))
    logger.info('Results overview figure saved at ' +
                os.path.join(out_path, 'results_overview.pdf'))
コード例 #5
0
def main(config_path):
    """  """
    # load config
    cfg = AttrDict.from_json_path(config_path)

    # make outputs dir
    out_path = os.path.join(cfg.path.output, cfg.exp_name)
    os.makedirs(out_path, exist_ok=True)

    # initialize seed
    if cfg.seed != -1:
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
        torch.backends.cudnn.deterministic = True

    # initialize logger
    logger = initialize_logger(os.path.join(out_path, 'log.txt'))
    logger.info(f"Experiment : {cfg.exp_name}")

    # set device
    if cfg.device:
        cfg.device = torch.device(cfg.device)
    else:
        cfg.device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    logger.info(f"Device set to {cfg.device}.")

    #-------------------------------------------
    #       Make Dataset
    #-------------------------------------------

    data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'ct_info.csv'), index_col=0)
    dataset = public_SegICH_Dataset2D(data_info_df, cfg.path.data,
                    augmentation_transform=[getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.data.augmentation.items()],
                    output_size=cfg.data.size, window=(cfg.data.win_center, cfg.data.win_width))

    #-------------------------------------------
    #       Load FCDD Model
    #-------------------------------------------

    cfg_fcdd = AttrDict.from_json_path(cfg.fcdd_cfg_path)
    fcdd_net = FCDD_CNN_VGG(in_shape=(cfg_fcdd.net.in_channels, 256, 256), bias=cfg_fcdd.net.bias)
    loaded_state_dict = torch.load(cfg.fcdd_model_path, map_location=cfg.device)
    fcdd_net.load_state_dict(loaded_state_dict)
    fcdd_net = fcdd_net.to(cfg.device).eval()
    logger.info(f"FCDD model succesfully loaded from {cfg.fcdd_model_path}")

    # make FCDD object
    fcdd = FCDD(fcdd_net, batch_size=cfg.batch_size, num_workers=cfg.num_workers,
                device=cfg.device, print_progress=cfg.print_progress)

    #-------------------------------------------
    #       Load Classifier Model
    #-------------------------------------------

    # Load Classifier
    if cfg.classifier_model_path is not None:
        cfg_classifier = AttrDict.from_json_path(os.path.join(cfg.classifier_model_path, 'config.json'))
        classifier = getattr(rn, cfg_classifier.net.resnet)(num_classes=cfg_classifier.net.num_classes, input_channels=cfg_classifier.net.input_channels)
        classifier_state_dict = torch.load(os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt'), map_location=cfg.device)
        classifier.load_state_dict(classifier_state_dict)
        classifier = classifier.to(cfg.device)
        classifier.eval()
        logger.info(f"ResNet classifier model succesfully loaded from {os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt')}")

    #-------------------------------------------
    #       Generate Heat-Map for each slice
    #-------------------------------------------

    with torch.no_grad():
        # make loader
        loader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers,
                                             shuffle=False, worker_init_fn=lambda _: np.random.seed())
        fcdd_net.eval()

        min_val, max_val = fcdd.get_min_max(loader, **cfg.heatmap_param)

        # computing and saving heatmaps
        out = dict(id=[], slice=[], label=[], ad_map_fn=[], ad_mask_fn=[],
                   TP=[], TN=[], FP=[], FN=[], AUC=[], classifier_pred=[])
        for b, data in enumerate(loader):
            im, mask, id, slice = data
            im = im.to(cfg.device).float()
            mask = mask.to(cfg.device).float()

            # get heatmap
            heatmap = fcdd.generate_heatmap(im, reception=cfg.heatmap_param.reception, std=cfg.heatmap_param.std,
                                            cpu=cfg.heatmap_param.cpu)
            # scaling
            heatmap = ((heatmap - min_val) / (max_val - min_val)).clamp(0,1)

            # Threshold
            ad_mask = torch.where(heatmap >= cfg.heatmap_threshold, torch.ones_like(heatmap, device=heatmap.device),
                                                                    torch.zeros_like(heatmap, device=heatmap.device))

            # Compute CM
            tn, fp, fn, tp  = batch_binary_confusion_matrix(ad_mask, mask.to(heatmap.device))

            # Save heatmaps/mask
            map_fn, mask_fn = [], []
            for i in range(im.shape[0]):
                # Save AD Map
                ad_map_fn = f"{id[i]}/{slice[i]}_map_anomalies.png"
                save_path = os.path.join(out_path, 'pred/', ad_map_fn)
                if not os.path.isdir(os.path.dirname(save_path)):
                    os.makedirs(os.path.dirname(save_path), exist_ok=True)

                ad_map = heatmap[i].squeeze().cpu().numpy()
                io.imsave(save_path, img_as_ubyte(ad_map), check_contrast=False)
                # save ad_mask
                ad_mask_fn = f"{id[i]}/{slice[i]}_anomalies.bmp"
                save_path = os.path.join(out_path, 'pred/', ad_mask_fn)
                io.imsave(save_path, img_as_ubyte(ad_mask[i].squeeze().cpu().numpy()), check_contrast=False)

                map_fn.append(ad_map_fn)
                mask_fn.append(ad_mask_fn)

            # apply classifier ResNet-18
            if cfg.classifier_model_path is not None:
                pred_score = nn.functional.softmax(classifier(im), dim=1)[:,1] # take columns of softmax of positive class as score
                clss_pred = torch.where(pred_score >= cfg.classification_threshold, torch.ones_like(pred_score, device=pred_score.device),
                                                                                    torch.zeros_like(pred_score, device=pred_score.device))
            else:
                clss_pred = [None]*im.shape[0]

            # Save Values
            out['id'] += id.cpu().tolist()
            out['slice'] += slice.cpu().tolist()
            out['label'] += mask.reshape(mask.shape[0], -1).max(dim=1)[0].cpu().tolist()
            out['ad_map_fn'] += map_fn
            out['ad_mask_fn'] += mask_fn
            out['TN'] += tn.cpu().tolist()
            out['FP'] += fp.cpu().tolist()
            out['FN'] += fn.cpu().tolist()
            out['TP'] += tp.cpu().tolist()
            out['AUC'] += [roc_auc_score(mask[i].cpu().numpy().ravel(), heatmap[i].cpu().numpy().ravel()) if torch.any(mask[i]>0) else 'None' for i in range(im.shape[0])]
            out['classifier_pred'] += clss_pred.cpu().tolist()

            if cfg.print_progress:
                print_progessbar(b, len(loader), Name='Heatmap Generation Batch', Size=100, erase=True)

    # make df and save as csv
    slice_df = pd.DataFrame(out)
    volume_df = slice_df[['id', 'label', 'TP', 'TN', 'FP', 'FN']].groupby('id').agg({'label':'max', 'TP':'sum', 'TN':'sum', 'FP':'sum', 'FN':'sum'})

    slice_df['Dice'] = (2*slice_df.TP + 1) / (2*slice_df.TP + slice_df.FP + slice_df.FN + 1)
    volume_df['Dice'] = (2*volume_df.TP + 1) / (2*volume_df.TP + volume_df.FP + volume_df.FN + 1)
    logger.info(f"Mean slice dice : {slice_df.Dice.mean(axis=0):.3f}")
    logger.info(f"Mean volume dice : {volume_df.Dice.mean(axis=0):.3f}")
    logger.info(f"Mean posiitve slice AUC {slice_df[slice_df.label == 1].AUC.mean(axis=0):.3f}")

    # Save Scores and Config
    slice_df.to_csv(os.path.join(out_path, 'slice_predictions.csv'))
    logger.info(f"Slice prediction csv saved at {os.path.join(out_path, 'slice_predictions.csv')}")
    volume_df.to_csv(os.path.join(out_path, 'volume_predictions.csv'))
    logger.info(f"Volume prediction csv saved at {os.path.join(out_path, 'volume_predictions.csv')}")
    cfg.device = str(cfg.device)
    with open(os.path.join(out_path, 'config.json'), 'w') as f:
        json.dump(cfg, f)
    logger.info(f"Config file saved at {os.path.join(out_path, 'config.json')}")
コード例 #6
0
def main(config_path):
    """
    ResNet binary classification with the RSNA dataset.
    """
    # load the config file
    cfg = AttrDict.from_json_path(config_path)

    # Make Outputs directories
    out_path = os.path.join(cfg.path.output, cfg.exp_name)
    os.makedirs(out_path, exist_ok=True)

    # Initialize random seed
    if cfg.seed != -1:
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
        torch.backends.cudnn.deterministic = True

    # initialize logger
    logger = initialize_logger(os.path.join(out_path, 'log.txt'))
    if os.path.exists(os.path.join(out_path, f'checkpoint.pt')):
        logger.info('\n' + '#'*30 + f'\n Recovering Session \n' + '#'*30)
    logger.info(f"Experiment : {cfg.exp_name}")

    # Set number of thread
    if cfg.n_thread > 0: torch.set_num_threads(cfg.n_thread)
    # set device, if None use the first one
    if cfg.device:
        cfg.device = torch.device(cfg.device)
    else:
        cfg.device = torch.device(f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    logger.info(f"Device set to {cfg.device}. {torch.cuda.device_count()} GPU available, "
                f"{len(cfg.multi_gpu_id) if cfg.multi_gpu_id else 1} used.")

    # Load RSNA data csv
    df_rsna = pd.read_csv(os.path.join(cfg.path.data, 'slice_info.csv'), index_col=0)

    # Keep only fractions sample
    if cfg.dataset.n_data_0 >= 0:
        df_rsna_noICH = df_rsna[df_rsna.Hemorrhage == 0].sample(n=cfg.dataset.n_data_0, random_state=cfg.seed)
    else:
        df_rsna_noICH = df_rsna[df_rsna.Hemorrhage == 0]
    if cfg.dataset.n_data_1 >= 0:
        df_rsna_ICH = df_rsna[df_rsna.Hemorrhage == 1].sample(n=cfg.dataset.n_data_1, random_state=cfg.seed)
    else:
        df_rsna_ICH = df_rsna[df_rsna.Hemorrhage == 1]
    df_rsna = pd.concat([df_rsna_ICH, df_rsna_noICH], axis=0)

    # Split data to keep few for evaluation in a strafied way
    train_df, test_df = train_test_split(df_rsna, test_size=cfg.dataset.frac_eval, stratify=df_rsna.Hemorrhage, random_state=cfg.seed)
    logger.info('\n' + str(get_split_summary_table(df_rsna, train_df, test_df)))

    # Make dataset : Train --> BinaryClassification, Test --> BinaryClassification
    train_RSNA_dataset = RSNA_dataset(train_df, cfg.path.data,
                                 augmentation_transform=[getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.dataset.augmentation.train.items()],
                                 window=(cfg.data.win_center, cfg.data.win_width), output_size=cfg.data.size,
                                 mode='binary_classification')
    test_RSNA_dataset = RSNA_dataset(test_df, cfg.path.data,
                                 augmentation_transform=[getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in cfg.dataset.augmentation.eval.items()],
                                 window=(cfg.data.win_center, cfg.data.win_width), output_size=cfg.data.size,
                                 mode='binary_classification')

    logger.info(f"Data will be loaded from {cfg.path.data}.")
    logger.info(f"CT scans will be windowed on [{cfg.data.win_center-cfg.data.win_width/2} ; {cfg.data.win_center + cfg.data.win_width/2}]")
    logger.info(f"CT scans will be resized to {cfg.data.size}x{cfg.data.size}")
    logger.info(f"Training online data transformation: \n\n {str(train_RSNA_dataset.transform)}\n")
    logger.info(f"Evaluation online data transformation: \n\n {str(test_RSNA_dataset.transform)}\n")

    # Make Resnet Architecture
    resnet_network = getattr(resnet, cfg.net.resnet)(num_classes=cfg.net.num_classes, input_channels=cfg.net.input_channels)
    logger.info(f"Using a {cfg.net.resnet} architecture.")
    if cfg.multi_gpu_id is not None and len(cfg.multi_gpu_id) > 1: # set network for multi-GPU
        resnet_network = torch.nn.DataParallel(resnet_network, device_ids=cfg.multi_gpu_id)
        logger.info("Enabling the resnet for multi-GPU computation.")
    resnet_network = resnet_network.to(cfg.device)
    logger.info(f"The {cfg.net.resnet} has {sum(p.numel() for p in resnet_network.parameters())} parameters.")

    # Make model
    cfg.train.model_param.lr_scheduler = getattr(torch.optim.lr_scheduler, cfg.train.model_param.lr_scheduler) # convert scheduler name to scheduler class object
    cfg.train.model_param.loss_fn = getattr(torch.nn, cfg.train.model_param.loss_fn) # convert loss_fn name to nn.Module class object
    w_ICH = train_df.Hemorrhage.sum() / len(train_df) # define CE weighting from train dataset
    cfg.train.model_param.loss_fn_kwargs['weight'] = torch.tensor([1 - w_ICH, w_ICH], device=cfg.device).float() # add weighting to CE kwargs

    classifier = BinaryClassifier(resnet_network, device=cfg.device, print_progress=cfg.print_progress, **cfg.train.model_param)

    train_params = [f"--> {k} : {v}" for k, v in cfg.train.model_param.items()]
    logger.info("Classifer Training Parameters \n\t" + "\n\t".join(train_params))

    # Load weights if specified
    if cfg.train.model_path_to_load:
        model_path = cfg.train.model_path_to_load
        classifier.load_model(model_path, map_location=cfg.device)
        logger.info(f"Classifer Model succesfully loaded from {cfg.train.model_path_to_load}")

    # Train
    if cfg.train.model_param.n_epoch > 0:
        classifier.train(train_RSNA_dataset, valid_dataset=test_RSNA_dataset,
                        checkpoint_path=os.path.join(out_path, f'checkpoint.pt'))

    # Evaluate
    auc, acc, recall, precision, f1 = classifier.evaluate(test_RSNA_dataset, save_tsne=False, return_scores=True)
    logger.info(f"Classifier Test AUC : {auc:.2%}")
    logger.info(f"Classifier Test Accuracy : {acc:.2%}")
    logger.info(f"Classifier Test Recall : {recall:.2%}")
    logger.info(f"Classifier Test Precision : {precision:.2%}")
    logger.info(f"Classifier Test F1-score : {f1:.2%}")

    # save model, outputs
    classifier.save_model(os.path.join(out_path, 'resnet.pt'))
    logger.info(f"{cfg.net.resnet} saved at " + os.path.join(out_path, 'resnet.pt'))
    classifier.save_model_state_dict(os.path.join(out_path, 'resnet_state_dict.pt'))
    logger.info(f"{cfg.net.resnet} saved at " + os.path.join(out_path, 'resnet_state_dict.pt'))
    classifier.save_outputs(os.path.join(out_path, 'outputs.json'))
    logger.info("Classifier outputs saved at " + os.path.join(out_path, 'outputs.json'))
    test_df.reset_index(drop=True).to_csv(os.path.join(out_path, 'eval_data_info.csv'))
    logger.info("Evaluation data info saved at " + os.path.join(out_path, 'eval_data_info.csv'))

    # delete any checkpoints
    if os.path.exists(os.path.join(out_path, f'checkpoint.pt')):
        os.remove(os.path.join(out_path, f'checkpoint.pt'))
        logger.info('Checkpoint deleted.')

    cfg.device = str(cfg.device)
    cfg.train.model_param.lr_scheduler = str(cfg.train.model_param.lr_scheduler)
    cfg.train.model_param.loss_fn = str(cfg.train.model_param.loss_fn)
    cfg.train.model_param.loss_fn_kwargs.weight = list(cfg.train.model_param.loss_fn_kwargs.weight)
    with open(os.path.join(out_path, 'config.json'), 'w') as f:
        json.dump(cfg, f)
    logger.info(f"Config file saved at {os.path.join(out_path, 'config.json')}")
コード例 #7
0
def main(config_path):
    """
    Train an Inpainting generator with gated convolution through a SN-PatchGAN training scheme. The generator is trained
    to inpaint non-ICH CT scans from the RSNA dataset.
    """
    # Load config file
    cfg = AttrDict.from_json_path(config_path)

    # make outputs dir
    out_path = os.path.join(cfg.path.output, cfg.exp_name)
    os.makedirs(out_path, exist_ok=True)
    if cfg.train.validate_epoch:
        os.makedirs(os.path.join(out_path, 'valid_results/'), exist_ok=True)

    # initialize seed
    if cfg.seed != -1:
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
        torch.backends.cudnn.deterministic = True

    # initialize logger
    logger = initialize_logger(os.path.join(out_path, 'log.txt'))
    if os.path.exists(os.path.join(out_path, f'checkpoint.pt')):
        logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' + '#' * 30)
    logger.info(f"Experiment : {cfg.exp_name}")

    # set device
    if cfg.device:
        cfg.device = torch.device(cfg.device)
    else:
        cfg.device = torch.device(
            f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    logger.info(
        f"Device set to {cfg.device}. {torch.cuda.device_count()} GPU available, "
        f"{len(cfg.multi_gpu_id) if cfg.multi_gpu_id else 1} used.")

    # set n_thread
    if cfg.n_thread > 0: torch.set_num_threads(cfg.n_thread)

    #--------------------------------------------------------------------
    #                           Make Datasets
    #--------------------------------------------------------------------
    # load RSNA data & keep normal only & and sample the required number
    df_rsna = pd.read_csv(os.path.join(cfg.path.data, 'slice_info.csv'),
                          index_col=0)
    df_rsna_pos = df_rsna[df_rsna.Hemorrhage == 0]
    if cfg.dataset.n_sample >= 0:
        df_rsna_pos = df_rsna_pos.sample(n=cfg.dataset.n_sample,
                                         random_state=cfg.seed)
    # make dataset
    train_dataset = RSNA_Inpaint_dataset(
        df_rsna_pos,
        cfg.path.data,
        augmentation_transform=[
            getattr(tf, tf_name)(**tf_kwargs)
            for tf_name, tf_kwargs in cfg.dataset.augmentation.train.items()
        ],
        window=(cfg.dataset.win_center, cfg.dataset.win_width),
        output_size=cfg.dataset.size,
        **cfg.dataset.mask)

    # load small valid subset and make dataset
    if cfg.train.validate_epoch:
        df_valid = pd.read_csv(os.path.join(cfg.path.data_valid, 'info.csv'),
                               index_col=0)
        valid_dataset = ImgMaskDataset(
            df_valid,
            cfg.path.data_valid,
            augmentation_transform=[
                getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                cfg.dataset.augmentation.eval.items()
            ],
            window=(cfg.dataset.win_center, cfg.dataset.win_width),
            output_size=cfg.dataset.size)
    else:
        valid_dataset = None

    logger.info(f"Train Data will be loaded from {cfg.path.data}.")
    logger.info(f"Train contains {len(train_dataset)} samples.")
    logger.info(f"Valid Data will be loaded from {cfg.path.data_valid}.")
    if valid_dataset:
        logger.info(f"Valid contains {len(valid_dataset)} samples.")
    logger.info(
        f"CT scans will be windowed on [{cfg.dataset.win_center-cfg.dataset.win_width/2} ; {cfg.dataset.win_center + cfg.dataset.win_width/2}]"
    )
    logger.info(
        f"CT scans will be resized to {cfg.dataset.size}x{cfg.dataset.size}")
    logger.info(
        f"Training online data transformation: \n\n {str(train_dataset.transform)}\n"
    )
    if valid_dataset:
        logger.info(
            f"Evaluation online data transformation: \n\n {str(valid_dataset.transform)}\n"
        )
    mask_params = [f"--> {k} : {v}" for k, v in cfg.dataset.mask.items()]
    logger.info("Train inpainting masks generated with \n\t" +
                "\n\t".join(mask_params))

    #--------------------------------------------------------------------
    #                           Make Networks
    #--------------------------------------------------------------------
    if 'context_attention' in cfg.net.gen:
        cfg.net.gen.context_attention_kwargs[
            'device'] = cfg.device  # add device to kwargs of contextual attention module
        generator_net = GatedGenerator(**cfg.net.gen)
    elif 'self_attention' in cfg.net.gen:
        generator_net = SAGatedGenerator(**cfg.net.gen)

    discriminator_net = PatchDiscriminator(**cfg.net.dis)

    if cfg.multi_gpu_id is not None and len(
            cfg.multi_gpu_id) > 1:  # set network for multi-GPU
        #generator_net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(generator_net)
        generator_net = torch.nn.DataParallel(generator_net,
                                              device_ids=cfg.multi_gpu_id)
        #discriminator_net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator_net)
        discriminator_net = torch.nn.DataParallel(discriminator_net,
                                                  device_ids=cfg.multi_gpu_id)
        logger.info("Enabling multi-GPU computation.")

    gen_params = [f"--> {k} : {v}" for k, v in cfg.net.gen.items()]
    logger.info("Gated Generator Parameters \n\t" + "\n\t".join(gen_params))
    dis_params = [f"--> {k} : {v}" for k, v in cfg.net.dis.items()]
    logger.info("Gated Discriminator Parameters \n\t" +
                "\n\t".join(dis_params))

    #--------------------------------------------------------------------
    #                      Make Inpainting GAN model
    #--------------------------------------------------------------------
    cfg.train.model_param.lr_scheduler = getattr(
        torch.optim.lr_scheduler, cfg.train.model_param.lr_scheduler
    )  # convert scheduler name to scheduler class object
    gan_model = SNPatchGAN(generator_net,
                           discriminator_net,
                           print_progress=cfg.print_progress,
                           device=cfg.device,
                           **cfg.train.model_param)
    train_params = [f"--> {k} : {v}" for k, v in cfg.train.model_param.items()]
    logger.info("GAN Training Parameters \n\t" + "\n\t".join(train_params))

    # load models if provided
    if cfg.train.model_path_to_load.gen:
        gan_model.load_Generator(cfg.train.model_path_to_load.gen,
                                 map_location=cfg.device)
    if cfg.train.model_path_to_load.dis:
        gan_model.load_Discriminator(cfg.train.model_path_to_load.dis,
                                     map_location=cfg.device)

    #--------------------------------------------------------------------
    #                       Train SN-PatchGAN model
    #--------------------------------------------------------------------
    if cfg.train.model_param.n_epoch > 0:
        gan_model.train(train_dataset,
                        checkpoint_path=os.path.join(out_path,
                                                     'Checkpoint.pt'),
                        valid_dataset=valid_dataset,
                        valid_path=os.path.join(out_path, 'valid_results/'),
                        save_freq=cfg.train.valid_save_freq)

    #--------------------------------------------------------------------
    #                   Save outputs, models and config
    #--------------------------------------------------------------------
    # save models
    gan_model.save_models(export_fn=(os.path.join(out_path, 'generator.pt'),
                                     os.path.join(out_path,
                                                  'discriminator.pt')),
                          which='both')
    logger.info("Generator model saved at " +
                os.path.join(out_path, 'generator.pt'))
    logger.info("Discriminator model saved at " +
                os.path.join(out_path, 'discriminator.pt'))
    # save outputs
    gan_model.save_outputs(export_fn=os.path.join(out_path, 'outputs.json'))
    logger.info("Outputs file saved at " +
                os.path.join(out_path, 'outputs.json'))
    # save config file
    cfg.device = str(
        cfg.device)  # set device as string to be JSON serializable
    if 'context_attention' in cfg.net.gen:
        cfg.net.gen.context_attention_kwargs.device = str(
            cfg.net.gen.context_attention_kwargs.device)
    cfg.train.model_param.lr_scheduler = str(
        cfg.train.model_param.lr_scheduler)
    with open(os.path.join(out_path, 'config.json'), 'w') as fp:
        json.dump(cfg, fp)
    logger.info("Config file saved at " +
                os.path.join(out_path, 'config.json'))

    # delete any checkpoints
    if os.path.exists(os.path.join(out_path, f'Checkpoint.pt')):
        os.remove(os.path.join(out_path, f'Checkpoint.pt'))
        logger.info('Checkpoint deleted.')
コード例 #8
0
def main(config_path):
    """
    Segmente ICH using the anomaly inpainting approach on a whole dataset and compute slice/volume dice.
    """
    # load config
    cfg = AttrDict.from_json_path(config_path)

    # make outputs dir
    out_path = os.path.join(cfg.path.output, cfg.exp_name)
    os.makedirs(out_path, exist_ok=True)

    # initialize seed
    if cfg.seed != -1:
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
        torch.backends.cudnn.deterministic = True

    # set device
    cfg.device = torch.device(
        cfg.device) if cfg.device else get_available_device()

    # initialize logger
    logger = initialize_logger(os.path.join(out_path, 'log.txt'))
    logger.info(f"Experiment : {cfg.exp_name}")

    # get Dataset
    data_info_df = pd.read_csv(os.path.join(cfg.path.data, 'ct_info.csv'),
                               index_col=0)
    #data_info_df = data_info_df[sum([data_info_df.CT_fn.str.contains(s) for s in ['49/16', '51/39', '71/15', '71/22', '75/22']]) > 0]
    dataset = public_SegICH_Dataset2D(
        data_info_df,
        cfg.path.data,
        augmentation_transform=[
            getattr(tf, tf_name)(**tf_kwargs)
            for tf_name, tf_kwargs in cfg.data.augmentation.items()
        ],
        output_size=cfg.data.size,
        window=(cfg.data.win_center, cfg.data.win_width))

    # load inpainting model
    cfg_inpaint = AttrDict.from_json_path(cfg.inpainter_cfg_path)
    cfg_inpaint.net.gen.return_coarse = False
    inpaint_net = SAGatedGenerator(**cfg_inpaint.net.gen)
    loaded_state_dict = torch.load(cfg.inpainter_model_path,
                                   map_location=cfg.device)
    inpaint_net.load_state_dict(loaded_state_dict)
    inpaint_net = inpaint_net.to(
        cfg.device
    )  # inpainter not in eval mode beacuse batch norm layers are not stabilized (because GAN optimization)
    logger.info(
        f"Inpainter model succesfully loaded from {cfg.inpainter_model_path}")

    # make AD inpainter Module
    ad_inpainter = InpaintAnomalyDetector(inpaint_net,
                                          device=cfg.device,
                                          **cfg.model_param)

    # Load Classifier
    if cfg.classifier_model_path is not None:
        cfg_classifier = AttrDict.from_json_path(
            os.path.join(cfg.classifier_model_path, 'config.json'))
        classifier = getattr(rn, cfg_classifier.net.resnet)(
            num_classes=cfg_classifier.net.num_classes,
            input_channels=cfg_classifier.net.input_channels)
        classifier_state_dict = torch.load(os.path.join(
            cfg.classifier_model_path, 'resnet_state_dict.pt'),
                                           map_location=cfg.device)
        classifier.load_state_dict(classifier_state_dict)
        classifier = classifier.to(cfg.device)
        classifier.eval()
        logger.info(
            f"ResNet classifier model succesfully loaded from {os.path.join(cfg.classifier_model_path, 'resnet_state_dict.pt')}"
        )

    # iterate over dataset
    all_pred = []
    for i, sample in enumerate(dataset):
        # unpack data
        image, target, id, slice = sample
        logger.info(
            "=" * 25 +
            f" SAMPLE {i+1:04}/{len(dataset):04} - Volume {id:03} Slice {slice:03} "
            + "=" * 25)

        # Classify sample
        if cfg.classifier_model_path is not None:
            with torch.no_grad():
                input_clss = image.unsqueeze(0).to(cfg.device).float()
                pred_score = nn.functional.softmax(
                    classifier(input_clss), dim=1
                )[:, 1]  # take columns of softmax of positive class as score
                pred = 1 if pred_score >= cfg.classification_threshold else 0
        else:
            pred = 1  # if not classifier given, all slices are processed

        # process slice if classifier has detected Hemorrhage
        if pred == 1:
            logger.info(
                f"ICH detected. Compute anomaly mask through inpainting.")
            # Detect anomalies using the robuste approach
            ad_mask, ano_map, intermediate_masks = robust_anomaly_detect(
                image,
                ad_inpainter,
                save_dir=None,
                verbose=True,
                return_intermediate=True,
                **cfg.robust_param)
            # save ad_mask
            ad_mask_fn = f"{id}/{slice}_anomalies.bmp"
            save_path = os.path.join(out_path, 'pred/', ad_mask_fn)
            if not os.path.isdir(os.path.dirname(save_path)):
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
            io.imsave(save_path, img_as_ubyte(ad_mask), check_contrast=False)
            # save anomaly map
            ad_map_fn = f"{id}/{slice}_map_anomalies.png"
            save_path_map = os.path.join(out_path, 'pred/', ad_map_fn)
            io.imsave(save_path_map,
                      img_as_ubyte(ano_map),
                      check_contrast=False)
            # save intermediate mask
            for j, m in enumerate(intermediate_masks):
                if not os.path.isdir(
                        os.path.join(out_path,
                                     f"pred/{id}/intermediate_masks/")):
                    os.makedirs(os.path.join(out_path,
                                             f"pred/{id}/intermediate_masks/"),
                                exist_ok=True)
                io.imsave(os.path.join(
                    out_path,
                    f"pred/{id}/intermediate_masks/{slice}_anomalies_{j+1}.bmp"
                ),
                          img_as_ubyte(m),
                          check_contrast=False)
        else:
            logger.info(f"No ICH detected. Set the anomaly mask to zeros.")
            ad_mask = np.zeros_like(target[0].numpy())
            ad_mask_fn, ad_map_fn = 'None', 'None'

        # compute confusion matrix with target ICH mask
        tn, fp, fn, tp = confusion_matrix(target[0].numpy().ravel(),
                                          ad_mask.ravel(),
                                          labels=[0, 1]).ravel()
        # append to all_pred list
        all_pred.append({
            'id': id.item(),
            'slice': slice.item(),
            'label': target.max().item(),
            'TP': tp,
            'TN': tn,
            'FP': fp,
            'FN': fn,
            'ad_mask_fn': ad_mask_fn,
            'ad_map_fn': ad_map_fn
        })

    # make a dataframe of all predictions
    slice_df = pd.DataFrame(all_pred)
    volume_df = slice_df[['id', 'label', 'TP', 'TN', 'FP',
                          'FN']].groupby('id').agg({
                              'label': 'max',
                              'TP': 'sum',
                              'TN': 'sum',
                              'FP': 'sum',
                              'FN': 'sum'
                          })

    # Compute Dice and Volume Dice
    slice_df['Dice'] = (2 * slice_df.TP + 1) / (2 * slice_df.TP + slice_df.FP +
                                                slice_df.FN + 1)
    volume_df['Dice'] = (2 * volume_df.TP + 1) / (
        2 * volume_df.TP + volume_df.FP + volume_df.FN + 1)
    logger.info(f"Mean slice dice : {slice_df.Dice.mean(axis=0):.3f}")
    logger.info(f"Mean volume dice : {volume_df.Dice.mean(axis=0):.3f}")

    # Save Scores and Config
    slice_df.to_csv(os.path.join(out_path, 'slice_predictions.csv'))
    logger.info(
        f"Slice prediction csv saved at {os.path.join(out_path, 'slice_predictions.csv')}"
    )
    volume_df.to_csv(os.path.join(out_path, 'volume_predictions.csv'))
    logger.info(
        f"Volume prediction csv saved at {os.path.join(out_path, 'volume_predictions.csv')}"
    )
    cfg.device = str(cfg.device)
    with open(os.path.join(out_path, 'config.json'), 'w') as f:
        json.dump(cfg, f)
    logger.info(
        f"Config file saved at {os.path.join(out_path, 'config.json')}")
def main(config_path):
    """
    Train an Auto-Encoder to reconstruct CT-scans from the RSNA dataset.
    """
    # Load config file
    cfg = AttrDict.from_json_path(config_path)

    # make outputs dir
    out_path = os.path.join(cfg.path.output, cfg.exp_name)
    os.makedirs(out_path, exist_ok=True)
    if cfg.train.validate_epoch:
        os.makedirs(os.path.join(out_path, 'valid_results/'), exist_ok=True)

    # initialize seed
    if cfg.seed != None:
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)
        torch.backends.cudnn.deterministic = True

    # initialize logger
    logger = initialize_logger(os.path.join(out_path, 'log.txt'))
    if os.path.exists(os.path.join(out_path, f'checkpoint.pt')):
        logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' + '#' * 30)
    logger.info(f"Experiment : {cfg.exp_name}")

    # set device
    if cfg.device:
        cfg.device = torch.device(cfg.device)
    else:
        cfg.device = torch.device(
            f'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    logger.info(f"Device set to {cfg.device}.")

    #--------------------------------------------------------------------
    #                           Make Datasets
    #--------------------------------------------------------------------
    # load RSNA data & keep normal only & and sample the required number
    df_rsna = pd.read_csv(os.path.join(cfg.path.data, 'slice_info.csv'),
                          index_col=0)
    df_rsna_pos = df_rsna[df_rsna.Hemorrhage == 0]
    if cfg.dataset.n_sample > 0:
        df_rsna_pos = df_rsna_pos.sample(n=cfg.dataset.n_sample,
                                         random_state=cfg.seed)
    # split df to keep n_sample_valid for validation
    if (cfg.dataset.n_sample_valid > 0) & cfg.train.validate_epoch:
        df_train, df_valid = train_test_split(
            df_rsna_pos,
            test_size=cfg.dataset.n_sample_valid,
            random_state=cfg.seed)
    else:
        df_train = df_rsna_pos

    # make dataset
    train_dataset = RSNA_dataset(
        df_train,
        cfg.path.data,
        mode='standard',
        augmentation_transform=[
            getattr(tf, tf_name)(**tf_kwargs)
            for tf_name, tf_kwargs in cfg.dataset.augmentation.train.items()
        ],
        window=(cfg.dataset.win_center, cfg.dataset.win_width),
        output_size=cfg.dataset.size)

    # load small valid subset and make dataset
    if cfg.train.validate_epoch:
        valid_dataset = RSNA_dataset(df_valid,
                                     cfg.path.data,
                                     mode='standard',
                                     augmentation_transform=[
                                         getattr(tf, tf_name)(**tf_kwargs)
                                         for tf_name, tf_kwargs in
                                         cfg.dataset.augmentation.eval.items()
                                     ],
                                     window=(cfg.dataset.win_center,
                                             cfg.dataset.win_width),
                                     output_size=cfg.dataset.size)
    else:
        valid_dataset = None

    logger.info(f"Data will be loaded from {cfg.path.data}.")
    logger.info(f"Train contains {len(train_dataset)} samples.")
    if valid_dataset:
        logger.info(f"Valid contains {len(valid_dataset)} samples.")
    logger.info(
        f"CT scans will be windowed on [{cfg.dataset.win_center-cfg.dataset.win_width/2} ; {cfg.dataset.win_center + cfg.dataset.win_width/2}]"
    )
    logger.info(
        f"CT scans will be resized to {cfg.dataset.size}x{cfg.dataset.size}")
    logger.info(
        f"Training online data transformation: \n\n {str(train_dataset.transform)}\n"
    )
    if valid_dataset:
        logger.info(
            f"Evaluation online data transformation: \n\n {str(valid_dataset.transform)}\n"
        )

    #--------------------------------------------------------------------
    #                           Make Networks
    #--------------------------------------------------------------------
    ae_net = AE_net(**cfg.net)

    ae_params = [f"--> {k} : {v}" for k, v in cfg.net.items()]
    logger.info("AE Parameters \n\t" + "\n\t".join(ae_params))

    #--------------------------------------------------------------------
    #                          Make AE model
    #--------------------------------------------------------------------
    cfg.train.model_param.lr_scheduler = getattr(
        torch.optim.lr_scheduler, cfg.train.model_param.lr_scheduler
    )  # convert scheduler name to scheduler class object
    ae_model = AE(ae_net,
                  print_progress=cfg.print_progress,
                  device=cfg.device,
                  **cfg.train.model_param)
    train_params = [f"--> {k} : {v}" for k, v in cfg.train.model_param.items()]
    logger.info("AE Training Parameters \n\t" + "\n\t".join(train_params))

    # load models if provided
    if cfg.train.model_path_to_load:
        ae_model.load_model(cfg.train.model_path_to_load,
                            map_location=cfg.device)

    #--------------------------------------------------------------------
    #                          Train AE model
    #--------------------------------------------------------------------
    if cfg.train.model_param.n_epoch > 0:
        ae_model.train(train_dataset,
                       checkpoint_path=os.path.join(out_path, 'Checkpoint.pt'),
                       valid_dataset=valid_dataset,
                       valid_path=os.path.join(out_path, 'valid_results/'),
                       valid_freq=cfg.train.valid_save_freq)

    #--------------------------------------------------------------------
    #                   Save outputs, models and config
    #--------------------------------------------------------------------
    # save models
    ae_model.save_model(export_fn=os.path.join(out_path, 'AE.pt'))
    logger.info("AE model saved at " + os.path.join(out_path, 'AE.pt'))
    # save outputs
    ae_model.save_outputs(export_fn=os.path.join(out_path, 'outputs.json'))
    logger.info("Outputs file saved at " +
                os.path.join(out_path, 'outputs.json'))
    # save config file
    cfg.device = str(
        cfg.device)  # set device as string to be JSON serializable
    cfg.train.model_param.lr_scheduler = str(
        cfg.train.model_param.lr_scheduler)
    with open(os.path.join(out_path, 'config.json'), 'w') as fp:
        json.dump(cfg, fp)
    logger.info("Config file saved at " +
                os.path.join(out_path, 'config.json'))