def get_norm1d(name: str, num_features: int, affine: bool = True, params: Optional[dict] = None) -> nn.Module: """ get a 1d normalization module from pytorch must be one of: instance, batch, none Args: name (str): name of normalization function desired num_features (int): number of channels in the normalization layer affine (bool): learn affine transform after normalization params (dict): dictionary of optional other parameters for the normalization layer as specified by the pytorch documentation Returns: norm: instance of normalization layer """ if name.lower() == 'instance': norm = nn.InstanceNorm1d(num_features, affine=affine) if params is None else \ nn.InstanceNorm1d(num_features, affine=affine, **params) elif name.lower() == 'batch': norm = nn.BatchNorm1d(num_features, affine=affine) if params is None else \ nn.BatchNorm1d(num_features, affine=affine, **params) elif name.lower() == 'layer': norm = nn.GroupNorm(1, num_features, affine=affine) elif name.lower() == 'none': norm = None else: raise SynthtorchError( f'Normalization: "{name}" not a valid normalization routine or not supported.' ) return norm
def get_optim(name: str): """ get an optimizer by name """ if name.lower() == 'adam': optimizer = torch.optim.Adam elif name.lower() == 'adamw': optimizer = torch.optim.AdamW elif name.lower() == 'sgd': optimizer = torch.optim.SGD elif name.lower() == 'sgdw': from ..learn.optim import SGDW optimizer = SGDW elif name.lower() == 'nsgd': from ..learn.optim import NesterovSGD optimizer = NesterovSGD elif name.lower() == 'nsgdw': from ..learn.optim import NesterovSGDW optimizer = NesterovSGDW elif name.lower() == 'rmsprop': optimizer = torch.optim.rmsprop elif name.lower() == 'adagrad': optimizer = torch.optim.adagrad elif name.lower() == 'amsgrad': from ..learn.optim import AMSGrad optimizer = AMSGrad else: raise SynthtorchError( f'Optimizer: "{name}" not a valid optimizer routine or not supported.' ) return optimizer
def get_dataloader(config: ExperimentConfig, tfms: Tuple[List, List] = None): """ get the dataloaders for training/validation """ if config.dim > 1: # get data augmentation if not defined train_tfms, valid_tfms = get_data_augmentation(config) if tfms is None else tfms # check number of jobs requested and CPUs available num_cpus = os.cpu_count() if num_cpus < config.n_jobs: logger.warning(f'Requested more workers than available (n_jobs={config.n_jobs}, # cpus={num_cpus}). ' f'Setting n_jobs={num_cpus}.') config.n_jobs = num_cpus # define dataset and split into training/validation set use_nii_ds = config.ext is None or 'nii' in config.ext dataset = MultimodalNiftiDataset.setup_from_dir(config.source_dir, config.target_dir, Compose(train_tfms), preload=config.preload) if use_nii_ds else \ MultimodalImageDataset.setup_from_dir(config.source_dir, config.target_dir, Compose(train_tfms), ext='*.' + config.ext, color=config.color, preload=config.preload) logger.info(f'Number of training images: {len(dataset)}') if config.valid_source_dir is not None and config.valid_target_dir is not None: valid_dataset = MultimodalNiftiDataset.setup_from_dir(config.valid_source_dir, config.valid_target_dir, Compose(valid_tfms), preload=config.preload) if use_nii_ds else \ MultimodalImageDataset.setup_from_dir(config.valid_source_dir, config.valid_target_dir, Compose(valid_tfms), ext='*.' + config.ext, color=config.color, preload=config.preload) logger.info(f'Number of validation images: {len(valid_dataset)}') train_loader = DataLoader(dataset, batch_size=config.batch_size, num_workers=config.n_jobs, shuffle=True, pin_memory=config.pin_memory, worker_init_fn=init_fn) valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, num_workers=config.n_jobs, pin_memory=config.pin_memory, worker_init_fn=init_fn) else: # setup training and validation set num_train = len(dataset) indices = list(range(num_train)) split = int(config.valid_split * num_train) valid_idx = np.random.choice(indices, size=split, replace=False) train_idx = list(set(indices) - set(valid_idx)) train_sampler = SubsetRandomSampler(train_idx) valid_sampler = SubsetRandomSampler(valid_idx) # set up data loader for nifti images train_loader = DataLoader(dataset, sampler=train_sampler, batch_size=config.batch_size, num_workers=config.n_jobs, pin_memory=config.pin_memory, worker_init_fn=init_fn) valid_loader = DataLoader(dataset, sampler=valid_sampler, batch_size=config.batch_size, num_workers=config.n_jobs, pin_memory=config.pin_memory, worker_init_fn=init_fn) else: try: from altdataset import CSVDataset except (ImportError, ModuleNotFoundError): raise SynthtorchError('Cannot use 1D ConvNet in CLI without the altdataset toolbox.') train_dataset, valid_dataset = CSVDataset(config.source_dir[0]), CSVDataset(config.valid_source_dir[0]) train_loader = DataLoader(train_dataset, batch_size=config.batch_size, num_workers=config.n_jobs, shuffle=True, pin_memory=config.pin_memory) valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, num_workers=config.n_jobs, pin_memory=config.pin_memory) return train_loader, valid_loader
def predict(self, fn: str, nsyn: int = 1, calc_var: bool = False): self.model.eval() f = fn[0].lower() if f.endswith('.nii') or f.endswith('.nii.gz'): img_nib = nib.load(fn[0]) img = np.stack([nib.load(f).get_fdata(dtype=np.float32) for f in fn]) out = self.predictor.predict(img, nsyn, calc_var) out_img = [nib.Nifti1Image(o, img_nib.affine, img_nib.header) for o in out] elif f.split('.')[-1] in ('tif', 'tiff', 'png', 'jpg', 'jpeg'): out_img = self._img_predict(fn, nsyn, calc_var) else: raise SynthtorchError(f'File: {fn[0]}, not supported.') return out_img
def get_act(name: str, inplace: bool = True, params: Optional[dict] = None) -> nn.Module: """ get activation module from pytorch must be one of: relu, lrelu, linear, tanh, sigmoid Args: name (str): name of activation function desired inplace (bool): flag activation to do operations in-place (if option available) params (dict): dictionary of parameters (as per pytorch documentation) Returns: act (activation): instance of activation class """ if name.lower() == 'relu': act = nn.ReLU(inplace=inplace) elif name.lower() == 'lrelu': act = nn.LeakyReLU( inplace=inplace) if params is None else nn.LeakyReLU( inplace=inplace, **params) elif name.lower() == 'prelu': act = nn.PReLU() if params is None else nn.PReLU(**params) elif name.lower() == 'elu': act = nn.ELU(inplace=inplace) if params is None else nn.ELU( inplace=inplace, **params) elif name.lower() == 'celu': act = nn.CELU(inplace=inplace) if params is None else nn.CELU( inplace=inplace, **params) elif name.lower() == 'selu': act = nn.SELU(inplace=inplace) elif name.lower() == 'linear': act = nn.LeakyReLU(1, inplace=inplace) # hack to get linear output elif name.lower() == 'tanh': act = nn.Tanh() elif name.lower() == 'sigmoid': act = nn.Sigmoid() elif name.lower() == 'softmax': act = nn.Softmax(dim=1) elif name.lower() == 'swish': act = Swish(inplace=inplace) else: raise SynthtorchError( f'Activation: "{name}" not a valid activation function or not supported.' ) return act
def lr_scheduler(self, n_epochs, lr_scheduler='cyclic', restart_period=None, t_mult=None, num_cycles=1, cycle_mode='triangular', momentum_range=(0.85, 0.95), div_factor=25, pct_start=0.3, **kwargs): lr = self.config.learning_rate if lr_scheduler == 'cyclic': logger.info(f'Enabling cyclic LR scheduler with {num_cycles} cycle(s)') ss = int((n_epochs * len(self.train_loader)) / num_cycles) ssu = int(pct_start * ss) ssd = ss - ssu cycle_momentum = self.config.optimizer in ('sgd', 'sgdw', 'nsgd', 'nsgdw', 'rmsprop') momentum_kwargs = {'cycle_momentum': cycle_momentum} if not cycle_momentum and momentum_range is not None: logger.warning(f'{self.config.optimizer} not compatible with momentum cycling, disabling.') elif momentum_range is not None: momentum_kwargs.update({'base_momentum': momentum_range[0], 'max_momentum': momentum_range[1]}) self.scheduler = CyclicLR(self.optimizer, lr / div_factor, lr, step_size_up=ssu, step_size_down=ssd, mode=cycle_mode, **momentum_kwargs) elif lr_scheduler == 'cosinerestarts': logger.info('Enabling cosine annealing with restarts LR scheduler') self.scheduler = CosineAnnealingWarmRestarts(self.optimizer, restart_period, T_mult=t_mult, eta_min=lr / div_factor) else: raise SynthtorchError(f'Invalid type {type} for scheduler.') logger.info(f'Max LR: {lr:.2e}, Min LR: {lr / div_factor:.2e}')
def get_model(config: ExperimentConfig, enable_dropout: bool = True, inplace: bool = False): """ instantiate a model based on an ExperimentConfig class instance Args: config (ExperimentConfig): instance of the ExperimentConfig class enable_dropout (bool): enable dropout in the model (usually for training) Returns: model: instance of one of the available models in the synthtorch package """ if config.nn_arch == 'nconv': from ..models.nconvnet import SimpleConvNet logger.warning('The nconv network is for basic testing.') model = SimpleConvNet(**config) elif config.nn_arch == 'unet': from ..models.unet import Unet model = Unet(enable_dropout=enable_dropout, inplace=inplace, **config) elif config.nn_arch == 'vae': from ..models.vae import VAE model = VAE(**config) elif config.nn_arch == 'densenet': from ..models.densenet import DenseNet model = DenseNet(**config) elif config.nn_arch == 'ordnet': try: from annom.models import OrdNet except (ImportError, ModuleNotFoundError): raise SynthtorchError('Cannot use the OrdNet without the annom toolbox.') model = OrdNet(enable_dropout=enable_dropout, inplace=inplace, **config) elif config.nn_arch == 'hotnet': try: from annom.models import HotNet except (ImportError, ModuleNotFoundError): raise SynthtorchError('Cannot use the HotNet without the annom toolbox.') model = HotNet(inplace=inplace, **config) elif config.nn_arch == 'burnnet': try: from annom.models import BurnNet except (ImportError, ModuleNotFoundError): raise SynthtorchError('Cannot use the BurnNet without the annom toolbox.') model = BurnNet(inplace=inplace, **config) elif config.nn_arch == 'burn2net': try: from annom.models import Burn2Net except (ImportError, ModuleNotFoundError): raise SynthtorchError('Cannot use the Burn2Net without the annom toolbox.') model = Burn2Net(inplace=inplace, **config) elif config.nn_arch == 'burn2netp12': try: from annom.models import Burn2NetP12 except (ImportError, ModuleNotFoundError): raise SynthtorchError('Cannot use the Burn2NetP12 without the annom toolbox.') model = Burn2NetP12(inplace=inplace, **config) elif config.nn_arch == 'burn2netp21': try: from annom.models import Burn2NetP21 except (ImportError, ModuleNotFoundError): raise SynthtorchError('Cannot use the Burn2NetP21 without the annom toolbox.') model = Burn2NetP21(inplace=inplace, **config) elif config.nn_arch == 'unburnnet': try: from annom.models import UnburnNet except (ImportError, ModuleNotFoundError): raise SynthtorchError('Cannot use the UnburnNet without the annom toolbox.') model = UnburnNet(inplace=inplace, **config) elif config.nn_arch == 'unburn2net': try: from annom.models import Unburn2Net except (ImportError, ModuleNotFoundError): raise SynthtorchError('Cannot use the Unburn2Net without the annom toolbox.') model = Unburn2Net(inplace=inplace, **config) elif config.nn_arch == 'lavanet': try: from annom.models import LavaNet except (ImportError, ModuleNotFoundError): raise SynthtorchError('Cannot use the LavaNet without the annom toolbox.') model = LavaNet(inplace=inplace, **config) elif config.nn_arch == 'lava2net': try: from annom.models import Lava2Net except (ImportError, ModuleNotFoundError): raise SynthtorchError('Cannot use the Lava2Net without the annom toolbox.') model = Lava2Net(inplace=inplace, **config) elif config.nn_arch == 'lautonet': try: from annom.models import LAutoNet except (ImportError, ModuleNotFoundError): raise SynthtorchError('Cannot use the LAutoNet without the annom toolbox.') model = LAutoNet(enable_dropout=enable_dropout, inplace=inplace, **config) elif config.nn_arch == 'ocnet1': try: from annom.models import OCNet1 except (ImportError, ModuleNotFoundError): raise SynthtorchError('Cannot use the OCNet without the annom toolbox.') model = OCNet1(enable_dropout=enable_dropout, inplace=inplace if config.dropout_prob == 0 else False, **config) elif config.nn_arch == 'ocnet2': try: from annom.models import OCNet2 except (ImportError, ModuleNotFoundError): raise SynthtorchError('Cannot use the OCNet without the annom toolbox.') model = OCNet2(enable_dropout=enable_dropout, inplace=inplace if config.dropout_prob == 0 else False, **config) else: raise SynthtorchError(f'Invalid NN type: {config.nn_arch}. ' f'{{nconv,unet,vae,densenet,ordnet,hotnet,burnnet,burn2netp12,burn2netp21,' f'unburnnet,unburn2net,lavanet,lava2net,lautonet,ocnet1,ocnet2}} ' f'are the only supported options.') return model
def fit(self, n_epochs, clip: float = None, checkpoint: int = None, trained_model: str = None): """ training loop for neural network """ self.model.train() use_tb = self.config.tensorboard and SummaryWriter is not None if use_tb: writer = SummaryWriter() use_valid = self.valid_loader is not None use_scheduler = hasattr(self, 'scheduler') use_restarts = self.config.lr_scheduler == 'cosinerestarts' train_losses, valid_losses = [], [] n_batches = len(self.train_loader) for t in range(1, n_epochs + 1): # training t_losses = [] if use_valid: self.model.train(True) for i, (src, tgt) in enumerate(self.train_loader): logger.debug(f'Epoch {t} - training iteration {i} - ' f'Src. shape: {src.shape}; Tgt. shape: {tgt.shape}') src, tgt = src.to(self.device), tgt.to(self.device) self.optimizer.zero_grad() out = self.model(src) loss = self._criterion(out, tgt) t_losses.append(loss.item()) if self.use_fp16: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if clip is not None: nn.utils.clip_grad_norm_(self.model.parameters(), clip) self.optimizer.step() if use_scheduler: self.scheduler.step(((t - 1) + (i / n_batches)) if use_restarts else None) if use_tb: if i % 20 == 0: writer.add_scalar('Loss/train', loss.item(), ((t - 1) * n_batches) + i) del loss # save memory by removing ref to gradient tree train_losses.append(t_losses) if checkpoint is not None: if t % checkpoint == 0: path, base, ext = split_filename(trained_model) fn = os.path.join(path, base + f'_chk_{t}' + ext) self.save(fn, t) # validation v_losses = [] if use_valid: self.model.train(False) with torch.no_grad(): for i, (src, tgt) in enumerate(self.valid_loader): logger.debug(f'Epoch {t} - validation iteration {i} - ' f'Src. shape: {src.shape}; Tgt. shape: {tgt.shape}') src, tgt = src.to(self.device), tgt.to(self.device) out = self.model(src) loss = self._criterion(out, tgt) if use_tb: if i % 20 == 0: writer.add_scalar('Loss/valid', loss.item(), ((t - 1) * n_batches) + i) do_plot = i == 0 and ((t - 1) % 5) == 0 if do_plot and self.model.dim == 2: writer.add_images('source', src[:8], t, dataformats='NCHW') outimg = out[0][:8] if isinstance(out, tuple) else out[:8] if self.config.color: outimg = torch.round(outimg) writer.add_images('target', outimg, t, dataformats='NCHW') if do_plot: self._histogram_weights(writer, t) v_losses.append(loss.item()) valid_losses.append(v_losses) if not np.all(np.isfinite(t_losses)): raise SynthtorchError( 'NaN or Inf in training loss, cannot recover. Exiting.') if logger is not None: log = f'Epoch: {t} - Training Loss: {np.mean(t_losses):.2e}' if use_valid: log += f', Validation Loss: {np.mean(v_losses):.2e}' if use_scheduler: log += f', LR: {self.scheduler.get_last_lr()[0]:.2e}' logger.info(log) self.record = Record(train_losses, valid_losses) if use_tb: writer.close()