def _train_save_load(self, tmpdir, loss, val_metric, model='UNet3D', max_num_epochs=1, log_after_iters=2, validate_after_iters=2, max_num_iterations=4, weight_map=False): binary_loss = loss in ['BCEWithLogitsLoss', 'DiceLoss', 'GeneralizedDiceLoss'] device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') test_config = copy.deepcopy(CONFIG_BASE) test_config['model']['name'] = model test_config.update({ # get device to train on 'device': device, 'loss': {'name': loss, 'weight': np.random.rand(2).astype(np.float32), 'pos_weight': 3.}, 'eval_metric': {'name': val_metric} }) test_config['model']['final_sigmoid'] = binary_loss if weight_map: test_config['loaders']['weight_internal_path'] = 'weight_map' loss_criterion = get_loss_criterion(test_config) eval_criterion = get_evaluation_metric(test_config) model = get_model(test_config) model = model.to(device) if loss in ['BCEWithLogitsLoss']: label_dtype = 'float32' else: label_dtype = 'long' test_config['loaders']['train']['transformer']['label'][0]['dtype'] = label_dtype test_config['loaders']['val']['transformer']['label'][0]['dtype'] = label_dtype train, val = TestUNet3DTrainer._create_random_dataset((3, 128, 128, 128), (3, 64, 64, 64), binary_loss) test_config['loaders']['train']['file_paths'] = [train] test_config['loaders']['val']['file_paths'] = [val] loaders = get_train_loaders(test_config) optimizer = _create_optimizer(test_config, model) test_config['lr_scheduler']['name'] = 'MultiStepLR' lr_scheduler = _create_lr_scheduler(test_config, optimizer) logger = get_logger('UNet3DTrainer', logging.DEBUG) formatter = DefaultTensorboardFormatter() trainer = UNet3DTrainer(model, optimizer, lr_scheduler, loss_criterion, eval_criterion, device, loaders, tmpdir, max_num_epochs=max_num_epochs, log_after_iters=log_after_iters, validate_after_iters=validate_after_iters, max_num_iterations=max_num_iterations, tensorboard_formatter=formatter) trainer.fit() # test loading the trainer from the checkpoint trainer = UNet3DTrainer.from_checkpoint(os.path.join(tmpdir, 'last_checkpoint.pytorch'), model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, tensorboard_formatter=formatter) return trainer
def __call__(self): logger = utils.get_logger('UNet3DPredictor') if not self.state: # skip network predictions and return input_paths gui_logger.info( f"Skipping '{self.__class__.__name__}'. Disabled by the user.") return self.paths else: # create config/download models only when cnn_prediction enabled config = create_predict_config(self.paths, self.cnn_config) # Create the model model = get_model(config) # Load model state model_path = config['model_path'] model_name = config["model_name"] logger.info(f"Loading model '{model_name}' from {model_path}") utils.load_checkpoint(model_path, model) logger.info(f"Sending the model to '{config['device']}'") model = model.to(config['device']) logger.info('Loading HDF5 datasets...') # Run prediction output_paths = [] for test_loader in get_test_loaders(config): gui_logger.info( f"Running network prediction on {test_loader.dataset.file_path}..." ) runtime = time.time() logger.info(f"Processing '{test_loader.dataset.file_path}'...") output_file = _get_output_file(test_loader.dataset, model_name) predictor = _get_predictor(model, test_loader, output_file, config) # run the model prediction on the entire dataset and save to the 'output_file' H5 predictor.predict() # save resulting output path output_paths.append(output_file) runtime = time.time() - runtime gui_logger.info(f"Network prediction took {runtime:.2f} s") self._update_voxel_size(self.paths, output_paths) # free GPU memory after the inference is finished if torch.cuda.is_available(): torch.cuda.empty_cache() return output_paths
def main(): # Create main logger logger = get_logger('UNet3DTrainer') # Load and log experiment configuration config = load_config() logger.info(config) manual_seed = config.get('manual_seed', None) if manual_seed is not None: logger.info(f'Seed the RNG for all devices with {manual_seed}') torch.manual_seed(manual_seed) # see https://pytorch.org/docs/stable/notes/randomness.html torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Create the model model = get_model(config) # put the model on GPUs logger.info(f"Sending the model to '{config['device']}'") model = model.to(config['device']) # Log the number of learnable parameters logger.info( f'Number of learnable params {get_number_of_learnable_parameters(model)}' ) # Create loss criterion loss_criterion = get_loss_criterion(config) # Create evaluation metric eval_criterion = get_evaluation_metric(config) # Create data loaders loaders = get_train_loaders(config) # Create the optimizer optimizer = _create_optimizer(config, model) # Create learning rate adjustment strategy lr_scheduler = _create_lr_scheduler(config, optimizer) # Create model trainer trainer = _create_trainer(config, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, loss_criterion=loss_criterion, eval_criterion=eval_criterion, loaders=loaders, logger=logger) # Start training trainer.fit()
import glob import os from itertools import chain import h5py import numpy as np import pytorch3dunet.augment.transforms as transforms from pytorch3dunet.datasets.utils import get_slice_builder, ConfigDataset, calculate_stats from pytorch3dunet.unet3d.utils import get_logger logger = get_logger('HDF5Dataset') class AbstractHDF5Dataset(ConfigDataset): """ Implementation of torch.utils.data.Dataset backed by the HDF5 files, which iterates over the raw and label datasets patch by patch with a given stride. """ def __init__(self, file_path, phase, slice_builder_config, transformer_config, mirror_padding=(16, 32, 32), raw_internal_path='raw', label_internal_path='label', weight_internal_path=None, global_normalization=True): """ :param file_path: path to H5 file containing raw data as well as labels and per pixel weights (optional)
import random import torch from pytorch3dunet.unet3d.config import load_config from pytorch3dunet.unet3d.trainer import create_trainer from pytorch3dunet.unet3d.utils import get_logger logger = get_logger('TrainingSetup') def main(): # Load and log experiment configuration config = load_config() logger.info(config) manual_seed = config.get('manual_seed', None) if manual_seed is not None: logger.info(f'Seed the RNG for all devices with {manual_seed}') logger.warning('Using CuDNN deterministic setting. This may slow down the training!') random.seed(manual_seed) torch.manual_seed(manual_seed) # see https://pytorch.org/docs/stable/notes/randomness.html torch.backends.cudnn.deterministic = True # create trainer trainer = create_trainer(config) # Start training trainer.fit()
from skimage import feature from skimage import measure from skimage.filters import threshold_otsu from skimage.metrics import adapted_rand_error, peak_signal_noise_ratio, mean_squared_error from skimage.metrics import normalized_root_mse from skimage.segmentation import watershed from sklearn.cluster import MeanShift from scipy.spatial import distance from scipy import ndimage from pytorch3dunet.unet3d.losses import compute_per_channel_dice from pytorch3dunet.unet3d.seg_metrics import AveragePrecision, Accuracy from pytorch3dunet.unet3d.utils import get_logger, expand_as_one_hot, plot_segm, convert_to_numpy logger = get_logger('EvalMetric') class DiceCoefficient: """Computes Dice Coefficient. Generalized to multiple channels by computing per-channel Dice Score (as described in https://arxiv.org/pdf/1707.03237.pdf) and theTn simply taking the average. Input is expected to be probabilities instead of logits. This metric is mostly useful when channels contain the same semantic class (e.g. affinities computed with different offsets). DO NOT USE this metric when training with DiceLoss, otherwise the results will be biased towards the loss. """ def __init__(self, epsilon=1e-6, **kwargs): self.epsilon = epsilon def __call__(self, input, target):
import importlib import os import torch import torch.nn as nn from pytorch3dunet.datasets.utils import get_test_loaders from pytorch3dunet.unet3d import utils from pytorch3dunet.unet3d.config import load_config from pytorch3dunet.unet3d.model import get_model from argparse import ArgumentParser import yaml from pathlib import Path logger = utils.get_logger('UNet3DPredict') checkpointname = "checkpoint" predname = 'predictions' def load_config(runconfig, nworkers, device): runconfig = yaml.safe_load(open(runconfig, 'r')) train_config = Path(runconfig['runFolder']) / 'train_config.yml' test_config = Path(runconfig['runFolder']) / 'test_config.yml' config = yaml.safe_load(open(test_config, 'r')) train_config = yaml.safe_load(open(train_config, 'r')) dataFolder = Path(runconfig['dataFolder']) runFolder = Path(runconfig['runFolder'])
import time import h5py import hdbscan import numpy as np import torch from sklearn.cluster import MeanShift from pytorch3dunet.datasets.utils import SliceBuilder from pytorch3dunet.unet3d.utils import get_logger from pytorch3dunet.unet3d.utils import remove_halo logger = get_logger('UNet3DPredictor') class _AbstractPredictor: def __init__(self, model, loader, output_file, config, **kwargs): self.model = model self.loader = loader self.output_file = output_file self.config = config self.predictor_config = kwargs @staticmethod def _volume_shape(dataset): # TODO: support multiple internal datasets raw = dataset.raws[0] if raw.ndim == 3: return raw.shape else: return raw.shape[1:]
import os import torch import torch.nn as nn from tensorboardX import SummaryWriter from torch.optim.lr_scheduler import ReduceLROnPlateau from pytorch3dunet.unet3d.utils import get_logger from . import utils logger = get_logger('UNet3DTrainer') class UNet3DTrainer: """3D UNet trainer. Args: model (Unet3D): UNet 3D model to be trained optimizer (nn.optim.Optimizer): optimizer used for training lr_scheduler (torch.optim.lr_scheduler._LRScheduler): learning rate scheduler WARN: bear in mind that lr_scheduler.step() is invoked after every validation step (i.e. validate_after_iters) not after every epoch. So e.g. if one uses StepLR with step_size=30 the learning rate will be adjusted after every 30 * validate_after_iters iterations. loss_criterion (callable): loss function eval_criterion (callable): used to compute training/validation metric (such as Dice, IoU, AP or Rand score) saving the best checkpoint is based on the result of this function on the validation set device (torch.device): device to train on loaders (dict): 'train' and 'val' loaders checkpoint_dir (string): dir for saving checkpoints and tensorboard logs max_num_epochs (int): maximum number of epochs max_num_iterations (int): maximum number of iterations
import os import imageio import numpy as np from pytorch3dunet.augment import transforms from pytorch3dunet.datasets.utils import ConfigDataset, calculate_stats from pytorch3dunet.unet3d.utils import get_logger logger = get_logger('DSB2018Dataset') class DSB2018Dataset(ConfigDataset): def __init__(self, root_dir, phase, transformer_config, mirror_padding=(0, 32, 32), expand_dims=True): assert os.path.isdir(root_dir), 'root_dir is not a directory' assert phase in ['train', 'val', 'test'] # use mirror padding only during the 'test' phase if phase in ['train', 'val']: mirror_padding = None if mirror_padding is not None: assert len(mirror_padding) == 3, f"Invalid mirror_padding: {mirror_padding}" self.mirror_padding = mirror_padding self.phase = phase # load raw images images_dir = os.path.join(root_dir, 'images') assert os.path.isdir(images_dir) self.images, self.paths = self._load_files(images_dir, expand_dims) self.file_path = images_dir
import collections import importlib import numpy as np import torch from torch.utils.data import DataLoader, ConcatDataset, Dataset from pytorch3dunet.unet3d.utils import get_logger logger = get_logger('Dataset') class ConfigDataset(Dataset): def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError @classmethod def create_datasets(cls, dataset_config, phase): """ Factory method for creating a list of datasets based on the provided config. Args: dataset_config (dict): dataset configuration phase (str): one of ['train', 'val', 'test'] Returns: list of `Dataset` instances """
import numpy as np import SimpleITK as sitk import h5py from pytorch3dunet.augment import transforms from pytorch3dunet.datasets.utils import ConfigDataset, calculate_stats from pytorch3dunet.unet3d.utils import get_logger, expand_as_one_hot from pytorch3dunet.datasets.hdf5 import AbstractHDF5Dataset from pytorch3dunet.datasets.visualize import visualizer from scipy import ndimage import re import yaml import ipdb import torch logger = get_logger('SXTHDataset') class SXTHDataset(AbstractHDF5Dataset): def __init__(self, file_path, phase, slice_builder_config, transformer_config, mirror_padding=(16, 32, 32), raw_internal_path='raw', label_internal_path='label', weight_internal_path=None): super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config, transformer_config=transformer_config, mirror_padding=mirror_padding, raw_internal_path=raw_internal_path, label_internal_path=label_internal_path, weight_internal_path=weight_internal_path) @classmethod
import importlib import os import time import hdbscan import numpy as np import torch import torch.nn.functional as F from skimage import measure from sklearn.cluster import MeanShift from pytorch3dunet.unet3d.losses import compute_per_channel_dice from pytorch3dunet.unet3d.utils import get_logger, adapted_rand, expand_as_one_hot, plot_segm LOGGER = get_logger('EvalMetric') class DiceCoefficient: """Computes Dice Coefficient. Generalized to multiple channels by computing per-channel Dice Score (as described in https://arxiv.org/pdf/1707.03237.pdf) and theTn simply taking the average. Input is expected to be probabilities instead of logits. This metric is mostly useful when channels contain the same semantic class (e.g. affinities computed with different offsets). DO NOT USE this metric when training with DiceLoss, otherwise the results will be biased towards the loss. """ def __init__(self, epsilon=1e-5, ignore_index=None, **kwargs): self.epsilon = epsilon self.ignore_index = ignore_index def __call__(self, input, target):
import os import tempfile import h5py import numpy as np from pytorch3dunet.datasets.utils import ConfigDataset from pytorch3dunet.unet3d.utils import get_logger logger = get_logger('EGFPDataset') class EGFPDataset(ConfigDataset): def __init__(self, file_path, internal_path, z_slice_count, target_slice_index, **kwargs): with h5py.File(file_path, 'r') as f: self.raw = f[internal_path][...] self.z_slice_count = z_slice_count self.target_slice_index = target_slice_index assert self.raw.ndim == 4 # assumes ZYXC axis order assert self.raw.shape[0] >= z_slice_count assert target_slice_index < z_slice_count def __getitem__(self, index): if self.target_slice_index == 0: # 1st z-slice is the target # return rank 4 tensor always: ZYXC
import argparse import torch import yaml from pytorch3dunet.unet3d import utils logger = utils.get_logger('ConfigLoader') def load_config(): parser = argparse.ArgumentParser(description='UNet3D') parser.add_argument('--config', type=str, help='Path to the YAML config file') args = parser.parse_args() # args.config = '../resources/train_config_ce.yaml' config = _load_config_yaml(args.config) # Get a device to train on device_str = config.get('device', None) if device_str is not None: logger.info(f"Device specified in config: '{device_str}'") if device_str.startswith('cuda') and not torch.cuda.is_available(): logger.warn('CUDA not available, using CPU') device_str = 'cpu' else: device_str = "cuda:0" if torch.cuda.is_available() else 'cpu' logger.info(f"Using '{device_str}' device") device = torch.device(device_str) config['device'] = device
import imageio import numpy as np import nibabel as nib import h5py from pytorch3dunet.augment import transforms from pytorch3dunet.datasets.utils import ConfigDataset, calculate_stats from pytorch3dunet.unet3d.utils import get_logger, expand_as_one_hot from pytorch3dunet.datasets.hdf5 import AbstractHDF5Dataset from pytorch3dunet.datasets.visualize import visualizer from scipy import ndimage import yaml import ipdb import torch logger = get_logger('Kits19Dataset') class Kits19Dataset(AbstractHDF5Dataset): def __init__(self, file_path, phase, slice_builder_config, transformer_config, mirror_padding=(16, 32, 32), raw_internal_path='raw', label_internal_path='label', weight_internal_path=None): super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config,