def image(self, tag, image, step=None): """Saves RGB image summary from onp.ndarray [H,W], [H,W,1], or [H,W,3]. Args: tag: str: label for this data image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors/ step: int: training step """ image = onp.array(image) if step is None: step = self._step else: self._step = step if len(onp.shape(image)) == 2: image = image[:, :, onp.newaxis] if onp.shape(image)[-1] == 1: image = onp.repeat(image, 3, axis=-1) image_strio = io.BytesIO() plt.imsave(image_strio, image, format='png') image_summary = Summary.Image( encoded_image_string=image_strio.getvalue(), colorspace=3, height=image.shape[0], width=image.shape[1]) summary = Summary(value=[Summary.Value(tag=tag, image=image_summary)]) self.add_summary(summary, step)
def plot(self, tag, mpl_plt, step=None, close_plot=True): """Saves matplotlib plot output to summary image. Args: tag: str: label for this data mpl_plt: matplotlib stateful pyplot object with prepared plotting state step: int: training step close_plot: bool: automatically closes plot """ if step is None: step = self._step else: self._step = step fig = mpl_plt.get_current_fig_manager() img_w, img_h = fig.canvas.get_width_height() image_buf = io.BytesIO() mpl_plt.savefig(image_buf, format='png') image_summary = Summary.Image( encoded_image_string=image_buf.getvalue(), colorspace=4, # RGBA height=img_h, width=img_w) summary = Summary(value=[Summary.Value(tag=tag, image=image_summary)]) self.add_summary(summary, step) if close_plot: mpl_plt.close()
def log_colorimages(tag, images, tagsuffix=''): img = images s = StringIO() plt.imsave(s, img, format='png') img_sum = Summary.Image(encoded_image_string=s.getvalue(), height=img.shape[0], width=img.shape[1]) return Summary( value=[Summary.Value(tag='%s%s' % (tag, tagsuffix), image=img_sum)])
def log_images(tag, images, tagsuffix=''): """ log_images Logs a list of images. """ def convert_to_uint8(img): return np.uint8(img * 255) if not type(images) == list: img = images s = StringIO() Image.fromarray(convert_to_uint8(img), mode='L').save(s, 'png') # Create an Image object img_res = Summary.Image(encoded_image_string=s.getvalue(), height=img.shape[0], width=img.shape[1], colorspace=1) return Summary(value=[ Summary.Value(tag='%s%s' % (tag, tagsuffix), image=img_res) ]) else: im_summaries = [] for nr, img in enumerate(images): # Write the image to a string s = StringIO() Image.fromarray(convert_to_uint8(img), mode='L').save(s, 'png') img_sum = Summary.Image( encoded_image_string=s.getvalue(), height=img.shape[0], width=img.shape[1], colorspace=1 ) #https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/core/framework/summary.proto # Create a Summary value im_summaries.append( Summary.Value(tag='%s/%d%s' % (tag, nr, tagsuffix), image=img_sum)) return Summary(value=im_summaries)
def to_summary(fig, tag): """ Convert a matplotlib figure ``fig`` into a TensorFlow Summary object that can be directly fed into ``Summary.FileWriter``. Example: >>> fig, ax = ... # (as above) >>> summary = to_summary(fig, tag='MyFigure/image') >>> type(summary) tensorflow.core.framework.summary_pb2.Summary >>> summary_writer.add_summary(summary, global_step=global_step) Args: fig: A ``matplotlib.figure.Figure`` object. tag (string): The tag name of the created summary. Returns: A TensorFlow ``Summary`` protobuf object containing the plot image as a image summary. """ if not isinstance(tag, six.string_types): raise TypeError("tag must be a string type") # attach a new agg canvas _old_canvas = fig.canvas try: canvas = FigureCanvasAgg(fig) canvas.draw() w, h = canvas.get_width_height() # get PNG data from the figure png_buffer = BytesIO() canvas.print_png(png_buffer) png_encoded = png_buffer.getvalue() png_buffer.close() summary_image = Summary.Image( height=h, width=w, colorspace=4, # RGB-A encoded_image_string=png_encoded) summary = Summary(value=[Summary.Value(tag=tag, image=summary_image)]) return summary finally: fig.canvas = _old_canvas
def log_plot(self, tag, figure, global_step): plot_buf = io.BytesIO() figure.savefig(plot_buf, format='png') plot_buf.seek(0) img = Image.open(plot_buf) img_ar = np.array(img) img_summary = Summary.Image(encoded_image_string=plot_buf.getvalue(), height=img_ar.shape[0], width=img_ar.shape[1]) summary = Summary() summary.value.add(tag=tag, image=img_summary) self.writer.add_summary(summary, global_step=global_step) self.writer.flush()
def log_image(file_writer, tensor, epoch_no, tag): height, width, channel = tensor.shape tensor = ((tensor + 1) * 255) tensor = tensor.astype('uint8') image = Image.fromarray(tensor) import io output = io.BytesIO() image.save(output, format='PNG') image_string = output.getvalue() output.close() tf_img = Summary.Image(height=height, width=width, colorspace=channel, encoded_image_string=image_string) summary = Summary(value=[Summary.Value(tag=tag, image=tf_img)]) file_writer.add_summary(summary, epoch_no) file_writer.flush()