예제 #1
0
    def __call__(self, sample):
        if self.normalize_key == "scaling_factor":  # This is a real-valued given number
            scaling_factor = sample["scaling_factor"]
        elif not self.normalize_key:
            scaling_factor = 1.0
        else:
            data = sample[self.normalize_key]

            # Compute the maximum and scale the input
            if self.percentile:
                # TODO: Fix when named tensors allow views.
                tview = -1.0 * T.modulus(data).rename(None).view(-1)
                scaling_factor, _ = torch.kthvalue(
                    tview, int((1 - self.percentile) * tview.size()[0]))
                scaling_factor = -1.0 * scaling_factor
            else:
                scaling_factor = T.modulus(data).max()

        # Normalize data
        if self.normalize_key:
            for key in sample.keys():
                if key != self.normalize_key and key not in self.other_keys:
                    continue
                sample[key] = sample[key] / scaling_factor

        sample["scaling_factor"] = scaling_factor
        return sample
예제 #2
0
    def __call__(self, sample):
        """

        Parameters
        ----------
        sample : dict

        Returns
        -------
        data dictionary

        TODO: Normalization of the sensitivity map should be done in the data loader.
        """
        data = sample[self.normalize_key]

        # Compute the maximum and scale the input
        if self.percentile:
            # TODO: Fix when named tensors allow views.
            tview = -1.0 * transforms.modulus(data).rename(None).view(-1)
            image_max, _ = torch.kthvalue(tview, int((1 - self.percentile) * tview.size()[0]))
            image_max = -1.0 * image_max
        else:
            image_max = transforms.modulus(data).max()

        # Normalize data
        for key in sample.keys():
            # TODO: Reconsider this.
            if any([_ in key for _ in [self.normalize_key, 'masked_kspace', 'target', 'kspace']]):
                sample[key] = sample[key] / image_max

        sample['scaling_factor'] = image_max

        return sample
예제 #3
0
    def __call__(self, sample):
        data = sample[self.normalize_key]

        # Compute the maximum and scale the input
        if self.percentile:
            # TODO: Fix when named tensors allow views.
            tview = -1.0 * T.modulus(data).rename(None).view(-1)
            image_max, _ = torch.kthvalue(
                tview, int((1 - self.percentile) * tview.size()[0])
            )
            image_max = -1.0 * image_max
        else:
            image_max = T.modulus(data).max()

        # Normalize data
        for key in sample.keys():
            # TODO: Reconsider this.
            if any(
                [
                    _ in key
                    for _ in [self.normalize_key, "masked_kspace", "target", "kspace"]
                ]
            ):
                sample[key] = sample[key] / image_max

        sample["scaling_factor"] = image_max

        return sample
예제 #4
0
def test_modulus(shape):
    shape = shape + [2]
    data = create_input(shape)
    out_torch = transforms.modulus(data).numpy()
    input_numpy = tensor_to_complex_numpy(data)
    out_numpy = np.abs(input_numpy)
    assert np.allclose(out_torch, out_numpy)
예제 #5
0
    def cropper(self, source, target, resolution=(320, 320)):
        slice_index = target.names.index("slice")
        source = source.refine_names(*target.names)

        use_center_slice = True
        if use_center_slice:
            # Source and target have a different number of slices when trimming in depth.
            source = source.select(slice_index,
                                   source.size("slice") // 2).rename(None)
            target = target.select("slice",
                                   target.size("slice") // 2).rename(None)
        else:
            source = source.flatten(["batch", "slice"], "batch").rename(None)
            target = target.flatten(["batch", "slice"], "batch").rename(None)

        complex_names = self.complex_names().copy()
        complex_names.pop(slice_index)

        source_abs = T.modulus(source.refine_names(*complex_names))
        if not resolution or all([_ == 0 for _ in resolution]):
            return source_abs.rename(None).unsqueeze(1), target

        source_abs = T.center_crop(source_abs,
                                   resolution).rename(None).unsqueeze(1)
        target_abs = T.center_crop(target, resolution)
        return source_abs, target_abs
예제 #6
0
    def cropper(self, source, target, resolution=(320, 320)):
        # Can also do reshaping and compute over the full volume
        slice_index = target.names.index("slice")

        use_center_slice = True
        if use_center_slice:
            center_slice = target.size("slice") // 2
            source = source.select(slice_index, center_slice)
            target = target.select("slice", center_slice).rename(None)
        else:
            source = source.refine_names(*target.names)
            source = source.flatten(["batch", "slice"], "batch").rename(None)
            target = target.flatten(["batch", "slice"], "batch").rename(None)

        complex_names = self.complex_names.copy()
        complex_names.pop(slice_index)

        source_abs = modulus(source.refine_names(*complex_names))
        if not resolution or all([_ == 0 for _ in resolution]):
            return source_abs.rename(None).unsqueeze(1), target

        source_abs = center_crop(source_abs,
                                 resolution).rename(None).unsqueeze(1)
        target_abs = center_crop(target, resolution)
        return source_abs, target_abs
예제 #7
0
    def cropper(self, source, target, resolution):
        source = source.rename(None)
        target = target.align_to(*self.complex_names()).rename(None)
        source_abs = T.modulus(source.refine_names(*self.complex_names()))
        if not resolution or all([_ == 0 for _ in resolution]):
            return source_abs.rename(None).unsqueeze(1), target

        source_abs = T.center_crop(source_abs,
                                   resolution).rename(None).unsqueeze(1)
        target_abs = T.center_crop(target, resolution)
        return source_abs, target_abs
예제 #8
0
    def log_first_training_example_and_model(self, data):
        storage = get_event_storage()
        self.logger.info(
            f"First case: slice_no: {data['slice_no'][0]}, filename: {data['filename'][0]}."
        )

        # TODO(jt): Cleaner, loop over types of images
        first_sampling_mask = data["sampling_mask"][0][0]
        first_target = data["target"][0]

        if self.ndim == 3:
            first_sampling_mask = first_sampling_mask[0]
            num_slices = first_target.shape[first_target.names.index("slice")]
            first_target = first_target[num_slices // 2]
        elif self.ndim > 3:
            raise NotImplementedError

        storage.add_image(
            "train/mask", first_sampling_mask[..., 0].rename(None).unsqueeze(0)
        )
        storage.add_image(
            "train/target",
            normalize_image(first_target.rename(None).unsqueeze(0)),
        )

        if "initial_image" in data:
            storage.add_image(
                "train/initial_image",
                normalize_image(
                    T.modulus(data["initial_image"][0]).rename(None).unsqueeze(0)
                ),
            )

        # TODO: Add graph

        self.write_to_logs()
예제 #9
0
    def _do_iteration(
            self, data: Dict[str, torch.Tensor],
            loss_fns: Optional[Dict[str,
                                    Callable]]) -> Tuple[torch.Tensor, Dict]:

        # loss_fns can be done, e.g. during validation
        if loss_fns is None:
            loss_fns = {}

        # TODO(jt): Target is not needed in the model input, but in the loss computation. Keep it here for now.
        target = data["target"].align_to(*self.complex_names).to(
            self.device)  # type: ignore
        # The first input_image in the iteration is the input_image with the mask applied and no first hidden state.
        input_image = data.pop("masked_image").to(self.device)  # type: ignore
        hidden_state = None
        output_image = None
        loss_dicts = []

        # TODO: Target might not need to be copied.
        data = dict_to_device(data, self.device)
        # TODO(jt): keys=['sampling_mask', 'sensitivity_map', 'target', 'masked_kspace', 'scaling_factor']

        sensitivity_map = data["sensitivity_map"]
        # Some things can be done with the sensitivity map here, e.g. apply a u-net
        if "sensitivity_model" in self.models:
            sensitivity_map = self.compute_model_per_coil(
                self.models["sensitivity_model"], sensitivity_map)

        # The sensitivity map needs to be normalized such that
        # So \sum_{i \in \text{coils}} S_i S_i^* = 1
        sensitivity_map_norm = modulus(sensitivity_map).sum("coil")
        data["sensitivity_map"] = safe_divide(sensitivity_map,
                                              sensitivity_map_norm)

        for rim_step in range(self.cfg.model.steps):
            with autocast(enabled=self.mixed_precision):
                reconstruction_iter, hidden_state = self.model(
                    **data,
                    input_image=input_image,
                    hidden_state=hidden_state,
                )
                # TODO: Unclear why this refining is needed.

                output_image = reconstruction_iter[-1].refine_names(
                    *self.complex_names)

                loss_dict = {
                    k: torch.tensor([0.0], dtype=target.dtype).to(self.device)
                    for k in loss_fns.keys()
                }
                for output_image_iter in reconstruction_iter:
                    for k, v in loss_dict.items():
                        loss_dict[k] = v + loss_fns[k](
                            output_image_iter,
                            target,
                            reduction="mean",
                        )

                loss_dict = {
                    k: v / len(reconstruction_iter)
                    for k, v in loss_dict.items()
                }
                loss = sum(loss_dict.values())

            if self.model.training:
                self._scaler.scale(loss).backward()

            # Detach hidden state from computation graph, to ensure loss is only computed per RIM block.
            hidden_state = hidden_state.detach()
            input_image = output_image.detach()

            loss_dicts.append(
                detach_dict(loss_dict)
            )  # Need to detach dict as this is only used for logging.

        # Add the loss dicts together over RIM steps, divide by the number of steps.
        loss_dict = reduce_list_of_dicts(loss_dicts,
                                         mode="sum",
                                         divisor=self.cfg.model.steps)
        return output_image, loss_dict
예제 #10
0
 def cropper(source, target, resolution=(320, 320)):
     source_abs = modulus(source.refine_names('batch', 'complex', 'height', 'width'))
     if resolution is not None or all([_ is not 0 for _ in resolution]):
         source_abs = center_crop(source_abs, resolution).rename(None).unsqueeze(1)
         target_abs = center_crop(target, resolution)
     return source_abs, target_abs