Esempio n. 1
0
    def __getitem__(self, item):
        if self.loadASTrain:
            if self.torchiosub:
                return Subject({
                    'img':
                    H5DSImage(self.samples[item][0], lazypatch=self.lazypatch)
                })
            else:
                return torch.from_numpy(self.samples[item][0][()]).unsqueeze(0)
        else:
            if self.torchiosub:
                if len(self.samples[item]) == 2:
                    return Subject({
                        'img':
                        H5DSImage(self.samples[item][0],
                                  lazypatch=self.lazypatch),
                        'gt':
                        H5DSImage(self.samples[item][1],
                                  lazypatch=self.lazypatch)
                    })
                else:
                    return Subject({
                        'img':
                        H5DSImage(self.samples[item],
                                  lazypatch=self.lazypatch),
                        'gt':
                        H5DSImage(self.samples[item], lazypatch=self.lazypatch)
                    })  #this is dirty. TODO

            else:
                return (torch.from_numpy(
                    self.samples[item][0][()]).unsqueeze(0),
                        torch.from_numpy(
                            self.samples[item][1][()]).unsqueeze(0))
Esempio n. 2
0
 def __getitem__(self, item):
     if self.torchiosub:
         return Subject({
             'img':
             H5DSImage(self.samples[item], lazypatch=self.lazypatch)
         })
     else:
         return torch.from_numpy(self.samples[item][()]).unsqueeze(0)
Esempio n. 3
0
def plot_subject(subject: Subject, save_plot_path: str):
    if save_plot_path:
        os.makedirs(save_plot_path, exist_ok=True)

    data_dict = {}
    sx, sy, sz = subject.spatial_shape
    sx, sy, sz = min(sx, sy, sz) / sx, min(sx, sy, sz) / sy, min(sx, sy,
                                                                 sz) / sz
    for name, image in subject.get_images_dict(intensity_only=False).items():
        if isinstance(image, LabelMap):
            data_dict[name] = LabelMap(
                tensor=squeeze_segmentation(image),
                affine=np.eye(4) * np.array([sx, sy, sz, 1]),
            )
        else:
            data_dict[name] = ScalarImage(tensor=image.data,
                                          affine=np.eye(4) *
                                          np.array([sx, sy, sz, 1]))

    out_subject = Subject(data_dict)
    out_subject.plot(reorient=False, show=True, figsize=(10, 10))

    mpl, plt = import_mpl_plt()
    backend_ = mpl.get_backend()

    plt.ioff()
    mpl.use("agg")
    for x in range(max(out_subject.spatial_shape)):
        out_subject.plot(
            reorient=False,
            indices=(
                min(x, out_subject.spatial_shape[0] - 1),
                min(x, out_subject.spatial_shape[1] - 1),
                min(x, out_subject.spatial_shape[2] - 1),
            ),
            output_path=f"{save_plot_path}/{x:03d}.png",
            show=False,
            figsize=(10, 10),
        )
        plt.close("all")
    plt.ion()
    mpl.use(backend_)

    create_gifs(save_plot_path,
                f"{save_plot_path}/{os.path.basename(save_plot_path)}.gif")
Esempio n. 4
0
def plot_aggregated_image(
    writer: SummaryWriter,
    epoch: int,
    model: torch.nn.Module,
    data_loader: torch.utils.data.DataLoader,  # type: ignore
    device: torch.device,
    save_path: str,
):
    log = logging.getLogger(__name__)

    sampler, subject_id = random_subject_from_loader(data_loader)
    aggregator_x = GridAggregator(sampler)
    aggregator_y = GridAggregator(sampler)
    aggregator_y_pred = GridAggregator(sampler)
    for batch, locations in batches_from_sampler(sampler,
                                                 data_loader.batch_size):
        x: torch.Tensor = batch["image"]["data"]
        aggregator_x.add_batch(x, locations)
        y: torch.Tensor = batch["seg"]["data"]
        aggregator_y.add_batch(y, locations)

        logits = model(x.to(device))
        y_pred = (torch.sigmoid(logits) > 0.5).float()
        aggregator_y_pred.add_batch(y_pred, locations)

    whole_x = aggregator_x.get_output_tensor()
    whole_y = aggregator_y.get_output_tensor()
    whole_y_pred = aggregator_y_pred.get_output_tensor()

    plot_subject(
        Subject(
            image=ScalarImage(tensor=whole_x),
            true_seg=LabelMap(tensor=whole_y),
            pred_seg=LabelMap(tensor=whole_y_pred),
        ),
        f"{save_path}/{epoch}-{subject_id}",
    )