def _get_loaders(train_path, val_path, label_dtype, train_patch, train_stride, val_patch, val_stride): """ Returns dictionary containing the training and validation loaders (torch.utils.data.DataLoader) backed by the datasets.hdf5.HDF5Dataset :param train_path: path to the H5 file containing the training set :param val_path: path to the H5 file containing the validation set :param label_dtype: target type of the label dataset :return: dict { 'train': <train_loader> 'val': <val_loader> } """ # create H5 backed training dataset with data augmentation train_dataset = HDF5Dataset(train_path, train_patch, train_stride, phase='train', label_dtype=label_dtype, transformer=ExtendedTransformer) # create H5 backed validation dataset val_dataset = HDF5Dataset(val_path, val_patch, val_stride, phase='val', label_dtype=label_dtype) return { 'train': DataLoader(train_dataset, batch_size=1, shuffle=True), 'val': DataLoader(val_dataset, batch_size=1, shuffle=True) }
def _get_loaders(channel_per_class, label_dtype, pixel_wise_weight=False): train, val = TestUNet3DTrainer._create_random_dataset((128, 128, 128), (64, 64, 64), channel_per_class) train_dataset = HDF5Dataset(train, patch_shape=(32, 64, 64), stride_shape=(16, 32, 32), phase='train', label_dtype=label_dtype, weighted=pixel_wise_weight) val_dataset = HDF5Dataset(val, patch_shape=(64, 64, 64), stride_shape=(64, 64, 64), phase='val', label_dtype=label_dtype, weighted=pixel_wise_weight) return { 'train': DataLoader(train_dataset, batch_size=1, shuffle=True), 'val': DataLoader(val_dataset, batch_size=1, shuffle=True) }
def test_hdf5_dataset(self): path = create_random_dataset((128, 128, 128)) patch_shapes = [(127, 127, 127), (69, 70, 70), (32, 64, 64)] stride_shapes = [(1, 1, 1), (17, 23, 23), (32, 64, 64)] for patch_shape, stride_shape in zip(patch_shapes, stride_shapes): with h5py.File(path, 'r') as f: raw = f['raw'][...] label = f['label'][...] dataset = HDF5Dataset(path, patch_shape, stride_shape, 'test') # create zero-arrays of the same shape as the original dataset in order to verify if every element # was visited during the iteration visit_raw = np.zeros_like(raw) visit_label = np.zeros_like(label) for (_, idx) in dataset: visit_raw[idx] = 1 visit_label[idx] = 1 # verify that every element was visited at least once assert np.all(visit_raw) assert np.all(visit_label)
def test_augmentation(self): raw = np.random.rand(32, 96, 96) # assign raw to label's channels for ease of comparison label = np.stack(raw for _ in range(3)) tmp_file = NamedTemporaryFile() tmp_path = tmp_file.name with h5py.File(tmp_path, 'w') as f: f.create_dataset('raw', data=raw) f.create_dataset('label', data=label) dataset = HDF5Dataset(tmp_path, patch_shape=(16, 64, 64), stride_shape=(8, 32, 32), phase='train', transformer_config=transformer_config) # test augmentations using DataLoader with 4 worker threads data_loader = DataLoader(dataset, batch_size=1, num_workers=4, shuffle=True) for (img, label) in data_loader: for i in range(label.shape[0]): assert np.allclose(img, label[i])
def test_hdf5_with_multiple_label_datasets(self): path = create_random_dataset((128, 128, 128), label_datasets=['label1', 'label2']) patch_shape = (32, 64, 64) stride_shape = (32, 64, 64) dataset = HDF5Dataset(path, patch_shape, stride_shape, phase='train', transformer_config=transformer_config, raw_internal_path='raw', label_internal_path=['label1', 'label2']) for raw, labels in dataset: assert len(labels) == 2
def test_embeddings_predictor(self, tmpdir): config = {'model': {'output_heads': 1}, 'device': torch.device('cpu')} slice_builder_config = { 'name': 'SliceBuilder', 'patch_shape': (100, 200, 200), 'stride_shape': (60, 150, 150) } transformer_config = { 'raw': [{ 'name': 'ToTensor', 'expand_dims': False, 'dtype': 'long' }] } gt_file = 'resources/sample_cells.h5' output_file = os.path.join(tmpdir, 'output_segmentation.h5') dataset = HDF5Dataset(gt_file, phase='test', slice_builder_config=slice_builder_config, transformer_config=transformer_config, raw_internal_path='label') loader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=False, collate_fn=prediction_collate) predictor = FakePredictor(FakeModel(), loader, output_file, config, clustering='meanshift', bandwidth=0.5) predictor.predict() with h5py.File(gt_file, 'r') as f: with h5py.File(output_file, 'r') as g: gt = f['label'][...] segm = g['segmentation/meanshift'][...] arand_error = adapted_rand(segm, gt) assert arand_error < 0.1
def main(): config = parse_test_config() # make sure those values correspond to the ones used during training in_channels = config.in_channels out_channels = config.out_channels # use F.interpolate for upsampling interpolate = config.interpolate layer_order = config.layer_order final_sigmoid = config.final_sigmoid model = UNet3D(in_channels, out_channels, init_channel_number=config.init_channel_number, final_sigmoid=final_sigmoid, interpolate=interpolate, conv_layer_order=layer_order) logger.info(f'Loading model from {config.model_path}...') utils.load_checkpoint(config.model_path, model) logger.info('Loading datasets...') if torch.cuda.is_available(): device = torch.device('cuda:0') else: logger.warning('No CUDA device available. Using CPU for prediction...') device = torch.device('cpu') model = model.to(device) patch = tuple(config.patch) stride = tuple(config.stride) for test_path in config.test_path: # create dataset for a given test file dataset = HDF5Dataset(test_path, patch, stride, phase='test', raw_internal_path=config.raw_internal_path) # run the model prediction on the entire dataset probability_maps = predict(model, dataset, out_channels, device) # save the resulting probability maps output_file = f'{os.path.splitext(test_path)[0]}_probabilities.h5' save_predictions(probability_maps, output_file)
def test_augmentation(self): raw = np.random.rand(32, 96, 96) label = np.zeros((3, 32, 96, 96)) # assign raw to label's channels for ease of comparison for i in range(label.shape[0]): label[i] = raw tmp_file = NamedTemporaryFile() tmp_path = tmp_file.name with h5py.File(tmp_path, 'w') as f: f.create_dataset('raw', data=raw) f.create_dataset('label', data=label) dataset = HDF5Dataset(tmp_path, patch_shape=(16, 64, 64), stride_shape=(8, 32, 32), phase='train', transformer_config=transformer_config) for (img, label) in dataset: for i in range(label.shape[0]): assert np.allclose(img, label[i])
def test_cl_slice_builder(self): path = create_random_dataset((128, 128, 128), ignore_index=True) patch_shape = (32, 64, 64) stride_shape = (32, 64, 64) ignore_label_volumes = [] with h5py.File(path, 'r') as f: dataset = HDF5Dataset( path, patch_shape, stride_shape, 'train', slice_builder_cls=CurriculumLearningSliceBuilder) for _, label in dataset: ignore_label_volumes.append(np.count_nonzero(label == -1)) assert all(ignore_label_volumes[i] <= ignore_label_volumes[i + 1] for i in range(len(ignore_label_volumes) - 1))
def test_cl_slice_builder(self): path = create_random_dataset((128, 128, 128), ignore_index=True) patch_shape = (32, 64, 64) stride_shape = (32, 64, 64) ignore_label_volumes = [] dataset = HDF5Dataset(path, patch_shape, stride_shape, 'test', transformer_config=transformer_config, slice_builder_cls=CurriculumLearningSliceBuilder) for _, label in dataset: ignore_label_volumes.append(np.count_nonzero(label == -1)) # make sure that label patches are sorted by the number of ignore index voxels assert all(ignore_label_volumes[i] <= ignore_label_volumes[i + 1] for i in range(len(ignore_label_volumes) - 1))
def main(): parser = argparse.ArgumentParser(description='3D U-Net predictions') parser.add_argument('--cdmodel-path', required=True, type=str, help='path to the coordinate detector model.') parser.add_argument('--model-path', required=True, type=str, help='path to the segmentation model') parser.add_argument('--in-channels', type=int, default=1, help='number of input channels (default: 1)') parser.add_argument('--out-channels', type=int, default=2, help='number of output channels (default: 2)') parser.add_argument('--init-channel-number', type=int, default=64, help='Initial number of feature maps in the encoder path which gets doubled on every stage (default: 64)') parser.add_argument('--layer-order', type=str, help="Conv layer ordering, e.g. 'crg' -> Conv3D+ReLU+GroupNorm", default='crg') parser.add_argument('--final-sigmoid', action='store_true', help='if True apply element-wise nn.Sigmoid after the last layer otherwise apply nn.Softmax') parser.add_argument('--test-path', type=str, nargs='+', required=True, help='path to the test dataset') parser.add_argument('--raw-internal-path', type=str, default='raw') parser.add_argument('--patch', type=int, nargs='+', default=None, help='Patch shape for used for prediction on the test set') parser.add_argument('--stride', type=int, nargs='+', default=None, help='Patch stride for used for prediction on the test set') parser.add_argument('--report-metrics', action='store_true', help='Whether to print metrics for each prediction') parser.add_argument('--output-path', type=str, default='./output/', help='The output path to generate the nifti file') args = parser.parse_args() # Check if output path exists if not os.path.isdir(args.output_path): os.mkdir(args.output_path) # make sure those values correspond to the ones used during training in_channels = args.in_channels out_channels = args.out_channels # use F.interpolate for upsampling interpolate = True layer_order = args.layer_order final_sigmoid = args.final_sigmoid # Define model UNet_model = UNet3D(in_channels, out_channels, init_channel_number=args.init_channel_number, final_sigmoid=final_sigmoid, interpolate=interpolate, conv_layer_order=layer_order) Coor_model = CoorNet(in_channels) # Define metrics loss = nn.MSELoss(reduction='sum') acc = DiceCoefficient() logger.info('Loading trained coordinate detector model from ' + args.cdmodel_path) utils.load_checkpoint(args.cdmodel_path, Coor_model) logger.info('Loading trained segmentation model from ' + args.model_path) utils.load_checkpoint(args.model_path, UNet_model) # Load the model to the device if torch.cuda.is_available(): device = torch.device('cuda:0') else: logger.warning('No CUDA device available. Using CPU for predictions') device = torch.device('cpu') UNet_model = UNet_model.to(device) Coor_model = Coor_model.to(device) # Apply patch training if assigned if args.patch and args.stride: patch = tuple(args.patch) stride = tuple(args.stride) # Initialise counters total_dice = 0 total_loss = 0 count = 0 tmp_created = False for test_path in args.test_path: if test_path.endswith('.nii.gz'): if args.report_metrics: raise ValueError("Cannot report metrics on original files.") # Temporary save as h5 file # Preprocess if dim != 192 x 224 x 192 data = preprocess_nifti(test_path, args.output_path) logger.info('Preprocessing complete.') hf = h5py.File(test_path + '.h5', 'w') hf.create_dataset('raw', data=data) hf.close() test_path += '.h5' tmp_created = True if not args.patch and not args.stride: curr_shape = np.array(h5py.File(test_path, 'r')[args.raw_internal_path]).shape patch = curr_shape stride = curr_shape # Initialise dataset dataset = HDF5Dataset(test_path, patch, stride, phase='test', raw_internal_path=args.raw_internal_path) file_name = test_path.split('/')[-1].split('.')[0] # Predict the centre coordinates x, y, z = predict(Coor_model, dataset, out_channels, device) # Perform segmentation probability_maps = predict(UNet_model, dataset, out_channels, device, x, y, z) res = np.argmax(probability_maps, axis=0) # Put the image batch back to mask with the original size res = recover_patch(res, x, y, z, dataset.raw.shape) # Extract LH and RH segmentations and write as file LH = np.zeros(res.shape) LH[int(res.shape[0]/2):,:,:] = res[int(res.shape[0]/2):,:,:] RH = np.zeros(res.shape) RH[:int(res.shape[0]/2),:,:] = res[:int(res.shape[0]/2),:,:] LH_img = nib.Nifti1Image(LH, AFF) RH_img = nib.Nifti1Image(RH, AFF) nib.save(LH_img, args.output_path + file_name + '_LH.nii.gz') nib.save(RH_img, args.output_path + file_name + '_RH.nii.gz') logger.info('File saved to ' + args.output_path + file_name + '_LH.nii.gz') logger.info('File saved to ' + args.output_path + file_name + '_RH.nii.gz') if tmp_created: os.remove(test_path) if args.report_metrics: count += 1 # Compute coordinate accuracy # Coordinate evaluation disabled by default, since not all data have coordinate information # coor_dataset = HDF5Dataset(test_path, patch, stride, phase='val', raw_internal_path=args.raw_internal_path, label_internal_path='coor') # coor_target = coor_dataset[0][1].to(device) # coor_pred_tensor = torch.from_numpy(np.array([x, y, z])).to(device) # curr_coor_loss = loss(coor_pred_tensor, coor_target) # total_loss += curr_coor_loss # logger.info('Current coordinate loss: %f' % (curr_coor_loss)) # Compute segmentation Dice score label_dataset = HDF5Dataset(test_path, patch, stride, phase='val', raw_internal_path=args.raw_internal_path, label_internal_path='label') label_target = label_dataset[0][1].to(device) res_dice = probability_maps new_shape = np.append(res_dice.shape[0], np.array(label_target.size())) res_dice = recover_patch_4d(res_dice, x, y, z, new_shape) pred_tensor = torch.from_numpy(res_dice).to(device).float() label_target = label_target.view((1,) + label_target.shape) curr_dice_score = acc(pred_tensor, label_target.long()) total_dice += curr_dice_score logger.info('Current Dice score: %f' % (curr_dice_score)) # Compute length estimation logger.info('RH length: ' + str(get_total_dist(res[:int(res.shape[0]/2),:,:]))) logger.info('LH length: ' + str(get_total_dist(res[int(res.shape[0]/2):,:,:]))) if args.report_metrics: # logger.info('Average loss: %f.' % (total_loss/count)) logger.info('Average Dice score: %f.' % (total_dice/count))
def main(): parser = argparse.ArgumentParser(description='3D U-Net predictions') parser.add_argument('--model-path', required=True, type=str, help='path to the model') parser.add_argument('--in-channels', required=True, type=int, help='number of input channels') parser.add_argument('--out-channels', required=True, type=int, help='number of output channels') parser.add_argument( '--init-channel-number', type=int, default=64, help= 'Initial number of feature maps in the encoder path which gets doubled on every stage (default: 64)' ) parser.add_argument('--interpolate', help='use F.interpolate instead of ConvTranspose3d', action='store_true') parser.add_argument( '--layer-order', type=str, help="Conv layer ordering, e.g. 'crg' -> Conv3D+ReLU+GroupNorm", default='crg') parser.add_argument( '--final-sigmoid', action='store_true', help= 'if True apply element-wise nn.Sigmoid after the last layer otherwise apply nn.Softmax' ) parser.add_argument('--test-path', type=str, required=True, help='path to the test dataset') parser.add_argument('--raw-internal-path', type=str, default='raw') parser.add_argument( '--patch', required=True, type=int, nargs='+', default=None, help='Patch shape for used for prediction on the test set') parser.add_argument( '--stride', required=True, type=int, nargs='+', default=None, help='Patch stride for used for prediction on the test set') args = parser.parse_args() # make sure those values correspond to the ones used during training in_channels = args.in_channels out_channels = args.out_channels # use F.interpolate for upsampling interpolate = args.interpolate layer_order = args.layer_order final_sigmoid = args.final_sigmoid model = UNet3D(in_channels, out_channels, init_channel_number=args.init_channel_number, final_sigmoid=final_sigmoid, interpolate=interpolate, conv_layer_order=layer_order) logger.info(f'Loading model from {args.model_path}...') utils.load_checkpoint(args.model_path, model) logger.info('Loading datasets...') if torch.cuda.is_available(): device = torch.device('cuda:0') else: logger.warning('No CUDA device available. Using CPU for predictions') device = torch.device('cpu') model = model.to(device) patch = tuple(args.patch) stride = tuple(args.stride) dataset = HDF5Dataset(args.test_path, patch, stride, phase='test', raw_internal_path=args.raw_internal_path) probability_maps = predict(model, dataset, out_channels, device) output_file = f'{os.path.splitext(args.test_path)[0]}_probabilities.h5' # average channels only in case of final_sigmoid save_predictions(probability_maps, output_file, final_sigmoid)
def _get_loaders(train_path, val_path, raw_internal_path, label_internal_path, label_dtype, train_patch, train_stride, val_patch, val_stride, transformer, pixel_wise_weight=False, curriculum_learning=False, ignore_index=None): """ Returns dictionary containing the training and validation loaders (torch.utils.data.DataLoader) backed by the datasets.hdf5.HDF5Dataset :param train_path: path to the H5 file containing the training set :param val_path: path to the H5 file containing the validation set :param raw_internal_path: :param label_internal_path: :param label_dtype: target type of the label dataset :param train_patch: :param train_stride: :param val_path: :param val_stride: :param transformer: :return: dict { 'train': <train_loader> 'val': <val_loader> } """ transformers = { 'LabelToBoundaryTransformer': LabelToBoundaryTransformer, 'RandomLabelToBoundaryTransformer': RandomLabelToBoundaryTransformer, 'AnisotropicRotationTransformer': AnisotropicRotationTransformer, 'IsotropicRotationTransformer': IsotropicRotationTransformer, 'StandardTransformer': StandardTransformer, 'BaseTransformer': BaseTransformer } assert transformer in transformers if curriculum_learning: slice_builder_cls = CurriculumLearningSliceBuilder else: slice_builder_cls = SliceBuilder # create H5 backed training and validation dataset with data augmentation train_dataset = HDF5Dataset(train_path, train_patch, train_stride, phase='train', label_dtype=label_dtype, raw_internal_path=raw_internal_path, label_internal_path=label_internal_path, transformer=transformers[transformer], weighted=pixel_wise_weight, ignore_index=ignore_index, slice_builder_cls=slice_builder_cls) val_dataset = HDF5Dataset(val_path, val_patch, val_stride, phase='val', label_dtype=label_dtype, raw_internal_path=raw_internal_path, label_internal_path=label_internal_path, transformer=transformers[transformer], weighted=pixel_wise_weight, ignore_index=ignore_index) # shuffle only if curriculum_learning scheme is not used return { 'train': DataLoader(train_dataset, batch_size=1, shuffle=not curriculum_learning), 'val': DataLoader(val_dataset, batch_size=1, shuffle=not curriculum_learning) }
def main(): parser = argparse.ArgumentParser(description='3D U-Net predictions') parser.add_argument('--model-path', required=True, type=str, help='path to the model') parser.add_argument('--in-channels', required=True, type=int, help='number of input channels') parser.add_argument('--out-channels', required=True, type=int, help='number of output channels') parser.add_argument('--interpolate', help='use F.interpolate instead of ConvTranspose3d', action='store_true') parser.add_argument( '--average-channels', help= 'average the probability_maps across the the channel axis (use only if your channels refer to the same semantic class)', action='store_true') parser.add_argument( '--layer-order', type=str, help="Conv layer ordering, e.g. 'crg' -> Conv3D+ReLU+GroupNorm", default='crg') parser.add_argument( '--loss', type=str, required=True, help= 'Loss function used for training. Possible values: [ce, bce, wce, dice]. Has to be provided cause loss determines the final activation of the model.' ) parser.add_argument('--test-path', type=str, required=True, help='path to the test dataset') parser.add_argument( '--patch', required=True, type=int, nargs='+', default=None, help='Patch shape for used for prediction on the test set') parser.add_argument( '--stride', required=True, type=int, nargs='+', default=None, help='Patch stride for used for prediction on the test set') args = parser.parse_args() # make sure those values correspond to the ones used during training in_channels = args.in_channels out_channels = args.out_channels # use F.interpolate for upsampling interpolate = args.interpolate layer_order = args.layer_order final_sigmoid = _final_sigmoid(args.loss) model = UNet3D(in_channels, out_channels, final_sigmoid=final_sigmoid, interpolate=interpolate, conv_layer_order=layer_order) logger.info(f'Loading model from {args.model_path}...') utils.load_checkpoint(args.model_path, model) logger.info('Loading datasets...') if torch.cuda.is_available(): device = torch.device('cuda:0') else: logger.warning('No CUDA device available. Using CPU for predictions') device = torch.device('cpu') model = model.to(device) patch = tuple(args.patch) stride = tuple(args.stride) dataset = HDF5Dataset(args.test_path, patch, stride, phase='test') probability_maps = predict(model, dataset, dataset.raw.shape, device) output_file = f'{os.path.splitext(args.test_path)[0]}_probabilities.h5' save_predictions(probability_maps, output_file, args.average_channels)