Ejemplo n.º 1
0
    def run_masking(self, data_type: str) -> None:
        """
        Mask all bids scans in self.data_selection with data type "data_type".

        Parameters
        ----------
        data_type : either anat or func
        """
        masking_opts = get_masking_opts(self.masking_config_path, data_type)

        if 'model_folder_path' not in masking_opts or not masking_opts[
                'model_folder_path']:
            # if no model_folder_path is given in the config, the default models are selected.
            masking_opts['model_folder_path'] = get_mlebe_models(data_type)
        model_config = get_model_config(masking_opts)

        # load model
        model = get_model(model_config.model)
        df = self.data_selection.loc[self.data_selection.datatype == data_type]
        for _, elem in tqdm(df.iterrows(), total=len(df)):
            bids_path = str(elem.path if data_type ==
                            'anat' else elem.tmean_path)
            mask_path, masked_path = self.mask_one(bids_path, elem,
                                                   masking_opts, model,
                                                   model_config)

            self.data_selection.loc[self.data_selection.path == elem.path,
                                    'mask_path'] = mask_path
            self.data_selection.loc[self.data_selection.path == elem.path,
                                    'masked_path'] = masked_path
Ejemplo n.º 2
0
def get_mask(json_opts, in_file_data, ori_shape, use_cuda: bool, model=None):
    """Predict segmentation mask on in_file_data with mlebe model."""
    if not model:
        from mlebe.training.models import get_model
        # To make sure that the GPU is not used for the predictions: (might be unnecessary)
        if not use_cuda:
            os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
        model = get_model(json_opts.model)
    ds_transform = get_dataset_transformation(
        'mlebe',
        opts=json_opts.augmentation,
        max_output_channels=json_opts.model.output_nc)
    transformer = ds_transform['bids']()
    # preprocess data for compatibility with model
    model_input = transformer(np.expand_dims(in_file_data, -1))
    # add dimension for batches
    model_input = model_input.unsqueeze(0)
    model.set_input(model_input)
    model.test()
    # predict
    mask_pred = np.squeeze(model.pred_seg.cpu().byte().numpy()).astype(
        np.int16)
    # switching to z,x,y
    mask_pred = np.moveaxis(mask_pred, 2, 0)
    in_file_data = np.moveaxis(in_file_data, 2, 0)
    model_input = np.moveaxis(np.squeeze(model_input.cpu().numpy()), 2, 0)

    # need to un-pad on the z-axis to the original shape:
    diff = int(np.ceil(mask_pred.shape[0] - ori_shape[0]))
    mask_pred = mask_pred[int(np.ceil(diff / 2.)):ori_shape[0] +
                          int(np.ceil(diff / 2.)), :, :]
    model_input = model_input[int(np.ceil(diff / 2.)):ori_shape[0] +
                              int(np.ceil(diff / 2.)), :, :]

    return in_file_data, mask_pred, model_input
Ejemplo n.º 3
0
def tester(json_opts, test_dataset, save_directory):
    model = get_model(json_opts.model)
    train_opts = json_opts.training

    test_loader = DataLoader(dataset=test_dataset,
                             num_workers=0,
                             batch_size=train_opts.batchSize,
                             shuffle=False)
    error_logger = ErrorLogger()
    # test
    x_test = []
    y_test = []
    y_pred = []
    for iteration, (images, labels, indices) in tqdm(enumerate(test_loader, 1),
                                                     total=len(test_loader)):
        model.set_input(images, labels)
        model.test()
        ids = test_dataset.get_ids(indices)

        stats = model.get_segmentation_stats()
        error_logger.update(stats, split='test')

        for batch_iter in range(len(ids)):
            input_arr = np.squeeze(images[batch_iter].cpu().numpy()).astype(
                np.float32)
            label_arr = np.squeeze(labels[batch_iter].cpu().numpy()).astype(
                np.int16)
            if len(ids) != 1:
                output_arr = np.squeeze(
                    model.pred_seg.cpu().byte().numpy()).astype(
                        np.int16)[batch_iter]
            else:
                output_arr = np.squeeze(
                    model.pred_seg.cpu().byte().numpy()).astype(np.int16)

            y = input_arr.shape[2]
            for slice in range(y):
                if not np.max(label_arr[..., slice]) <= 0:
                    x_img = cv2.normalize(input_arr[..., slice],
                                          None,
                                          alpha=0,
                                          beta=1,
                                          norm_type=cv2.NORM_MINMAX,
                                          dtype=cv2.CV_32F)
                    assert x_img.shape == output_arr[
                        ..., slice].shape == label_arr[..., slice].shape == (
                            json_opts.augmentation.mlebe.scale_size[0],
                            json_opts.augmentation.mlebe.scale_size[1])
                    x_test.append(x_img)
                    y_test.append(label_arr[..., slice])
                    y_pred.append(output_arr[..., slice])

    with open(os.path.join(save_directory, 'x_test.npy'), 'wb') as file1:
        np.save(file1, x_test)
    with open(os.path.join(save_directory, 'y_test.npy'), 'wb') as file2:
        np.save(file2, y_test)
    with open(os.path.join(save_directory, 'y_pred.npy'), 'wb') as file3:
        np.save(file3, y_pred)
Ejemplo n.º 4
0
def tester(json_opts, test_dataset, save_directory):
    model = get_model(json_opts.model)
    train_opts = json_opts.training

    test_loader = DataLoader(dataset=test_dataset,
                             num_workers=16,
                             batch_size=train_opts.batchSize,
                             shuffle=False)

    # test
    x_test = []
    y_test = []
    y_pred = []
    for iteration, (images, labels, indices) in tqdm(enumerate(test_loader, 1),
                                                     total=len(test_loader)):
        model.set_input(images, labels)
        model.test()
        ids = test_dataset.get_ids(indices)

        for batch_iter in range(len(ids)):
            input_arr = np.squeeze(images[batch_iter].cpu().numpy()).astype(
                np.float32)
            label_arr = np.squeeze(labels[batch_iter].cpu().numpy()).astype(
                np.int16)
            output_arr = np.squeeze(
                model.pred_seg.cpu().byte().numpy()).astype(
                    np.int16)[batch_iter]

            input_img, target = remove_black_images(input_arr, label_arr)
            _, output_img = remove_black_images(input_arr, output_arr)

            y = input_img.shape[2]
            for slice in range(y):
                x_test.append(input_img[..., slice])
                y_test.append(target[..., slice])
                y_pred.append(output_img[..., slice])

    with open(os.path.join(save_directory, 'x_test.npy'), 'wb') as file1:
        np.save(file1, x_test)
    with open(os.path.join(save_directory, 'y_test.npy'), 'wb') as file2:
        np.save(file2, y_test)
    with open(os.path.join(save_directory, 'y_pred.npy'), 'wb') as file3:
        np.save(file3, y_pred)
Ejemplo n.º 5
0
def evaluate(config_path):
    json_opts = json_file_to_pyobj(config_path)
    template_dir = json_opts.data.template_dir
    model = get_model(json_opts.model)
    save_path = os.path.join(model.save_dir, 'irsabi_test')
    mkdir(save_path)
    data_type = json_opts.data.data_type
    print(save_path)
    # shape of the images on which the classifier was trained:
    training_shape = json_opts.augmentation.mlebe.scale_size[:3]
    ds_class = get_dataset('mlebe_dataset')
    # define preprocessing transformations for model
    ds_transform = get_dataset_transformation('mlebe', opts=json_opts.augmentation,
                                              max_output_channels=json_opts.model.output_nc)

    test_dataset = ds_class(template_dir, json_opts.data.data_dir, json_opts.data, split='test',
                            transform=ds_transform['valid'],
                            train_size=None, training_shape=training_shape)
    data_selection = test_dataset.data_selection
    transformer = ds_transform['valid']()

    temp = load_mask(template_dir)
    mask_data = [copy.deepcopy(temp) for _ in range(len(data_selection))]
    dice_scores_df = pd.DataFrame(columns=['volume_name', 'slice', 'dice_score', 'idx'])
    predictions = []
    for volume in tqdm(range(len(data_selection))):  # volume is an index
        # get volume
        volume_name = data_selection.iloc[volume]['uid']
        img = nib.load(data_selection.iloc[volume]['path']).get_data()
        target = mask_data[volume].get_data()

        if json_opts.data.with_arranged_mask:
            # set the mask to zero where the image is zero
            target = arrange_mask(img, target)

        # img = preprocess(img, training_shape[:2], 'coronal')
        # target = preprocess(target, training_shape[:2], 'coronal')
        #
        # # set image shape to x,y,z
        # img = np.moveaxis(img, 0, 2)
        # target = np.moveaxis(target, 0, 2)

        # preprocess data for compatibility with model
        network_input = transformer(np.expand_dims(img, -1))
        target = np.squeeze(transformer(np.expand_dims(target, -1)).cpu().byte().numpy()).astype(np.int16)
        # add dimension for batches
        network_input = network_input.unsqueeze(0)
        model.set_input(network_input)
        model.test()
        # predict
        mask_pred = np.squeeze(model.pred_seg.cpu().numpy())
        img = np.squeeze(network_input.numpy())
        # set image shape to z,x,y
        mask_pred = np.moveaxis(mask_pred, 2, 0)
        img = np.moveaxis(img, 2, 0)
        target = np.moveaxis(target, 2, 0)

        for slice in range(img.shape[0]):
            dice_score = dice(target[slice], mask_pred[slice])
            # see if this is a black slice (want to skip those for visualisation)
            black_slice = np.max(img[slice]) <= 0
            dice_scores_df = dice_scores_df.append(
                {'volume_name': volume_name, 'slice': slice, 'dice_score': dice_score, 'idx': volume,
                 'black_slice': black_slice},
                ignore_index=True)
        predictions.append(mask_pred)
    min_df = dice_scores_df.loc[dice_scores_df['black_slice'] == False].sort_values(by=['dice_score']).head(
        sum(IMG_NBRs) // 2)
    min_df = pd.concat([min_df,
                        dice_scores_df.loc[dice_scores_df['black_slice'] == False].sort_values(by=['dice_score']).tail(
                            sum(IMG_NBRs) - sum(IMG_NBRs) // 2)],
                       ignore_index=True)
    with PdfPages(os.path.join(save_path, 'irsabi_test_{}.pdf'.format(data_type))) as pdf:
        df_idx = 0

        for IMG_NBR in IMG_NBRs:
            plt.figure(figsize=(40, IMG_NBR * 10))
            plt.figtext(.5, .9, 'Mean dice score of {}'.format(np.round(dice_scores_df['dice_score'].mean(), 4)),
                        fontsize=100, ha='center')
            i = 1
            while i <= IMG_NBR * 2 and df_idx < len(min_df):
                volume = min_df.iloc[df_idx]['idx']
                slice = min_df.iloc[df_idx]['slice']
                dice_score = min_df.iloc[df_idx]['dice_score']
                plt.subplot(IMG_NBR, 2, i)
                plt.imshow(img[slice], cmap='gray')
                plt.imshow(target[slice], cmap='Blues', alpha=0.6)
                plt.axis('off')
                i += 1
                plt.subplot(IMG_NBR, 2, i)
                plt.imshow(img[slice], cmap='gray')
                plt.imshow(predictions[volume][slice], cmap='Blues', alpha=0.6)
                plt.title('Volume: {}, slice {}, dice {}'.format(volume_name, slice, dice_score))
                plt.axis('off')
                i += 1
                df_idx += 1
            pdf.savefig()
            plt.close()

    plt.title('Dice score = {}'.format(dice_scores_df['dice_score'].mean()))
    plt.savefig('{}.pdf'.format(save_path), format='pdf')

    return dice_scores_df['dice_score'].mean(), dice_scores_df['dice_score'].std()
Ejemplo n.º 6
0
def train(json_filename, network_debug=False, experiment_config=None):
    """
Main training function for the model.
    Parameters
    ----------
    json_filename : str
        Path to the json configuration file
    network_debug : bool (optional)
    experiment_config : class used for logging (optional)
    """

    # Load options
    json_opts = json_file_to_pyobj(json_filename)
    bigprint(f'New try with parameters: {json_opts}')
    train_opts = json_opts.training

    # Setup Dataset and Augmentation
    ds_class = get_dataset('mlebe_dataset')
    ds_path = json_opts.data.data_dir
    template_path = json_opts.data.template_dir
    ds_transform = get_dataset_transformation(
        'mlebe',
        opts=json_opts.augmentation,
        max_output_channels=json_opts.model.output_nc)

    # Setup channels
    channels = json_opts.data_opts.channels
    if len(channels) != json_opts.model.input_nc \
            or len(channels) != getattr(json_opts.augmentation, 'mlebe').scale_size[-1]:
        raise Exception(
            'Number of data channels must match number of model channels, and patch and scale size dimensions'
        )

    # Setup the NN Model
    model = get_model(json_opts.model)
    if json_filename == 'configs/test_config.json':
        print('removing dir ', model.save_dir)
        shutil.rmtree(model.save_dir)
        os.mkdir(model.save_dir)

    if network_debug:
        print('# of pars: ', model.get_number_parameters())
        print('fp time: {0:.3f} sec\tbp time: {1:.3f} sec per sample'.format(
            *model.get_fp_bp_time()))
        exit()

    # Setup Data Loader
    split_opts = json_opts.data_split
    data_opts = json_opts.data
    train_dataset = ds_class(
        template_path,
        ds_path,
        data_opts,
        split='train',
        save_dir=model.save_dir,
        transform=ds_transform['train'],
        train_size=split_opts.train_size,
        test_size=split_opts.test_size,
        valid_size=split_opts.validation_size,
        split_seed=split_opts.seed,
        training_shape=json_opts.augmentation.mlebe.scale_size[:3])
    valid_dataset = ds_class(
        template_path,
        ds_path,
        data_opts,
        split='validation',
        save_dir=model.save_dir,
        transform=ds_transform['valid'],
        train_size=split_opts.train_size,
        test_size=split_opts.test_size,
        valid_size=split_opts.validation_size,
        split_seed=split_opts.seed,
        training_shape=json_opts.augmentation.mlebe.scale_size[:3])
    test_dataset = ds_class(
        template_path,
        ds_path,
        data_opts,
        split='test',
        save_dir=model.save_dir,
        transform=ds_transform['valid'],
        train_size=split_opts.train_size,
        test_size=split_opts.test_size,
        valid_size=split_opts.validation_size,
        split_seed=split_opts.seed,
        training_shape=json_opts.augmentation.mlebe.scale_size[:3])
    train_loader = DataLoader(dataset=train_dataset,
                              num_workers=1,
                              batch_size=train_opts.batchSize,
                              shuffle=True)
    valid_loader = DataLoader(dataset=valid_dataset,
                              num_workers=1,
                              batch_size=train_opts.batchSize,
                              shuffle=False)
    test_loader = DataLoader(dataset=test_dataset,
                             num_workers=1,
                             batch_size=train_opts.batchSize,
                             shuffle=False)

    # Visualisation Parameters
    visualizer = Visualiser(json_opts.visualisation, save_dir=model.save_dir)
    error_logger = ErrorLogger()

    # Training Function
    model.set_scheduler(train_opts)
    # Setup Early Stopping
    early_stopper = EarlyStopper(json_opts.training.early_stopping_patience)

    for epoch in range(model.which_epoch, train_opts.n_epochs):
        print('(epoch: %d, total # iters: %d)' % (epoch, len(train_loader)))
        train_volumes = []
        validation_volumes = []

        # Training Iterations
        for epoch_iter, (images, labels,
                         indices) in tqdm(enumerate(train_loader, 1),
                                          total=len(train_loader)):
            # Make a training update
            model.set_input(images, labels)
            model.optimize_parameters()
            # model.optimize_parameters_accumulate_grd(epoch_iter)

            # Error visualisation
            errors = model.get_current_errors()
            error_logger.update(errors, split='train')

            ids = train_dataset.get_ids(indices)
            volumes = model.get_current_volumes()
            visualizer.display_current_volumes(volumes, ids, 'train', epoch)
            train_volumes.append(volumes)

        # Validation and Testing Iterations
        for loader, split, dataset in zip([valid_loader, test_loader],
                                          ['validation', 'test'],
                                          [valid_dataset, test_dataset]):
            for epoch_iter, (images, labels,
                             indices) in tqdm(enumerate(loader, 1),
                                              total=len(loader)):
                ids = dataset.get_ids(indices)

                # Make a forward pass with the model
                model.set_input(images, labels)
                model.validate()

                # Error visualisation
                errors = model.get_current_errors()
                stats = model.get_segmentation_stats()
                error_logger.update({**errors, **stats}, split=split)

                # Visualise predictions
                if split == 'validation':  # do not look at testing
                    # Visualise predictions
                    volumes = model.get_current_volumes()
                    visualizer.display_current_volumes(volumes, ids, split,
                                                       epoch)
                    validation_volumes.append(volumes)

        current_loss = error_logger.get_errors('validation')['Seg_Loss']
        # Update best validation loss/epoch values
        model.update_validation_state(epoch, current_loss)
        early_stopper.update(model, epoch, current_loss)
        # Update the plots
        for split in ['train', 'validation', 'test']:
            visualizer.plot_current_errors(epoch,
                                           error_logger.get_errors(split),
                                           split_name=split)
            visualizer.print_current_errors(epoch,
                                            error_logger.get_errors(split),
                                            split_name=split)
        visualizer.save_plots(epoch, save_frequency=5)
        error_logger.reset()

        # saving checkpoint
        if model.is_improving:
            print('saving model')
            # replacing old model with new model
            model.save(json_opts.model.model_type, epoch)

        # Update the model learning rate
        model.update_learning_rate(metric=current_loss)

        if early_stopper.should_stop_early:
            print('early stopping')
            # get validation metrics
            val_loss_log = pd.read_excel(os.path.join(
                'checkpoints', json_opts.model.experiment_name,
                'loss_log.xlsx'),
                                         sheet_name='validation').iloc[:, 1:]

            irsabi_dice_mean, irsabi_dice_std = finalize(
                json_opts, json_filename, model, experiment_config)

            val_loss_log['irsabi_dice_mean'] = irsabi_dice_mean
            val_loss_log['irsabi_dice_std'] = irsabi_dice_std
            return val_loss_log.loc[val_loss_log['Seg_Loss'] ==
                                    val_loss_log['Seg_Loss'].min()]

    # get validation metrics
    val_loss_log = pd.read_excel(os.path.join(json_opts.model.checkpoints_dir,
                                              json_opts.model.experiment_name,
                                              'loss_log.xlsx'),
                                 sheet_name='validation').iloc[:, 1:]

    irsabi_dice_mean, irsabi_dice_std = finalize(json_opts, json_filename,
                                                 model, experiment_config)

    val_loss_log['irsabi_dice_mean'] = irsabi_dice_mean
    val_loss_log['irsabi_dice_std'] = irsabi_dice_std
    return val_loss_log.loc[val_loss_log['Seg_Loss'] ==
                            val_loss_log['Seg_Loss'].min()]