Example #1
0
    def _train_save_load(self, tmpdir, loss, val_metric, model='UNet3D', max_num_epochs=1, log_after_iters=2,
                         validate_after_iters=2, max_num_iterations=4, weight_map=False):
        binary_loss = loss in ['BCEWithLogitsLoss', 'DiceLoss', 'GeneralizedDiceLoss']

        device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

        test_config = copy.deepcopy(CONFIG_BASE)
        test_config['model']['name'] = model
        test_config.update({
            # get device to train on
            'device': device,
            'loss': {'name': loss, 'weight': np.random.rand(2).astype(np.float32), 'pos_weight': 3.},
            'eval_metric': {'name': val_metric}
        })
        test_config['model']['final_sigmoid'] = binary_loss

        if weight_map:
            test_config['loaders']['weight_internal_path'] = 'weight_map'

        loss_criterion = get_loss_criterion(test_config)
        eval_criterion = get_evaluation_metric(test_config)
        model = get_model(test_config)
        model = model.to(device)

        if loss in ['BCEWithLogitsLoss']:
            label_dtype = 'float32'
        else:
            label_dtype = 'long'
        test_config['loaders']['train']['transformer']['label'][0]['dtype'] = label_dtype
        test_config['loaders']['val']['transformer']['label'][0]['dtype'] = label_dtype

        train, val = TestUNet3DTrainer._create_random_dataset((3, 128, 128, 128), (3, 64, 64, 64), binary_loss)
        test_config['loaders']['train']['file_paths'] = [train]
        test_config['loaders']['val']['file_paths'] = [val]

        loaders = get_train_loaders(test_config)

        optimizer = _create_optimizer(test_config, model)

        test_config['lr_scheduler']['name'] = 'MultiStepLR'
        lr_scheduler = _create_lr_scheduler(test_config, optimizer)

        logger = get_logger('UNet3DTrainer', logging.DEBUG)

        formatter = DefaultTensorboardFormatter()
        trainer = UNet3DTrainer(model, optimizer, lr_scheduler,
                                loss_criterion, eval_criterion,
                                device, loaders, tmpdir,
                                max_num_epochs=max_num_epochs,
                                log_after_iters=log_after_iters,
                                validate_after_iters=validate_after_iters,
                                max_num_iterations=max_num_iterations,
                                tensorboard_formatter=formatter)
        trainer.fit()
        # test loading the trainer from the checkpoint
        trainer = UNet3DTrainer.from_checkpoint(os.path.join(tmpdir, 'last_checkpoint.pytorch'),
                                                model, optimizer, lr_scheduler,
                                                loss_criterion, eval_criterion,
                                                loaders, tensorboard_formatter=formatter)
        return trainer
Example #2
0
    def __call__(self):
        logger = utils.get_logger('UNet3DPredictor')

        if not self.state:
            # skip network predictions and return input_paths
            gui_logger.info(
                f"Skipping '{self.__class__.__name__}'. Disabled by the user.")
            return self.paths
        else:
            # create config/download models only when cnn_prediction enabled
            config = create_predict_config(self.paths, self.cnn_config)

            # Create the model
            model = get_model(config)

            # Load model state
            model_path = config['model_path']
            model_name = config["model_name"]

            logger.info(f"Loading model '{model_name}' from {model_path}")
            utils.load_checkpoint(model_path, model)
            logger.info(f"Sending the model to '{config['device']}'")
            model = model.to(config['device'])

            logger.info('Loading HDF5 datasets...')

            # Run prediction
            output_paths = []
            for test_loader in get_test_loaders(config):
                gui_logger.info(
                    f"Running network prediction on {test_loader.dataset.file_path}..."
                )
                runtime = time.time()

                logger.info(f"Processing '{test_loader.dataset.file_path}'...")

                output_file = _get_output_file(test_loader.dataset, model_name)

                predictor = _get_predictor(model, test_loader, output_file,
                                           config)

                # run the model prediction on the entire dataset and save to the 'output_file' H5
                predictor.predict()

                # save resulting output path
                output_paths.append(output_file)

                runtime = time.time() - runtime
                gui_logger.info(f"Network prediction took {runtime:.2f} s")

            self._update_voxel_size(self.paths, output_paths)

            # free GPU memory after the inference is finished
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            return output_paths
Example #3
0
def main():
    # Create main logger
    logger = get_logger('UNet3DTrainer')

    # Load and log experiment configuration
    config = load_config()
    logger.info(config)

    manual_seed = config.get('manual_seed', None)
    if manual_seed is not None:
        logger.info(f'Seed the RNG for all devices with {manual_seed}')
        torch.manual_seed(manual_seed)
        # see https://pytorch.org/docs/stable/notes/randomness.html
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Create the model
    model = get_model(config)
    # put the model on GPUs
    logger.info(f"Sending the model to '{config['device']}'")
    model = model.to(config['device'])
    # Log the number of learnable parameters
    logger.info(
        f'Number of learnable params {get_number_of_learnable_parameters(model)}'
    )

    # Create loss criterion
    loss_criterion = get_loss_criterion(config)
    # Create evaluation metric
    eval_criterion = get_evaluation_metric(config)

    # Create data loaders
    loaders = get_train_loaders(config)

    # Create the optimizer
    optimizer = _create_optimizer(config, model)

    # Create learning rate adjustment strategy
    lr_scheduler = _create_lr_scheduler(config, optimizer)

    # Create model trainer
    trainer = _create_trainer(config,
                              model=model,
                              optimizer=optimizer,
                              lr_scheduler=lr_scheduler,
                              loss_criterion=loss_criterion,
                              eval_criterion=eval_criterion,
                              loaders=loaders,
                              logger=logger)
    # Start training
    trainer.fit()
Example #4
0
import glob
import os
from itertools import chain

import h5py
import numpy as np

import pytorch3dunet.augment.transforms as transforms
from pytorch3dunet.datasets.utils import get_slice_builder, ConfigDataset, calculate_stats
from pytorch3dunet.unet3d.utils import get_logger

logger = get_logger('HDF5Dataset')


class AbstractHDF5Dataset(ConfigDataset):
    """
    Implementation of torch.utils.data.Dataset backed by the HDF5 files, which iterates over the raw and label datasets
    patch by patch with a given stride.
    """
    def __init__(self,
                 file_path,
                 phase,
                 slice_builder_config,
                 transformer_config,
                 mirror_padding=(16, 32, 32),
                 raw_internal_path='raw',
                 label_internal_path='label',
                 weight_internal_path=None,
                 global_normalization=True):
        """
        :param file_path: path to H5 file containing raw data as well as labels and per pixel weights (optional)
Example #5
0
import random

import torch

from pytorch3dunet.unet3d.config import load_config
from pytorch3dunet.unet3d.trainer import create_trainer
from pytorch3dunet.unet3d.utils import get_logger

logger = get_logger('TrainingSetup')


def main():
    # Load and log experiment configuration
    config = load_config()
    logger.info(config)

    manual_seed = config.get('manual_seed', None)
    if manual_seed is not None:
        logger.info(f'Seed the RNG for all devices with {manual_seed}')
        logger.warning('Using CuDNN deterministic setting. This may slow down the training!')
        random.seed(manual_seed)
        torch.manual_seed(manual_seed)
        # see https://pytorch.org/docs/stable/notes/randomness.html
        torch.backends.cudnn.deterministic = True

    # create trainer
    trainer = create_trainer(config)
    # Start training
    trainer.fit()

Example #6
0
from skimage import feature
from skimage import measure
from skimage.filters import threshold_otsu
from skimage.metrics import adapted_rand_error, peak_signal_noise_ratio, mean_squared_error
from skimage.metrics import normalized_root_mse
from skimage.segmentation import watershed

from sklearn.cluster import MeanShift
from scipy.spatial import distance
from scipy import ndimage

from pytorch3dunet.unet3d.losses import compute_per_channel_dice
from pytorch3dunet.unet3d.seg_metrics import AveragePrecision, Accuracy
from pytorch3dunet.unet3d.utils import get_logger, expand_as_one_hot, plot_segm, convert_to_numpy

logger = get_logger('EvalMetric')


class DiceCoefficient:
    """Computes Dice Coefficient.
    Generalized to multiple channels by computing per-channel Dice Score
    (as described in https://arxiv.org/pdf/1707.03237.pdf) and theTn simply taking the average.
    Input is expected to be probabilities instead of logits.
    This metric is mostly useful when channels contain the same semantic class (e.g. affinities computed with different offsets).
    DO NOT USE this metric when training with DiceLoss, otherwise the results will be biased towards the loss.
    """

    def __init__(self, epsilon=1e-6, **kwargs):
        self.epsilon = epsilon

    def __call__(self, input, target):
Example #7
0
import importlib
import os
import torch
import torch.nn as nn
from pytorch3dunet.datasets.utils import get_test_loaders
from pytorch3dunet.unet3d import utils
from pytorch3dunet.unet3d.config import load_config
from pytorch3dunet.unet3d.model import get_model
from argparse import ArgumentParser
import yaml
from pathlib import Path

logger = utils.get_logger('UNet3DPredict')

checkpointname = "checkpoint"

predname = 'predictions'


def load_config(runconfig, nworkers, device):
    runconfig = yaml.safe_load(open(runconfig, 'r'))

    train_config = Path(runconfig['runFolder']) / 'train_config.yml'
    test_config = Path(runconfig['runFolder']) / 'test_config.yml'

    config = yaml.safe_load(open(test_config, 'r'))
    train_config = yaml.safe_load(open(train_config, 'r'))

    dataFolder = Path(runconfig['dataFolder'])
    runFolder = Path(runconfig['runFolder'])
Example #8
0
import time

import h5py
import hdbscan
import numpy as np
import torch
from sklearn.cluster import MeanShift

from pytorch3dunet.datasets.utils import SliceBuilder
from pytorch3dunet.unet3d.utils import get_logger
from pytorch3dunet.unet3d.utils import remove_halo

logger = get_logger('UNet3DPredictor')


class _AbstractPredictor:
    def __init__(self, model, loader, output_file, config, **kwargs):
        self.model = model
        self.loader = loader
        self.output_file = output_file
        self.config = config
        self.predictor_config = kwargs

    @staticmethod
    def _volume_shape(dataset):
        # TODO: support multiple internal datasets
        raw = dataset.raws[0]
        if raw.ndim == 3:
            return raw.shape
        else:
            return raw.shape[1:]
Example #9
0
import os

import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau

from pytorch3dunet.unet3d.utils import get_logger
from . import utils

logger = get_logger('UNet3DTrainer')


class UNet3DTrainer:
    """3D UNet trainer.

    Args:
        model (Unet3D): UNet 3D model to be trained
        optimizer (nn.optim.Optimizer): optimizer used for training
        lr_scheduler (torch.optim.lr_scheduler._LRScheduler): learning rate scheduler
            WARN: bear in mind that lr_scheduler.step() is invoked after every validation step
            (i.e. validate_after_iters) not after every epoch. So e.g. if one uses StepLR with step_size=30
            the learning rate will be adjusted after every 30 * validate_after_iters iterations.
        loss_criterion (callable): loss function
        eval_criterion (callable): used to compute training/validation metric (such as Dice, IoU, AP or Rand score)
            saving the best checkpoint is based on the result of this function on the validation set
        device (torch.device): device to train on
        loaders (dict): 'train' and 'val' loaders
        checkpoint_dir (string): dir for saving checkpoints and tensorboard logs
        max_num_epochs (int): maximum number of epochs
        max_num_iterations (int): maximum number of iterations
Example #10
0
import os

import imageio
import numpy as np

from pytorch3dunet.augment import transforms
from pytorch3dunet.datasets.utils import ConfigDataset, calculate_stats
from pytorch3dunet.unet3d.utils import get_logger

logger = get_logger('DSB2018Dataset')


class DSB2018Dataset(ConfigDataset):
    def __init__(self, root_dir, phase, transformer_config, mirror_padding=(0, 32, 32), expand_dims=True):
        assert os.path.isdir(root_dir), 'root_dir is not a directory'
        assert phase in ['train', 'val', 'test']

        # use mirror padding only during the 'test' phase
        if phase in ['train', 'val']:
            mirror_padding = None
        if mirror_padding is not None:
            assert len(mirror_padding) == 3, f"Invalid mirror_padding: {mirror_padding}"
        self.mirror_padding = mirror_padding

        self.phase = phase

        # load raw images
        images_dir = os.path.join(root_dir, 'images')
        assert os.path.isdir(images_dir)
        self.images, self.paths = self._load_files(images_dir, expand_dims)
        self.file_path = images_dir
Example #11
0
import collections
import importlib

import numpy as np
import torch
from torch.utils.data import DataLoader, ConcatDataset, Dataset

from pytorch3dunet.unet3d.utils import get_logger

logger = get_logger('Dataset')


class ConfigDataset(Dataset):
    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    @classmethod
    def create_datasets(cls, dataset_config, phase):
        """
        Factory method for creating a list of datasets based on the provided config.

        Args:
            dataset_config (dict): dataset configuration
            phase (str): one of ['train', 'val', 'test']

        Returns:
            list of `Dataset` instances
        """
Example #12
0
import numpy as np
import SimpleITK as sitk
import h5py
from pytorch3dunet.augment import transforms
from pytorch3dunet.datasets.utils import ConfigDataset, calculate_stats
from pytorch3dunet.unet3d.utils import get_logger, expand_as_one_hot 
from pytorch3dunet.datasets.hdf5 import AbstractHDF5Dataset  
from pytorch3dunet.datasets.visualize import visualizer
from scipy import ndimage

import re
import yaml
import ipdb
import torch

logger = get_logger('SXTHDataset')


class SXTHDataset(AbstractHDF5Dataset):
    def __init__(self, file_path, phase, slice_builder_config, transformer_config, mirror_padding=(16, 32, 32),
                 raw_internal_path='raw', label_internal_path='label', weight_internal_path=None):
        super().__init__(file_path=file_path,
                         phase=phase,
                         slice_builder_config=slice_builder_config,
                         transformer_config=transformer_config,
                         mirror_padding=mirror_padding,
                         raw_internal_path=raw_internal_path,
                         label_internal_path=label_internal_path,
                         weight_internal_path=weight_internal_path)
                
    @classmethod
Example #13
0
import importlib
import os
import time

import hdbscan
import numpy as np
import torch
import torch.nn.functional as F
from skimage import measure
from sklearn.cluster import MeanShift

from pytorch3dunet.unet3d.losses import compute_per_channel_dice
from pytorch3dunet.unet3d.utils import get_logger, adapted_rand, expand_as_one_hot, plot_segm

LOGGER = get_logger('EvalMetric')


class DiceCoefficient:
    """Computes Dice Coefficient.
    Generalized to multiple channels by computing per-channel Dice Score
    (as described in https://arxiv.org/pdf/1707.03237.pdf) and theTn simply taking the average.
    Input is expected to be probabilities instead of logits.
    This metric is mostly useful when channels contain the same semantic class (e.g. affinities computed with different offsets).
    DO NOT USE this metric when training with DiceLoss, otherwise the results will be biased towards the loss.
    """

    def __init__(self, epsilon=1e-5, ignore_index=None, **kwargs):
        self.epsilon = epsilon
        self.ignore_index = ignore_index

    def __call__(self, input, target):
Example #14
0
import os
import tempfile

import h5py
import numpy as np

from pytorch3dunet.datasets.utils import ConfigDataset
from pytorch3dunet.unet3d.utils import get_logger

logger = get_logger('EGFPDataset')


class EGFPDataset(ConfigDataset):
    def __init__(self, file_path, internal_path, z_slice_count,
                 target_slice_index, **kwargs):
        with h5py.File(file_path, 'r') as f:
            self.raw = f[internal_path][...]

        self.z_slice_count = z_slice_count
        self.target_slice_index = target_slice_index

        assert self.raw.ndim == 4
        # assumes ZYXC axis order
        assert self.raw.shape[0] >= z_slice_count

        assert target_slice_index < z_slice_count

    def __getitem__(self, index):
        if self.target_slice_index == 0:
            # 1st z-slice is the target
            # return rank 4 tensor always: ZYXC
Example #15
0
import argparse

import torch
import yaml

from pytorch3dunet.unet3d import utils

logger = utils.get_logger('ConfigLoader')


def load_config():
    parser = argparse.ArgumentParser(description='UNet3D')
    parser.add_argument('--config',
                        type=str,
                        help='Path to the YAML config file')
    args = parser.parse_args()
    #    args.config = '../resources/train_config_ce.yaml'
    config = _load_config_yaml(args.config)
    # Get a device to train on
    device_str = config.get('device', None)
    if device_str is not None:
        logger.info(f"Device specified in config: '{device_str}'")
        if device_str.startswith('cuda') and not torch.cuda.is_available():
            logger.warn('CUDA not available, using CPU')
            device_str = 'cpu'
    else:
        device_str = "cuda:0" if torch.cuda.is_available() else 'cpu'
        logger.info(f"Using '{device_str}' device")

    device = torch.device(device_str)
    config['device'] = device
Example #16
0
import imageio
import numpy as np
import nibabel as nib
import h5py
from pytorch3dunet.augment import transforms
from pytorch3dunet.datasets.utils import ConfigDataset, calculate_stats
from pytorch3dunet.unet3d.utils import get_logger, expand_as_one_hot
from pytorch3dunet.datasets.hdf5 import AbstractHDF5Dataset
from pytorch3dunet.datasets.visualize import visualizer
from scipy import ndimage

import yaml
import ipdb
import torch

logger = get_logger('Kits19Dataset')


class Kits19Dataset(AbstractHDF5Dataset):
    def __init__(self,
                 file_path,
                 phase,
                 slice_builder_config,
                 transformer_config,
                 mirror_padding=(16, 32, 32),
                 raw_internal_path='raw',
                 label_internal_path='label',
                 weight_internal_path=None):
        super().__init__(file_path=file_path,
                         phase=phase,
                         slice_builder_config=slice_builder_config,