示例#1
0
def _get_loaders(train_path, val_path, label_dtype, train_patch, train_stride, val_patch, val_stride):
    """
    Returns dictionary containing the  training and validation loaders
    (torch.utils.data.DataLoader) backed by the datasets.hdf5.HDF5Dataset

    :param train_path: path to the H5 file containing the training set
    :param val_path: path to the H5 file containing the validation set
    :param label_dtype: target type of the label dataset
    :return: dict {
        'train': <train_loader>
        'val': <val_loader>
    }
    """

    # create H5 backed training dataset with data augmentation
    train_dataset = HDF5Dataset(train_path, train_patch, train_stride, phase='train', label_dtype=label_dtype,
                                transformer=ExtendedTransformer)

    # create H5 backed validation dataset
    val_dataset = HDF5Dataset(val_path, val_patch, val_stride, phase='val', label_dtype=label_dtype)

    return {
        'train': DataLoader(train_dataset, batch_size=1, shuffle=True),
        'val': DataLoader(val_dataset, batch_size=1, shuffle=True)
    }
    def _get_loaders(channel_per_class, label_dtype, pixel_wise_weight=False):
        train, val = TestUNet3DTrainer._create_random_dataset((128, 128, 128), (64, 64, 64), channel_per_class)
        train_dataset = HDF5Dataset(train, patch_shape=(32, 64, 64), stride_shape=(16, 32, 32), phase='train',
                                    label_dtype=label_dtype, weighted=pixel_wise_weight)
        val_dataset = HDF5Dataset(val, patch_shape=(64, 64, 64), stride_shape=(64, 64, 64), phase='val',
                                  label_dtype=label_dtype, weighted=pixel_wise_weight)

        return {
            'train': DataLoader(train_dataset, batch_size=1, shuffle=True),
            'val': DataLoader(val_dataset, batch_size=1, shuffle=True)
        }
示例#3
0
    def test_hdf5_dataset(self):
        path = create_random_dataset((128, 128, 128))

        patch_shapes = [(127, 127, 127), (69, 70, 70), (32, 64, 64)]
        stride_shapes = [(1, 1, 1), (17, 23, 23), (32, 64, 64)]

        for patch_shape, stride_shape in zip(patch_shapes, stride_shapes):
            with h5py.File(path, 'r') as f:
                raw = f['raw'][...]
                label = f['label'][...]

                dataset = HDF5Dataset(path, patch_shape, stride_shape, 'test')

                # create zero-arrays of the same shape as the original dataset in order to verify if every element
                # was visited during the iteration
                visit_raw = np.zeros_like(raw)
                visit_label = np.zeros_like(label)

                for (_, idx) in dataset:
                    visit_raw[idx] = 1
                    visit_label[idx] = 1

                # verify that every element was visited at least once
                assert np.all(visit_raw)
                assert np.all(visit_label)
示例#4
0
    def test_augmentation(self):
        raw = np.random.rand(32, 96, 96)
        # assign raw to label's channels for ease of comparison
        label = np.stack(raw for _ in range(3))

        tmp_file = NamedTemporaryFile()
        tmp_path = tmp_file.name
        with h5py.File(tmp_path, 'w') as f:
            f.create_dataset('raw', data=raw)
            f.create_dataset('label', data=label)

        dataset = HDF5Dataset(tmp_path,
                              patch_shape=(16, 64, 64),
                              stride_shape=(8, 32, 32),
                              phase='train',
                              transformer_config=transformer_config)

        # test augmentations using DataLoader with 4 worker threads
        data_loader = DataLoader(dataset,
                                 batch_size=1,
                                 num_workers=4,
                                 shuffle=True)
        for (img, label) in data_loader:
            for i in range(label.shape[0]):
                assert np.allclose(img, label[i])
示例#5
0
    def test_hdf5_with_multiple_label_datasets(self):
        path = create_random_dataset((128, 128, 128), label_datasets=['label1', 'label2'])
        patch_shape = (32, 64, 64)
        stride_shape = (32, 64, 64)

        dataset = HDF5Dataset(path, patch_shape, stride_shape, phase='train', transformer_config=transformer_config,
                              raw_internal_path='raw', label_internal_path=['label1', 'label2'])

        for raw, labels in dataset:
            assert len(labels) == 2
示例#6
0
    def test_embeddings_predictor(self, tmpdir):
        config = {'model': {'output_heads': 1}, 'device': torch.device('cpu')}

        slice_builder_config = {
            'name': 'SliceBuilder',
            'patch_shape': (100, 200, 200),
            'stride_shape': (60, 150, 150)
        }

        transformer_config = {
            'raw': [{
                'name': 'ToTensor',
                'expand_dims': False,
                'dtype': 'long'
            }]
        }

        gt_file = 'resources/sample_cells.h5'
        output_file = os.path.join(tmpdir, 'output_segmentation.h5')

        dataset = HDF5Dataset(gt_file,
                              phase='test',
                              slice_builder_config=slice_builder_config,
                              transformer_config=transformer_config,
                              raw_internal_path='label')

        loader = DataLoader(dataset,
                            batch_size=1,
                            num_workers=1,
                            shuffle=False,
                            collate_fn=prediction_collate)

        predictor = FakePredictor(FakeModel(),
                                  loader,
                                  output_file,
                                  config,
                                  clustering='meanshift',
                                  bandwidth=0.5)

        predictor.predict()

        with h5py.File(gt_file, 'r') as f:
            with h5py.File(output_file, 'r') as g:
                gt = f['label'][...]
                segm = g['segmentation/meanshift'][...]
                arand_error = adapted_rand(segm, gt)

                assert arand_error < 0.1
示例#7
0
def main():
    config = parse_test_config()

    # make sure those values correspond to the ones used during training
    in_channels = config.in_channels
    out_channels = config.out_channels
    # use F.interpolate for upsampling
    interpolate = config.interpolate
    layer_order = config.layer_order
    final_sigmoid = config.final_sigmoid
    model = UNet3D(in_channels,
                   out_channels,
                   init_channel_number=config.init_channel_number,
                   final_sigmoid=final_sigmoid,
                   interpolate=interpolate,
                   conv_layer_order=layer_order)

    logger.info(f'Loading model from {config.model_path}...')
    utils.load_checkpoint(config.model_path, model)

    logger.info('Loading datasets...')

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        logger.warning('No CUDA device available. Using CPU for prediction...')
        device = torch.device('cpu')

    model = model.to(device)

    patch = tuple(config.patch)
    stride = tuple(config.stride)

    for test_path in config.test_path:
        # create dataset for a given test file
        dataset = HDF5Dataset(test_path,
                              patch,
                              stride,
                              phase='test',
                              raw_internal_path=config.raw_internal_path)
        # run the model prediction on the entire dataset
        probability_maps = predict(model, dataset, out_channels, device)
        # save the resulting probability maps
        output_file = f'{os.path.splitext(test_path)[0]}_probabilities.h5'
        save_predictions(probability_maps, output_file)
示例#8
0
    def test_augmentation(self):
        raw = np.random.rand(32, 96, 96)
        label = np.zeros((3, 32, 96, 96))
        # assign raw to label's channels for ease of comparison
        for i in range(label.shape[0]):
            label[i] = raw

        tmp_file = NamedTemporaryFile()
        tmp_path = tmp_file.name
        with h5py.File(tmp_path, 'w') as f:
            f.create_dataset('raw', data=raw)
            f.create_dataset('label', data=label)

        dataset = HDF5Dataset(tmp_path, patch_shape=(16, 64, 64), stride_shape=(8, 32, 32), phase='train',
                              transformer_config=transformer_config)

        for (img, label) in dataset:
            for i in range(label.shape[0]):
                assert np.allclose(img, label[i])
示例#9
0
    def test_cl_slice_builder(self):
        path = create_random_dataset((128, 128, 128), ignore_index=True)

        patch_shape = (32, 64, 64)
        stride_shape = (32, 64, 64)

        ignore_label_volumes = []
        with h5py.File(path, 'r') as f:
            dataset = HDF5Dataset(
                path,
                patch_shape,
                stride_shape,
                'train',
                slice_builder_cls=CurriculumLearningSliceBuilder)

            for _, label in dataset:
                ignore_label_volumes.append(np.count_nonzero(label == -1))

        assert all(ignore_label_volumes[i] <= ignore_label_volumes[i + 1]
                   for i in range(len(ignore_label_volumes) - 1))
示例#10
0
    def test_cl_slice_builder(self):
        path = create_random_dataset((128, 128, 128), ignore_index=True)

        patch_shape = (32, 64, 64)
        stride_shape = (32, 64, 64)

        ignore_label_volumes = []
        dataset = HDF5Dataset(path,
                              patch_shape,
                              stride_shape,
                              'test',
                              transformer_config=transformer_config,
                              slice_builder_cls=CurriculumLearningSliceBuilder)

        for _, label in dataset:
            ignore_label_volumes.append(np.count_nonzero(label == -1))

        # make sure that label patches are sorted by the number of ignore index voxels
        assert all(ignore_label_volumes[i] <= ignore_label_volumes[i + 1]
                   for i in range(len(ignore_label_volumes) - 1))
示例#11
0
def main():
    parser = argparse.ArgumentParser(description='3D U-Net predictions')
    parser.add_argument('--cdmodel-path', required=True, type=str,
                        help='path to the coordinate detector model.')
    parser.add_argument('--model-path', required=True, type=str,
                        help='path to the segmentation model')
    parser.add_argument('--in-channels', type=int, default=1,
                        help='number of input channels (default: 1)')
    parser.add_argument('--out-channels', type=int, default=2,
                        help='number of output channels (default: 2)')
    parser.add_argument('--init-channel-number', type=int, default=64,
                        help='Initial number of feature maps in the encoder path which gets doubled on every stage (default: 64)')
    parser.add_argument('--layer-order', type=str,
                        help="Conv layer ordering, e.g. 'crg' -> Conv3D+ReLU+GroupNorm",
                        default='crg')
    parser.add_argument('--final-sigmoid',
                        action='store_true',
                        help='if True apply element-wise nn.Sigmoid after the last layer otherwise apply nn.Softmax')
    parser.add_argument('--test-path', type=str, nargs='+', required=True, help='path to the test dataset')
    parser.add_argument('--raw-internal-path', type=str, default='raw')
    parser.add_argument('--patch', type=int, nargs='+', default=None,
                        help='Patch shape for used for prediction on the test set')
    parser.add_argument('--stride', type=int, nargs='+', default=None,
                        help='Patch stride for used for prediction on the test set')
    parser.add_argument('--report-metrics', action='store_true',
                        help='Whether to print metrics for each prediction')
    parser.add_argument('--output-path', type=str, default='./output/',
                        help='The output path to generate the nifti file')

    args = parser.parse_args()

    # Check if output path exists
    if not os.path.isdir(args.output_path):
        os.mkdir(args.output_path)

    # make sure those values correspond to the ones used during training
    in_channels = args.in_channels
    out_channels = args.out_channels
    # use F.interpolate for upsampling
    interpolate = True
    layer_order = args.layer_order
    final_sigmoid = args.final_sigmoid

    # Define model
    UNet_model = UNet3D(in_channels, out_channels,
                       init_channel_number=args.init_channel_number,
                       final_sigmoid=final_sigmoid,
                       interpolate=interpolate,
                       conv_layer_order=layer_order)
    Coor_model = CoorNet(in_channels)

    # Define metrics
    loss = nn.MSELoss(reduction='sum')
    acc = DiceCoefficient()
    
    logger.info('Loading trained coordinate detector model from ' + args.cdmodel_path)
    utils.load_checkpoint(args.cdmodel_path, Coor_model)

    logger.info('Loading trained segmentation model from ' + args.model_path)
    utils.load_checkpoint(args.model_path, UNet_model)

    # Load the model to the device
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        logger.warning('No CUDA device available. Using CPU for predictions')
        device = torch.device('cpu')
    UNet_model = UNet_model.to(device)
    Coor_model = Coor_model.to(device)

    # Apply patch training if assigned
    if args.patch and args.stride:
        patch = tuple(args.patch)
        stride = tuple(args.stride)

    # Initialise counters
    total_dice = 0
    total_loss = 0
    count = 0
    tmp_created = False

    for test_path in args.test_path:
        if test_path.endswith('.nii.gz'):
            if args.report_metrics:
                raise ValueError("Cannot report metrics on original files.")
            # Temporary save as h5 file
            # Preprocess if dim != 192 x 224 x 192
            data = preprocess_nifti(test_path, args.output_path)
            logger.info('Preprocessing complete.')
            hf = h5py.File(test_path + '.h5', 'w')
            hf.create_dataset('raw', data=data)
            hf.close()
            test_path += '.h5'
            tmp_created = True
        if not args.patch and not args.stride:
            curr_shape = np.array(h5py.File(test_path, 'r')[args.raw_internal_path]).shape
            patch = curr_shape
            stride = curr_shape

        # Initialise dataset
        dataset = HDF5Dataset(test_path, patch, stride, phase='test', raw_internal_path=args.raw_internal_path)        

        file_name = test_path.split('/')[-1].split('.')[0]
        # Predict the centre coordinates
        x, y, z = predict(Coor_model, dataset, out_channels, device)

        # Perform segmentation
        probability_maps = predict(UNet_model, dataset, out_channels, device, x, y, z)
        res = np.argmax(probability_maps, axis=0)

        # Put the image batch back to mask with the original size
        res = recover_patch(res, x, y, z, dataset.raw.shape)

        # Extract LH and RH segmentations and write as file
        LH = np.zeros(res.shape)
        LH[int(res.shape[0]/2):,:,:] = res[int(res.shape[0]/2):,:,:]
        RH = np.zeros(res.shape)
        RH[:int(res.shape[0]/2),:,:] = res[:int(res.shape[0]/2),:,:]
        
        LH_img = nib.Nifti1Image(LH, AFF)
        RH_img = nib.Nifti1Image(RH, AFF)
        nib.save(LH_img, args.output_path + file_name + '_LH.nii.gz')
        nib.save(RH_img, args.output_path + file_name + '_RH.nii.gz')
        logger.info('File saved to ' + args.output_path + file_name + '_LH.nii.gz')
        logger.info('File saved to ' + args.output_path + file_name + '_RH.nii.gz')
        
        if tmp_created:
            os.remove(test_path)

        if args.report_metrics:
            count += 1

            # Compute coordinate accuracy
            # Coordinate evaluation disabled by default, since not all data have coordinate information
            # coor_dataset = HDF5Dataset(test_path, patch, stride, phase='val', raw_internal_path=args.raw_internal_path, label_internal_path='coor')
            # coor_target = coor_dataset[0][1].to(device)
            # coor_pred_tensor = torch.from_numpy(np.array([x, y, z])).to(device)
            # curr_coor_loss = loss(coor_pred_tensor, coor_target)
            # total_loss += curr_coor_loss
            # logger.info('Current coordinate loss: %f' % (curr_coor_loss))

            # Compute segmentation Dice score
            label_dataset = HDF5Dataset(test_path, patch, stride, phase='val', raw_internal_path=args.raw_internal_path, label_internal_path='label')
            label_target = label_dataset[0][1].to(device)
            res_dice = probability_maps
            new_shape = np.append(res_dice.shape[0], np.array(label_target.size()))
            res_dice = recover_patch_4d(res_dice, x, y, z, new_shape)
            pred_tensor = torch.from_numpy(res_dice).to(device).float()
            label_target = label_target.view((1,) + label_target.shape)
            curr_dice_score = acc(pred_tensor, label_target.long())
            total_dice += curr_dice_score
            logger.info('Current Dice score: %f' % (curr_dice_score))

            # Compute length estimation
            logger.info('RH length: ' + str(get_total_dist(res[:int(res.shape[0]/2),:,:])))
            logger.info('LH length: ' + str(get_total_dist(res[int(res.shape[0]/2):,:,:])))
    
    if args.report_metrics:       
        # logger.info('Average loss: %f.' % (total_loss/count))
        logger.info('Average Dice score: %f.' % (total_dice/count))
示例#12
0
def main():
    parser = argparse.ArgumentParser(description='3D U-Net predictions')
    parser.add_argument('--model-path',
                        required=True,
                        type=str,
                        help='path to the model')
    parser.add_argument('--in-channels',
                        required=True,
                        type=int,
                        help='number of input channels')
    parser.add_argument('--out-channels',
                        required=True,
                        type=int,
                        help='number of output channels')
    parser.add_argument(
        '--init-channel-number',
        type=int,
        default=64,
        help=
        'Initial number of feature maps in the encoder path which gets doubled on every stage (default: 64)'
    )
    parser.add_argument('--interpolate',
                        help='use F.interpolate instead of ConvTranspose3d',
                        action='store_true')
    parser.add_argument(
        '--layer-order',
        type=str,
        help="Conv layer ordering, e.g. 'crg' -> Conv3D+ReLU+GroupNorm",
        default='crg')
    parser.add_argument(
        '--final-sigmoid',
        action='store_true',
        help=
        'if True apply element-wise nn.Sigmoid after the last layer otherwise apply nn.Softmax'
    )
    parser.add_argument('--test-path',
                        type=str,
                        required=True,
                        help='path to the test dataset')
    parser.add_argument('--raw-internal-path', type=str, default='raw')
    parser.add_argument(
        '--patch',
        required=True,
        type=int,
        nargs='+',
        default=None,
        help='Patch shape for used for prediction on the test set')
    parser.add_argument(
        '--stride',
        required=True,
        type=int,
        nargs='+',
        default=None,
        help='Patch stride for used for prediction on the test set')

    args = parser.parse_args()

    # make sure those values correspond to the ones used during training
    in_channels = args.in_channels
    out_channels = args.out_channels
    # use F.interpolate for upsampling
    interpolate = args.interpolate
    layer_order = args.layer_order
    final_sigmoid = args.final_sigmoid
    model = UNet3D(in_channels,
                   out_channels,
                   init_channel_number=args.init_channel_number,
                   final_sigmoid=final_sigmoid,
                   interpolate=interpolate,
                   conv_layer_order=layer_order)

    logger.info(f'Loading model from {args.model_path}...')
    utils.load_checkpoint(args.model_path, model)

    logger.info('Loading datasets...')

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        logger.warning('No CUDA device available. Using CPU for predictions')
        device = torch.device('cpu')

    model = model.to(device)

    patch = tuple(args.patch)
    stride = tuple(args.stride)

    dataset = HDF5Dataset(args.test_path,
                          patch,
                          stride,
                          phase='test',
                          raw_internal_path=args.raw_internal_path)
    probability_maps = predict(model, dataset, out_channels, device)

    output_file = f'{os.path.splitext(args.test_path)[0]}_probabilities.h5'

    # average channels only in case of final_sigmoid
    save_predictions(probability_maps, output_file, final_sigmoid)
示例#13
0
def _get_loaders(train_path,
                 val_path,
                 raw_internal_path,
                 label_internal_path,
                 label_dtype,
                 train_patch,
                 train_stride,
                 val_patch,
                 val_stride,
                 transformer,
                 pixel_wise_weight=False,
                 curriculum_learning=False,
                 ignore_index=None):
    """
    Returns dictionary containing the  training and validation loaders
    (torch.utils.data.DataLoader) backed by the datasets.hdf5.HDF5Dataset

    :param train_path: path to the H5 file containing the training set
    :param val_path: path to the H5 file containing the validation set
    :param raw_internal_path:
    :param label_internal_path:
    :param label_dtype: target type of the label dataset
    :param train_patch:
    :param train_stride:
    :param val_path:
    :param val_stride:
    :param transformer:
    :return: dict {
        'train': <train_loader>
        'val': <val_loader>
    }
    """
    transformers = {
        'LabelToBoundaryTransformer': LabelToBoundaryTransformer,
        'RandomLabelToBoundaryTransformer': RandomLabelToBoundaryTransformer,
        'AnisotropicRotationTransformer': AnisotropicRotationTransformer,
        'IsotropicRotationTransformer': IsotropicRotationTransformer,
        'StandardTransformer': StandardTransformer,
        'BaseTransformer': BaseTransformer
    }

    assert transformer in transformers

    if curriculum_learning:
        slice_builder_cls = CurriculumLearningSliceBuilder
    else:
        slice_builder_cls = SliceBuilder

    # create H5 backed training and validation dataset with data augmentation
    train_dataset = HDF5Dataset(train_path,
                                train_patch,
                                train_stride,
                                phase='train',
                                label_dtype=label_dtype,
                                raw_internal_path=raw_internal_path,
                                label_internal_path=label_internal_path,
                                transformer=transformers[transformer],
                                weighted=pixel_wise_weight,
                                ignore_index=ignore_index,
                                slice_builder_cls=slice_builder_cls)

    val_dataset = HDF5Dataset(val_path,
                              val_patch,
                              val_stride,
                              phase='val',
                              label_dtype=label_dtype,
                              raw_internal_path=raw_internal_path,
                              label_internal_path=label_internal_path,
                              transformer=transformers[transformer],
                              weighted=pixel_wise_weight,
                              ignore_index=ignore_index)

    # shuffle only if curriculum_learning scheme is not used
    return {
        'train':
        DataLoader(train_dataset,
                   batch_size=1,
                   shuffle=not curriculum_learning),
        'val':
        DataLoader(val_dataset, batch_size=1, shuffle=not curriculum_learning)
    }
示例#14
0
def main():
    parser = argparse.ArgumentParser(description='3D U-Net predictions')
    parser.add_argument('--model-path',
                        required=True,
                        type=str,
                        help='path to the model')
    parser.add_argument('--in-channels',
                        required=True,
                        type=int,
                        help='number of input channels')
    parser.add_argument('--out-channels',
                        required=True,
                        type=int,
                        help='number of output channels')
    parser.add_argument('--interpolate',
                        help='use F.interpolate instead of ConvTranspose3d',
                        action='store_true')
    parser.add_argument(
        '--average-channels',
        help=
        'average the probability_maps across the the channel axis (use only if your channels refer to the same semantic class)',
        action='store_true')
    parser.add_argument(
        '--layer-order',
        type=str,
        help="Conv layer ordering, e.g. 'crg' -> Conv3D+ReLU+GroupNorm",
        default='crg')
    parser.add_argument(
        '--loss',
        type=str,
        required=True,
        help=
        'Loss function used for training. Possible values: [ce, bce, wce, dice]. Has to be provided cause loss determines the final activation of the model.'
    )
    parser.add_argument('--test-path',
                        type=str,
                        required=True,
                        help='path to the test dataset')
    parser.add_argument(
        '--patch',
        required=True,
        type=int,
        nargs='+',
        default=None,
        help='Patch shape for used for prediction on the test set')
    parser.add_argument(
        '--stride',
        required=True,
        type=int,
        nargs='+',
        default=None,
        help='Patch stride for used for prediction on the test set')

    args = parser.parse_args()

    # make sure those values correspond to the ones used during training
    in_channels = args.in_channels
    out_channels = args.out_channels
    # use F.interpolate for upsampling
    interpolate = args.interpolate
    layer_order = args.layer_order
    final_sigmoid = _final_sigmoid(args.loss)
    model = UNet3D(in_channels,
                   out_channels,
                   final_sigmoid=final_sigmoid,
                   interpolate=interpolate,
                   conv_layer_order=layer_order)

    logger.info(f'Loading model from {args.model_path}...')
    utils.load_checkpoint(args.model_path, model)

    logger.info('Loading datasets...')

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        logger.warning('No CUDA device available. Using CPU for predictions')
        device = torch.device('cpu')

    model = model.to(device)

    patch = tuple(args.patch)
    stride = tuple(args.stride)

    dataset = HDF5Dataset(args.test_path, patch, stride, phase='test')
    probability_maps = predict(model, dataset, dataset.raw.shape, device)

    output_file = f'{os.path.splitext(args.test_path)[0]}_probabilities.h5'

    save_predictions(probability_maps, output_file, args.average_channels)