예제 #1
0
def main():
    config = load_config()

    # make sure those values correspond to the ones used during training
    in_channels = config['in_channels']
    out_channels = config['out_channels']
    # use the same upsampling as during training
    interpolate = config['interpolate']
    # specify the layer ordering used during training
    layer_order = config['layer_order']
    # should sigmoid be used as a final activation layer
    final_sigmoid = config['final_sigmoid']
    init_channel_number = config['init_channel_number']
    model = UNet3D(in_channels,
                   out_channels,
                   init_channel_number=init_channel_number,
                   final_sigmoid=final_sigmoid,
                   interpolate=interpolate,
                   conv_layer_order=layer_order)

    model_path = config['model_path']
    logger.info(f'Loading model from {model_path}...')
    utils.load_checkpoint(model_path, model)

    device = config['device']
    model = model.to(device)

    logger.info('Loading datasets...')
    for test_dataset in get_test_datasets(config):
        # run the model prediction on the entire dataset
        probability_maps = predict(model, test_dataset, out_channels, device)
        # save the resulting probability maps
        output_file = _get_output_file(test_dataset)
        save_predictions(probability_maps, output_file)
예제 #2
0
 def _load_model(self, final_sigmoid, layer_order):
     in_channels = 1
     out_channels = 2
     # use F.interpolate for upsampling
     interpolate = True
     return UNet3D(in_channels, out_channels, interpolate, final_sigmoid,
                   layer_order)
예제 #3
0
def predict(
    model: UNet3D,
    brain_loader: BrainLoaders,
    out_files,
    num_workers,
    verbose=True,
):
    model.eval()

    loaders = brain_loader.test_loader(num_workers=num_workers)

    with torch.no_grad():
        for loader in loaders:  # For every image
            epoch_start_time = timeit.default_timer()
            result = Sticher(brain_loader.dataset.input_shape,
                             brain_loader.dataset.slices)

            for batch in loader:  # For every patch in image
                img_patch = batch['image']
                file_idx = batch['file_idx'].item()
                slice_idx = batch['slice_idx'].item()

                if verbose:
                    logging.info(
                        f"Predicting image {file_idx} slice id: {slice_idx} "
                        f"of shape {list(img_patch.size())} starting...")

                output = model(img_patch)

                if verbose:
                    logging.info(f"Slice id: {slice_idx} outputted.")

                output_patch = BrainDataset.post_process(output)

                if verbose:
                    logging.info(f"Slice id: {slice_idx} post-processed.")

                result.update(output_patch, slice_idx)

                if verbose:
                    logging.info(f"Slice id: {slice_idx} saved.")

            # Once all the patches are done, save the image
            result.save(out_files[file_idx],
                        brain_loader.dataset.get_nib_file(file_idx))
            elapsed = timeit.default_timer() - epoch_start_time
            logging.info(f'Predicted in {elapsed} seconds')
예제 #4
0
 def _create_model(final_sigmoid, layer_order):
     in_channels = 1
     out_channels = 2
     # use F.interpolate for upsampling
     return UNet3D(in_channels,
                   out_channels,
                   final_sigmoid=final_sigmoid,
                   interpolate=True,
                   conv_layer_order=layer_order)
예제 #5
0
    def __init__(self,
                 imgs_dir,
                 img_postfix='T1',
                 masks_dir=None,
                 stack_size=16,
                 stride=14,
                 mask_net=False,
                 verbose=False):
        self.img_filenames = listdir(imgs_dir)

        tags = [
            file[:file.find(img_postfix)]
            if file.find(img_postfix) != -1 else file[:file.find('.')]
            for file in self.img_filenames if not file.startswith('.')
        ]

        self.img_nibs = [get_nii_files(imgs_dir, tag) for tag in tags]

        self.img_files = [img.dataobj for img in self.img_nibs]
        self.mask_files = None if masks_dir is None else [
            get_nii_files(masks_dir, tag).dataobj for tag in tags
        ]

        self.input_shape = self.img_files[0].shape
        patch_shape = (stack_size, self.input_shape[1], self.input_shape[2])
        stride_shape = (stride, self.input_shape[1], self.input_shape[2])
        self.train_indices = []

        logging.info(f'Input shape: {self.input_shape}')
        logging.info(f'Patch shape: {patch_shape}')

        self.slices = build_slices(self.input_shape,
                                   patch_shape=patch_shape,
                                   stride_shape=stride_shape)

        logging.info(
            f'Creating dataset from {imgs_dir} and {masks_dir}'
            f'\nWith {len(self.img_files)} examples and {len(self.slices)} slices each'
        )
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.length = len(self.img_files) * len(self.slices)
        self.verbose = verbose
        if mask_net:
            logging.info(f'Loading Pre-process Network from {mask_net}')
            pre_net = UNet3D(in_channels=1,
                             out_channels=2,
                             final_sigmoid=False,
                             testing=True)
            pre_net.load_state_dict(
                torch.load(mask_net, map_location=self.device))
            pre_net.to(device=self.device)
            logging.info(f'Pre-process Network Loaded')
            self.mask_net = pre_net
        else:
            self.mask_net = None
예제 #6
0
 def _create_model(final_sigmoid, layer_order):
     in_channels = 1
     out_channels = 2
     # use F.interpolate for upsampling and 16 initial feature maps to speed up the tests
     return UNet3D(in_channels,
                   out_channels,
                   init_channel_number=16,
                   final_sigmoid=final_sigmoid,
                   interpolate=True,
                   conv_layer_order=layer_order)
예제 #7
0
def _create_model(in_channels,
                  out_channels,
                  layer_order,
                  interpolate=False,
                  final_sigmoid=True):
    return UNet3D(in_channels,
                  out_channels,
                  interpolate,
                  final_sigmoid,
                  conv_layer_order=layer_order)
예제 #8
0
def main():
    def _get_output_file(dataset):
        volume_file, volume_ext = os.path.splitext(dataset.path)
        return volume_file + '_probabilities' + volume_ext

    parser = argparse.ArgumentParser()
    parser.add_argument('--model-path', type=str, help='path to the model')
    parser.add_argument('--config-path', type=str,
                        help='path to the dataset config')
    parser.add_argument('--in-channels', default=1, type=int,
                        help='number of input channels')
    parser.add_argument('--out-channels', default=6, type=int,
                        help='number of output channels')
    parser.add_argument('--layer-order', type=str,
                        help="Conv layer ordering, e.g. 'brc' -> BatchNorm3d+ReLU+Conv3D",
                        default='brc')
    parser.add_argument('--interpolate',
                        help='use F.interpolate instead of ConvTranspose3d',
                        action='store_true')

    args = parser.parse_args()

    # make sure those values correspond to the ones used during training
    in_channels = args.in_channels
    out_channels = args.out_channels
    # use F.interpolate for upsampling
    interpolate = args.interpolate
    # Conv layer ordering e.g. 'cr' is equivalent to Conv3D+ReLU
    conv_layer_order = args.layer_order
    model = UNet3D(in_channels, out_channels, interpolate=interpolate,
                   final_sigmoid=True, conv_layer_order=conv_layer_order)

    logger.info(f'Loading model from {args.model_path}...')
    utils.load_checkpoint(args.model_path, model)

    logger.info('Loading datasets...')
    raw_volumes = get_raw_volumes(args.config_path)

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        logger.warning(
            'No CUDA device available. Using CPU for predictions')
        device = torch.device('cpu')

    model = model.to(device)

    for raw_volume in raw_volumes:
        probability_maps = predict(model, raw_volume, device)

        output_file = _get_output_file(raw_volume)

        save_predictions(probability_maps, output_file)
예제 #9
0
def validate(model: UNet3D, brain_loader: BrainLoaders, loss_fnc=None, is_validation=True, quiet=True):
    val_losses = RunningAverage()
    val_scores = {fnc: RunningAverage() for fnc in METRICS}

    loader = brain_loader.validation_loader() if is_validation else brain_loader.test_loader()

    with torch.no_grad():
        for batch in loader:
            img_batch = batch['image']
            mask_batch = batch['mask']
            file_idx = batch['file_idx'][0]

            patch_validator = PatchValidator(device=brain_loader.device, image_shape=brain_loader.dataset.input_shape)
            for slice_no, (img_patch, mask_patch) in enumerate(zip(img_batch, mask_batch)):
                img_patch = img_patch.unsqueeze(0)
                mask_patch = mask_patch.unsqueeze(0)

                # forward pass
                output = model(img_patch)

                # compute the loss
                if loss_fnc is not None:
                    loss = loss_fnc(output, mask_patch)
                    val_losses.update(loss.item(), n=1)

                # if model contains final_activation layer for normalizing logits apply it, otherwise
                # the evaluation metric will be incorrectly computed
                if hasattr(model, 'final_activation') and model.final_activation is not None and not model.testing:
                    output = model.final_activation(output)

                slices = brain_loader.dataset.slices
                patch_validator.update(output, mask_patch, patch_slice=slices[slice_no % len(slices)])

            # When we finish an image, calculate the score for it and update mean
            if not quiet:
                logging.info(f'Image:{file_idx}:')
            for fnc in METRICS:
                score = patch_validator.calculate_fnc(fnc)
                val_scores[fnc].update(score, n=1)

                if not quiet:
                    logging.info(f"\t\t{fnc+' Score:':<30}{score}")

    if loss_fnc is not None:
        logging.info(f'Validation: Avg. Loss: {val_losses.avg}.')
    scores = '\n\t'.join([f"{fnc+' Score:':<30}{val_scores[fnc].avg}" for fnc in METRICS])
    logging.info(f'Evaluation Scores: \n\t{scores}')

    return {fnc: v.avg for fnc, v in val_scores.items()}
예제 #10
0
def main():
    logger = get_logger('UNet3DTrainer')

    config = load_config()

    logger.info(config)

    # Create loss criterion
    loss_criterion = get_loss_criterion(config)

    # Create the model
    model = UNet3D(config['in_channels'], config['out_channels'],
                   final_sigmoid=config['final_sigmoid'],
                   init_channel_number=config['init_channel_number'],
                   conv_layer_order=config['layer_order'],
                   interpolate=config['interpolate'])

    model = model.to(config['device'])

    # Log the number of learnable parameters
    logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')

    # Create evaluation metric
    eval_criterion = get_evaluation_metric(config)

    loaders = get_train_loaders(config)

    # Create the optimizer
    optimizer = _create_optimizer(config, model)

    # Create learning rate adjustment strategy
    lr_scheduler = _create_lr_scheduler(config, optimizer)

    if config['resume'] is not None:
        trainer = UNet3DTrainer.from_checkpoint(config['resume'], model,
                                                optimizer, lr_scheduler, loss_criterion,
                                                eval_criterion, loaders,
                                                logger=logger)
    else:
        trainer = UNet3DTrainer(model, optimizer, lr_scheduler, loss_criterion, eval_criterion,
                                config['device'], loaders, config['checkpoint_dir'],
                                max_num_epochs=config['epochs'],
                                max_num_iterations=config['iters'],
                                validate_after_iters=config['validate_after_iters'],
                                log_after_iters=config['log_after_iters'],
                                logger=logger)

    trainer.fit()
예제 #11
0
def main():
    parser = argparse.ArgumentParser(description='3D U-Net predictions')
    parser.add_argument('--model-path', required=True, type=str,
                        help='path to the model')
    parser.add_argument('--in-channels', required=True, type=int,
                        help='number of input channels')
    parser.add_argument('--out-channels', required=True, type=int,
                        help='number of output channels')
    parser.add_argument('--interpolate',
                        help='use F.interpolate instead of ConvTranspose3d',
                        action='store_true')
    parser.add_argument('--layer-order', type=str,
                        help="Conv layer ordering, e.g. 'brc' -> BatchNorm3d+ReLU+Conv3D",
                        default='brc')

    args = parser.parse_args()

    # make sure those values correspond to the ones used during training
    in_channels = args.in_channels
    out_channels = args.out_channels
    # use F.interpolate for upsampling
    interpolate = args.interpolate
    layer_order = args.layer_order
    model = UNet3D(in_channels, out_channels, interpolate=interpolate,
                   final_sigmoid=True, conv_layer_order=layer_order)

    logger.info(f'Loading model from {args.model_path}...')
    utils.load_checkpoint(args.model_path, model)

    logger.info('Loading datasets...')

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        logger.warning(
            'No CUDA device available. Using CPU for predictions')
        device = torch.device('cpu')

    model = model.to(device)

    dataset, dataset_size = _get_dataset(1)
    probability_maps = predict(model, dataset, dataset_size, device)

    output_file = os.path.join(os.path.split(args.model_path)[0],
                               'probabilities.h5')

    save_predictions(probability_maps, output_file)
예제 #12
0
def main():
    config = parse_test_config()

    # make sure those values correspond to the ones used during training
    in_channels = config.in_channels
    out_channels = config.out_channels
    # use F.interpolate for upsampling
    interpolate = config.interpolate
    layer_order = config.layer_order
    final_sigmoid = config.final_sigmoid
    model = UNet3D(in_channels,
                   out_channels,
                   init_channel_number=config.init_channel_number,
                   final_sigmoid=final_sigmoid,
                   interpolate=interpolate,
                   conv_layer_order=layer_order)

    logger.info(f'Loading model from {config.model_path}...')
    utils.load_checkpoint(config.model_path, model)

    logger.info('Loading datasets...')

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        logger.warning('No CUDA device available. Using CPU for prediction...')
        device = torch.device('cpu')

    model = model.to(device)

    patch = tuple(config.patch)
    stride = tuple(config.stride)

    for test_path in config.test_path:
        # create dataset for a given test file
        dataset = HDF5Dataset(test_path,
                              patch,
                              stride,
                              phase='test',
                              raw_internal_path=config.raw_internal_path)
        # run the model prediction on the entire dataset
        probability_maps = predict(model, dataset, out_channels, device)
        # save the resulting probability maps
        output_file = f'{os.path.splitext(test_path)[0]}_probabilities.h5'
        save_predictions(probability_maps, output_file)
예제 #13
0
        "%b-%d-%Y_%I-%M-%S_%p") + '_' + args.name + args.output
    handlers = [logging.FileHandler(output_filename)]
    if not args.quiet:
        handlers.append(logging.StreamHandler(sys.stdout))

    logging.basicConfig(level=logging.INFO,
                        format='%(levelname)s: %(message)s',
                        handlers=handlers)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    input_channels = 1
    output_channels = 2
    net = UNet3D(in_channels=input_channels,
                 out_channels=output_channels,
                 final_sigmoid=False)
    logging.info(f'Network:\n'
                 f'\t{input_channels} input channels\n'
                 f'\t{output_channels} output channels (classes)\n')

    if args.load:
        net.load_state_dict(torch.load(args.load, map_location=device))
        logging.info(f'Model loaded from {args.load}')

    net.to(device=device)
    # faster convolutions, but more memory
    cudnn.benchmark = True

    try:
        if args.memory:
예제 #14
0
def train_net(model: UNet3D,
              epochs=5,
              learning_rate=0.0002,
              val_percent=0.1,
              test_percent=0.1,
              name='U-Net',
              tests=None,
              patch_size=16,
              testing_memory=False,
              mask_model=False):

    data_set = BrainDataset(dir_img,
                            'T1',
                            dir_mask,
                            stack_size=patch_size,
                            mask_net=mask_model)
    loader = BrainLoaders(data_set,
                          ratios=[val_percent, test_percent],
                          files=[None, tests])

    train_loader = loader.train_loader()
    val_loader = loader.validation_loader()
    test_loader = loader.test_loader()

    num_images = data_set.num_files()
    log_interval = len(train_loader) if num_images < 10 else len(
        data_set.slices) * (num_images // 10)
    global_step = 0
    logging.info(f'''Starting {name} training:
        Epochs:          {epochs}
        Learning rate:   {learning_rate}
        Training size:   {len(train_loader)} slices
        Validation size: {len(val_loader)} images
        Testing size:    {len(test_loader)} images
        Log Interval     {log_interval}
    ''')

    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           weight_decay=0.00001)
    losses = []
    val_scores = {}
    for fnc in METRICS:
        val_scores[fnc] = []

    for epoch in range(epochs):

        epoch_loss = 0
        epoch_start_time = timeit.default_timer()
        log_start_time = timeit.default_timer()
        log_loss = RunningAverage()
        for batch in train_loader:
            model.train()

            img = batch['image']
            mask = batch['mask']

            masks_pred = model(img)

            loss = loss_fnc(masks_pred, mask)

            epoch_loss += loss.item()
            log_loss.update(loss.item(), n=1)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

            if testing_memory:  # When testing patch sizes only one iteration is enough
                return

            global_step += 1
            if global_step % log_interval == 0:
                elapsed = timeit.default_timer() - log_start_time
                losses.append(log_loss.avg)
                logging.info(
                    f'I: {global_step}, Avg. Loss: {log_loss.avg} in {elapsed} seconds'
                )
                log_start_time = timeit.default_timer()
                log_loss = RunningAverage()

        scores = validate(model, loader, is_validation=True, loss_fnc=loss_fnc)
        for fnc in METRICS:
            val_scores[fnc].append(scores[fnc])

        make_dir(dir_checkpoint)
        torch.save(model.state_dict(),
                   dir_checkpoint + f'{name}_epoch{epoch + 1}.pth')
        elapsed = timeit.default_timer() - epoch_start_time
        logging.info(
            f'Epoch: {epoch + 1} Total Loss: {epoch_loss} in {elapsed} seconds'
        )
        logging.info(f'Checkpoint {epoch + 1} saved !')
        plot_cost(losses, name='Loss', model_name=name + str(epoch) + '_')
        for fnc in METRICS:
            plot_cost(val_scores[fnc],
                      name='Validation_' + type(fnc).__name__,
                      model_name=name + str(epoch) + '_')

    logging.info('Starting Testing')
    validate(model,
             loader,
             is_validation=False,
             loss_fnc=loss_fnc,
             quiet=False)
예제 #15
0
    return parser.parse_args()


if __name__ == '__main__':
    args = get_args()

    output_filename = args.output + '_' + args.name + datetime.now().strftime("_%b-%d-%Y_%I-%M-%S_%p")
    handlers = [logging.FileHandler(output_filename)]
    if not args.quiet:
        handlers.append(logging.StreamHandler(sys.stdout))

    logging.basicConfig(level=logging.INFO,
                        format='%(levelname)s: %(message)s',
                        handlers=handlers)

    net = UNet3D(in_channels=1, out_channels=2, testing=True, final_sigmoid=False)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net.to(device=device)
    logging.info(f'Using device {device}')

    logging.info("Testing model {}".format(args.model))
    net.load_state_dict(torch.load(args.model, map_location=device))

    logging.info(f"On dataset in {args.input} and {args.label}")
    data_set = BrainDataset(args.input, 'T1', args.label)

    if args.tests:
        data_loader = BrainLoaders(data_set, files=[None, args.tests])
    else:
        data_loader = BrainLoaders(data_set, ratios=[0, 1])
예제 #16
0
def main():
    parser = argparse.ArgumentParser(description='3D U-Net predictions')
    parser.add_argument('--model-path',
                        required=True,
                        type=str,
                        help='path to the model')
    parser.add_argument('--in-channels',
                        required=True,
                        type=int,
                        help='number of input channels')
    parser.add_argument('--out-channels',
                        required=True,
                        type=int,
                        help='number of output channels')
    parser.add_argument('--interpolate',
                        help='use F.interpolate instead of ConvTranspose3d',
                        action='store_true')
    parser.add_argument(
        '--average-channels',
        help=
        'average the probability_maps across the the channel axis (use only if your channels refer to the same semantic class)',
        action='store_true')
    parser.add_argument(
        '--layer-order',
        type=str,
        help="Conv layer ordering, e.g. 'crg' -> Conv3D+ReLU+GroupNorm",
        default='crg')
    parser.add_argument(
        '--loss',
        type=str,
        required=True,
        help=
        'Loss function used for training. Possible values: [ce, bce, wce, dice]. Has to be provided cause loss determines the final activation of the model.'
    )
    parser.add_argument('--test-path',
                        type=str,
                        required=True,
                        help='path to the test dataset')
    parser.add_argument(
        '--patch',
        required=True,
        type=int,
        nargs='+',
        default=None,
        help='Patch shape for used for prediction on the test set')
    parser.add_argument(
        '--stride',
        required=True,
        type=int,
        nargs='+',
        default=None,
        help='Patch stride for used for prediction on the test set')

    args = parser.parse_args()

    # make sure those values correspond to the ones used during training
    in_channels = args.in_channels
    out_channels = args.out_channels
    # use F.interpolate for upsampling
    interpolate = args.interpolate
    layer_order = args.layer_order
    final_sigmoid = _final_sigmoid(args.loss)
    model = UNet3D(in_channels,
                   out_channels,
                   final_sigmoid=final_sigmoid,
                   interpolate=interpolate,
                   conv_layer_order=layer_order)

    logger.info(f'Loading model from {args.model_path}...')
    utils.load_checkpoint(args.model_path, model)

    logger.info('Loading datasets...')

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        logger.warning('No CUDA device available. Using CPU for predictions')
        device = torch.device('cpu')

    model = model.to(device)

    patch = tuple(args.patch)
    stride = tuple(args.stride)

    dataset = HDF5Dataset(args.test_path, patch, stride, phase='test')
    probability_maps = predict(model, dataset, dataset.raw.shape, device)

    output_file = f'{os.path.splitext(args.test_path)[0]}_probabilities.h5'

    save_predictions(probability_maps, output_file, args.average_channels)
예제 #17
0
def main():
    parser = argparse.ArgumentParser(description='3D U-Net predictions')
    parser.add_argument('--cdmodel-path', required=True, type=str,
                        help='path to the coordinate detector model.')
    parser.add_argument('--model-path', required=True, type=str,
                        help='path to the segmentation model')
    parser.add_argument('--in-channels', type=int, default=1,
                        help='number of input channels (default: 1)')
    parser.add_argument('--out-channels', type=int, default=2,
                        help='number of output channels (default: 2)')
    parser.add_argument('--init-channel-number', type=int, default=64,
                        help='Initial number of feature maps in the encoder path which gets doubled on every stage (default: 64)')
    parser.add_argument('--layer-order', type=str,
                        help="Conv layer ordering, e.g. 'crg' -> Conv3D+ReLU+GroupNorm",
                        default='crg')
    parser.add_argument('--final-sigmoid',
                        action='store_true',
                        help='if True apply element-wise nn.Sigmoid after the last layer otherwise apply nn.Softmax')
    parser.add_argument('--test-path', type=str, nargs='+', required=True, help='path to the test dataset')
    parser.add_argument('--raw-internal-path', type=str, default='raw')
    parser.add_argument('--patch', type=int, nargs='+', default=None,
                        help='Patch shape for used for prediction on the test set')
    parser.add_argument('--stride', type=int, nargs='+', default=None,
                        help='Patch stride for used for prediction on the test set')
    parser.add_argument('--report-metrics', action='store_true',
                        help='Whether to print metrics for each prediction')
    parser.add_argument('--output-path', type=str, default='./output/',
                        help='The output path to generate the nifti file')

    args = parser.parse_args()

    # Check if output path exists
    if not os.path.isdir(args.output_path):
        os.mkdir(args.output_path)

    # make sure those values correspond to the ones used during training
    in_channels = args.in_channels
    out_channels = args.out_channels
    # use F.interpolate for upsampling
    interpolate = True
    layer_order = args.layer_order
    final_sigmoid = args.final_sigmoid

    # Define model
    UNet_model = UNet3D(in_channels, out_channels,
                       init_channel_number=args.init_channel_number,
                       final_sigmoid=final_sigmoid,
                       interpolate=interpolate,
                       conv_layer_order=layer_order)
    Coor_model = CoorNet(in_channels)

    # Define metrics
    loss = nn.MSELoss(reduction='sum')
    acc = DiceCoefficient()
    
    logger.info('Loading trained coordinate detector model from ' + args.cdmodel_path)
    utils.load_checkpoint(args.cdmodel_path, Coor_model)

    logger.info('Loading trained segmentation model from ' + args.model_path)
    utils.load_checkpoint(args.model_path, UNet_model)

    # Load the model to the device
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        logger.warning('No CUDA device available. Using CPU for predictions')
        device = torch.device('cpu')
    UNet_model = UNet_model.to(device)
    Coor_model = Coor_model.to(device)

    # Apply patch training if assigned
    if args.patch and args.stride:
        patch = tuple(args.patch)
        stride = tuple(args.stride)

    # Initialise counters
    total_dice = 0
    total_loss = 0
    count = 0
    tmp_created = False

    for test_path in args.test_path:
        if test_path.endswith('.nii.gz'):
            if args.report_metrics:
                raise ValueError("Cannot report metrics on original files.")
            # Temporary save as h5 file
            # Preprocess if dim != 192 x 224 x 192
            data = preprocess_nifti(test_path, args.output_path)
            logger.info('Preprocessing complete.')
            hf = h5py.File(test_path + '.h5', 'w')
            hf.create_dataset('raw', data=data)
            hf.close()
            test_path += '.h5'
            tmp_created = True
        if not args.patch and not args.stride:
            curr_shape = np.array(h5py.File(test_path, 'r')[args.raw_internal_path]).shape
            patch = curr_shape
            stride = curr_shape

        # Initialise dataset
        dataset = HDF5Dataset(test_path, patch, stride, phase='test', raw_internal_path=args.raw_internal_path)        

        file_name = test_path.split('/')[-1].split('.')[0]
        # Predict the centre coordinates
        x, y, z = predict(Coor_model, dataset, out_channels, device)

        # Perform segmentation
        probability_maps = predict(UNet_model, dataset, out_channels, device, x, y, z)
        res = np.argmax(probability_maps, axis=0)

        # Put the image batch back to mask with the original size
        res = recover_patch(res, x, y, z, dataset.raw.shape)

        # Extract LH and RH segmentations and write as file
        LH = np.zeros(res.shape)
        LH[int(res.shape[0]/2):,:,:] = res[int(res.shape[0]/2):,:,:]
        RH = np.zeros(res.shape)
        RH[:int(res.shape[0]/2),:,:] = res[:int(res.shape[0]/2),:,:]
        
        LH_img = nib.Nifti1Image(LH, AFF)
        RH_img = nib.Nifti1Image(RH, AFF)
        nib.save(LH_img, args.output_path + file_name + '_LH.nii.gz')
        nib.save(RH_img, args.output_path + file_name + '_RH.nii.gz')
        logger.info('File saved to ' + args.output_path + file_name + '_LH.nii.gz')
        logger.info('File saved to ' + args.output_path + file_name + '_RH.nii.gz')
        
        if tmp_created:
            os.remove(test_path)

        if args.report_metrics:
            count += 1

            # Compute coordinate accuracy
            # Coordinate evaluation disabled by default, since not all data have coordinate information
            # coor_dataset = HDF5Dataset(test_path, patch, stride, phase='val', raw_internal_path=args.raw_internal_path, label_internal_path='coor')
            # coor_target = coor_dataset[0][1].to(device)
            # coor_pred_tensor = torch.from_numpy(np.array([x, y, z])).to(device)
            # curr_coor_loss = loss(coor_pred_tensor, coor_target)
            # total_loss += curr_coor_loss
            # logger.info('Current coordinate loss: %f' % (curr_coor_loss))

            # Compute segmentation Dice score
            label_dataset = HDF5Dataset(test_path, patch, stride, phase='val', raw_internal_path=args.raw_internal_path, label_internal_path='label')
            label_target = label_dataset[0][1].to(device)
            res_dice = probability_maps
            new_shape = np.append(res_dice.shape[0], np.array(label_target.size()))
            res_dice = recover_patch_4d(res_dice, x, y, z, new_shape)
            pred_tensor = torch.from_numpy(res_dice).to(device).float()
            label_target = label_target.view((1,) + label_target.shape)
            curr_dice_score = acc(pred_tensor, label_target.long())
            total_dice += curr_dice_score
            logger.info('Current Dice score: %f' % (curr_dice_score))

            # Compute length estimation
            logger.info('RH length: ' + str(get_total_dist(res[:int(res.shape[0]/2),:,:])))
            logger.info('LH length: ' + str(get_total_dist(res[int(res.shape[0]/2):,:,:])))
    
    if args.report_metrics:       
        # logger.info('Average loss: %f.' % (total_loss/count))
        logger.info('Average Dice score: %f.' % (total_dice/count))
예제 #18
0
def main():
    logger = get_logger('UNet3DTrainer')
    # Get device to train on
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

    config = parse_train_config()

    logger.info(config)

    # Create loss criterion
    if config.loss_weight is not None:
        loss_weight = torch.tensor(config.loss_weight)
        loss_weight = loss_weight.to(device)
    else:
        loss_weight = None

    loss_criterion = get_loss_criterion(config.loss, loss_weight,
                                        config.ignore_index)

    model = UNet3D(config.in_channels,
                   config.out_channels,
                   init_channel_number=config.init_channel_number,
                   conv_layer_order=config.layer_order,
                   interpolate=config.interpolate,
                   final_sigmoid=config.final_sigmoid)

    model = model.to(device)

    # Log the number of learnable parameters
    logger.info(
        f'Number of learnable params {get_number_of_learnable_parameters(model)}'
    )

    # Create evaluation metric
    eval_criterion = get_evaluation_metric(config.eval_metric,
                                           ignore_index=config.ignore_index)

    # Get data loaders. If 'bce' or 'dice' loss is used, convert labels to float
    train_path, val_path = config.train_path, config.val_path
    if config.loss in ['bce']:
        label_dtype = 'float32'
    else:
        label_dtype = 'long'

    train_patch = tuple(config.train_patch)
    train_stride = tuple(config.train_stride)
    val_patch = tuple(config.val_patch)
    val_stride = tuple(config.val_stride)

    logger.info(f'Train patch/stride: {train_patch}/{train_stride}')
    logger.info(f'Val patch/stride: {val_patch}/{val_stride}')

    pixel_wise_weight = config.loss == 'pce'
    loaders = get_loaders(train_path,
                          val_path,
                          label_dtype=label_dtype,
                          raw_internal_path=config.raw_internal_path,
                          label_internal_path=config.label_internal_path,
                          train_patch=train_patch,
                          train_stride=train_stride,
                          val_patch=val_patch,
                          val_stride=val_stride,
                          transformer=config.transformer,
                          pixel_wise_weight=pixel_wise_weight,
                          curriculum_learning=config.curriculum,
                          ignore_index=config.ignore_index)

    # Create the optimizer
    optimizer = _create_optimizer(config, model)

    if config.resume:
        trainer = UNet3DTrainer.from_checkpoint(config.resume,
                                                model,
                                                optimizer,
                                                loss_criterion,
                                                eval_criterion,
                                                loaders,
                                                logger=logger)
    else:
        trainer = UNet3DTrainer(
            model,
            optimizer,
            loss_criterion,
            eval_criterion,
            device,
            loaders,
            config.checkpoint_dir,
            max_num_epochs=config.epochs,
            max_num_iterations=config.iters,
            max_patience=config.patience,
            validate_after_iters=config.validate_after_iters,
            log_after_iters=config.log_after_iters,
            logger=logger)

    trainer.fit()
예제 #19
0
                        nargs='+',
                        default=None,
                        help='Files to use as test cases')
    return parser.parse_args()


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO,
                        format='%(levelname)s: %(message)s')
    args = get_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    input_channels = 1
    output_channels = 2
    net = UNet3D(in_channels=input_channels, out_channels=output_channels)
    logging.info(f'Network:\n'
                 f'\t{input_channels} input channels\n'
                 f'\t{output_channels} output channels (classes)\n')

    if args.load:
        net.load_state_dict(torch.load(args.load, map_location=device))
        logging.info(f'Model loaded from {args.load}')

    net.to(device=device)
    # faster convolutions, but more memory
    # cudnn.benchmark = True

    try:
        train_net(model=net,
                  epochs=args.epochs,
예제 #20
0
                        "--verbose",
                        default=False,
                        action="store_true",
                        help="Log more detail")

    return parser.parse_args()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
    args = get_args()

    input_channels = 1
    output_channels = 2
    net = UNet3D(in_channels=input_channels,
                 out_channels=output_channels,
                 testing=True,
                 final_sigmoid=False)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net.to(device=device)
    logging.info(f'Using device {device}')

    base_path = Path(__file__).parent

    file_path = (base_path / "../models/test.csv").resolve()
    pre_net = False
    if args.type == 'mask':
        model = (base_path / "../models/BrainMask.pth").resolve()
    elif args.type == 'wm':
        model = (base_path / "../models/WhiteMatter.pth").resolve()
        pre_net = (base_path / "../models/BrainMask.pth").resolve()
예제 #21
0
파일: predict.py 프로젝트: aronza/UNet
                        required=True)

    return parser.parse_args()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO,
                        format='%(levelname)s: %(message)s')
    args = get_args()
    in_files = listdir(args.input)
    out_files = [f'OUT_{file}' for file in in_files]

    input_channels = 1
    output_channels = 2
    net = UNet3D(in_channels=input_channels,
                 out_channels=output_channels,
                 testing=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net.to(device=device)
    logging.info(f'Using device {device}')

    logging.info("Loading model {}".format(args.model))
    net.load_state_dict(torch.load(args.model, map_location=device))

    logging.info("Model loaded !")

    for i, fn in enumerate(in_files):
        logging.info("\nPredicting image {} ...".format(fn))

        nib_img = nib.load(join(args.input, fn))
예제 #22
0
def train_net(model: UNet3D,
              device,
              loss_fnc=DiceLoss(sigmoid_normalization=False),
              eval_criterion=MeanIoU(),
              epochs=5,
              batch_size=1,
              learning_rate=0.0002,
              val_percent=0.04,
              test_percent=0.1,
              name='U-Net',
              save_cp=True,
              tests=None):
    data_set = BasicDataset(dir_img, dir_mask, 'T1', device)
    train_loader, val_loader, test_loader = data_set.split_to_loaders(
        val_percent, test_percent, batch_size, test_files=tests)

    writer = SummaryWriter(comment=f'LR_{learning_rate}_BS_{batch_size}')
    global_step = 0
    logging.info(f'''Starting {name} training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {learning_rate}
        Training size:   {len(train_loader)}
        Validation size: {len(val_loader)}
        Testing size:    {len(test_loader)}
        Checkpoints:     {save_cp}
        Device:          {device.type}
    ''')

    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           weight_decay=0.00001)
    losses = []
    val_scores = []

    for epoch in range(epochs):

        epoch_loss = 0
        for batch in train_loader:
            model.train()
            start_time = timeit.default_timer()

            img = batch['image']
            mask = batch['mask']

            masks_pred = model(img)

            loss = loss_fnc(masks_pred, mask)

            epoch_loss += loss.item()
            losses.append(loss.item())

            writer.add_scalar('Loss/train', loss.item(), global_step)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

            global_step += 1
            elapsed = timeit.default_timer() - start_time
            logging.info(
                f'I: {global_step}, Loss: {loss.item()} in {elapsed} seconds')

            if global_step % (len(train_loader) // (5 * batch_size)) == 0:
                val_score = validate(model, val_loader, loss_fnc,
                                     eval_criterion)
                val_scores.append(val_score)

                writer.add_scalar('Validation/test', val_score, global_step)

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(model.state_dict(),
                       dir_checkpoint + f'{name}_epoch{epoch + 1}.pth')
            logging.info(f'Epoch: {epoch + 1} Loss: {epoch_loss}')
            logging.info(f'Checkpoint {epoch + 1} saved !')
            plot_cost(losses, name='Loss' + str(epoch), model_name=name)
            plot_cost(val_scores,
                      name='Validation' + str(epoch),
                      model_name=name)

    writer.close()
예제 #23
0
def main():
    parser = argparse.ArgumentParser(description='3D U-Net predictions')
    parser.add_argument('--model-path',
                        required=True,
                        type=str,
                        help='path to the model')
    parser.add_argument('--in-channels',
                        required=True,
                        type=int,
                        help='number of input channels')
    parser.add_argument('--out-channels',
                        required=True,
                        type=int,
                        help='number of output channels')
    parser.add_argument(
        '--init-channel-number',
        type=int,
        default=64,
        help=
        'Initial number of feature maps in the encoder path which gets doubled on every stage (default: 64)'
    )
    parser.add_argument('--interpolate',
                        help='use F.interpolate instead of ConvTranspose3d',
                        action='store_true')
    parser.add_argument(
        '--layer-order',
        type=str,
        help="Conv layer ordering, e.g. 'crg' -> Conv3D+ReLU+GroupNorm",
        default='crg')
    parser.add_argument(
        '--final-sigmoid',
        action='store_true',
        help=
        'if True apply element-wise nn.Sigmoid after the last layer otherwise apply nn.Softmax'
    )
    parser.add_argument('--test-path',
                        type=str,
                        required=True,
                        help='path to the test dataset')
    parser.add_argument('--raw-internal-path', type=str, default='raw')
    parser.add_argument(
        '--patch',
        required=True,
        type=int,
        nargs='+',
        default=None,
        help='Patch shape for used for prediction on the test set')
    parser.add_argument(
        '--stride',
        required=True,
        type=int,
        nargs='+',
        default=None,
        help='Patch stride for used for prediction on the test set')

    args = parser.parse_args()

    # make sure those values correspond to the ones used during training
    in_channels = args.in_channels
    out_channels = args.out_channels
    # use F.interpolate for upsampling
    interpolate = args.interpolate
    layer_order = args.layer_order
    final_sigmoid = args.final_sigmoid
    model = UNet3D(in_channels,
                   out_channels,
                   init_channel_number=args.init_channel_number,
                   final_sigmoid=final_sigmoid,
                   interpolate=interpolate,
                   conv_layer_order=layer_order)

    logger.info(f'Loading model from {args.model_path}...')
    utils.load_checkpoint(args.model_path, model)

    logger.info('Loading datasets...')

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        logger.warning('No CUDA device available. Using CPU for predictions')
        device = torch.device('cpu')

    model = model.to(device)

    patch = tuple(args.patch)
    stride = tuple(args.stride)

    dataset = HDF5Dataset(args.test_path,
                          patch,
                          stride,
                          phase='test',
                          raw_internal_path=args.raw_internal_path)
    probability_maps = predict(model, dataset, out_channels, device)

    output_file = f'{os.path.splitext(args.test_path)[0]}_probabilities.h5'

    # average channels only in case of final_sigmoid
    save_predictions(probability_maps, output_file, final_sigmoid)
예제 #24
0
def main():
    parser = _arg_parser()
    logger = get_logger('Trainer')
    # Get device to train on
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

    args = parser.parse_args()

    if args.loss_weight is not None:
        loss_weight = torch.tensor(args.loss_weight)
        loss_weight = loss_weight.to(device)
    else:
        loss_weight = None

    if args.network == 'cd':
        args.loss = 'mse'
        loss_criterion = get_loss_criterion('mse', loss_weight,
                                            args.ignore_index)

        model = CoorNet(args.in_channels)

        model = model.to(device)

        accuracy_criterion = PrecisionBasedAccuracy(30)

    elif args.network == 'seg':
        if not args.loss:
            raise ValueError("Invalid loss assigned.")
        loss_criterion = get_loss_criterion(args.loss, loss_weight,
                                            args.ignore_index)

        model = UNet3D(args.in_channels,
                       args.out_channels,
                       init_channel_number=args.init_channel_number,
                       conv_layer_order=args.layer_order,
                       interpolate=True,
                       final_sigmoid=args.final_sigmoid)

        model = model.to(device)

        accuracy_criterion = DiceCoefficient(ignore_index=args.ignore_index)

    else:
        raise ValueError(
            "Incorrect network type defined by the --network argument, either cd or seg."
        )

    # Get data loaders. If 'bce' or 'dice' loss is used, convert labels to float
    train_path = args.train_path
    if args.loss in ['bce', 'mse']:
        label_dtype = 'float32'
    else:
        label_dtype = 'long'

    train_patch = tuple(args.train_patch)
    train_stride = tuple(args.train_stride)

    pixel_wise_weight = args.loss == 'pce'

    loaders = get_loaders(train_path,
                          label_dtype=label_dtype,
                          raw_internal_path=args.raw_internal_path,
                          label_internal_path=args.label_internal_path,
                          train_patch=train_patch,
                          train_stride=train_stride,
                          transformer=args.transformer,
                          pixel_wise_weight=pixel_wise_weight,
                          curriculum_learning=args.curriculum,
                          ignore_index=args.ignore_index)

    # Create the optimizer
    optimizer = _create_optimizer(args, model)

    if args.resume:
        trainer = UNet3DTrainer.from_checkpoint(args.resume,
                                                model,
                                                optimizer,
                                                loss_criterion,
                                                accuracy_criterion,
                                                loaders,
                                                logger=logger)
    else:
        trainer = UNet3DTrainer(model,
                                optimizer,
                                loss_criterion,
                                accuracy_criterion,
                                device,
                                loaders,
                                args.checkpoint_dir,
                                max_num_epochs=args.epochs,
                                max_num_iterations=args.iters,
                                max_patience=args.patience,
                                validate_after_iters=args.validate_after_iters,
                                log_after_iters=args.log_after_iters,
                                logger=logger)

    trainer.fit()