Beispiel #1
0
 def on_epoch_end(self, data: Data) -> None:
     for key in self.inputs:
         if key in data:
             imgs = data[key]
             if isinstance(imgs, ImgData):
                 fig = imgs.paint_numpy(dpi=96)
                 plt.imshow(fig[0])
                 plt.axis('off')
                 plt.tight_layout()
                 plt.show()
             else:
                 for idx, img in enumerate(imgs):
                     show_image(img, title="{}_{}".format(key, idx))
                     plt.show()
Beispiel #2
0
 def _save_images(self, data: Data):
     for key in self.inputs:
         if key in data:
             imgs = data[key]
             im_path = os.path.join(self.save_dir,
                                    "{}_{}_epoch_{}.png".format(key, self.system.mode, self.system.epoch_idx))
             if isinstance(imgs, ImgData):
                 f = imgs.paint_figure()
                 plt.savefig(im_path, dpi=self.dpi, bbox_inches="tight")
                 plt.close(f)
                 print("FastEstimator-ImageSaver: saved image to {}".format(im_path))
             elif isinstance(imgs, Summary):
                 visualize_logs([imgs], save_path=im_path, dpi=self.dpi, verbose=False)
                 print("FastEstimator-ImageSaver: saved image to {}".format(im_path))
             elif isinstance(imgs, (list, tuple)) and all([isinstance(img, Summary) for img in imgs]):
                 visualize_logs(imgs, save_path=im_path, dpi=self.dpi, verbose=False)
                 print("FastEstimator-ImageSaver: saved image to {}".format(im_path))
             else:
                 for idx, img in enumerate(imgs):
                     f = show_image(img, title=key)
                     im_path = os.path.join(
                         self.save_dir,
                         "{}_{}_epoch_{}_elem_{}.png".format(key, self.system.mode, self.system.epoch_idx, idx))
                     plt.savefig(im_path, dpi=self.dpi, bbox_inches="tight")
                     plt.close(f)
                     print("FastEstimator-ImageSaver: saved image to {}".format(im_path))
 def on_epoch_end(self, data: Data) -> None:
     for key in self.inputs:
         if key in data:
             imgs = data[key]
             if isinstance(imgs, ImgData):
                 f = imgs.paint_figure()
                 im_path = os.path.join(
                     self.save_dir,
                     "{}_{}_epoch_{}.png".format(key, self.system.mode,
                                                 self.system.epoch_idx))
                 plt.savefig(im_path, dpi=self.dpi, bbox_inches="tight")
                 plt.close(f)
                 print("FastEstimator-ImageSaver: saved image to {}".format(
                     im_path))
             else:
                 for idx, img in enumerate(imgs):
                     f = show_image(img, title=key)
                     im_path = os.path.join(
                         self.save_dir, "{}_{}_epoch_{}_elem_{}.png".format(
                             key, self.system.mode, self.system.epoch_idx,
                             idx))
                     plt.savefig(im_path, dpi=self.dpi, bbox_inches="tight")
                     plt.close(f)
                     print("FastEstimator-ImageSaver: saved image to {}".
                           format(im_path))
Beispiel #4
0
    def _display_images(self, data: Data) -> None:
        """A method to render images to the screen.

        Args:
            data: Data possibly containing images to render.
        """
        for key in self.inputs:
            if key in data:
                imgs = data[key]
                if isinstance(imgs, ImgData):
                    fig = imgs.paint_numpy(dpi=96)
                    plt.imshow(fig[0])
                    plt.axis('off')
                    plt.tight_layout()
                    plt.show()
                elif isinstance(imgs, Summary):
                    visualize_logs([imgs])
                elif isinstance(imgs, (list, tuple)) and all(
                    [isinstance(img, Summary) for img in imgs]):
                    visualize_logs(imgs)
                else:
                    for idx, img in enumerate(imgs):
                        show_image(img, title="{}_{}".format(key, idx))
                        plt.show()
Beispiel #5
0
    def paint_figure(self,
                     height_gap: int = 100,
                     min_height: int = 200,
                     width_gap: int = 50,
                     min_width: int = 200,
                     dpi: int = 96,
                     save_path: Optional[str] = None) -> plt.Figure:
        """Visualize the current ImgData entries in a matplotlib figure.

        ```python
        d = fe.util.ImgData(y=tf.ones((4,)), x=0.5*tf.ones((4, 32, 32, 3)))
        fig = d.paint_figure()
        plt.show()
        ```

        Args:
            height_gap: How much space to put between each row.
            min_height: The minimum height of a row.
            width_gap: How much space to put between each column.
            min_width: The minimum width of a column.
            dpi: The resolution of the image to display.
            save_path: If provided, the figure will be saved to the given path.

        Returns:
            The handle to the generated matplotlib figure.
        """
        total_width = self._total_width(gap=width_gap, min_width=min_width)
        total_height = self._total_height(gap=height_gap,
                                          min_height=min_height)

        fig = plt.figure(figsize=(total_width / dpi, total_height / dpi),
                         dpi=dpi)

        grid = self._to_grid()
        # TODO - elements with batch size = 1 should be laid out in a grid like for plotting
        for row_idx, (start_height, end_height) in enumerate(
                self._heights(gap=height_gap, min_height=min_height)):
            row = grid[row_idx]
            batch_size = self._batch_size(row_idx)
            gs = GridSpec(nrows=batch_size,
                          ncols=total_width,
                          figure=fig,
                          left=0.0,
                          right=1.0,
                          bottom=start_height / total_height,
                          top=end_height / total_height,
                          hspace=0.05,
                          wspace=0.0)
            for batch_idx in range(batch_size):
                for col_idx, width in enumerate(
                        self._widths(row=row_idx,
                                     gap=width_gap,
                                     min_width=min_width)):
                    ax = fig.add_subplot(gs[batch_idx, width[0]:width[1]])
                    img_stack = [elem[batch_idx] for elem in row[col_idx][1]]
                    for idx, img in enumerate(img_stack):
                        show_image(img,
                                   axis=ax,
                                   fig=fig,
                                   title=row[col_idx][0] if
                                   (batch_idx == 0 and idx == 0) else None,
                                   stack_depth=idx,
                                   color_map=self.colormap)
        if save_path:
            plt.savefig(save_path, dpi=dpi, bbox_inches="tight")
        return fig