def main(config_path, type, n_cluster, n_neighbors):
    """
    Ablation study on Contrastive or AE representation.
    """
    # Load config file
    cfg = Config(settings=None)
    cfg.load_config(config_path)

    # Update n_cluster and n_neighbors if provided
    if n_cluster > 0:
        cfg.settings['n_cluster'] = n_cluster
    if n_neighbors > 0:
        cfg.settings['n_neighbors'] = n_neighbors

    # Get path to output
    OUTPUT_PATH = cfg.settings['PATH']['OUTPUT'] + cfg.settings[
        'Experiment_Name'] + datetime.today().strftime('%Y_%m_%d_%Hh%M') + '/'
    # make output dir
    if not os.path.isdir(OUTPUT_PATH + 'models/'):
        os.makedirs(OUTPUT_PATH + 'model/', exist_ok=True)
    if not os.path.isdir(OUTPUT_PATH + 'results/'):
        os.makedirs(OUTPUT_PATH + 'results/', exist_ok=True)
    if not os.path.isdir(OUTPUT_PATH + 'logs/'):
        os.makedirs(OUTPUT_PATH + 'logs/', exist_ok=True)

    for seed_i, seed in enumerate(cfg.settings['seeds']):
        ############################### Set Up #################################
        # initialize logger
        logging.basicConfig(level=logging.INFO)
        logger = logging.getLogger()
        try:
            logger.handlers[1].stream.close()
            logger.removeHandler(logger.handlers[1])
        except IndexError:
            pass
        logger.setLevel(logging.INFO)
        formatter = logging.Formatter(
            '%(asctime)s | %(levelname)s | %(message)s')
        log_file = OUTPUT_PATH + 'logs/' + f'log_{seed_i+1}.txt'
        file_handler = logging.FileHandler(log_file)
        file_handler.setLevel(logging.INFO)
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)

        # print path
        logger.info(f"Log file : {log_file}")
        logger.info(f"Data path : {cfg.settings['PATH']['DATA']}")
        logger.info(f"Outputs path : {OUTPUT_PATH}" + "\n")

        # Set seed
        if seed != -1:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            torch.backends.cudnn.deterministic = True
            logger.info(
                f"Set seed {seed_i+1:02}/{len(cfg.settings['seeds']):02} to {seed}"
            )

        # set number of thread
        if cfg.settings['n_thread'] > 0:
            torch.set_num_threads(cfg.settings['n_thread'])

        # check if GPU available
        cfg.settings['device'] = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')
        # Print technical info in logger
        logger.info(f"Device : {cfg.settings['device']}")
        logger.info(f"Number of thread : {cfg.settings['n_thread']}")

        ############################### Split Data #############################
        # Load data informations
        df_info = pd.read_csv(cfg.settings['PATH']['DATA_INFO'])
        df_info = df_info.drop(df_info.columns[0], axis=1)
        # remove low contrast images (all black)
        df_info = df_info[df_info.low_contrast == 0]

        # Train Validation Test Split
        spliter = MURA_TrainValidTestSplitter(
            df_info,
            train_frac=cfg.settings['Split']['train_frac'],
            ratio_known_normal=cfg.settings['Split']['known_normal'],
            ratio_known_abnormal=cfg.settings['Split']['known_abnormal'],
            random_state=42)
        spliter.split_data(verbose=False)
        train_df = spliter.get_subset('train')
        valid_df = spliter.get_subset('valid')
        test_df = spliter.get_subset('test')

        # print info to logger
        for key, value in cfg.settings['Split'].items():
            logger.info(f"Split param {key} : {value}")
        logger.info("Split Summary \n" +
                    str(spliter.print_stat(returnTable=True)))

        ############################### Load model #############################
        if type == 'SimCLR':
            net = Encoder(MLP_Neurons_layer=[512, 256, 128])
            init_key = 'repr_net_dict'
        elif type == 'AE':
            net = AE_net(MLP_Neurons_layer_enc=[512, 256, 128],
                         MLP_Neurons_layer_dec=[128, 256, 512],
                         output_channels=1)
            init_key = 'ae_net_dict'

        net = net.to(cfg.settings['device'])
        pretrain_state_dict = torch.load(cfg.settings['model_path_name'] +
                                         f'{seed_i+1}.pt',
                                         map_location=cfg.settings['device'])
        net.load_state_dict(pretrain_state_dict[init_key])
        logger.info('Model weights successfully loaded from ' +
                    cfg.settings['model_path_name'] + f'{seed_i+1}.pt')

        ablation_model = NearestNeighbors(
            net,
            kmeans_reduction=cfg.settings['kmeans_reduction'],
            batch_size=cfg.settings['batch_size'],
            n_job_dataloader=cfg.settings['num_worker'],
            print_batch_progress=cfg.settings['print_batch_progress'],
            device=cfg.settings['device'])

        ############################### Training ###############################
        # make dataset
        train_dataset = MURA_Dataset(
            train_df,
            data_path=cfg.settings['PATH']['DATA'],
            load_mask=True,
            load_semilabels=True,
            output_size=cfg.settings['Split']['img_size'],
            data_augmentation=False)
        valid_dataset = MURA_Dataset(
            valid_df,
            data_path=cfg.settings['PATH']['DATA'],
            load_mask=True,
            load_semilabels=True,
            output_size=cfg.settings['Split']['img_size'],
            data_augmentation=False)
        test_dataset = MURA_Dataset(
            test_df,
            data_path=cfg.settings['PATH']['DATA'],
            load_mask=True,
            load_semilabels=True,
            output_size=cfg.settings['Split']['img_size'],
            data_augmentation=False)

        logger.info("Online preprocessing pipeline : \n" +
                    str(train_dataset.transform) + "\n")

        # Train
        ablation_model.train(train_dataset,
                             n_cluster=cfg.settings['n_cluster'])

        # Evaluate
        logger.info(
            f'--- Validation with {cfg.settings["n_neighbors"]} neighbors')
        ablation_model.evaluate(valid_dataset,
                                n_neighbors=cfg.settings['n_neighbors'],
                                mode='valid')
        logger.info(f'--- Test with {cfg.settings["n_neighbors"]} neighbors')
        ablation_model.evaluate(test_dataset,
                                n_neighbors=cfg.settings['n_neighbors'],
                                mode='test')

        ############################## Save Results ############################
        ablation_model.save_results(OUTPUT_PATH +
                                    f'results/results_{seed_i+1}.json')
        logger.info("Results saved at " + OUTPUT_PATH +
                    f"results/results_{seed_i+1}.json")

        ablation_model.save_model(OUTPUT_PATH +
                                  f'model/ablation_{seed_i+1}.pt')
        logger.info("model saved at " + OUTPUT_PATH +
                    f"model/abaltion_{seed_i+1}.pt")

    # save config file
    cfg.settings['device'] = str(cfg.settings['device'])
    cfg.save_config(OUTPUT_PATH + 'config.json')
    logger.info("Config saved at " + OUTPUT_PATH + "config.json")
Example #2
0
import numpy as np
from transformers import *

from src.data.Batcher import Batcher
from src.utils.Config import Config
from src.utils.util import device, ParseKwargs
from src.adapet import adapet
from src.eval.eval_model import dev_eval

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', "--model_dir", required=True)
    parser.add_argument('-c', "--config_file", required=True)
    parser.add_argument('-k',
                        '--kwargs',
                        nargs='*',
                        action=ParseKwargs,
                        default={})
    args = parser.parse_args()

    config = Config(args.config_file, args.kwargs, mkdir=True)

    tokenizer = AutoTokenizer.from_pretrained(config.pretrained_weight)
    batcher = Batcher(config, tokenizer, config.dataset)
    dataset_reader = batcher.get_dataset_reader()

    model = adapet(config, tokenizer, dataset_reader).to(device)
    model.load_state_dict(
        torch.load(os.path.join(args.model_dir, "best_model.pt")))
    dev_eval(config, model, batcher, 0)
Example #3
0
import argparse

import matplotlib.pyplot as plt
import numpy as np

from src.data.Data import Data
from src.utils.Config import Config

parser = argparse.ArgumentParser(description="")
parser.add_argument('config')
args = parser.parse_args()

config = Config.from_file(args.config)
data = Data(config.get_with_prefix("data"))

dataset = data.build_val_dataset()

for reference_images, reference_cam_poses, query_images, query_cam_poses, iou, room_ids, pose_transform, full_matching in dataset:
    fig = plt.figure()
    plt.imshow(np.concatenate((reference_images[0], query_images[0]), axis=1),
               extent=[0, data.image_size * 2, data.image_size, 0])

    lines = []

    def onclick(event):
        print('button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %
              (event.button, event.x, event.y, event.xdata, event.ydata))
        x = int(event.xdata)
        y = int(event.ydata)
        if 0 <= x < 128 and 0 <= y < 128:
            for line in lines:
def main(**params):
    """

    """
    # make output dir
    OUTPUT_PATH = params['exp_folder'] + params['dataset_name'] + '_' + \
                  params['net_name'] + '_' + datetime.today().strftime('%Y_%m_%d_%Hh%M')+'/'
    if not os.path.isdir(OUTPUT_PATH + 'model/'):
        os.makedirs(OUTPUT_PATH + 'model/', exist_ok=True)
    if not os.path.isdir(OUTPUT_PATH + 'results/'):
        os.makedirs(OUTPUT_PATH + 'results/', exist_ok=True)

    # create the config file
    cfg = Config(params)

    # set up the logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')
    log_file = OUTPUT_PATH + 'LOG.txt'
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    logger.info(f'Export path : {OUTPUT_PATH}')

    # Load config if required
    if params['load_config']:
        cfg.load_config(params['load_config'])
        logger.info(f'Config loaded from {params["load_config"]}')

    if not torch.cuda.is_available():
        cfg.settings['device'] = 'cpu'

    logger.info('Config Parameters:')
    for key, value in cfg.settings.items():
        logger.info(f'|---- {key} : {value}')

    #loop over seeds:
    train_acc_list, test_acc_list = [], []
    seeds = ast.literal_eval(cfg.settings['seeds'])

    for i, seed in enumerate(seeds):
        logger.info('-' * 25 + f' Training n°{i+1} ' + '-' * 25)
        # set the seed
        if seed != -1:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            logger.info(f'Set seed {i+1:02}/{len(seeds):02} to {seed}')

        # get dataset
        train_dataset = MNISTDataset(cfg.settings['dataset_name'],
                                     cfg.settings['data_path'],
                                     train=True,
                                     data_augmentation=True)

        test_dataset = MNISTDataset(cfg.settings['dataset_name'],
                                    cfg.settings['data_path'],
                                    train=False,
                                    data_augmentation=False)

        # define the LookUpTable for KMNIST and FashionMNIST
        LUT = None
        if cfg.settings['dataset_name'] == 'FashionMNIST':
            LUT = {
                0: 'T-Shirt/Top',
                1: 'Trouser',
                2: 'Pullover',
                3: 'Dress',
                4: 'Coat',
                5: 'Sandal',
                6: 'Shirt',
                7: 'Sneaker',
                8: 'Bag',
                9: 'Ankle boot'
            }
        elif cfg.settings['dataset_name'] == 'KMNIST':
            LUT = {
                0: chr(12362) + ' (a)',
                1: chr(12365) + ' (ki)',
                2: chr(12377) + ' (su)',
                3: chr(12388) + ' (tu)',
                4: chr(12394) + ' (na)',
                5: chr(12399) + ' (ha)',
                6: chr(12414) + ' (ma)',
                7: chr(12420) + ' (ya)',
                8: chr(12428) + ' (re)',
                9: chr(12434) + ' (wo)'
            }

        # get model
        net = LeNet5()
        LeNet = LeNet5_trainer(net,
                               n_epoch=cfg.settings['n_epochs'],
                               batch_size=cfg.settings['batch_size'],
                               num_workers=cfg.settings['num_workers'],
                               lr=cfg.settings['lr'],
                               lr_decay=cfg.settings['lr_decay'],
                               device=cfg.settings['device'],
                               optimizer=cfg.settings['optimizer_name'],
                               seed=seed)

        # Load model if required
        if cfg.settings['load_model']:
            LeNet.load_model(cfg.settings['load_model'])
            logger.info(f'Model loaded from {cfg.settings["load_model"]}')

        # train model
        LeNet.train(train_dataset, test_dataset)
        LeNet.evaluate(test_dataset, last=True)

        # Save model and results
        LeNet.save_model(OUTPUT_PATH + f'model/model_{i+1}.pt')
        logger.info('Model saved at ' + OUTPUT_PATH + f'model/model_{i+1}.pt')
        LeNet.save_results(OUTPUT_PATH + f'results/results_{i+1}.json')
        logger.info('Results saved at ' + OUTPUT_PATH +
                    f'results/results_{i+1}.json')
        cfg.save_config(OUTPUT_PATH + 'config.json')
        logger.info('Config saved at ' + OUTPUT_PATH + 'config.json')

        train_acc_list.append(LeNet.train_acc)
        test_acc_list.append(LeNet.test_acc)

        # show results
        show_samples(LeNet.test_pred,
                     test_dataset,
                     n=(5, 10),
                     save_path=OUTPUT_PATH +
                     f'results/classification_sample_{i+1}.pdf',
                     lut=LUT)

    train_acc, test_acc = np.array(train_acc_list), np.array(test_acc_list)
    logger.info('\n' + '-' * 60)
    logger.info(
        f"Performance of {cfg.settings['net_name']} on {cfg.settings['dataset_name']} over {len(seeds)} replicates"
    )
    logger.info(
        f"|---- Train accuracy {train_acc.mean():.3%} +/- {1.96*train_acc.std():.3%}"
    )
    logger.info(
        f"|---- Test accuracy {test_acc.mean():.3%} +/- {1.96*test_acc.std():.3%}"
    )
def main(config_path):
    """
    Train and evaluate a 2D UNet on the public ICH dataset using the parameters sepcified on the JSON at the
    config_path. The evaluation is performed by k-fold cross-validation.
    """
    # load config file
    cfg = Config(settings=None)
    cfg.load_config(config_path)

    # Make Output directories
    out_path = os.path.join(cfg.settings['path']['OUTPUT'],
                            cfg.settings['exp_name'])  # + '/'
    os.makedirs(out_path, exist_ok=True)
    for k in range(cfg.settings['split']['n_fold']):
        os.makedirs(os.path.join(out_path, f'Fold_{k+1}/pred/'), exist_ok=True)

    # Initialize random seed to given seed
    seed = cfg.settings['seed']
    if seed != -1:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

    # Load data csv
    data_info_df = pd.read_csv(
        os.path.join(cfg.settings['path']['DATA'], 'ct_info.csv'))
    data_info_df = data_info_df.drop(data_info_df.columns[0], axis=1)
    patient_df = pd.read_csv(
        os.path.join(cfg.settings['path']['DATA'], 'patient_info.csv'))
    patient_df = patient_df.drop(patient_df.columns[0], axis=1)

    # Generate Cross-Val indices at the patient level
    skf = StratifiedKFold(n_splits=cfg.settings['split']['n_fold'],
                          shuffle=cfg.settings['split']['shuffle'],
                          random_state=seed)
    # iterate over folds and ensure that there are the same amount of ICH positive patient per fold --> Stratiffied CrossVal
    for k, (train_idx, test_idx) in enumerate(
            skf.split(patient_df.PatientNumber, patient_df.Hemorrhage)):
        # if fold results not already there
        if not os.path.exists(
                os.path.join(out_path, f'Fold_{k+1}/outputs.json')):
            # initialize logger
            logging.basicConfig(level=logging.INFO)
            logger = logging.getLogger()
            try:
                logger.handlers[1].stream.close()
                logger.removeHandler(logger.handlers[1])
            except IndexError:
                pass
            logger.setLevel(logging.INFO)
            file_handler = logging.FileHandler(
                os.path.join(out_path, f'Fold_{k+1}/log.txt'))
            file_handler.setLevel(logging.INFO)
            file_handler.setFormatter(
                logging.Formatter('%(asctime)s | %(levelname)s | %(message)s'))
            logger.addHandler(file_handler)

            if os.path.exists(
                    os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt')):
                logger.info('\n' + '#' * 30 + f'\n Recovering Session \n' +
                            '#' * 30)

            logger.info(f"Experiment : {cfg.settings['exp_name']}")
            logger.info(
                f"Cross-Validation fold {k+1:02}/{cfg.settings['split']['n_fold']:02}"
            )

            # initialize nbr of thread
            if cfg.settings['n_thread'] > 0:
                torch.set_num_threads(cfg.settings['n_thread'])
            logger.info(f"Number of thread : {cfg.settings['n_thread']}")
            # check if GPU available
            #cfg.settings['device'] = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
            if cfg.settings['device'] is not None:
                cfg.settings['device'] = torch.device(cfg.settings['device'])
            else:
                if torch.cuda.is_available():
                    free_mem, device_idx = 0.0, 0
                    for d in range(torch.cuda.device_count()):
                        mem = torch.cuda.get_device_properties(
                            d).total_memory - torch.cuda.memory_allocated(d)
                        if mem > free_mem:
                            device_idx = d
                            free_mem = mem
                    cfg.settings['device'] = torch.device(f'cuda:{device_idx}')
                else:
                    cfg.settings['device'] = torch.device('cpu')
            logger.info(f"Device : {cfg.settings['device']}")

            # extract train and test DataFrames + print summary (n samples positive and negatives)
            train_df = data_info_df[data_info_df.PatientNumber.isin(
                patient_df.loc[train_idx, 'PatientNumber'].values)]
            test_df = data_info_df[data_info_df.PatientNumber.isin(
                patient_df.loc[test_idx, 'PatientNumber'].values)]
            # sample the dataframe to have more or less normal slices
            n_remove = int(
                max(
                    0,
                    len(train_df[train_df.Hemorrhage == 0]) -
                    cfg.settings['dataset']['frac_negative'] *
                    len(train_df[train_df.Hemorrhage == 1])))
            df_remove = train_df[train_df.Hemorrhage == 0].sample(
                n=n_remove, random_state=seed)
            train_df = train_df[~train_df.index.isin(df_remove.index)]
            logger.info(
                '\n' +
                str(get_split_summary_table(data_info_df, train_df, test_df)))

            # Make Dataset + print online augmentation summary
            train_dataset = public_SegICH_Dataset2D(
                train_df,
                cfg.settings['path']['DATA'],
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg.settings['data']['augmentation']['train'].items()
                ],
                window=(cfg.settings['data']['win_center'],
                        cfg.settings['data']['win_width']),
                output_size=cfg.settings['data']['size'])
            test_dataset = public_SegICH_Dataset2D(
                test_df,
                cfg.settings['path']['DATA'],
                augmentation_transform=[
                    getattr(tf, tf_name)(**tf_kwargs) for tf_name, tf_kwargs in
                    cfg.settings['data']['augmentation']['eval'].items()
                ],
                window=(cfg.settings['data']['win_center'],
                        cfg.settings['data']['win_width']),
                output_size=cfg.settings['data']['size'])
            logger.info(
                f"Data will be loaded from {cfg.settings['path']['DATA']}.")
            logger.info(
                f"CT scans will be windowed on [{cfg.settings['data']['win_center']-cfg.settings['data']['win_width']/2} ; {cfg.settings['data']['win_center'] + cfg.settings['data']['win_width']/2}]"
            )
            logger.info(
                f"Training online data transformation: \n\n {str(train_dataset.transform)}\n"
            )
            logger.info(
                f"Evaluation online data transformation: \n\n {str(test_dataset.transform)}\n"
            )

            # Make architecture (and print summmary ??)
            unet_arch = UNet(
                depth=cfg.settings['net']['depth'],
                top_filter=cfg.settings['net']['top_filter'],
                use_3D=cfg.settings['net']['3D'],
                in_channels=cfg.settings['net']['in_channels'],
                out_channels=cfg.settings['net']['out_channels'],
                bilinear=cfg.settings['net']['bilinear'],
                midchannels_factor=cfg.settings['net']['midchannels_factor'],
                p_dropout=cfg.settings['net']['p_dropout'])
            unet_arch.to(cfg.settings['device'])
            logger.info(
                f"U-Net2D initialized with a depth of {cfg.settings['net']['depth']}"
                f" and a number of initial filter of {cfg.settings['net']['top_filter']},"
            )
            logger.info(
                f"Reconstruction performed with {'Upsample + Conv' if cfg.settings['net']['bilinear'] else 'ConvTranspose'}."
            )
            logger.info(
                f"U-Net2D takes {cfg.settings['net']['in_channels']} as input channels and {cfg.settings['net']['out_channels']} as output channels."
            )
            logger.info(
                f"The U-Net2D has {sum(p.numel() for p in unet_arch.parameters())} parameters."
            )

            # Make model
            unet2D = UNet2D(
                unet_arch,
                n_epoch=cfg.settings['train']['n_epoch'],
                batch_size=cfg.settings['train']['batch_size'],
                lr=cfg.settings['train']['lr'],
                lr_scheduler=getattr(torch.optim.lr_scheduler,
                                     cfg.settings['train']['lr_scheduler']),
                lr_scheduler_kwargs=cfg.settings['train']
                ['lr_scheduler_kwargs'],
                loss_fn=getattr(src.models.optim.LossFunctions,
                                cfg.settings['train']['loss_fn']),
                loss_fn_kwargs=cfg.settings['train']['loss_fn_kwargs'],
                weight_decay=cfg.settings['train']['weight_decay'],
                num_workers=cfg.settings['train']['num_workers'],
                device=cfg.settings['device'],
                print_progress=cfg.settings['print_progress'])

            # Load model if required
            if cfg.settings['train']['model_path_to_load']:
                if isinstance(cfg.settings['train']['model_path_to_load'],
                              str):
                    model_path = cfg.settings['train']['model_path_to_load']
                    unet2D.load_model(model_path,
                                      map_location=cfg.settings['device'])
                elif isinstance(cfg.settings['train']['model_path_to_load'],
                                list):
                    model_path = cfg.settings['train']['model_path_to_load'][k]
                    unet2D.load_model(model_path,
                                      map_location=cfg.settings['device'])
                else:
                    raise ValueError(
                        f'Model path to load type not understood.')
                logger.info(f"2D U-Net model loaded from {model_path}")

            # print Training hyper-parameters
            train_params = []
            for key, value in cfg.settings['train'].items():
                train_params.append(f"--> {key} : {value}")
            logger.info('Training settings:\n\t' + '\n\t'.join(train_params))

            # Train model
            eval_dataset = test_dataset if cfg.settings['train'][
                'validate_epoch'] else None
            unet2D.train(train_dataset,
                         valid_dataset=eval_dataset,
                         checkpoint_path=os.path.join(
                             out_path, f'Fold_{k+1}/checkpoint.pt'))

            # Evaluate model
            unet2D.evaluate(test_dataset,
                            save_path=os.path.join(out_path,
                                                   f'Fold_{k+1}/pred/'))

            # Save models & outputs
            unet2D.save_model(
                os.path.join(out_path, f'Fold_{k+1}/trained_unet.pt'))
            logger.info("Trained U-Net saved at " +
                        os.path.join(out_path, f'Fold_{k+1}/trained_unet.pt'))
            unet2D.save_outputs(
                os.path.join(out_path, f'Fold_{k+1}/outputs.json'))
            logger.info("Trained statistics saved at " +
                        os.path.join(out_path, f'Fold_{k+1}/outputs.json'))

            # delete checkpoint if exists
            if os.path.exists(
                    os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt')):
                os.remove(os.path.join(out_path, f'Fold_{k+1}/checkpoint.pt'))
                logger.info('Checkpoint deleted.')

    # save mean +/- 1.96 std Dice in .txt file
    scores_list = []
    for k in range(cfg.settings['split']['n_fold']):
        with open(os.path.join(out_path, f'Fold_{k+1}/outputs.json'),
                  'r') as f:
            out = json.load(f)
        scores_list.append(
            [out['eval']['dice']['all'], out['eval']['dice']['positive']])
    means = np.array(scores_list).mean(axis=0)
    CI95 = 1.96 * np.array(scores_list).std(axis=0)
    with open(os.path.join(out_path, 'average_scores.txt'), 'w') as f:
        f.write(f'Dice = {means[0]} +/- {CI95[0]}\n')
        f.write(f'Dice (Positive) = {means[1]} +/- {CI95[1]}\n')
    logger.info('Average Scores saved at ' +
                os.path.join(out_path, 'average_scores.txt'))

    # generate dataframe of all prediction
    df_list = [
        pd.read_csv(
            os.path.join(out_path,
                         f'Fold_{i+1}/pred/volume_prediction_scores.csv'))
        for i in range(cfg.settings['split']['n_fold'])
    ]
    all_df = pd.concat(df_list, axis=0).reset_index(drop=True)
    all_df.to_csv(os.path.join(out_path, 'all_volume_prediction.csv'))
    logger.info('CSV of all volumes prediction saved at ' +
                os.path.join(out_path, 'all_volume_prediction.csv'))

    # Save config file
    cfg.settings['device'] = str(cfg.settings['device'])
    cfg.save_config(os.path.join(out_path, 'config.json'))
    logger.info("Config file saved at " +
                os.path.join(out_path, 'config.json'))

    # Analyse results
    analyse_supervised_exp(out_path,
                           cfg.settings['path']['DATA'],
                           cfg.settings['split']['n_fold'],
                           save_fn=os.path.join(out_path,
                                                'results_overview.pdf'))
    logger.info('Results overview figure saved at ' +
                os.path.join(out_path, 'results_overview.pdf'))
Example #6
0
    def setUpClass(self):
        self.config = Config()
        self.project_sodocu_path = self.config.get_sodocu_path()
#         print 'sodocu path: ' + self.project_sodocu_path
        self.fileHandler = FileHandler(self.config)
Example #7
0
def main(config_path):
    """
    Train a DSAD on the MURA dataset using a SimCLR pretraining.
    """
    # Load config file
    cfg = Config(settings=None)
    cfg.load_config(config_path)

    # Get path to output
    OUTPUT_PATH = cfg.settings['PATH']['OUTPUT'] + cfg.settings[
        'Experiment_Name'] + datetime.today().strftime('%Y_%m_%d_%Hh%M') + '/'
    # make output dir
    if not os.path.isdir(OUTPUT_PATH + 'models/'):
        os.makedirs(OUTPUT_PATH + 'model/', exist_ok=True)
    if not os.path.isdir(OUTPUT_PATH + 'results/'):
        os.makedirs(OUTPUT_PATH + 'results/', exist_ok=True)
    if not os.path.isdir(OUTPUT_PATH + 'logs/'):
        os.makedirs(OUTPUT_PATH + 'logs/', exist_ok=True)

    for seed_i, seed in enumerate(cfg.settings['seeds']):
        ############################### Set Up #################################
        # initialize logger
        logging.basicConfig(level=logging.INFO)
        logger = logging.getLogger()
        try:
            logger.handlers[1].stream.close()
            logger.removeHandler(logger.handlers[1])
        except IndexError:
            pass
        logger.setLevel(logging.INFO)
        formatter = logging.Formatter(
            '%(asctime)s | %(levelname)s | %(message)s')
        log_file = OUTPUT_PATH + 'logs/' + f'log_{seed_i+1}.txt'
        file_handler = logging.FileHandler(log_file)
        file_handler.setLevel(logging.INFO)
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)

        # print path
        logger.info(f"Log file : {log_file}")
        logger.info(f"Data path : {cfg.settings['PATH']['DATA']}")
        logger.info(f"Outputs path : {OUTPUT_PATH}" + "\n")

        # Set seed
        if seed != -1:
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            torch.backends.cudnn.deterministic = True
            logger.info(
                f"Set seed {seed_i+1:02}/{len(cfg.settings['seeds']):02} to {seed}"
            )

        # set number of thread
        if cfg.settings['n_thread'] > 0:
            torch.set_num_threads(cfg.settings['n_thread'])

        # check if GPU available
        cfg.settings['device'] = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')
        # Print technical info in logger
        logger.info(f"Device : {cfg.settings['device']}")
        logger.info(f"Number of thread : {cfg.settings['n_thread']}")

        ############################### Split Data #############################
        # Load data informations
        df_info = pd.read_csv(cfg.settings['PATH']['DATA_INFO'])
        df_info = df_info.drop(df_info.columns[0], axis=1)
        # remove low contrast images (all black)
        df_info = df_info[df_info.low_contrast == 0]

        # Train Validation Test Split
        spliter = MURA_TrainValidTestSplitter(
            df_info,
            train_frac=cfg.settings['Split']['train_frac'],
            ratio_known_normal=cfg.settings['Split']['known_normal'],
            ratio_known_abnormal=cfg.settings['Split']['known_abnormal'],
            random_state=42)
        spliter.split_data(verbose=False)
        train_df = spliter.get_subset('train')
        valid_df = spliter.get_subset('valid')
        test_df = spliter.get_subset('test')

        # print info to logger
        for key, value in cfg.settings['Split'].items():
            logger.info(f"Split param {key} : {value}")
        logger.info("Split Summary \n" +
                    str(spliter.print_stat(returnTable=True)))

        ############################# Build Model  #############################
        # make networks
        net_CLR = SimCLR_net(
            MLP_Neurons_layer=cfg.settings['SimCLR']['MLP_head'])
        net_CLR = net_CLR.to(cfg.settings['device'])
        net_DSAD = SimCLR_net(
            MLP_Neurons_layer=cfg.settings['DSAD']['MLP_head'])
        net_DSAD = net_DSAD.to(cfg.settings['device'])
        # print network architecture
        net_architecture = summary_string(
            net_CLR, (1, cfg.settings['Split']['img_size'],
                      cfg.settings['Split']['img_size']),
            batch_size=cfg.settings['SimCLR']['batch_size'],
            device=str(cfg.settings['device']))
        logger.info("SimCLR net architecture: \n" + net_architecture + '\n')
        net_architecture = summary_string(
            net_DSAD, (1, cfg.settings['Split']['img_size'],
                       cfg.settings['Split']['img_size']),
            batch_size=cfg.settings['DSAD']['batch_size'],
            device=str(cfg.settings['device']))
        logger.info("DSAD net architecture: \n" + net_architecture + '\n')

        # make model
        clr_DSAD = SimCLR_DSAD(net_CLR,
                               net_DSAD,
                               tau=cfg.settings['SimCLR']['tau'],
                               eta=cfg.settings['DSAD']['eta'])

        ############################# Train SimCLR #############################
        # make datasets
        train_dataset_CLR = MURADataset_SimCLR(
            train_df,
            data_path=cfg.settings['PATH']['DATA'],
            output_size=cfg.settings['Split']['img_size'],
            mask_img=True)
        valid_dataset_CLR = MURADataset_SimCLR(
            valid_df,
            data_path=cfg.settings['PATH']['DATA'],
            output_size=cfg.settings['Split']['img_size'],
            mask_img=True)
        test_dataset_CLR = MURADataset_SimCLR(
            test_df,
            data_path=cfg.settings['PATH']['DATA'],
            output_size=cfg.settings['Split']['img_size'],
            mask_img=True)

        logger.info("SimCLR Online preprocessing pipeline : \n" +
                    str(train_dataset_CLR.transform) + "\n")

        # Load model if required
        if cfg.settings['SimCLR']['model_path_to_load']:
            clr_DSAD.load_repr_net(
                cfg.settings['SimCLR']['model_path_to_load'],
                map_location=cfg.settings['device'])
            logger.info(
                f"SimCLR Model Loaded from {cfg.settings['SimCLR']['model_path_to_load']}"
                + "\n")

        # print Train parameters
        for key, value in cfg.settings['SimCLR'].items():
            logger.info(f"SimCLR {key} : {value}")

        # Train SimCLR
        clr_DSAD.train_SimCLR(
            train_dataset_CLR,
            valid_dataset=None,
            n_epoch=cfg.settings['SimCLR']['n_epoch'],
            batch_size=cfg.settings['SimCLR']['batch_size'],
            lr=cfg.settings['SimCLR']['lr'],
            weight_decay=cfg.settings['SimCLR']['weight_decay'],
            lr_milestones=cfg.settings['SimCLR']['lr_milestone'],
            n_job_dataloader=cfg.settings['SimCLR']['num_worker'],
            device=cfg.settings['device'],
            print_batch_progress=cfg.settings['print_batch_progress'])

        # Evaluate SimCLR to get embeddings
        clr_DSAD.evaluate_SimCLR(
            valid_dataset_CLR,
            batch_size=cfg.settings['SimCLR']['batch_size'],
            n_job_dataloader=cfg.settings['SimCLR']['num_worker'],
            device=cfg.settings['device'],
            print_batch_progress=cfg.settings['print_batch_progress'],
            set='valid')

        clr_DSAD.evaluate_SimCLR(
            test_dataset_CLR,
            batch_size=cfg.settings['SimCLR']['batch_size'],
            n_job_dataloader=cfg.settings['SimCLR']['num_worker'],
            device=cfg.settings['device'],
            print_batch_progress=cfg.settings['print_batch_progress'],
            set='test')

        # save repr net
        clr_DSAD.save_repr_net(OUTPUT_PATH + f'model/SimCLR_net_{seed_i+1}.pt')
        logger.info("SimCLR model saved at " + OUTPUT_PATH +
                    f"model/SimCLR_net_{seed_i+1}.pt")

        # save Results
        clr_DSAD.save_results(OUTPUT_PATH + f'results/results_{seed_i+1}.json')
        logger.info("Results saved at " + OUTPUT_PATH +
                    f"results/results_{seed_i+1}.json")

        ######################## Transfer Encoder Weight #######################

        clr_DSAD.transfer_encoder()

        ############################## Train DSAD ##############################
        # make dataset
        train_dataset_AD = MURA_Dataset(
            train_df,
            data_path=cfg.settings['PATH']['DATA'],
            load_mask=True,
            load_semilabels=True,
            output_size=cfg.settings['Split']['img_size'])
        valid_dataset_AD = MURA_Dataset(
            valid_df,
            data_path=cfg.settings['PATH']['DATA'],
            load_mask=True,
            load_semilabels=True,
            output_size=cfg.settings['Split']['img_size'])
        test_dataset_AD = MURA_Dataset(
            test_df,
            data_path=cfg.settings['PATH']['DATA'],
            load_mask=True,
            load_semilabels=True,
            output_size=cfg.settings['Split']['img_size'])

        logger.info("DSAD Online preprocessing pipeline : \n" +
                    str(train_dataset_AD.transform) + "\n")

        # Load model if required
        if cfg.settings['DSAD']['model_path_to_load']:
            clr_DSAD.load_AD(cfg.settings['DSAD']['model_path_to_load'],
                             map_location=cfg.settings['device'])
            logger.info(
                f"DSAD Model Loaded from {cfg.settings['DSAD']['model_path_to_load']} \n"
            )

        # print Train parameters
        for key, value in cfg.settings['DSAD'].items():
            logger.info(f"DSAD {key} : {value}")

        # Train DSAD
        clr_DSAD.train_AD(
            train_dataset_AD,
            valid_dataset=valid_dataset_AD,
            n_epoch=cfg.settings['DSAD']['n_epoch'],
            batch_size=cfg.settings['DSAD']['batch_size'],
            lr=cfg.settings['DSAD']['lr'],
            weight_decay=cfg.settings['DSAD']['weight_decay'],
            lr_milestone=cfg.settings['DSAD']['lr_milestone'],
            n_job_dataloader=cfg.settings['DSAD']['num_worker'],
            device=cfg.settings['device'],
            print_batch_progress=cfg.settings['print_batch_progress'])
        logger.info('--- Validation')
        clr_DSAD.evaluate_AD(
            valid_dataset_AD,
            batch_size=cfg.settings['DSAD']['batch_size'],
            n_job_dataloader=cfg.settings['DSAD']['num_worker'],
            device=cfg.settings['device'],
            print_batch_progress=cfg.settings['print_batch_progress'],
            set='valid')
        logger.info('--- Test')
        clr_DSAD.evaluate_AD(
            test_dataset_AD,
            batch_size=cfg.settings['DSAD']['batch_size'],
            n_job_dataloader=cfg.settings['DSAD']['num_worker'],
            device=cfg.settings['device'],
            print_batch_progress=cfg.settings['print_batch_progress'],
            set='test')

        # save DSAD
        clr_DSAD.save_AD(OUTPUT_PATH + f'model/DSAD_{seed_i+1}.pt')
        logger.info("model saved at " + OUTPUT_PATH +
                    f"model/DSAD_{seed_i+1}.pt")

        ########################## Save Results ################################
        # save Results
        clr_DSAD.save_results(OUTPUT_PATH + f'results/results_{seed_i+1}.json')
        logger.info("Results saved at " + OUTPUT_PATH +
                    f"results/results_{seed_i+1}.json")

    # save config file
    cfg.settings['device'] = str(cfg.settings['device'])
    cfg.save_config(OUTPUT_PATH + 'config.json')
    logger.info("Config saved at " + OUTPUT_PATH + "config.json")
Example #8
0
class TestFileHandler(unittest.TestCase):

    @classmethod
    def setUpClass(self):
        self.config = Config()
        self.project_sodocu_path = self.config.get_sodocu_path()
#         print 'sodocu path: ' + self.project_sodocu_path
        self.fileHandler = FileHandler(self.config)
 
 
    @classmethod 
    def tearDownClass(self):
        self.fileHandler = None


    def test_create_directory(self):
        idea = Idea(ItemType('idea', ''), 'idea-99', 'this is a file writer test')
        self.fileHandler.create_directory(idea)
        assert os.path.exists(self.project_sodocu_path + '/idea')
 
 
    def test_create_file(self):
        idea1 = Idea(ItemType('idea', ''), 'idea-99', 'this is a file writer test')
#         print "hasattr(idea1, 'description'): " + str(hasattr(idea1, 'description'))
#         print "hasattr(idea1, 'inventedBy'): " + str(hasattr(idea1, 'inventedBy'))
#         print 'idea1.get_description(): ' + str(idea1.get_description())
#         print 'idea1.get_invented_by(): ' + str(idea1.get_invented_by())
        self.fileHandler.create_file(idea1)
#         print self.project_sodocu_path + '/sodocu/idea/ThisIsAFileWriterTest.txt'
        item_config = read_file(self.project_sodocu_path + '/idea/ThisIsAFileWriterTest.txt')
        idea2 = create_item(self.config, item_config, self.project_sodocu_path + '/idea/ThisIsAFileWriterTest.txt')
        assert idea1.get_id() == idea2.get_id()
 
 
    def test_read_file(self):
#         print self.project_sodocu_path + '/idea/useVCSforRequirements.txt'
        config = read_file(self.project_sodocu_path + '/idea/ThisIsAFileWriterTest.txt')
        assert 'meta' in config.sections()
 
 
    def test_read_file_failure(self):
        with self.assertRaises(ValueError):
            read_file(self.project_sodocu_path + '/idea/ThisFileDoesNotExists.txt')
 
 
    def test_write_file_failure(self):
        idea = Idea(ItemType('idea', ''), 'idea-66', 'idea-66')
        idea.set_filename('/no/directory/idea/Idea-66.txt')
#         with self.assertRaises(IOError):
        assert not self.fileHandler.write_file(idea)
 
 
    def test_update_file_new_item(self):
        idea = Idea(ItemType('idea', ''), 'idea-88', 'idea-88')
        assert self.fileHandler.update_file(idea)
        config = read_file(self.project_sodocu_path + '/idea/Idea-88.txt')
        assert 'idea' in config.sections()
        
        
    def test_update_file_item_non_existing_file(self):
        idea = Idea(ItemType('idea', ''), 'idea-77', 'idea-77')
        idea.relations.add_invented_by('stakeholder-1')
        idea.relations.add_invented_by('stakeholder-2')
        assert self.fileHandler.update_file(idea)
        idea.set_filename(self.project_sodocu_path + '/idea/Idea-66.txt')
        idea.set_name('idea-55')
        assert not self.fileHandler.update_file(idea)
        
        
    def test_update_file_changed_item_name(self):
        item_config = read_file(self.project_sodocu_path + '/idea/ThisIsAFileWriterTest.txt')
        idea1 = create_item(self.config, item_config, self.project_sodocu_path + '/idea/ThisIsAFileWriterTest.txt')
        assert idea1.get_filename is not None
        idea1.set_name('this is a file update test')
        self.fileHandler.update_file(idea1)
        item_config = read_file(self.project_sodocu_path + '/idea/ThisIsAFileUpdateTest.txt')
        idea2 = create_item(self.config, item_config, self.project_sodocu_path + '/idea/ThisIsAFileUpdateTest.txt')
        assert idea1.get_id() == idea2.get_id()
        assert idea1.get_name() == idea2.get_name()
        assert idea1.get_filename() != idea2.get_filename()
Example #9
0
import argparse
import os
import torch
import numpy as np
from transformers import *

from src.data.Batcher import Batcher
from src.utils.Config import Config
from src.utils.util import device
from src.adapet import adapet
from src.eval.eval_model import test_eval

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-e', "--exp_dir", required=True)
    args = parser.parse_args()

    config_file = os.path.join(args.exp_dir, "config.json")
    config = Config(config_file, mkdir=False)

    tokenizer = AutoTokenizer.from_pretrained(config.pretrained_weight)
    batcher = Batcher(config, tokenizer, config.dataset)
    dataset_reader = batcher.get_dataset_reader()

    model = adapet(config, tokenizer, dataset_reader).to(device)
    model.load_state_dict(
        torch.load(os.path.join(args.exp_dir, "best_model.pt")))
    test_eval(config, model, batcher)
Example #10
0
import torch
import numpy as np
from transformers import *

from src.data.Batcher import Batcher
from src.utils.get_model import get_model
from src.utils.Config import Config
from src.eval.eval_model import dev_eval

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-e', "--exp_dir", required=True)
    args = parser.parse_args()

    config_file = os.path.join(args.exp_dir, "config.json")
    config = Config(config_file, mkdir=False)
    config.eval_dev = True

    tokenizer = AutoTokenizer.from_pretrained(config.pretrained_weight)
    batcher = Batcher(config, tokenizer, config.dataset)
    dataset_reader = batcher.get_dataset_reader()
    model = get_model(config, tokenizer, dataset_reader)

    model.load_state_dict(
        torch.load(os.path.join(args.exp_dir,
                                "cur_model.pt"))["model_state_dict"])
    dev_acc, dev_logits = dev_eval(config, model, batcher, 0)

    with open(os.path.join(config.exp_dir, "dev_logits.npy"), 'wb') as f:
        np.save(f, dev_logits)
Example #11
0
 def setUpClass(self):
     self.config = Config()
Example #12
0
class Test(unittest.TestCase):

    @classmethod
    def setUpClass(self):
        self.config = Config()
 
 
    @classmethod 
    def tearDownClass(self):
        self.config = None


    def test_get_sodocu_path(self):
#         print self.config.get_sodocu_path()
        assert 'SoDocu' + os.sep + './sodocu' in self.config.get_sodocu_path()
        

    def test_is_valid_item_type_valid(self):
        assert self.config.is_valid_item_type('idea')
         
 
    def test_is_valid_item_type_invalid(self):
        assert self.config.is_valid_item_type('invalid') == False
         
 
    def test_get_item_type_valid(self):
        assert self.config.get_item_type_by_name('idea').get_name() == 'idea'
         
 
    def test_get_item_type_exception(self):
#         with self.assertRaises(Exception):
#             self.config.get_item_type_by_name('invalid')
        assert self.config.get_item_type_by_name('invalid') == None
         
 
    def test_read_config(self):
        self.config.read_config()
        assert self.config.is_valid_item_type('document')
    
    
    def test_get_item_types_as_string(self):
#         print self.config.get_item_types_as_string()
        assert 'idea' in self.config.get_item_types_as_string()
        
 
    def test_get_item_types(self):
#         print self.config.get_item_types()
        assert 'idea' == self.config.get_item_types()[1].get_name()
        assert 'stakeholder' == self.config.get_item_types()[2].get_name()
        assert 'document' == self.config.get_item_types()[3].get_name()
        assert 'invented_by' in str(self.config.get_item_types()[0].get_valid_relations())