Exemplo n.º 1
0
def main():
    # Create main logger
    logger = get_logger('UNet3DTrainer')

    # Load and log experiment configuration
    config = load_config()
    logger.info(config)

    manual_seed = config.get('manual_seed', None)
    if manual_seed is not None:
        logger.info(f'Seed the RNG for all devices with {manual_seed}')
        torch.manual_seed(manual_seed)
        # see https://pytorch.org/docs/stable/notes/randomness.html
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Create the model
    model = get_model(config)
    # put the model on GPUs
    logger.info(f"Sending the model to '{config['device']}'")

    model = model.to(config['device'])
    # Log the number of learnable parameters
    logger.info(
        f'Number of learnable params {get_number_of_learnable_parameters(model)}'
    )

    # Create loss criterion
    loss_criterion = get_loss_criterion(config)
    # Create evaluation metric
    eval_criterion = get_evaluation_metric(config)

    # Cross validation
    path_to_folder = config['loaders']['all_data_path'][0]
    cross_walidation = CrossValidation(path_to_folder, 1, 3, 2)
    train_set = cross_walidation.train_filepaths
    val_set = cross_walidation.validation_filepaths
    config['loaders']['train_path'] = train_set
    config['loaders']['val_path'] = val_set

    # Create data loaders
    loaders = get_train_loaders(config)

    # Create the optimizer
    optimizer = _create_optimizer(config, model)

    # Create learning rate adjustment strategy
    lr_scheduler = _create_lr_scheduler(config, optimizer)

    # Create model trainer
    trainer = _create_trainer(config,
                              model=model,
                              optimizer=optimizer,
                              lr_scheduler=lr_scheduler,
                              loss_criterion=loss_criterion,
                              eval_criterion=eval_criterion,
                              loaders=loaders,
                              logger=logger)
    # Start training
    trainer.fit()
Exemplo n.º 2
0
def main():
    # Load configuration
    config = load_config()

    # create logger
    logfile = config.get('logfile', None)
    logger = utils.get_logger('UNet3DPredictor', logfile=logfile)

    # Create the model
    model = get_model(config)

    # multiple GPUs
    if (torch.cuda.device_count() > 1):
        logger.info("There are {} GPUs available".format(
            torch.cuda.device_count()))
        model = nn.DataParallel(model)

    # Load model state
    model_path = config['model_path']
    logger.info(f'Loading model from {model_path}...')
    utils.load_checkpoint(model_path, model)
    logger.info(f"Sending the model to '{config['device']}'")
    model = model.to(config['device'])

    logger.info('Loading HDF5 datasets...')
    for test_loader in get_test_loaders(config):
        logger.info(f"Processing '{test_loader.dataset.file_path}'...")

        #output_file = _get_output_file(test_loader.dataset)
        output_file = _get_output_file(config['output_folder'],
                                       test_loader.dataset)
        logger.info(output_file)
        predictor = _get_predictor(model, test_loader, output_file, config)
        # run the model prediction on the entire dataset and save to the 'output_file' H5
        predictor.predict()
Exemplo n.º 3
0
def main():
    logger = get_logger('UNet3DTrainer')

    config = load_config()
    logger.info(config)

    manual_seed = config.get('manual_seed', None)
    if manual_seed is not None:
        logger.info(f'Seed the RNG for all devices with {manual_seed}')
        torch.manual_seed(manual_seed)
        # see https://pytorch.org/docs/stable/notes/randomness.html
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    #loaders = get_train_loaders(config)

    #print(loaders['test'].__len__())

    #for i, t in enumerate(loaders['test']):
    #    for tt in t:
    #        print(i, tt.shape)

    for loader in get_test_loaders(config):
        print(loader.dataset.__getid__())
        print(loader.dataset.__len__())
Exemplo n.º 4
0
    def _train_save_load(self, tmpdir, loss, val_metric, model='UNet3D', max_num_epochs=1, log_after_iters=2,
                         validate_after_iters=2, max_num_iterations=4, weight_map=False):
        binary_loss = loss in ['BCEWithLogitsLoss', 'DiceLoss', 'GeneralizedDiceLoss']

        device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

        test_config = copy.deepcopy(CONFIG_BASE)
        test_config['model']['name'] = model
        test_config.update({
            # get device to train on
            'device': device,
            'loss': {'name': loss, 'weight': np.random.rand(2).astype(np.float32)},
            'eval_metric': {'name': val_metric}
        })
        test_config['model']['final_sigmoid'] = binary_loss

        if weight_map:
            test_config['loaders']['weight_internal_path'] = 'weight_map'

        loss_criterion = get_loss_criterion(test_config)
        eval_criterion = get_evaluation_metric(test_config)
        model = get_model(test_config)
        model = model.to(device)

        if loss in ['BCEWithLogitsLoss']:
            label_dtype = 'float32'
        else:
            label_dtype = 'long'
        test_config['loaders']['transformer']['train']['label'][0]['dtype'] = label_dtype
        test_config['loaders']['transformer']['test']['label'][0]['dtype'] = label_dtype

        train, val = TestUNet3DTrainer._create_random_dataset((128, 128, 128), (64, 64, 64), binary_loss)
        test_config['loaders']['train_path'] = [train]
        test_config['loaders']['val_path'] = [val]

        loaders = get_train_loaders(test_config)

        optimizer = _create_optimizer(test_config, model)

        test_config['lr_scheduler']['name'] = 'MultiStepLR'
        lr_scheduler = _create_lr_scheduler(test_config, optimizer)

        logger = get_logger('UNet3DTrainer', logging.DEBUG)

        formatter = DefaultTensorboardFormatter()
        trainer = UNet3DTrainer(model, optimizer, lr_scheduler,
                                loss_criterion, eval_criterion,
                                device, loaders, tmpdir,
                                max_num_epochs=max_num_epochs,
                                log_after_iters=log_after_iters,
                                validate_after_iters=validate_after_iters,
                                max_num_iterations=max_num_iterations,
                                logger=logger, tensorboard_formatter=formatter)
        trainer.fit()
        # test loading the trainer from the checkpoint
        trainer = UNet3DTrainer.from_checkpoint(os.path.join(tmpdir, 'last_checkpoint.pytorch'),
                                                model, optimizer, lr_scheduler,
                                                loss_criterion, eval_criterion,
                                                loaders, logger=logger, tensorboard_formatter=formatter)
        return trainer
Exemplo n.º 5
0
def main():
    parser = _arg_parser()
    logger = get_logger('UNet3DTrainer')
    # Get device to train on
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

    args = parser.parse_args()

    logger.info(args)

    # Create loss criterion
    loss_criterion, final_sigmoid = _get_loss_criterion(args.loss)

    model = _create_model(args.in_channels,
                          args.out_channels,
                          layer_order=args.layer_order,
                          interpolate=args.interpolate,
                          final_sigmoid=final_sigmoid)

    model = model.to(device)

    # Log the number of learnable parameters
    logger.info(
        f'Number of learnable params {get_number_of_learnable_parameters(model)}'
    )

    # Create error criterion
    error_criterion = DiceCoefficient()

    # Get data loaders
    loaders = _get_loaders(args.config_dir, logger)

    # Create the optimizer
    optimizer = _create_optimizer(args, model)

    if args.resume:
        trainer = UNet3DTrainer.from_checkpoint(
            args.resume,
            model,
            optimizer,
            loss_criterion,
            error_criterion,
            loaders,
            validate_after_iters=args.validate_after_iters,
            log_after_iters=args.log_after_iters,
            logger=logger)
    else:
        trainer = UNet3DTrainer(model,
                                optimizer,
                                loss_criterion,
                                error_criterion,
                                device,
                                loaders,
                                args.checkpoint_dir,
                                validate_after_iters=args.validate_after_iters,
                                log_after_iters=args.log_after_iters,
                                logger=logger)

    trainer.fit()
Exemplo n.º 6
0
 def __init__(self, model, loader, output_file, config, **kwargs):
     self.model = model
     self.loader = loader
     self.output_file = output_file
     self.config = config
     self.predictor_config = kwargs
     self.logfile = self.config.get('logfile', None)
     self.logger = get_logger('UNet3DTrainer', logfile=self.logfile)
Exemplo n.º 7
0
def main():
    # Load and log experiment configuration
    config = load_config()

    # Create main logger
    logger = get_logger('UNet3DTrainer',
                        file_name=config['trainer']['checkpoint_dir'])
    logger.info(config)

    os.environ['CUDA_VISIBLE_DEVICES'] = config['default_device']
    assert torch.cuda.is_available(), "Currently, we only support CUDA version"

    manual_seed = config.get('manual_seed', None)
    if manual_seed is not None:
        logger.info(f'Seed the RNG for all devices with {manual_seed}')
        torch.manual_seed(manual_seed)
        # see https://pytorch.org/docs/stable/notes/randomness.html
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Create the model
    model = get_model(config)
    # model, parameters = generate_model(MedConfig)

    # put the model on GPUs
    logger.info(f"Sending the model to '{config['default_device']}'")
    model = torch.nn.DataParallel(model).cuda()

    # Log the number of learnable parameters
    logger.info(
        f'Number of learnable params {get_number_of_learnable_parameters(model)}'
    )

    # Create loss criterion
    loss_criterion = get_loss_criterion(config)
    # Create evaluation metric
    eval_criterion = get_evaluation_metric(config)

    # Create data loaders
    loaders = get_brats_train_loaders(config)

    # Create the optimizer
    optimizer = _create_optimizer(config, model)

    # Create learning rate adjustment strategy
    lr_scheduler = _create_lr_scheduler(config, optimizer)

    # Create model trainer
    trainer = _create_trainer(config,
                              model=model,
                              optimizer=optimizer,
                              lr_scheduler=lr_scheduler,
                              loss_criterion=loss_criterion,
                              eval_criterion=eval_criterion,
                              loaders=loaders,
                              logger=logger)
    # Start training
    trainer.fit()
Exemplo n.º 8
0
def main():
    # Load and log experiment configuration
    config = load_config()
    
    # Create main logger
    logfile = config.get('logfile', None)
    logger = get_logger('UNet3DTrainer', logfile=logfile)

    logger.info(config)

    manual_seed = config.get('manual_seed', None)
    if manual_seed is not None:
        logger.info(f'Seed the RNG for all devices with {manual_seed}')
        torch.manual_seed(manual_seed)
        # see https://pytorch.org/docs/stable/notes/randomness.html
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Create the model
    model = get_model(config)

    # multiple GPUs
    if (torch.cuda.device_count() > 1):
        logger.info("There are {} GPUs available".format(torch.cuda.device_count()))
        model = nn.DataParallel(model)

    # put the model on GPUs
    logger.info(f"Sending the model to '{config['device']}'")
    model = model.to(config['device'])
    
    # Log the number of learnable parameters
    logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')

    # Create loss criterion
    loss_criterion = get_loss_criterion(config)
    logger.info(f"Created loss criterion: {config['loss']['name']}")
    
    # Create evaluation metric
    eval_criterion = get_evaluation_metric(config)
    logger.info(f"Created eval criterion: {config['eval_metric']['name']}")

    # Create data loaders
    loaders = get_train_loaders(config)

    # Create the optimizer
    optimizer = _create_optimizer(config, model)

    # Create learning rate adjustment strategy
    lr_scheduler = _create_lr_scheduler(config, optimizer)

    # Create model trainer
    trainer = _create_trainer(config, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler,
                              loss_criterion=loss_criterion, eval_criterion=eval_criterion, loaders=loaders,
                              logger=logger)
    
    # Start training
    trainer.fit()
Exemplo n.º 9
0
 def _train_save_load(self,
                      tmpdir,
                      loss,
                      max_num_epochs=1,
                      log_after_iters=2,
                      validate_after_iters=2,
                      max_num_iterations=4):
     # get device to train on
     device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
     # conv-relu-groupnorm
     conv_layer_order = 'crg'
     final_sigmoid = loss == 'bce'
     loss_criterion = get_loss_criterion(loss,
                                         final_sigmoid,
                                         weight=torch.rand(2).to(device))
     model = self._create_model(final_sigmoid, conv_layer_order)
     accuracy_criterion = DiceCoefficient()
     channel_per_class = loss == 'bce'
     if loss in ['bce', 'dice']:
         label_dtype = 'float32'
     else:
         label_dtype = 'long'
     pixel_wise_weight = loss == 'pce'
     loaders = self._get_loaders(channel_per_class=channel_per_class,
                                 label_dtype=label_dtype,
                                 pixel_wise_weight=pixel_wise_weight)
     learning_rate = 2e-4
     weight_decay = 0.0001
     optimizer = optim.Adam(model.parameters(),
                            lr=learning_rate,
                            weight_decay=weight_decay)
     logger = get_logger('UNet3DTrainer', logging.DEBUG)
     trainer = UNet3DTrainer(model,
                             optimizer,
                             loss_criterion,
                             accuracy_criterion,
                             device,
                             loaders,
                             tmpdir,
                             max_num_epochs=max_num_epochs,
                             log_after_iters=log_after_iters,
                             validate_after_iters=validate_after_iters,
                             max_num_iterations=max_num_iterations,
                             logger=logger)
     trainer.fit()
     # test loading the trainer from the checkpoint
     trainer = UNet3DTrainer.from_checkpoint(os.path.join(
         tmpdir, 'last_checkpoint.pytorch'),
                                             model,
                                             optimizer,
                                             loss_criterion,
                                             accuracy_criterion,
                                             loaders,
                                             logger=logger)
     return trainer
Exemplo n.º 10
0
def get_brats_train_loaders(config):
    """
    Returns dictionary containing the training and validation loaders
    (torch.utils.data.DataLoader) backed by the datasets.hdf5.HDF5Dataset.

    :param config: a top level configuration object containing the 'loaders' key
    :return: dict {
        'train': <train_loader>
        'val': <val_loader>
    }
    """
    assert 'loaders' in config, 'Could not find data loaders configuration'
    loaders_config = config['loaders']

    logger = get_logger('BraTS_Dataset')
    logger.info('Creating training and validation set loaders...')

    # get train and validation files
    data_paths = loaders_config['dataset_path']
    assert isinstance(data_paths, list)

    brats = BraTS.DataSet(brats_root=data_paths[0], year=2019).train
    train_ids, test_ids, validation_ids = get_all_partition_ids()
    # loss_file_num = 0
    # for i in train_ids:
    #     name = i + ".tfrecord.gzip"
    #     answer = search('/home/server/data/TFRecord/val', name)
    #     if answer == -1:
    #         print("查无此文件", name)
    #         loss_file_num += 1
    # print(f'loss file num is {loss_file_num}')

    logger.info(f'Loading training set from: {data_paths}...')
    train_datasets = BraTSDataset(brats, train_ids)

    logger.info(f'Loading validation set from: {data_paths}...')
    brats = BraTS.DataSet(brats_root=data_paths[0], year=2019).train
    val_datasets = BraTSDataset(brats, validation_ids)

    num_workers = loaders_config.get('num_workers', 1)
    logger.info(f'Number of workers for train/val datasets: {num_workers}')
    # when training with volumetric data use batch_size of 1 due to GPU memory constraints
    return {
        'train':
        DataLoader(train_datasets,
                   batch_size=1,
                   shuffle=True,
                   num_workers=num_workers),
        'val':
        DataLoader(val_datasets,
                   batch_size=1,
                   shuffle=True,
                   num_workers=num_workers)
    }
Exemplo n.º 11
0
def get_test_loaders(config):
    """
    Returns a list of DataLoader, one per each test file.

    :param config: a top level configuration object containing the 'datasets' key
    :return: generator of DataLoader objects
    """

    assert 'datasets' in config, 'Could not find data sets configuration'
    datasets_config = config['datasets']

    logger = get_logger('HDF5Dataset', logfile=config['logfile'])

    # get train and validation files
    test_paths = datasets_config['test_path']
    #print(test_paths)
    assert isinstance(test_paths, list)
    # get h5 internal path
    raw_internal_path = datasets_config['raw_internal_path']
    # get train/validation patch size and stride
    patch = tuple(datasets_config['patch'])
    stride = tuple(datasets_config['stride'])

    mirror_padding = datasets_config.get('mirror_padding', False)
    pad_width = datasets_config.get('pad_width', 20)

    if mirror_padding:
        logger.info(f'Using mirror padding. Pad width: {pad_width}')

    num_workers = datasets_config.get('num_workers', 1)
    logger.info(f'Number of workers for the dataloader: {num_workers}')

    batch_size = datasets_config.get('batch_size', 1)
    logger.info(f'Batch size for dataloader: {batch_size}')

    # construct datasets lazily
    datasets = (HDF5Dataset(test_path,
                            patch,
                            stride,
                            phase='test',
                            raw_internal_path=raw_internal_path,
                            transformer_config=datasets_config['transformer'],
                            mirror_padding=mirror_padding,
                            pad_width=pad_width) for test_path in test_paths)

    # use generator in order to create data loaders lazily one by one
    for dataset in datasets:
        logger.info(f'Loading test set from: {dataset.file_path}...')
        yield DataLoader(dataset,
                         batch_size=batch_size,
                         num_workers=num_workers,
                         collate_fn=prediction_collate)
Exemplo n.º 12
0
def get_test_loaders(config):
    """
    Returns a list of DataLoader, one per each test file.

    :param config: a top level configuration object containing the 'datasets' key
    :return: generator of DataLoader objects
    """
    def my_collate(batch):
        error_msg = "batch must contain tensors or slice; found {}"
        if isinstance(batch[0], torch.Tensor):
            return torch.stack(batch, 0)
        elif isinstance(batch[0], slice):
            return batch[0]
        elif isinstance(batch[0], collections.Sequence):
            transposed = zip(*batch)
            return [my_collate(samples) for samples in transposed]

        raise TypeError((error_msg.format(type(batch[0]))))

    logger = get_logger('HDF5Dataset')

    assert 'datasets' in config, 'Could not find data sets configuration'
    datasets_config = config['datasets']

    # get train and validation files
    test_paths = datasets_config['test_path']
    assert isinstance(test_paths, list)
    # get h5 internal path
    raw_internal_path = datasets_config['raw_internal_path']
    # get train/validation patch size and stride
    patch = tuple(datasets_config['patch'])
    stride = tuple(datasets_config['stride'])
    num_workers = datasets_config.get('num_workers', 1)

    # construct datasets lazily
    datasets = (HDF5Dataset(test_path,
                            patch,
                            stride,
                            phase='test',
                            raw_internal_path=raw_internal_path,
                            transformer_config=datasets_config['transformer'])
                for test_path in test_paths)

    # use generator in order to create data loaders lazily one by one
    for dataset in datasets:
        logger.info(f'Loading test set from: {dataset.file_path}...')
        yield DataLoader(dataset,
                         batch_size=1,
                         num_workers=num_workers,
                         collate_fn=my_collate)
Exemplo n.º 13
0
def main():
    logger = get_logger('UNet3DTrainer')

    config = load_config()

    logger.info(config)

    # Create loss criterion
    loss_criterion = get_loss_criterion(config)

    # Create the model
    model = UNet3D(config['in_channels'], config['out_channels'],
                   final_sigmoid=config['final_sigmoid'],
                   init_channel_number=config['init_channel_number'],
                   conv_layer_order=config['layer_order'],
                   interpolate=config['interpolate'])

    model = model.to(config['device'])

    # Log the number of learnable parameters
    logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')

    # Create evaluation metric
    eval_criterion = get_evaluation_metric(config)

    loaders = get_train_loaders(config)

    # Create the optimizer
    optimizer = _create_optimizer(config, model)

    # Create learning rate adjustment strategy
    lr_scheduler = _create_lr_scheduler(config, optimizer)

    if config['resume'] is not None:
        trainer = UNet3DTrainer.from_checkpoint(config['resume'], model,
                                                optimizer, lr_scheduler, loss_criterion,
                                                eval_criterion, loaders,
                                                logger=logger)
    else:
        trainer = UNet3DTrainer(model, optimizer, lr_scheduler, loss_criterion, eval_criterion,
                                config['device'], loaders, config['checkpoint_dir'],
                                max_num_epochs=config['epochs'],
                                max_num_iterations=config['iters'],
                                validate_after_iters=config['validate_after_iters'],
                                log_after_iters=config['log_after_iters'],
                                logger=logger)

    trainer.fit()
Exemplo n.º 14
0
    def test_single_epoch(self, tmpdir, capsys):
        with capsys.disabled():
            # get device to train on
            device = torch.device(
                "cuda:0" if torch.cuda.is_available() else 'cpu')

            conv_layer_order = 'crg'

            loss_criterion, final_sigmoid = DiceLoss(), True

            model = self._load_model(final_sigmoid, conv_layer_order)

            error_criterion = DiceCoefficient()

            loaders = self._get_loaders()

            learning_rate = 1e-4
            weight_decay = 0.0005
            optimizer = optim.Adam(model.parameters(),
                                   lr=learning_rate,
                                   weight_decay=weight_decay)

            logger = get_logger('UNet3DTrainer', logging.DEBUG)
            trainer = UNet3DTrainer(model,
                                    optimizer,
                                    loss_criterion,
                                    error_criterion,
                                    device,
                                    loaders,
                                    tmpdir,
                                    max_num_epochs=1,
                                    log_after_iters=2,
                                    validate_after_iters=2,
                                    logger=logger)

            trainer.fit()

            # test loading the trainer from the checkpoint
            UNet3DTrainer.from_checkpoint(os.path.join(
                tmpdir, 'last_checkpoint.pytorch'),
                                          model,
                                          optimizer,
                                          loss_criterion,
                                          error_criterion,
                                          loaders,
                                          logger=logger)
Exemplo n.º 15
0
def main():
    # Create main logger
    logger = get_logger('UNet3DTrainer')

    # Load and log experiment configuration
    config = load_config()  # Set DEFAULT_DEVICE and config file
    logger.info(config)     # Log configure from train_config_4d_input.yaml

    manual_seed = config.get('manual_seed', None)
    if manual_seed is not None:
        logger.info(f'Seed the RNG for all devices with {manual_seed}')
        torch.manual_seed(manual_seed)
        torch.backends.cudnn.deterministic = True  # Ensure the repeatability of the experiment
        torch.backends.cudnn.benchmark = False     # Benchmark mode improves the computation speed, but results in slightly different network feedforward results

    # Create the model
    model = get_model(config)
    # put the model on GPUs
    logger.info(f"Sending the model to '{config['device']}'")
    model = model.to(config['device'])
    # Log the number of learnable parameters
    logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')

    # Create loss criterion
    loss_criterion = get_loss_criterion(config)
    # Create evaluation metric
    eval_criterion = get_evaluation_metric(config)

    # Create data loaders
    # loaders: {'train': train_loader, 'val': val_loader}
    loaders = get_train_loaders(config)

    # Create the optimizer
    optimizer = _create_optimizer(config, model)

    # Create learning rate adjustment strategy
    lr_scheduler = _create_lr_scheduler(config, optimizer)

    # Create model trainer
    trainer = _create_trainer(config, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler,
                              loss_criterion=loss_criterion, eval_criterion=eval_criterion, loaders=loaders,
                              logger=logger)
    # Start training
    trainer.fit()
Exemplo n.º 16
0
    def _train_save_load(self, tmpdir, loss, val_metric, max_num_epochs=1, log_after_iters=2, validate_after_iters=2,
                         max_num_iterations=4):
        # get device to train on
        device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
        # conv-relu-groupnorm
        conv_layer_order = 'crg'
        final_sigmoid = loss == 'bce'
        loss_criterion = get_loss_criterion(loss, weight=torch.rand(2).to(device))
        eval_criterion = get_evaluation_metric(val_metric)
        model = self._create_model(final_sigmoid, conv_layer_order)
        channel_per_class = loss == 'bce'
        if loss in ['bce']:
            label_dtype = 'float32'
        else:
            label_dtype = 'long'
        pixel_wise_weight = loss == 'pce'

        patch = (32, 64, 64)
        stride = (32, 64, 64)
        train, val = TestUNet3DTrainer._create_random_dataset((128, 128, 128), (64, 64, 64), channel_per_class)
        loaders = get_loaders([train], [val], 'raw', 'label', label_dtype=label_dtype, train_patch=patch,
                              train_stride=stride, val_patch=patch, val_stride=stride, transformer='BaseTransformer',
                              pixel_wise_weight=pixel_wise_weight)

        learning_rate = 2e-4
        weight_decay = 0.0001
        optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
        logger = get_logger('UNet3DTrainer', logging.DEBUG)
        trainer = UNet3DTrainer(model, optimizer, loss_criterion,
                                eval_criterion,
                                device, loaders, tmpdir,
                                max_num_epochs=max_num_epochs,
                                log_after_iters=log_after_iters,
                                validate_after_iters=validate_after_iters,
                                max_num_iterations=max_num_iterations,
                                logger=logger)
        trainer.fit()
        # test loading the trainer from the checkpoint
        trainer = UNet3DTrainer.from_checkpoint(
            os.path.join(tmpdir, 'last_checkpoint.pytorch'),
            model, optimizer, loss_criterion, eval_criterion, loaders,
            logger=logger)
        return trainer
Exemplo n.º 17
0
import importlib

import os
import numpy as np
import time
import torch
import torch.nn.functional as F
from skimage import measure
import hdbscan
from sklearn.cluster import MeanShift

from unet3d.losses import compute_per_channel_dice
from unet3d.utils import get_logger, adapted_rand, expand_as_one_hot, plot_segm

LOGGER = get_logger('EvalMetric')


class DiceCoefficient:
    """Computes Dice Coefficient.
    Generalized to multiple channels by computing per-channel Dice Score
    (as described in https://arxiv.org/pdf/1707.03237.pdf) and theTn simply taking the average.
    Input is expected to be probabilities instead of logits.
    This metric is mostly useful when channels contain the same semantic class (e.g. affinities computed with different offsets).
    DO NOT USE this metric when training with DiceLoss, otherwise the results will be biased towards the loss.
    """
    def __init__(self, epsilon=1e-5, ignore_index=None, **kwargs):
        self.epsilon = epsilon
        self.ignore_index = ignore_index

    def __call__(self, input, target):
        """
Exemplo n.º 18
0
    def __init__(self,
                 file_path,
                 patch_shape,
                 stride_shape,
                 phase,
                 transformer_config,
                 raw_internal_path='raw',
                 label_internal_path='label',
                 weight_internal_path=None,
                 slice_builder_cls=SliceBuilder,
                 mirror_padding=False,
                 pad_width=20,
                 logfile=None):
        """
        :param file_path: path to H5 file containing raw data as well as labels and per pixel weights (optional)
        :param patch_shape: the shape of the patch DxHxW
        :param stride_shape: the shape of the stride DxHxW
        :param phase: 'train' for training, 'val' for validation, 'test' for testing; data augmentation is performed
            only during the 'train' phase
        :param transformer_config: data augmentation configuration
        :param raw_internal_path (str or list): H5 internal path to the raw dataset
        :param label_internal_path (str or list): H5 internal path to the label dataset
        :param weight_internal_path (str or list): H5 internal path to the per pixel weights
        :param slice_builder_cls: defines how to sample the patches from the volume
        :param mirror_padding (bool): pad with the reflection of the vector mirrored on the first and last values
            along each axis. Only applicable during the 'test' phase
        :param pad_width: number of voxels padded to the edges of each axis (only if `mirror_padding=True`)
        """
        assert phase in ['train', 'val', 'test']
        self._check_patch_shape(patch_shape)
        self.phase = phase
        self.file_path = file_path

        self.mirror_padding = mirror_padding
        self.pad_width = pad_width
        self.logfile = logfile

        self.logger = get_logger('HDF5Dataset', logfile=self.logfile)

        # convert raw_internal_path, label_internal_path and weight_internal_path to list for ease of computation
        if isinstance(raw_internal_path, str):
            raw_internal_path = [raw_internal_path]
        if isinstance(label_internal_path, str):
            label_internal_path = [label_internal_path]
        if isinstance(weight_internal_path, str):
            weight_internal_path = [weight_internal_path]

        with h5py.File(file_path, 'r') as input_file:
            # WARN: we load everything into memory due to hdf5 bug when reading H5 from multiple subprocesses, i.e.
            # File "h5py/_proxy.pyx", line 84, in h5py._proxy.H5PY_H5Dread
            # OSError: Can't read data (inflate() failed)
            self.raws = [
                input_file[internal_path][...]
                for internal_path in raw_internal_path
            ]
            # calculate global mean and std for Normalization augmentation
            mean, std = self._calculate_mean_std(self.raws[0])

            self.transformer = transforms.get_transformer(
                transformer_config, mean, std, phase)
            self.raw_transform = self.transformer.raw_transform()

            if phase != 'test':
                # create label/weight transform only in train/val phase
                self.label_transform = self.transformer.label_transform()
                self.labels = [
                    input_file[internal_path][...]
                    for internal_path in label_internal_path
                ]

                if weight_internal_path is not None:
                    # look for the weight map in the raw file
                    self.weight_maps = [
                        input_file[internal_path][...]
                        for internal_path in weight_internal_path
                    ]
                    self.weight_transform = self.transformer.weight_transform()
                else:
                    self.weight_maps = None

                self._check_dimensionality(self.raws, self.labels)
            else:
                # 'test' phase used only for predictions so ignore the label dataset
                self.labels = None
                self.weight_maps = None

                # add mirror padding if needed
                if self.mirror_padding:
                    padded_volumes = []
                    for raw in self.raws:
                        if raw.ndim == 4:
                            channels = [
                                np.pad(r,
                                       pad_width=self.pad_width,
                                       mode='reflect') for r in raw
                            ]
                            padded_volume = np.stack(channels)
                        else:
                            padded_volume = np.pad(raw,
                                                   pad_width=self.pad_width,
                                                   mode='reflect')

                        padded_volumes.append(padded_volume)

                    self.raws = padded_volumes

            # build slice indices for raw and label data sets
            slice_builder = slice_builder_cls(self.raws, self.labels,
                                              self.weight_maps, patch_shape,
                                              stride_shape)
            self.raw_slices = slice_builder.raw_slices
            self.label_slices = slice_builder.label_slices
            self.weight_slices = slice_builder.weight_slices

            self.patch_count = len(self.raw_slices)
            self.logger.info(f'Number of patches: {self.patch_count}')
Exemplo n.º 19
0
import os

import h5py
import numpy as np
import torch

from datasets.hdf5 import get_test_datasets
from unet3d import utils
from unet3d.config import load_config
from unet3d.model import get_model

logger = utils.get_logger('UNet3DPredictor')


def predict(model, hdf5_dataset, config):
    """
    Return prediction masks by applying the model on the given dataset

    Args:
        model (Unet3D): trained 3D UNet model used for prediction
        hdf5_dataset (torch.utils.data.Dataset): input dataset
        out_channels (int): number of channels in the network output
        device (torch.Device): device to run the prediction on

    Returns:
         prediction_maps (numpy array): prediction masks for given dataset
    """

    def _volume_shape(hdf5_dataset):
        #TODO: support multiple internal datasets
        raw = hdf5_dataset.raws[0]
Exemplo n.º 20
0
import hdbscan
import numpy as np
import torch
from skimage import measure
from skimage.metrics import adapted_rand_error, peak_signal_noise_ratio
from sklearn.cluster import MeanShift

# from pytorch3dunet.unet3d.losses import compute_per_channel_dice
# from pytorch3dunet.unet3d.seg_metrics import AveragePrecision, Accuracy
# from pytorch3dunet.unet3d.utils import get_logger, expand_as_one_hot, plot_segm, convert_to_numpy
from unet3d.losses import compute_per_channel_dice
from unet3d.seg_metrics import AveragePrecision, Accuracy
from unet3d.utils import get_logger, expand_as_one_hot, plot_segm, convert_to_numpy

logger = get_logger('EvalMetric')


class DiceCoefficient:
    """Computes Dice Coefficient.
    Generalized to multiple channels by computing per-channel Dice Score
    (as described in https://arxiv.org/pdf/1707.03237.pdf) and theTn simply taking the average.
    Input is expected to be probabilities instead of logits.
    This metric is mostly useful when channels contain the same semantic class (e.g. affinities computed with different offsets).
    DO NOT USE this metric when training with DiceLoss, otherwise the results will be biased towards the loss.
    """
    def __init__(self, epsilon=1e-6, **kwargs):
        self.epsilon = epsilon

    def __call__(self, input, target):
        # Average across channels in order to get the final score
Exemplo n.º 21
0
import collections
import importlib

import h5py
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset

import augment.transforms as transforms
from unet3d.utils import get_logger

logger = get_logger('HDF5Dataset')


class SliceBuilder:
    def __init__(self, raw_datasets, label_datasets, weight_dataset, patch_shape, stride_shape):
        self._raw_slices = self._build_slices(raw_datasets[0], patch_shape, stride_shape)
        if label_datasets is None:
            self._label_slices = None
        else:
            # take the first element in the label_datasets to build slices
            self._label_slices = self._build_slices(label_datasets[0], patch_shape, stride_shape)
            assert len(self._raw_slices) == len(self._label_slices)
        if weight_dataset is None:
            self._weight_slices = None
        else:
            self._weight_slices = self._build_slices(weight_dataset[0], patch_shape, stride_shape)
            assert len(self.raw_slices) == len(self._weight_slices)

    @property
    def raw_slices(self):
Exemplo n.º 22
0
def get_test_loaders(config):
    """
    Returns a list of DataLoader, one per each test file.

    :param config: a top level configuration object containing the 'datasets' key
    :return: generator of DataLoader objects
    """

    assert 'datasets' in config, 'Could not find data sets configuration'
    datasets_config = config['datasets']

    logfile = config.get('logfile', None)
    logger = get_logger('CloudVolumeDataset', logfile=logfile)

    # get test data information
    image_cv_path = datasets_config['image_cv_path']
    seg_cv_path = datasets_config['seg_cv_path']
    #cutout_pkl_file = datasets_config['cutout_pkl_file']
    mip_level = datasets_config['mip_level']
    volume_start = datasets_config['volume_start']
    volume_end = datasets_config['volume_end']
    id = datasets_config['id']

    cutout_bounds = [
        volume_start[0], volume_end[0], volume_start[1], volume_end[1],
        volume_start[2], volume_end[2]
    ]

    image_cv = cloudvolume.CloudVolume(image_cv_path,
                                       mip=mip_level,
                                       bounded=False,
                                       autocrop=True,
                                       fill_missing=True)

    seg_cv = None
    if not (seg_cv_path is None):
        seg_cv = cloudvolume.CloudVolume(seg_cv_path,
                                         mip=mip_level,
                                         bounded=False,
                                         autocrop=True,
                                         fill_missing=True)

    patch = tuple(datasets_config['patch'])
    stride = tuple(datasets_config['stride'])

    mirror_padding = datasets_config.get('mirror_padding', False)
    pad_width = datasets_config.get('pad_width', 20)

    if mirror_padding:
        logger.info(f'Using mirror padding. Pad width: {pad_width}')

    num_workers = datasets_config.get('num_workers', 1)
    logger.info(f'Number of workers for the dataloader: {num_workers}')

    batch_size = datasets_config.get('batch_size', 1)
    logger.info(f'Batch size for dataloader: {batch_size}')

    #with open(cutout_pkl_file, 'rb') as f:
    #    df = pickle.load(f)

    test_datasets = []

    test_dataset = CloudVolumeDataset(
        image_cv,
        seg_cv,
        id,
        cutout_bounds,
        mip_level,
        'test',
        patch,
        stride,
        transformer_config=datasets_config['transformer'],
        mirror_padding=mirror_padding,
        pad_width=pad_width)
    test_datasets.append(test_dataset)

    #for indx, row in df.iterrows():
    #    if row['id'] == 0 or row['phase'] != 'test':
    #        continue
    #    test_dataset = CloudVolumeDataset(image_cv, seg_cv, row['id'], list(row['cutout_bounds']),
    #                                      mip_level, 'test', patch, stride,
    #                                      transformer_config=datasets_config['transformer'],
    #                                      mirror_padding=mirror_padding, pad_width=pad_width)
    #    test_datasets.append(test_dataset)

    # use generator in order to create data loaders lazily one by one
    for indx, dataset in enumerate(test_datasets):
        logger.info(f'Loading test set no: {indx}')
        yield DataLoader(dataset,
                         batch_size=batch_size,
                         num_workers=num_workers,
                         collate_fn=prediction_collate)
Exemplo n.º 23
0
    def __init__(self,
                 image_cv,
                 seg_cv,
                 id,
                 bounds,
                 mip_level,
                 phase,
                 patch_shape,
                 stride_shape,
                 transformer_config,
                 slice_builder_cls=SliceBuilder,
                 mirror_padding=False,
                 pad_width=20,
                 logfile=None):

        assert phase in ['train', 'val', 'test']
        self.phase = phase
        self._check_patch_shape(patch_shape)
        self.image_cv = image_cv
        self.seg_cv = seg_cv
        self.id = id
        assert isinstance(bounds, list)
        self.bounds = bounds
        self.mip_level = mip_level

        self.mirror_padding = mirror_padding
        self.pad_width = pad_width

        self.logger = get_logger('CloudVolumeDataset', logfile=logfile)

        minx, maxx, miny, maxy, minz, maxz = self.bounds
        self.raws = []
        img = np.squeeze(self.image_cv[minx:maxx, miny:maxy, minz:maxz, 0])
        # transpose (the data is always CDHW)
        img = np.transpose(img, (2, 0, 1))
        self.raws.append(img)

        mean, std = self._calculate_mean_std(self.raws[0])

        self.transformer = transforms.get_transformer(transformer_config, mean,
                                                      std, phase)
        self.raw_transform = self.transformer.raw_transform()

        if phase != 'test':
            # create label/weight transform only in train/val phase
            self.label_transform = self.transformer.label_transform()
            label = np.squeeze(self.seg_cv[minx:maxx, miny:maxy, minz:maxz, 0])
            label = np.where(label / np.ndarray.max(label) >= 0.2, 1, 0)
            # transpose the label (the data is always CDHW)
            label = np.transpose(label, (2, 0, 1))
            self.labels = [label]

            self.weight_maps = None

            self._check_dimensionality(self.raws, self.labels)
        else:
            self.labels = None
            self.weight_maps = None

            # add mirror padding if needed
            if self.mirror_padding:
                padded_volumes = [
                    np.pad(raw, pad_width=self.pad_width, mode='reflect')
                    for raw in self.raws
                ]
                self.raws = padded_volumes

        #print(self.raws[0].shape, self.labels[0].shape)
        #print(np.min(self.labels[0][:,:,200]), np.max(self.labels[0][:,:,200]))
        #plt.imshow(self.labels[0][:,:,200])
        #plt.show()
        #plt.imshow(self.raws[0][:,:,200])
        #plt.show()

        # build slice indices for raw and label data sets
        slice_builder = slice_builder_cls(self.raws, self.labels,
                                          self.weight_maps, patch_shape,
                                          stride_shape)
        self.raw_slices = slice_builder.raw_slices
        self.label_slices = slice_builder.label_slices
        self.weight_slices = slice_builder.weight_slices

        self.patch_count = len(self.raw_slices)
        self.logger.info(f'Number of patches: {self.patch_count}')
Exemplo n.º 24
0
def get_train_loaders(config):
    """
    Returns dictionary containing the training and validation loaders

    :param config: a top level configuration object containing the 'loaders' key
    :return: dict {
        'train': <train_loader>
        'val': <val_loader>
    }
    """

    assert 'loaders' in config, 'Could not find data loaders configuration'
    loaders_config = config['loaders']

    logger = get_logger('CloudVolumeDataset', logfile=config['logfile'])

    logger.info('Creating training and validation set loaders...')

    # get image cloudvolume path and segmentation mask cv path
    image_cv_path = loaders_config['image_cv_path']
    seg_cv_path = loaders_config['seg_cv_path']
    cutout_pkl_file = loaders_config['cutout_pkl_file']
    mip_level = loaders_config['mip_level']
    #volume_start = dataset_config['volume_start']
    #volume_end = dataset_config['volume_end']

    # get train/validation patch size and stride
    train_patch = tuple(loaders_config['train_patch'])
    train_stride = tuple(loaders_config['train_stride'])
    val_patch = tuple(loaders_config['val_patch'])
    val_stride = tuple(loaders_config['val_stride'])

    # get train slice_builder_cls
    train_slice_builder_str = loaders_config.get('train_slice_builder',
                                                 'SliceBuilder')
    logger.info(f'Train slice builder class: {train_slice_builder_str}')
    train_slice_builder_cls = _get_slice_builder_cls(train_slice_builder_str)

    image_cv = cloudvolume.CloudVolume(image_cv_path,
                                       mip=mip_level,
                                       bounded=False,
                                       autocrop=True,
                                       fill_missing=True)
    seg_cv = cloudvolume.CloudVolume(seg_cv_path,
                                     mip=mip_level,
                                     bounded=False,
                                     autocrop=True,
                                     fill_missing=True)

    df = []
    with open(cutout_pkl_file, 'rb') as f:
        df = pickle.load(f)

    train_datasets = []
    val_datasets = []

    val_slice_builder_str = loaders_config.get('val_slice_builder',
                                               'SliceBuilder')
    logger.info(f'Val slice builder class: {val_slice_builder_str}')
    val_slice_builder = _get_slice_builder_cls(val_slice_builder_str)

    phase = 'train'
    for indx, row in df.iterrows():
        if row['id'] == 0 or row['phase'] == 'test':
            continue
        row['cutout_bounds'] = list(row['cutout_bounds'])
        if row['phase'] == 'train':
            try:
                logger.info(
                    "Loading image and segmentation mask for {}".format(
                        row['id']))
                train_dataset = CloudVolumeDataset(
                    image_cv,
                    seg_cv,
                    row['id'],
                    row['cutout_bounds'],
                    mip_level,
                    'train',
                    train_patch,
                    train_stride,
                    transformer_config=loaders_config['transformer'],
                    slice_builder_cls=train_slice_builder_cls)
                train_datasets.append(train_dataset)
            except Exception:
                logger.info("Skipping training data for: {}".format(row['id']),
                            exc_info=True)
        if row['phase'] == 'val':
            try:
                logger.info(
                    f"Loading image and segmentation mask for {row['id']}")
                val_dataset = CloudVolumeDataset(
                    image_cv,
                    seg_cv,
                    row['id'],
                    row['cutout_bounds'],
                    mip_level,
                    'val',
                    val_patch,
                    val_stride,
                    transformer_config=loaders_config['transformer'],
                    slice_builder_cls=val_slice_builder)
                val_datasets.append(val_dataset)
            except Exception:
                logger.info(f"Skipping validation data for: {row['id']}",
                            exc_info=True)

    num_workers = loaders_config.get('num_workers', 1)
    logger.info(f'Number of workers for train/val dataloader: {num_workers}')
    batch_size = loaders_config.get('batch_size', 1)
    logger.info(f'Batch size for train/val loader: {batch_size}')

    # when training with volumetric data use batch_size of 1 due to GPU memory constraints
    return {
        'train':
        DataLoader(ConcatDataset(train_datasets),
                   batch_size=batch_size,
                   shuffle=True,
                   num_workers=num_workers),
        'val':
        DataLoader(ConcatDataset(val_datasets),
                   batch_size=batch_size,
                   shuffle=True,
                   num_workers=num_workers)
    }
Exemplo n.º 25
0
def main():
    logger = get_logger('UNet3DTrainer')
    # Get device to train on
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

    config = parse_train_config()

    logger.info(config)

    # Create loss criterion
    if config.loss_weight is not None:
        loss_weight = torch.tensor(config.loss_weight)
        loss_weight = loss_weight.to(device)
    else:
        loss_weight = None

    loss_criterion = get_loss_criterion(config.loss, loss_weight,
                                        config.ignore_index)

    model = UNet3D(config.in_channels,
                   config.out_channels,
                   init_channel_number=config.init_channel_number,
                   conv_layer_order=config.layer_order,
                   interpolate=config.interpolate,
                   final_sigmoid=config.final_sigmoid)

    model = model.to(device)

    # Log the number of learnable parameters
    logger.info(
        f'Number of learnable params {get_number_of_learnable_parameters(model)}'
    )

    # Create evaluation metric
    eval_criterion = get_evaluation_metric(config.eval_metric,
                                           ignore_index=config.ignore_index)

    # Get data loaders. If 'bce' or 'dice' loss is used, convert labels to float
    train_path, val_path = config.train_path, config.val_path
    if config.loss in ['bce']:
        label_dtype = 'float32'
    else:
        label_dtype = 'long'

    train_patch = tuple(config.train_patch)
    train_stride = tuple(config.train_stride)
    val_patch = tuple(config.val_patch)
    val_stride = tuple(config.val_stride)

    logger.info(f'Train patch/stride: {train_patch}/{train_stride}')
    logger.info(f'Val patch/stride: {val_patch}/{val_stride}')

    pixel_wise_weight = config.loss == 'pce'
    loaders = get_loaders(train_path,
                          val_path,
                          label_dtype=label_dtype,
                          raw_internal_path=config.raw_internal_path,
                          label_internal_path=config.label_internal_path,
                          train_patch=train_patch,
                          train_stride=train_stride,
                          val_patch=val_patch,
                          val_stride=val_stride,
                          transformer=config.transformer,
                          pixel_wise_weight=pixel_wise_weight,
                          curriculum_learning=config.curriculum,
                          ignore_index=config.ignore_index)

    # Create the optimizer
    optimizer = _create_optimizer(config, model)

    if config.resume:
        trainer = UNet3DTrainer.from_checkpoint(config.resume,
                                                model,
                                                optimizer,
                                                loss_criterion,
                                                eval_criterion,
                                                loaders,
                                                logger=logger)
    else:
        trainer = UNet3DTrainer(
            model,
            optimizer,
            loss_criterion,
            eval_criterion,
            device,
            loaders,
            config.checkpoint_dir,
            max_num_epochs=config.epochs,
            max_num_iterations=config.iters,
            max_patience=config.patience,
            validate_after_iters=config.validate_after_iters,
            log_after_iters=config.log_after_iters,
            logger=logger)

    trainer.fit()
Exemplo n.º 26
0
def main():
    parser = _arg_parser()
    logger = get_logger('UNet3DTrainer')
    # Get device to train on
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

    args = parser.parse_args()

    logger.info(args)

    # Create loss criterion
    if args.loss_weight is not None:
        loss_weight = torch.tensor(args.loss_weight)
        loss_weight = loss_weight.to(device)
    else:
        loss_weight = None

    loss_criterion, final_sigmoid = _get_loss_criterion(args.loss, loss_weight)

    model = _create_model(args.in_channels, args.out_channels,
                          layer_order=args.layer_order,
                          interpolate=args.interpolate,
                          final_sigmoid=final_sigmoid)

    model = model.to(device)

    # Log the number of learnable parameters
    #logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')

    # Create accuracy metric
    accuracy_criterion = _get_accuracy_criterion(not final_sigmoid)

    # Get data loaders. If 'bce' or 'dice' loss is used, convert labels to float
    train_path, val_path = args.train_path, args.val_path
    if args.loss in ['bce', 'dice']:
        label_dtype = 'float32'
    else:
        label_dtype = 'long'

    train_patch = tuple(args.train_patch)
    train_stride = tuple(args.train_stride)
    val_patch = tuple(args.val_patch)
    val_stride = tuple(args.val_stride)

    #logger.info(f'Train patch/stride: {train_patch}/{train_stride}')
    #logger.info(f'Val patch/stride: {val_patch}/{val_stride}')

    loaders = _get_loaders(train_path, val_path, label_dtype=label_dtype, train_patch=train_patch,
                           train_stride=train_stride, val_patch=val_patch, val_stride=val_stride)

    # Create the optimizer
    optimizer = _create_optimizer(args, model)

    if args.resume:
        trainer = UNet3DTrainer.from_checkpoint(args.resume, model,
                                                optimizer, loss_criterion,
                                                accuracy_criterion, loaders,
                                                logger=logger)
    else:
        trainer = UNet3DTrainer(model, optimizer, loss_criterion,
                                accuracy_criterion,
                                device, loaders, args.checkpoint_dir,
                                max_num_epochs=args.epochs,
                                max_num_iterations=args.iters,
                                max_patience=args.patience,
                                validate_after_iters=args.validate_after_iters,
                                log_after_iters=args.log_after_iters,
                                logger=logger)

    trainer.fit()
Exemplo n.º 27
0
# Authors:
# Christian F. Baumgartner ([email protected])
# Lisa M. Koch ([email protected])
# Source: https://git.ee.ethz.ch/baumgach/discriminative_learning_toolbox/blob/master/utils.py

import nibabel as nib
import numpy as np
import os
import logging
from unet3d.utils import get_logger
from skimage import measure, transform
logger = get_logger('utils')

try:
    import cv2
except:
    logger.warning(
        'Could not import opencv. Augmentation functions will be unavailable.')
else:

    def rotate_image(img, angle, interp=cv2.INTER_LINEAR):
        rows, cols = img.shape[:2]
        rotation_matrix = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle,
                                                  1)
        out = cv2.warpAffine(img,
                             rotation_matrix, (cols, rows),
                             flags=interp,
                             borderMode=cv2.BORDER_REPLICATE)
        return np.reshape(out, img.shape)

    def rotate_image_as_onehot(img, angle, nlabels, interp=cv2.INTER_LINEAR):
Exemplo n.º 28
0
def get_train_loaders(config):
    """
    Returns dictionary containing the training and validation loaders
    (torch.utils.data.DataLoader) backed by the datasets.hdf5.HDF5Dataset.

    :param config: a top level configuration object containing the 'loaders' key
    :return: dict {
        'train': <train_loader>
        'val': <val_loader>
    }
    """
    assert 'loaders' in config, 'Could not find data loaders configuration'
    loaders_config = config['loaders']

    logger = get_logger('HDF5Dataset')
    logger.info('Creating training and validation set loaders...')

    # get train and validation files
    train_paths = loaders_config['train_path']
    val_paths = loaders_config['val_path']
    assert isinstance(train_paths, list)
    assert isinstance(val_paths, list)
    # get h5 internal paths for raw and label
    raw_internal_path = loaders_config['raw_internal_path']
    label_internal_path = loaders_config['label_internal_path']
    weight_internal_path = loaders_config.get('weight_internal_path', None)
    # get train/validation patch size and stride
    train_patch = tuple(loaders_config['train_patch'])
    train_stride = tuple(loaders_config['train_stride'])
    val_patch = tuple(loaders_config['val_patch'])
    val_stride = tuple(loaders_config['val_stride'])

    # get slice_builder_cls
    slice_builder_str = loaders_config.get('slice_builder', 'SliceBuilder')
    logger.info(f'Slice builder class: {slice_builder_str}')
    slice_builder_cls = _get_slice_builder_cls(slice_builder_str)

    train_datasets = []
    for train_path in train_paths:
        try:
            logger.info(f'Loading training set from: {train_path}...')
            # create H5 backed training and validation dataset with data augmentation
            train_dataset = HDF5Dataset(
                train_path,
                train_patch,
                train_stride,
                phase='train',
                transformer_config=loaders_config['transformer'],
                raw_internal_path=raw_internal_path,
                label_internal_path=label_internal_path,
                weight_internal_path=weight_internal_path,
                slice_builder_cls=slice_builder_cls)
            train_datasets.append(train_dataset)
        except Exception:
            logger.info(f'Skipping training set: {train_path}', exc_info=True)

    val_datasets = []
    for val_path in val_paths:
        try:
            logger.info(f'Loading validation set from: {train_path}...')
            val_dataset = HDF5Dataset(
                val_path,
                val_patch,
                val_stride,
                phase='val',
                transformer_config=loaders_config['transformer'],
                raw_internal_path=raw_internal_path,
                label_internal_path=label_internal_path,
                weight_internal_path=weight_internal_path)
            val_datasets.append(val_dataset)
        except Exception:
            logger.info(f'Skipping validation set: {val_path}', exc_info=True)

    num_workers = loaders_config.get('num_workers', 1)
    logger.info(f'Number of workers for train/val datasets: {num_workers}')
    # when training with volumetric data use batch_size of 1 due to GPU memory constraints
    return {
        'train':
        DataLoader(ConcatDataset(train_datasets),
                   batch_size=1,
                   shuffle=True,
                   num_workers=num_workers),
        'val':
        DataLoader(ConcatDataset(val_datasets),
                   batch_size=1,
                   shuffle=True,
                   num_workers=num_workers)
    }
Exemplo n.º 29
0
import os

import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau

# from pytorch3dunet.unet3d.utils import get_logger
from unet3d.utils import get_logger
from . import utils

logger = get_logger('UNet3DTrainer')


class UNet3DTrainer:
    """3D UNet trainer.

    Args:
        model (Unet3D): UNet 3D model to be trained
        optimizer (nn.optim.Optimizer): optimizer used for training
        lr_scheduler (torch.optim.lr_scheduler._LRScheduler): learning rate scheduler
            WARN: bear in mind that lr_scheduler.step() is invoked after every validation step
            (i.e. validate_after_iters) not after every epoch. So e.g. if one uses StepLR with step_size=30
            the learning rate will be adjusted after every 30 * validate_after_iters iterations.
        loss_criterion (callable): loss function
        eval_criterion (callable): used to compute training/validation metric (such as Dice, IoU, AP or Rand score)
            saving the best checkpoint is based on the result of this function on the validation set
        device (torch.device): device to train on
        loaders (dict): 'train' and 'val' loaders
        checkpoint_dir (string): dir for saving checkpoints and tensorboard logs
        max_num_epochs (int): maximum number of epochs
Exemplo n.º 30
0
import collections
import importlib

import numpy as np
import torch
from torch.utils.data import DataLoader, ConcatDataset, Dataset

# from pytorch3dunet.unet3d.utils import get_logger
# from pytorch3dunet.datasets import lung
from unet3d.utils import get_logger
from datasets import lung

import monai.transforms as mn_tf

logger = get_logger('Dataset')


class ConfigDataset(Dataset):
    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    @classmethod
    def create_datasets(cls, dataset_config, phase):
        """
        Factory method for creating a list of datasets based on the provided config.

        Args:
            dataset_config (dict): dataset configuration