예제 #1
0
    def _do_iteration(self,
                      data: Dict[str, torch.Tensor],
                      loss_fns: Dict[str, Callable]) -> Tuple[torch.Tensor, Dict]:

        # Target is not needed in the model input
        target = data['target'].align_to('batch', 'complex', 'height', 'width').to(self.device)  # type: ignore
        # The first input_image in the iteration is the input_image with the mask applied and no first hidden state.
        input_image = data.pop('masked_image').to(self.device)  # type: ignore
        hidden_state = None
        output_image = None
        loss_dicts = []
        for rim_step in range(self.cfg.model.steps):
            reconstruction_iter, hidden_state = self.model(
                **dict_to_device(data, self.device),
                input_image=input_image,
                hidden_state=hidden_state,
            )
            # TODO: Unclear why this refining is needed.
            output_image = reconstruction_iter[-1].refine_names('batch', 'complex', 'height', 'width')

            loss_dict = {k: torch.tensor([0.], dtype=target.dtype).to(self.device) for k in loss_fns.keys()}
            loss = torch.tensor([0.], device=output_image.device)
            for output_image_iter in reconstruction_iter:
                for k, v in loss_dict.items():
                    loss_dict[k] = v + loss_fns[k](
                        output_image_iter.rename(None), target.rename(None), reduction='mean'
                    )

            # for output_image_iter in reconstruction_iter:
            #     loss_dict = {
            #         k: v + loss_fns[k](output_image_iter.rename(None), target.rename(None), reduction='mean')
            #         for k, v in loss_dict.items()}
            loss_dict = {k: v / len(reconstruction_iter) for k, v in loss_dict.items()}
            loss = sum(loss_dict.values())

            if self.model.training:
                if self.mixed_precision:
                    with amp.scale_loss(loss, self.__optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()  # type: ignore

            # Detach hidden state from computation graph, to ensure loss is only computed per RIM block.
            hidden_state = hidden_state.detach()
            input_image = output_image.detach()

            loss_dicts.append(detach_dict(loss_dict))  # Need to detach dict as this is only used for logging.

        # Add the loss dicts together over RIM steps, divide by the number of steps.
        loss_dict = reduce_list_of_dicts(loss_dicts, mode='sum', divisor=self.cfg.model.steps)
        return output_image, loss_dict
예제 #2
0
    def _do_iteration(
        self,
        data: Dict[str, torch.Tensor],
        loss_fns: Optional[Dict[str, Callable]] = None,
        regularizer_fns: Optional[Dict[str, Callable]] = None,
    ) -> namedtuple:

        # loss_fns can be done, e.g. during validation
        if loss_fns is None:
            loss_fns = {}

        if regularizer_fns is None:
            regularizer_fns = {}

        # The first input_image in the iteration is the input_image with the mask applied and no first hidden state.
        input_image = None
        hidden_state = None
        output_image = None
        loss_dicts = []
        regularizer_dicts = []

        data = dict_to_device(data, self.device)
        # TODO(jt): keys=['sampling_mask', 'sensitivity_map', 'target', 'masked_kspace', 'scaling_factor']
        sensitivity_map = data["sensitivity_map"]

        if "noise_model" in self.models:
            raise NotImplementedError()

        # Some things can be done with the sensitivity map here, e.g. apply a u-net
        if "sensitivity_model" in self.models:
            # Move channels to first axis
            sensitivity_map = sensitivity_map.align_to(*self.complex_names(
                add_coil=True))

            sensitivity_map = (self.compute_model_per_coil(
                "sensitivity_model",
                sensitivity_map).refine_names(*sensitivity_map.names).align_to(
                    *self.complex_names_complex_last(add_coil=True)))
            # Output has channel first, it is ("batch, "coil", "complex", ...)

        # The sensitivity map needs to be normalized such that
        # So \sum_{i \in \text{coils}} S_i S_i^* = 1
        sensitivity_map_norm = torch.sqrt(
            ((sensitivity_map**2).sum("complex")).sum("coil"))

        data["sensitivity_map"] = T.safe_divide(sensitivity_map,
                                                sensitivity_map_norm)
        if self.cfg.model.scale_loglikelihood:
            scaling_factor = (1.0 * self.cfg.model.scale_loglikelihood /
                              (data["scaling_factor"]**2))
            scaling_factor = scaling_factor.reshape(-1, 1).refine_names(
                "batch", "complex")
            self.logger.debug(f"Scaling factor is: {scaling_factor}")
        else:
            # Needs fixing.
            scaling_factor = (torch.tensor([1.0]).to(
                sensitivity_map.device).refine_names("complex"))

        for _ in range(self.cfg.model.steps):
            with autocast(enabled=self.mixed_precision):
                reconstruction_iter, hidden_state = self.model(
                    **data,
                    input_image=input_image,
                    hidden_state=hidden_state,
                    loglikelihood_scaling=scaling_factor,
                )
                # TODO: Unclear why this refining is needed.
                output_image = reconstruction_iter[-1].refine_names(
                    *self.complex_names())

                loss_dict = {
                    k: torch.tensor([0.0],
                                    dtype=data["target"].dtype).to(self.device)
                    for k in loss_fns.keys()
                }
                regularizer_dict = {
                    k: torch.tensor([0.0],
                                    dtype=data["target"].dtype).to(self.device)
                    for k in regularizer_fns.keys()
                }

                # TODO: This seems too similar not to be able to do this, perhaps a partial can help here
                for output_image_iter in reconstruction_iter:
                    for k, v in loss_dict.items():
                        loss_dict[k] = v + loss_fns[k](
                            output_image_iter,
                            **data,
                            reduction="mean",
                        )
                    for k, v in regularizer_dict.items():
                        regularizer_dict[k] = (v + regularizer_fns[k](
                            output_image_iter,
                            **data,
                        ).rename(None))

                loss_dict = {
                    k: v / len(reconstruction_iter)
                    for k, v in loss_dict.items()
                }
                regularizer_dict = {
                    k: v / len(reconstruction_iter)
                    for k, v in regularizer_dict.items()
                }

                loss = sum(loss_dict.values()) + sum(regularizer_dict.values())

            if self.model.training:
                self._scaler.scale(loss).backward()

            # Detach hidden state from computation graph, to ensure loss is only computed per RIM block.
            hidden_state = hidden_state.detach()
            input_image = output_image.detach()

            loss_dicts.append(detach_dict(loss_dict))
            regularizer_dicts.append(
                detach_dict(regularizer_dict)
            )  # Need to detach dict as this is only used for logging.

        # Add the loss dicts together over RIM steps, divide by the number of steps.
        loss_dict = reduce_list_of_dicts(loss_dicts,
                                         mode="sum",
                                         divisor=self.cfg.model.steps)
        regularizer_dict = reduce_list_of_dicts(regularizer_dicts,
                                                mode="sum",
                                                divisor=self.cfg.model.steps)
        output = namedtuple(
            "do_iteration",
            ["output_image", "sensitivity_map", "data_dict"],
        )

        return output(
            output_image=output_image,
            sensitivity_map=data["sensitivity_map"],
            data_dict={
                **loss_dict,
                **regularizer_dict
            },
        )
예제 #3
0
    def evaluate(
        self,
        data_loader: DataLoader,
        loss_fns: Optional[Dict[str, Callable]],
        regularizer_fns: Optional[Dict[str, Callable]] = None,
        crop: Optional[str] = None,
        is_validation_process=True,
    ):

        self.models_to_device()
        self.models_validation_mode()
        torch.cuda.empty_cache()

        # Variables required for evaluation.
        # TODO(jt): Consider if this needs to be in the main engine.py or here. Might be possible we have different
        # types needed, perhaps even a FastMRI engine or something similar depending on the metrics.
        volume_metrics = self.build_metrics(self.cfg.validation.metrics)

        # filenames can be in the volume_indices attribute of the dataset
        if hasattr(data_loader.dataset, "volume_indices"):
            all_filenames = list(data_loader.dataset.volume_indices.keys())
            num_for_this_process = len(
                list(data_loader.batch_sampler.sampler.volume_indices.keys()))
            self.logger.info(
                f"Reconstructing a total of {len(all_filenames)} volumes. "
                f"This process has {num_for_this_process} volumes (world size: {communication.get_world_size()})."
            )
        else:
            num_for_this_process = None
        filenames_seen = 0

        reconstruction_output = defaultdict(list)
        targets_output = defaultdict(list)
        val_losses = []
        val_volume_metrics = defaultdict(dict)
        last_filename = None

        # Container to for the slices which can be visualized in TensorBoard.
        visualize_slices = []
        visualize_target = []
        visualizations = {}

        extra_visualization_keys = (self.cfg.logging.log_as_image
                                    if self.cfg.logging.log_as_image else [])

        # Loop over dataset. This requires the use of direct.data.sampler.DistributedSequentialSampler as this sampler
        # splits the data over the different processes, and outputs the slices linearly. The implicit assumption here is
        # that the slices are outputted from the Dataset *sequentially* for each volume one by one.
        time_start = time.time()

        for iter_idx, data in enumerate(data_loader):
            data = AddNames()(data)
            filenames = data.pop("filename")
            if len(set(filenames)) != 1:
                raise ValueError(
                    f"Expected a batch during validation to only contain filenames of one case. "
                    f"Got {set(filenames)}.")

            slice_nos = data.pop("slice_no")
            scaling_factors = data["scaling_factor"]

            resolution = self.compute_resolution(
                key=self.cfg.validation.crop,
                reconstruction_size=data.get("reconstruction_size", None),
            )

            # Compute output and loss.
            iteration_output = self._do_iteration(
                data, loss_fns, regularizer_fns=regularizer_fns)
            output = iteration_output.output_image
            loss_dict = iteration_output.data_dict
            # sensitivity_map = iteration_output.sensitivity_map

            loss_dict = detach_dict(loss_dict)
            output = output.detach()
            val_losses.append(loss_dict)

            # Output is complex-valued, and has to be cropped. This holds for both output and target.
            output_abs = self.process_output(
                output.refine_names(*self.complex_names()),
                scaling_factors,
                resolution=resolution,
            )

            if is_validation_process:
                target_abs = self.process_output(
                    data["target"].detach().refine_names(*self.real_names()),
                    scaling_factors,
                    resolution=resolution,
                )
                for key in extra_visualization_keys:
                    curr_data = data[key].detach()
                    # Here we need to discover which keys are actually normalized or not
                    # this requires a solution to issue #23: https://github.com/directgroup/direct/issues/23

            del output  # Explicitly call delete to clear memory.
            # TODO: Is a hack.

            # Aggregate volumes to be able to compute the metrics on complete volumes.
            for idx, filename in enumerate(filenames):
                if last_filename is None:
                    last_filename = (
                        filename  # First iteration last_filename is not set.
                    )

                # If the new filename is not the previous one, then we can reconstruct the volume as the sampling
                # is linear.
                # For the last case we need to check if we are at the last batch *and* at the last element in the batch.
                is_last_element_of_last_batch = iter_idx + 1 == len(
                    data_loader) and idx + 1 == len(data["target"])
                if filename != last_filename or is_last_element_of_last_batch:
                    filenames_seen += 1
                    # Now we can ditch the reconstruction dict by reconstructing the volume,
                    # will take too much memory otherwise.
                    # TODO: Stack does not support named tensors.
                    volume = torch.stack([
                        _[1].rename(None)
                        for _ in reconstruction_output[last_filename]
                    ])
                    if is_validation_process:
                        target = torch.stack([
                            _[1].rename(None)
                            for _ in targets_output[last_filename]
                        ])
                        curr_metrics = {
                            metric_name: metric_fn(target, volume)
                            for metric_name, metric_fn in
                            volume_metrics.items()
                        }
                        val_volume_metrics[last_filename] = curr_metrics
                        # Log the center slice of the volume
                        if (len(visualize_slices) <
                                self.cfg.logging.tensorboard.num_images):
                            visualize_slices.append(volume[volume.shape[0] //
                                                           2])
                            visualize_target.append(target[target.shape[0] //
                                                           2])

                        # Delete outputs from memory, and recreate dictionary. This is not needed when not in validation
                        # as we are actually interested in the output
                        del targets_output
                        targets_output = defaultdict(list)
                        del reconstruction_output
                        reconstruction_output = defaultdict(list)

                    if all_filenames:
                        log_prefix = f"{filenames_seen} of {num_for_this_process} volumes reconstructed:"
                    else:
                        log_prefix = f"{iter_idx + 1} of {len(data_loader)} slices reconstructed:"

                    self.logger.info(
                        f"{log_prefix} {last_filename}"
                        f" (shape = {list(volume.shape)}) in {time.time() - time_start:.3f}s."
                    )
                    # restart timer
                    time_start = time.time()
                    last_filename = filename

                curr_slice = output_abs[idx].detach()
                slice_no = int(slice_nos[idx].numpy())

                # TODO: CPU?
                reconstruction_output[filename].append(
                    (slice_no, curr_slice.cpu()))

                if is_validation_process:
                    targets_output[filename].append(
                        (slice_no, target_abs[idx].cpu()))

        # Average loss dict
        loss_dict = reduce_list_of_dicts(val_losses)
        reduce_tensor_dict(loss_dict)

        communication.synchronize()
        torch.cuda.empty_cache()

        # TODO: Does not work yet with normal gather.
        all_gathered_metrics = merge_list_of_dicts(
            communication.all_gather(val_volume_metrics))
        if not is_validation_process:
            return loss_dict, reconstruction_output

        # TODO: Apply named tuples where applicable
        # TODO: Several functions have multiple output values, in many cases
        # TODO: it would be more convenient to convert this to namedtuples.
        return loss_dict, all_gathered_metrics, visualize_slices, visualize_target
예제 #4
0
    def _do_iteration(
            self, data: Dict[str, torch.Tensor],
            loss_fns: Optional[Dict[str,
                                    Callable]]) -> Tuple[torch.Tensor, Dict]:

        # loss_fns can be done, e.g. during validation
        if loss_fns is None:
            loss_fns = {}

        # TODO(jt): Target is not needed in the model input, but in the loss computation. Keep it here for now.
        target = data["target"].align_to(*self.complex_names).to(
            self.device)  # type: ignore
        # The first input_image in the iteration is the input_image with the mask applied and no first hidden state.
        input_image = data.pop("masked_image").to(self.device)  # type: ignore
        hidden_state = None
        output_image = None
        loss_dicts = []

        # TODO: Target might not need to be copied.
        data = dict_to_device(data, self.device)
        # TODO(jt): keys=['sampling_mask', 'sensitivity_map', 'target', 'masked_kspace', 'scaling_factor']

        sensitivity_map = data["sensitivity_map"]
        # Some things can be done with the sensitivity map here, e.g. apply a u-net
        if "sensitivity_model" in self.models:
            sensitivity_map = self.compute_model_per_coil(
                self.models["sensitivity_model"], sensitivity_map)

        # The sensitivity map needs to be normalized such that
        # So \sum_{i \in \text{coils}} S_i S_i^* = 1
        sensitivity_map_norm = modulus(sensitivity_map).sum("coil")
        data["sensitivity_map"] = safe_divide(sensitivity_map,
                                              sensitivity_map_norm)

        for rim_step in range(self.cfg.model.steps):
            with autocast(enabled=self.mixed_precision):
                reconstruction_iter, hidden_state = self.model(
                    **data,
                    input_image=input_image,
                    hidden_state=hidden_state,
                )
                # TODO: Unclear why this refining is needed.

                output_image = reconstruction_iter[-1].refine_names(
                    *self.complex_names)

                loss_dict = {
                    k: torch.tensor([0.0], dtype=target.dtype).to(self.device)
                    for k in loss_fns.keys()
                }
                for output_image_iter in reconstruction_iter:
                    for k, v in loss_dict.items():
                        loss_dict[k] = v + loss_fns[k](
                            output_image_iter,
                            target,
                            reduction="mean",
                        )

                loss_dict = {
                    k: v / len(reconstruction_iter)
                    for k, v in loss_dict.items()
                }
                loss = sum(loss_dict.values())

            if self.model.training:
                self._scaler.scale(loss).backward()

            # Detach hidden state from computation graph, to ensure loss is only computed per RIM block.
            hidden_state = hidden_state.detach()
            input_image = output_image.detach()

            loss_dicts.append(
                detach_dict(loss_dict)
            )  # Need to detach dict as this is only used for logging.

        # Add the loss dicts together over RIM steps, divide by the number of steps.
        loss_dict = reduce_list_of_dicts(loss_dicts,
                                         mode="sum",
                                         divisor=self.cfg.model.steps)
        return output_image, loss_dict