コード例 #1
0
def _get_accuracy_criterion(should_normalize):
    """
    Returns the segmentation's accuracy metric. Specify whether the criterion's input (model's output) should be normalized
    with nn.Softmax in order to obtain a valid probability distribution.
    :param should_normalize: whether or not to normalize the input
    :return: Dice coefficient callable
    """
    return DiceCoefficient(should_normalize)
コード例 #2
0
 def _train_save_load(self,
                      tmpdir,
                      loss,
                      max_num_epochs=1,
                      log_after_iters=2,
                      validate_after_iters=2,
                      max_num_iterations=4):
     # get device to train on
     device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
     # conv-relu-groupnorm
     conv_layer_order = 'crg'
     final_sigmoid = loss == 'bce'
     loss_criterion = get_loss_criterion(loss,
                                         final_sigmoid,
                                         weight=torch.rand(2).to(device))
     model = self._create_model(final_sigmoid, conv_layer_order)
     accuracy_criterion = DiceCoefficient()
     channel_per_class = loss == 'bce'
     if loss in ['bce', 'dice']:
         label_dtype = 'float32'
     else:
         label_dtype = 'long'
     pixel_wise_weight = loss == 'pce'
     loaders = self._get_loaders(channel_per_class=channel_per_class,
                                 label_dtype=label_dtype,
                                 pixel_wise_weight=pixel_wise_weight)
     learning_rate = 2e-4
     weight_decay = 0.0001
     optimizer = optim.Adam(model.parameters(),
                            lr=learning_rate,
                            weight_decay=weight_decay)
     logger = get_logger('UNet3DTrainer', logging.DEBUG)
     trainer = UNet3DTrainer(model,
                             optimizer,
                             loss_criterion,
                             accuracy_criterion,
                             device,
                             loaders,
                             tmpdir,
                             max_num_epochs=max_num_epochs,
                             log_after_iters=log_after_iters,
                             validate_after_iters=validate_after_iters,
                             max_num_iterations=max_num_iterations,
                             logger=logger)
     trainer.fit()
     # test loading the trainer from the checkpoint
     trainer = UNet3DTrainer.from_checkpoint(os.path.join(
         tmpdir, 'last_checkpoint.pytorch'),
                                             model,
                                             optimizer,
                                             loss_criterion,
                                             accuracy_criterion,
                                             loaders,
                                             logger=logger)
     return trainer
コード例 #3
0
    def test_ignore_index_loss_with_dice_coeff(self):
        loss = DiceCoefficient(ignore_index=-1)
        input = torch.zeros((3, 3))
        input[1, 1] = 1.
        target = -1. * torch.ones((3, 3))
        target[1, 1] = 1.

        actual = loss(input, target)

        target = input.clone()
        expected = loss(input, target)

        assert expected == actual
コード例 #4
0
    def _compute_criterion(criterion):
        shape = [1, 0, 30, 30, 30]
        # channel size varies between 1 and 4
        results = []
        for C in range(1, 5):
            batch_shape = list(shape)
            batch_shape[1] = C
            batch_shape = tuple(batch_shape)
            # compute Dice Coefficient 100 times
            for i in range(100):
                dice = DiceCoefficient()
                input = torch.rand(batch_shape)
                target = torch.zeros(batch_shape).random_(0, 2)
                results.append(dice(input, target))

        return results
コード例 #5
0
ファイル: train.py プロジェクト: xiaochengcike/pytorch-3dunet
def main():
    parser = _arg_parser()
    logger = get_logger('UNet3DTrainer')
    # Get device to train on
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

    args = parser.parse_args()

    logger.info(args)

    # Create loss criterion
    if args.loss_weight is not None:
        loss_weight = torch.tensor(args.loss_weight)
        loss_weight = loss_weight.to(device)
    else:
        loss_weight = None

    loss_criterion = get_loss_criterion(args.loss, loss_weight,
                                        args.ignore_index)

    model = UNet3D(args.in_channels,
                   args.out_channels,
                   init_channel_number=args.init_channel_number,
                   conv_layer_order=args.layer_order,
                   interpolate=args.interpolate,
                   final_sigmoid=args.final_sigmoid)

    model = model.to(device)

    # Log the number of learnable parameters
    logger.info(
        f'Number of learnable params {get_number_of_learnable_parameters(model)}'
    )

    # Create accuracy metric
    accuracy_criterion = DiceCoefficient(ignore_index=args.ignore_index)

    # Get data loaders. If 'bce' or 'dice' loss is used, convert labels to float
    train_path, val_path = args.train_path, args.val_path
    if args.loss in ['bce']:
        label_dtype = 'float32'
    else:
        label_dtype = 'long'

    train_patch = tuple(args.train_patch)
    train_stride = tuple(args.train_stride)
    val_patch = tuple(args.val_patch)
    val_stride = tuple(args.val_stride)

    logger.info(f'Train patch/stride: {train_patch}/{train_stride}')
    logger.info(f'Val patch/stride: {val_patch}/{val_stride}')

    pixel_wise_weight = args.loss == 'pce'
    loaders = get_loaders(train_path,
                          val_path,
                          label_dtype=label_dtype,
                          raw_internal_path=args.raw_internal_path,
                          label_internal_path=args.label_internal_path,
                          train_patch=train_patch,
                          train_stride=train_stride,
                          val_patch=val_patch,
                          val_stride=val_stride,
                          transformer=args.transformer,
                          pixel_wise_weight=pixel_wise_weight,
                          curriculum_learning=args.curriculum,
                          ignore_index=args.ignore_index)

    # Create the optimizer
    optimizer = _create_optimizer(args, model)

    if args.resume:
        trainer = UNet3DTrainer.from_checkpoint(args.resume,
                                                model,
                                                optimizer,
                                                loss_criterion,
                                                accuracy_criterion,
                                                loaders,
                                                logger=logger)
    else:
        trainer = UNet3DTrainer(model,
                                optimizer,
                                loss_criterion,
                                accuracy_criterion,
                                device,
                                loaders,
                                args.checkpoint_dir,
                                max_num_epochs=args.epochs,
                                max_num_iterations=args.iters,
                                max_patience=args.patience,
                                validate_after_iters=args.validate_after_iters,
                                log_after_iters=args.log_after_iters,
                                logger=logger)

    trainer.fit()
コード例 #6
0
 def test_dice_coefficient(self):
     results = _compute_criterion(DiceCoefficient())
     # check that all of the coefficients belong to [0, 1]
     results = np.array(results)
     assert np.all(results > 0)
     assert np.all(results < 1)
コード例 #7
0
def main():
    parser = argparse.ArgumentParser(description='3D U-Net predictions')
    parser.add_argument('--cdmodel-path', required=True, type=str,
                        help='path to the coordinate detector model.')
    parser.add_argument('--model-path', required=True, type=str,
                        help='path to the segmentation model')
    parser.add_argument('--in-channels', type=int, default=1,
                        help='number of input channels (default: 1)')
    parser.add_argument('--out-channels', type=int, default=2,
                        help='number of output channels (default: 2)')
    parser.add_argument('--init-channel-number', type=int, default=64,
                        help='Initial number of feature maps in the encoder path which gets doubled on every stage (default: 64)')
    parser.add_argument('--layer-order', type=str,
                        help="Conv layer ordering, e.g. 'crg' -> Conv3D+ReLU+GroupNorm",
                        default='crg')
    parser.add_argument('--final-sigmoid',
                        action='store_true',
                        help='if True apply element-wise nn.Sigmoid after the last layer otherwise apply nn.Softmax')
    parser.add_argument('--test-path', type=str, nargs='+', required=True, help='path to the test dataset')
    parser.add_argument('--raw-internal-path', type=str, default='raw')
    parser.add_argument('--patch', type=int, nargs='+', default=None,
                        help='Patch shape for used for prediction on the test set')
    parser.add_argument('--stride', type=int, nargs='+', default=None,
                        help='Patch stride for used for prediction on the test set')
    parser.add_argument('--report-metrics', action='store_true',
                        help='Whether to print metrics for each prediction')
    parser.add_argument('--output-path', type=str, default='./output/',
                        help='The output path to generate the nifti file')

    args = parser.parse_args()

    # Check if output path exists
    if not os.path.isdir(args.output_path):
        os.mkdir(args.output_path)

    # make sure those values correspond to the ones used during training
    in_channels = args.in_channels
    out_channels = args.out_channels
    # use F.interpolate for upsampling
    interpolate = True
    layer_order = args.layer_order
    final_sigmoid = args.final_sigmoid

    # Define model
    UNet_model = UNet3D(in_channels, out_channels,
                       init_channel_number=args.init_channel_number,
                       final_sigmoid=final_sigmoid,
                       interpolate=interpolate,
                       conv_layer_order=layer_order)
    Coor_model = CoorNet(in_channels)

    # Define metrics
    loss = nn.MSELoss(reduction='sum')
    acc = DiceCoefficient()
    
    logger.info('Loading trained coordinate detector model from ' + args.cdmodel_path)
    utils.load_checkpoint(args.cdmodel_path, Coor_model)

    logger.info('Loading trained segmentation model from ' + args.model_path)
    utils.load_checkpoint(args.model_path, UNet_model)

    # Load the model to the device
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        logger.warning('No CUDA device available. Using CPU for predictions')
        device = torch.device('cpu')
    UNet_model = UNet_model.to(device)
    Coor_model = Coor_model.to(device)

    # Apply patch training if assigned
    if args.patch and args.stride:
        patch = tuple(args.patch)
        stride = tuple(args.stride)

    # Initialise counters
    total_dice = 0
    total_loss = 0
    count = 0
    tmp_created = False

    for test_path in args.test_path:
        if test_path.endswith('.nii.gz'):
            if args.report_metrics:
                raise ValueError("Cannot report metrics on original files.")
            # Temporary save as h5 file
            # Preprocess if dim != 192 x 224 x 192
            data = preprocess_nifti(test_path, args.output_path)
            logger.info('Preprocessing complete.')
            hf = h5py.File(test_path + '.h5', 'w')
            hf.create_dataset('raw', data=data)
            hf.close()
            test_path += '.h5'
            tmp_created = True
        if not args.patch and not args.stride:
            curr_shape = np.array(h5py.File(test_path, 'r')[args.raw_internal_path]).shape
            patch = curr_shape
            stride = curr_shape

        # Initialise dataset
        dataset = HDF5Dataset(test_path, patch, stride, phase='test', raw_internal_path=args.raw_internal_path)        

        file_name = test_path.split('/')[-1].split('.')[0]
        # Predict the centre coordinates
        x, y, z = predict(Coor_model, dataset, out_channels, device)

        # Perform segmentation
        probability_maps = predict(UNet_model, dataset, out_channels, device, x, y, z)
        res = np.argmax(probability_maps, axis=0)

        # Put the image batch back to mask with the original size
        res = recover_patch(res, x, y, z, dataset.raw.shape)

        # Extract LH and RH segmentations and write as file
        LH = np.zeros(res.shape)
        LH[int(res.shape[0]/2):,:,:] = res[int(res.shape[0]/2):,:,:]
        RH = np.zeros(res.shape)
        RH[:int(res.shape[0]/2),:,:] = res[:int(res.shape[0]/2),:,:]
        
        LH_img = nib.Nifti1Image(LH, AFF)
        RH_img = nib.Nifti1Image(RH, AFF)
        nib.save(LH_img, args.output_path + file_name + '_LH.nii.gz')
        nib.save(RH_img, args.output_path + file_name + '_RH.nii.gz')
        logger.info('File saved to ' + args.output_path + file_name + '_LH.nii.gz')
        logger.info('File saved to ' + args.output_path + file_name + '_RH.nii.gz')
        
        if tmp_created:
            os.remove(test_path)

        if args.report_metrics:
            count += 1

            # Compute coordinate accuracy
            # Coordinate evaluation disabled by default, since not all data have coordinate information
            # coor_dataset = HDF5Dataset(test_path, patch, stride, phase='val', raw_internal_path=args.raw_internal_path, label_internal_path='coor')
            # coor_target = coor_dataset[0][1].to(device)
            # coor_pred_tensor = torch.from_numpy(np.array([x, y, z])).to(device)
            # curr_coor_loss = loss(coor_pred_tensor, coor_target)
            # total_loss += curr_coor_loss
            # logger.info('Current coordinate loss: %f' % (curr_coor_loss))

            # Compute segmentation Dice score
            label_dataset = HDF5Dataset(test_path, patch, stride, phase='val', raw_internal_path=args.raw_internal_path, label_internal_path='label')
            label_target = label_dataset[0][1].to(device)
            res_dice = probability_maps
            new_shape = np.append(res_dice.shape[0], np.array(label_target.size()))
            res_dice = recover_patch_4d(res_dice, x, y, z, new_shape)
            pred_tensor = torch.from_numpy(res_dice).to(device).float()
            label_target = label_target.view((1,) + label_target.shape)
            curr_dice_score = acc(pred_tensor, label_target.long())
            total_dice += curr_dice_score
            logger.info('Current Dice score: %f' % (curr_dice_score))

            # Compute length estimation
            logger.info('RH length: ' + str(get_total_dist(res[:int(res.shape[0]/2),:,:])))
            logger.info('LH length: ' + str(get_total_dist(res[int(res.shape[0]/2):,:,:])))
    
    if args.report_metrics:       
        # logger.info('Average loss: %f.' % (total_loss/count))
        logger.info('Average Dice score: %f.' % (total_dice/count))
コード例 #8
0
def main():
    parser = _arg_parser()
    logger = get_logger('Trainer')
    # Get device to train on
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

    args = parser.parse_args()

    if args.loss_weight is not None:
        loss_weight = torch.tensor(args.loss_weight)
        loss_weight = loss_weight.to(device)
    else:
        loss_weight = None

    if args.network == 'cd':
        args.loss = 'mse'
        loss_criterion = get_loss_criterion('mse', loss_weight,
                                            args.ignore_index)

        model = CoorNet(args.in_channels)

        model = model.to(device)

        accuracy_criterion = PrecisionBasedAccuracy(30)

    elif args.network == 'seg':
        if not args.loss:
            raise ValueError("Invalid loss assigned.")
        loss_criterion = get_loss_criterion(args.loss, loss_weight,
                                            args.ignore_index)

        model = UNet3D(args.in_channels,
                       args.out_channels,
                       init_channel_number=args.init_channel_number,
                       conv_layer_order=args.layer_order,
                       interpolate=True,
                       final_sigmoid=args.final_sigmoid)

        model = model.to(device)

        accuracy_criterion = DiceCoefficient(ignore_index=args.ignore_index)

    else:
        raise ValueError(
            "Incorrect network type defined by the --network argument, either cd or seg."
        )

    # Get data loaders. If 'bce' or 'dice' loss is used, convert labels to float
    train_path = args.train_path
    if args.loss in ['bce', 'mse']:
        label_dtype = 'float32'
    else:
        label_dtype = 'long'

    train_patch = tuple(args.train_patch)
    train_stride = tuple(args.train_stride)

    pixel_wise_weight = args.loss == 'pce'

    loaders = get_loaders(train_path,
                          label_dtype=label_dtype,
                          raw_internal_path=args.raw_internal_path,
                          label_internal_path=args.label_internal_path,
                          train_patch=train_patch,
                          train_stride=train_stride,
                          transformer=args.transformer,
                          pixel_wise_weight=pixel_wise_weight,
                          curriculum_learning=args.curriculum,
                          ignore_index=args.ignore_index)

    # Create the optimizer
    optimizer = _create_optimizer(args, model)

    if args.resume:
        trainer = UNet3DTrainer.from_checkpoint(args.resume,
                                                model,
                                                optimizer,
                                                loss_criterion,
                                                accuracy_criterion,
                                                loaders,
                                                logger=logger)
    else:
        trainer = UNet3DTrainer(model,
                                optimizer,
                                loss_criterion,
                                accuracy_criterion,
                                device,
                                loaders,
                                args.checkpoint_dir,
                                max_num_epochs=args.epochs,
                                max_num_iterations=args.iters,
                                max_patience=args.patience,
                                validate_after_iters=args.validate_after_iters,
                                log_after_iters=args.log_after_iters,
                                logger=logger)

    trainer.fit()