Exemple #1
0
    def _do_iteration(self,
                      data: Dict[str, torch.Tensor],
                      loss_fns: Dict[str, Callable]) -> Tuple[torch.Tensor, Dict]:

        # Target is not needed in the model input
        target = data['target'].align_to('batch', 'complex', 'height', 'width').to(self.device)  # type: ignore
        # The first input_image in the iteration is the input_image with the mask applied and no first hidden state.
        input_image = data.pop('masked_image').to(self.device)  # type: ignore
        hidden_state = None
        output_image = None
        loss_dicts = []
        for rim_step in range(self.cfg.model.steps):
            reconstruction_iter, hidden_state = self.model(
                **dict_to_device(data, self.device),
                input_image=input_image,
                hidden_state=hidden_state,
            )
            # TODO: Unclear why this refining is needed.
            output_image = reconstruction_iter[-1].refine_names('batch', 'complex', 'height', 'width')

            loss_dict = {k: torch.tensor([0.], dtype=target.dtype).to(self.device) for k in loss_fns.keys()}
            loss = torch.tensor([0.], device=output_image.device)
            for output_image_iter in reconstruction_iter:
                for k, v in loss_dict.items():
                    loss_dict[k] = v + loss_fns[k](
                        output_image_iter.rename(None), target.rename(None), reduction='mean'
                    )

            # for output_image_iter in reconstruction_iter:
            #     loss_dict = {
            #         k: v + loss_fns[k](output_image_iter.rename(None), target.rename(None), reduction='mean')
            #         for k, v in loss_dict.items()}
            loss_dict = {k: v / len(reconstruction_iter) for k, v in loss_dict.items()}
            loss = sum(loss_dict.values())

            if self.model.training:
                if self.mixed_precision:
                    with amp.scale_loss(loss, self.__optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()  # type: ignore

            # Detach hidden state from computation graph, to ensure loss is only computed per RIM block.
            hidden_state = hidden_state.detach()
            input_image = output_image.detach()

            loss_dicts.append(detach_dict(loss_dict))  # Need to detach dict as this is only used for logging.

        # Add the loss dicts together over RIM steps, divide by the number of steps.
        loss_dict = reduce_list_of_dicts(loss_dicts, mode='sum', divisor=self.cfg.model.steps)
        return output_image, loss_dict
Exemple #2
0
    def training_loop(
        self,
        data_loader: DataLoader,
        start_iter: int,
        validation_data_loaders: Optional[List[DataLoader]] = None,
        experiment_directory: Optional[pathlib.Path] = None,
    ):
        self.logger.info(f"Local rank: {communication.get_local_rank()}.")
        self.models_training_mode()

        loss_fns = self.build_loss()
        metric_fns = self.build_metrics(self.cfg.training.metrics)
        storage = get_event_storage()

        total_iter = self.cfg.training.num_iterations  # noqa
        for data, iter_idx in zip(data_loader, range(start_iter, total_iter)):
            data = AddNames()(data)
            if iter_idx == start_iter:
                self.ndim = self.compute_dimensionality_from_sample(data)
                self.logger.info(f"Data dimensionality: {self.ndim}.")

            if iter_idx == 0:
                self.log_first_training_example(data)

            try:
                output, loss_dict = self._do_iteration(data, loss_fns)
            except ProcessKilledException as e:
                # If the process is killed, the output if saved at state iter_idx, which is the current state,
                # so the computation can restart from the last iteration.
                self.logger.exception(f"Exiting with exception: {e}.")
                self.checkpointer.save(
                    iter_idx)  # Save checkpoint at kill. # noqa
                self.write_to_logs(
                )  # TODO: This causes the issue that current metrics are not written,
                # and you end up with an empty line.
                sys.exit(-1)

            # Gradient accumulation
            if (iter_idx +
                    1) % self.cfg.training.gradient_steps == 0:  # type: ignore
                if self.cfg.training.gradient_steps > 1:  # type: ignore
                    for parameter in self.model.parameters():
                        if parameter.grad is not None:
                            # In-place division
                            parameter.grad.div_(self.cfg.training.
                                                gradient_steps)  # type: ignore
                if self.cfg.training.gradient_clipping > 0.0:  # type: ignore
                    self._scaler.unscale_(self.__optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(),
                        self.cfg.training.gradient_clipping)

                # Gradient norm
                if self.cfg.training.gradient_debug:  # type: ignore
                    warnings.warn(
                        f"Gradient debug set. This will affect training performance. Only use for debugging."
                        f"This message will only be displayed once.")
                    parameters = list(
                        filter(lambda p: p.grad is not None,
                               self.model.parameters()))
                    gradient_norm = sum([
                        parameter.grad.data**2 for parameter in parameters
                    ]).sqrt()  # typing: ignore
                    storage.add_scalar("train/gradient_norm", gradient_norm)

                # Same as self.__optimizer.step() for mixed precision.
                self._scaler.step(self.__optimizer)
                # Updates the scale for next iteration.
                self._scaler.update()

            # Incorrect inference by mypy and pyflake
            self.__lr_scheduler.step()  # type: ignore # noqa
            storage.add_scalar("lr",
                               self.__optimizer.param_groups[0]["lr"],
                               smoothing_hint=False)

            self.__optimizer.zero_grad()  # type: ignore

            # Reduce the loss over all devices
            loss_dict_reduced = communication.reduce_tensor_dict(loss_dict)
            loss_reduced = sum(loss_dict_reduced.values())

            metrics_dict = evaluate_dict(
                metric_fns,
                transforms.modulus_if_complex(output.detach()).rename(None),
                data["target"].rename(None).detach().to(self.device),
                reduction="mean",
            )
            metrics_dict_reduced = (
                communication.reduce_tensor_dict(metrics_dict)
                if metrics_dict else {})
            storage.add_scalars(loss=loss_reduced,
                                **loss_dict_reduced,
                                **metrics_dict_reduced)

            if iter_idx > 5 and (
                    iter_idx % self.cfg.training.checkpointer.checkpoint_steps
                    == 0 or (iter_idx + 1) == total_iter):
                self.logger.info(f"Checkpointing at iteration {iter_idx}.")
                self.checkpointer.save(iter_idx)

            if (validation_data_loaders is not None and iter_idx > 5
                    and (iter_idx % self.cfg.training.validation_steps == 0 or
                         (iter_idx + 1) == total_iter)):
                for (
                        curr_dataset_name,
                        curr_validation_data_loader,
                ) in validation_data_loaders:
                    self.logger.info(
                        f"Evaluating: {curr_dataset_name}..."
                    )  # TODO(jt): Fix with better names and stuff.
                    (
                        curr_val_loss_dict,
                        curr_val_metric_dict_per_case,
                        visualize_slices,
                        visualize_target,
                    ) = self.evaluate(
                        curr_validation_data_loader,
                        loss_fns,
                        is_validation_process=True,
                    )

                    if experiment_directory:
                        # Make dictionary serializable for logging
                        serializable_val_metric_dict = {
                            k0: {k1: float(v1)
                                 for k1, v1 in v0.items()}
                            for k0, v0 in
                            curr_val_metric_dict_per_case.items()
                        }
                        write_json(
                            experiment_directory /
                            f"metrics_val_{curr_dataset_name}_{iter_idx}.json",
                            serializable_val_metric_dict,
                        )

                    # Metric dict still needs to be reduced as it gives values *per* data
                    curr_val_metric_dict = reduce_list_of_dicts(list(
                        curr_val_metric_dict_per_case.values()),
                                                                mode="average")

                    key_prefix = ("val/" if not curr_dataset_name else
                                  f"val/{curr_dataset_name}/")
                    val_loss_reduced = sum(curr_val_loss_dict.values())
                    storage.add_scalars(
                        **{key_prefix + "loss": val_loss_reduced},
                        **{
                            **prefix_dict_keys(curr_val_metric_dict, key_prefix),
                            **prefix_dict_keys(curr_val_loss_dict, key_prefix),
                        },
                        smoothing_hint=False,
                    )
                    visualize_slices = self.process_slices_for_visualization(
                        visualize_slices, visualize_target)
                    storage.add_image(f"{key_prefix}prediction",
                                      visualize_slices)

                    if iter_idx // self.cfg.training.validation_steps - 1 == 0:
                        visualize_target = make_grid(
                            crop_to_largest(visualize_target, pad_value=0),
                            nrow=self.cfg.tensorboard.num_images,
                            scale_each=True,
                        )
                        storage.add_image(f"{key_prefix}target",
                                          visualize_target)

                self.logger.info(
                    f"Done evaluation of {curr_dataset_name} at iteration {iter_idx}."
                )
                self.model.train()

            # Log every 20 iterations, or at a validation step or at the end of training.
            if iter_idx > 5 and (iter_idx % 20 == 0 or iter_idx %
                                 self.cfg.training.validation_steps == 0 or
                                 (iter_idx + 1) == total_iter):
                self.write_to_logs()

            storage.step()
Exemple #3
0
    def _do_iteration(
        self,
        data: Dict[str, torch.Tensor],
        loss_fns: Optional[Dict[str, Callable]] = None,
        regularizer_fns: Optional[Dict[str, Callable]] = None,
    ) -> namedtuple:

        # loss_fns can be done, e.g. during validation
        if loss_fns is None:
            loss_fns = {}

        if regularizer_fns is None:
            regularizer_fns = {}

        # The first input_image in the iteration is the input_image with the mask applied and no first hidden state.
        input_image = None
        hidden_state = None
        output_image = None
        loss_dicts = []
        regularizer_dicts = []

        data = dict_to_device(data, self.device)
        # TODO(jt): keys=['sampling_mask', 'sensitivity_map', 'target', 'masked_kspace', 'scaling_factor']
        sensitivity_map = data["sensitivity_map"]

        if "noise_model" in self.models:
            raise NotImplementedError()

        # Some things can be done with the sensitivity map here, e.g. apply a u-net
        if "sensitivity_model" in self.models:
            # Move channels to first axis
            sensitivity_map = sensitivity_map.align_to(*self.complex_names(
                add_coil=True))

            sensitivity_map = (self.compute_model_per_coil(
                "sensitivity_model",
                sensitivity_map).refine_names(*sensitivity_map.names).align_to(
                    *self.complex_names_complex_last(add_coil=True)))
            # Output has channel first, it is ("batch, "coil", "complex", ...)

        # The sensitivity map needs to be normalized such that
        # So \sum_{i \in \text{coils}} S_i S_i^* = 1
        sensitivity_map_norm = torch.sqrt(
            ((sensitivity_map**2).sum("complex")).sum("coil"))

        data["sensitivity_map"] = T.safe_divide(sensitivity_map,
                                                sensitivity_map_norm)
        if self.cfg.model.scale_loglikelihood:
            scaling_factor = (1.0 * self.cfg.model.scale_loglikelihood /
                              (data["scaling_factor"]**2))
            scaling_factor = scaling_factor.reshape(-1, 1).refine_names(
                "batch", "complex")
            self.logger.debug(f"Scaling factor is: {scaling_factor}")
        else:
            # Needs fixing.
            scaling_factor = (torch.tensor([1.0]).to(
                sensitivity_map.device).refine_names("complex"))

        for _ in range(self.cfg.model.steps):
            with autocast(enabled=self.mixed_precision):
                reconstruction_iter, hidden_state = self.model(
                    **data,
                    input_image=input_image,
                    hidden_state=hidden_state,
                    loglikelihood_scaling=scaling_factor,
                )
                # TODO: Unclear why this refining is needed.
                output_image = reconstruction_iter[-1].refine_names(
                    *self.complex_names())

                loss_dict = {
                    k: torch.tensor([0.0],
                                    dtype=data["target"].dtype).to(self.device)
                    for k in loss_fns.keys()
                }
                regularizer_dict = {
                    k: torch.tensor([0.0],
                                    dtype=data["target"].dtype).to(self.device)
                    for k in regularizer_fns.keys()
                }

                # TODO: This seems too similar not to be able to do this, perhaps a partial can help here
                for output_image_iter in reconstruction_iter:
                    for k, v in loss_dict.items():
                        loss_dict[k] = v + loss_fns[k](
                            output_image_iter,
                            **data,
                            reduction="mean",
                        )
                    for k, v in regularizer_dict.items():
                        regularizer_dict[k] = (v + regularizer_fns[k](
                            output_image_iter,
                            **data,
                        ).rename(None))

                loss_dict = {
                    k: v / len(reconstruction_iter)
                    for k, v in loss_dict.items()
                }
                regularizer_dict = {
                    k: v / len(reconstruction_iter)
                    for k, v in regularizer_dict.items()
                }

                loss = sum(loss_dict.values()) + sum(regularizer_dict.values())

            if self.model.training:
                self._scaler.scale(loss).backward()

            # Detach hidden state from computation graph, to ensure loss is only computed per RIM block.
            hidden_state = hidden_state.detach()
            input_image = output_image.detach()

            loss_dicts.append(detach_dict(loss_dict))
            regularizer_dicts.append(
                detach_dict(regularizer_dict)
            )  # Need to detach dict as this is only used for logging.

        # Add the loss dicts together over RIM steps, divide by the number of steps.
        loss_dict = reduce_list_of_dicts(loss_dicts,
                                         mode="sum",
                                         divisor=self.cfg.model.steps)
        regularizer_dict = reduce_list_of_dicts(regularizer_dicts,
                                                mode="sum",
                                                divisor=self.cfg.model.steps)
        output = namedtuple(
            "do_iteration",
            ["output_image", "sensitivity_map", "data_dict"],
        )

        return output(
            output_image=output_image,
            sensitivity_map=data["sensitivity_map"],
            data_dict={
                **loss_dict,
                **regularizer_dict
            },
        )
Exemple #4
0
    def evaluate(
        self,
        data_loader: DataLoader,
        loss_fns: Optional[Dict[str, Callable]],
        regularizer_fns: Optional[Dict[str, Callable]] = None,
        crop: Optional[str] = None,
        is_validation_process=True,
    ):

        self.models_to_device()
        self.models_validation_mode()
        torch.cuda.empty_cache()

        # Variables required for evaluation.
        # TODO(jt): Consider if this needs to be in the main engine.py or here. Might be possible we have different
        # types needed, perhaps even a FastMRI engine or something similar depending on the metrics.
        volume_metrics = self.build_metrics(self.cfg.validation.metrics)

        # filenames can be in the volume_indices attribute of the dataset
        if hasattr(data_loader.dataset, "volume_indices"):
            all_filenames = list(data_loader.dataset.volume_indices.keys())
            num_for_this_process = len(
                list(data_loader.batch_sampler.sampler.volume_indices.keys()))
            self.logger.info(
                f"Reconstructing a total of {len(all_filenames)} volumes. "
                f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})."
            )
        else:
            num_for_this_process = None
        filenames_seen = 0

        reconstruction_output = defaultdict(list)
        targets_output = defaultdict(list)
        val_losses = []
        val_volume_metrics = defaultdict(dict)
        last_filename = None

        # Container to for the slices which can be visualized in TensorBoard.
        visualize_slices = []
        visualize_target = []
        visualizations = {}

        extra_visualization_keys = (self.cfg.logging.log_as_image
                                    if self.cfg.logging.log_as_image else [])

        # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler
        # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is
        # that the slices are outputted from the Dataset *sequentially* for each volume one by one.
        time_start = time.time()

        for iter_idx, data in enumerate(data_loader):
            data = AddNames()(data)
            filenames = data.pop("filename")
            if len(set(filenames)) != 1:
                raise ValueError(
                    f"Expected a batch during validation to only contain filenames of one case. "
                    f"Got {set(filenames)}.")

            slice_nos = data.pop("slice_no")
            scaling_factors = data["scaling_factor"]

            resolution = self.compute_resolution(
                key=self.cfg.validation.crop,
                reconstruction_size=data.get("reconstruction_size", None),
            )

            # Compute output and loss.
            iteration_output = self._do_iteration(
                data, loss_fns, regularizer_fns=regularizer_fns)
            output = iteration_output.output_image
            loss_dict = iteration_output.data_dict
            # sensitivity_map = iteration_output.sensitivity_map

            loss_dict = detach_dict(loss_dict)
            output = output.detach()
            val_losses.append(loss_dict)

            # Output is complex-valued, and has to be cropped. This holds for both output and target.
            output_abs = self.process_output(
                output.refine_names(*self.complex_names()),
                scaling_factors,
                resolution=resolution,
            )

            if is_validation_process:
                target_abs = self.process_output(
                    data["target"].detach().refine_names(*self.real_names()),
                    scaling_factors,
                    resolution=resolution,
                )
                for key in extra_visualization_keys:
                    curr_data = data[key].detach()
                    # Here we need to discover which keys are actually normalized or not
                    # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23

            del output  # Explicitly call delete to clear memory.
            # TODO: Is a hack.

            # Aggregate volumes to be able to compute the metrics on complete volumes.
            for idx, filename in enumerate(filenames):
                if last_filename is None:
                    last_filename = (
                        filename  # First iteration last_filename is not set.
                    )

                # If the new filename is not the previous one, then we can reconstruct the volume as the sampling
                # is linear.
                # For the last case we need to check if we are at the last batch *and* at the last element in the batch.
                is_last_element_of_last_batch = iter_idx + 1 == len(
                    data_loader) and idx + 1 == len(data["target"])
                if filename != last_filename or is_last_element_of_last_batch:
                    filenames_seen += 1
                    # Now we can ditch the reconstruction dict by reconstructing the volume,
                    # will take too much memory otherwise.
                    # TODO: Stack does not support named tensors.
                    volume = torch.stack([
                        _[1].rename(None)
                        for _ in reconstruction_output[last_filename]
                    ])
                    if is_validation_process:
                        target = torch.stack([
                            _[1].rename(None)
                            for _ in targets_output[last_filename]
                        ])
                        curr_metrics = {
                            metric_name: metric_fn(target, volume)
                            for metric_name, metric_fn in
                            volume_metrics.items()
                        }
                        val_volume_metrics[last_filename] = curr_metrics
                        # Log the center slice of the volume
                        if (len(visualize_slices) <
                                self.cfg.logging.tensorboard.num_images):
                            visualize_slices.append(volume[volume.shape[0] //
                                                           2])
                            visualize_target.append(target[target.shape[0] //
                                                           2])

                        # Delete outputs from memory, and recreate dictionary. This is not needed when not in validation
                        # as we are actually interested in the output
                        del targets_output
                        targets_output = defaultdict(list)
                        del reconstruction_output
                        reconstruction_output = defaultdict(list)

                    if all_filenames:
                        log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:"
                    else:
                        log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:"

                    self.logger.info(
                        f"{log_prefix} {last_filename}"
                        f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s."
                    )
                    # restart timer
                    time_start = time.time()
                    last_filename = filename

                curr_slice = output_abs[idx].detach()
                slice_no = int(slice_nos[idx].numpy())

                # TODO: CPU?
                reconstruction_output[filename].append(
                    (slice_no, curr_slice.cpu()))

                if is_validation_process:
                    targets_output[filename].append(
                        (slice_no, target_abs[idx].cpu()))

        # Average loss dict
        loss_dict = reduce_list_of_dicts(val_losses)
        reduce_tensor_dict(loss_dict)

        communication.synchronize()
        torch.cuda.empty_cache()

        # TODO: Does not work yet with normal gather.
        all_gathered_metrics = merge_list_of_dicts(
            communication.all_gather(val_volume_metrics))
        if not is_validation_process:
            return loss_dict, reconstruction_output

        # TODO: Apply named tuples where applicable
        # TODO: Several functions have multiple output values, in many cases
        # TODO: it would be more convenient to convert this to namedtuples.
        return loss_dict, all_gathered_metrics, visualize_slices, visualize_target
Exemple #5
0
    def _do_iteration(
            self, data: Dict[str, torch.Tensor],
            loss_fns: Optional[Dict[str,
                                    Callable]]) -> Tuple[torch.Tensor, Dict]:

        # loss_fns can be done, e.g. during validation
        if loss_fns is None:
            loss_fns = {}

        # TODO(jt): Target is not needed in the model input, but in the loss computation. Keep it here for now.
        target = data["target"].align_to(*self.complex_names).to(
            self.device)  # type: ignore
        # The first input_image in the iteration is the input_image with the mask applied and no first hidden state.
        input_image = data.pop("masked_image").to(self.device)  # type: ignore
        hidden_state = None
        output_image = None
        loss_dicts = []

        # TODO: Target might not need to be copied.
        data = dict_to_device(data, self.device)
        # TODO(jt): keys=['sampling_mask', 'sensitivity_map', 'target', 'masked_kspace', 'scaling_factor']

        sensitivity_map = data["sensitivity_map"]
        # Some things can be done with the sensitivity map here, e.g. apply a u-net
        if "sensitivity_model" in self.models:
            sensitivity_map = self.compute_model_per_coil(
                self.models["sensitivity_model"], sensitivity_map)

        # The sensitivity map needs to be normalized such that
        # So \sum_{i \in \text{coils}} S_i S_i^* = 1
        sensitivity_map_norm = modulus(sensitivity_map).sum("coil")
        data["sensitivity_map"] = safe_divide(sensitivity_map,
                                              sensitivity_map_norm)

        for rim_step in range(self.cfg.model.steps):
            with autocast(enabled=self.mixed_precision):
                reconstruction_iter, hidden_state = self.model(
                    **data,
                    input_image=input_image,
                    hidden_state=hidden_state,
                )
                # TODO: Unclear why this refining is needed.

                output_image = reconstruction_iter[-1].refine_names(
                    *self.complex_names)

                loss_dict = {
                    k: torch.tensor([0.0], dtype=target.dtype).to(self.device)
                    for k in loss_fns.keys()
                }
                for output_image_iter in reconstruction_iter:
                    for k, v in loss_dict.items():
                        loss_dict[k] = v + loss_fns[k](
                            output_image_iter,
                            target,
                            reduction="mean",
                        )

                loss_dict = {
                    k: v / len(reconstruction_iter)
                    for k, v in loss_dict.items()
                }
                loss = sum(loss_dict.values())

            if self.model.training:
                self._scaler.scale(loss).backward()

            # Detach hidden state from computation graph, to ensure loss is only computed per RIM block.
            hidden_state = hidden_state.detach()
            input_image = output_image.detach()

            loss_dicts.append(
                detach_dict(loss_dict)
            )  # Need to detach dict as this is only used for logging.

        # Add the loss dicts together over RIM steps, divide by the number of steps.
        loss_dict = reduce_list_of_dicts(loss_dicts,
                                         mode="sum",
                                         divisor=self.cfg.model.steps)
        return output_image, loss_dict
Exemple #6
0
    def evaluate(
        self,
        data_loader: DataLoader,
        loss_fns: Optional[Dict[str, Callable]],
        crop: Optional[str] = None,
        is_validation_process=True,
    ):

        # TODO(jt): Also log other models output (e.g. sensitivity map).
        # TODO(jt): This can be simplified as the sampler now only outputs batches belonging to the same volume.
        self.models_to_device()
        self.models_validation_mode()
        torch.cuda.empty_cache()

        # Variables required for evaluation.
        # TODO(jt): Consider if this needs to be in the main engine.py or here. Might be possible we have different
        # types needed, perhaps even a FastMRI engine or something similar depending on the metrics.
        volume_metrics = self.build_metrics(self.cfg.validation.metrics)

        reconstruction_output = defaultdict(list)
        targets_output = defaultdict(list)
        val_losses = []
        val_volume_metrics = defaultdict(dict)
        last_filename = None

        # Container to for the slices which can be visualized in TensorBoard.
        visualize_slices = []
        visualize_target = []

        # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler
        # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is
        # that the slices are outputted from the Dataset *sequentially* for each volume one by one.
        for iter_idx, data in enumerate(data_loader):
            self.log_process(iter_idx, len(data_loader))
            data = AddNames()(data)
            filenames = data.pop("filename")
            if len(set(filenames)) != 1:
                raise ValueError(
                    f"Expected a batch during validation to only contain filenames of one case. "
                    f"Got {set(filenames)}.")

            slice_nos = data.pop("slice_no")
            scaling_factors = data.pop("scaling_factor")

            # Check if reconstruction size is the data
            if self.cfg.validation.crop == "header":
                # This will be of the form [tensor(x_0, x_1, ...), tensor(y_0, y_1,...), tensor(z_0, z_1, ...)] over
                # batches.
                resolution = [
                    _.cpu().numpy().tolist()
                    for _ in data["reconstruction_size"]
                ]
                # The volume sampler should give validation indices belonging to the *same* volume, so it should be
                # safe taking the first element, the matrix size are in x,y,z (we work in z,x,y).
                resolution = [_[0] for _ in resolution][:-1]
            elif self.cfg.validation.crop == "training":
                resolution = self.cfg.training.loss.crop
            elif not self.cfg.validation.loss.crop:
                resolution = None
            else:
                raise ValueError(
                    f"Cropping should be either set to `header` to get the values from the header or "
                    f"`training` to take the same value as training.")

            # Compute output and loss.
            output, loss_dict = self._do_iteration(data, loss_fns)
            val_losses.append(loss_dict)

            # Output is complex-valued, and has to be cropped. This holds for both output and target.
            output_abs = self.process_output(
                output.refine_names(*self.complex_names).detach(),
                scaling_factors,
                resolution=resolution,
            )

            if is_validation_process:
                target_abs = self.process_output(
                    data["target"].refine_names(*self.real_names).detach(),
                    scaling_factors,
                    resolution=resolution,
                )
            del output  # Explicitly call delete to clear memory.
            # TODO: Is a hack.

            # Aggregate volumes to be able to compute the metrics on complete volumes.
            for idx, filename in enumerate(filenames):
                if last_filename is None:
                    last_filename = (
                        filename  # First iteration last_filename is not set.
                    )
                # If the new filename is not the previous one, then we can reconstruct the volume as the sampling
                # is linear.
                # For the last case we need to check if we are at the last batch *and* at the last element in the batch.
                if filename != last_filename or (
                        iter_idx + 1 == len(data_loader)
                        and idx + 1 == len(data["target"])):
                    # Now we can ditch the reconstruction dict by reconstructing the volume,
                    # will take too much memory otherwise.
                    # TODO: Stack does not support named tensors.
                    volume = torch.stack([
                        _[1].rename(None)
                        for _ in reconstruction_output[last_filename]
                    ])
                    self.logger.info(
                        f"Reconstructed {last_filename} (shape = {list(volume.shape)})."
                    )
                    if is_validation_process:
                        target = torch.stack([
                            _[1].rename(None)
                            for _ in targets_output[last_filename]
                        ])
                        curr_metrics = {
                            metric_name: metric_fn(volume, target)
                            for metric_name, metric_fn in
                            volume_metrics.items()
                        }
                        val_volume_metrics[last_filename] = curr_metrics
                        # Log the center slice of the volume
                        if len(visualize_slices
                               ) < self.cfg.tensorboard.num_images:
                            visualize_slices.append(
                                normalize_image(volume[volume.shape[0] // 2]))
                            visualize_target.append(
                                normalize_image(target[target.shape[0] // 2]))

                        # Delete outputs from memory, and recreate dictionary. This is not needed when not in validation
                        # as we are actually interested in the output
                        del targets_output
                        targets_output = defaultdict(list)
                        del reconstruction_output
                        reconstruction_output = defaultdict(list)

                    last_filename = filename

                curr_slice = output_abs[idx]
                slice_no = int(slice_nos[idx].numpy())

                # TODO: CPU?
                reconstruction_output[filename].append(
                    (slice_no, curr_slice.cpu()))

                if is_validation_process:
                    targets_output[filename].append(
                        (slice_no, target_abs[idx].cpu()))

        # Average loss dict
        loss_dict = reduce_list_of_dicts(val_losses)
        reduce_tensor_dict(loss_dict)

        communication.synchronize()
        torch.cuda.empty_cache()

        # TODO(jt): Does not work yet with normal gather.
        all_gathered_metrics = merge_list_of_dicts(
            communication.all_gather(val_volume_metrics))

        if not is_validation_process:
            return loss_dict, reconstruction_output

        # TODO(jt): Make named tuple
        return loss_dict, all_gathered_metrics, visualize_slices, visualize_target
Exemple #7
0
    def validation_loop(
        self,
        validation_datasets,
        loss_fns,
        experiment_directory,
        iter_idx,
        num_workers: int = 6,
    ):
        if not validation_datasets:
            return

        storage = get_event_storage()

        data_loaders = self.build_validation_loaders(
            validation_data=validation_datasets,
            num_workers=num_workers,
        )
        for curr_dataset_name, curr_data_loader in data_loaders:
            self.logger.info(f"Evaluating: {curr_dataset_name}...")
            (
                curr_loss_dict,
                curr_metrics_per_case,
                visualize_slices,
                visualize_target,
            ) = self.evaluate(
                curr_data_loader,
                loss_fns,
                is_validation_process=True,
            )

            if experiment_directory:
                json_output_fn = (
                    experiment_directory
                    / f"metrics_val_{curr_dataset_name}_{iter_idx}.json"
                )
                json_output_fn.parent.mkdir(
                    exist_ok=True, parents=True
                )  # A / in the filename can create a folder
                if communication.is_main_process():
                    write_json(
                        json_output_fn,
                        curr_metrics_per_case,
                    )
                self.logger.info(f"Wrote per image logs to: {json_output_fn}.")

            # Metric dict still needs to be reduced as it gives values *per* data
            curr_metric_dict = reduce_list_of_dicts(
                list(curr_metrics_per_case.values()), mode="average"
            )

            key_prefix = (
                "val/" if not curr_dataset_name else f"val/{curr_dataset_name}/"
            )
            loss_reduced = sum(curr_loss_dict.values())
            storage.add_scalars(
                **{key_prefix + "loss": loss_reduced},
                **{
                    **prefix_dict_keys(curr_metric_dict, key_prefix),
                    **prefix_dict_keys(curr_loss_dict, key_prefix),
                },
                smoothing_hint=False,
            )
            visualize_slices = self.process_slices_for_visualization(
                visualize_slices, visualize_target
            )
            storage.add_image(f"{key_prefix}prediction", visualize_slices)

            if iter_idx // self.cfg.training.validation_steps - 1 == 0:
                visualize_target = make_grid(
                    crop_to_largest(visualize_target, pad_value=0),
                    nrow=self.cfg.logging.tensorboard.num_images,
                    scale_each=True,
                )
                storage.add_image(f"{key_prefix}target", visualize_target)

            self.logger.info(
                f"Done evaluation of {curr_dataset_name} at iteration {iter_idx}."
            )
        self.model.train()
Exemple #8
0
    def evaluate(self,
                 data_loader: DataLoader,
                 loss_fns: Dict[str, Callable],
                 volume_metrics: Optional[Dict[str, Callable]] = None,
                 evaluation_round=0):

        self.logger.info(f'Evaluating...')
        self.model.eval()
        torch.cuda.empty_cache()

        # Variables required for evaluation.
        volume_metrics = volume_metrics if volume_metrics is not None else self.build_metrics()
        storage = get_event_storage()

        reconstruction_output = defaultdict(list)
        targets_output = defaultdict(list)
        val_losses = []
        val_volume_metrics = defaultdict(dict)
        last_filename = None

        # Container to for the slices which can be visualized in TensorBoard.
        visualize_slices = []
        visualize_target = []

        # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler
        # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is
        # that the slices are outputted from the Dataset *sequentially* for each volume one by one.
        for iter_idx, data in enumerate(data_loader):
            self.log_process(iter_idx, len(data_loader))
            data = AddNames()(data)
            filenames = data.pop('filename')
            slice_nos = data.pop('slice_no')
            scaling_factors = data.pop('scaling_factor')

            # Compute output and loss.
            output, loss_dict = self._do_iteration(data, loss_fns)
            val_losses.append(loss_dict)

            # Output is complex-valued, and has to be cropped. This holds for both output and target.
            output_abs = self.process_output(
                output.refine_names('batch', 'complex', 'height', 'width').detach(), scaling_factors, 320)
            target_abs = self.process_output(
                data['target'].refine_names('batch', 'height', 'width').detach(), scaling_factors, 320)
            del output  # Explicitly call delete to clear memory.
            # TODO: Is a hack.

            # Aggregate volumes to be able to compute the metrics on complete volumes.
            batch_counter = 0
            for idx, filename in enumerate(filenames):
                if last_filename is None:
                    last_filename = filename  # First iteration last_filename is not set.
                # If the new filename is not the previous one, then we can reconstruct the volume as the sampling
                # is linear.
                # For the last case we need to check if we are at the last batch *and* at the last element in the batch.
                if filename != last_filename or (iter_idx + 1 == len(data_loader) and idx + 1 == len(data['target'])):
                    # Now we can ditch the reconstruction dict by reconstructing the volume,
                    # will take too mucih memory otherwise.
                    # TODO: Stack does not support named tensors.
                    volume = torch.stack([_[1].rename(None) for _ in reconstruction_output[last_filename]])
                    target = torch.stack([_[1].rename(None) for _ in targets_output[last_filename]])
                    self.logger.info(f'Reconstructed {last_filename} (shape = {list(volume.shape)}).')
                    curr_metrics = {
                        metric_name: metric_fn(volume, target) for metric_name, metric_fn in volume_metrics.items()}
                    val_volume_metrics[last_filename] = curr_metrics

                    # Log the center slice of the volume
                    if len(visualize_slices) < self.cfg.tensorboard.num_images:
                        visualize_slices.append(normalize_image(volume[volume.shape[0] // 2]))
                        # Target only needs to be logged once.
                        if evaluation_round == 0:
                            visualize_target.append(normalize_image(target[target.shape[0] // 2]))

                    last_filename = filename

                    # Delete outputs from memory, and recreate dictionary.
                    del reconstruction_output
                    del targets_output
                    reconstruction_output = defaultdict(list)
                    targets_output = defaultdict(list)

                curr_slice = output_abs[idx]
                slice_no = int(slice_nos[idx].numpy())

                # TODO: CPU?
                reconstruction_output[filename].append((slice_no, curr_slice.cpu()))
                targets_output[filename].append((slice_no, target_abs[idx].cpu()))

        # Average loss dict
        loss_dict = reduce_list_of_dicts(val_losses)
        reduce_tensor_dict(loss_dict)

        # Log slices.
        visualize_slices = make_grid(visualize_slices, nrow=4, scale_each=True)
        storage.add_image('validation/prediction', visualize_slices)

        if evaluation_round == 0:
            visualize_target = make_grid(visualize_target, nrow=4, scale_each=True)
            storage.add_image('validation/target', visualize_target)

        communication.synchronize()
        torch.cuda.empty_cache()

        return loss_dict