Example #1
0
    def apply_transform(self, sample: Subject) -> dict:
        scaling_params, rotation_params = self.get_params(
            self.scales, self.degrees, self.isotropic)
        random_parameters_dict = {
            'scaling': scaling_params,
            'rotation': rotation_params,
        }

        if not self.is_tensor:
            sample.check_consistent_shape()
            for image_dict in sample.get_images(intensity_only=False):
                if image_dict[TYPE] == LABEL:
                    interpolation = Interpolation.NEAREST
                else:
                    interpolation = self.interpolation
                image_dict[DATA] = self.apply_affine_transform(
                    image_dict[DATA],
                    image_dict[AFFINE],
                    scaling_params,
                    rotation_params,
                    interpolation,
                )
            sample.add_transform(self, random_parameters_dict)
        else:
            sample = self.apply_affine_transform(
                sample,
                np.identity(4),
                scaling_params,
                rotation_params,
                self.interpolation,
            )
        return sample
Example #2
0
    def __getitem__(self, item):
        if self.loadASTrain:
            if self.torchiosub:
                return Subject({
                    'img':
                    H5DSImage(self.samples[item][0], lazypatch=self.lazypatch)
                })
            else:
                return torch.from_numpy(self.samples[item][0][()]).unsqueeze(0)
        else:
            if self.torchiosub:
                if len(self.samples[item]) == 2:
                    return Subject({
                        'img':
                        H5DSImage(self.samples[item][0],
                                  lazypatch=self.lazypatch),
                        'gt':
                        H5DSImage(self.samples[item][1],
                                  lazypatch=self.lazypatch)
                    })
                else:
                    return Subject({
                        'img':
                        H5DSImage(self.samples[item],
                                  lazypatch=self.lazypatch),
                        'gt':
                        H5DSImage(self.samples[item], lazypatch=self.lazypatch)
                    })  #this is dirty. TODO

            else:
                return (torch.from_numpy(
                    self.samples[item][0][()]).unsqueeze(0),
                        torch.from_numpy(
                            self.samples[item][1][()]).unsqueeze(0))
Example #3
0
    def apply_transform(self, sample: Subject) -> dict:
        bspline_params = self.get_params(
            self.num_control_points,
            self.max_displacement,
            self.num_locked_borders,
        )
        random_parameters_dict = {'coarse_grid': bspline_params}
        if not self.is_tensor:
            sample.check_consistent_shape()
            for image_dict in sample.get_images(intensity_only=False):
                if image_dict[TYPE] == LABEL:
                    interpolation = Interpolation.NEAREST
                else:
                    interpolation = self.interpolation
                image_dict[DATA] = self.apply_bspline_transform(
                    image_dict[DATA],
                    image_dict[AFFINE],
                    bspline_params,
                    interpolation,
                )
            sample.add_transform(self, random_parameters_dict)
        else:
            sample = self.apply_bspline_transform(
                sample,
                np.identity(4),
                bspline_params,
                self.interpolation

            )
        return sample
Example #4
0
 def apply_transform(self, sample: Subject) -> dict:
     axes_to_flip_hot = self.get_params(self.axes, self.flip_probability)
     random_parameters_dict = {'axes': axes_to_flip_hot}
     if self.is_tensor:
         return self.flip_dimensions(sample, axes_to_flip_hot)
     else:
         for image_dict in sample.get_images(intensity_only=False):
             tensor = image_dict[DATA]
             image_dict[DATA] = self.flip_dimensions(tensor, axes_to_flip_hot)
         sample.add_transform(self, random_parameters_dict)
         return sample
Example #5
0
 def apply_transform(self, sample: Subject) -> dict:
     random_parameters_images_dict = {}
     std = self.get_params(self.std_range)
     if not self.is_tensor:
         sample.check_consistent_shape()
         for image_name, image_dict in sample.get_images_dict().items():
             random_parameters_dict = {'std': std}
             random_parameters_images_dict[image_name] = random_parameters_dict
             image_dict[DATA] = add_noise(image_dict[DATA], std)
         sample.add_transform(self, random_parameters_images_dict)
     else:
         sample = add_noise(sample, std)
     return sample
Example #6
0
 def __getitem__(self, item):
     if self.torchiosub:
         return Subject({
             'img':
             H5DSImage(self.samples[item], lazypatch=self.lazypatch)
         })
     else:
         return torch.from_numpy(self.samples[item][()]).unsqueeze(0)
Example #7
0
 def apply_normalization(
     self,
     subject: Subject,
     image_name: str,
     mask: torch.Tensor,
 ) -> None:
     image = subject[image_name]
     mask = image.data != 0
     standardized = self.znorm(
         image.data,
         mask,
     )
     if standardized is None:
         message = ("Standard deviation is 0 for masked values"
                    f' in image "{image_name}" ({image.path})')
         raise RuntimeError(message)
     subject.get_images_dict(
         intensity_only=True)[image_name]["data"] = standardized
Example #8
0
def plot_subject(subject: Subject, save_plot_path: str):
    if save_plot_path:
        os.makedirs(save_plot_path, exist_ok=True)

    data_dict = {}
    sx, sy, sz = subject.spatial_shape
    sx, sy, sz = min(sx, sy, sz) / sx, min(sx, sy, sz) / sy, min(sx, sy,
                                                                 sz) / sz
    for name, image in subject.get_images_dict(intensity_only=False).items():
        if isinstance(image, LabelMap):
            data_dict[name] = LabelMap(
                tensor=squeeze_segmentation(image),
                affine=np.eye(4) * np.array([sx, sy, sz, 1]),
            )
        else:
            data_dict[name] = ScalarImage(tensor=image.data,
                                          affine=np.eye(4) *
                                          np.array([sx, sy, sz, 1]))

    out_subject = Subject(data_dict)
    out_subject.plot(reorient=False, show=True, figsize=(10, 10))

    mpl, plt = import_mpl_plt()
    backend_ = mpl.get_backend()

    plt.ioff()
    mpl.use("agg")
    for x in range(max(out_subject.spatial_shape)):
        out_subject.plot(
            reorient=False,
            indices=(
                min(x, out_subject.spatial_shape[0] - 1),
                min(x, out_subject.spatial_shape[1] - 1),
                min(x, out_subject.spatial_shape[2] - 1),
            ),
            output_path=f"{save_plot_path}/{x:03d}.png",
            show=False,
            figsize=(10, 10),
        )
        plt.close("all")
    plt.ion()
    mpl.use(backend_)

    create_gifs(save_plot_path,
                f"{save_plot_path}/{os.path.basename(save_plot_path)}.gif")
Example #9
0
def plot_aggregated_image(
    writer: SummaryWriter,
    epoch: int,
    model: torch.nn.Module,
    data_loader: torch.utils.data.DataLoader,  # type: ignore
    device: torch.device,
    save_path: str,
):
    log = logging.getLogger(__name__)

    sampler, subject_id = random_subject_from_loader(data_loader)
    aggregator_x = GridAggregator(sampler)
    aggregator_y = GridAggregator(sampler)
    aggregator_y_pred = GridAggregator(sampler)
    for batch, locations in batches_from_sampler(sampler,
                                                 data_loader.batch_size):
        x: torch.Tensor = batch["image"]["data"]
        aggregator_x.add_batch(x, locations)
        y: torch.Tensor = batch["seg"]["data"]
        aggregator_y.add_batch(y, locations)

        logits = model(x.to(device))
        y_pred = (torch.sigmoid(logits) > 0.5).float()
        aggregator_y_pred.add_batch(y_pred, locations)

    whole_x = aggregator_x.get_output_tensor()
    whole_y = aggregator_y.get_output_tensor()
    whole_y_pred = aggregator_y_pred.get_output_tensor()

    plot_subject(
        Subject(
            image=ScalarImage(tensor=whole_x),
            true_seg=LabelMap(tensor=whole_y),
            pred_seg=LabelMap(tensor=whole_y_pred),
        ),
        f"{save_path}/{epoch}-{subject_id}",
    )
Example #10
0
 def apply_transform(self, sample: Subject) -> dict:
     for image_dict in sample.get_images(intensity_only=False):
         image_dict[DATA] = F.interpolate(image_dict[DATA].unsqueeze(0),
                                          size=(SIZE, SIZE, SIZE))
         image_dict[DATA] = image_dict[DATA].squeeze(0)
     return sample
Example #11
0
 def apply_transform(self, sample: Subject) -> dict:
     for image_dict in sample.get_images(intensity_only=False):
         image_dict[DATA] = image_dict[DATA].squeeze().unsqueeze(0)
     return sample