Esempio n. 1
0
def get_norm1d(name: str,
               num_features: int,
               affine: bool = True,
               params: Optional[dict] = None) -> nn.Module:
    """
    get a 1d normalization module from pytorch
    must be one of: instance, batch, none

    Args:
        name (str): name of normalization function desired
        num_features (int): number of channels in the normalization layer
        affine (bool): learn affine transform after normalization
        params (dict): dictionary of optional other parameters for the normalization layer
            as specified by the pytorch documentation

    Returns:
        norm: instance of normalization layer
    """
    if name.lower() == 'instance':
        norm = nn.InstanceNorm1d(num_features, affine=affine) if params is None else \
            nn.InstanceNorm1d(num_features, affine=affine, **params)
    elif name.lower() == 'batch':
        norm = nn.BatchNorm1d(num_features, affine=affine) if params is None else \
            nn.BatchNorm1d(num_features, affine=affine, **params)
    elif name.lower() == 'layer':
        norm = nn.GroupNorm(1, num_features, affine=affine)
    elif name.lower() == 'none':
        norm = None
    else:
        raise SynthtorchError(
            f'Normalization: "{name}" not a valid normalization routine or not supported.'
        )
    return norm
Esempio n. 2
0
def get_optim(name: str):
    """ get an optimizer by name """
    if name.lower() == 'adam':
        optimizer = torch.optim.Adam
    elif name.lower() == 'adamw':
        optimizer = torch.optim.AdamW
    elif name.lower() == 'sgd':
        optimizer = torch.optim.SGD
    elif name.lower() == 'sgdw':
        from ..learn.optim import SGDW
        optimizer = SGDW
    elif name.lower() == 'nsgd':
        from ..learn.optim import NesterovSGD
        optimizer = NesterovSGD
    elif name.lower() == 'nsgdw':
        from ..learn.optim import NesterovSGDW
        optimizer = NesterovSGDW
    elif name.lower() == 'rmsprop':
        optimizer = torch.optim.rmsprop
    elif name.lower() == 'adagrad':
        optimizer = torch.optim.adagrad
    elif name.lower() == 'amsgrad':
        from ..learn.optim import AMSGrad
        optimizer = AMSGrad
    else:
        raise SynthtorchError(
            f'Optimizer: "{name}" not a valid optimizer routine or not supported.'
        )
    return optimizer
Esempio n. 3
0
def get_dataloader(config: ExperimentConfig, tfms: Tuple[List, List] = None):
    """ get the dataloaders for training/validation """
    if config.dim > 1:
        # get data augmentation if not defined
        train_tfms, valid_tfms = get_data_augmentation(config) if tfms is None else tfms

        # check number of jobs requested and CPUs available
        num_cpus = os.cpu_count()
        if num_cpus < config.n_jobs:
            logger.warning(f'Requested more workers than available (n_jobs={config.n_jobs}, # cpus={num_cpus}). '
                           f'Setting n_jobs={num_cpus}.')
            config.n_jobs = num_cpus

        # define dataset and split into training/validation set
        use_nii_ds = config.ext is None or 'nii' in config.ext
        dataset = MultimodalNiftiDataset.setup_from_dir(config.source_dir, config.target_dir, Compose(train_tfms),
                                                        preload=config.preload) if use_nii_ds else \
            MultimodalImageDataset.setup_from_dir(config.source_dir, config.target_dir, Compose(train_tfms),
                                                  ext='*.' + config.ext, color=config.color, preload=config.preload)
        logger.info(f'Number of training images: {len(dataset)}')

        if config.valid_source_dir is not None and config.valid_target_dir is not None:
            valid_dataset = MultimodalNiftiDataset.setup_from_dir(config.valid_source_dir, config.valid_target_dir,
                                                                  Compose(valid_tfms),
                                                                  preload=config.preload) if use_nii_ds else \
                MultimodalImageDataset.setup_from_dir(config.valid_source_dir, config.valid_target_dir,
                                                      Compose(valid_tfms),
                                                      ext='*.' + config.ext, color=config.color, preload=config.preload)
            logger.info(f'Number of validation images: {len(valid_dataset)}')
            train_loader = DataLoader(dataset, batch_size=config.batch_size, num_workers=config.n_jobs, shuffle=True,
                                      pin_memory=config.pin_memory, worker_init_fn=init_fn)
            valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, num_workers=config.n_jobs,
                                      pin_memory=config.pin_memory, worker_init_fn=init_fn)
        else:
            # setup training and validation set
            num_train = len(dataset)
            indices = list(range(num_train))
            split = int(config.valid_split * num_train)
            valid_idx = np.random.choice(indices, size=split, replace=False)
            train_idx = list(set(indices) - set(valid_idx))
            train_sampler = SubsetRandomSampler(train_idx)
            valid_sampler = SubsetRandomSampler(valid_idx)
            # set up data loader for nifti images
            train_loader = DataLoader(dataset, sampler=train_sampler, batch_size=config.batch_size,
                                      num_workers=config.n_jobs, pin_memory=config.pin_memory, worker_init_fn=init_fn)
            valid_loader = DataLoader(dataset, sampler=valid_sampler, batch_size=config.batch_size,
                                      num_workers=config.n_jobs, pin_memory=config.pin_memory, worker_init_fn=init_fn)
    else:
        try:
            from altdataset import CSVDataset
        except (ImportError, ModuleNotFoundError):
            raise SynthtorchError('Cannot use 1D ConvNet in CLI without the altdataset toolbox.')
        train_dataset, valid_dataset = CSVDataset(config.source_dir[0]), CSVDataset(config.valid_source_dir[0])
        train_loader = DataLoader(train_dataset, batch_size=config.batch_size, num_workers=config.n_jobs, shuffle=True,
                                  pin_memory=config.pin_memory)
        valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, num_workers=config.n_jobs,
                                  pin_memory=config.pin_memory)

    return train_loader, valid_loader
Esempio n. 4
0
 def predict(self, fn: str, nsyn: int = 1, calc_var: bool = False):
     self.model.eval()
     f = fn[0].lower()
     if f.endswith('.nii') or f.endswith('.nii.gz'):
         img_nib = nib.load(fn[0])
         img = np.stack([nib.load(f).get_fdata(dtype=np.float32) for f in fn])
         out = self.predictor.predict(img, nsyn, calc_var)
         out_img = [nib.Nifti1Image(o, img_nib.affine, img_nib.header) for o in out]
     elif f.split('.')[-1] in ('tif', 'tiff', 'png', 'jpg', 'jpeg'):
         out_img = self._img_predict(fn, nsyn, calc_var)
     else:
         raise SynthtorchError(f'File: {fn[0]}, not supported.')
     return out_img
Esempio n. 5
0
def get_act(name: str,
            inplace: bool = True,
            params: Optional[dict] = None) -> nn.Module:
    """
    get activation module from pytorch
    must be one of: relu, lrelu, linear, tanh, sigmoid

    Args:
        name (str): name of activation function desired
        inplace (bool): flag activation to do operations in-place (if option available)
        params (dict): dictionary of parameters (as per pytorch documentation)

    Returns:
        act (activation): instance of activation class
    """
    if name.lower() == 'relu':
        act = nn.ReLU(inplace=inplace)
    elif name.lower() == 'lrelu':
        act = nn.LeakyReLU(
            inplace=inplace) if params is None else nn.LeakyReLU(
                inplace=inplace, **params)
    elif name.lower() == 'prelu':
        act = nn.PReLU() if params is None else nn.PReLU(**params)
    elif name.lower() == 'elu':
        act = nn.ELU(inplace=inplace) if params is None else nn.ELU(
            inplace=inplace, **params)
    elif name.lower() == 'celu':
        act = nn.CELU(inplace=inplace) if params is None else nn.CELU(
            inplace=inplace, **params)
    elif name.lower() == 'selu':
        act = nn.SELU(inplace=inplace)
    elif name.lower() == 'linear':
        act = nn.LeakyReLU(1, inplace=inplace)  # hack to get linear output
    elif name.lower() == 'tanh':
        act = nn.Tanh()
    elif name.lower() == 'sigmoid':
        act = nn.Sigmoid()
    elif name.lower() == 'softmax':
        act = nn.Softmax(dim=1)
    elif name.lower() == 'swish':
        act = Swish(inplace=inplace)
    else:
        raise SynthtorchError(
            f'Activation: "{name}" not a valid activation function or not supported.'
        )
    return act
Esempio n. 6
0
 def lr_scheduler(self, n_epochs, lr_scheduler='cyclic', restart_period=None, t_mult=None,
                  num_cycles=1, cycle_mode='triangular', momentum_range=(0.85, 0.95), div_factor=25, pct_start=0.3,
                  **kwargs):
     lr = self.config.learning_rate
     if lr_scheduler == 'cyclic':
         logger.info(f'Enabling cyclic LR scheduler with {num_cycles} cycle(s)')
         ss = int((n_epochs * len(self.train_loader)) / num_cycles)
         ssu = int(pct_start * ss)
         ssd = ss - ssu
         cycle_momentum = self.config.optimizer in ('sgd', 'sgdw', 'nsgd', 'nsgdw', 'rmsprop')
         momentum_kwargs = {'cycle_momentum': cycle_momentum}
         if not cycle_momentum and momentum_range is not None:
             logger.warning(f'{self.config.optimizer} not compatible with momentum cycling, disabling.')
         elif momentum_range is not None:
             momentum_kwargs.update({'base_momentum': momentum_range[0], 'max_momentum': momentum_range[1]})
         self.scheduler = CyclicLR(self.optimizer, lr / div_factor, lr, step_size_up=ssu, step_size_down=ssd,
                                   mode=cycle_mode, **momentum_kwargs)
     elif lr_scheduler == 'cosinerestarts':
         logger.info('Enabling cosine annealing with restarts LR scheduler')
         self.scheduler = CosineAnnealingWarmRestarts(self.optimizer, restart_period, T_mult=t_mult,
                                                      eta_min=lr / div_factor)
     else:
         raise SynthtorchError(f'Invalid type {type} for scheduler.')
     logger.info(f'Max LR: {lr:.2e}, Min LR: {lr / div_factor:.2e}')
Esempio n. 7
0
def get_model(config: ExperimentConfig, enable_dropout: bool = True, inplace: bool = False):
    """
    instantiate a model based on an ExperimentConfig class instance

    Args:
        config (ExperimentConfig): instance of the ExperimentConfig class
        enable_dropout (bool): enable dropout in the model (usually for training)

    Returns:
        model: instance of one of the available models in the synthtorch package
    """
    if config.nn_arch == 'nconv':
        from ..models.nconvnet import SimpleConvNet
        logger.warning('The nconv network is for basic testing.')
        model = SimpleConvNet(**config)
    elif config.nn_arch == 'unet':
        from ..models.unet import Unet
        model = Unet(enable_dropout=enable_dropout, inplace=inplace, **config)
    elif config.nn_arch == 'vae':
        from ..models.vae import VAE
        model = VAE(**config)
    elif config.nn_arch == 'densenet':
        from ..models.densenet import DenseNet
        model = DenseNet(**config)
    elif config.nn_arch == 'ordnet':
        try:
            from annom.models import OrdNet
        except (ImportError, ModuleNotFoundError):
            raise SynthtorchError('Cannot use the OrdNet without the annom toolbox.')
        model = OrdNet(enable_dropout=enable_dropout, inplace=inplace, **config)
    elif config.nn_arch == 'hotnet':
        try:
            from annom.models import HotNet
        except (ImportError, ModuleNotFoundError):
            raise SynthtorchError('Cannot use the HotNet without the annom toolbox.')
        model = HotNet(inplace=inplace, **config)
    elif config.nn_arch == 'burnnet':
        try:
            from annom.models import BurnNet
        except (ImportError, ModuleNotFoundError):
            raise SynthtorchError('Cannot use the BurnNet without the annom toolbox.')
        model = BurnNet(inplace=inplace, **config)
    elif config.nn_arch == 'burn2net':
        try:
            from annom.models import Burn2Net
        except (ImportError, ModuleNotFoundError):
            raise SynthtorchError('Cannot use the Burn2Net without the annom toolbox.')
        model = Burn2Net(inplace=inplace, **config)
    elif config.nn_arch == 'burn2netp12':
        try:
            from annom.models import Burn2NetP12
        except (ImportError, ModuleNotFoundError):
            raise SynthtorchError('Cannot use the Burn2NetP12 without the annom toolbox.')
        model = Burn2NetP12(inplace=inplace, **config)
    elif config.nn_arch == 'burn2netp21':
        try:
            from annom.models import Burn2NetP21
        except (ImportError, ModuleNotFoundError):
            raise SynthtorchError('Cannot use the Burn2NetP21 without the annom toolbox.')
        model = Burn2NetP21(inplace=inplace, **config)
    elif config.nn_arch == 'unburnnet':
        try:
            from annom.models import UnburnNet
        except (ImportError, ModuleNotFoundError):
            raise SynthtorchError('Cannot use the UnburnNet without the annom toolbox.')
        model = UnburnNet(inplace=inplace, **config)
    elif config.nn_arch == 'unburn2net':
        try:
            from annom.models import Unburn2Net
        except (ImportError, ModuleNotFoundError):
            raise SynthtorchError('Cannot use the Unburn2Net without the annom toolbox.')
        model = Unburn2Net(inplace=inplace, **config)
    elif config.nn_arch == 'lavanet':
        try:
            from annom.models import LavaNet
        except (ImportError, ModuleNotFoundError):
            raise SynthtorchError('Cannot use the LavaNet without the annom toolbox.')
        model = LavaNet(inplace=inplace, **config)
    elif config.nn_arch == 'lava2net':
        try:
            from annom.models import Lava2Net
        except (ImportError, ModuleNotFoundError):
            raise SynthtorchError('Cannot use the Lava2Net without the annom toolbox.')
        model = Lava2Net(inplace=inplace, **config)
    elif config.nn_arch == 'lautonet':
        try:
            from annom.models import LAutoNet
        except (ImportError, ModuleNotFoundError):
            raise SynthtorchError('Cannot use the LAutoNet without the annom toolbox.')
        model = LAutoNet(enable_dropout=enable_dropout, inplace=inplace, **config)
    elif config.nn_arch == 'ocnet1':
        try:
            from annom.models import OCNet1
        except (ImportError, ModuleNotFoundError):
            raise SynthtorchError('Cannot use the OCNet without the annom toolbox.')
        model = OCNet1(enable_dropout=enable_dropout, inplace=inplace if config.dropout_prob == 0 else False, **config)
    elif config.nn_arch == 'ocnet2':
        try:
            from annom.models import OCNet2
        except (ImportError, ModuleNotFoundError):
            raise SynthtorchError('Cannot use the OCNet without the annom toolbox.')
        model = OCNet2(enable_dropout=enable_dropout, inplace=inplace if config.dropout_prob == 0 else False, **config)
    else:
        raise SynthtorchError(f'Invalid NN type: {config.nn_arch}. '
                              f'{{nconv,unet,vae,densenet,ordnet,hotnet,burnnet,burn2netp12,burn2netp21,'
                              f'unburnnet,unburn2net,lavanet,lava2net,lautonet,ocnet1,ocnet2}} '
                              f'are the only supported options.')
    return model
Esempio n. 8
0
    def fit(self, n_epochs, clip: float = None, checkpoint: int = None, trained_model: str = None):
        """ training loop for neural network """
        self.model.train()
        use_tb = self.config.tensorboard and SummaryWriter is not None
        if use_tb: writer = SummaryWriter()
        use_valid = self.valid_loader is not None
        use_scheduler = hasattr(self, 'scheduler')
        use_restarts = self.config.lr_scheduler == 'cosinerestarts'
        train_losses, valid_losses = [], []
        n_batches = len(self.train_loader)
        for t in range(1, n_epochs + 1):
            # training
            t_losses = []
            if use_valid: self.model.train(True)
            for i, (src, tgt) in enumerate(self.train_loader):
                logger.debug(f'Epoch {t} - training iteration {i} - '
                             f'Src. shape: {src.shape}; Tgt. shape: {tgt.shape}')
                src, tgt = src.to(self.device), tgt.to(self.device)
                self.optimizer.zero_grad()
                out = self.model(src)
                loss = self._criterion(out, tgt)
                t_losses.append(loss.item())
                if self.use_fp16:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                if clip is not None: nn.utils.clip_grad_norm_(self.model.parameters(), clip)
                self.optimizer.step()
                if use_scheduler: self.scheduler.step(((t - 1) + (i / n_batches)) if use_restarts else None)
                if use_tb:
                    if i % 20 == 0: writer.add_scalar('Loss/train', loss.item(), ((t - 1) * n_batches) + i)

                del loss  # save memory by removing ref to gradient tree
            train_losses.append(t_losses)

            if checkpoint is not None:
                if t % checkpoint == 0:
                    path, base, ext = split_filename(trained_model)
                    fn = os.path.join(path, base + f'_chk_{t}' + ext)
                    self.save(fn, t)

            # validation
            v_losses = []
            if use_valid:
                self.model.train(False)
                with torch.no_grad():
                    for i, (src, tgt) in enumerate(self.valid_loader):
                        logger.debug(f'Epoch {t} - validation iteration {i} - '
                                     f'Src. shape: {src.shape}; Tgt. shape: {tgt.shape}')
                        src, tgt = src.to(self.device), tgt.to(self.device)
                        out = self.model(src)
                        loss = self._criterion(out, tgt)
                        if use_tb:
                            if i % 20 == 0: writer.add_scalar('Loss/valid', loss.item(), ((t - 1) * n_batches) + i)
                            do_plot = i == 0 and ((t - 1) % 5) == 0
                            if do_plot and self.model.dim == 2:
                                writer.add_images('source', src[:8], t, dataformats='NCHW')
                                outimg = out[0][:8] if isinstance(out, tuple) else out[:8]
                                if self.config.color: outimg = torch.round(outimg)
                                writer.add_images('target', outimg, t, dataformats='NCHW')
                            if do_plot: self._histogram_weights(writer, t)
                        v_losses.append(loss.item())
                    valid_losses.append(v_losses)

            if not np.all(np.isfinite(t_losses)): raise SynthtorchError(
                'NaN or Inf in training loss, cannot recover. Exiting.')
            if logger is not None:
                log = f'Epoch: {t} - Training Loss: {np.mean(t_losses):.2e}'
                if use_valid: log += f', Validation Loss: {np.mean(v_losses):.2e}'
                if use_scheduler: log += f', LR: {self.scheduler.get_last_lr()[0]:.2e}'
                logger.info(log)

        self.record = Record(train_losses, valid_losses)
        if use_tb: writer.close()