Esempio n. 1
0
def plot_2d_or_3d_image(
    data: Union[torch.Tensor, np.ndarray],
    step: int,
    writer: SummaryWriter,
    index: int = 0,
    max_channels: int = 1,
    max_frames: int = 64,
    tag: str = "output",
) -> None:
    """Plot 2D or 3D image on the TensorBoard, 3D image will be converted to GIF image.

    Note:
        Plot 3D or 2D image(with more than 3 channels) as separate images.

    Args:
        data: target data to be plotted as image on the TensorBoard.
            The data is expected to have 'NCHW[D]' dimensions, and only plot the first in the batch.
        step: current step to plot in a chart.
        writer: specify TensorBoard SummaryWriter to plot the image.
        index: plot which element in the input data batch, default is the first element.
        max_channels: number of channels to plot.
        max_frames: number of frames for 2D-t plot.
        tag: tag of the plotted image on TensorBoard.
    """
    d = data[index].detach().cpu().numpy() if torch.is_tensor(
        data) else data[index]

    if d.ndim == 2:
        d = rescale_array(d, 0, 1)
        dataformats = "HW"
        writer.add_image(f"{tag}_{dataformats}",
                         d,
                         step,
                         dataformats=dataformats)
        return

    if d.ndim == 3:
        if d.shape[0] == 3 and max_channels == 3:  # RGB
            dataformats = "CHW"
            writer.add_image(f"{tag}_{dataformats}",
                             d,
                             step,
                             dataformats=dataformats)
            return
        for j, d2 in enumerate(d[:max_channels]):
            d2 = rescale_array(d2, 0, 1)
            dataformats = "HW"
            writer.add_image(f"{tag}_{dataformats}_{j}",
                             d2,
                             step,
                             dataformats=dataformats)
        return

    if d.ndim >= 4:
        spatial = d.shape[-3:]
        for j, d3 in enumerate(d.reshape([-1] + list(spatial))[:max_channels]):
            d3 = rescale_array(d3, 0, 255)
            add_animated_gif(writer, f"{tag}_HWD_{j}", d3[None], max_frames,
                             1.0, step)
        return
Esempio n. 2
0
    def __call__(self, engine: Engine, action) -> None:
        total_steps = engine.state.iteration
        if total_steps < engine.state.epoch_length:
            total_steps = engine.state.epoch_length * (engine.state.epoch -
                                                       1) + total_steps

        if action == "epoch" and not self.fold_size:
            epoch = engine.state.epoch
        elif self.fold_size and total_steps % self.fold_size == 0:
            epoch = int(total_steps / self.fold_size)
        else:
            epoch = None

        if epoch:
            if self.images and epoch % self.image_interval == 0:
                self.write_images(epoch)
            if self.add_scalar:
                self.write_region_metrics(epoch)

        if action == "epoch" or epoch:
            self.plot_data = {}
            self.metric_data = {}
            return

        device = engine.state.device
        batch_data = engine.state.batch
        output_data = engine.state.output

        for bidx in range(len(batch_data.get("region", []))):
            region = batch_data.get("region")[bidx]
            region = region.item() if torch.is_tensor(region) else region

            if self.images and self.plot_data.get(region) is None:
                self.plot_data[region] = [
                    rescale_array(
                        batch_data["image"][bidx][0].detach().cpu().numpy()[
                            np.newaxis], 0, 1),
                    rescale_array(
                        batch_data["label"][bidx].detach().cpu().numpy(), 0,
                        1),
                    rescale_array(
                        output_data["pred"][bidx].detach().cpu().numpy(), 0,
                        1),
                ]

            if self.compute_metric:
                if self.metric_data.get(region) is None:
                    self.metric_data[region] = RegionDice()
                self.metric_data[region].update(
                    y_pred=output_data["pred"][bidx].to(device),
                    y=batch_data["label"][bidx].to(device),
                    batched=False)
Esempio n. 3
0
    def write_images(self, epoch):
        if not self.plot_data or not len(self.plot_data):
            return

        all_imgs = []
        for region in sorted(self.plot_data.keys()):
            metric = self.metric_data.get(region)
            region_data = self.plot_data[region]
            if len(region_data[0].shape) == 3:
                ti = Image.new("RGB", region_data[0].shape[1:])
                d = ImageDraw.Draw(ti)
                t = "region: {}".format(region)
                if self.compute_metric:
                    t = t + "\ndice: {:.4f}".format(metric.mean())
                    t = t + "\nstdev: {:.4f}".format(metric.stdev())
                d.multiline_text((10, 10), t, fill=(255, 255, 0))
                ti = rescale_array(
                    np.rollaxis(np.array(ti), 2, 0)[0][np.newaxis])
                all_imgs.append(ti)
            all_imgs.extend(region_data)

        if len(all_imgs[0].shape) == 3:
            img_tensor = make_grid(tensor=torch.from_numpy(np.array(all_imgs)),
                                   nrow=4,
                                   normalize=True,
                                   pad_value=2)
            self.writer.add_image(tag=f"Deepgrow Regions ({self.tag_name})",
                                  img_tensor=img_tensor,
                                  global_step=epoch)

        if len(all_imgs[0].shape) == 4:
            for region in sorted(self.plot_data.keys()):
                tags = [
                    f"region_{region}_image", f"region_{region}_label",
                    f"region_{region}_output"
                ]
                if torch.distributed.is_initialized():
                    rank = "r{}-".format(torch.distributed.get_rank())
                    tags = [rank + tags[0], rank + tags[1], rank + tags[2]]
                for i in range(3):
                    img = self.plot_data[region][i]
                    img = np.moveaxis(img, -3, -1)
                    plot_2d_or_3d_image(img[np.newaxis], epoch, self.writer, 0,
                                        self.max_channels, self.max_frames,
                                        tags[i])

        self.logger.info(
            "Saved {} Regions {} into Tensorboard at epoch: {}".format(
                len(self.plot_data), sorted([*self.plot_data]), epoch))
        self.writer.flush()
Esempio n. 4
0
    def __call__(self, engine: Engine):
        batch_data = engine.state.batch
        output_data = engine.state.output
        device = engine.state.device
        tag = ""
        if torch.distributed.is_initialized():
            tag = "r{}-".format(torch.distributed.get_rank())

        for bidx in range(len(batch_data.get("image"))):
            step = engine.state.iteration
            region = batch_data.get("region")[bidx]
            region = region.item() if torch.is_tensor(region) else region

            image = batch_data["image"][bidx][0].detach().cpu().numpy()[
                np.newaxis]
            label = batch_data["label"][bidx].detach().cpu().numpy()
            pred = output_data["pred"][bidx].detach().cpu().numpy()
            dice = compute_meandice(
                y_pred=output_data["pred"][bidx][None].to(device),
                y=batch_data["label"][bidx][None].to(device),
                include_background=False,
            ).mean()

            if self.save_np:
                np.savez(
                    os.path.join(
                        self.output_dir,
                        "{}img_label_pred_{}_{:0>4d}_{:0>2d}_{:.4f}".format(
                            tag, region, step, bidx, dice),
                    ),
                    image,
                    label,
                    pred,
                )

            if self.images and len(image.shape) == 3:
                img = make_grid(torch.from_numpy(
                    rescale_array(image, 0, 1)[0]))
                lab = make_grid(torch.from_numpy(
                    rescale_array(label, 0, 1)[0]))

                pos = rescale_array(
                    output_data["image"][bidx][1].detach().cpu().numpy()[
                        np.newaxis], 0, 1)[0]
                neg = rescale_array(
                    output_data["image"][bidx][2].detach().cpu().numpy()[
                        np.newaxis], 0, 1)[0]
                pre = make_grid(
                    torch.from_numpy(
                        np.array([rescale_array(pred, 0, 1)[0], pos, neg])))

                torchvision.utils.save_image(
                    tensor=[img, lab, pre],
                    nrow=3,
                    pad_value=2,
                    fp=os.path.join(
                        self.output_dir,
                        "{}img_label_pred_{}_{:0>4d}_{:0>2d}_{:.4f}.png".
                        format(tag, region, step, bidx, dice),
                    ),
                )

            if self.images and len(image.shape) == 4:
                samples = {
                    "image": image[0],
                    "label": label[0],
                    "pred": pred[0]
                }
                for sample in samples:
                    img = np.moveaxis(samples[sample], -3, -1)
                    img = nib.Nifti1Image(img, np.eye(4))
                    nib.save(
                        img,
                        os.path.join(
                            self.output_dir,
                            "{}{}_{:0>4d}_{:0>2d}_{:.4f}.nii.gz".format(
                                tag, sample, step, bidx, dice)),
                    )
Esempio n. 5
0
def plot_2d_or_3d_image(
    data: Union[NdarrayTensor, List[NdarrayTensor]],
    step: int,
    writer: SummaryWriter,
    index: int = 0,
    max_channels: int = 1,
    frame_dim: int = -3,
    max_frames: int = 24,
    tag: str = "output",
) -> None:
    """Plot 2D or 3D image on the TensorBoard, 3D image will be converted to GIF image.

    Note:
        Plot 3D or 2D image(with more than 3 channels) as separate images.
        And if writer is from TensorBoardX, data has 3 channels and `max_channels=3`, will plot as RGB video.

    Args:
        data: target data to be plotted as image on the TensorBoard.
            The data is expected to have 'NCHW[D]' dimensions or a list of data with `CHW[D]` dimensions,
            and only plot the first in the batch.
        step: current step to plot in a chart.
        writer: specify TensorBoard or TensorBoardX SummaryWriter to plot the image.
        index: plot which element in the input data batch, default is the first element.
        max_channels: number of channels to plot.
        frame_dim: if plotting 3D image as GIF, specify the dimension used as frames,
            expect input data shape as `NCHWD`, default to `-3` (the first spatial dim)
        max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`.
        tag: tag of the plotted image on TensorBoard.
    """
    data_index = data[index]
    # as the `d` data has no batch dim, reduce the spatial dim index if positive
    frame_dim = frame_dim - 1 if frame_dim > 0 else frame_dim

    d: np.ndarray = data_index.detach().cpu().numpy() if isinstance(
        data_index, torch.Tensor) else data_index

    if d.ndim == 2:
        d = rescale_array(d, 0, 1)  # type: ignore
        dataformats = "HW"
        writer.add_image(f"{tag}_{dataformats}",
                         d,
                         step,
                         dataformats=dataformats)
        return

    if d.ndim == 3:
        if d.shape[0] == 3 and max_channels == 3:  # RGB
            dataformats = "CHW"
            writer.add_image(f"{tag}_{dataformats}",
                             d,
                             step,
                             dataformats=dataformats)
            return
        dataformats = "HW"
        for j, d2 in enumerate(d[:max_channels]):
            d2 = rescale_array(d2, 0, 1)
            writer.add_image(f"{tag}_{dataformats}_{j}",
                             d2,
                             step,
                             dataformats=dataformats)
        return

    if d.ndim >= 4:
        spatial = d.shape[-3:]
        d = d.reshape([-1] + list(spatial))
        if d.shape[
                0] == 3 and max_channels == 3 and has_tensorboardx and isinstance(
                    writer, SummaryWriterX):  # RGB
            # move the expected frame dim to the end as `T` dim for video
            d = np.moveaxis(d, frame_dim, -1)
            writer.add_video(tag,
                             d[None],
                             step,
                             fps=max_frames,
                             dataformats="NCHWT")
            return
        # scale data to 0 - 255 for visualization
        max_channels = min(max_channels, d.shape[0])
        d = np.stack([rescale_array(i, 0, 255) for i in d[:max_channels]],
                     axis=0)
        # will plot every channel as a separate GIF image
        add_animated_gif(writer,
                         f"{tag}_HWD",
                         d,
                         max_out=max_channels,
                         frame_dim=frame_dim,
                         global_step=step)
        return