def evaluate_qualitatively_on_dataset(
    tag: str,
    dataset: TorchDataset,
    model: Module,
    tb_writer: SummaryWriter,
    max_number_of_images: int = 15,
):
    save_dir = os.path.join(tb_writer.get_logdir(), "imgs")
    os.makedirs(save_dir, exist_ok=True)
    images = []
    for _ in range(min(len(dataset), max_number_of_images)):
        data = dataset[_]
        prediction = model(data["observation"].unsqueeze(0),
                           intermediate_outputs=False)
        mask = prediction.detach().cpu().squeeze().numpy()
        obs = data["observation"].detach().cpu().squeeze().permute(1, 2,
                                                                   0).numpy()
        combined = combine_mask_observation(mask, obs)
        images.append(torch.from_numpy(combined).permute(2, 0, 1))
        fig, ax = plt.subplots(1, 3, figsize=(9, 3))
        ax[0].imshow(obs)
        ax[0].axis("off")
        ax[1].imshow(mask)
        ax[1].axis("off")
        ax[2].imshow(combined)
        ax[2].axis("off")
        fig.tight_layout()
        plt.savefig(os.path.join(save_dir, f"{tag}_{_}.jpg"))
        plt.close(fig=fig)
    grid = torchvision.utils.make_grid(torch.stack(images), nrow=5)
    tb_writer.add_image(tag, grid, dataformats="CHW")
Exemple #2
0
def addColorGrid(
    inputImages: torch.Tensor,
    tag: str,
    step: int,
    writer: SummaryWriter = None,
    mlflowFile=None,
):
    images = list(inputImages.cpu().unbind(dim=0))
    nImages = len(images)
    rows = int(ceil(sqrt(nImages)))
    if not writer is None:
        grid = make_grid(images, scale_each=True, normalize=True, nrow=rows)
        writer.add_image(tag, grid, step)
    if not mlflowFile is None:
        with tempfile.NamedTemporaryFile(prefix=mlflowFile + "_",
                                         suffix=".png") as f:

            save_image(images,
                       f.name,
                       scale_each=True,
                       normalize=True,
                       nrow=rows)
            mlflow.log_artifact(f.name)

    return
Exemple #3
0
class TBManager():
    """
        A wrapper around Tensorboard.
    """
    def __init__(self):
        self.writer = SummaryWriter(log_dir=TENSORBOARD_LOG_DIR,
                                    comment='',
                                    purge_step=None)

    def add_scalar(self, name, scalar, epoch):
        self.writer.add_scalar(name, scalar, epoch)

    def add_images(self, name, model, arg_images):
        images = arg_images * 255
        grid = make_grid(images)
        self.writer.add_image(name, grid, 0)
        # if model:
        # self.writer.add_graph(model, images)

    def close(self):
        self.writer.close()
Exemple #4
0
    def forward(self,
                input,
                hidden,
                writer: SummaryWriter = None,
                step: int = None):
        """input: image [B, C, H, W], hidden: level set [B, C, H, W]; C == 1"""

        batch_size = input.size(0)

        c1, c2 = self.avg_inside(input, hidden.detach()), self.avg_outside(
            input, hidden.detach())
        I_c1 = (self.k1 * input - c1)**2
        I_c2 = (self.k2 * input - c2)**2
        kappa = self.curvature(hidden)
        make_grid_p = partial(make_grid, nrow=4, normalize=True)
        if writer is not None:
            writer.add_image(f"rls/hidden", make_grid_p(hidden.detach()), step)
            writer.add_image(f"rls/I_c1", make_grid_p(I_c1.detach()), step)
            writer.add_image(f"rls/I_c2", make_grid_p(I_c2.detach()), step)
            writer.add_image(f"rls/kappa", make_grid(kappa.detach(), 4), step)

        hidden = self.gru_rls_cell(hidden, kappa, I_c1, I_c2)
        plot(input, 'input')
        plot(hidden, "levelset")
        plot(kappa, "kappa")

        output = self.dense(hidden)
        return (
            output.view(batch_size, *self.img_size),
            hidden.view(batch_size, 1, *self.img_size),
        )
Exemple #5
0
class TensorboardSummaryHook:
    """
    Logging object allowing Tensorboard summaries to be automatically exported to the tensorboard. Much of its
    functionality is automated. This means that the hook will export as much information as possible to the
    tensorboard.

    Losses, Metrics, Inputs and Outputs are all interpreted and exported according to their dimensionality. Vectors
    results in mean and standard deviation estimates as well as histograms; Pictures results in image summaries and
    histograms; etc.

    There is also the possibily of comparing inputs and outputs pair. This needs to be specified during object
    instantiation.

    Once the user instantiates this object, the workflow corresponding to the ID passes as argument will be
    tracked and the results of the workflow will be exported to the tensorboard.

    .. code-block:: python

            from eisen.utils.logging import TensorboardSummaryHook

            workflow = # Eg. An instance of Training workflow

            logger = TensorboardSummaryHook(workflow.id, 'Training', '/artifacts/dir')
    """

    def __init__(
        self,
        workflow_id,
        phase,
        artifacts_dir,
        comparison_pairs=None,
        show_all_axes=False,
    ):
        """
        This method instantiates an object of type TensorboardSummaryHook. The signature of this method is similar to
        that of every other hook. There is one additional parameter called `comparison_pairs` which is meant to
        hold a list of lists each containing a pair of input/output names that share the same dimensionality and can be
        compared to each other.

        A typical use of `comparison_pairs` is when users want to plot a pr_curve or a confusion matrix by comparing
        some input with some output. Eg. by comparing the labels with the predictions.

        .. code-block:: python

            from eisen.utils.logging import TensorboardSummaryHook

            workflow = # Eg. An instance of Training workflow

            logger = TensorboardSummaryHook(
                workflow_id=workflow.id,
                phase='Training',
                artifacts_dir='/artifacts/dir'
                comparison_pairs=[['labels', 'predictions']]
            )

        :param workflow_id: string containing the workflow id of the workflow being monitored (workflow_instance.id)
        :type workflow_id: UUID
        :param phase: string containing the name of the phase (training, testing, ...) of the workflow monitored
        :type phase: str
        :param artifacts_dir: whether the history of all models that were at a certain point the best should be saved
        :type artifacts_dir: bool
        :param comparison_pairs: list of lists of pairs, which are names of inputs and outputs to be compared directly
        :type comparison_pairs: list of lists of strings
        :param show_all_axes: whether any volumetric data should be shown as axial + sagittal + coronal
        :type show_all_axes: bool

        <json>
        [
            {"name": "comparison_pairs", "type": "list:list:string", "value": ""},
            {"name": "show_all_axes", "type": "bool", "value": "false"}
        ]
        </json>
        """
        self.workflow_id = workflow_id
        self.phase = phase

        self.comparison_pairs = comparison_pairs
        self.show_all_axes = show_all_axes

        if not os.path.exists(artifacts_dir):
            raise ValueError("The directory specified to save artifacts does not exist!")

        dispatcher.connect(self.end_epoch, signal=EISEN_END_EPOCH_EVENT, sender=workflow_id)

        self.artifacts_dir = os.path.join(artifacts_dir, "summaries", phase)

        if not os.path.exists(self.artifacts_dir):
            os.makedirs(self.artifacts_dir)

        self.writer = SummaryWriter(log_dir=self.artifacts_dir)

    def end_epoch(self, message):
        epoch = message["epoch"]

        # if epoch == 0:
        #     self.writer.add_graph(message['model'], ...)

        for typ in ["losses", "metrics"]:
            for dct in message[typ]:
                for key in dct.keys():
                    self.write_vector(typ + "/{}".format(key), dct[key], epoch)

        for typ in ["inputs", "outputs"]:
            for key in message[typ].keys():
                if message[typ][key].ndim == 5:
                    # Volumetric image (N, C, W, H, D)
                    self.write_volumetric_image(typ + "/{}".format(key), message[typ][key], epoch)

                if message[typ][key].ndim == 4:
                    self.write_image(typ + "/{}".format(key), message[typ][key], epoch)

                if message[typ][key].ndim == 3:
                    self.write_embedding(typ + "/{}".format(key), message[typ][key], epoch)

                if message[typ][key].ndim == 2:
                    self.write_class_probabilities(typ + "/{}".format(key), message[typ][key], epoch)

                if message[typ][key].ndim == 1:
                    self.write_vector(typ + "/{}".format(key), message[typ][key], epoch)

                if message[typ][key].ndim == 0:
                    self.write_scalar(typ + "/{}".format(key), message[typ][key], epoch)

        if self.comparison_pairs:
            for inp, out in self.comparison_pairs:
                assert message["inputs"][inp].ndim == message["outputs"][out].ndim

                if message["inputs"][inp].ndim == 1:
                    # in case of binary classification >> PR curve
                    if np.max(message["inputs"][inp]) <= 1 and np.max(message["outputs"][out]) <= 1:
                        self.write_pr_curve(
                            "{}_Vs_{}/pr_curve".format(inp, out),
                            message["inputs"][inp],
                            message["outputs"][out],
                            epoch,
                        )

                    # in any case for classification >> Confusion Matrix
                    self.write_confusion_matrix(
                        "{}_Vs_{}/confusion_matrix".format(inp, out),
                        message["inputs"][inp],
                        message["outputs"][out],
                        epoch,
                    )

    def write_volumetric_image(self, name, value, global_step):
        self.writer.add_scalar(name + "/mean", np.mean(value), global_step=global_step)
        self.writer.add_scalar(name + "/std", np.std(value), global_step=global_step)
        self.writer.add_histogram(name + "/histogram", value.flatten(), global_step=global_step)

        v = np.transpose(value, [0, 2, 1, 3, 4])

        if v.shape[2] != 3 and v.shape[2] != 1:
            v = np.average(v, axis=2, weights=np.arange(0, 1, 1 / v.shape[2]))[:, :, np.newaxis]

        torch_value = torch.tensor(v).float()

        self.writer.add_video(name + "_axis_1", torch_value, fps=10, global_step=global_step)

        if self.show_all_axes:
            v = np.transpose(value, [0, 3, 1, 2, 4])

            if v.shape[2] != 3 and v.shape[2] != 1:
                v = np.average(v, axis=2, weights=np.arange(0, 1, 1 / v.shape[2]))[:, :, np.newaxis]

            torch_value = torch.tensor(v).float()

            self.writer.add_video(name + "_axis_2", torch_value, fps=10, global_step=global_step)

            v = np.transpose(value, [0, 4, 1, 2, 3])

            if v.shape[2] != 3 and v.shape[2] != 1:
                v = np.average(v, axis=2, weights=np.arange(0, 1, 1 / v.shape[2]))[:, :, np.newaxis]

            torch_value = torch.tensor(v).float()

            self.writer.add_video(name + "_axis_3", torch_value, fps=10, global_step=global_step)

    def write_image(self, name, value, global_step):
        self.writer.add_scalar(name + "/mean", np.mean(value), global_step=global_step)
        self.writer.add_scalar(name + "/std", np.std(value), global_step=global_step)
        self.writer.add_histogram(name + "/histogram", value.flatten(), global_step=global_step)
        self.writer.add_images(name, value, global_step=global_step, dataformats="NCHW")

    def write_embedding(self, name, value, global_step):
        pass

    def write_pr_curve(self, name, labels, predictions, global_step):
        self.writer.add_pr_curve(name + "/pr_curve", labels, predictions, global_step)

    def write_confusion_matrix(self, name, labels, predictions, global_step):
        cnf_matrix = confusion_matrix(labels, predictions)
        image = plot_confusion_matrix(cnf_matrix, range(np.max(labels) + 1), normalize=True, title=name)[:, :, 0:3]
        self.writer.add_image(
            name,
            image.astype(float) / 255.0,
            global_step=global_step,
            dataformats="HWC",
        )

    def write_class_probabilities(self, name, value, global_step):
        self.writer.add_image(name, value, global_step=global_step, dataformats="HW")
        self.writer.add_histogram(name + "/distribution", np.argmax(value), global_step=global_step)

    def write_vector(self, name, value, global_step):
        self.writer.add_histogram(name, value, global_step=global_step)
        self.writer.add_scalar(name + "/mean", np.mean(value), global_step=global_step)
        self.writer.add_scalar(name + "/std", np.std(value), global_step=global_step)

    def write_scalar(self, name, value, global_step):
        self.writer.add_scalar(name, value, global_step=global_step)
class ModelTrainerIMG:
    """
    Model trainer for real-valued image domain losses.
    This model trainer can accept k-space an semi-k-space, regardless of weighting.
    Both complex and real-valued image domain losses can be calculated.
    """
    def __init__(self,
                 args,
                 model,
                 optimizer,
                 train_loader,
                 val_loader,
                 input_train_transform,
                 input_val_transform,
                 output_train_transform,
                 output_val_transform,
                 losses,
                 scheduler=None):

        # Allow multiple processes to access tensors on GPU. Add checking for multiple continuous runs.
        if multiprocessing.get_start_method(allow_none=True) is None:
            multiprocessing.set_start_method(method='spawn')

        self.logger = get_logger(name=__name__,
                                 save_file=args.log_path / args.run_name)

        # Checking whether inputs are correct.
        assert isinstance(model,
                          nn.Module), '`model` must be a Pytorch Module.'
        assert isinstance(
            optimizer,
            optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.'
        assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \
            '`train_loader` and `val_loader` must be Pytorch DataLoader objects.'

        assert callable(input_train_transform) and callable(input_val_transform), \
            'input_transforms must be callable functions.'
        # I think this would be best practice.
        assert isinstance(output_train_transform, nn.Module) and isinstance(output_val_transform, nn.Module), \
            '`output_train_transform` and `output_val_transform` must be Pytorch Modules.'

        # 'losses' is expected to be a dictionary.
        # Even composite losses should be a single loss module with a tuple as its output.
        losses = nn.ModuleDict(losses)

        if scheduler is not None:
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_scheduler = True
            elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_scheduler = False
            else:
                raise TypeError(
                    '`scheduler` must be a Pytorch Learning Rate Scheduler.')

        # Display interval of 0 means no display of validation images on TensorBoard.
        if args.max_images <= 0:
            self.display_interval = 0
        else:
            self.display_interval = int(
                len(val_loader.dataset) // (args.max_images * args.batch_size))

        self.manager = CheckpointManager(model,
                                         optimizer,
                                         mode='min',
                                         save_best_only=args.save_best_only,
                                         ckpt_dir=args.ckpt_path,
                                         max_to_keep=args.max_to_keep)

        # loading from checkpoint if specified.
        if vars(args).get('prev_model_ckpt'):
            self.manager.load(load_dir=args.prev_model_ckpt,
                              load_optimizer=False)

        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.input_train_transform = input_train_transform
        self.input_val_transform = input_val_transform
        self.output_train_transform = output_train_transform
        self.output_val_transform = output_val_transform
        self.losses = losses
        self.scheduler = scheduler
        self.writer = SummaryWriter(str(args.log_path))

        self.verbose = args.verbose
        self.num_epochs = args.num_epochs
        self.smoothing_factor = args.smoothing_factor
        self.shrink_scale = args.shrink_scale
        self.use_slice_metrics = args.use_slice_metrics

        # This part should get SSIM, not 1 - SSIM.
        self.ssim = SSIM(filter_size=7).to(
            device=args.device)  # Needed to cache the kernel.

        # Logging all components of the Model Trainer.
        # Train and Val input and output transforms are assumed to use the same input transform class.
        self.logger.info(f'''
        Summary of Model Trainer Components:
        Model: {get_class_name(model)}.
        Optimizer: {get_class_name(optimizer)}.
        Input Transforms: {get_class_name(input_val_transform)}.
        Output Transform: {get_class_name(output_val_transform)}.
        Image Domain Loss: {get_class_name(losses['img_loss'])}.
        Learning-Rate Scheduler: {get_class_name(scheduler)}.
        ''')  # This part has parts different for IMG and CMG losses!!

    def train_model(self):
        tic_tic = time()
        self.logger.info('Beginning Training Loop.')
        for epoch in range(1,
                           self.num_epochs + 1):  # 1 based indexing of epochs.
            tic = time()  # Training
            train_epoch_loss, train_epoch_metrics = self._train_epoch(
                epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch,
                                    train_epoch_loss,
                                    train_epoch_metrics,
                                    elapsed_secs=toc,
                                    training=True)

            tic = time()  # Validation
            val_epoch_loss, val_epoch_metrics = self._val_epoch(epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch,
                                    val_epoch_loss,
                                    val_epoch_metrics,
                                    elapsed_secs=toc,
                                    training=False)

            self.manager.save(metric=val_epoch_loss, verbose=True)

            if self.scheduler is not None:
                if self.metric_scheduler:  # If the scheduler is a metric based scheduler, include metrics.
                    self.scheduler.step(metrics=val_epoch_loss)
                else:
                    self.scheduler.step()

        self.writer.close()  # Flushes remaining data to TensorBoard.
        toc_toc = int(time() - tic_tic)
        self.logger.info(
            f'Finishing Training Loop. Total elapsed time: '
            f'{toc_toc // 3600} hr {(toc_toc // 60) % 60} min {toc_toc % 60} sec.'
        )

    def _train_epoch(self, epoch):
        self.model.train()
        torch.autograd.set_grad_enabled(True)

        epoch_loss = list(
        )  # Appending values to list due to numerical underflow and NaN values.
        epoch_metrics = defaultdict(list)

        data_loader = enumerate(self.train_loader, start=1)
        if not self.verbose:  # tqdm has to be on the outermost iterator to function properly.
            data_loader = tqdm(
                data_loader, total=len(
                    self.train_loader.dataset))  # Should divide by batch size.

        for step, data in data_loader:
            # Data pre-processing is expected to have gradient calculations removed inside already.
            inputs, targets, extra_params = self.input_train_transform(*data)

            # 'recons' is a dictionary containing k-space, complex image, and real image reconstructions.
            recons, step_loss, step_metrics = self._train_step(
                inputs, targets, extra_params)
            epoch_loss.append(step_loss.detach(
            ))  # Perhaps not elegant, but underflow makes this necessary.

            # Gradients are not calculated so as to boost speed and remove weird errors.
            with torch.no_grad():  # Update epoch loss and metrics
                if self.use_slice_metrics:
                    slice_metrics = self._get_slice_metrics(
                        recons, targets, extra_params)
                    step_metrics.update(slice_metrics)

                [
                    epoch_metrics[key].append(value.detach())
                    for key, value in step_metrics.items()
                ]

                if self.verbose:
                    self._log_step_outputs(epoch,
                                           step,
                                           step_loss,
                                           step_metrics,
                                           training=True)

        # Converted to scalar and dict with scalar values respectively.
        return self._get_epoch_outputs(epoch,
                                       epoch_loss,
                                       epoch_metrics,
                                       training=True)

    def _train_step(self, inputs, targets, extra_params):
        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        recons = self.output_train_transform(outputs, targets, extra_params)
        step_loss, step_metrics = self._step(recons, targets, extra_params)
        step_loss.backward()
        self.optimizer.step()
        return recons, step_loss, step_metrics

    def _val_epoch(self, epoch):
        self.model.eval()
        torch.autograd.set_grad_enabled(False)

        epoch_loss = list()
        epoch_metrics = defaultdict(list)

        # 1 based indexing for steps.
        data_loader = enumerate(self.val_loader, start=1)
        if not self.verbose:
            data_loader = tqdm(data_loader, total=len(self.val_loader.dataset))

        for step, data in data_loader:
            inputs, targets, extra_params = self.input_val_transform(*data)
            recons, step_loss, step_metrics = self._val_step(
                inputs, targets, extra_params)
            epoch_loss.append(step_loss.detach())

            if self.use_slice_metrics:
                slice_metrics = self._get_slice_metrics(
                    recons, targets, extra_params)
                step_metrics.update(slice_metrics)

            [
                epoch_metrics[key].append(value.detach())
                for key, value in step_metrics.items()
            ]

            if self.verbose:
                self._log_step_outputs(epoch,
                                       step,
                                       step_loss,
                                       step_metrics,
                                       training=False)

            # Visualize images on TensorBoard.
            self._visualize_images(recons,
                                   targets,
                                   extra_params,
                                   epoch,
                                   step,
                                   training=False)

        # Converted to scalar and dict with scalar values respectively.
        return self._get_epoch_outputs(epoch,
                                       epoch_loss,
                                       epoch_metrics,
                                       training=False)

    def _val_step(self, inputs, targets, extra_params):
        outputs = self.model(inputs)
        recons = self.output_val_transform(outputs, targets, extra_params)
        step_loss, step_metrics = self._step(recons, targets, extra_params)
        return recons, step_loss, step_metrics

    def _step(self, recons, targets, extra_params):
        step_loss = self.losses['img_loss'](recons['img_recons'],
                                            targets['img_targets'])

        # If img_loss is a tuple, it is expected to contain all its component losses as a dict in its second element.
        step_metrics = dict()
        if isinstance(step_loss, tuple):
            step_loss, step_metrics = step_loss

        acc = extra_params["acceleration"]
        if step_metrics:  # This has to be checked before anything is added to step_metrics.
            for key, value in step_metrics.items():
                step_metrics[f'acc_{acc}_{key}'] = value
        step_metrics[f'acc_{acc}_loss'] = step_loss
        return step_loss, step_metrics

    def _visualize_images(self,
                          recons,
                          targets,
                          extra_params,
                          epoch,
                          step,
                          training=False):
        mode = 'Training' if training else 'Validation'

        # This numbering scheme seems to have issues for certain numbers.
        # Please check cases when there is no remainder.
        if self.display_interval and (step % self.display_interval == 0):
            img_recon_grid = make_img_grid(recons['img_recons'],
                                           self.shrink_scale)

            # The delta image is obtained by subtracting at the complex image, not the real valued image.
            delta_image = complex_abs(targets['cmg_targets'] -
                                      recons['cmg_recons'])
            delta_img_grid = make_img_grid(delta_image, self.shrink_scale)

            acc = extra_params['acceleration']
            kwargs = dict(global_step=epoch, dataformats='HW')

            self.writer.add_image(f'{mode} Image Recons/{acc}/{step}',
                                  img_recon_grid, **kwargs)
            self.writer.add_image(f'{mode} Delta Image/{acc}/{step}',
                                  delta_img_grid, **kwargs)

            if 'kspace_recons' in recons:
                kspace_recon_grid = make_k_grid(recons['kspace_recons'],
                                                self.smoothing_factor,
                                                self.shrink_scale)
                self.writer.add_image(f'{mode} k-space Recons/{acc}/{step}',
                                      kspace_recon_grid, **kwargs)

            # Adding RSS images of reconstructions and targets.
            if 'rss_recons' in recons:
                recon_rss = standardize_image(recons['rss_recons'])
                delta_rss = standardize_image(make_rss_slice(delta_image))
                self.writer.add_image(f'{mode} RSS Recons/{acc}/{step}',
                                      recon_rss, **kwargs)
                self.writer.add_image(f'{mode} RSS Delta/{acc}/{step}',
                                      delta_rss, **kwargs)

            if 'semi_kspace_recons' in recons:
                semi_kspace_recon_grid = make_k_grid(
                    recons['semi_kspace_recons'], self.smoothing_factor,
                    self.shrink_scale)

                self.writer.add_image(
                    f'{mode} semi-k-space Recons/{acc}/{step}',
                    semi_kspace_recon_grid, **kwargs)

            if epoch == 1:  # Maybe add input images too later on.
                img_target_grid = make_img_grid(targets['img_targets'],
                                                self.shrink_scale)
                self.writer.add_image(f'{mode} Image Targets/{acc}/{step}',
                                      img_target_grid, **kwargs)

                if 'kspace_targets' in targets:
                    kspace_target_grid = \
                        make_k_grid(targets['kspace_targets'], self.smoothing_factor, self.shrink_scale)
                    self.writer.add_image(
                        f'{mode} k-space Targets/{acc}/{step}',
                        kspace_target_grid, **kwargs)

                if 'img_inputs' in targets:
                    # Not actually the input but what the input looks like as an image.
                    img_grid = make_img_grid(targets['img_inputs'],
                                             self.shrink_scale)
                    self.writer.add_image(
                        f'{mode} Inputs as Images/{acc}/{step}', img_grid,
                        **kwargs)

                if 'rss_targets' in targets:
                    target_rss = standardize_image(targets['rss_targets'])
                    self.writer.add_image(f'{mode} RSS Targets/{acc}/{step}',
                                          target_rss, **kwargs)

                if 'semi_kspace_targets' in targets:
                    semi_kspace_target_grid = make_k_grid(
                        targets['semi_kspace_targets'], self.smoothing_factor,
                        self.shrink_scale)

                    self.writer.add_image(
                        f'{mode} semi-k-space Targets/{acc}/{step}',
                        semi_kspace_target_grid, **kwargs)

    def _get_slice_metrics(self, recons, targets, extra_params):
        img_recons = recons['img_recons'].detach()  # Just in case.
        img_targets = targets['img_targets'].detach()
        max_range = img_targets.max() - img_targets.min()

        slice_ssim = self.ssim(img_recons, img_targets)
        slice_psnr = psnr(img_recons, img_targets, data_range=max_range)
        slice_nmse = nmse(img_recons, img_targets)

        slice_metrics = {
            'slice/ssim': slice_ssim,
            'slice/nmse': slice_nmse,
            'slice/psnr': slice_psnr
        }

        if 'rss_recons' in recons:
            rss_recons = recons['rss_recons'].detach()
            rss_targets = targets['rss_targets'].detach()
            max_range = rss_targets.max() - rss_targets.min()

            rss_ssim = self.ssim(rss_recons, rss_targets)
            rss_psnr = psnr(rss_recons, rss_targets, data_range=max_range)
            rss_nmse = nmse(rss_recons, rss_targets)

            slice_metrics['rss/ssim'] = rss_ssim
            slice_metrics['rss/psnr'] = rss_psnr
            slice_metrics['rss/nmse'] = rss_nmse
        else:
            rss_ssim = rss_psnr = rss_nmse = 0

        # Additional metrics for separating between acceleration factors.
        if 'acceleration' in extra_params:
            acc = extra_params["acceleration"]
            slice_metrics[f'slice_acc_{acc}/ssim'] = slice_ssim
            slice_metrics[f'slice_acc_{acc}/psnr'] = slice_psnr
            slice_metrics[f'slice_acc_{acc}/nmse'] = slice_nmse

            if 'rss_recons' in recons:
                slice_metrics[f'rss_acc_{acc}/ssim'] = rss_ssim
                slice_metrics[f'rss_acc_{acc}/psnr'] = rss_psnr
                slice_metrics[f'rss_acc_{acc}/nmse'] = rss_nmse

        return slice_metrics

    def _get_epoch_outputs(self,
                           epoch,
                           epoch_loss,
                           epoch_metrics,
                           training=True):
        mode = 'Training' if training else 'Validation'
        num_slices = len(self.train_loader.dataset) if training else len(
            self.val_loader.dataset)

        # Checking for nan values.
        epoch_loss = torch.stack(epoch_loss)
        is_finite = torch.isfinite(epoch_loss)
        num_nans = (is_finite.size(0) - is_finite.sum()).item()

        if num_nans > 0:
            self.logger.warning(
                f'Epoch {epoch} {mode}: {num_nans} NaN values present in {num_slices} slices.'
                f'Turning on anomaly detection.')
            # Turn on anomaly detection for finding where the nan values are.
            torch.autograd.set_detect_anomaly(True)
            epoch_loss = torch.mean(epoch_loss[is_finite]).item()
        else:
            epoch_loss = torch.mean(epoch_loss).item()

        for key, value in epoch_metrics.items():
            epoch_metric = torch.stack(value)
            is_finite = torch.isfinite(epoch_metric)
            num_nans = (is_finite.size(0) - is_finite.sum()).item()

            if num_nans > 0:
                self.logger.warning(
                    f'Epoch {epoch} {mode} {key}: {num_nans} NaN values present in {num_slices} slices.'
                    f'Turning on anomaly detection.')
                epoch_metrics[key] = torch.mean(epoch_metric[is_finite]).item()
            else:
                epoch_metrics[key] = torch.mean(epoch_metric).item()

        return epoch_loss, epoch_metrics

    def _log_step_outputs(self,
                          epoch,
                          step,
                          step_loss,
                          step_metrics,
                          training=True):
        mode = 'Training' if training else 'Validation'
        self.logger.info(
            f'Epoch {epoch:03d} Step {step:03d} {mode} loss: {step_loss.item():.4e}'
        )
        for key, value in step_metrics.items():
            self.logger.info(
                f'Epoch {epoch:03d} Step {step:03d}: {mode} {key}: {value.item():.4e}'
            )

    def _log_epoch_outputs(self,
                           epoch,
                           epoch_loss,
                           epoch_metrics,
                           elapsed_secs,
                           training=True):
        mode = 'Training' if training else 'Validation'
        self.logger.info(
            f'Epoch {epoch:03d} {mode}. loss: {epoch_loss:.4e}, '
            f'Time: {elapsed_secs // 60} min {elapsed_secs % 60} sec')
        self.writer.add_scalar(f'{mode} epoch_loss',
                               scalar_value=epoch_loss,
                               global_step=epoch)

        for key, value in epoch_metrics.items():
            self.logger.info(f'Epoch {epoch:03d} {mode}. {key}: {value:.4e}')
            # Very important whether it is mode_~~ or mode/~~.
            if 'loss' in key:
                self.writer.add_scalar(f'{mode}/epoch_{key}',
                                       scalar_value=value,
                                       global_step=epoch)
            else:
                self.writer.add_scalar(f'{mode}_epoch_{key}',
                                       scalar_value=value,
                                       global_step=epoch)

        if not training:  # Record learning rate.
            for idx, group in enumerate(self.optimizer.param_groups, start=1):
                self.writer.add_scalar(f'learning_rate_{idx}',
                                       group['lr'],
                                       global_step=epoch)
class ModelTrainerK2C:

    def __init__(self, args, model, optimizer, train_loader, val_loader,
                 input_train_transform, input_val_transform, output_transform, losses, scheduler=None):

        # Allow multiple processes to access tensors on GPU. Add checking for multiple continuous runs.
        if multiprocessing.get_start_method(allow_none=True) is None:
            multiprocessing.set_start_method(method='spawn')

        self.logger = get_logger(name=__name__, save_file=args.log_path / args.run_name)

        # Checking whether inputs are correct.
        assert isinstance(model, nn.Module), '`model` must be a Pytorch Module.'
        assert isinstance(optimizer, optim.Optimizer), '`optimizer` must be a Pytorch Optimizer.'
        assert isinstance(train_loader, DataLoader) and isinstance(val_loader, DataLoader), \
            '`train_loader` and `val_loader` must be Pytorch DataLoader objects.'

        assert callable(input_train_transform) and callable(input_val_transform), \
            'input_transforms must be callable functions.'
        # I think this would be best practice.
        assert isinstance(output_transform, nn.Module), '`output_transform` must be a Pytorch Module.'

        # 'losses' is expected to be a dictionary.
        # Even composite losses should be a single loss module with multiple outputs.
        losses = nn.ModuleDict(losses)

        if scheduler is not None:
            if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                self.metric_scheduler = True
            elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
                self.metric_scheduler = False
            else:
                raise TypeError('`scheduler` must be a Pytorch Learning Rate Scheduler.')

        # Display interval of 0 means no display of validation images on TensorBoard.
        if args.max_images <= 0:
            self.display_interval = 0
        else:
            self.display_interval = int(len(val_loader.dataset) // (args.max_images * args.batch_size))

        self.manager = CheckpointManager(model, optimizer, mode='min', save_best_only=args.save_best_only,
                                         ckpt_dir=args.ckpt_path, max_to_keep=args.max_to_keep)

        # loading from checkpoint if specified.
        if vars(args).get('prev_model_ckpt'):
            self.manager.load(load_dir=args.prev_model_ckpt, load_optimizer=False)

        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.input_train_transform = input_train_transform
        self.input_val_transform = input_val_transform
        self.output_transform = output_transform
        self.losses = losses
        self.scheduler = scheduler

        self.verbose = args.verbose
        self.num_epochs = args.num_epochs
        self.smoothing_factor = args.smoothing_factor
        self.use_slice_metrics = args.use_slice_metrics
        self.writer = SummaryWriter(str(args.log_path))

    def train_model(self):
        tic_tic = time()
        self.logger.info('Beginning Training Loop.')
        for epoch in range(1, self.num_epochs + 1):  # 1 based indexing of epochs.
            tic = time()  # Training
            train_epoch_loss, train_epoch_metrics = self._train_epoch(epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch, train_epoch_loss, train_epoch_metrics, elapsed_secs=toc, training=True)

            tic = time()  # Validation
            val_epoch_loss, val_epoch_metrics = self._val_epoch(epoch=epoch)
            toc = int(time() - tic)
            self._log_epoch_outputs(epoch, val_epoch_loss, val_epoch_metrics, elapsed_secs=toc, training=False)

            self.manager.save(metric=val_epoch_loss, verbose=True)

            if self.scheduler is not None:
                if self.metric_scheduler:  # If the scheduler is a metric based scheduler, include metrics.
                    self.scheduler.step(metrics=val_epoch_loss)
                else:
                    self.scheduler.step()

        self.writer.close()  # Flushes remaining data to TensorBoard.
        toc_toc = int(time() - tic_tic)
        self.logger.info(f'Finishing Training Loop. Total elapsed time: '
                         f'{toc_toc // 3600} hr {(toc_toc // 60) % 60} min {toc_toc % 60} sec.')

    def _train_epoch(self, epoch):
        self.model.train()
        torch.autograd.set_grad_enabled(True)

        epoch_loss = list()  # Appending values to list due to numerical underflow.
        epoch_metrics = defaultdict(list)

        data_loader = enumerate(self.train_loader, start=1)
        if not self.verbose:  # tqdm has to be on the outermost iterator to function properly.
            data_loader = tqdm(data_loader, total=len(self.train_loader.dataset))

        for step, data in data_loader:
            # Data pre-processing is expected to have gradient calculations removed already.
            inputs, targets, extra_params = self.input_train_transform(*data)

            # 'recons' is a dictionary containing k-space, complex image, and real image reconstructions.
            recons, step_loss, step_metrics = self._train_step(inputs, targets, extra_params)
            epoch_loss.append(step_loss.detach())  # Perhaps not elegant, but underflow makes this necessary.

            # Gradients are not calculated so as to boost speed and remove weird errors.
            with torch.no_grad():  # Update epoch loss and metrics
                if self.use_slice_metrics:
                    slice_metrics = self._get_slice_metrics(recons['img_recons'], targets['img_targets'])
                    step_metrics.update(slice_metrics)

                [epoch_metrics[key].append(value.detach()) for key, value in step_metrics.items()]

                if self.verbose:
                    self._log_step_outputs(epoch, step, step_loss, step_metrics, training=True)

        # Converted to scalar and dict with scalar forms.
        return self._get_epoch_outputs(epoch, epoch_loss, epoch_metrics, training=True)

    def _train_step(self, inputs, targets, extra_params):
        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        recons = self.output_transform(outputs, targets, extra_params)
        step_loss = self.losses['cmg_loss'](recons['cmg_recons'], targets['cmg_targets'])
        step_loss.backward()
        self.optimizer.step()
        step_metrics = dict()
        return recons, step_loss, step_metrics

    def _val_epoch(self, epoch):
        self.model.eval()
        torch.autograd.set_grad_enabled(False)

        epoch_loss = list()
        epoch_metrics = defaultdict(list)

        # 1 based indexing for steps.
        data_loader = enumerate(self.val_loader, start=1)
        if not self.verbose:
            data_loader = tqdm(data_loader, total=len(self.val_loader.dataset))

        for step, data in data_loader:
            inputs, targets, extra_params = self.input_val_transform(*data)
            recons, step_loss, step_metrics = self._val_step(inputs, targets, extra_params)
            epoch_loss.append(step_loss.detach())

            if self.use_slice_metrics:
                slice_metrics = self._get_slice_metrics(recons['img_recons'], targets['img_targets'])
                step_metrics.update(slice_metrics)

            [epoch_metrics[key].append(value.detach()) for key, value in step_metrics.items()]

            if self.verbose:
                self._log_step_outputs(epoch, step, step_loss, step_metrics, training=False)

            # This numbering scheme seems to have issues for certain numbers.
            # Please check cases when there is no remainder.
            if self.display_interval and (step % self.display_interval == 0):
                # Change image display function later.
                img_recon_grid, img_target_grid, img_delta_grid = \
                    make_grid_triplet(recons['img_recons'], targets['img_targets'])
                kspace_recon_grid = make_k_grid(recons['kspace_recons'], self.smoothing_factor)
                kspace_target_grid = make_k_grid(targets['kspace_targets'], self.smoothing_factor)

                self.writer.add_image(f'k-space_Recons/{step}', kspace_recon_grid, epoch, dataformats='HW')

                self.writer.add_image(f'Image_Recons/{step}', img_recon_grid, epoch, dataformats='HW')

                self.writer.add_image(f'Image_Deltas/{step}', img_delta_grid, epoch, dataformats='HW')

                if epoch == 1:
                    self.writer.add_image(f'k-space_Targets/{step}', kspace_target_grid, epoch, dataformats='HW')
                    self.writer.add_image(f'Image_Targets/{step}', img_target_grid, epoch, dataformats='HW')

                    # TODO: Add input images to visualization too.

                self.targets_recorded = True

        epoch_loss, epoch_metrics = self._get_epoch_outputs(epoch, epoch_loss, epoch_metrics, training=False)
        return epoch_loss, epoch_metrics

    def _val_step(self, inputs, targets, extra_params):
        outputs = self.model(inputs)
        recons = self.output_transform(outputs, targets, extra_params)
        step_loss = self.losses['cmg_loss'](recons['cmg_recons'], targets['cmg_targets'])
        step_metrics = dict()
        return recons, step_loss, step_metrics

    @staticmethod
    def _get_slice_metrics(img_recons, img_targets):

        img_recons = img_recons.detach()  # Just in case.
        img_targets = img_targets.detach()

        max_range = img_targets.max() - img_targets.min()
        slice_ssim = ssim_loss(img_recons, img_targets, max_val=max_range)
        slice_psnr = psnr(img_recons, img_targets, data_range=max_range)
        slice_nmse = nmse(img_recons, img_targets)

        return {'slice_ssim': slice_ssim, 'slice_nmse': slice_nmse, 'slice_psnr': slice_psnr}

    def _get_epoch_outputs(self, epoch, epoch_loss, epoch_metrics, training=True):
        mode = 'Training' if training else 'Validation'
        num_slices = len(self.train_loader.dataset) if training else len(self.val_loader.dataset)

        # Checking for nan values.
        epoch_loss = torch.stack(epoch_loss)
        is_finite = torch.isfinite(epoch_loss)
        num_nans = (is_finite.size(0) - is_finite.sum()).item()

        if num_nans > 0:
            self.logger.warning(f'Epoch {epoch} {mode}: {num_nans} NaN values present in {num_slices} slices.'
                                f'Turning on anomaly detection.')
            # Turn on anomaly detection for finding where the nan values are.
            torch.autograd.set_detect_anomaly(True)
            epoch_loss = torch.mean(epoch_loss[is_finite]).item()
        else:
            epoch_loss = torch.mean(epoch_loss).item()

        for key, value in epoch_metrics.items():
            epoch_metric = torch.stack(value)
            is_finite = torch.isfinite(epoch_metric)
            num_nans = (is_finite.size(0) - is_finite.sum()).item()

            if num_nans > 0:
                self.logger.warning(f'Epoch {epoch} {mode} {key}: {num_nans} NaN values present in {num_slices} slices.'
                                    f'Turning on anomaly detection.')
                epoch_metrics[key] = torch.mean(epoch_metric[is_finite]).item()
            else:
                epoch_metrics[key] = torch.mean(epoch_metric).item()

        return epoch_loss, epoch_metrics

    def _log_step_outputs(self, epoch, step, step_loss, step_metrics, training=True):
        mode = 'Training' if training else 'Validation'
        self.logger.info(f'Epoch {epoch:03d} Step {step:03d} {mode} loss: {step_loss.item():.4e}')
        for key, value in step_metrics.items():
            self.logger.info(f'Epoch {epoch:03d} Step {step:03d}: {mode} {key}: {value.item():.4e}')

    def _log_epoch_outputs(self, epoch, epoch_loss, epoch_metrics, elapsed_secs, training=True):
        mode = 'Training' if training else 'Validation'
        self.logger.info(f'Epoch {epoch:03d} {mode}. loss: {epoch_loss:.4e}, '
                         f'Time: {elapsed_secs // 60} min {elapsed_secs % 60} sec')
        self.writer.add_scalar(f'{mode}_epoch_loss', scalar_value=epoch_loss, global_step=epoch)

        for key, value in epoch_metrics.items():
            self.logger.info(f'Epoch {epoch:03d} {mode}. {key}: {value:.4e}')
            self.writer.add_scalar(f'{mode}_epoch_{key}', scalar_value=value, global_step=epoch)

        if not training:  # Record learning rate.
            for idx, group in enumerate(self.optimizer.param_groups, start=1):
                self.writer.add_scalar(f'learning_rate_{idx}', group['lr'], global_step=epoch)