Exemple #1
0
 def apply_transform(self, sample: Subject) -> dict:
     random_parameters_images_dict = {}
     std = self.get_params(self.std_range)
     if not self.is_tensor:
         sample.check_consistent_shape()
         for image_name, image_dict in sample.get_images_dict().items():
             random_parameters_dict = {'std': std}
             random_parameters_images_dict[image_name] = random_parameters_dict
             image_dict[DATA] = add_noise(image_dict[DATA], std)
         sample.add_transform(self, random_parameters_images_dict)
     else:
         sample = add_noise(sample, std)
     return sample
Exemple #2
0
def plot_subject(subject: Subject, save_plot_path: str):
    if save_plot_path:
        os.makedirs(save_plot_path, exist_ok=True)

    data_dict = {}
    sx, sy, sz = subject.spatial_shape
    sx, sy, sz = min(sx, sy, sz) / sx, min(sx, sy, sz) / sy, min(sx, sy,
                                                                 sz) / sz
    for name, image in subject.get_images_dict(intensity_only=False).items():
        if isinstance(image, LabelMap):
            data_dict[name] = LabelMap(
                tensor=squeeze_segmentation(image),
                affine=np.eye(4) * np.array([sx, sy, sz, 1]),
            )
        else:
            data_dict[name] = ScalarImage(tensor=image.data,
                                          affine=np.eye(4) *
                                          np.array([sx, sy, sz, 1]))

    out_subject = Subject(data_dict)
    out_subject.plot(reorient=False, show=True, figsize=(10, 10))

    mpl, plt = import_mpl_plt()
    backend_ = mpl.get_backend()

    plt.ioff()
    mpl.use("agg")
    for x in range(max(out_subject.spatial_shape)):
        out_subject.plot(
            reorient=False,
            indices=(
                min(x, out_subject.spatial_shape[0] - 1),
                min(x, out_subject.spatial_shape[1] - 1),
                min(x, out_subject.spatial_shape[2] - 1),
            ),
            output_path=f"{save_plot_path}/{x:03d}.png",
            show=False,
            figsize=(10, 10),
        )
        plt.close("all")
    plt.ion()
    mpl.use(backend_)

    create_gifs(save_plot_path,
                f"{save_plot_path}/{os.path.basename(save_plot_path)}.gif")
Exemple #3
0
 def apply_normalization(
     self,
     subject: Subject,
     image_name: str,
     mask: torch.Tensor,
 ) -> None:
     image = subject[image_name]
     mask = image.data != 0
     standardized = self.znorm(
         image.data,
         mask,
     )
     if standardized is None:
         message = ("Standard deviation is 0 for masked values"
                    f' in image "{image_name}" ({image.path})')
         raise RuntimeError(message)
     subject.get_images_dict(
         intensity_only=True)[image_name]["data"] = standardized