def on_epoch_end(self, data: Data) -> None: for key in self.inputs: if key in data: imgs = data[key] if isinstance(imgs, ImgData): fig = imgs.paint_numpy(dpi=96) plt.imshow(fig[0]) plt.axis('off') plt.tight_layout() plt.show() else: for idx, img in enumerate(imgs): show_image(img, title="{}_{}".format(key, idx)) plt.show()
def _save_images(self, data: Data): for key in self.inputs: if key in data: imgs = data[key] im_path = os.path.join(self.save_dir, "{}_{}_epoch_{}.png".format(key, self.system.mode, self.system.epoch_idx)) if isinstance(imgs, ImgData): f = imgs.paint_figure() plt.savefig(im_path, dpi=self.dpi, bbox_inches="tight") plt.close(f) print("FastEstimator-ImageSaver: saved image to {}".format(im_path)) elif isinstance(imgs, Summary): visualize_logs([imgs], save_path=im_path, dpi=self.dpi, verbose=False) print("FastEstimator-ImageSaver: saved image to {}".format(im_path)) elif isinstance(imgs, (list, tuple)) and all([isinstance(img, Summary) for img in imgs]): visualize_logs(imgs, save_path=im_path, dpi=self.dpi, verbose=False) print("FastEstimator-ImageSaver: saved image to {}".format(im_path)) else: for idx, img in enumerate(imgs): f = show_image(img, title=key) im_path = os.path.join( self.save_dir, "{}_{}_epoch_{}_elem_{}.png".format(key, self.system.mode, self.system.epoch_idx, idx)) plt.savefig(im_path, dpi=self.dpi, bbox_inches="tight") plt.close(f) print("FastEstimator-ImageSaver: saved image to {}".format(im_path))
def on_epoch_end(self, data: Data) -> None: for key in self.inputs: if key in data: imgs = data[key] if isinstance(imgs, ImgData): f = imgs.paint_figure() im_path = os.path.join( self.save_dir, "{}_{}_epoch_{}.png".format(key, self.system.mode, self.system.epoch_idx)) plt.savefig(im_path, dpi=self.dpi, bbox_inches="tight") plt.close(f) print("FastEstimator-ImageSaver: saved image to {}".format( im_path)) else: for idx, img in enumerate(imgs): f = show_image(img, title=key) im_path = os.path.join( self.save_dir, "{}_{}_epoch_{}_elem_{}.png".format( key, self.system.mode, self.system.epoch_idx, idx)) plt.savefig(im_path, dpi=self.dpi, bbox_inches="tight") plt.close(f) print("FastEstimator-ImageSaver: saved image to {}". format(im_path))
def _display_images(self, data: Data) -> None: """A method to render images to the screen. Args: data: Data possibly containing images to render. """ for key in self.inputs: if key in data: imgs = data[key] if isinstance(imgs, ImgData): fig = imgs.paint_numpy(dpi=96) plt.imshow(fig[0]) plt.axis('off') plt.tight_layout() plt.show() elif isinstance(imgs, Summary): visualize_logs([imgs]) elif isinstance(imgs, (list, tuple)) and all( [isinstance(img, Summary) for img in imgs]): visualize_logs(imgs) else: for idx, img in enumerate(imgs): show_image(img, title="{}_{}".format(key, idx)) plt.show()
def paint_figure(self, height_gap: int = 100, min_height: int = 200, width_gap: int = 50, min_width: int = 200, dpi: int = 96, save_path: Optional[str] = None) -> plt.Figure: """Visualize the current ImgData entries in a matplotlib figure. ```python d = fe.util.ImgData(y=tf.ones((4,)), x=0.5*tf.ones((4, 32, 32, 3))) fig = d.paint_figure() plt.show() ``` Args: height_gap: How much space to put between each row. min_height: The minimum height of a row. width_gap: How much space to put between each column. min_width: The minimum width of a column. dpi: The resolution of the image to display. save_path: If provided, the figure will be saved to the given path. Returns: The handle to the generated matplotlib figure. """ total_width = self._total_width(gap=width_gap, min_width=min_width) total_height = self._total_height(gap=height_gap, min_height=min_height) fig = plt.figure(figsize=(total_width / dpi, total_height / dpi), dpi=dpi) grid = self._to_grid() # TODO - elements with batch size = 1 should be laid out in a grid like for plotting for row_idx, (start_height, end_height) in enumerate( self._heights(gap=height_gap, min_height=min_height)): row = grid[row_idx] batch_size = self._batch_size(row_idx) gs = GridSpec(nrows=batch_size, ncols=total_width, figure=fig, left=0.0, right=1.0, bottom=start_height / total_height, top=end_height / total_height, hspace=0.05, wspace=0.0) for batch_idx in range(batch_size): for col_idx, width in enumerate( self._widths(row=row_idx, gap=width_gap, min_width=min_width)): ax = fig.add_subplot(gs[batch_idx, width[0]:width[1]]) img_stack = [elem[batch_idx] for elem in row[col_idx][1]] for idx, img in enumerate(img_stack): show_image(img, axis=ax, fig=fig, title=row[col_idx][0] if (batch_idx == 0 and idx == 0) else None, stack_depth=idx, color_map=self.colormap) if save_path: plt.savefig(save_path, dpi=dpi, bbox_inches="tight") return fig