示例#1
0
    def __call__(self, sample):
        if self.type_of_map == "unit":
            kspace = sample["kspace"]
            #TODO(kp) Figure out a way to skip this class entirely if sensitivity map already in sample and est = false
            #(kp) added if statement to keep sensitivity map from being altered if already existing
            if 'sensitivity_map' in sample:
                return sample
            sensitivity_map = torch.zeros(kspace.shape).float()
            # TODO(jt): Named variant, this assumes the complex channel is last.
            if not kspace.names[-1] == "complex":
                raise NotImplementedError(f"Assuming last channel is complex.")
            sensitivity_map[..., 0] = 1.0
            sample["sensitivity_map"] = sensitivity_map.refine_names(*kspace.names).to(
                kspace.device
            )

        elif self.type_of_map == "rss_estimate":
            acs_image = self.estimate_acs_image(sample)
            acs_image_rss = T.root_sum_of_squares(acs_image, dim="coil").align_as(
                acs_image
            )
            sample["sensitivity_map"] = T.safe_divide(acs_image, acs_image_rss)
        else:
            raise ValueError(
                f"Expected type of map to be either `unit` or `rss_estimate`. Got {self.type_of_map}."
            )

        return sample
示例#2
0
    def __call__(self, sample):
        if self.type_of_map == "unit":
            kspace = sample["kspace"]
            sensitivity_map = torch.zeros(kspace.shape).float()
            # TODO(jt): Named variant, this assumes the complex channel is last.
            if not kspace.names[-1] == "complex":
                raise NotImplementedError(f"Assuming last channel is complex.")
            sensitivity_map[..., 0] = 1.0
            sample["sensitivity_map"] = sensitivity_map.refine_names(
                *kspace.names).to(kspace.device)

        elif self.type_of_map == "rss_estimate":
            acs_image = self.estimate_acs_image(sample)
            acs_image_rss = T.root_sum_of_squares(
                acs_image, dim="coil").align_as(acs_image)
            sample["sensitivity_map"] = T.safe_divide(acs_image, acs_image_rss)
        else:
            raise ValueError(
                f"Expected type of map to be either `unit` or `rss_estimate`. Got {self.type_of_map}."
            )

        return sample
示例#3
0
    def _do_iteration(
        self,
        data: Dict[str, torch.Tensor],
        loss_fns: Optional[Dict[str, Callable]] = None,
        regularizer_fns: Optional[Dict[str, Callable]] = None,
    ) -> namedtuple:

        # loss_fns can be done, e.g. during validation
        if loss_fns is None:
            loss_fns = {}

        if regularizer_fns is None:
            regularizer_fns = {}

        # The first input_image in the iteration is the input_image with the mask applied and no first hidden state.
        input_image = None
        hidden_state = None
        output_image = None
        loss_dicts = []
        regularizer_dicts = []

        data = dict_to_device(data, self.device)
        # TODO(jt): keys=['sampling_mask', 'sensitivity_map', 'target', 'masked_kspace', 'scaling_factor']
        sensitivity_map = data["sensitivity_map"]

        if "noise_model" in self.models:
            raise NotImplementedError()

        # Some things can be done with the sensitivity map here, e.g. apply a u-net
        if "sensitivity_model" in self.models:
            # Move channels to first axis
            sensitivity_map = sensitivity_map.align_to(*self.complex_names(
                add_coil=True))

            sensitivity_map = (self.compute_model_per_coil(
                "sensitivity_model",
                sensitivity_map).refine_names(*sensitivity_map.names).align_to(
                    *self.complex_names_complex_last(add_coil=True)))
            # Output has channel first, it is ("batch, "coil", "complex", ...)

        # The sensitivity map needs to be normalized such that
        # So \sum_{i \in \text{coils}} S_i S_i^* = 1
        sensitivity_map_norm = torch.sqrt(
            ((sensitivity_map**2).sum("complex")).sum("coil"))

        data["sensitivity_map"] = T.safe_divide(sensitivity_map,
                                                sensitivity_map_norm)
        if self.cfg.model.scale_loglikelihood:
            scaling_factor = (1.0 * self.cfg.model.scale_loglikelihood /
                              (data["scaling_factor"]**2))
            scaling_factor = scaling_factor.reshape(-1, 1).refine_names(
                "batch", "complex")
            self.logger.debug(f"Scaling factor is: {scaling_factor}")
        else:
            # Needs fixing.
            scaling_factor = (torch.tensor([1.0]).to(
                sensitivity_map.device).refine_names("complex"))

        for _ in range(self.cfg.model.steps):
            with autocast(enabled=self.mixed_precision):
                reconstruction_iter, hidden_state = self.model(
                    **data,
                    input_image=input_image,
                    hidden_state=hidden_state,
                    loglikelihood_scaling=scaling_factor,
                )
                # TODO: Unclear why this refining is needed.
                output_image = reconstruction_iter[-1].refine_names(
                    *self.complex_names())

                loss_dict = {
                    k: torch.tensor([0.0],
                                    dtype=data["target"].dtype).to(self.device)
                    for k in loss_fns.keys()
                }
                regularizer_dict = {
                    k: torch.tensor([0.0],
                                    dtype=data["target"].dtype).to(self.device)
                    for k in regularizer_fns.keys()
                }

                # TODO: This seems too similar not to be able to do this, perhaps a partial can help here
                for output_image_iter in reconstruction_iter:
                    for k, v in loss_dict.items():
                        loss_dict[k] = v + loss_fns[k](
                            output_image_iter,
                            **data,
                            reduction="mean",
                        )
                    for k, v in regularizer_dict.items():
                        regularizer_dict[k] = (v + regularizer_fns[k](
                            output_image_iter,
                            **data,
                        ).rename(None))

                loss_dict = {
                    k: v / len(reconstruction_iter)
                    for k, v in loss_dict.items()
                }
                regularizer_dict = {
                    k: v / len(reconstruction_iter)
                    for k, v in regularizer_dict.items()
                }

                loss = sum(loss_dict.values()) + sum(regularizer_dict.values())

            if self.model.training:
                self._scaler.scale(loss).backward()

            # Detach hidden state from computation graph, to ensure loss is only computed per RIM block.
            hidden_state = hidden_state.detach()
            input_image = output_image.detach()

            loss_dicts.append(detach_dict(loss_dict))
            regularizer_dicts.append(
                detach_dict(regularizer_dict)
            )  # Need to detach dict as this is only used for logging.

        # Add the loss dicts together over RIM steps, divide by the number of steps.
        loss_dict = reduce_list_of_dicts(loss_dicts,
                                         mode="sum",
                                         divisor=self.cfg.model.steps)
        regularizer_dict = reduce_list_of_dicts(regularizer_dicts,
                                                mode="sum",
                                                divisor=self.cfg.model.steps)
        output = namedtuple(
            "do_iteration",
            ["output_image", "sensitivity_map", "data_dict"],
        )

        return output(
            output_image=output_image,
            sensitivity_map=data["sensitivity_map"],
            data_dict={
                **loss_dict,
                **regularizer_dict
            },
        )
示例#4
0
    def _do_iteration(
            self, data: Dict[str, torch.Tensor],
            loss_fns: Optional[Dict[str,
                                    Callable]]) -> Tuple[torch.Tensor, Dict]:

        # loss_fns can be done, e.g. during validation
        if loss_fns is None:
            loss_fns = {}

        # TODO(jt): Target is not needed in the model input, but in the loss computation. Keep it here for now.
        target = data["target"].align_to(*self.complex_names).to(
            self.device)  # type: ignore
        # The first input_image in the iteration is the input_image with the mask applied and no first hidden state.
        input_image = data.pop("masked_image").to(self.device)  # type: ignore
        hidden_state = None
        output_image = None
        loss_dicts = []

        # TODO: Target might not need to be copied.
        data = dict_to_device(data, self.device)
        # TODO(jt): keys=['sampling_mask', 'sensitivity_map', 'target', 'masked_kspace', 'scaling_factor']

        sensitivity_map = data["sensitivity_map"]
        # Some things can be done with the sensitivity map here, e.g. apply a u-net
        if "sensitivity_model" in self.models:
            sensitivity_map = self.compute_model_per_coil(
                self.models["sensitivity_model"], sensitivity_map)

        # The sensitivity map needs to be normalized such that
        # So \sum_{i \in \text{coils}} S_i S_i^* = 1
        sensitivity_map_norm = modulus(sensitivity_map).sum("coil")
        data["sensitivity_map"] = safe_divide(sensitivity_map,
                                              sensitivity_map_norm)

        for rim_step in range(self.cfg.model.steps):
            with autocast(enabled=self.mixed_precision):
                reconstruction_iter, hidden_state = self.model(
                    **data,
                    input_image=input_image,
                    hidden_state=hidden_state,
                )
                # TODO: Unclear why this refining is needed.

                output_image = reconstruction_iter[-1].refine_names(
                    *self.complex_names)

                loss_dict = {
                    k: torch.tensor([0.0], dtype=target.dtype).to(self.device)
                    for k in loss_fns.keys()
                }
                for output_image_iter in reconstruction_iter:
                    for k, v in loss_dict.items():
                        loss_dict[k] = v + loss_fns[k](
                            output_image_iter,
                            target,
                            reduction="mean",
                        )

                loss_dict = {
                    k: v / len(reconstruction_iter)
                    for k, v in loss_dict.items()
                }
                loss = sum(loss_dict.values())

            if self.model.training:
                self._scaler.scale(loss).backward()

            # Detach hidden state from computation graph, to ensure loss is only computed per RIM block.
            hidden_state = hidden_state.detach()
            input_image = output_image.detach()

            loss_dicts.append(
                detach_dict(loss_dict)
            )  # Need to detach dict as this is only used for logging.

        # Add the loss dicts together over RIM steps, divide by the number of steps.
        loss_dict = reduce_list_of_dicts(loss_dicts,
                                         mode="sum",
                                         divisor=self.cfg.model.steps)
        return output_image, loss_dict