def _get_accuracy_criterion(should_normalize): """ Returns the segmentation's accuracy metric. Specify whether the criterion's input (model's output) should be normalized with nn.Softmax in order to obtain a valid probability distribution. :param should_normalize: whether or not to normalize the input :return: Dice coefficient callable """ return DiceCoefficient(should_normalize)
def _train_save_load(self, tmpdir, loss, max_num_epochs=1, log_after_iters=2, validate_after_iters=2, max_num_iterations=4): # get device to train on device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') # conv-relu-groupnorm conv_layer_order = 'crg' final_sigmoid = loss == 'bce' loss_criterion = get_loss_criterion(loss, final_sigmoid, weight=torch.rand(2).to(device)) model = self._create_model(final_sigmoid, conv_layer_order) accuracy_criterion = DiceCoefficient() channel_per_class = loss == 'bce' if loss in ['bce', 'dice']: label_dtype = 'float32' else: label_dtype = 'long' pixel_wise_weight = loss == 'pce' loaders = self._get_loaders(channel_per_class=channel_per_class, label_dtype=label_dtype, pixel_wise_weight=pixel_wise_weight) learning_rate = 2e-4 weight_decay = 0.0001 optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) logger = get_logger('UNet3DTrainer', logging.DEBUG) trainer = UNet3DTrainer(model, optimizer, loss_criterion, accuracy_criterion, device, loaders, tmpdir, max_num_epochs=max_num_epochs, log_after_iters=log_after_iters, validate_after_iters=validate_after_iters, max_num_iterations=max_num_iterations, logger=logger) trainer.fit() # test loading the trainer from the checkpoint trainer = UNet3DTrainer.from_checkpoint(os.path.join( tmpdir, 'last_checkpoint.pytorch'), model, optimizer, loss_criterion, accuracy_criterion, loaders, logger=logger) return trainer
def test_ignore_index_loss_with_dice_coeff(self): loss = DiceCoefficient(ignore_index=-1) input = torch.zeros((3, 3)) input[1, 1] = 1. target = -1. * torch.ones((3, 3)) target[1, 1] = 1. actual = loss(input, target) target = input.clone() expected = loss(input, target) assert expected == actual
def _compute_criterion(criterion): shape = [1, 0, 30, 30, 30] # channel size varies between 1 and 4 results = [] for C in range(1, 5): batch_shape = list(shape) batch_shape[1] = C batch_shape = tuple(batch_shape) # compute Dice Coefficient 100 times for i in range(100): dice = DiceCoefficient() input = torch.rand(batch_shape) target = torch.zeros(batch_shape).random_(0, 2) results.append(dice(input, target)) return results
def main(): parser = _arg_parser() logger = get_logger('UNet3DTrainer') # Get device to train on device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') args = parser.parse_args() logger.info(args) # Create loss criterion if args.loss_weight is not None: loss_weight = torch.tensor(args.loss_weight) loss_weight = loss_weight.to(device) else: loss_weight = None loss_criterion = get_loss_criterion(args.loss, loss_weight, args.ignore_index) model = UNet3D(args.in_channels, args.out_channels, init_channel_number=args.init_channel_number, conv_layer_order=args.layer_order, interpolate=args.interpolate, final_sigmoid=args.final_sigmoid) model = model.to(device) # Log the number of learnable parameters logger.info( f'Number of learnable params {get_number_of_learnable_parameters(model)}' ) # Create accuracy metric accuracy_criterion = DiceCoefficient(ignore_index=args.ignore_index) # Get data loaders. If 'bce' or 'dice' loss is used, convert labels to float train_path, val_path = args.train_path, args.val_path if args.loss in ['bce']: label_dtype = 'float32' else: label_dtype = 'long' train_patch = tuple(args.train_patch) train_stride = tuple(args.train_stride) val_patch = tuple(args.val_patch) val_stride = tuple(args.val_stride) logger.info(f'Train patch/stride: {train_patch}/{train_stride}') logger.info(f'Val patch/stride: {val_patch}/{val_stride}') pixel_wise_weight = args.loss == 'pce' loaders = get_loaders(train_path, val_path, label_dtype=label_dtype, raw_internal_path=args.raw_internal_path, label_internal_path=args.label_internal_path, train_patch=train_patch, train_stride=train_stride, val_patch=val_patch, val_stride=val_stride, transformer=args.transformer, pixel_wise_weight=pixel_wise_weight, curriculum_learning=args.curriculum, ignore_index=args.ignore_index) # Create the optimizer optimizer = _create_optimizer(args, model) if args.resume: trainer = UNet3DTrainer.from_checkpoint(args.resume, model, optimizer, loss_criterion, accuracy_criterion, loaders, logger=logger) else: trainer = UNet3DTrainer(model, optimizer, loss_criterion, accuracy_criterion, device, loaders, args.checkpoint_dir, max_num_epochs=args.epochs, max_num_iterations=args.iters, max_patience=args.patience, validate_after_iters=args.validate_after_iters, log_after_iters=args.log_after_iters, logger=logger) trainer.fit()
def test_dice_coefficient(self): results = _compute_criterion(DiceCoefficient()) # check that all of the coefficients belong to [0, 1] results = np.array(results) assert np.all(results > 0) assert np.all(results < 1)
def main(): parser = argparse.ArgumentParser(description='3D U-Net predictions') parser.add_argument('--cdmodel-path', required=True, type=str, help='path to the coordinate detector model.') parser.add_argument('--model-path', required=True, type=str, help='path to the segmentation model') parser.add_argument('--in-channels', type=int, default=1, help='number of input channels (default: 1)') parser.add_argument('--out-channels', type=int, default=2, help='number of output channels (default: 2)') parser.add_argument('--init-channel-number', type=int, default=64, help='Initial number of feature maps in the encoder path which gets doubled on every stage (default: 64)') parser.add_argument('--layer-order', type=str, help="Conv layer ordering, e.g. 'crg' -> Conv3D+ReLU+GroupNorm", default='crg') parser.add_argument('--final-sigmoid', action='store_true', help='if True apply element-wise nn.Sigmoid after the last layer otherwise apply nn.Softmax') parser.add_argument('--test-path', type=str, nargs='+', required=True, help='path to the test dataset') parser.add_argument('--raw-internal-path', type=str, default='raw') parser.add_argument('--patch', type=int, nargs='+', default=None, help='Patch shape for used for prediction on the test set') parser.add_argument('--stride', type=int, nargs='+', default=None, help='Patch stride for used for prediction on the test set') parser.add_argument('--report-metrics', action='store_true', help='Whether to print metrics for each prediction') parser.add_argument('--output-path', type=str, default='./output/', help='The output path to generate the nifti file') args = parser.parse_args() # Check if output path exists if not os.path.isdir(args.output_path): os.mkdir(args.output_path) # make sure those values correspond to the ones used during training in_channels = args.in_channels out_channels = args.out_channels # use F.interpolate for upsampling interpolate = True layer_order = args.layer_order final_sigmoid = args.final_sigmoid # Define model UNet_model = UNet3D(in_channels, out_channels, init_channel_number=args.init_channel_number, final_sigmoid=final_sigmoid, interpolate=interpolate, conv_layer_order=layer_order) Coor_model = CoorNet(in_channels) # Define metrics loss = nn.MSELoss(reduction='sum') acc = DiceCoefficient() logger.info('Loading trained coordinate detector model from ' + args.cdmodel_path) utils.load_checkpoint(args.cdmodel_path, Coor_model) logger.info('Loading trained segmentation model from ' + args.model_path) utils.load_checkpoint(args.model_path, UNet_model) # Load the model to the device if torch.cuda.is_available(): device = torch.device('cuda:0') else: logger.warning('No CUDA device available. Using CPU for predictions') device = torch.device('cpu') UNet_model = UNet_model.to(device) Coor_model = Coor_model.to(device) # Apply patch training if assigned if args.patch and args.stride: patch = tuple(args.patch) stride = tuple(args.stride) # Initialise counters total_dice = 0 total_loss = 0 count = 0 tmp_created = False for test_path in args.test_path: if test_path.endswith('.nii.gz'): if args.report_metrics: raise ValueError("Cannot report metrics on original files.") # Temporary save as h5 file # Preprocess if dim != 192 x 224 x 192 data = preprocess_nifti(test_path, args.output_path) logger.info('Preprocessing complete.') hf = h5py.File(test_path + '.h5', 'w') hf.create_dataset('raw', data=data) hf.close() test_path += '.h5' tmp_created = True if not args.patch and not args.stride: curr_shape = np.array(h5py.File(test_path, 'r')[args.raw_internal_path]).shape patch = curr_shape stride = curr_shape # Initialise dataset dataset = HDF5Dataset(test_path, patch, stride, phase='test', raw_internal_path=args.raw_internal_path) file_name = test_path.split('/')[-1].split('.')[0] # Predict the centre coordinates x, y, z = predict(Coor_model, dataset, out_channels, device) # Perform segmentation probability_maps = predict(UNet_model, dataset, out_channels, device, x, y, z) res = np.argmax(probability_maps, axis=0) # Put the image batch back to mask with the original size res = recover_patch(res, x, y, z, dataset.raw.shape) # Extract LH and RH segmentations and write as file LH = np.zeros(res.shape) LH[int(res.shape[0]/2):,:,:] = res[int(res.shape[0]/2):,:,:] RH = np.zeros(res.shape) RH[:int(res.shape[0]/2),:,:] = res[:int(res.shape[0]/2),:,:] LH_img = nib.Nifti1Image(LH, AFF) RH_img = nib.Nifti1Image(RH, AFF) nib.save(LH_img, args.output_path + file_name + '_LH.nii.gz') nib.save(RH_img, args.output_path + file_name + '_RH.nii.gz') logger.info('File saved to ' + args.output_path + file_name + '_LH.nii.gz') logger.info('File saved to ' + args.output_path + file_name + '_RH.nii.gz') if tmp_created: os.remove(test_path) if args.report_metrics: count += 1 # Compute coordinate accuracy # Coordinate evaluation disabled by default, since not all data have coordinate information # coor_dataset = HDF5Dataset(test_path, patch, stride, phase='val', raw_internal_path=args.raw_internal_path, label_internal_path='coor') # coor_target = coor_dataset[0][1].to(device) # coor_pred_tensor = torch.from_numpy(np.array([x, y, z])).to(device) # curr_coor_loss = loss(coor_pred_tensor, coor_target) # total_loss += curr_coor_loss # logger.info('Current coordinate loss: %f' % (curr_coor_loss)) # Compute segmentation Dice score label_dataset = HDF5Dataset(test_path, patch, stride, phase='val', raw_internal_path=args.raw_internal_path, label_internal_path='label') label_target = label_dataset[0][1].to(device) res_dice = probability_maps new_shape = np.append(res_dice.shape[0], np.array(label_target.size())) res_dice = recover_patch_4d(res_dice, x, y, z, new_shape) pred_tensor = torch.from_numpy(res_dice).to(device).float() label_target = label_target.view((1,) + label_target.shape) curr_dice_score = acc(pred_tensor, label_target.long()) total_dice += curr_dice_score logger.info('Current Dice score: %f' % (curr_dice_score)) # Compute length estimation logger.info('RH length: ' + str(get_total_dist(res[:int(res.shape[0]/2),:,:]))) logger.info('LH length: ' + str(get_total_dist(res[int(res.shape[0]/2):,:,:]))) if args.report_metrics: # logger.info('Average loss: %f.' % (total_loss/count)) logger.info('Average Dice score: %f.' % (total_dice/count))
def main(): parser = _arg_parser() logger = get_logger('Trainer') # Get device to train on device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') args = parser.parse_args() if args.loss_weight is not None: loss_weight = torch.tensor(args.loss_weight) loss_weight = loss_weight.to(device) else: loss_weight = None if args.network == 'cd': args.loss = 'mse' loss_criterion = get_loss_criterion('mse', loss_weight, args.ignore_index) model = CoorNet(args.in_channels) model = model.to(device) accuracy_criterion = PrecisionBasedAccuracy(30) elif args.network == 'seg': if not args.loss: raise ValueError("Invalid loss assigned.") loss_criterion = get_loss_criterion(args.loss, loss_weight, args.ignore_index) model = UNet3D(args.in_channels, args.out_channels, init_channel_number=args.init_channel_number, conv_layer_order=args.layer_order, interpolate=True, final_sigmoid=args.final_sigmoid) model = model.to(device) accuracy_criterion = DiceCoefficient(ignore_index=args.ignore_index) else: raise ValueError( "Incorrect network type defined by the --network argument, either cd or seg." ) # Get data loaders. If 'bce' or 'dice' loss is used, convert labels to float train_path = args.train_path if args.loss in ['bce', 'mse']: label_dtype = 'float32' else: label_dtype = 'long' train_patch = tuple(args.train_patch) train_stride = tuple(args.train_stride) pixel_wise_weight = args.loss == 'pce' loaders = get_loaders(train_path, label_dtype=label_dtype, raw_internal_path=args.raw_internal_path, label_internal_path=args.label_internal_path, train_patch=train_patch, train_stride=train_stride, transformer=args.transformer, pixel_wise_weight=pixel_wise_weight, curriculum_learning=args.curriculum, ignore_index=args.ignore_index) # Create the optimizer optimizer = _create_optimizer(args, model) if args.resume: trainer = UNet3DTrainer.from_checkpoint(args.resume, model, optimizer, loss_criterion, accuracy_criterion, loaders, logger=logger) else: trainer = UNet3DTrainer(model, optimizer, loss_criterion, accuracy_criterion, device, loaders, args.checkpoint_dir, max_num_epochs=args.epochs, max_num_iterations=args.iters, max_patience=args.patience, validate_after_iters=args.validate_after_iters, log_after_iters=args.log_after_iters, logger=logger) trainer.fit()