class ModelTrainerWGANGP: def __init__(self, args, generator, discriminator, gen_optim, disc_optim, train_loader, val_loader, loss_funcs, gen_scheduler=None, disc_scheduler=None): self.logger = get_logger(name=__name__, save_file=args.log_path / args.run_name) # Checking whether inputs are correct. assert isinstance(generator, nn.Module) and isinstance(discriminator, nn.Module), \ '`generator` and `discriminator` must be Pytorch Modules.' assert isinstance(gen_optim, optim.Optimizer) and isinstance(disc_optim, optim.Optimizer), \ '`gen_optim` and `disc_optim` must be Pytorch Optimizers.' assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \ '`train_loader` and `val_loader` must be Pytorch DataLoader objects.' loss_funcs = nn.ModuleDict( loss_funcs ) # Expected to be a dictionary with names and loss functions. if gen_scheduler is not None: if isinstance(gen_scheduler, optim.lr_scheduler.ReduceLROnPlateau): self.metric_gen_scheduler = True elif isinstance(gen_scheduler, optim.lr_scheduler._LRScheduler): self.metric_gen_scheduler = False else: raise TypeError( '`gen_scheduler` must be a Pytorch Learning Rate Scheduler.' ) if disc_scheduler is not None: if isinstance(disc_scheduler, optim.lr_scheduler.ReduceLROnPlateau): self.metric_disc_scheduler = True elif isinstance(disc_scheduler, optim.lr_scheduler._LRScheduler): self.metric_disc_scheduler = False else: raise TypeError( '`disc_scheduler` must be a Pytorch Learning Rate Scheduler.' ) self.generator = generator self.discriminator = discriminator self.gen_optim = gen_optim self.disc_optim = disc_optim self.train_loader = train_loader self.val_loader = val_loader self.loss_funcs = loss_funcs self.gen_scheduler = gen_scheduler self.disc_scheduler = disc_scheduler self.device = args.device self.verbose = args.verbose self.num_epochs = args.num_epochs self.writer = SummaryWriter(str(args.log_path)) self.recon_lambda = torch.tensor(args.recon_lambda, dtype=torch.float32, device=args.device) self.lambda_gp = torch.tensor(args.lambda_gp, dtype=torch.float32, device=args.device) # This will work best if batch size is 1, as is recommended. I don't know whether this generalizes. self.target_real = torch.tensor(1, dtype=torch.float32, device=args.device) self.target_fake = torch.tensor(0, dtype=torch.float32, device=args.device) # Display interval of 0 means no display of validation images on TensorBoard. if args.max_images <= 0: self.display_interval = 0 else: self.display_interval = int( len(self.val_loader.dataset) // (args.max_images * args.batch_size)) self.gen_checkpoint_manager = CheckpointManager( model=self.generator, optimizer=self.gen_optim, mode='min', save_best_only=args.save_best_only, ckpt_dir=args.ckpt_path / 'Generator', max_to_keep=args.max_to_keep) self.disc_checkpoint_manager = CheckpointManager( model=self.discriminator, optimizer=self.disc_optim, mode='min', save_best_only=args.save_best_only, ckpt_dir=args.ckpt_path / 'Discriminator', max_to_keep=args.max_to_keep) # loading from checkpoint if specified. if vars(args).get('gen_prev_model_ckpt'): self.gen_checkpoint_manager.load(load_dir=args.gen_prev_model_ckpt, load_optimizer=False) if vars(args).get('disc_prev_model_ckpt'): self.disc_checkpoint_manager.load( load_dir=args.disc_prev_model_ckpt, load_optimizer=False) def train_model(self): self.logger.info('Beginning Training Loop.') tic_tic = time() for epoch in range(1, self.num_epochs + 1): # 1 based indexing # Training tic = time() train_epoch_loss, train_epoch_loss_components = self._train_epoch( epoch=epoch) toc = int(time() - tic) self._log_epoch_outputs( epoch=epoch, epoch_loss=train_epoch_loss, epoch_loss_components=train_epoch_loss_components, elapsed_secs=toc, training=True) # Validation tic = time() val_epoch_loss, val_epoch_loss_components = self._val_epoch( epoch=epoch) toc = int(time() - tic) self._log_epoch_outputs( epoch=epoch, epoch_loss=val_epoch_loss, epoch_loss_components=val_epoch_loss_components, elapsed_secs=toc, training=False) self.gen_checkpoint_manager.save(metric=val_epoch_loss, verbose=True) self.disc_checkpoint_manager.save(metric=val_epoch_loss, verbose=True) if self.gen_scheduler is not None: if self.metric_gen_scheduler: # If the scheduler is a metric based scheduler, include metrics. self.gen_scheduler.step(metrics=val_epoch_loss) else: self.gen_scheduler.step() if self.disc_scheduler is not None: if self.metric_disc_scheduler: self.disc_scheduler.step(metrics=val_epoch_loss) else: self.disc_scheduler.step() # Finishing Training Loop self.writer.close() # Flushes remaining data to TensorBoard. toc_toc = int(time() - tic_tic) self.logger.info( f'Finishing Training Loop. Total elapsed time: ' f'{toc_toc // 3600} hr {(toc_toc // 60) % 60} min {toc_toc % 60} sec.' ) def _train_step(self, inputs, targets, extra_params): assert inputs.size() == targets.size( ), 'input and target sizes do not match' inputs = inputs.to(self.device) targets = targets.to(self.device, non_blocking=True) # Train discriminator self.disc_optim.zero_grad() recons = self.generator(inputs) pred_fake = self.discriminator(recons.detach( )) # Generated fake image going through discriminator. pred_real = self.discriminator( targets) # Real image going through discriminator. gradient_penalty = compute_gradient_penalty(self.discriminator, targets, recons.detach()) disc_loss = pred_fake.mean() - pred_real.mean( ) + self.lambda_gp * gradient_penalty disc_loss.backward() self.disc_optim.step() # Train Generator self.gen_optim.zero_grad() pred_fake = self.discriminator(recons) gen_loss = -pred_fake.mean() recon_loss = self.loss_funcs['recon_loss_func'](recons, targets) total_gen_loss = gen_loss + self.recon_lambda * recon_loss total_gen_loss.backward() self.gen_optim.step() # Just using reconstruction loss since it is the most meaningful. step_loss = recon_loss step_loss_components = { 'pred_fake': pred_fake.mean(), 'pred_real': pred_real.mean(), 'gradient_penalty': gradient_penalty, 'disc_loss': disc_loss, 'gen_loss': gen_loss, 'recon_loss': recon_loss, 'total_gen_loss': total_gen_loss } return step_loss, step_loss_components def _train_epoch(self, epoch): self.generator.train() self.discriminator.train() torch.autograd.set_grad_enabled(True) epoch_loss = list( ) # Appending values to list due to numerical underflow. epoch_loss_components = defaultdict(list) # labels are fully sampled coil-wise images, not rss or esc. train_len = len( self.train_loader.dataset) # Adding progress bar for convenience. for step, (inputs, targets, extra_params) in tqdm(enumerate(self.train_loader, start=1), total=train_len): step_loss, step_loss_components = self._train_step( inputs, targets, extra_params) # Perhaps not elegant, but NaN values make this necessary. epoch_loss.append(step_loss.detach()) for key, value in step_loss_components.items(): epoch_loss_components[key].append(value.detach()) if self.verbose: self._log_step_outputs(epoch, step, step_loss, step_loss_components, training=True) # Return as scalar value and dict respectively. Remove the inner lists. return self._get_epoch_outputs(epoch, epoch_loss, epoch_loss_components, training=True) def _val_step(self, inputs, targets, extra_params): """ All extra parameters are to be placed in extra_params. This makes the system more flexible. """ inputs = inputs.to(self.device) targets = targets.to(self.device, non_blocking=True) recons = self.generator(inputs) # Discriminator part. # pred_fake = self.discriminator(recons) # Generated fake image going through discriminator. # pred_real = self.discriminator(targets) # Real image going through discriminator. # gradient_penalty = compute_gradient_penalty(self.discriminator, targets, recons) # disc_loss = pred_fake.mean() - pred_real.mean() + self.lambda_gp * gradient_penalty # Generator part. # gen_loss = -pred_fake.mean() recon_loss = self.loss_funcs['recon_loss_func'](recons, targets) # total_gen_loss = gen_loss + self.recon_lambda * recon_loss step_loss = recon_loss step_loss_components = {'recon_loss': recon_loss} return recons, step_loss, step_loss_components def _val_epoch(self, epoch): self.generator.eval() self.discriminator.eval() torch.autograd.set_grad_enabled(False) epoch_loss = list( ) # Appending values to list due to numerical underflow. epoch_loss_components = defaultdict(list) val_len = len(self.val_loader.dataset) for step, (inputs, targets, extra_params) in tqdm(enumerate(self.val_loader, start=1), total=val_len): recons, step_loss, step_loss_components = self._val_step( inputs, targets, extra_params) # Append to list to prevent errors from NaN and Inf values. epoch_loss.append(step_loss) for key, value in step_loss_components.items(): epoch_loss_components[key].append(value) if self.verbose: self._log_step_outputs(epoch, step, step_loss, step_loss_components, training=False) # Save images to TensorBoard. # Condition ensures that self.display_interval != 0 and that the step is right for display. if self.display_interval and (step % self.display_interval == 0): recon_grid, target_grid, delta_grid = make_grid_triplet( recons, targets) self.writer.add_image(f'Recons/{step}', recon_grid, global_step=epoch, dataformats='HW') self.writer.add_image(f'Targets/{step}', target_grid, global_step=epoch, dataformats='HW') self.writer.add_image(f'Deltas/{step}', delta_grid, global_step=epoch, dataformats='HW') return self._get_epoch_outputs(epoch, epoch_loss, epoch_loss_components, training=False) def _get_epoch_outputs(self, epoch, epoch_loss, epoch_loss_components, training=True): mode = 'Training' if training else 'Validation' num_slices = len(self.train_loader.dataset) if training else len( self.val_loader.dataset) # Checking for nan values. epoch_loss = torch.stack(epoch_loss) is_finite = torch.isfinite(epoch_loss) num_nans = (is_finite.size(0) - is_finite.sum()).item() if num_nans > 0: self.logger.warning( f'Epoch {epoch} {mode}: {num_nans} NaN values present in {num_slices} slices' ) epoch_loss = torch.mean(epoch_loss[is_finite]).item() else: epoch_loss = torch.mean(epoch_loss).item() for key, value in epoch_loss_components.items(): epoch_loss_component = torch.stack(value) is_finite = torch.isfinite(epoch_loss_component) num_nans = (is_finite.size(0) - is_finite.sum()).item() if num_nans > 0: self.logger.warning( f'Epoch {epoch} {mode} {key}: {num_nans} NaN values present in {num_slices} slices' ) epoch_loss_components[key] = torch.mean( epoch_loss_component[is_finite]).item() else: epoch_loss_components[key] = torch.mean( epoch_loss_component).item() return epoch_loss, epoch_loss_components def _log_step_outputs(self, epoch, step, step_loss, step_loss_components, training=True): mode = 'Training' if training else 'Validation' self.logger.info( f'Epoch {epoch:03d} Step {step:03d} {mode} loss: {step_loss.item():.4e}' ) for key, value in step_loss_components.items(): self.logger.info( f'Epoch {epoch:03d} Step {step:03d}: {mode} {key}: {value.item():.4e}' ) def _log_epoch_outputs(self, epoch, epoch_loss, epoch_loss_components, elapsed_secs, training=True): mode = 'Training' if training else 'Validation' self.logger.info( f'Epoch {epoch:03d} {mode}. loss: {epoch_loss:.4e}, ' f'Time: {elapsed_secs // 60} min {elapsed_secs % 60} sec') self.writer.add_scalar(f'{mode}_epoch_loss', scalar_value=epoch_loss, global_step=epoch) for key, value in epoch_loss_components.items(): self.logger.info(f'Epoch {epoch:03d} {mode}. {key}: {value}') self.writer.add_scalar(f'{mode}_epoch_{key}', scalar_value=value, global_step=epoch)
class ModelTrainerK2CI: """ Model Trainer for k-space learning or complex image learning with losses in complex image domains and real valued image domains. All learning occurs in k-space or complex image domains while all losses are obtained from either complex images or real-valued images. Expects multiple losses to be present for the image domain loss. Expects only 1 component for complex image loss, which will be MSE, though this not set explicitly. Also, if the image domain loss has multiple components, the total weighted image loss is expected to return a dictionary with 'img_loss' as the key for the weighted loss. """ def __init__(self, args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_transform, losses, scheduler=None): # Allow multiple processes to access tensors on GPU. Add checking for multiple continuous runs. if multiprocessing.get_start_method(allow_none=True) is None: multiprocessing.set_start_method(method='spawn') self.logger = get_logger(name=__name__, save_file=args.log_path / args.run_name) # Checking whether inputs are correct. assert isinstance(model, nn.Module), '`model` must be a Pytorch Module.' assert isinstance(optimizer, optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.' assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \ '`train_loader` and `val_loader` must be Pytorch DataLoader objects.' assert callable(input_train_transform) and callable(input_val_transform), \ 'input_transforms must be callable functions.' # I think this would be best practice. assert isinstance(output_transform, nn.Module), '`output_transform` must be a Pytorch Module.' # 'losses' is expected to be a dictionary. # Even composite losses should be a single loss module with a dictionary output. losses = nn.ModuleDict(losses) if scheduler is not None: if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): self.metric_scheduler = True elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): self.metric_scheduler = False else: raise TypeError('`scheduler` must be a Pytorch Learning Rate Scheduler.') # Display interval of 0 means no display of validation images on TensorBoard. if args.max_images <= 0: self.display_interval = 0 else: self.display_interval = int(len(val_loader.dataset) // (args.max_images * args.batch_size)) self.checkpointer = CheckpointManager(model, optimizer, mode='min', save_best_only=args.save_best_only, ckpt_dir=args.ckpt_path, max_to_keep=args.max_to_keep) # loading from checkpoint if specified. if vars(args).get('prev_model_ckpt'): self.checkpointer.load(load_dir=args.prev_model_ckpt, load_optimizer=False) self.model = model self.optimizer = optimizer self.train_loader = train_loader self.val_loader = val_loader self.input_train_transform = input_train_transform self.input_val_transform = input_val_transform self.output_transform = output_transform self.losses = losses self.scheduler = scheduler self.verbose = args.verbose self.num_epochs = args.num_epochs self.smoothing_factor = args.smoothing_factor self.use_slice_metrics = args.use_slice_metrics self.img_lambda = torch.tensor(args.img_lambda, dtype=torch.float32, device=args.device) self.writer = SummaryWriter(str(args.log_path)) def train_model(self): tic_tic = time() self.logger.info('Beginning Training Loop.') for epoch in range(1, self.num_epochs + 1): # 1 based indexing of epochs. tic = time() # Training train_epoch_loss, train_epoch_metrics = self._train_epoch(epoch=epoch) toc = int(time() - tic) self._log_epoch_outputs(epoch, train_epoch_loss, train_epoch_metrics, elapsed_secs=toc, training=True) tic = time() # Validation val_epoch_loss, val_epoch_metrics = self._val_epoch(epoch=epoch) toc = int(time() - tic) self._log_epoch_outputs(epoch, val_epoch_loss, val_epoch_metrics, elapsed_secs=toc, training=False) self.checkpointer.save(metric=val_epoch_loss, verbose=True) if self.scheduler is not None: if self.metric_scheduler: # If the scheduler is a metric based scheduler, include metrics. self.scheduler.step(metrics=val_epoch_loss) else: self.scheduler.step() self.writer.close() # Flushes remaining data to TensorBoard. toc_toc = int(time() - tic_tic) self.logger.info(f'Finishing Training Loop. Total elapsed time: ' f'{toc_toc // 3600} hr {(toc_toc // 60) % 60} min {toc_toc % 60} sec.') def _train_epoch(self, epoch): self.model.train() torch.autograd.set_grad_enabled(True) epoch_loss = list() # Appending values to list due to numerical underflow. epoch_metrics = defaultdict(list) data_loader = enumerate(self.train_loader, start=1) if not self.verbose: # tqdm has to be on the outermost iterator to function properly. data_loader = tqdm(data_loader, total=len(self.train_loader.dataset)) # 'targets' is a dictionary containing k-space targets, cmg_targets, and img_targets. for step, data in data_loader: with torch.no_grad(): # Data pre-processing should be done without gradients. inputs, targets, extra_params = self.input_train_transform(*data) # Maybe pass as tuple later? # 'recons' is a dictionary containing k-space, complex image, and real image reconstructions. recons, step_loss, step_metrics = self._train_step(inputs, targets, extra_params) epoch_loss.append(step_loss.detach()) # Perhaps not elegant, but underflow makes this necessary. # Gradients are not calculated so as to boost speed and remove weird errors. with torch.no_grad(): # Update epoch loss and metrics if self.use_slice_metrics: slice_metrics = self._get_slice_metrics(recons['img_recons'], targets['img_targets']) step_metrics.update(slice_metrics) [epoch_metrics[key].append(value.detach()) for key, value in step_metrics.items()] if self.verbose: self._log_step_outputs(epoch, step, step_loss, step_metrics, training=True) # Converted to scalar and dict with scalar forms. return self._get_epoch_outputs(epoch, epoch_loss, epoch_metrics, training=True) def _train_step(self, inputs, targets, extra_params): self.optimizer.zero_grad() outputs = self.model(inputs) recons = self.output_transform(outputs, targets, extra_params) # Expects a single loss. No loss decomposition within domain implemented yet. cmg_loss = self.losses['cmg_loss'](recons['cmg_recons'], targets['cmg_targets']) img_loss = self.losses['img_loss'](recons['img_recons'], targets['img_targets']) # If img_loss is a dict, it is expected to contain all its component losses as key, value pairs. if isinstance(img_loss, dict): step_metrics = img_loss img_loss = step_metrics['img_loss'] else: # img_loss is expected to be a Tensor if not a dict. step_metrics = {'img_loss': img_loss} step_loss = cmg_loss + self.img_lambda * img_loss step_loss.backward() self.optimizer.step() step_metrics['cmg_loss'] = cmg_loss return recons, step_loss, step_metrics def _val_epoch(self, epoch): self.model.eval() torch.autograd.set_grad_enabled(False) epoch_loss = list() epoch_metrics = defaultdict(list) # 1 based indexing for steps. data_loader = enumerate(self.val_loader, start=1) if not self.verbose: data_loader = tqdm(data_loader, total=len(self.val_loader.dataset)) # 'targets' is a dictionary containing k-space targets, cmg_targets, and img_targets. for step, data in data_loader: inputs, targets, extra_params = self.input_val_transform(*data) # 'recons' is a dictionary containing k-space, complex image, and real image reconstructions. recons, step_loss, step_metrics = self._val_step(inputs, targets, extra_params) epoch_loss.append(step_loss.detach()) if self.use_slice_metrics: slice_metrics = self._get_slice_metrics(recons['img_recons'], targets['img_targets']) step_metrics.update(slice_metrics) [epoch_metrics[key].append(value.detach()) for key, value in step_metrics.items()] if self.verbose: self._log_step_outputs(epoch, step, step_loss, step_metrics, training=False) # Save images to TensorBoard. # Condition ensures that self.display_interval != 0 and that the step is right for display. # This numbering scheme seems to have issues for certain numbers. # Please check cases when there is no remainder. if self.display_interval and (step % self.display_interval == 0): img_recon_grid, img_target_grid, img_delta_grid = \ make_grid_triplet(recons['img_recons'], targets['img_targets']) kspace_recon_grid = make_k_grid(recons['kspace_recons'], self.smoothing_factor) kspace_target_grid = make_k_grid(targets['kspace_targets'], self.smoothing_factor) self.writer.add_image(f'k-space_Recons/{step}', kspace_recon_grid, epoch, dataformats='HW') self.writer.add_image(f'k-space_Targets/{step}', kspace_target_grid, epoch, dataformats='HW') self.writer.add_image(f'Image_Recons/{step}', img_recon_grid, epoch, dataformats='HW') self.writer.add_image(f'Image_Targets/{step}', img_target_grid, epoch, dataformats='HW') self.writer.add_image(f'Image_Deltas/{step}', img_delta_grid, epoch, dataformats='HW') epoch_loss, epoch_metrics = self._get_epoch_outputs(epoch, epoch_loss, epoch_metrics, training=False) return epoch_loss, epoch_metrics def _val_step(self, inputs, targets, extra_params): outputs = self.model(inputs) recons = self.output_transform(outputs, targets, extra_params) cmg_loss = self.losses['cmg_loss'](recons['cmg_recons'], targets['cmg_targets']) img_loss = self.losses['img_loss'](recons['img_recons'], targets['img_targets']) # If img_loss is a dict, it is expected to contain all its component losses as key, value pairs. if isinstance(img_loss, dict): step_metrics = img_loss img_loss = step_metrics['img_loss'] else: # img_loss is expected to be a Tensor if not a dict. step_metrics = {'img_loss': img_loss} step_loss = cmg_loss + self.img_lambda * img_loss step_metrics['cmg_loss'] = cmg_loss return recons, step_loss, step_metrics @staticmethod def _get_slice_metrics(img_recons, img_targets): img_recons = img_recons.detach() # Just in case. img_targets = img_targets.detach() max_range = img_targets.max() - img_targets.min() slice_ssim = ssim_loss(img_recons, img_targets, max_val=max_range) slice_psnr = psnr(img_recons, img_targets, data_range=max_range) slice_nmse = nmse(img_recons, img_targets) return {'slice_ssim': slice_ssim, 'slice_nmse': slice_nmse, 'slice_psnr': slice_psnr} def _get_epoch_outputs(self, epoch, epoch_loss, epoch_metrics, training=True): mode = 'Training' if training else 'Validation' num_slices = len(self.train_loader.dataset) if training else len(self.val_loader.dataset) # Checking for nan values. epoch_loss = torch.stack(epoch_loss) is_finite = torch.isfinite(epoch_loss) num_nans = (is_finite.size(0) - is_finite.sum()).item() if num_nans > 0: self.logger.warning(f'Epoch {epoch} {mode}: {num_nans} NaN values present in {num_slices} slices') # Turn on anomaly detection for finding where the nan values are. torch.autograd.set_detect_anomaly(True) epoch_loss = torch.mean(epoch_loss[is_finite]).item() else: epoch_loss = torch.mean(epoch_loss).item() for key, value in epoch_metrics.items(): epoch_metric = torch.stack(value) is_finite = torch.isfinite(epoch_metric) num_nans = (is_finite.size(0) - is_finite.sum()).item() if num_nans > 0: self.logger.warning(f'Epoch {epoch} {mode} {key}: {num_nans} NaN values present in {num_slices} slices') epoch_metrics[key] = torch.mean(epoch_metric[is_finite]).item() else: epoch_metrics[key] = torch.mean(epoch_metric).item() return epoch_loss, epoch_metrics def _log_step_outputs(self, epoch, step, step_loss, step_metrics, training=True): mode = 'Training' if training else 'Validation' self.logger.info(f'Epoch {epoch:03d} Step {step:03d} {mode} loss: {step_loss.item():.4e}') for key, value in step_metrics.items(): self.logger.info(f'Epoch {epoch:03d} Step {step:03d}: {mode} {key}: {value.item():.4e}') def _log_epoch_outputs(self, epoch, epoch_loss, epoch_metrics, elapsed_secs, training=True): mode = 'Training' if training else 'Validation' self.logger.info(f'Epoch {epoch:03d} {mode}. loss: {epoch_loss:.4e}, ' f'Time: {elapsed_secs // 60} min {elapsed_secs % 60} sec') self.writer.add_scalar(f'{mode}_epoch_loss', scalar_value=epoch_loss, global_step=epoch) for key, value in epoch_metrics.items(): self.logger.info(f'Epoch {epoch:03d} {mode}. {key}: {value:.4e}') self.writer.add_scalar(f'{mode}_epoch_{key}', scalar_value=value, global_step=epoch)
class ModelTrainerIMG: """ Model trainer for real-valued image domain losses. This model trainer can accept k-space an semi-k-space, regardless of weighting. Both complex and real-valued image domain losses can be calculated. """ def __init__(self, args, model, optimizer, train_loader, val_loader, input_train_transform, input_val_transform, output_train_transform, output_val_transform, losses, scheduler=None): # Allow multiple processes to access tensors on GPU. Add checking for multiple continuous runs. if multiprocessing.get_start_method(allow_none=True) is None: multiprocessing.set_start_method(method='spawn') self.logger = get_logger(name=__name__, save_file=args.log_path / args.run_name) # Checking whether inputs are correct. assert isinstance(model, nn.Module), '`model` must be a Pytorch Module.' assert isinstance( optimizer, optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.' assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \ '`train_loader` and `val_loader` must be Pytorch DataLoader objects.' assert callable(input_train_transform) and callable(input_val_transform), \ 'input_transforms must be callable functions.' # I think this would be best practice. assert isinstance(output_train_transform, nn.Module) and isinstance(output_val_transform, nn.Module), \ '`output_train_transform` and `output_val_transform` must be Pytorch Modules.' # 'losses' is expected to be a dictionary. # Even composite losses should be a single loss module with a tuple as its output. losses = nn.ModuleDict(losses) if scheduler is not None: if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): self.metric_scheduler = True elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): self.metric_scheduler = False else: raise TypeError( '`scheduler` must be a Pytorch Learning Rate Scheduler.') # Display interval of 0 means no display of validation images on TensorBoard. if args.max_images <= 0: self.display_interval = 0 else: self.display_interval = int( len(val_loader.dataset) // (args.max_images * args.batch_size)) self.manager = CheckpointManager(model, optimizer, mode='min', save_best_only=args.save_best_only, ckpt_dir=args.ckpt_path, max_to_keep=args.max_to_keep) # loading from checkpoint if specified. if vars(args).get('prev_model_ckpt'): self.manager.load(load_dir=args.prev_model_ckpt, load_optimizer=False) self.model = model self.optimizer = optimizer self.train_loader = train_loader self.val_loader = val_loader self.input_train_transform = input_train_transform self.input_val_transform = input_val_transform self.output_train_transform = output_train_transform self.output_val_transform = output_val_transform self.losses = losses self.scheduler = scheduler self.writer = SummaryWriter(str(args.log_path)) self.verbose = args.verbose self.num_epochs = args.num_epochs self.smoothing_factor = args.smoothing_factor self.shrink_scale = args.shrink_scale self.use_slice_metrics = args.use_slice_metrics # This part should get SSIM, not 1 - SSIM. self.ssim = SSIM(filter_size=7).to( device=args.device) # Needed to cache the kernel. # Logging all components of the Model Trainer. # Train and Val input and output transforms are assumed to use the same input transform class. self.logger.info(f''' Summary of Model Trainer Components: Model: {get_class_name(model)}. Optimizer: {get_class_name(optimizer)}. Input Transforms: {get_class_name(input_val_transform)}. Output Transform: {get_class_name(output_val_transform)}. Image Domain Loss: {get_class_name(losses['img_loss'])}. Learning-Rate Scheduler: {get_class_name(scheduler)}. ''') # This part has parts different for IMG and CMG losses!! def train_model(self): tic_tic = time() self.logger.info('Beginning Training Loop.') for epoch in range(1, self.num_epochs + 1): # 1 based indexing of epochs. tic = time() # Training train_epoch_loss, train_epoch_metrics = self._train_epoch( epoch=epoch) toc = int(time() - tic) self._log_epoch_outputs(epoch, train_epoch_loss, train_epoch_metrics, elapsed_secs=toc, training=True) tic = time() # Validation val_epoch_loss, val_epoch_metrics = self._val_epoch(epoch=epoch) toc = int(time() - tic) self._log_epoch_outputs(epoch, val_epoch_loss, val_epoch_metrics, elapsed_secs=toc, training=False) self.manager.save(metric=val_epoch_loss, verbose=True) if self.scheduler is not None: if self.metric_scheduler: # If the scheduler is a metric based scheduler, include metrics. self.scheduler.step(metrics=val_epoch_loss) else: self.scheduler.step() self.writer.close() # Flushes remaining data to TensorBoard. toc_toc = int(time() - tic_tic) self.logger.info( f'Finishing Training Loop. Total elapsed time: ' f'{toc_toc // 3600} hr {(toc_toc // 60) % 60} min {toc_toc % 60} sec.' ) def _train_epoch(self, epoch): self.model.train() torch.autograd.set_grad_enabled(True) epoch_loss = list( ) # Appending values to list due to numerical underflow and NaN values. epoch_metrics = defaultdict(list) data_loader = enumerate(self.train_loader, start=1) if not self.verbose: # tqdm has to be on the outermost iterator to function properly. data_loader = tqdm( data_loader, total=len( self.train_loader.dataset)) # Should divide by batch size. for step, data in data_loader: # Data pre-processing is expected to have gradient calculations removed inside already. inputs, targets, extra_params = self.input_train_transform(*data) # 'recons' is a dictionary containing k-space, complex image, and real image reconstructions. recons, step_loss, step_metrics = self._train_step( inputs, targets, extra_params) epoch_loss.append(step_loss.detach( )) # Perhaps not elegant, but underflow makes this necessary. # Gradients are not calculated so as to boost speed and remove weird errors. with torch.no_grad(): # Update epoch loss and metrics if self.use_slice_metrics: slice_metrics = self._get_slice_metrics( recons, targets, extra_params) step_metrics.update(slice_metrics) [ epoch_metrics[key].append(value.detach()) for key, value in step_metrics.items() ] if self.verbose: self._log_step_outputs(epoch, step, step_loss, step_metrics, training=True) # Converted to scalar and dict with scalar values respectively. return self._get_epoch_outputs(epoch, epoch_loss, epoch_metrics, training=True) def _train_step(self, inputs, targets, extra_params): self.optimizer.zero_grad() outputs = self.model(inputs) recons = self.output_train_transform(outputs, targets, extra_params) step_loss, step_metrics = self._step(recons, targets, extra_params) step_loss.backward() self.optimizer.step() return recons, step_loss, step_metrics def _val_epoch(self, epoch): self.model.eval() torch.autograd.set_grad_enabled(False) epoch_loss = list() epoch_metrics = defaultdict(list) # 1 based indexing for steps. data_loader = enumerate(self.val_loader, start=1) if not self.verbose: data_loader = tqdm(data_loader, total=len(self.val_loader.dataset)) for step, data in data_loader: inputs, targets, extra_params = self.input_val_transform(*data) recons, step_loss, step_metrics = self._val_step( inputs, targets, extra_params) epoch_loss.append(step_loss.detach()) if self.use_slice_metrics: slice_metrics = self._get_slice_metrics( recons, targets, extra_params) step_metrics.update(slice_metrics) [ epoch_metrics[key].append(value.detach()) for key, value in step_metrics.items() ] if self.verbose: self._log_step_outputs(epoch, step, step_loss, step_metrics, training=False) # Visualize images on TensorBoard. self._visualize_images(recons, targets, extra_params, epoch, step, training=False) # Converted to scalar and dict with scalar values respectively. return self._get_epoch_outputs(epoch, epoch_loss, epoch_metrics, training=False) def _val_step(self, inputs, targets, extra_params): outputs = self.model(inputs) recons = self.output_val_transform(outputs, targets, extra_params) step_loss, step_metrics = self._step(recons, targets, extra_params) return recons, step_loss, step_metrics def _step(self, recons, targets, extra_params): step_loss = self.losses['img_loss'](recons['img_recons'], targets['img_targets']) # If img_loss is a tuple, it is expected to contain all its component losses as a dict in its second element. step_metrics = dict() if isinstance(step_loss, tuple): step_loss, step_metrics = step_loss acc = extra_params["acceleration"] if step_metrics: # This has to be checked before anything is added to step_metrics. for key, value in step_metrics.items(): step_metrics[f'acc_{acc}_{key}'] = value step_metrics[f'acc_{acc}_loss'] = step_loss return step_loss, step_metrics def _visualize_images(self, recons, targets, extra_params, epoch, step, training=False): mode = 'Training' if training else 'Validation' # This numbering scheme seems to have issues for certain numbers. # Please check cases when there is no remainder. if self.display_interval and (step % self.display_interval == 0): img_recon_grid = make_img_grid(recons['img_recons'], self.shrink_scale) # The delta image is obtained by subtracting at the complex image, not the real valued image. delta_image = complex_abs(targets['cmg_targets'] - recons['cmg_recons']) delta_img_grid = make_img_grid(delta_image, self.shrink_scale) acc = extra_params['acceleration'] kwargs = dict(global_step=epoch, dataformats='HW') self.writer.add_image(f'{mode} Image Recons/{acc}/{step}', img_recon_grid, **kwargs) self.writer.add_image(f'{mode} Delta Image/{acc}/{step}', delta_img_grid, **kwargs) if 'kspace_recons' in recons: kspace_recon_grid = make_k_grid(recons['kspace_recons'], self.smoothing_factor, self.shrink_scale) self.writer.add_image(f'{mode} k-space Recons/{acc}/{step}', kspace_recon_grid, **kwargs) # Adding RSS images of reconstructions and targets. if 'rss_recons' in recons: recon_rss = standardize_image(recons['rss_recons']) delta_rss = standardize_image(make_rss_slice(delta_image)) self.writer.add_image(f'{mode} RSS Recons/{acc}/{step}', recon_rss, **kwargs) self.writer.add_image(f'{mode} RSS Delta/{acc}/{step}', delta_rss, **kwargs) if 'semi_kspace_recons' in recons: semi_kspace_recon_grid = make_k_grid( recons['semi_kspace_recons'], self.smoothing_factor, self.shrink_scale) self.writer.add_image( f'{mode} semi-k-space Recons/{acc}/{step}', semi_kspace_recon_grid, **kwargs) if epoch == 1: # Maybe add input images too later on. img_target_grid = make_img_grid(targets['img_targets'], self.shrink_scale) self.writer.add_image(f'{mode} Image Targets/{acc}/{step}', img_target_grid, **kwargs) if 'kspace_targets' in targets: kspace_target_grid = \ make_k_grid(targets['kspace_targets'], self.smoothing_factor, self.shrink_scale) self.writer.add_image( f'{mode} k-space Targets/{acc}/{step}', kspace_target_grid, **kwargs) if 'img_inputs' in targets: # Not actually the input but what the input looks like as an image. img_grid = make_img_grid(targets['img_inputs'], self.shrink_scale) self.writer.add_image( f'{mode} Inputs as Images/{acc}/{step}', img_grid, **kwargs) if 'rss_targets' in targets: target_rss = standardize_image(targets['rss_targets']) self.writer.add_image(f'{mode} RSS Targets/{acc}/{step}', target_rss, **kwargs) if 'semi_kspace_targets' in targets: semi_kspace_target_grid = make_k_grid( targets['semi_kspace_targets'], self.smoothing_factor, self.shrink_scale) self.writer.add_image( f'{mode} semi-k-space Targets/{acc}/{step}', semi_kspace_target_grid, **kwargs) def _get_slice_metrics(self, recons, targets, extra_params): img_recons = recons['img_recons'].detach() # Just in case. img_targets = targets['img_targets'].detach() max_range = img_targets.max() - img_targets.min() slice_ssim = self.ssim(img_recons, img_targets) slice_psnr = psnr(img_recons, img_targets, data_range=max_range) slice_nmse = nmse(img_recons, img_targets) slice_metrics = { 'slice/ssim': slice_ssim, 'slice/nmse': slice_nmse, 'slice/psnr': slice_psnr } if 'rss_recons' in recons: rss_recons = recons['rss_recons'].detach() rss_targets = targets['rss_targets'].detach() max_range = rss_targets.max() - rss_targets.min() rss_ssim = self.ssim(rss_recons, rss_targets) rss_psnr = psnr(rss_recons, rss_targets, data_range=max_range) rss_nmse = nmse(rss_recons, rss_targets) slice_metrics['rss/ssim'] = rss_ssim slice_metrics['rss/psnr'] = rss_psnr slice_metrics['rss/nmse'] = rss_nmse else: rss_ssim = rss_psnr = rss_nmse = 0 # Additional metrics for separating between acceleration factors. if 'acceleration' in extra_params: acc = extra_params["acceleration"] slice_metrics[f'slice_acc_{acc}/ssim'] = slice_ssim slice_metrics[f'slice_acc_{acc}/psnr'] = slice_psnr slice_metrics[f'slice_acc_{acc}/nmse'] = slice_nmse if 'rss_recons' in recons: slice_metrics[f'rss_acc_{acc}/ssim'] = rss_ssim slice_metrics[f'rss_acc_{acc}/psnr'] = rss_psnr slice_metrics[f'rss_acc_{acc}/nmse'] = rss_nmse return slice_metrics def _get_epoch_outputs(self, epoch, epoch_loss, epoch_metrics, training=True): mode = 'Training' if training else 'Validation' num_slices = len(self.train_loader.dataset) if training else len( self.val_loader.dataset) # Checking for nan values. epoch_loss = torch.stack(epoch_loss) is_finite = torch.isfinite(epoch_loss) num_nans = (is_finite.size(0) - is_finite.sum()).item() if num_nans > 0: self.logger.warning( f'Epoch {epoch} {mode}: {num_nans} NaN values present in {num_slices} slices.' f'Turning on anomaly detection.') # Turn on anomaly detection for finding where the nan values are. torch.autograd.set_detect_anomaly(True) epoch_loss = torch.mean(epoch_loss[is_finite]).item() else: epoch_loss = torch.mean(epoch_loss).item() for key, value in epoch_metrics.items(): epoch_metric = torch.stack(value) is_finite = torch.isfinite(epoch_metric) num_nans = (is_finite.size(0) - is_finite.sum()).item() if num_nans > 0: self.logger.warning( f'Epoch {epoch} {mode} {key}: {num_nans} NaN values present in {num_slices} slices.' f'Turning on anomaly detection.') epoch_metrics[key] = torch.mean(epoch_metric[is_finite]).item() else: epoch_metrics[key] = torch.mean(epoch_metric).item() return epoch_loss, epoch_metrics def _log_step_outputs(self, epoch, step, step_loss, step_metrics, training=True): mode = 'Training' if training else 'Validation' self.logger.info( f'Epoch {epoch:03d} Step {step:03d} {mode} loss: {step_loss.item():.4e}' ) for key, value in step_metrics.items(): self.logger.info( f'Epoch {epoch:03d} Step {step:03d}: {mode} {key}: {value.item():.4e}' ) def _log_epoch_outputs(self, epoch, epoch_loss, epoch_metrics, elapsed_secs, training=True): mode = 'Training' if training else 'Validation' self.logger.info( f'Epoch {epoch:03d} {mode}. loss: {epoch_loss:.4e}, ' f'Time: {elapsed_secs // 60} min {elapsed_secs % 60} sec') self.writer.add_scalar(f'{mode} epoch_loss', scalar_value=epoch_loss, global_step=epoch) for key, value in epoch_metrics.items(): self.logger.info(f'Epoch {epoch:03d} {mode}. {key}: {value:.4e}') # Very important whether it is mode_~~ or mode/~~. if 'loss' in key: self.writer.add_scalar(f'{mode}/epoch_{key}', scalar_value=value, global_step=epoch) else: self.writer.add_scalar(f'{mode}_epoch_{key}', scalar_value=value, global_step=epoch) if not training: # Record learning rate. for idx, group in enumerate(self.optimizer.param_groups, start=1): self.writer.add_scalar(f'learning_rate_{idx}', group['lr'], global_step=epoch)
class ModelTrainerC2C: """ Model trainer for Complex Image Learning. """ def __init__(self, args, model, optimizer, train_loader, val_loader, post_processing, c_loss, metrics=None, scheduler=None): multiprocessing.set_start_method(method='spawn') self.logger = get_logger(name=__name__, save_file=args.log_path / args.run_name) # Checking whether inputs are correct. assert isinstance(model, nn.Module), '`model` must be a Pytorch Module.' assert isinstance( optimizer, optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.' assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \ '`train_loader` and `val_loader` must be Pytorch DataLoader objects.' # I think this would be best practice. assert isinstance( post_processing, nn.Module), '`post_processing` must be a Pytorch Module.' # This is not a mistake. Pytorch implements loss functions as modules. assert isinstance( c_loss, nn.Module), '`c_loss` must be a callable Pytorch Module.' if metrics is not None: assert isinstance( metrics, Iterable ), '`metrics` must be an iterable, preferably a list or tuple.' for metric in metrics: assert callable( metric), 'All metrics must be callable functions.' if scheduler is not None: if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): self.metric_scheduler = True elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): self.metric_scheduler = False else: raise TypeError( '`scheduler` must be a Pytorch Learning Rate Scheduler.') self.model = model self.optimizer = optimizer self.train_loader = train_loader self.val_loader = val_loader self.post_processing_func = post_processing self.c_loss_func = c_loss self.metrics = metrics self.scheduler = scheduler self.verbose = args.verbose self.num_epochs = args.num_epochs self.writer = SummaryWriter(logdir=str(args.log_path)) # Display interval of 0 means no display of validation images on TensorBoard. if args.max_images <= 0: self.display_interval = 0 else: self.display_interval = int( len(self.val_loader.dataset) // (args.max_images * args.batch_size)) self.checkpointer = CheckpointManager( model=self.model, optimizer=self.optimizer, mode='min', save_best_only=args.save_best_only, ckpt_dir=args.ckpt_path, max_to_keep=args.max_to_keep) # loading from checkpoint if specified. if vars(args).get('prev_model_ckpt'): self.checkpointer.load(load_dir=args.prev_model_ckpt, load_optimizer=False) def train_model(self): tic_tic = time() self.logger.info('Beginning Training Loop.') for epoch in range(1, self.num_epochs + 1): # 1 based indexing # Training tic = time() train_epoch_loss, train_epoch_metrics = self._train_epoch( epoch=epoch) toc = int(time() - tic) self._log_epoch_outputs(epoch=epoch, epoch_loss=train_epoch_loss, epoch_metrics=train_epoch_metrics, elapsed_secs=toc, training=True) # Validation tic = time() val_epoch_loss, val_epoch_metrics = self._val_epoch(epoch=epoch) toc = int(time() - tic) self._log_epoch_outputs(epoch=epoch, epoch_loss=val_epoch_loss, epoch_metrics=val_epoch_metrics, elapsed_secs=toc, training=False) self.checkpointer.save(metric=val_epoch_loss, verbose=True) if self.scheduler is not None: if self.metric_scheduler: # If the scheduler is a metric based scheduler, include metrics. self.scheduler.step(metrics=val_epoch_loss) else: self.scheduler.step() # Finishing Training Loop self.writer.close() # Flushes remaining data to TensorBoard. toc_toc = int(time() - tic_tic) self.logger.info( f'Finishing Training Loop. Total elapsed time: ' f'{toc_toc // 3600} hr {(toc_toc // 60) % 60} min {toc_toc % 60} sec.' ) def _train_epoch(self, epoch): self.model.train() torch.autograd.set_grad_enabled(True) epoch_loss_list = list( ) # Appending values to list due to numerical underflow. epoch_metrics_list = [list() for _ in self.metrics] if self.metrics else None # labels are fully sampled coil-wise images, not rss or esc. for step, (inputs, c_img_targets, extra_params) in enumerate(self.train_loader, start=1): step_loss, c_img_recons = self._train_step(inputs, c_img_targets, extra_params) # Gradients are not calculated so as to boost speed and remove weird errors. with torch.no_grad(): # Update epoch loss and metrics epoch_loss_list.append(step_loss.detach( )) # Perhaps not elegant, but underflow makes this necessary. step_metrics = self._get_step_metrics(c_img_recons, c_img_targets, epoch_metrics_list) self._log_step_outputs(epoch, step, step_loss, step_metrics, training=True) epoch_loss, epoch_metrics = self._get_epoch_outputs(epoch, epoch_loss_list, epoch_metrics_list, training=True) return epoch_loss, epoch_metrics def _train_step(self, inputs, c_img_targets, extra_params): self.optimizer.zero_grad() outputs = self.model(inputs) c_img_recons = self.post_processing_func(outputs, c_img_targets, extra_params) step_loss = self.c_loss_func(c_img_recons, c_img_targets) step_loss.backward() self.optimizer.step() return step_loss, c_img_recons def _val_epoch(self, epoch): self.model.eval() torch.autograd.set_grad_enabled(False) epoch_loss_list = list() epoch_metrics_list = [list() for _ in self.metrics] if self.metrics else None for step, (inputs, c_img_targets, extra_params) in enumerate(self.val_loader, start=1): step_loss, c_img_recons = self._val_step(inputs, c_img_targets, extra_params) epoch_loss_list.append(step_loss.detach()) # Step functions have internalized conditional statements deciding whether to execute or not. step_metrics = self._get_step_metrics(c_img_recons, c_img_targets, epoch_metrics_list) self._log_step_outputs(epoch, step, step_loss, step_metrics, training=False) # Save images to TensorBoard. # Condition ensures that self.display_interval != 0 and that the step is right for display. if self.display_interval and (step % self.display_interval == 0): kspace_recons, kspace_targets, image_recons, image_targets, image_deltas \ = self._visualize_outputs(c_img_recons, c_img_targets, smoothing_factor=8) self.writer.add_image(f'k-space_Recons/{step}', kspace_recons, epoch, dataformats='HW') self.writer.add_image(f'k-space_Targets/{step}', kspace_targets, epoch, dataformats='HW') self.writer.add_image(f'Image_Recons/{step}', image_recons, epoch, dataformats='HW') self.writer.add_image(f'Image_Targets/{step}', image_targets, epoch, dataformats='HW') self.writer.add_image(f'Image_Deltas/{step}', image_deltas, epoch, dataformats='HW') epoch_loss, epoch_metrics = self._get_epoch_outputs(epoch, epoch_loss_list, epoch_metrics_list, training=False) return epoch_loss, epoch_metrics def _val_step(self, inputs, c_img_targets, extra_params): """ All extra parameters are to be placed in extra_params. This makes the system more flexible. """ outputs = self.model(inputs) c_img_recons = self.post_processing_func(outputs, c_img_targets, extra_params) step_loss = self.c_loss_func(c_img_recons, c_img_targets) return step_loss, c_img_recons def _get_step_metrics(self, c_img_recons, c_img_targets, epoch_metrics_list): if self.metrics is not None: step_metrics = [ metric(c_img_recons, c_img_targets) for metric in self.metrics ] for step_metric, epoch_metric_list in zip(step_metrics, epoch_metrics_list): epoch_metric_list.append(step_metric.detach()) return step_metrics return None # Explicitly return None for step_metrics if self.metrics is None. Not necessary but more readable. def _get_epoch_outputs(self, epoch, epoch_loss_list, epoch_metrics_list, training=True): mode = 'Training' if training else 'Validation' num_slices = len(self.train_loader.dataset) if training else len( self.val_loader.dataset) # Checking for nan or inf values. epoch_loss_tensor = torch.stack(epoch_loss_list) finite_values = torch.isfinite(epoch_loss_tensor) num_nans = len(epoch_loss_list) - int(finite_values.sum().item()) if num_nans > 0: self.logger.warning( f'Epoch {epoch} {mode}: {num_nans} NaN values present in {num_slices} slices' ) epoch_loss = torch.mean(epoch_loss_tensor[finite_values]).item() else: epoch_loss = torch.mean(epoch_loss_tensor).item() if self.metrics: epoch_metrics = list() for idx, epoch_metric_list in enumerate(epoch_metrics_list, start=1): epoch_metric_tensor = torch.stack(epoch_metric_list) finite_values = torch.isfinite(epoch_metric_tensor) num_nans = len(epoch_metric_list) - int( finite_values.sum().item()) if num_nans > 0: self.logger.warning( f'Epoch {epoch} {mode}: Metric {idx} has {num_nans} NaN values in {num_slices} slices' ) epoch_metric = torch.mean( epoch_metric_tensor[finite_values]).item() else: epoch_metric = torch.mean(epoch_metric_tensor).item() epoch_metrics.append(epoch_metric) else: epoch_metrics = None return epoch_loss, epoch_metrics def _log_step_outputs(self, epoch, step, step_loss, step_metrics, training=True): if self.verbose: mode = 'Training' if training else 'Validation' self.logger.info( f'Epoch {epoch:03d} Step {step:03d} {mode} loss: {step_loss.item():.4e}' ) if self.metrics: for idx, step_metric in enumerate(step_metrics): self.logger.info( f'Epoch {epoch:03d} Step {step:03d}: {mode} metric {idx}: {step_metric.item():.4e}' ) def _log_epoch_outputs(self, epoch, epoch_loss, epoch_metrics, elapsed_secs, training=True): mode = 'Training' if training else 'Validation' self.logger.info( f'Epoch {epoch:03d} {mode}. loss: {epoch_loss:.4e}, Time: {elapsed_secs // 60} min {elapsed_secs % 60} sec' ) self.writer.add_scalar(f'{mode}_epoch_loss', scalar_value=epoch_loss, global_step=epoch) if isinstance( epoch_metrics, list ): # The metrics being returned are either 'None' or a list of values. for idx, epoch_metric in enumerate(epoch_metrics, start=1): self.logger.info( f'Epoch {epoch:03d} {mode}. Metric {idx}: {epoch_metric}') self.writer.add_scalar(f'{mode}_epoch_metric_{idx}', scalar_value=epoch_metric, global_step=epoch) @staticmethod def _visualize_outputs(c_img_recons, c_img_targets, smoothing_factor=8): image_recons = complex_abs(c_img_recons) image_targets = complex_abs(c_img_targets) kspace_recons = make_k_grid(fft2(c_img_recons), smoothing_factor) kspace_targets = make_k_grid(fft2(c_img_targets), smoothing_factor) image_recons, image_targets, image_deltas = make_grid_triplet( image_recons, image_targets) return kspace_recons, kspace_targets, image_recons, image_targets, image_deltas
class ModelTrainerK2I: """ Model Trainer for K-space learning. Please note a bit of terminology. In this file, 'recons' indicates coil-wise reconstructions, not final reconstructions for submissions. Also, 'targets' indicates coil-wise targets, not the 320x320 ground-truth labels. k-slice means a slice of k-space, i.e. only 1 slice of k-space. """ def __init__(self, args, model, optimizer, train_loader, val_loader, post_processing, loss_func, metrics=None, scheduler=None): self.logger = get_logger(name=__name__, save_file=args.log_path / args.run_name) # Checking whether inputs are correct. assert isinstance(model, nn.Module), '`model` must be a Pytorch Module.' assert isinstance( optimizer, optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.' assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \ '`train_loader` and `val_loader` must be Pytorch DataLoader objects.' # I think this would be best practice. assert isinstance( post_processing, nn.Module), '`post_processing_func` must be a Pytorch Module.' # This is not a mistake. Pytorch implements loss functions as modules. assert isinstance( loss_func, nn.Module), '`loss_func` must be a callable Pytorch Module.' if metrics is not None: assert isinstance( metrics, (list, tuple)), '`metrics` must be a list or tuple.' for metric in metrics: assert callable( metric), 'All metrics must be callable functions.' if scheduler is not None: if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): self.metric_scheduler = True elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): self.metric_scheduler = False else: raise TypeError( '`scheduler` must be a Pytorch Learning Rate Scheduler.') self.model = model self.optimizer = optimizer self.train_loader = train_loader self.val_loader = val_loader self.post_processing_func = post_processing self.loss_func = loss_func # I don't think it is necessary to send loss_func or metrics to device. self.metrics = metrics self.scheduler = scheduler self.verbose = args.verbose self.num_epochs = args.num_epochs self.writer = SummaryWriter(logdir=str(args.log_path)) # Display interval of 0 means no display of validation images on TensorBoard. self.display_interval = int( len(self.val_loader.dataset) // args.max_images) if (args.max_images > 0) else 0 # # Writing model graph to TensorBoard. Results might not be very good. if args.add_graph: num_chans = 30 if args.challenge == 'multicoil' else 2 example_inputs = torch.ones(size=(1, num_chans, 640, 328), device=args.device) self.writer.add_graph(model=model, input_to_model=example_inputs) del example_inputs # Remove unnecessary tensor taking up memory. self.checkpointer = CheckpointManager( model=self.model, optimizer=self.optimizer, mode='min', save_best_only=args.save_best_only, ckpt_dir=args.ckpt_path, max_to_keep=args.max_to_keep) # loading from checkpoint if specified. if vars(args).get('prev_model_ckpt'): self.checkpointer.load(load_dir=args.prev_model_ckpt, load_optimizer=False) def train_model(self): multiprocessing.set_start_method(method='spawn') self.logger.info('Beginning Training Loop.') tic_tic = time() for epoch in range(1, self.num_epochs + 1): # 1 based indexing # Training tic = time() train_epoch_loss, train_epoch_metrics = self._train_epoch( epoch=epoch) toc = int(time() - tic) self._log_epoch_outputs(epoch=epoch, epoch_loss=train_epoch_loss, epoch_metrics=train_epoch_metrics, elapsed_secs=toc, training=True) # Validation tic = time() val_epoch_loss, val_epoch_metrics = self._val_epoch(epoch=epoch) toc = int(time() - tic) self._log_epoch_outputs(epoch=epoch, epoch_loss=val_epoch_loss, epoch_metrics=val_epoch_metrics, elapsed_secs=toc, training=False) self.checkpointer.save(metric=val_epoch_loss, verbose=True) if self.scheduler is not None: if self.metric_scheduler: # If the scheduler is a metric based scheduler, include metrics. self.scheduler.step(metrics=val_epoch_loss) else: self.scheduler.step() # Finishing Training Loop self.writer.close() # Flushes remaining data to TensorBoard. toc_toc = int(time() - tic_tic) self.logger.info( f'Finishing Training Loop. Total elapsed time: ' f'{toc_toc // 3600} hr {(toc_toc // 60) % 60} min {toc_toc % 60} sec.' ) def _train_step(self, inputs, targets, extra_params): self.optimizer.zero_grad() outputs = self.model(inputs) image_recons, kspace_recons = self.post_processing_func( outputs, targets, extra_params) step_loss = self.loss_func(image_recons, targets) step_loss.backward() self.optimizer.step() return step_loss, image_recons, kspace_recons def _train_epoch(self, epoch): self.model.train() torch.autograd.set_grad_enabled(True) epoch_loss_lst = list( ) # Appending values to list due to numerical underflow. epoch_metrics_lst = [list() for _ in self.metrics] if self.metrics else None # labels are fully sampled coil-wise images, not rss or esc. for step, (inputs, targets, extra_params) in enumerate(self.train_loader, start=1): step_loss, image_recons, kspace_recons = self._train_step( inputs, targets, extra_params) # Gradients are not calculated so as to boost speed and remove weird errors. with torch.no_grad(): # Update epoch loss and metrics epoch_loss_lst.append(step_loss.item( )) # Perhaps not elegant, but underflow makes this necessary. # The step functions here have all necessary conditionals internally. # There is no need to externally specify whether to use them or not. step_metrics = self._get_step_metrics(image_recons, targets, epoch_metrics_lst) self._log_step_outputs(epoch, step, step_loss, step_metrics, training=True) epoch_loss, epoch_metrics = self._get_epoch_outputs(epoch, epoch_loss_lst, epoch_metrics_lst, training=True) return epoch_loss, epoch_metrics def _val_step(self, inputs, targets, extra_params): """ All extra parameters are to be placed in extra_params. This makes the system more flexible. """ outputs = self.model(inputs) image_recons, kspace_recons = self.post_processing_func( outputs, targets, extra_params) step_loss = self.loss_func(image_recons, targets) return step_loss, image_recons, kspace_recons def _val_epoch(self, epoch): self.model.eval() torch.autograd.set_grad_enabled(False) epoch_loss_lst = list() epoch_metrics_lst = [list() for _ in self.metrics] if self.metrics else None for step, (inputs, targets, extra_params) in enumerate(self.val_loader, start=1): step_loss, image_recons, kspace_recons = self._val_step( inputs, targets, extra_params) epoch_loss_lst.append(step_loss.item()) # Step functions have internalized conditional statements deciding whether to execute or not. step_metrics = self._get_step_metrics(image_recons, targets, epoch_metrics_lst) self._log_step_outputs(epoch, step, step_loss, step_metrics, training=False) # Save images to TensorBoard. Send this to a separate function later on. # Condition ensures that self.display_interval != 0 and that the step is right for display. if self.display_interval and (step % self.display_interval == 0): recons_grid, targets_grid, deltas_grid = make_grid_triplet( image_recons, targets) kspace_grid = make_k_grid(kspace_recons) self.writer.add_image(f'k-space_Recons/{step}', kspace_grid, epoch, dataformats='HW') self.writer.add_image(f'Image_Recons/{step}', recons_grid, epoch, dataformats='HW') self.writer.add_image(f'Targets/{step}', targets_grid, epoch, dataformats='HW') self.writer.add_image(f'Deltas/{step}', deltas_grid, epoch, dataformats='HW') epoch_loss, epoch_metrics = self._get_epoch_outputs(epoch, epoch_loss_lst, epoch_metrics_lst, training=False) return epoch_loss, epoch_metrics def _get_step_metrics(self, image_recons, targets, epoch_metrics_lst): if self.metrics is not None: step_metrics = [ metric(image_recons, targets) for metric in self.metrics ] for step_metric, epoch_metric_lst in zip(step_metrics, epoch_metrics_lst): epoch_metric_lst.append(step_metric.item()) return step_metrics return None # Explicitly return None for step_metrics if self.metrics is None. Not necessary but more readable. def _get_epoch_outputs(self, epoch, epoch_loss_lst, epoch_metrics_lst, training=True): mode = 'training' if training else 'validation' num_slices = len(self.train_loader.dataset) if training else len( self.val_loader.dataset) # Checking for nan values. num_nans = np.isnan(epoch_loss_lst).sum() if num_nans > 0: self.logger.warning( f'Epoch {epoch} {mode}: {num_nans} NaN values present in {num_slices} slices' ) epoch_loss = float( np.nanmean(epoch_loss_lst)) # Remove nan values just in case. epoch_metrics = [ float(np.nanmean(epoch_metric_lst)) for epoch_metric_lst in epoch_metrics_lst ] if self.metrics else None return epoch_loss, epoch_metrics def _log_step_outputs(self, epoch, step, step_loss, step_metrics, training=True): if self.verbose: mode = 'Training' if training else 'Validation' self.logger.info( f'Epoch {epoch:03d} Step {step:03d} {mode} loss: {step_loss.item():.4e}' ) if self.metrics: for idx, step_metric in enumerate(step_metrics): self.logger.info( f'Epoch {epoch:03d} Step {step:03d}: {mode} metric {idx}: {step_metric.item():.4e}' ) def _log_epoch_outputs(self, epoch, epoch_loss, epoch_metrics, elapsed_secs, training=True): mode = 'Training' if training else 'Validation' self.logger.info( f'Epoch {epoch:03d} {mode}. loss: {epoch_loss:.4e}, Time: {elapsed_secs // 60} min {elapsed_secs % 60} sec' ) self.writer.add_scalar(f'{mode}_epoch_loss', scalar_value=epoch_loss, global_step=epoch) if isinstance( epoch_metrics, list ): # The metrics being returned are either 'None' or a list of values. for idx, epoch_metric in enumerate(epoch_metrics, start=1): self.logger.info( f'Epoch {epoch:03d} {mode}. Metric {idx}: {epoch_metric}') self.writer.add_scalar(f'{mode}_epoch_metric_{idx}', scalar_value=epoch_metric, global_step=epoch)
def train_model(model, args): assert isinstance(model, nn.Module) # Beginning session. run_number, run_name = initialize(args.ckpt_root) ckpt_path = Path(args.ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(args.log_root) log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) if (args.gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') else: device = torch.device('cpu') # Saving args for later use. save_dict_as_json(vars(args), log_dir=log_path, save_name=run_name) dataset_kwargs = dict(root=args.data_root, download=True) train_dataset = torchvision.datasets.CIFAR100(train=True, transform=train_transform(), **dataset_kwargs) val_dataset = torchvision.datasets.CIFAR100(train=False, transform=val_transform(), **dataset_kwargs) loader_kwargs = dict(batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True) train_loader = DataLoader(train_dataset, shuffle=True, **loader_kwargs) val_loader = DataLoader(val_dataset, shuffle=False, **loader_kwargs) # Define model, optimizer, etc. model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=args.init_lr) # No softmax layer at the end necessary. Just need logits. loss_func = nn.CrossEntropyLoss(reduction='mean').to(device) # LR scheduler scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) # Create checkpoint manager checkpointer = CheckpointManager(model, optimizer, mode='max', save_best_only=args.save_best_only, ckpt_dir=ckpt_path, max_to_keep=args.max_to_keep) # Tensorboard Writer writer = SummaryWriter(log_dir=str(log_path)) # Training loop. Please excuse my use of 1 based indexing here. logger.info('Beginning Training loop') for epoch in range(1, args.num_epochs + 1): # Start of training tic = time() train_loss_sum, train_top1_correct, train_top5_correct = \ train_epoch(model, optimizer, loss_func, train_loader, device, epoch, args.verbose) toc = int(time() - tic) # Last step with small batch causes some inaccuracy but that is tolerable. train_epoch_loss = train_loss_sum.item() * args.batch_size / len( train_loader.dataset) train_epoch_top1_acc = train_top1_correct.item() / len( train_loader.dataset) * 100 train_epoch_top5_acc = train_top5_correct.item() / len( train_loader.dataset) * 100 logger.info( f'Epoch {epoch:03d} Training. loss: {train_epoch_loss:.4e}, top1 accuracy: ' f'{train_epoch_top1_acc:.2f}%, top5 accuracy: {train_epoch_top5_acc:.2f}% Time: {toc}s' ) # Writing to Tensorboard writer.add_scalar('train_epoch_loss', train_epoch_loss, epoch) writer.add_scalar('train_epoch_top1_acc', train_epoch_top1_acc, epoch) writer.add_scalar('train_epoch_top5_acc', train_epoch_top5_acc, epoch) # Start of evaluation tic = time() val_loss_sum, val_top1_correct, val_top5_correct = \ eval_epoch(model, loss_func, val_loader, device, epoch, args.verbose) toc = int(time() - tic) val_epoch_loss = val_loss_sum.item() * args.batch_size / len( val_loader.dataset) val_epoch_top1_acc = val_top1_correct.item() / len( val_loader.dataset) * 100 val_epoch_top5_acc = val_top5_correct.item() / len( val_loader.dataset) * 100 logger.info( f'Epoch {epoch:03d} Validation. loss: {val_epoch_loss:.4e}, top1 accuracy: ' f'{val_epoch_top1_acc:.2f}%, top5 accuracy: {val_epoch_top5_acc:.2f}% Time: {toc}s' ) # Writing to Tensorboard writer.add_scalar('val_epoch_loss', val_epoch_loss, epoch) writer.add_scalar('val_epoch_top1_acc', val_epoch_top1_acc, epoch) writer.add_scalar('val_epoch_top5_acc', val_epoch_top5_acc, epoch) for idx, group in enumerate(optimizer.param_groups, start=1): writer.add_scalar(f'learning_rate_{idx}', group['lr'], epoch) # Things to do after each epoch. scheduler.step( ) # Reduces LR at the designated times. Probably does not use 1 indexing like me. checkpointer.save(metric=val_epoch_top5_acc)
def main(): # Put these in args later. batch_size = 12 num_workers = 8 init_lr = 2E-4 gpu = 0 # Set to None for CPU mode. num_epochs = 500 verbose = False save_best_only = True max_to_keep = 100 data_root = '/home/veritas/PycharmProjects/PA1/data' ckpt_root = '/home/veritas/PycharmProjects/PA1/checkpoints' log_root = '/home/veritas/PycharmProjects/PA1/logs' # Beginning session. run_number, run_name = initialize(ckpt_root) ckpt_path = Path(ckpt_root) ckpt_path.mkdir(exist_ok=True) ckpt_path = ckpt_path / run_name ckpt_path.mkdir(exist_ok=True) log_path = Path(log_root) log_path.mkdir(exist_ok=True) log_path = log_path / run_name log_path.mkdir(exist_ok=True) logger = get_logger(name=__name__, save_file=log_path / run_name) if (gpu is not None) and torch.cuda.is_available(): device = torch.device(f'cuda:{gpu}') else: device = torch.device('cpu') # Do more fancy transforms later. transform = torchvision.transforms.ToTensor() train_dataset = torchvision.datasets.CIFAR100(data_root, train=True, transform=transform, download=True) val_dataset = torchvision.datasets.CIFAR100(data_root, train=False, transform=transform, download=True) train_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) val_loader = DataLoader(val_dataset, batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) # Define model, optimizer, etc. model = se_resnet50_cifar100().to(device, non_blocking=True) optimizer = optim.Adam(model.parameters(), lr=init_lr) # No softmax layer at the end necessary. Just need logits. loss_func = nn.CrossEntropyLoss().to(device, non_blocking=True) # Create checkpoint manager checkpointer = CheckpointManager(model, optimizer, ckpt_path, save_best_only, max_to_keep) # For recording data. previous_best = 0. # Accuracy should improve. # Training loop. Please excuse my use of 1 based indexing here. logger.info('Beginning Training loop') for epoch in range(1, num_epochs + 1): # Start of training tic = time() train_loss_sum, train_top1_correct, train_top5_correct = \ train_epoch(model, optimizer, loss_func, train_loader, device, epoch, verbose) toc = int(time() - tic) # Last step with small batch causes some inaccuracy but that is tolerable. train_epoch_loss = train_loss_sum.item() * batch_size / len( train_loader.dataset) train_epoch_top1_acc = train_top1_correct.item() / len( train_loader.dataset) * 100 train_epoch_top5_acc = train_top5_correct.item() / len( train_loader.dataset) * 100 msg = f'Epoch {epoch:03d} Training. loss: {train_epoch_loss:.4f}, ' \ f'top1 accuracy: {train_epoch_top1_acc:.2f}%, top5 accuracy: {train_epoch_top5_acc:.2f}% Time: {toc}s' logger.info(msg) # Start of evaluation tic = time() val_loss_sum, val_top1_correct, val_top5_correct = \ eval_epoch(model, loss_func, val_loader, device, epoch, verbose) toc = int(time() - tic) val_epoch_loss = val_loss_sum.item() * batch_size / len( val_loader.dataset) val_epoch_top1_acc = val_top1_correct.item() / len( val_loader.dataset) * 100 val_epoch_top5_acc = val_top5_correct.item() / len( val_loader.dataset) * 100 msg = f'Epoch {epoch:03d} Validation. loss: {val_epoch_loss:.4f}, ' \ f'top1 accuracy: {val_epoch_top1_acc:.2f}%, top5 accuracy: {val_epoch_top5_acc:.2f}% Time: {toc}s' logger.info(msg) if val_epoch_top5_acc > previous_best: # Assumes larger metric is better. logger.info( f'Top 5 Validation Accuracy in Epoch {epoch} has improved from ' f'{previous_best:.2f}% to {val_epoch_top5_acc:.2f}%') previous_best = val_epoch_top5_acc checkpointer.save(is_best=True) else: logger.info( f'Top 5 Validation Accuracy in Epoch {epoch} has not improved from the previous best epoch' ) checkpointer.save(is_best=False)