def _visualize_outputs(c_img_recons, c_img_targets, smoothing_factor=8):
     image_recons = complex_abs(c_img_recons)
     image_targets = complex_abs(c_img_targets)
     kspace_recons = make_k_grid(fft2(c_img_recons), smoothing_factor)
     kspace_targets = make_k_grid(fft2(c_img_targets), smoothing_factor)
     image_recons, image_targets, image_deltas = make_grid_triplet(
         image_recons, image_targets)
     return kspace_recons, kspace_targets, image_recons, image_targets, image_deltas
    def _val_epoch(self, epoch):
        self.model.eval()
        torch.autograd.set_grad_enabled(False)

        epoch_loss = list()
        epoch_metrics = defaultdict(list)

        # 1 based indexing for steps.
        data_loader = enumerate(self.val_loader, start=1)
        if not self.verbose:
            data_loader = tqdm(data_loader, total=len(self.val_loader.dataset))

        for step, data in data_loader:
            inputs, targets, extra_params = self.input_val_transform(*data)
            recons, step_loss, step_metrics = self._val_step(inputs, targets, extra_params)
            epoch_loss.append(step_loss.detach())

            if self.use_slice_metrics:
                slice_metrics = self._get_slice_metrics(recons['img_recons'], targets['img_targets'])
                step_metrics.update(slice_metrics)

            [epoch_metrics[key].append(value.detach()) for key, value in step_metrics.items()]

            if self.verbose:
                self._log_step_outputs(epoch, step, step_loss, step_metrics, training=False)

            # This numbering scheme seems to have issues for certain numbers.
            # Please check cases when there is no remainder.
            if self.display_interval and (step % self.display_interval == 0):
                # Change image display function later.
                img_recon_grid, img_target_grid, img_delta_grid = \
                    make_grid_triplet(recons['img_recons'], targets['img_targets'])
                kspace_recon_grid = make_k_grid(recons['kspace_recons'], self.smoothing_factor)
                kspace_target_grid = make_k_grid(targets['kspace_targets'], self.smoothing_factor)

                self.writer.add_image(f'k-space_Recons/{step}', kspace_recon_grid, epoch, dataformats='HW')
                self.writer.add_image(f'Image_Recons/{step}', img_recon_grid, epoch, dataformats='HW')
                self.writer.add_image(f'Image_Deltas/{step}', img_delta_grid, epoch, dataformats='HW')

                if 'semi_kspace_recons' in recons:
                    semi_kspace_recon_grid = make_k_grid(recons['semi_kspace_recons'])
                    self.writer.add_image(f'semi-k-space_Recons/{step}',
                                          semi_kspace_recon_grid, epoch, dataformats='HW')

                if epoch == 1:  # Maybe add input images too later on.
                    self.writer.add_image(f'k-space_Targets/{step}', kspace_target_grid, epoch, dataformats='HW')
                    self.writer.add_image(f'Image_Targets/{step}', img_target_grid, epoch, dataformats='HW')

                    if 'semi_kspace_targets' in targets:
                        semi_kspace_target_grid = make_k_grid(targets['semi_kspace_targets'])
                        self.writer.add_image(f'semi-k-space_Targets/{step}',
                                              semi_kspace_target_grid, epoch, dataformats='HW')

        # Converted to scalar and dict with scalar values respectively.
        return self._get_epoch_outputs(epoch, epoch_loss, epoch_metrics, training=False)
    def _val_epoch(self, epoch):
        self.model.eval()
        torch.autograd.set_grad_enabled(False)

        epoch_loss_lst = list()
        epoch_metrics_lst = [list()
                             for _ in self.metrics] if self.metrics else None

        for step, (inputs, targets, extra_params) in enumerate(self.val_loader,
                                                               start=1):
            step_loss, image_recons, kspace_recons = self._val_step(
                inputs, targets, extra_params)

            epoch_loss_lst.append(step_loss.item())
            # Step functions have internalized conditional statements deciding whether to execute or not.
            step_metrics = self._get_step_metrics(image_recons, targets,
                                                  epoch_metrics_lst)
            self._log_step_outputs(epoch,
                                   step,
                                   step_loss,
                                   step_metrics,
                                   training=False)

            # Save images to TensorBoard. Send this to a separate function later on.
            # Condition ensures that self.display_interval != 0 and that the step is right for display.
            if self.display_interval and (step % self.display_interval == 0):
                recons_grid, targets_grid, deltas_grid = make_grid_triplet(
                    image_recons, targets)
                kspace_grid = make_k_grid(kspace_recons)

                self.writer.add_image(f'k-space_Recons/{step}',
                                      kspace_grid,
                                      epoch,
                                      dataformats='HW')
                self.writer.add_image(f'Image_Recons/{step}',
                                      recons_grid,
                                      epoch,
                                      dataformats='HW')
                self.writer.add_image(f'Targets/{step}',
                                      targets_grid,
                                      epoch,
                                      dataformats='HW')
                self.writer.add_image(f'Deltas/{step}',
                                      deltas_grid,
                                      epoch,
                                      dataformats='HW')

        epoch_loss, epoch_metrics = self._get_epoch_outputs(epoch,
                                                            epoch_loss_lst,
                                                            epoch_metrics_lst,
                                                            training=False)
        return epoch_loss, epoch_metrics
    def _val_epoch(self, epoch):
        self.generator.eval()
        self.discriminator.eval()
        torch.autograd.set_grad_enabled(False)

        epoch_loss = list(
        )  # Appending values to list due to numerical underflow.
        epoch_loss_components = defaultdict(list)

        val_len = len(self.val_loader.dataset)
        for step, (inputs, targets,
                   extra_params) in tqdm(enumerate(self.val_loader, start=1),
                                         total=val_len):
            recons, step_loss, step_loss_components = self._val_step(
                inputs, targets, extra_params)

            # Append to list to prevent errors from NaN and Inf values.
            epoch_loss.append(step_loss)
            for key, value in step_loss_components.items():
                epoch_loss_components[key].append(value)

            if self.verbose:
                self._log_step_outputs(epoch,
                                       step,
                                       step_loss,
                                       step_loss_components,
                                       training=False)

            # Save images to TensorBoard.
            # Condition ensures that self.display_interval != 0 and that the step is right for display.
            if self.display_interval and (step % self.display_interval == 0):
                recon_grid, target_grid, delta_grid = make_grid_triplet(
                    recons, targets)

                self.writer.add_image(f'Recons/{step}',
                                      recon_grid,
                                      global_step=epoch,
                                      dataformats='HW')
                self.writer.add_image(f'Targets/{step}',
                                      target_grid,
                                      global_step=epoch,
                                      dataformats='HW')
                self.writer.add_image(f'Deltas/{step}',
                                      delta_grid,
                                      global_step=epoch,
                                      dataformats='HW')

        return self._get_epoch_outputs(epoch,
                                       epoch_loss,
                                       epoch_loss_components,
                                       training=False)
Esempio n. 5
0
    def _val_epoch(self, epoch):
        self.model.eval()
        torch.autograd.set_grad_enabled(False)

        epoch_loss = list()
        epoch_metrics = defaultdict(list)

        # 1 based indexing for steps.
        data_loader = enumerate(self.val_loader, start=1)
        if not self.verbose:
            data_loader = tqdm(data_loader, total=len(self.val_loader.dataset))

        # 'targets' is a dictionary containing k-space targets, cmg_targets, and img_targets.
        for step, data in data_loader:
            inputs, targets, extra_params = self.input_val_transform(*data)
            # 'recons' is a dictionary containing k-space, complex image, and real image reconstructions.
            recons, step_loss, step_metrics = self._val_step(inputs, targets, extra_params)
            epoch_loss.append(step_loss.detach())

            if self.use_slice_metrics:
                slice_metrics = self._get_slice_metrics(recons['img_recons'], targets['img_targets'])
                step_metrics.update(slice_metrics)

            [epoch_metrics[key].append(value.detach()) for key, value in step_metrics.items()]

            if self.verbose:
                self._log_step_outputs(epoch, step, step_loss, step_metrics, training=False)

            # Save images to TensorBoard.
            # Condition ensures that self.display_interval != 0 and that the step is right for display.

            # This numbering scheme seems to have issues for certain numbers.
            # Please check cases when there is no remainder.
            if self.display_interval and (step % self.display_interval == 0):
                img_recon_grid, img_target_grid, img_delta_grid = \
                    make_grid_triplet(recons['img_recons'], targets['img_targets'])
                kspace_recon_grid = make_k_grid(recons['kspace_recons'], self.smoothing_factor)
                kspace_target_grid = make_k_grid(targets['kspace_targets'], self.smoothing_factor)

                self.writer.add_image(f'k-space_Recons/{step}', kspace_recon_grid, epoch, dataformats='HW')
                self.writer.add_image(f'k-space_Targets/{step}', kspace_target_grid, epoch, dataformats='HW')
                self.writer.add_image(f'Image_Recons/{step}', img_recon_grid, epoch, dataformats='HW')
                self.writer.add_image(f'Image_Targets/{step}', img_target_grid, epoch, dataformats='HW')
                self.writer.add_image(f'Image_Deltas/{step}', img_delta_grid, epoch, dataformats='HW')

        epoch_loss, epoch_metrics = self._get_epoch_outputs(epoch, epoch_loss, epoch_metrics, training=False)
        return epoch_loss, epoch_metrics