Example #1
0
def get_mask(json_opts, in_file_data, ori_shape, use_cuda: bool, model=None):
    """Predict segmentation mask on in_file_data with mlebe model."""
    if not model:
        from mlebe.training.models import get_model
        # To make sure that the GPU is not used for the predictions: (might be unnecessary)
        if not use_cuda:
            os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
        model = get_model(json_opts.model)
    ds_transform = get_dataset_transformation(
        'mlebe',
        opts=json_opts.augmentation,
        max_output_channels=json_opts.model.output_nc)
    transformer = ds_transform['bids']()
    # preprocess data for compatibility with model
    model_input = transformer(np.expand_dims(in_file_data, -1))
    # add dimension for batches
    model_input = model_input.unsqueeze(0)
    model.set_input(model_input)
    model.test()
    # predict
    mask_pred = np.squeeze(model.pred_seg.cpu().byte().numpy()).astype(
        np.int16)
    # switching to z,x,y
    mask_pred = np.moveaxis(mask_pred, 2, 0)
    in_file_data = np.moveaxis(in_file_data, 2, 0)
    model_input = np.moveaxis(np.squeeze(model_input.cpu().numpy()), 2, 0)

    # need to un-pad on the z-axis to the original shape:
    diff = int(np.ceil(mask_pred.shape[0] - ori_shape[0]))
    mask_pred = mask_pred[int(np.ceil(diff / 2.)):ori_shape[0] +
                          int(np.ceil(diff / 2.)), :, :]
    model_input = model_input[int(np.ceil(diff / 2.)):ori_shape[0] +
                              int(np.ceil(diff / 2.)), :, :]

    return in_file_data, mask_pred, model_input
Example #2
0
def get_traindata(json_opts, save_directory):
    train_opts = json_opts.training
    ds_transform = get_dataset_transformation(
        'mlebe',
        opts=model_json_opts.augmentation,
        max_output_channels=model_json_opts.model.output_nc)

    train_dataset = ds_class(
        template_dir,
        ds_path,
        model_json_opts.data,
        split='train',
        save_dir=None,
        transform=ds_transform['train'],
        train_size=split_opts.train_size,
        test_size=split_opts.test_size,
        valid_size=split_opts.validation_size,
        split_seed=split_opts.seed,
        training_shape=model_json_opts.augmentation.mlebe.scale_size)
    # train_loader = DataLoader(dataset=train_dataset, num_workers=16, batch_size=train_opts.batchSize, shuffle=False)
    train_loader = DataLoader(dataset=train_dataset,
                              num_workers=0,
                              batch_size=train_opts.batchSize,
                              shuffle=False)

    # test
    x_train = []
    y_train = []
    for iteration, (images, labels,
                    indices) in tqdm(enumerate(train_loader, 1),
                                     total=len(train_loader)):
        ids = train_dataset.get_ids(indices)
        for batch_iter in range(len(ids)):
            input_arr = np.squeeze(images[batch_iter].cpu().numpy()).astype(
                np.float32)
            label_arr = np.squeeze(labels[batch_iter].cpu().numpy()).astype(
                np.int16)
            input_arr, label_arr = remove_black_images(input_arr, label_arr)

            for slice in range(input_arr.shape[-1]):
                if not np.max(label_arr[..., slice]) <= 0:
                    x_train.append(
                        cv2.normalize(input_arr[..., slice],
                                      None,
                                      alpha=0,
                                      beta=1,
                                      norm_type=cv2.NORM_MINMAX,
                                      dtype=cv2.CV_32F))
                    y_train.append(label_arr[..., slice])

    with open(os.path.join(save_directory, 'x_train.npy'), 'wb') as file1:
        np.save(file1, x_train)
    with open(os.path.join(save_directory, 'y_train.npy'), 'wb') as file2:
        np.save(file2, y_train)
Example #3
0
def evaluate(config_path):
    json_opts = json_file_to_pyobj(config_path)
    template_dir = json_opts.data.template_dir
    model = get_model(json_opts.model)
    save_path = os.path.join(model.save_dir, 'irsabi_test')
    mkdir(save_path)
    data_type = json_opts.data.data_type
    print(save_path)
    # shape of the images on which the classifier was trained:
    training_shape = json_opts.augmentation.mlebe.scale_size[:3]
    ds_class = get_dataset('mlebe_dataset')
    # define preprocessing transformations for model
    ds_transform = get_dataset_transformation('mlebe', opts=json_opts.augmentation,
                                              max_output_channels=json_opts.model.output_nc)

    test_dataset = ds_class(template_dir, json_opts.data.data_dir, json_opts.data, split='test',
                            transform=ds_transform['valid'],
                            train_size=None, training_shape=training_shape)
    data_selection = test_dataset.data_selection
    transformer = ds_transform['valid']()

    temp = load_mask(template_dir)
    mask_data = [copy.deepcopy(temp) for _ in range(len(data_selection))]
    dice_scores_df = pd.DataFrame(columns=['volume_name', 'slice', 'dice_score', 'idx'])
    predictions = []
    for volume in tqdm(range(len(data_selection))):  # volume is an index
        # get volume
        volume_name = data_selection.iloc[volume]['uid']
        img = nib.load(data_selection.iloc[volume]['path']).get_data()
        target = mask_data[volume].get_data()

        if json_opts.data.with_arranged_mask:
            # set the mask to zero where the image is zero
            target = arrange_mask(img, target)

        # img = preprocess(img, training_shape[:2], 'coronal')
        # target = preprocess(target, training_shape[:2], 'coronal')
        #
        # # set image shape to x,y,z
        # img = np.moveaxis(img, 0, 2)
        # target = np.moveaxis(target, 0, 2)

        # preprocess data for compatibility with model
        network_input = transformer(np.expand_dims(img, -1))
        target = np.squeeze(transformer(np.expand_dims(target, -1)).cpu().byte().numpy()).astype(np.int16)
        # add dimension for batches
        network_input = network_input.unsqueeze(0)
        model.set_input(network_input)
        model.test()
        # predict
        mask_pred = np.squeeze(model.pred_seg.cpu().numpy())
        img = np.squeeze(network_input.numpy())
        # set image shape to z,x,y
        mask_pred = np.moveaxis(mask_pred, 2, 0)
        img = np.moveaxis(img, 2, 0)
        target = np.moveaxis(target, 2, 0)

        for slice in range(img.shape[0]):
            dice_score = dice(target[slice], mask_pred[slice])
            # see if this is a black slice (want to skip those for visualisation)
            black_slice = np.max(img[slice]) <= 0
            dice_scores_df = dice_scores_df.append(
                {'volume_name': volume_name, 'slice': slice, 'dice_score': dice_score, 'idx': volume,
                 'black_slice': black_slice},
                ignore_index=True)
        predictions.append(mask_pred)
    min_df = dice_scores_df.loc[dice_scores_df['black_slice'] == False].sort_values(by=['dice_score']).head(
        sum(IMG_NBRs) // 2)
    min_df = pd.concat([min_df,
                        dice_scores_df.loc[dice_scores_df['black_slice'] == False].sort_values(by=['dice_score']).tail(
                            sum(IMG_NBRs) - sum(IMG_NBRs) // 2)],
                       ignore_index=True)
    with PdfPages(os.path.join(save_path, 'irsabi_test_{}.pdf'.format(data_type))) as pdf:
        df_idx = 0

        for IMG_NBR in IMG_NBRs:
            plt.figure(figsize=(40, IMG_NBR * 10))
            plt.figtext(.5, .9, 'Mean dice score of {}'.format(np.round(dice_scores_df['dice_score'].mean(), 4)),
                        fontsize=100, ha='center')
            i = 1
            while i <= IMG_NBR * 2 and df_idx < len(min_df):
                volume = min_df.iloc[df_idx]['idx']
                slice = min_df.iloc[df_idx]['slice']
                dice_score = min_df.iloc[df_idx]['dice_score']
                plt.subplot(IMG_NBR, 2, i)
                plt.imshow(img[slice], cmap='gray')
                plt.imshow(target[slice], cmap='Blues', alpha=0.6)
                plt.axis('off')
                i += 1
                plt.subplot(IMG_NBR, 2, i)
                plt.imshow(img[slice], cmap='gray')
                plt.imshow(predictions[volume][slice], cmap='Blues', alpha=0.6)
                plt.title('Volume: {}, slice {}, dice {}'.format(volume_name, slice, dice_score))
                plt.axis('off')
                i += 1
                df_idx += 1
            pdf.savefig()
            plt.close()

    plt.title('Dice score = {}'.format(dice_scores_df['dice_score'].mean()))
    plt.savefig('{}.pdf'.format(save_path), format='pdf')

    return dice_scores_df['dice_score'].mean(), dice_scores_df['dice_score'].std()
Example #4
0
def train(json_filename, network_debug=False, experiment_config=None):
    """
Main training function for the model.
    Parameters
    ----------
    json_filename : str
        Path to the json configuration file
    network_debug : bool (optional)
    experiment_config : class used for logging (optional)
    """

    # Load options
    json_opts = json_file_to_pyobj(json_filename)
    bigprint(f'New try with parameters: {json_opts}')
    train_opts = json_opts.training

    # Setup Dataset and Augmentation
    ds_class = get_dataset('mlebe_dataset')
    ds_path = json_opts.data.data_dir
    template_path = json_opts.data.template_dir
    ds_transform = get_dataset_transformation(
        'mlebe',
        opts=json_opts.augmentation,
        max_output_channels=json_opts.model.output_nc)

    # Setup channels
    channels = json_opts.data_opts.channels
    if len(channels) != json_opts.model.input_nc \
            or len(channels) != getattr(json_opts.augmentation, 'mlebe').scale_size[-1]:
        raise Exception(
            'Number of data channels must match number of model channels, and patch and scale size dimensions'
        )

    # Setup the NN Model
    model = get_model(json_opts.model)
    if json_filename == 'configs/test_config.json':
        print('removing dir ', model.save_dir)
        shutil.rmtree(model.save_dir)
        os.mkdir(model.save_dir)

    if network_debug:
        print('# of pars: ', model.get_number_parameters())
        print('fp time: {0:.3f} sec\tbp time: {1:.3f} sec per sample'.format(
            *model.get_fp_bp_time()))
        exit()

    # Setup Data Loader
    split_opts = json_opts.data_split
    data_opts = json_opts.data
    train_dataset = ds_class(
        template_path,
        ds_path,
        data_opts,
        split='train',
        save_dir=model.save_dir,
        transform=ds_transform['train'],
        train_size=split_opts.train_size,
        test_size=split_opts.test_size,
        valid_size=split_opts.validation_size,
        split_seed=split_opts.seed,
        training_shape=json_opts.augmentation.mlebe.scale_size[:3])
    valid_dataset = ds_class(
        template_path,
        ds_path,
        data_opts,
        split='validation',
        save_dir=model.save_dir,
        transform=ds_transform['valid'],
        train_size=split_opts.train_size,
        test_size=split_opts.test_size,
        valid_size=split_opts.validation_size,
        split_seed=split_opts.seed,
        training_shape=json_opts.augmentation.mlebe.scale_size[:3])
    test_dataset = ds_class(
        template_path,
        ds_path,
        data_opts,
        split='test',
        save_dir=model.save_dir,
        transform=ds_transform['valid'],
        train_size=split_opts.train_size,
        test_size=split_opts.test_size,
        valid_size=split_opts.validation_size,
        split_seed=split_opts.seed,
        training_shape=json_opts.augmentation.mlebe.scale_size[:3])
    train_loader = DataLoader(dataset=train_dataset,
                              num_workers=1,
                              batch_size=train_opts.batchSize,
                              shuffle=True)
    valid_loader = DataLoader(dataset=valid_dataset,
                              num_workers=1,
                              batch_size=train_opts.batchSize,
                              shuffle=False)
    test_loader = DataLoader(dataset=test_dataset,
                             num_workers=1,
                             batch_size=train_opts.batchSize,
                             shuffle=False)

    # Visualisation Parameters
    visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)
    error_logger = ErrorLogger()

    # Training Function
    model.set_scheduler(train_opts)
    # Setup Early Stopping
    early_stopper = EarlyStopper(json_opts.training.early_stopping_patience)

    for epoch in range(model.which_epoch, train_opts.n_epochs):
        print('(epoch: %d, total # iters: %d)' % (epoch, len(train_loader)))
        train_volumes = []
        validation_volumes = []

        # Training Iterations
        for epoch_iter, (images, labels,
                         indices) in tqdm(enumerate(train_loader, 1),
                                          total=len(train_loader)):
            # Make a training update
            model.set_input(images, labels)
            model.optimize_parameters()
            # model.optimize_parameters_accumulate_grd(epoch_iter)

            # Error visualisation
            errors = model.get_current_errors()
            error_logger.update(errors, split='train')

            ids = train_dataset.get_ids(indices)
            volumes = model.get_current_volumes()
            visualizer.display_current_volumes(volumes, ids, 'train', epoch)
            train_volumes.append(volumes)

        # Validation and Testing Iterations
        for loader, split, dataset in zip([valid_loader, test_loader],
                                          ['validation', 'test'],
                                          [valid_dataset, test_dataset]):
            for epoch_iter, (images, labels,
                             indices) in tqdm(enumerate(loader, 1),
                                              total=len(loader)):
                ids = dataset.get_ids(indices)

                # Make a forward pass with the model
                model.set_input(images, labels)
                model.validate()

                # Error visualisation
                errors = model.get_current_errors()
                stats = model.get_segmentation_stats()
                error_logger.update({**errors, **stats}, split=split)

                # Visualise predictions
                if split == 'validation':  # do not look at testing
                    # Visualise predictions
                    volumes = model.get_current_volumes()
                    visualizer.display_current_volumes(volumes, ids, split,
                                                       epoch)
                    validation_volumes.append(volumes)

        current_loss = error_logger.get_errors('validation')['Seg_Loss']
        # Update best validation loss/epoch values
        model.update_validation_state(epoch, current_loss)
        early_stopper.update(model, epoch, current_loss)
        # Update the plots
        for split in ['train', 'validation', 'test']:
            visualizer.plot_current_errors(epoch,
                                           error_logger.get_errors(split),
                                           split_name=split)
            visualizer.print_current_errors(epoch,
                                            error_logger.get_errors(split),
                                            split_name=split)
        visualizer.save_plots(epoch, save_frequency=5)
        error_logger.reset()

        # saving checkpoint
        if model.is_improving:
            print('saving model')
            # replacing old model with new model
            model.save(json_opts.model.model_type, epoch)

        # Update the model learning rate
        model.update_learning_rate(metric=current_loss)

        if early_stopper.should_stop_early:
            print('early stopping')
            # get validation metrics
            val_loss_log = pd.read_excel(os.path.join(
                'checkpoints', json_opts.model.experiment_name,
                'loss_log.xlsx'),
                                         sheet_name='validation').iloc[:, 1:]

            irsabi_dice_mean, irsabi_dice_std = finalize(
                json_opts, json_filename, model, experiment_config)

            val_loss_log['irsabi_dice_mean'] = irsabi_dice_mean
            val_loss_log['irsabi_dice_std'] = irsabi_dice_std
            return val_loss_log.loc[val_loss_log['Seg_Loss'] ==
                                    val_loss_log['Seg_Loss'].min()]

    # get validation metrics
    val_loss_log = pd.read_excel(os.path.join(json_opts.model.checkpoints_dir,
                                              json_opts.model.experiment_name,
                                              'loss_log.xlsx'),
                                 sheet_name='validation').iloc[:, 1:]

    irsabi_dice_mean, irsabi_dice_std = finalize(json_opts, json_filename,
                                                 model, experiment_config)

    val_loss_log['irsabi_dice_mean'] = irsabi_dice_mean
    val_loss_log['irsabi_dice_std'] = irsabi_dice_std
    return val_loss_log.loc[val_loss_log['Seg_Loss'] ==
                            val_loss_log['Seg_Loss'].min()]
Example #5
0
mkdir(save_dir)

workflow_json_opts = json_file_to_pyobj(config_path)
model_config_path = Path(workflow_json_opts.masking_config.masking_config_anat.
                         model_folder_path) / 'trained_mlebe_config_anat.json'
model_json_opts = json_file_to_pyobj(model_config_path)
data_dir = model_json_opts.data.data_dir
template_dir = '/usr/share/mouse-brain-atlases/'

ds_class = get_dataset('mlebe_dataset')
ds_path = model_json_opts.data.data_dir
channels = model_json_opts.data_opts.channels
split_opts = model_json_opts.data_split
train_opts = model_json_opts.training
ds_transform = get_dataset_transformation(
    'mlebe',
    opts=model_json_opts.augmentation,
    max_output_channels=model_json_opts.model.output_nc)

test_dataset = ds_class(
    template_dir,
    ds_path,
    model_json_opts.data,
    split='test',
    save_dir=None,
    transform=ds_transform['valid'],
    train_size=split_opts.train_size,
    test_size=split_opts.test_size,
    valid_size=split_opts.validation_size,
    split_seed=split_opts.seed,
    training_shape=model_json_opts.augmentation.mlebe.scale_size[:3])