def get_data_augmentation(config:ExperimentConfig): """ get all data augmentation transforms for training """ train_tfms, valid_tfms = [], [] # add data augmentation if desired if config.prob is not None: logger.info('Adding data augmentation transforms') train_tfms.extend(niftitfms.get_transforms(config.prob, config.tfm_x, config.tfm_y, config.rotate, config.translate, config.scale, config.vflip, config.hflip, config.gamma, config.gain, config.noise_pwr, config.block, config.threshold, config.dim == 3, config.mean, config.std, config.color)) if config.mean is not None and config.std is not None: valid_tfms.extend([niftitfms.ToTensor(config.color), niftitfms.Normalize(config.mean, config.std, config.tfm_x, config.tfm_y, config.dim == 3)]) else: logger.info('No data augmentation will be used') train_tfms.append(niftitfms.ToTensor(config.color)) valid_tfms.append(niftitfms.ToTensor(config.color)) # control random cropping patch size (or if used at all) if (config.ext is None or config.ext == 'nii') and config.patch_size is not None: cropper = niftitfms.RandomCrop3D(config.patch_size, config.threshold, config.sample_pct, config.sample_axis) if config.dim == 3 else \ niftitfms.RandomCrop2D(config.patch_size, config.sample_axis, config.threshold) train_tfms.append(cropper if config.patch_size is not None and config.dim == 3 else \ niftitfms.RandomSlice(config.sample_axis)) valid_tfms.append(cropper if config.patch_size is not None and config.dim == 3 else \ niftitfms.RandomSlice(config.sample_axis)) else: if config.patch_size is not None: train_tfms.append(niftitfms.RandomCrop(config.patch_size, config.threshold)) valid_tfms.append(niftitfms.RandomCrop(config.patch_size, config.threshold)) logger.debug(f'Training transforms: {train_tfms}') return train_tfms, valid_tfms
def get_data_augmentation(config: ExperimentConfig): """ get all data augmentation transforms for training """ # control random cropping patch size (or if used at all) if config.ext is None: cropper = niftitfms.RandomCrop3D(config.patch_size) if config.net3d else \ niftitfms.RandomCrop2D(config.patch_size, config.sample_axis) tfms = [cropper] if config.patch_size > 0 else \ [] if config.net3d else \ [niftitfms.RandomSlice(config.sample_axis)] else: tfms = [niftitfms.RandomCrop(config.patch_size) ] if config.patch_size > 0 else [] # add data augmentation if desired if config.prob is not None: logger.info('Adding data augmentation transforms') tfms.extend( niftitfms.get_transforms(config.prob, config.tfm_x, config.tfm_y, config.rotate, config.translate, config.scale, config.vflip, config.hflip, config.gamma, config.gain, config.noise_pwr, config.block, config.mean, config.std)) else: logger.info( 'No data augmentation will be used (except random cropping if patch_size > 0)' ) tfms.append(niftitfms.ToTensor()) return tfms
def main(args=None): args, no_config_file = get_args(args, arg_parser) setup_log(args.verbosity) logger = logging.getLogger(__name__) try: # set random seeds for reproducibility torch.manual_seed(args.seed) np.random.seed(args.seed) # define device to put tensors on device, use_cuda, n_gpus = get_device(args, logger) # import and initialize mixed precision training package amp_handle = None if args.fp16: try: from apex import amp amp_handle = amp.init() except ImportError: logger.info( 'Mixed precision training (i.e., the package `apex`) not available.' ) use_3d = args.net3d and not args.tiff if args.net3d and args.tiff: logger.warning( 'Cannot train a 3D network with TIFF images, creating a 2D network.' ) n_input, n_output = len(args.source_dir), len(args.target_dir) if args.ord_params is not None and n_output > 1: raise SynthNNError( 'Ordinal regression does not support multiple outputs.') # get the desired neural network architecture if args.nn_arch == 'nconv': from synthnn.models.nconvnet import SimpleConvNet logger.warning('The nconv network is for basic testing.') model = SimpleConvNet(args.n_layers, kernel_size=args.kernel_size, dropout_p=args.dropout_prob, n_input=n_input, n_output=n_output, is_3d=use_3d) elif args.nn_arch == 'unet': from synthnn.models.unet import Unet model = Unet(args.n_layers, kernel_size=args.kernel_size, dropout_p=args.dropout_prob, channel_base_power=args.channel_base_power, add_two_up=args.add_two_up, normalization=args.normalization, activation=args.activation, output_activation=args.out_activation, interp_mode=args.interp_mode, enable_dropout=True, enable_bias=args.enable_bias, is_3d=use_3d, n_input=n_input, n_output=n_output, no_skip=args.no_skip, ord_params=args.ord_params + [device] if args.ord_params is not None else None) elif args.nn_arch == 'vae': from synthnn.models.vae import VAE model = VAE(args.n_layers, args.img_dim, channel_base_power=args.channel_base_power, activation=args.activation, is_3d=use_3d, n_input=n_input, n_output=n_output, latent_size=args.latent_size) else: raise SynthNNError( f'Invalid NN type: {args.nn_arch}. {{nconv, unet, vae}} are the only supported options.' ) model.train(True) logger.debug(model) # put the model on the GPU if available and desired if use_cuda: model.cuda(device=device) use_multi = args.multi_gpu and n_gpus > 1 and use_cuda if args.multi_gpu and n_gpus <= 1: logger.warning( 'Multi-GPU functionality is not available on your system.') if use_multi: n_gpus = len( args.gpu_selector) if args.gpu_selector is not None else n_gpus logger.debug(f'Enabling use of {n_gpus} gpus') model = torch.nn.DataParallel(model, device_ids=args.gpu_selector) # initialize the weights with user-defined initialization routine logger.debug(f'Initializing weights with {args.init}') init_weights(model, args.init, args.init_gain) # check number of jobs requested and CPUs available num_cpus = os.cpu_count() if num_cpus < args.n_jobs: logger.warning( f'Requested more workers than available (n_jobs={args.n_jobs}, # cpus={num_cpus}). ' f'Setting n_jobs={num_cpus}.') args.n_jobs = num_cpus # control random cropping patch size (or if used at all) if not args.tiff: cropper = tfms.RandomCrop3D( args.patch_size) if args.net3d else tfms.RandomCrop2D( args.patch_size, args.sample_axis) tfm = [cropper] if args.patch_size > 0 else [] if args.net3d else [ tfms.RandomSlice(args.sample_axis) ] else: tfm = [] # add data augmentation if desired if args.prob is not None: # currently only support transforms on tiff images logger.debug('Adding data augmentation transforms') if args.net3d and (args.prob[0] > 0 or args.prob[1] > 0): logger.warning( 'Cannot do affine or flipping data augmentation with 3d networks' ) args.prob[:2] = 0 args.rotate, args.translate, args.scale, args.hflip, args.vflip = 0, None, None, False, False tfm.extend( tfms.get_transforms(args.prob, args.tfm_x, args.tfm_y, args.rotate, args.translate, args.scale, args.vflip, args.hflip, args.gamma, args.gain, args.noise_std)) else: logger.debug( 'No data augmentation will be used (except random cropping if patch_size > 0)' ) tfm.append(tfms.ToTensor()) # define dataset and split into training/validation set dataset = MultimodalNiftiDataset(args.source_dir, args.target_dir, Compose(tfm)) if not args.tiff else \ MultimodalTiffDataset(args.source_dir, args.target_dir, Compose(tfm)) logger.debug(f'Number of training images: {len(dataset)}') if args.valid_source_dir is not None and args.valid_target_dir is not None: valid_dataset = MultimodalNiftiDataset(args.valid_source_dir, args.valid_target_dir, Compose(tfm)) if not args.tiff else \ MultimodalTiffDataset(args.valid_source_dir, args.valid_target_dir, Compose(tfm)) logger.debug(f'Number of validation images: {len(valid_dataset)}') train_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.n_jobs, shuffle=True, pin_memory=args.pin_memory) validation_loader = DataLoader(valid_dataset, batch_size=args.batch_size, num_workers=args.n_jobs, pin_memory=args.pin_memory) else: # setup training and validation set num_train = len(dataset) indices = list(range(num_train)) split = int(args.valid_split * num_train) validation_idx = np.random.choice(indices, size=split, replace=False) train_idx = list(set(indices) - set(validation_idx)) train_sampler = SubsetRandomSampler(train_idx) validation_sampler = SubsetRandomSampler(validation_idx) # set up data loader for nifti images train_loader = DataLoader(dataset, sampler=train_sampler, batch_size=args.batch_size, num_workers=args.n_jobs, pin_memory=args.pin_memory) validation_loader = DataLoader(dataset, sampler=validation_sampler, batch_size=args.batch_size, num_workers=args.n_jobs, pin_memory=args.pin_memory) # train the model logger.info(f'LR: {args.learning_rate:.5f}') optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) if args.lr_scheduler: logger.debug('Enabling burn-in cosine annealing LR scheduler') scheduler = BurnCosineLR(optimizer, args.n_epochs) use_valid = args.valid_split > 0 or (args.valid_source_dir is not None and args.valid_target_dir is not None) train_losses, validation_losses = [], [] for t in range(args.n_epochs): # training t_losses = [] if use_valid: model.train(True) for src, tgt in train_loader: src, tgt = src.to(device), tgt.to(device) out = model(src) loss = criterion(out, tgt, model) t_losses.append(loss.item()) optimizer.zero_grad() if args.fp16 and amp_handle is not None: with amp_handle.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() if args.clip is not None: nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() train_losses.append(t_losses) if args.lr_scheduler: scheduler.step() # validation v_losses = [] if use_valid: model.train(False) with torch.set_grad_enabled(False): for src, tgt in validation_loader: src, tgt = src.to(device), tgt.to(device) out = model(src) loss = criterion(out, tgt, model) v_losses.append(loss.item()) validation_losses.append(v_losses) if np.any(np.isnan(t_losses)): raise SynthNNError( 'NaN in training loss, cannot recover. Exiting.') log = f'Epoch: {t+1} - Training Loss: {np.mean(t_losses):.2e}' if use_valid: log += f', Validation Loss: {np.mean(v_losses):.2e}' if args.lr_scheduler: log += f', LR: {scheduler.get_lr()[0]:.2e}' logger.info(log) # output a config file if desired if args.out_config_file is not None: write_out_config(args, n_gpus, n_input, n_output, use_3d) # save the trained model use_config_file = not no_config_file or args.out_config_file is not None if use_config_file: torch.save(model.state_dict(), args.trained_model) else: # save the whole model (if changes occur to pytorch, then this model will probably not be loadable) logger.warning( 'Saving the entire model. Preferred to create a config file and only save model weights' ) torch.save(model, args.trained_model) # strip multi-gpu specific attributes from saved model (so that it can be loaded easily) if use_multi and use_config_file: from collections import OrderedDict state_dict = torch.load(args.trained_model, map_location='cpu') # create new OrderedDict that does not contain `module.` new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v torch.save(new_state_dict, args.trained_model) # plot the loss vs epoch (if desired) if args.plot_loss is not None: plot_error = True if args.n_epochs <= 50 else False from synthnn import plot_loss if matplotlib.get_backend() != 'agg': import matplotlib.pyplot as plt plt.switch_backend('agg') ax = plot_loss(train_losses, ecolor='maroon', label='Train', plot_error=plot_error) _ = plot_loss(validation_losses, filename=args.plot_loss, ecolor='firebrick', ax=ax, label='Validation', plot_error=plot_error) return 0 except Exception as e: logger.exception(e) return 1