Beispiel #1
0
def get_data_augmentation(config:ExperimentConfig):
    """ get all data augmentation transforms for training """
    train_tfms, valid_tfms = [], []

    # add data augmentation if desired
    if config.prob is not None:
        logger.info('Adding data augmentation transforms')
        train_tfms.extend(niftitfms.get_transforms(config.prob, config.tfm_x, config.tfm_y, config.rotate, config.translate,
                                                   config.scale, config.vflip, config.hflip, config.gamma, config.gain,
                                                   config.noise_pwr, config.block, config.threshold, config.dim == 3,
                                                   config.mean, config.std, config.color))
        if config.mean is not None and config.std is not None:
            valid_tfms.extend([niftitfms.ToTensor(config.color),
                               niftitfms.Normalize(config.mean, config.std, config.tfm_x, config.tfm_y, config.dim == 3)])
    else:
        logger.info('No data augmentation will be used')
        train_tfms.append(niftitfms.ToTensor(config.color))
        valid_tfms.append(niftitfms.ToTensor(config.color))

    # control random cropping patch size (or if used at all)
    if (config.ext is None or config.ext == 'nii') and config.patch_size is not None:
        cropper = niftitfms.RandomCrop3D(config.patch_size, config.threshold, config.sample_pct, config.sample_axis) if config.dim == 3 else \
                  niftitfms.RandomCrop2D(config.patch_size, config.sample_axis, config.threshold)
        train_tfms.append(cropper if config.patch_size is not None and config.dim == 3 else \
                          niftitfms.RandomSlice(config.sample_axis))
        valid_tfms.append(cropper if config.patch_size is not None and config.dim == 3 else \
                          niftitfms.RandomSlice(config.sample_axis))
    else:
        if config.patch_size is not None:
            train_tfms.append(niftitfms.RandomCrop(config.patch_size, config.threshold))
            valid_tfms.append(niftitfms.RandomCrop(config.patch_size, config.threshold))

    logger.debug(f'Training transforms: {train_tfms}')
    return train_tfms, valid_tfms
Beispiel #2
0
def get_data_augmentation(config: ExperimentConfig):
    """ get all data augmentation transforms for training """
    # control random cropping patch size (or if used at all)
    if config.ext is None:
        cropper = niftitfms.RandomCrop3D(config.patch_size) if config.net3d else \
                  niftitfms.RandomCrop2D(config.patch_size, config.sample_axis)
        tfms = [cropper] if config.patch_size > 0 else \
               [] if config.net3d else \
               [niftitfms.RandomSlice(config.sample_axis)]
    else:
        tfms = [niftitfms.RandomCrop(config.patch_size)
                ] if config.patch_size > 0 else []

    # add data augmentation if desired
    if config.prob is not None:
        logger.info('Adding data augmentation transforms')
        tfms.extend(
            niftitfms.get_transforms(config.prob, config.tfm_x, config.tfm_y,
                                     config.rotate, config.translate,
                                     config.scale, config.vflip, config.hflip,
                                     config.gamma, config.gain,
                                     config.noise_pwr, config.block,
                                     config.mean, config.std))
    else:
        logger.info(
            'No data augmentation will be used (except random cropping if patch_size > 0)'
        )
        tfms.append(niftitfms.ToTensor())

    return tfms
Beispiel #3
0
def main(args=None):
    args, no_config_file = get_args(args, arg_parser)
    setup_log(args.verbosity)
    logger = logging.getLogger(__name__)
    try:
        # set random seeds for reproducibility
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

        # define device to put tensors on
        device, use_cuda, n_gpus = get_device(args, logger)

        # import and initialize mixed precision training package
        amp_handle = None
        if args.fp16:
            try:
                from apex import amp
                amp_handle = amp.init()
            except ImportError:
                logger.info(
                    'Mixed precision training (i.e., the package `apex`) not available.'
                )

        use_3d = args.net3d and not args.tiff
        if args.net3d and args.tiff:
            logger.warning(
                'Cannot train a 3D network with TIFF images, creating a 2D network.'
            )
        n_input, n_output = len(args.source_dir), len(args.target_dir)

        if args.ord_params is not None and n_output > 1:
            raise SynthNNError(
                'Ordinal regression does not support multiple outputs.')

        # get the desired neural network architecture
        if args.nn_arch == 'nconv':
            from synthnn.models.nconvnet import SimpleConvNet
            logger.warning('The nconv network is for basic testing.')
            model = SimpleConvNet(args.n_layers,
                                  kernel_size=args.kernel_size,
                                  dropout_p=args.dropout_prob,
                                  n_input=n_input,
                                  n_output=n_output,
                                  is_3d=use_3d)
        elif args.nn_arch == 'unet':
            from synthnn.models.unet import Unet
            model = Unet(args.n_layers,
                         kernel_size=args.kernel_size,
                         dropout_p=args.dropout_prob,
                         channel_base_power=args.channel_base_power,
                         add_two_up=args.add_two_up,
                         normalization=args.normalization,
                         activation=args.activation,
                         output_activation=args.out_activation,
                         interp_mode=args.interp_mode,
                         enable_dropout=True,
                         enable_bias=args.enable_bias,
                         is_3d=use_3d,
                         n_input=n_input,
                         n_output=n_output,
                         no_skip=args.no_skip,
                         ord_params=args.ord_params +
                         [device] if args.ord_params is not None else None)
        elif args.nn_arch == 'vae':
            from synthnn.models.vae import VAE
            model = VAE(args.n_layers,
                        args.img_dim,
                        channel_base_power=args.channel_base_power,
                        activation=args.activation,
                        is_3d=use_3d,
                        n_input=n_input,
                        n_output=n_output,
                        latent_size=args.latent_size)
        else:
            raise SynthNNError(
                f'Invalid NN type: {args.nn_arch}. {{nconv, unet, vae}} are the only supported options.'
            )
        model.train(True)
        logger.debug(model)

        # put the model on the GPU if available and desired
        if use_cuda: model.cuda(device=device)
        use_multi = args.multi_gpu and n_gpus > 1 and use_cuda
        if args.multi_gpu and n_gpus <= 1:
            logger.warning(
                'Multi-GPU functionality is not available on your system.')
        if use_multi:
            n_gpus = len(
                args.gpu_selector) if args.gpu_selector is not None else n_gpus
            logger.debug(f'Enabling use of {n_gpus} gpus')
            model = torch.nn.DataParallel(model, device_ids=args.gpu_selector)

        # initialize the weights with user-defined initialization routine
        logger.debug(f'Initializing weights with {args.init}')
        init_weights(model, args.init, args.init_gain)

        # check number of jobs requested and CPUs available
        num_cpus = os.cpu_count()
        if num_cpus < args.n_jobs:
            logger.warning(
                f'Requested more workers than available (n_jobs={args.n_jobs}, # cpus={num_cpus}). '
                f'Setting n_jobs={num_cpus}.')
            args.n_jobs = num_cpus

        # control random cropping patch size (or if used at all)
        if not args.tiff:
            cropper = tfms.RandomCrop3D(
                args.patch_size) if args.net3d else tfms.RandomCrop2D(
                    args.patch_size, args.sample_axis)
            tfm = [cropper] if args.patch_size > 0 else [] if args.net3d else [
                tfms.RandomSlice(args.sample_axis)
            ]
        else:
            tfm = []

        # add data augmentation if desired
        if args.prob is not None:  # currently only support transforms on tiff images
            logger.debug('Adding data augmentation transforms')
            if args.net3d and (args.prob[0] > 0 or args.prob[1] > 0):
                logger.warning(
                    'Cannot do affine or flipping data augmentation with 3d networks'
                )
                args.prob[:2] = 0
                args.rotate, args.translate, args.scale, args.hflip, args.vflip = 0, None, None, False, False
            tfm.extend(
                tfms.get_transforms(args.prob, args.tfm_x, args.tfm_y,
                                    args.rotate, args.translate, args.scale,
                                    args.vflip, args.hflip, args.gamma,
                                    args.gain, args.noise_std))
        else:
            logger.debug(
                'No data augmentation will be used (except random cropping if patch_size > 0)'
            )
            tfm.append(tfms.ToTensor())

        # define dataset and split into training/validation set
        dataset = MultimodalNiftiDataset(args.source_dir, args.target_dir, Compose(tfm)) if not args.tiff else \
                  MultimodalTiffDataset(args.source_dir, args.target_dir, Compose(tfm))
        logger.debug(f'Number of training images: {len(dataset)}')

        if args.valid_source_dir is not None and args.valid_target_dir is not None:
            valid_dataset = MultimodalNiftiDataset(args.valid_source_dir, args.valid_target_dir, Compose(tfm)) if not args.tiff else \
                            MultimodalTiffDataset(args.valid_source_dir, args.valid_target_dir, Compose(tfm))
            logger.debug(f'Number of validation images: {len(valid_dataset)}')
            train_loader = DataLoader(dataset,
                                      batch_size=args.batch_size,
                                      num_workers=args.n_jobs,
                                      shuffle=True,
                                      pin_memory=args.pin_memory)
            validation_loader = DataLoader(valid_dataset,
                                           batch_size=args.batch_size,
                                           num_workers=args.n_jobs,
                                           pin_memory=args.pin_memory)
        else:
            # setup training and validation set
            num_train = len(dataset)
            indices = list(range(num_train))
            split = int(args.valid_split * num_train)
            validation_idx = np.random.choice(indices,
                                              size=split,
                                              replace=False)
            train_idx = list(set(indices) - set(validation_idx))

            train_sampler = SubsetRandomSampler(train_idx)
            validation_sampler = SubsetRandomSampler(validation_idx)

            # set up data loader for nifti images
            train_loader = DataLoader(dataset,
                                      sampler=train_sampler,
                                      batch_size=args.batch_size,
                                      num_workers=args.n_jobs,
                                      pin_memory=args.pin_memory)
            validation_loader = DataLoader(dataset,
                                           sampler=validation_sampler,
                                           batch_size=args.batch_size,
                                           num_workers=args.n_jobs,
                                           pin_memory=args.pin_memory)

        # train the model
        logger.info(f'LR: {args.learning_rate:.5f}')
        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
        if args.lr_scheduler:
            logger.debug('Enabling burn-in cosine annealing LR scheduler')
            scheduler = BurnCosineLR(optimizer, args.n_epochs)
        use_valid = args.valid_split > 0 or (args.valid_source_dir is not None
                                             and args.valid_target_dir
                                             is not None)
        train_losses, validation_losses = [], []
        for t in range(args.n_epochs):
            # training
            t_losses = []
            if use_valid: model.train(True)
            for src, tgt in train_loader:
                src, tgt = src.to(device), tgt.to(device)
                out = model(src)
                loss = criterion(out, tgt, model)
                t_losses.append(loss.item())
                optimizer.zero_grad()
                if args.fp16 and amp_handle is not None:
                    with amp_handle.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                if args.clip is not None:
                    nn.utils.clip_grad_norm_(model.parameters(), args.clip)
                optimizer.step()
            train_losses.append(t_losses)
            if args.lr_scheduler: scheduler.step()

            # validation
            v_losses = []
            if use_valid: model.train(False)
            with torch.set_grad_enabled(False):
                for src, tgt in validation_loader:
                    src, tgt = src.to(device), tgt.to(device)
                    out = model(src)
                    loss = criterion(out, tgt, model)
                    v_losses.append(loss.item())
                validation_losses.append(v_losses)

            if np.any(np.isnan(t_losses)):
                raise SynthNNError(
                    'NaN in training loss, cannot recover. Exiting.')
            log = f'Epoch: {t+1} - Training Loss: {np.mean(t_losses):.2e}'
            if use_valid: log += f', Validation Loss: {np.mean(v_losses):.2e}'
            if args.lr_scheduler: log += f', LR: {scheduler.get_lr()[0]:.2e}'
            logger.info(log)

        # output a config file if desired
        if args.out_config_file is not None:
            write_out_config(args, n_gpus, n_input, n_output, use_3d)

        # save the trained model
        use_config_file = not no_config_file or args.out_config_file is not None
        if use_config_file:
            torch.save(model.state_dict(), args.trained_model)
        else:
            # save the whole model (if changes occur to pytorch, then this model will probably not be loadable)
            logger.warning(
                'Saving the entire model. Preferred to create a config file and only save model weights'
            )
            torch.save(model, args.trained_model)

        # strip multi-gpu specific attributes from saved model (so that it can be loaded easily)
        if use_multi and use_config_file:
            from collections import OrderedDict
            state_dict = torch.load(args.trained_model, map_location='cpu')
            # create new OrderedDict that does not contain `module.`
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:]  # remove `module.`
                new_state_dict[name] = v
            torch.save(new_state_dict, args.trained_model)

        # plot the loss vs epoch (if desired)
        if args.plot_loss is not None:
            plot_error = True if args.n_epochs <= 50 else False
            from synthnn import plot_loss
            if matplotlib.get_backend() != 'agg':
                import matplotlib.pyplot as plt
                plt.switch_backend('agg')
            ax = plot_loss(train_losses,
                           ecolor='maroon',
                           label='Train',
                           plot_error=plot_error)
            _ = plot_loss(validation_losses,
                          filename=args.plot_loss,
                          ecolor='firebrick',
                          ax=ax,
                          label='Validation',
                          plot_error=plot_error)

        return 0
    except Exception as e:
        logger.exception(e)
        return 1