Exemplo n.º 1
0
class ModelTrainerK2CI:
    """
    Model Trainer for k-space learning or complex image learning
    with losses in complex image domains and real valued image domains.
    All learning occurs in k-space or complex image domains
    while all losses are obtained from either complex images or real-valued images.
    Expects multiple losses to be present for the image domain loss.
    Expects only 1 component for complex image loss, which will be MSE, though this not set explicitly.
    Also, if the image domain loss has multiple components, the total weighted image loss is expected to return
    a dictionary with 'img_loss' as the key for the weighted loss.
    """

    def __init__(self, args, model, optimizer, train_loader, val_loader,
                 input_train_transform, input_val_transform, output_transform, losses, scheduler=None):

        # Allow multiple processes to access tensors on GPU. Add checking for multiple continuous runs.
        if multiprocessing.get_start_method(allow_none=True) is None:
            multiprocessing.set_start_method(method='spawn')

        self.logger = get_logger(name=__name__, save_file=args.log_path / args.run_name)

        # Checking whether inputs are correct.
        assert isinstance(model, nn.Module), '`model` must be a Pytorch Module.'
        assert isinstance(optimizer, optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.'
        assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \
            '`train_loader` and `val_loader` must be Pytorch DataLoader objects.'

        assert callable(input_train_transform) and callable(input_val_transform), \
            'input_transforms must be callable functions.'
        # I think this would be best practice.
        assert isinstance(output_transform, nn.Module), '`output_transform` must be a Pytorch Module.'

        # 'losses' is expected to be a dictionary.
        # Even composite losses should be a single loss module with a dictionary output.
        losses = nn.ModuleDict(losses)

        if scheduler is not None:
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_scheduler = True
            elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_scheduler = False
            else:
                raise TypeError('`scheduler` must be a Pytorch Learning Rate Scheduler.')

        # Display interval of 0 means no display of validation images on TensorBoard.
        if args.max_images <= 0:
            self.display_interval = 0
        else:
            self.display_interval = int(len(val_loader.dataset) // (args.max_images * args.batch_size))

        self.checkpointer = CheckpointManager(model, optimizer, mode='min', save_best_only=args.save_best_only,
                                              ckpt_dir=args.ckpt_path, max_to_keep=args.max_to_keep)

        # loading from checkpoint if specified.
        if vars(args).get('prev_model_ckpt'):
            self.checkpointer.load(load_dir=args.prev_model_ckpt, load_optimizer=False)

        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.input_train_transform = input_train_transform
        self.input_val_transform = input_val_transform
        self.output_transform = output_transform
        self.losses = losses
        self.scheduler = scheduler

        self.verbose = args.verbose
        self.num_epochs = args.num_epochs
        self.smoothing_factor = args.smoothing_factor
        self.use_slice_metrics = args.use_slice_metrics
        self.img_lambda = torch.tensor(args.img_lambda, dtype=torch.float32, device=args.device)
        self.writer = SummaryWriter(str(args.log_path))

    def train_model(self):
        tic_tic = time()
        self.logger.info('Beginning Training Loop.')
        for epoch in range(1, self.num_epochs + 1):  # 1 based indexing of epochs.
            tic = time()  # Training
            train_epoch_loss, train_epoch_metrics = self._train_epoch(epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch, train_epoch_loss, train_epoch_metrics, elapsed_secs=toc, training=True)

            tic = time()  # Validation
            val_epoch_loss, val_epoch_metrics = self._val_epoch(epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch, val_epoch_loss, val_epoch_metrics, elapsed_secs=toc, training=False)

            self.checkpointer.save(metric=val_epoch_loss, verbose=True)

            if self.scheduler is not None:
                if self.metric_scheduler:  # If the scheduler is a metric based scheduler, include metrics.
                    self.scheduler.step(metrics=val_epoch_loss)
                else:
                    self.scheduler.step()

        self.writer.close()  # Flushes remaining data to TensorBoard.
        toc_toc = int(time() - tic_tic)
        self.logger.info(f'Finishing Training Loop. Total elapsed time: '
                         f'{toc_toc // 3600} hr {(toc_toc // 60) % 60} min {toc_toc % 60} sec.')

    def _train_epoch(self, epoch):
        self.model.train()
        torch.autograd.set_grad_enabled(True)

        epoch_loss = list()  # Appending values to list due to numerical underflow.
        epoch_metrics = defaultdict(list)

        data_loader = enumerate(self.train_loader, start=1)
        if not self.verbose:  # tqdm has to be on the outermost iterator to function properly.
            data_loader = tqdm(data_loader, total=len(self.train_loader.dataset))

        # 'targets' is a dictionary containing k-space targets, cmg_targets, and img_targets.
        for step, data in data_loader:
            with torch.no_grad():  # Data pre-processing should be done without gradients.
                inputs, targets, extra_params = self.input_train_transform(*data)  # Maybe pass as tuple later?

            # 'recons' is a dictionary containing k-space, complex image, and real image reconstructions.
            recons, step_loss, step_metrics = self._train_step(inputs, targets, extra_params)
            epoch_loss.append(step_loss.detach())  # Perhaps not elegant, but underflow makes this necessary.

            # Gradients are not calculated so as to boost speed and remove weird errors.
            with torch.no_grad():  # Update epoch loss and metrics
                if self.use_slice_metrics:
                    slice_metrics = self._get_slice_metrics(recons['img_recons'], targets['img_targets'])
                    step_metrics.update(slice_metrics)

                [epoch_metrics[key].append(value.detach()) for key, value in step_metrics.items()]

                if self.verbose:
                    self._log_step_outputs(epoch, step, step_loss, step_metrics, training=True)

        # Converted to scalar and dict with scalar forms.
        return self._get_epoch_outputs(epoch, epoch_loss, epoch_metrics, training=True)

    def _train_step(self, inputs, targets, extra_params):
        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        recons = self.output_transform(outputs, targets, extra_params)

        # Expects a single loss. No loss decomposition within domain implemented yet.
        cmg_loss = self.losses['cmg_loss'](recons['cmg_recons'], targets['cmg_targets'])
        img_loss = self.losses['img_loss'](recons['img_recons'], targets['img_targets'])

        # If img_loss is a dict, it is expected to contain all its component losses as key, value pairs.
        if isinstance(img_loss, dict):
            step_metrics = img_loss
            img_loss = step_metrics['img_loss']
        else:  # img_loss is expected to be a Tensor if not a dict.
            step_metrics = {'img_loss': img_loss}

        step_loss = cmg_loss + self.img_lambda * img_loss
        step_loss.backward()
        self.optimizer.step()
        step_metrics['cmg_loss'] = cmg_loss
        return recons, step_loss, step_metrics

    def _val_epoch(self, epoch):
        self.model.eval()
        torch.autograd.set_grad_enabled(False)

        epoch_loss = list()
        epoch_metrics = defaultdict(list)

        # 1 based indexing for steps.
        data_loader = enumerate(self.val_loader, start=1)
        if not self.verbose:
            data_loader = tqdm(data_loader, total=len(self.val_loader.dataset))

        # 'targets' is a dictionary containing k-space targets, cmg_targets, and img_targets.
        for step, data in data_loader:
            inputs, targets, extra_params = self.input_val_transform(*data)
            # 'recons' is a dictionary containing k-space, complex image, and real image reconstructions.
            recons, step_loss, step_metrics = self._val_step(inputs, targets, extra_params)
            epoch_loss.append(step_loss.detach())

            if self.use_slice_metrics:
                slice_metrics = self._get_slice_metrics(recons['img_recons'], targets['img_targets'])
                step_metrics.update(slice_metrics)

            [epoch_metrics[key].append(value.detach()) for key, value in step_metrics.items()]

            if self.verbose:
                self._log_step_outputs(epoch, step, step_loss, step_metrics, training=False)

            # Save images to TensorBoard.
            # Condition ensures that self.display_interval != 0 and that the step is right for display.

            # This numbering scheme seems to have issues for certain numbers.
            # Please check cases when there is no remainder.
            if self.display_interval and (step % self.display_interval == 0):
                img_recon_grid, img_target_grid, img_delta_grid = \
                    make_grid_triplet(recons['img_recons'], targets['img_targets'])
                kspace_recon_grid = make_k_grid(recons['kspace_recons'], self.smoothing_factor)
                kspace_target_grid = make_k_grid(targets['kspace_targets'], self.smoothing_factor)

                self.writer.add_image(f'k-space_Recons/{step}', kspace_recon_grid, epoch, dataformats='HW')
                self.writer.add_image(f'k-space_Targets/{step}', kspace_target_grid, epoch, dataformats='HW')
                self.writer.add_image(f'Image_Recons/{step}', img_recon_grid, epoch, dataformats='HW')
                self.writer.add_image(f'Image_Targets/{step}', img_target_grid, epoch, dataformats='HW')
                self.writer.add_image(f'Image_Deltas/{step}', img_delta_grid, epoch, dataformats='HW')

        epoch_loss, epoch_metrics = self._get_epoch_outputs(epoch, epoch_loss, epoch_metrics, training=False)
        return epoch_loss, epoch_metrics

    def _val_step(self, inputs, targets, extra_params):
        outputs = self.model(inputs)
        recons = self.output_transform(outputs, targets, extra_params)
        cmg_loss = self.losses['cmg_loss'](recons['cmg_recons'], targets['cmg_targets'])
        img_loss = self.losses['img_loss'](recons['img_recons'], targets['img_targets'])

        # If img_loss is a dict, it is expected to contain all its component losses as key, value pairs.
        if isinstance(img_loss, dict):
            step_metrics = img_loss
            img_loss = step_metrics['img_loss']
        else:  # img_loss is expected to be a Tensor if not a dict.
            step_metrics = {'img_loss': img_loss}

        step_loss = cmg_loss + self.img_lambda * img_loss
        step_metrics['cmg_loss'] = cmg_loss
        return recons, step_loss, step_metrics

    @staticmethod
    def _get_slice_metrics(img_recons, img_targets):

        img_recons = img_recons.detach()  # Just in case.
        img_targets = img_targets.detach()

        max_range = img_targets.max() - img_targets.min()
        slice_ssim = ssim_loss(img_recons, img_targets, max_val=max_range)
        slice_psnr = psnr(img_recons, img_targets, data_range=max_range)
        slice_nmse = nmse(img_recons, img_targets)

        return {'slice_ssim': slice_ssim, 'slice_nmse': slice_nmse, 'slice_psnr': slice_psnr}

    def _get_epoch_outputs(self, epoch, epoch_loss, epoch_metrics, training=True):
        mode = 'Training' if training else 'Validation'
        num_slices = len(self.train_loader.dataset) if training else len(self.val_loader.dataset)

        # Checking for nan values.
        epoch_loss = torch.stack(epoch_loss)
        is_finite = torch.isfinite(epoch_loss)
        num_nans = (is_finite.size(0) - is_finite.sum()).item()

        if num_nans > 0:
            self.logger.warning(f'Epoch {epoch} {mode}: {num_nans} NaN values present in {num_slices} slices')
            # Turn on anomaly detection for finding where the nan values are.
            torch.autograd.set_detect_anomaly(True)
            epoch_loss = torch.mean(epoch_loss[is_finite]).item()
        else:
            epoch_loss = torch.mean(epoch_loss).item()

        for key, value in epoch_metrics.items():
            epoch_metric = torch.stack(value)
            is_finite = torch.isfinite(epoch_metric)
            num_nans = (is_finite.size(0) - is_finite.sum()).item()

            if num_nans > 0:
                self.logger.warning(f'Epoch {epoch} {mode} {key}: {num_nans} NaN values present in {num_slices} slices')
                epoch_metrics[key] = torch.mean(epoch_metric[is_finite]).item()
            else:
                epoch_metrics[key] = torch.mean(epoch_metric).item()

        return epoch_loss, epoch_metrics

    def _log_step_outputs(self, epoch, step, step_loss, step_metrics, training=True):
        mode = 'Training' if training else 'Validation'
        self.logger.info(f'Epoch {epoch:03d} Step {step:03d} {mode} loss: {step_loss.item():.4e}')
        for key, value in step_metrics.items():
            self.logger.info(f'Epoch {epoch:03d} Step {step:03d}: {mode} {key}: {value.item():.4e}')

    def _log_epoch_outputs(self, epoch, epoch_loss, epoch_metrics, elapsed_secs, training=True):
        mode = 'Training' if training else 'Validation'
        self.logger.info(f'Epoch {epoch:03d} {mode}. loss: {epoch_loss:.4e}, '
                         f'Time: {elapsed_secs // 60} min {elapsed_secs % 60} sec')
        self.writer.add_scalar(f'{mode}_epoch_loss', scalar_value=epoch_loss, global_step=epoch)

        for key, value in epoch_metrics.items():
            self.logger.info(f'Epoch {epoch:03d} {mode}. {key}: {value:.4e}')
            self.writer.add_scalar(f'{mode}_epoch_{key}', scalar_value=value, global_step=epoch)
Exemplo n.º 2
0
class ModelTrainerWGANGP:
    def __init__(self,
                 args,
                 generator,
                 discriminator,
                 gen_optim,
                 disc_optim,
                 train_loader,
                 val_loader,
                 loss_funcs,
                 gen_scheduler=None,
                 disc_scheduler=None):

        self.logger = get_logger(name=__name__,
                                 save_file=args.log_path / args.run_name)

        # Checking whether inputs are correct.
        assert isinstance(generator, nn.Module) and isinstance(discriminator, nn.Module), \
            '`generator` and `discriminator` must be Pytorch Modules.'
        assert isinstance(gen_optim, optim.Optimizer) and isinstance(disc_optim, optim.Optimizer), \
            '`gen_optim` and `disc_optim` must be Pytorch Optimizers.'
        assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \
            '`train_loader` and `val_loader` must be Pytorch DataLoader objects.'

        loss_funcs = nn.ModuleDict(
            loss_funcs
        )  # Expected to be a dictionary with names and loss functions.

        if gen_scheduler is not None:
            if isinstance(gen_scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_gen_scheduler = True
            elif isinstance(gen_scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_gen_scheduler = False
            else:
                raise TypeError(
                    '`gen_scheduler` must be a Pytorch Learning Rate Scheduler.'
                )

        if disc_scheduler is not None:
            if isinstance(disc_scheduler,
                          optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_disc_scheduler = True
            elif isinstance(disc_scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_disc_scheduler = False
            else:
                raise TypeError(
                    '`disc_scheduler` must be a Pytorch Learning Rate Scheduler.'
                )

        self.generator = generator
        self.discriminator = discriminator
        self.gen_optim = gen_optim
        self.disc_optim = disc_optim
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.loss_funcs = loss_funcs
        self.gen_scheduler = gen_scheduler
        self.disc_scheduler = disc_scheduler
        self.device = args.device
        self.verbose = args.verbose
        self.num_epochs = args.num_epochs
        self.writer = SummaryWriter(str(args.log_path))

        self.recon_lambda = torch.tensor(args.recon_lambda,
                                         dtype=torch.float32,
                                         device=args.device)
        self.lambda_gp = torch.tensor(args.lambda_gp,
                                      dtype=torch.float32,
                                      device=args.device)

        # This will work best if batch size is 1, as is recommended. I don't know whether this generalizes.
        self.target_real = torch.tensor(1,
                                        dtype=torch.float32,
                                        device=args.device)
        self.target_fake = torch.tensor(0,
                                        dtype=torch.float32,
                                        device=args.device)

        # Display interval of 0 means no display of validation images on TensorBoard.
        if args.max_images <= 0:
            self.display_interval = 0
        else:
            self.display_interval = int(
                len(self.val_loader.dataset) //
                (args.max_images * args.batch_size))

        self.gen_checkpoint_manager = CheckpointManager(
            model=self.generator,
            optimizer=self.gen_optim,
            mode='min',
            save_best_only=args.save_best_only,
            ckpt_dir=args.ckpt_path / 'Generator',
            max_to_keep=args.max_to_keep)

        self.disc_checkpoint_manager = CheckpointManager(
            model=self.discriminator,
            optimizer=self.disc_optim,
            mode='min',
            save_best_only=args.save_best_only,
            ckpt_dir=args.ckpt_path / 'Discriminator',
            max_to_keep=args.max_to_keep)

        # loading from checkpoint if specified.
        if vars(args).get('gen_prev_model_ckpt'):
            self.gen_checkpoint_manager.load(load_dir=args.gen_prev_model_ckpt,
                                             load_optimizer=False)

        if vars(args).get('disc_prev_model_ckpt'):
            self.disc_checkpoint_manager.load(
                load_dir=args.disc_prev_model_ckpt, load_optimizer=False)

    def train_model(self):
        self.logger.info('Beginning Training Loop.')
        tic_tic = time()
        for epoch in range(1, self.num_epochs + 1):  # 1 based indexing
            # Training
            tic = time()
            train_epoch_loss, train_epoch_loss_components = self._train_epoch(
                epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(
                epoch=epoch,
                epoch_loss=train_epoch_loss,
                epoch_loss_components=train_epoch_loss_components,
                elapsed_secs=toc,
                training=True)

            # Validation
            tic = time()
            val_epoch_loss, val_epoch_loss_components = self._val_epoch(
                epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(
                epoch=epoch,
                epoch_loss=val_epoch_loss,
                epoch_loss_components=val_epoch_loss_components,
                elapsed_secs=toc,
                training=False)

            self.gen_checkpoint_manager.save(metric=val_epoch_loss,
                                             verbose=True)
            self.disc_checkpoint_manager.save(metric=val_epoch_loss,
                                              verbose=True)

            if self.gen_scheduler is not None:
                if self.metric_gen_scheduler:  # If the scheduler is a metric based scheduler, include metrics.
                    self.gen_scheduler.step(metrics=val_epoch_loss)
                else:
                    self.gen_scheduler.step()

            if self.disc_scheduler is not None:
                if self.metric_disc_scheduler:
                    self.disc_scheduler.step(metrics=val_epoch_loss)
                else:
                    self.disc_scheduler.step()

        # Finishing Training Loop
        self.writer.close()  # Flushes remaining data to TensorBoard.
        toc_toc = int(time() - tic_tic)
        self.logger.info(
            f'Finishing Training Loop. Total elapsed time: '
            f'{toc_toc // 3600} hr {(toc_toc // 60) % 60} min {toc_toc % 60} sec.'
        )

    def _train_step(self, inputs, targets, extra_params):

        assert inputs.size() == targets.size(
        ), 'input and target sizes do not match'

        inputs = inputs.to(self.device)
        targets = targets.to(self.device, non_blocking=True)

        # Train discriminator
        self.disc_optim.zero_grad()
        recons = self.generator(inputs)
        pred_fake = self.discriminator(recons.detach(
        ))  # Generated fake image going through discriminator.
        pred_real = self.discriminator(
            targets)  # Real image going through discriminator.
        gradient_penalty = compute_gradient_penalty(self.discriminator,
                                                    targets, recons.detach())
        disc_loss = pred_fake.mean() - pred_real.mean(
        ) + self.lambda_gp * gradient_penalty
        disc_loss.backward()
        self.disc_optim.step()

        # Train Generator
        self.gen_optim.zero_grad()
        pred_fake = self.discriminator(recons)
        gen_loss = -pred_fake.mean()
        recon_loss = self.loss_funcs['recon_loss_func'](recons, targets)
        total_gen_loss = gen_loss + self.recon_lambda * recon_loss
        total_gen_loss.backward()
        self.gen_optim.step()

        # Just using reconstruction loss since it is the most meaningful.
        step_loss = recon_loss
        step_loss_components = {
            'pred_fake': pred_fake.mean(),
            'pred_real': pred_real.mean(),
            'gradient_penalty': gradient_penalty,
            'disc_loss': disc_loss,
            'gen_loss': gen_loss,
            'recon_loss': recon_loss,
            'total_gen_loss': total_gen_loss
        }

        return step_loss, step_loss_components

    def _train_epoch(self, epoch):
        self.generator.train()
        self.discriminator.train()
        torch.autograd.set_grad_enabled(True)

        epoch_loss = list(
        )  # Appending values to list due to numerical underflow.
        epoch_loss_components = defaultdict(list)

        # labels are fully sampled coil-wise images, not rss or esc.
        train_len = len(
            self.train_loader.dataset)  # Adding progress bar for convenience.
        for step, (inputs, targets,
                   extra_params) in tqdm(enumerate(self.train_loader, start=1),
                                         total=train_len):
            step_loss, step_loss_components = self._train_step(
                inputs, targets, extra_params)

            # Perhaps not elegant, but NaN values make this necessary.
            epoch_loss.append(step_loss.detach())
            for key, value in step_loss_components.items():
                epoch_loss_components[key].append(value.detach())

            if self.verbose:
                self._log_step_outputs(epoch,
                                       step,
                                       step_loss,
                                       step_loss_components,
                                       training=True)

        # Return as scalar value and dict respectively. Remove the inner lists.
        return self._get_epoch_outputs(epoch,
                                       epoch_loss,
                                       epoch_loss_components,
                                       training=True)

    def _val_step(self, inputs, targets, extra_params):
        """
        All extra parameters are to be placed in extra_params.
        This makes the system more flexible.
        """

        inputs = inputs.to(self.device)
        targets = targets.to(self.device, non_blocking=True)

        recons = self.generator(inputs)

        # Discriminator part.
        # pred_fake = self.discriminator(recons)  # Generated fake image going through discriminator.
        # pred_real = self.discriminator(targets)  # Real image going through discriminator.
        # gradient_penalty = compute_gradient_penalty(self.discriminator, targets, recons)
        # disc_loss = pred_fake.mean() - pred_real.mean() + self.lambda_gp * gradient_penalty

        # Generator part.
        # gen_loss = -pred_fake.mean()
        recon_loss = self.loss_funcs['recon_loss_func'](recons, targets)
        # total_gen_loss = gen_loss + self.recon_lambda * recon_loss

        step_loss = recon_loss
        step_loss_components = {'recon_loss': recon_loss}

        return recons, step_loss, step_loss_components

    def _val_epoch(self, epoch):
        self.generator.eval()
        self.discriminator.eval()
        torch.autograd.set_grad_enabled(False)

        epoch_loss = list(
        )  # Appending values to list due to numerical underflow.
        epoch_loss_components = defaultdict(list)

        val_len = len(self.val_loader.dataset)
        for step, (inputs, targets,
                   extra_params) in tqdm(enumerate(self.val_loader, start=1),
                                         total=val_len):
            recons, step_loss, step_loss_components = self._val_step(
                inputs, targets, extra_params)

            # Append to list to prevent errors from NaN and Inf values.
            epoch_loss.append(step_loss)
            for key, value in step_loss_components.items():
                epoch_loss_components[key].append(value)

            if self.verbose:
                self._log_step_outputs(epoch,
                                       step,
                                       step_loss,
                                       step_loss_components,
                                       training=False)

            # Save images to TensorBoard.
            # Condition ensures that self.display_interval != 0 and that the step is right for display.
            if self.display_interval and (step % self.display_interval == 0):
                recon_grid, target_grid, delta_grid = make_grid_triplet(
                    recons, targets)

                self.writer.add_image(f'Recons/{step}',
                                      recon_grid,
                                      global_step=epoch,
                                      dataformats='HW')
                self.writer.add_image(f'Targets/{step}',
                                      target_grid,
                                      global_step=epoch,
                                      dataformats='HW')
                self.writer.add_image(f'Deltas/{step}',
                                      delta_grid,
                                      global_step=epoch,
                                      dataformats='HW')

        return self._get_epoch_outputs(epoch,
                                       epoch_loss,
                                       epoch_loss_components,
                                       training=False)

    def _get_epoch_outputs(self,
                           epoch,
                           epoch_loss,
                           epoch_loss_components,
                           training=True):
        mode = 'Training' if training else 'Validation'
        num_slices = len(self.train_loader.dataset) if training else len(
            self.val_loader.dataset)

        # Checking for nan values.
        epoch_loss = torch.stack(epoch_loss)
        is_finite = torch.isfinite(epoch_loss)
        num_nans = (is_finite.size(0) - is_finite.sum()).item()
        if num_nans > 0:
            self.logger.warning(
                f'Epoch {epoch} {mode}: {num_nans} NaN values present in {num_slices} slices'
            )
            epoch_loss = torch.mean(epoch_loss[is_finite]).item()
        else:
            epoch_loss = torch.mean(epoch_loss).item()

        for key, value in epoch_loss_components.items():
            epoch_loss_component = torch.stack(value)
            is_finite = torch.isfinite(epoch_loss_component)
            num_nans = (is_finite.size(0) - is_finite.sum()).item()

            if num_nans > 0:
                self.logger.warning(
                    f'Epoch {epoch} {mode} {key}: {num_nans} NaN values present in {num_slices} slices'
                )
                epoch_loss_components[key] = torch.mean(
                    epoch_loss_component[is_finite]).item()
            else:
                epoch_loss_components[key] = torch.mean(
                    epoch_loss_component).item()

        return epoch_loss, epoch_loss_components

    def _log_step_outputs(self,
                          epoch,
                          step,
                          step_loss,
                          step_loss_components,
                          training=True):
        mode = 'Training' if training else 'Validation'
        self.logger.info(
            f'Epoch {epoch:03d} Step {step:03d} {mode} loss: {step_loss.item():.4e}'
        )
        for key, value in step_loss_components.items():
            self.logger.info(
                f'Epoch {epoch:03d} Step {step:03d}: {mode} {key}: {value.item():.4e}'
            )

    def _log_epoch_outputs(self,
                           epoch,
                           epoch_loss,
                           epoch_loss_components,
                           elapsed_secs,
                           training=True):
        mode = 'Training' if training else 'Validation'
        self.logger.info(
            f'Epoch {epoch:03d} {mode}. loss: {epoch_loss:.4e}, '
            f'Time: {elapsed_secs // 60} min {elapsed_secs % 60} sec')
        self.writer.add_scalar(f'{mode}_epoch_loss',
                               scalar_value=epoch_loss,
                               global_step=epoch)

        for key, value in epoch_loss_components.items():
            self.logger.info(f'Epoch {epoch:03d} {mode}. {key}: {value}')
            self.writer.add_scalar(f'{mode}_epoch_{key}',
                                   scalar_value=value,
                                   global_step=epoch)
class ModelTrainerC2C:
    """
    Model trainer for Complex Image Learning.
    """
    def __init__(self,
                 args,
                 model,
                 optimizer,
                 train_loader,
                 val_loader,
                 post_processing,
                 c_loss,
                 metrics=None,
                 scheduler=None):

        multiprocessing.set_start_method(method='spawn')

        self.logger = get_logger(name=__name__,
                                 save_file=args.log_path / args.run_name)

        # Checking whether inputs are correct.
        assert isinstance(model,
                          nn.Module), '`model` must be a Pytorch Module.'
        assert isinstance(
            optimizer,
            optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.'
        assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \
            '`train_loader` and `val_loader` must be Pytorch DataLoader objects.'

        # I think this would be best practice.
        assert isinstance(
            post_processing,
            nn.Module), '`post_processing` must be a Pytorch Module.'

        # This is not a mistake. Pytorch implements loss functions as modules.
        assert isinstance(
            c_loss, nn.Module), '`c_loss` must be a callable Pytorch Module.'

        if metrics is not None:
            assert isinstance(
                metrics, Iterable
            ), '`metrics` must be an iterable, preferably a list or tuple.'
            for metric in metrics:
                assert callable(
                    metric), 'All metrics must be callable functions.'

        if scheduler is not None:
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_scheduler = True
            elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_scheduler = False
            else:
                raise TypeError(
                    '`scheduler` must be a Pytorch Learning Rate Scheduler.')

        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.post_processing_func = post_processing
        self.c_loss_func = c_loss
        self.metrics = metrics
        self.scheduler = scheduler

        self.verbose = args.verbose
        self.num_epochs = args.num_epochs
        self.writer = SummaryWriter(logdir=str(args.log_path))

        # Display interval of 0 means no display of validation images on TensorBoard.
        if args.max_images <= 0:
            self.display_interval = 0
        else:
            self.display_interval = int(
                len(self.val_loader.dataset) //
                (args.max_images * args.batch_size))

        self.checkpointer = CheckpointManager(
            model=self.model,
            optimizer=self.optimizer,
            mode='min',
            save_best_only=args.save_best_only,
            ckpt_dir=args.ckpt_path,
            max_to_keep=args.max_to_keep)

        # loading from checkpoint if specified.
        if vars(args).get('prev_model_ckpt'):
            self.checkpointer.load(load_dir=args.prev_model_ckpt,
                                   load_optimizer=False)

    def train_model(self):
        tic_tic = time()
        self.logger.info('Beginning Training Loop.')
        for epoch in range(1, self.num_epochs + 1):  # 1 based indexing
            # Training
            tic = time()
            train_epoch_loss, train_epoch_metrics = self._train_epoch(
                epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch=epoch,
                                    epoch_loss=train_epoch_loss,
                                    epoch_metrics=train_epoch_metrics,
                                    elapsed_secs=toc,
                                    training=True)

            # Validation
            tic = time()
            val_epoch_loss, val_epoch_metrics = self._val_epoch(epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch=epoch,
                                    epoch_loss=val_epoch_loss,
                                    epoch_metrics=val_epoch_metrics,
                                    elapsed_secs=toc,
                                    training=False)

            self.checkpointer.save(metric=val_epoch_loss, verbose=True)

            if self.scheduler is not None:
                if self.metric_scheduler:  # If the scheduler is a metric based scheduler, include metrics.
                    self.scheduler.step(metrics=val_epoch_loss)
                else:
                    self.scheduler.step()

        # Finishing Training Loop
        self.writer.close()  # Flushes remaining data to TensorBoard.
        toc_toc = int(time() - tic_tic)
        self.logger.info(
            f'Finishing Training Loop. Total elapsed time: '
            f'{toc_toc // 3600} hr {(toc_toc // 60) % 60} min {toc_toc % 60} sec.'
        )

    def _train_epoch(self, epoch):
        self.model.train()
        torch.autograd.set_grad_enabled(True)

        epoch_loss_list = list(
        )  # Appending values to list due to numerical underflow.
        epoch_metrics_list = [list()
                              for _ in self.metrics] if self.metrics else None

        # labels are fully sampled coil-wise images, not rss or esc.
        for step, (inputs, c_img_targets,
                   extra_params) in enumerate(self.train_loader, start=1):
            step_loss, c_img_recons = self._train_step(inputs, c_img_targets,
                                                       extra_params)

            # Gradients are not calculated so as to boost speed and remove weird errors.
            with torch.no_grad():  # Update epoch loss and metrics
                epoch_loss_list.append(step_loss.detach(
                ))  # Perhaps not elegant, but underflow makes this necessary.
                step_metrics = self._get_step_metrics(c_img_recons,
                                                      c_img_targets,
                                                      epoch_metrics_list)
                self._log_step_outputs(epoch,
                                       step,
                                       step_loss,
                                       step_metrics,
                                       training=True)

        epoch_loss, epoch_metrics = self._get_epoch_outputs(epoch,
                                                            epoch_loss_list,
                                                            epoch_metrics_list,
                                                            training=True)
        return epoch_loss, epoch_metrics

    def _train_step(self, inputs, c_img_targets, extra_params):
        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        c_img_recons = self.post_processing_func(outputs, c_img_targets,
                                                 extra_params)
        step_loss = self.c_loss_func(c_img_recons, c_img_targets)
        step_loss.backward()
        self.optimizer.step()
        return step_loss, c_img_recons

    def _val_epoch(self, epoch):
        self.model.eval()
        torch.autograd.set_grad_enabled(False)

        epoch_loss_list = list()
        epoch_metrics_list = [list()
                              for _ in self.metrics] if self.metrics else None

        for step, (inputs, c_img_targets,
                   extra_params) in enumerate(self.val_loader, start=1):
            step_loss, c_img_recons = self._val_step(inputs, c_img_targets,
                                                     extra_params)

            epoch_loss_list.append(step_loss.detach())
            # Step functions have internalized conditional statements deciding whether to execute or not.
            step_metrics = self._get_step_metrics(c_img_recons, c_img_targets,
                                                  epoch_metrics_list)
            self._log_step_outputs(epoch,
                                   step,
                                   step_loss,
                                   step_metrics,
                                   training=False)

            # Save images to TensorBoard.
            # Condition ensures that self.display_interval != 0 and that the step is right for display.
            if self.display_interval and (step % self.display_interval == 0):
                kspace_recons, kspace_targets, image_recons, image_targets, image_deltas \
                    = self._visualize_outputs(c_img_recons, c_img_targets, smoothing_factor=8)

                self.writer.add_image(f'k-space_Recons/{step}',
                                      kspace_recons,
                                      epoch,
                                      dataformats='HW')
                self.writer.add_image(f'k-space_Targets/{step}',
                                      kspace_targets,
                                      epoch,
                                      dataformats='HW')
                self.writer.add_image(f'Image_Recons/{step}',
                                      image_recons,
                                      epoch,
                                      dataformats='HW')
                self.writer.add_image(f'Image_Targets/{step}',
                                      image_targets,
                                      epoch,
                                      dataformats='HW')
                self.writer.add_image(f'Image_Deltas/{step}',
                                      image_deltas,
                                      epoch,
                                      dataformats='HW')

        epoch_loss, epoch_metrics = self._get_epoch_outputs(epoch,
                                                            epoch_loss_list,
                                                            epoch_metrics_list,
                                                            training=False)
        return epoch_loss, epoch_metrics

    def _val_step(self, inputs, c_img_targets, extra_params):
        """
        All extra parameters are to be placed in extra_params.
        This makes the system more flexible.
        """
        outputs = self.model(inputs)
        c_img_recons = self.post_processing_func(outputs, c_img_targets,
                                                 extra_params)
        step_loss = self.c_loss_func(c_img_recons, c_img_targets)
        return step_loss, c_img_recons

    def _get_step_metrics(self, c_img_recons, c_img_targets,
                          epoch_metrics_list):
        if self.metrics is not None:
            step_metrics = [
                metric(c_img_recons, c_img_targets) for metric in self.metrics
            ]
            for step_metric, epoch_metric_list in zip(step_metrics,
                                                      epoch_metrics_list):
                epoch_metric_list.append(step_metric.detach())
            return step_metrics
        return None  # Explicitly return None for step_metrics if self.metrics is None. Not necessary but more readable.

    def _get_epoch_outputs(self,
                           epoch,
                           epoch_loss_list,
                           epoch_metrics_list,
                           training=True):
        mode = 'Training' if training else 'Validation'
        num_slices = len(self.train_loader.dataset) if training else len(
            self.val_loader.dataset)

        # Checking for nan or inf values.
        epoch_loss_tensor = torch.stack(epoch_loss_list)
        finite_values = torch.isfinite(epoch_loss_tensor)
        num_nans = len(epoch_loss_list) - int(finite_values.sum().item())
        if num_nans > 0:
            self.logger.warning(
                f'Epoch {epoch} {mode}: {num_nans} NaN values present in {num_slices} slices'
            )
            epoch_loss = torch.mean(epoch_loss_tensor[finite_values]).item()
        else:
            epoch_loss = torch.mean(epoch_loss_tensor).item()

        if self.metrics:
            epoch_metrics = list()
            for idx, epoch_metric_list in enumerate(epoch_metrics_list,
                                                    start=1):
                epoch_metric_tensor = torch.stack(epoch_metric_list)
                finite_values = torch.isfinite(epoch_metric_tensor)
                num_nans = len(epoch_metric_list) - int(
                    finite_values.sum().item())

                if num_nans > 0:
                    self.logger.warning(
                        f'Epoch {epoch} {mode}: Metric {idx} has {num_nans} NaN values in {num_slices} slices'
                    )
                    epoch_metric = torch.mean(
                        epoch_metric_tensor[finite_values]).item()
                else:
                    epoch_metric = torch.mean(epoch_metric_tensor).item()

                epoch_metrics.append(epoch_metric)
        else:
            epoch_metrics = None

        return epoch_loss, epoch_metrics

    def _log_step_outputs(self,
                          epoch,
                          step,
                          step_loss,
                          step_metrics,
                          training=True):
        if self.verbose:
            mode = 'Training' if training else 'Validation'
            self.logger.info(
                f'Epoch {epoch:03d} Step {step:03d} {mode} loss: {step_loss.item():.4e}'
            )
            if self.metrics:
                for idx, step_metric in enumerate(step_metrics):
                    self.logger.info(
                        f'Epoch {epoch:03d} Step {step:03d}: {mode} metric {idx}: {step_metric.item():.4e}'
                    )

    def _log_epoch_outputs(self,
                           epoch,
                           epoch_loss,
                           epoch_metrics,
                           elapsed_secs,
                           training=True):
        mode = 'Training' if training else 'Validation'
        self.logger.info(
            f'Epoch {epoch:03d} {mode}. loss: {epoch_loss:.4e}, Time: {elapsed_secs // 60} min {elapsed_secs % 60} sec'
        )
        self.writer.add_scalar(f'{mode}_epoch_loss',
                               scalar_value=epoch_loss,
                               global_step=epoch)
        if isinstance(
                epoch_metrics, list
        ):  # The metrics being returned are either 'None' or a list of values.
            for idx, epoch_metric in enumerate(epoch_metrics, start=1):
                self.logger.info(
                    f'Epoch {epoch:03d} {mode}. Metric {idx}: {epoch_metric}')
                self.writer.add_scalar(f'{mode}_epoch_metric_{idx}',
                                       scalar_value=epoch_metric,
                                       global_step=epoch)

    @staticmethod
    def _visualize_outputs(c_img_recons, c_img_targets, smoothing_factor=8):
        image_recons = complex_abs(c_img_recons)
        image_targets = complex_abs(c_img_targets)
        kspace_recons = make_k_grid(fft2(c_img_recons), smoothing_factor)
        kspace_targets = make_k_grid(fft2(c_img_targets), smoothing_factor)
        image_recons, image_targets, image_deltas = make_grid_triplet(
            image_recons, image_targets)
        return kspace_recons, kspace_targets, image_recons, image_targets, image_deltas
Exemplo n.º 4
0
class ModelTrainerIMG:
    """
    Model trainer for real-valued image domain losses.
    This model trainer can accept k-space an semi-k-space, regardless of weighting.
    Both complex and real-valued image domain losses can be calculated.
    """
    def __init__(self,
                 args,
                 model,
                 optimizer,
                 train_loader,
                 val_loader,
                 input_train_transform,
                 input_val_transform,
                 output_train_transform,
                 output_val_transform,
                 losses,
                 scheduler=None):

        # Allow multiple processes to access tensors on GPU. Add checking for multiple continuous runs.
        if multiprocessing.get_start_method(allow_none=True) is None:
            multiprocessing.set_start_method(method='spawn')

        self.logger = get_logger(name=__name__,
                                 save_file=args.log_path / args.run_name)

        # Checking whether inputs are correct.
        assert isinstance(model,
                          nn.Module), '`model` must be a Pytorch Module.'
        assert isinstance(
            optimizer,
            optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.'
        assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \
            '`train_loader` and `val_loader` must be Pytorch DataLoader objects.'

        assert callable(input_train_transform) and callable(input_val_transform), \
            'input_transforms must be callable functions.'
        # I think this would be best practice.
        assert isinstance(output_train_transform, nn.Module) and isinstance(output_val_transform, nn.Module), \
            '`output_train_transform` and `output_val_transform` must be Pytorch Modules.'

        # 'losses' is expected to be a dictionary.
        # Even composite losses should be a single loss module with a tuple as its output.
        losses = nn.ModuleDict(losses)

        if scheduler is not None:
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_scheduler = True
            elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_scheduler = False
            else:
                raise TypeError(
                    '`scheduler` must be a Pytorch Learning Rate Scheduler.')

        # Display interval of 0 means no display of validation images on TensorBoard.
        if args.max_images <= 0:
            self.display_interval = 0
        else:
            self.display_interval = int(
                len(val_loader.dataset) // (args.max_images * args.batch_size))

        self.manager = CheckpointManager(model,
                                         optimizer,
                                         mode='min',
                                         save_best_only=args.save_best_only,
                                         ckpt_dir=args.ckpt_path,
                                         max_to_keep=args.max_to_keep)

        # loading from checkpoint if specified.
        if vars(args).get('prev_model_ckpt'):
            self.manager.load(load_dir=args.prev_model_ckpt,
                              load_optimizer=False)

        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.input_train_transform = input_train_transform
        self.input_val_transform = input_val_transform
        self.output_train_transform = output_train_transform
        self.output_val_transform = output_val_transform
        self.losses = losses
        self.scheduler = scheduler
        self.writer = SummaryWriter(str(args.log_path))

        self.verbose = args.verbose
        self.num_epochs = args.num_epochs
        self.smoothing_factor = args.smoothing_factor
        self.shrink_scale = args.shrink_scale
        self.use_slice_metrics = args.use_slice_metrics

        # This part should get SSIM, not 1 - SSIM.
        self.ssim = SSIM(filter_size=7).to(
            device=args.device)  # Needed to cache the kernel.

        # Logging all components of the Model Trainer.
        # Train and Val input and output transforms are assumed to use the same input transform class.
        self.logger.info(f'''
        Summary of Model Trainer Components:
        Model: {get_class_name(model)}.
        Optimizer: {get_class_name(optimizer)}.
        Input Transforms: {get_class_name(input_val_transform)}.
        Output Transform: {get_class_name(output_val_transform)}.
        Image Domain Loss: {get_class_name(losses['img_loss'])}.
        Learning-Rate Scheduler: {get_class_name(scheduler)}.
        ''')  # This part has parts different for IMG and CMG losses!!

    def train_model(self):
        tic_tic = time()
        self.logger.info('Beginning Training Loop.')
        for epoch in range(1,
                           self.num_epochs + 1):  # 1 based indexing of epochs.
            tic = time()  # Training
            train_epoch_loss, train_epoch_metrics = self._train_epoch(
                epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch,
                                    train_epoch_loss,
                                    train_epoch_metrics,
                                    elapsed_secs=toc,
                                    training=True)

            tic = time()  # Validation
            val_epoch_loss, val_epoch_metrics = self._val_epoch(epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch,
                                    val_epoch_loss,
                                    val_epoch_metrics,
                                    elapsed_secs=toc,
                                    training=False)

            self.manager.save(metric=val_epoch_loss, verbose=True)

            if self.scheduler is not None:
                if self.metric_scheduler:  # If the scheduler is a metric based scheduler, include metrics.
                    self.scheduler.step(metrics=val_epoch_loss)
                else:
                    self.scheduler.step()

        self.writer.close()  # Flushes remaining data to TensorBoard.
        toc_toc = int(time() - tic_tic)
        self.logger.info(
            f'Finishing Training Loop. Total elapsed time: '
            f'{toc_toc // 3600} hr {(toc_toc // 60) % 60} min {toc_toc % 60} sec.'
        )

    def _train_epoch(self, epoch):
        self.model.train()
        torch.autograd.set_grad_enabled(True)

        epoch_loss = list(
        )  # Appending values to list due to numerical underflow and NaN values.
        epoch_metrics = defaultdict(list)

        data_loader = enumerate(self.train_loader, start=1)
        if not self.verbose:  # tqdm has to be on the outermost iterator to function properly.
            data_loader = tqdm(
                data_loader, total=len(
                    self.train_loader.dataset))  # Should divide by batch size.

        for step, data in data_loader:
            # Data pre-processing is expected to have gradient calculations removed inside already.
            inputs, targets, extra_params = self.input_train_transform(*data)

            # 'recons' is a dictionary containing k-space, complex image, and real image reconstructions.
            recons, step_loss, step_metrics = self._train_step(
                inputs, targets, extra_params)
            epoch_loss.append(step_loss.detach(
            ))  # Perhaps not elegant, but underflow makes this necessary.

            # Gradients are not calculated so as to boost speed and remove weird errors.
            with torch.no_grad():  # Update epoch loss and metrics
                if self.use_slice_metrics:
                    slice_metrics = self._get_slice_metrics(
                        recons, targets, extra_params)
                    step_metrics.update(slice_metrics)

                [
                    epoch_metrics[key].append(value.detach())
                    for key, value in step_metrics.items()
                ]

                if self.verbose:
                    self._log_step_outputs(epoch,
                                           step,
                                           step_loss,
                                           step_metrics,
                                           training=True)

        # Converted to scalar and dict with scalar values respectively.
        return self._get_epoch_outputs(epoch,
                                       epoch_loss,
                                       epoch_metrics,
                                       training=True)

    def _train_step(self, inputs, targets, extra_params):
        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        recons = self.output_train_transform(outputs, targets, extra_params)
        step_loss, step_metrics = self._step(recons, targets, extra_params)
        step_loss.backward()
        self.optimizer.step()
        return recons, step_loss, step_metrics

    def _val_epoch(self, epoch):
        self.model.eval()
        torch.autograd.set_grad_enabled(False)

        epoch_loss = list()
        epoch_metrics = defaultdict(list)

        # 1 based indexing for steps.
        data_loader = enumerate(self.val_loader, start=1)
        if not self.verbose:
            data_loader = tqdm(data_loader, total=len(self.val_loader.dataset))

        for step, data in data_loader:
            inputs, targets, extra_params = self.input_val_transform(*data)
            recons, step_loss, step_metrics = self._val_step(
                inputs, targets, extra_params)
            epoch_loss.append(step_loss.detach())

            if self.use_slice_metrics:
                slice_metrics = self._get_slice_metrics(
                    recons, targets, extra_params)
                step_metrics.update(slice_metrics)

            [
                epoch_metrics[key].append(value.detach())
                for key, value in step_metrics.items()
            ]

            if self.verbose:
                self._log_step_outputs(epoch,
                                       step,
                                       step_loss,
                                       step_metrics,
                                       training=False)

            # Visualize images on TensorBoard.
            self._visualize_images(recons,
                                   targets,
                                   extra_params,
                                   epoch,
                                   step,
                                   training=False)

        # Converted to scalar and dict with scalar values respectively.
        return self._get_epoch_outputs(epoch,
                                       epoch_loss,
                                       epoch_metrics,
                                       training=False)

    def _val_step(self, inputs, targets, extra_params):
        outputs = self.model(inputs)
        recons = self.output_val_transform(outputs, targets, extra_params)
        step_loss, step_metrics = self._step(recons, targets, extra_params)
        return recons, step_loss, step_metrics

    def _step(self, recons, targets, extra_params):
        step_loss = self.losses['img_loss'](recons['img_recons'],
                                            targets['img_targets'])

        # If img_loss is a tuple, it is expected to contain all its component losses as a dict in its second element.
        step_metrics = dict()
        if isinstance(step_loss, tuple):
            step_loss, step_metrics = step_loss

        acc = extra_params["acceleration"]
        if step_metrics:  # This has to be checked before anything is added to step_metrics.
            for key, value in step_metrics.items():
                step_metrics[f'acc_{acc}_{key}'] = value
        step_metrics[f'acc_{acc}_loss'] = step_loss
        return step_loss, step_metrics

    def _visualize_images(self,
                          recons,
                          targets,
                          extra_params,
                          epoch,
                          step,
                          training=False):
        mode = 'Training' if training else 'Validation'

        # This numbering scheme seems to have issues for certain numbers.
        # Please check cases when there is no remainder.
        if self.display_interval and (step % self.display_interval == 0):
            img_recon_grid = make_img_grid(recons['img_recons'],
                                           self.shrink_scale)

            # The delta image is obtained by subtracting at the complex image, not the real valued image.
            delta_image = complex_abs(targets['cmg_targets'] -
                                      recons['cmg_recons'])
            delta_img_grid = make_img_grid(delta_image, self.shrink_scale)

            acc = extra_params['acceleration']
            kwargs = dict(global_step=epoch, dataformats='HW')

            self.writer.add_image(f'{mode} Image Recons/{acc}/{step}',
                                  img_recon_grid, **kwargs)
            self.writer.add_image(f'{mode} Delta Image/{acc}/{step}',
                                  delta_img_grid, **kwargs)

            if 'kspace_recons' in recons:
                kspace_recon_grid = make_k_grid(recons['kspace_recons'],
                                                self.smoothing_factor,
                                                self.shrink_scale)
                self.writer.add_image(f'{mode} k-space Recons/{acc}/{step}',
                                      kspace_recon_grid, **kwargs)

            # Adding RSS images of reconstructions and targets.
            if 'rss_recons' in recons:
                recon_rss = standardize_image(recons['rss_recons'])
                delta_rss = standardize_image(make_rss_slice(delta_image))
                self.writer.add_image(f'{mode} RSS Recons/{acc}/{step}',
                                      recon_rss, **kwargs)
                self.writer.add_image(f'{mode} RSS Delta/{acc}/{step}',
                                      delta_rss, **kwargs)

            if 'semi_kspace_recons' in recons:
                semi_kspace_recon_grid = make_k_grid(
                    recons['semi_kspace_recons'], self.smoothing_factor,
                    self.shrink_scale)

                self.writer.add_image(
                    f'{mode} semi-k-space Recons/{acc}/{step}',
                    semi_kspace_recon_grid, **kwargs)

            if epoch == 1:  # Maybe add input images too later on.
                img_target_grid = make_img_grid(targets['img_targets'],
                                                self.shrink_scale)
                self.writer.add_image(f'{mode} Image Targets/{acc}/{step}',
                                      img_target_grid, **kwargs)

                if 'kspace_targets' in targets:
                    kspace_target_grid = \
                        make_k_grid(targets['kspace_targets'], self.smoothing_factor, self.shrink_scale)
                    self.writer.add_image(
                        f'{mode} k-space Targets/{acc}/{step}',
                        kspace_target_grid, **kwargs)

                if 'img_inputs' in targets:
                    # Not actually the input but what the input looks like as an image.
                    img_grid = make_img_grid(targets['img_inputs'],
                                             self.shrink_scale)
                    self.writer.add_image(
                        f'{mode} Inputs as Images/{acc}/{step}', img_grid,
                        **kwargs)

                if 'rss_targets' in targets:
                    target_rss = standardize_image(targets['rss_targets'])
                    self.writer.add_image(f'{mode} RSS Targets/{acc}/{step}',
                                          target_rss, **kwargs)

                if 'semi_kspace_targets' in targets:
                    semi_kspace_target_grid = make_k_grid(
                        targets['semi_kspace_targets'], self.smoothing_factor,
                        self.shrink_scale)

                    self.writer.add_image(
                        f'{mode} semi-k-space Targets/{acc}/{step}',
                        semi_kspace_target_grid, **kwargs)

    def _get_slice_metrics(self, recons, targets, extra_params):
        img_recons = recons['img_recons'].detach()  # Just in case.
        img_targets = targets['img_targets'].detach()
        max_range = img_targets.max() - img_targets.min()

        slice_ssim = self.ssim(img_recons, img_targets)
        slice_psnr = psnr(img_recons, img_targets, data_range=max_range)
        slice_nmse = nmse(img_recons, img_targets)

        slice_metrics = {
            'slice/ssim': slice_ssim,
            'slice/nmse': slice_nmse,
            'slice/psnr': slice_psnr
        }

        if 'rss_recons' in recons:
            rss_recons = recons['rss_recons'].detach()
            rss_targets = targets['rss_targets'].detach()
            max_range = rss_targets.max() - rss_targets.min()

            rss_ssim = self.ssim(rss_recons, rss_targets)
            rss_psnr = psnr(rss_recons, rss_targets, data_range=max_range)
            rss_nmse = nmse(rss_recons, rss_targets)

            slice_metrics['rss/ssim'] = rss_ssim
            slice_metrics['rss/psnr'] = rss_psnr
            slice_metrics['rss/nmse'] = rss_nmse
        else:
            rss_ssim = rss_psnr = rss_nmse = 0

        # Additional metrics for separating between acceleration factors.
        if 'acceleration' in extra_params:
            acc = extra_params["acceleration"]
            slice_metrics[f'slice_acc_{acc}/ssim'] = slice_ssim
            slice_metrics[f'slice_acc_{acc}/psnr'] = slice_psnr
            slice_metrics[f'slice_acc_{acc}/nmse'] = slice_nmse

            if 'rss_recons' in recons:
                slice_metrics[f'rss_acc_{acc}/ssim'] = rss_ssim
                slice_metrics[f'rss_acc_{acc}/psnr'] = rss_psnr
                slice_metrics[f'rss_acc_{acc}/nmse'] = rss_nmse

        return slice_metrics

    def _get_epoch_outputs(self,
                           epoch,
                           epoch_loss,
                           epoch_metrics,
                           training=True):
        mode = 'Training' if training else 'Validation'
        num_slices = len(self.train_loader.dataset) if training else len(
            self.val_loader.dataset)

        # Checking for nan values.
        epoch_loss = torch.stack(epoch_loss)
        is_finite = torch.isfinite(epoch_loss)
        num_nans = (is_finite.size(0) - is_finite.sum()).item()

        if num_nans > 0:
            self.logger.warning(
                f'Epoch {epoch} {mode}: {num_nans} NaN values present in {num_slices} slices.'
                f'Turning on anomaly detection.')
            # Turn on anomaly detection for finding where the nan values are.
            torch.autograd.set_detect_anomaly(True)
            epoch_loss = torch.mean(epoch_loss[is_finite]).item()
        else:
            epoch_loss = torch.mean(epoch_loss).item()

        for key, value in epoch_metrics.items():
            epoch_metric = torch.stack(value)
            is_finite = torch.isfinite(epoch_metric)
            num_nans = (is_finite.size(0) - is_finite.sum()).item()

            if num_nans > 0:
                self.logger.warning(
                    f'Epoch {epoch} {mode} {key}: {num_nans} NaN values present in {num_slices} slices.'
                    f'Turning on anomaly detection.')
                epoch_metrics[key] = torch.mean(epoch_metric[is_finite]).item()
            else:
                epoch_metrics[key] = torch.mean(epoch_metric).item()

        return epoch_loss, epoch_metrics

    def _log_step_outputs(self,
                          epoch,
                          step,
                          step_loss,
                          step_metrics,
                          training=True):
        mode = 'Training' if training else 'Validation'
        self.logger.info(
            f'Epoch {epoch:03d} Step {step:03d} {mode} loss: {step_loss.item():.4e}'
        )
        for key, value in step_metrics.items():
            self.logger.info(
                f'Epoch {epoch:03d} Step {step:03d}: {mode} {key}: {value.item():.4e}'
            )

    def _log_epoch_outputs(self,
                           epoch,
                           epoch_loss,
                           epoch_metrics,
                           elapsed_secs,
                           training=True):
        mode = 'Training' if training else 'Validation'
        self.logger.info(
            f'Epoch {epoch:03d} {mode}. loss: {epoch_loss:.4e}, '
            f'Time: {elapsed_secs // 60} min {elapsed_secs % 60} sec')
        self.writer.add_scalar(f'{mode} epoch_loss',
                               scalar_value=epoch_loss,
                               global_step=epoch)

        for key, value in epoch_metrics.items():
            self.logger.info(f'Epoch {epoch:03d} {mode}. {key}: {value:.4e}')
            # Very important whether it is mode_~~ or mode/~~.
            if 'loss' in key:
                self.writer.add_scalar(f'{mode}/epoch_{key}',
                                       scalar_value=value,
                                       global_step=epoch)
            else:
                self.writer.add_scalar(f'{mode}_epoch_{key}',
                                       scalar_value=value,
                                       global_step=epoch)

        if not training:  # Record learning rate.
            for idx, group in enumerate(self.optimizer.param_groups, start=1):
                self.writer.add_scalar(f'learning_rate_{idx}',
                                       group['lr'],
                                       global_step=epoch)
class ModelTrainerK2I:
    """
    Model Trainer for K-space learning.
    Please note a bit of terminology.
    In this file, 'recons' indicates coil-wise reconstructions,
    not final reconstructions for submissions.
    Also, 'targets' indicates coil-wise targets, not the 320x320 ground-truth labels.
    k-slice means a slice of k-space, i.e. only 1 slice of k-space.
    """
    def __init__(self,
                 args,
                 model,
                 optimizer,
                 train_loader,
                 val_loader,
                 post_processing,
                 loss_func,
                 metrics=None,
                 scheduler=None):
        self.logger = get_logger(name=__name__,
                                 save_file=args.log_path / args.run_name)

        # Checking whether inputs are correct.
        assert isinstance(model,
                          nn.Module), '`model` must be a Pytorch Module.'
        assert isinstance(
            optimizer,
            optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.'
        assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \
            '`train_loader` and `val_loader` must be Pytorch DataLoader objects.'

        # I think this would be best practice.
        assert isinstance(
            post_processing,
            nn.Module), '`post_processing_func` must be a Pytorch Module.'

        # This is not a mistake. Pytorch implements loss functions as modules.
        assert isinstance(
            loss_func,
            nn.Module), '`loss_func` must be a callable Pytorch Module.'

        if metrics is not None:
            assert isinstance(
                metrics, (list, tuple)), '`metrics` must be a list or tuple.'
            for metric in metrics:
                assert callable(
                    metric), 'All metrics must be callable functions.'

        if scheduler is not None:
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_scheduler = True
            elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_scheduler = False
            else:
                raise TypeError(
                    '`scheduler` must be a Pytorch Learning Rate Scheduler.')

        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.post_processing_func = post_processing
        self.loss_func = loss_func  # I don't think it is necessary to send loss_func or metrics to device.
        self.metrics = metrics
        self.scheduler = scheduler

        self.verbose = args.verbose
        self.num_epochs = args.num_epochs
        self.writer = SummaryWriter(logdir=str(args.log_path))

        # Display interval of 0 means no display of validation images on TensorBoard.
        self.display_interval = int(
            len(self.val_loader.dataset) //
            args.max_images) if (args.max_images > 0) else 0

        # # Writing model graph to TensorBoard. Results might not be very good.
        if args.add_graph:
            num_chans = 30 if args.challenge == 'multicoil' else 2
            example_inputs = torch.ones(size=(1, num_chans, 640, 328),
                                        device=args.device)
            self.writer.add_graph(model=model, input_to_model=example_inputs)
            del example_inputs  # Remove unnecessary tensor taking up memory.

        self.checkpointer = CheckpointManager(
            model=self.model,
            optimizer=self.optimizer,
            mode='min',
            save_best_only=args.save_best_only,
            ckpt_dir=args.ckpt_path,
            max_to_keep=args.max_to_keep)

        # loading from checkpoint if specified.
        if vars(args).get('prev_model_ckpt'):
            self.checkpointer.load(load_dir=args.prev_model_ckpt,
                                   load_optimizer=False)

    def train_model(self):
        multiprocessing.set_start_method(method='spawn')
        self.logger.info('Beginning Training Loop.')
        tic_tic = time()
        for epoch in range(1, self.num_epochs + 1):  # 1 based indexing
            # Training
            tic = time()
            train_epoch_loss, train_epoch_metrics = self._train_epoch(
                epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch=epoch,
                                    epoch_loss=train_epoch_loss,
                                    epoch_metrics=train_epoch_metrics,
                                    elapsed_secs=toc,
                                    training=True)

            # Validation
            tic = time()
            val_epoch_loss, val_epoch_metrics = self._val_epoch(epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch=epoch,
                                    epoch_loss=val_epoch_loss,
                                    epoch_metrics=val_epoch_metrics,
                                    elapsed_secs=toc,
                                    training=False)

            self.checkpointer.save(metric=val_epoch_loss, verbose=True)

            if self.scheduler is not None:
                if self.metric_scheduler:  # If the scheduler is a metric based scheduler, include metrics.
                    self.scheduler.step(metrics=val_epoch_loss)
                else:
                    self.scheduler.step()

        # Finishing Training Loop
        self.writer.close()  # Flushes remaining data to TensorBoard.
        toc_toc = int(time() - tic_tic)
        self.logger.info(
            f'Finishing Training Loop. Total elapsed time: '
            f'{toc_toc // 3600} hr {(toc_toc // 60) % 60} min {toc_toc % 60} sec.'
        )

    def _train_step(self, inputs, targets, extra_params):
        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        image_recons, kspace_recons = self.post_processing_func(
            outputs, targets, extra_params)
        step_loss = self.loss_func(image_recons, targets)
        step_loss.backward()
        self.optimizer.step()
        return step_loss, image_recons, kspace_recons

    def _train_epoch(self, epoch):
        self.model.train()
        torch.autograd.set_grad_enabled(True)

        epoch_loss_lst = list(
        )  # Appending values to list due to numerical underflow.
        epoch_metrics_lst = [list()
                             for _ in self.metrics] if self.metrics else None

        # labels are fully sampled coil-wise images, not rss or esc.
        for step, (inputs, targets,
                   extra_params) in enumerate(self.train_loader, start=1):
            step_loss, image_recons, kspace_recons = self._train_step(
                inputs, targets, extra_params)

            # Gradients are not calculated so as to boost speed and remove weird errors.
            with torch.no_grad():  # Update epoch loss and metrics
                epoch_loss_lst.append(step_loss.item(
                ))  # Perhaps not elegant, but underflow makes this necessary.

                # The step functions here have all necessary conditionals internally.
                # There is no need to externally specify whether to use them or not.
                step_metrics = self._get_step_metrics(image_recons, targets,
                                                      epoch_metrics_lst)
                self._log_step_outputs(epoch,
                                       step,
                                       step_loss,
                                       step_metrics,
                                       training=True)

        epoch_loss, epoch_metrics = self._get_epoch_outputs(epoch,
                                                            epoch_loss_lst,
                                                            epoch_metrics_lst,
                                                            training=True)
        return epoch_loss, epoch_metrics

    def _val_step(self, inputs, targets, extra_params):
        """
        All extra parameters are to be placed in extra_params.
        This makes the system more flexible.
        """
        outputs = self.model(inputs)
        image_recons, kspace_recons = self.post_processing_func(
            outputs, targets, extra_params)
        step_loss = self.loss_func(image_recons, targets)
        return step_loss, image_recons, kspace_recons

    def _val_epoch(self, epoch):
        self.model.eval()
        torch.autograd.set_grad_enabled(False)

        epoch_loss_lst = list()
        epoch_metrics_lst = [list()
                             for _ in self.metrics] if self.metrics else None

        for step, (inputs, targets, extra_params) in enumerate(self.val_loader,
                                                               start=1):
            step_loss, image_recons, kspace_recons = self._val_step(
                inputs, targets, extra_params)

            epoch_loss_lst.append(step_loss.item())
            # Step functions have internalized conditional statements deciding whether to execute or not.
            step_metrics = self._get_step_metrics(image_recons, targets,
                                                  epoch_metrics_lst)
            self._log_step_outputs(epoch,
                                   step,
                                   step_loss,
                                   step_metrics,
                                   training=False)

            # Save images to TensorBoard. Send this to a separate function later on.
            # Condition ensures that self.display_interval != 0 and that the step is right for display.
            if self.display_interval and (step % self.display_interval == 0):
                recons_grid, targets_grid, deltas_grid = make_grid_triplet(
                    image_recons, targets)
                kspace_grid = make_k_grid(kspace_recons)

                self.writer.add_image(f'k-space_Recons/{step}',
                                      kspace_grid,
                                      epoch,
                                      dataformats='HW')
                self.writer.add_image(f'Image_Recons/{step}',
                                      recons_grid,
                                      epoch,
                                      dataformats='HW')
                self.writer.add_image(f'Targets/{step}',
                                      targets_grid,
                                      epoch,
                                      dataformats='HW')
                self.writer.add_image(f'Deltas/{step}',
                                      deltas_grid,
                                      epoch,
                                      dataformats='HW')

        epoch_loss, epoch_metrics = self._get_epoch_outputs(epoch,
                                                            epoch_loss_lst,
                                                            epoch_metrics_lst,
                                                            training=False)
        return epoch_loss, epoch_metrics

    def _get_step_metrics(self, image_recons, targets, epoch_metrics_lst):
        if self.metrics is not None:
            step_metrics = [
                metric(image_recons, targets) for metric in self.metrics
            ]
            for step_metric, epoch_metric_lst in zip(step_metrics,
                                                     epoch_metrics_lst):
                epoch_metric_lst.append(step_metric.item())
            return step_metrics
        return None  # Explicitly return None for step_metrics if self.metrics is None. Not necessary but more readable.

    def _get_epoch_outputs(self,
                           epoch,
                           epoch_loss_lst,
                           epoch_metrics_lst,
                           training=True):
        mode = 'training' if training else 'validation'
        num_slices = len(self.train_loader.dataset) if training else len(
            self.val_loader.dataset)

        # Checking for nan values.
        num_nans = np.isnan(epoch_loss_lst).sum()
        if num_nans > 0:
            self.logger.warning(
                f'Epoch {epoch} {mode}: {num_nans} NaN values present in {num_slices} slices'
            )

        epoch_loss = float(
            np.nanmean(epoch_loss_lst))  # Remove nan values just in case.
        epoch_metrics = [
            float(np.nanmean(epoch_metric_lst))
            for epoch_metric_lst in epoch_metrics_lst
        ] if self.metrics else None

        return epoch_loss, epoch_metrics

    def _log_step_outputs(self,
                          epoch,
                          step,
                          step_loss,
                          step_metrics,
                          training=True):
        if self.verbose:
            mode = 'Training' if training else 'Validation'
            self.logger.info(
                f'Epoch {epoch:03d} Step {step:03d} {mode} loss: {step_loss.item():.4e}'
            )
            if self.metrics:
                for idx, step_metric in enumerate(step_metrics):
                    self.logger.info(
                        f'Epoch {epoch:03d} Step {step:03d}: {mode} metric {idx}: {step_metric.item():.4e}'
                    )

    def _log_epoch_outputs(self,
                           epoch,
                           epoch_loss,
                           epoch_metrics,
                           elapsed_secs,
                           training=True):
        mode = 'Training' if training else 'Validation'
        self.logger.info(
            f'Epoch {epoch:03d} {mode}. loss: {epoch_loss:.4e}, Time: {elapsed_secs // 60} min {elapsed_secs % 60} sec'
        )
        self.writer.add_scalar(f'{mode}_epoch_loss',
                               scalar_value=epoch_loss,
                               global_step=epoch)
        if isinstance(
                epoch_metrics, list
        ):  # The metrics being returned are either 'None' or a list of values.
            for idx, epoch_metric in enumerate(epoch_metrics, start=1):
                self.logger.info(
                    f'Epoch {epoch:03d} {mode}. Metric {idx}: {epoch_metric}')
                self.writer.add_scalar(f'{mode}_epoch_metric_{idx}',
                                       scalar_value=epoch_metric,
                                       global_step=epoch)