def __init__(self, log_dir): """Create a new SummaryWriter. Args: log_dir: path to record tfevents files in. """ # If needed, create log_dir directory as well as missing parent directories. if not gfile.isdir(log_dir): gfile.makedirs(log_dir) self._event_writer = EventFileWriter(log_dir, 10, 120, None) self._closed = False
def __init__(self) -> None: super().__init__() self.writer = EventFileWriter(logdir=str(tensorboard.get_base_path( {})), filename_suffix=None) self.createSummary = tf.Summary # _seen_summary_tags is vendored from TensorFlow: tensorflow/python/summary/writer/writer.py # This set contains tags of Summary Values that have been encountered # already. The motivation here is that the SummaryWriter only keeps the # metadata property (which is a SummaryMetadata proto) of the first Summary # Value encountered for each tag. The SummaryWriter strips away the # SummaryMetadata for all subsequent Summary Values with tags seen # previously. This saves space. self._seen_summary_tags: Set[str] = set()
def __init__(self, log_dir, enable=True): """Create a new SummaryWriter. Args: log_dir: path to record tfevents files in. enable: bool: if False don't actually write or flush data. Used in multihost training. """ # If needed, create log_dir directory as well as missing parent directories. if not tf.io.gfile.isdir(log_dir): tf.io.gfile.makedirs(log_dir) self._event_writer = EventFileWriter(log_dir, 10, 120, None) self._step = 0 self._closed = False self._enabled = enable
def __init__(self, log_dir, enable=True): """Create a new SummaryWriter. Args: log_dir: path to record tfevents files in. enable: bool: if False don't actually write or flush data. Used in multihost training. """ # If needed, create log_dir directory as well as missing parent directories. if not tf.io.gfile.isdir(log_dir): tf.io.gfile.makedirs(log_dir) log_dir = os.path.join(log_dir, "{}".format(len(os.listdir(log_dir)))) logging.info("Created new logger at: {}".format(log_dir)) self.log_dir = log_dir self._event_writer = EventFileWriter(log_dir, 10, 120, None) self._step = 0 self._closed = False self._enabled = enable self._best_loss = float("inf")
def __init__(self, logdir, graph=None, max_queue=10, flush_secs=120, graph_def=None, filename_suffix=None): """Creates a `FileWriter` and an event file. On construction the summary writer creates a new event file in `logdir`. This event file will contain `Event` protocol buffers constructed when you call one of the following functions: `add_summary()`, `add_session_log()`, `add_event()`, or `add_graph()`. If you pass a `Graph` to the constructor it is added to the event file. (This is equivalent to calling `add_graph()` later). TensorBoard will pick the graph from the file and display it graphically so you can interactively explore the graph you built. You will usually pass the graph from the session in which you launched it: ```python ...create a graph... # Launch the graph in a session. sess = tf.Session() # Create a summary writer, add the 'graph' to the event file. writer = tf.summary.FileWriter(<some-directory>, sess.graph) ``` The other arguments to the constructor control the asynchronous writes to the event file: * `flush_secs`: How often, in seconds, to flush the added summaries and events to disk. * `max_queue`: Maximum number of summaries or events pending to be written to disk before one of the 'add' calls block. Args: logdir: A string. Directory where event file will be written. graph: A `Graph` object, such as `sess.graph`. max_queue: Integer. Size of the queue for pending events and summaries. flush_secs: Number. How often, in seconds, to flush the pending events and summaries to disk. graph_def: DEPRECATED: Use the `graph` argument instead. filename_suffix: A string. Every event file's name is suffixed with `suffix`. """ event_writer = EventFileWriter(logdir, max_queue, flush_secs, filename_suffix) super(FileWriter, self).__init__(event_writer, graph, graph_def)
def __init__(self, logdir, graph=None, max_queue=10, flush_secs=120, graph_def=None): event_writer = EventFileWriter(logdir, max_queue, flush_secs) super(LegacySummaryWriter, self).__init__(event_writer, graph, graph_def) # Proxy the event_writer public API onto the LegacySummaryWriter # this gives consistency with the tf.train.SummaryWriter API. self.get_logdir = self.event_writer.get_logdir self.add_event = self.event_writer.add_event self.flush = self.event_writer.flush self.close = self.event_writer.close self.reopen = self.event_writer.reopen
class SummaryWriter(object): """Saves data in event and summary protos for tensorboard.""" def __init__(self, log_dir): """Create a new SummaryWriter. Args: log_dir: path to record tfevents files in. """ # If needed, create log_dir directory as well as missing parent directories. if not gfile.isdir(log_dir): gfile.makedirs(log_dir) self._event_writer = EventFileWriter(log_dir, 10, 120, None) self._step = 0 self._closed = False def add_summary(self, summary, step): event = event_pb2.Event(summary=summary) event.wall_time = time.time() if step is not None: event.step = int(step) self._event_writer.add_event(event) def close(self): """Close SummaryWriter. Final!""" if not self._closed: self._event_writer.close() self._closed = True del self._event_writer def __del__(self): # safe? self.close() def flush(self): self._event_writer.flush() def scalar(self, tag, value, step=None): """Saves scalar value. Args: tag: str: label for this data value: int/float: number to log step: int: training step """ value = float(onp.array(value)) if step is None: step = self._step else: self._step = step summary = Summary(value=[Summary.Value(tag=tag, simple_value=value)]) self.add_summary(summary, step) def image(self, tag, image, step=None): """Saves RGB image summary from onp.ndarray [H,W], [H,W,1], or [H,W,3]. Args: tag: str: label for this data image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors/ step: int: training step """ image = onp.array(image) if step is None: step = self._step else: self._step = step if len(onp.shape(image)) == 2: image = image[:, :, onp.newaxis] if onp.shape(image)[-1] == 1: image = onp.repeat(image, 3, axis=-1) image_strio = io.BytesIO() plt.imsave(image_strio, image, format='png') image_summary = Summary.Image( encoded_image_string=image_strio.getvalue(), colorspace=3, height=image.shape[0], width=image.shape[1]) summary = Summary(value=[Summary.Value(tag=tag, image=image_summary)]) self.add_summary(summary, step) def images(self, tag, images, step=None, rows=None, cols=None): """Saves (rows, cols) tiled images from onp.ndarray. If either rows or cols aren't given, they are determined automatically from the size of the image batch, if neither are given a long column of images is produced. This truncates the image batch rather than padding if it doesn't fill the final row. Args: tag: str: label for this data images: ndarray: [N,H,W,1] or [N,H,W,3] to tile in 2d step: int: training step rows: int: number of rows in tile cols: int: number of columns in tile """ images = onp.array(images) if step is None: step = self._step else: self._step = step n_images = onp.shape(images)[0] if rows is None and cols is None: rows = 1 cols = n_images elif rows is None: rows = n_images // cols elif cols is None: cols = n_images // rows tiled_images = _pack_images(images, rows, cols) self.image(tag, tiled_images, step=step) def plot(self, tag, mpl_plt, step=None, close_plot=True): """Saves matplotlib plot output to summary image. Args: tag: str: label for this data mpl_plt: matplotlib stateful pyplot object with prepared plotting state step: int: training step close_plot: bool: automatically closes plot """ if step is None: step = self._step else: self._step = step fig = mpl_plt.get_current_fig_manager() img_w, img_h = fig.canvas.get_width_height() image_buf = io.BytesIO() mpl_plt.savefig(image_buf, format='png') image_summary = Summary.Image( encoded_image_string=image_buf.getvalue(), colorspace=4, # RGBA height=img_h, width=img_w) summary = Summary(value=[Summary.Value(tag=tag, image=image_summary)]) self.add_summary(summary, step) if close_plot: mpl_plt.close() def audio(self, tag, audiodata, step=None, sample_rate=44100): """Saves audio. NB: single channel only right now. Args: tag: str: label for this data audiodata: ndarray [Nsamples,]: data between (-1.0,1.0) to save as wave step: int: training step sample_rate: sample rate of passed in audio buffer """ audiodata = onp.array(audiodata) if step is None: step = self._step else: self._step = step audiodata = onp.clip(onp.squeeze(audiodata), -1, 1) if audiodata.ndim != 1: raise ValueError('Audio data must be 1D.') sample_list = (32767.0 * audiodata).astype(int).tolist() wio = io.BytesIO() wav_buf = wave.open(wio, 'wb') wav_buf.setnchannels(1) wav_buf.setsampwidth(2) wav_buf.setframerate(sample_rate) enc = b''.join([struct.pack('<h', v) for v in sample_list]) wav_buf.writeframes(enc) wav_buf.close() encoded_audio_bytes = wio.getvalue() wio.close() audio = Summary.Audio(sample_rate=sample_rate, num_channels=1, length_frames=len(sample_list), encoded_audio_string=encoded_audio_bytes, content_type='audio/wav') summary = Summary(value=[Summary.Value(tag=tag, audio=audio)]) self.add_summary(summary, step) def histogram(self, tag, values, bins, step=None): """Saves histogram of values. Args: tag: str: label for this data values: ndarray: will be flattened by this routine bins: number of bins in histogram, or array of bins for onp.histogram step: int: training step """ if step is None: step = self._step else: self._step = step values = onp.array(values) bins = onp.array(bins) values = onp.reshape(values, -1) counts, limits = onp.histogram(values, bins=bins) # boundary logic cum_counts = onp.cumsum(onp.greater(counts, 0, dtype=onp.int32)) start, end = onp.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side='right') start, end = int(start), int(end) + 1 counts = (counts[start - 1:end] if start > 0 else onp.concatenate([[0], counts[:end]])) limits = limits[start:end + 1] sum_sq = values.dot(values) histo = HistogramProto(min=values.min(), max=values.max(), num=len(values), sum=values.sum(), sum_squares=sum_sq, bucket_limit=limits.tolist(), bucket=counts.tolist()) summary = Summary(value=[Summary.Value(tag=tag, histo=histo)]) self.add_summary(summary, step) def text(self, tag, textdata, step=None): """Saves a text summary. Args: tag: str: label for this data textdata: string, or 1D/2D list/numpy array of strings step: int: training step Note: markdown formatting is rendered by tensorboard. """ if step is None: step = self._step else: self._step = step smd = SummaryMetadata(plugin_data=SummaryMetadata.PluginData( plugin_name='text')) if isinstance(textdata, (str, bytes)): tensor = tf.make_tensor_proto( values=[textdata.encode(encoding='utf_8')], shape=(1, )) else: textdata = onp.array(textdata) # convert lists, jax arrays, etc. datashape = onp.shape(textdata) if len(datashape) == 1: tensor = tf.make_tensor_proto( values=[td.encode(encoding='utf_8') for td in textdata], shape=(datashape[0], )) elif len(datashape) == 2: tensor = tf.make_tensor_proto(values=[ td.encode(encoding='utf_8') for td in onp.reshape(textdata, -1) ], shape=(datashape[0], datashape[1])) summary = Summary( value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)]) self.add_summary(summary, step)
class TFWriter(tensorboard.MetricWriter): """ TFWriter uses tensorflow file writers and summary operations to write out tfevent files containing scalar batch metrics. """ def __init__(self) -> None: super().__init__() self.writer = EventFileWriter(logdir=str(tensorboard.get_base_path( {})), filename_suffix=None) self.createSummary = tf.Summary # _seen_summary_tags is vendored from TensorFlow: tensorflow/python/summary/writer/writer.py # This set contains tags of Summary Values that have been encountered # already. The motivation here is that the SummaryWriter only keeps the # metadata property (which is a SummaryMetadata proto) of the first Summary # Value encountered for each tag. The SummaryWriter strips away the # SummaryMetadata for all subsequent Summary Values with tags seen # previously. This saves space. self._seen_summary_tags: Set[str] = set() def add_scalar(self, name: str, value: Union[int, float, np.number], step: int) -> None: summary = self.createSummary() summary_value = summary.value.add() summary_value.tag = name summary_value.simple_value = value self._add_summary(summary, step) def _add_summary(self, summary: Union[str, summary_pb2.Summary], global_step: Optional[int] = None) -> None: """ _add_summary is vendored from TensorFlow: tensorflow/python/summary/writer/writer.py Adds a `Summary` protocol buffer to the event file. This method wraps the provided summary in an `Event` protocol buffer and adds it to the event file. You can pass the result of evaluating any summary op, using `tf.Session.run` or `tf.Tensor.eval`, to this function. Alternatively, you can pass a `tf.compat.v1.Summary` protocol buffer that you populate with your own data. The latter is commonly done to report evaluation results in event files. Args: summary: A `Summary` protocol buffer, optionally serialized as a string. global_step: Number. Optional global step value to record with the summary. """ if isinstance(summary, bytes): summ = summary_pb2.Summary() summ.ParseFromString(summary) summary = summ # We strip metadata from values with tags that we have seen before in order # to save space - we just store the metadata on the first value with a # specific tag. for value in summary.value: if not value.metadata: continue if value.tag in self._seen_summary_tags: # This tag has been encountered before. Strip the metadata. value.ClearField("metadata") continue # We encounter a value with a tag we have not encountered previously. And # it has metadata. Remember to strip metadata from future values with this # tag string. self._seen_summary_tags.add(value.tag) event = event_pb2.Event(summary=summary) self._add_event(event, global_step) def _add_event(self, event: event_pb2.Event, step: Optional[int]) -> None: # _add_event is vendored from TensorFlow: tensorflow/python/summary/writer/writer.py event.wall_time = time.time() if step is not None: event.step = int(step) self.writer.add_event(event) def reset(self) -> None: self.writer.close() self.writer.reopen()
class SummaryWriter(object): """Saves data in event and summary protos for tensorboard.""" def __init__(self, log_dir): """Create a new SummaryWriter. Args: log_dir: path to record tfevents files in. """ # If needed, create log_dir directory as well as missing parent directories. if not gfile.isdir(log_dir): gfile.makedirs(log_dir) self._event_writer = EventFileWriter(log_dir, 10, 120, None) self._closed = False def _add_summary(self, summary, step): event = event_pb2.Event(summary=summary) event.wall_time = time.time() if step is not None: event.step = int(step) self._event_writer.add_event(event) def close(self): """Close SummaryWriter. Final!""" if not self._closed: self._event_writer.close() self._closed = True del self._event_writer def flush(self): self._event_writer.flush() def scalar(self, tag, value, step): """Saves scalar value. Args: tag: str: label for this data value: int/float: number to log step: int: training step """ value = float(onp.array(value)) summary = Summary(value=[Summary.Value(tag=tag, simple_value=value)]) self._add_summary(summary, step) def image(self, tag, image, step): """Saves RGB image summary from onp.ndarray [H,W], [H,W,1], or [H,W,3]. Args: tag: str: label for this data image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors. Pixel values should be in the range [0, 1]. step: int: training step """ image = onp.array(image) if len(onp.shape(image)) == 2: image = image[:, :, onp.newaxis] if onp.shape(image)[-1] == 1: image = onp.repeat(image, 3, axis=-1) image_strio = io.BytesIO() plt.imsave(image_strio, image, vmin=0., vmax=1., format='png') image_summary = Summary.Image( encoded_image_string=image_strio.getvalue(), colorspace=3, height=image.shape[0], width=image.shape[1]) summary = Summary(value=[Summary.Value(tag=tag, image=image_summary)]) self._add_summary(summary, step) def audio(self, tag, audiodata, step, sample_rate=44100): """Saves audio. NB: single channel only right now. Args: tag: str: label for this data audiodata: ndarray [Nsamples,]: audo data to be saves as wave. The data will be clipped to [-1, 1]. step: int: training step sample_rate: sample rate of passed in audio buffer """ audiodata = onp.array(audiodata) audiodata = onp.clip(onp.squeeze(audiodata), -1, 1) if audiodata.ndim != 1: raise ValueError('Audio data must be 1D.') # convert from [-1, 1] -> [-2^15-1, 2^15-1] sample_list = (32767.0 * audiodata).astype(int).tolist() wio = io.BytesIO() wav_buf = wave.open(wio, 'wb') wav_buf.setnchannels(1) wav_buf.setsampwidth(2) wav_buf.setframerate(sample_rate) enc = b''.join([struct.pack('<h', v) for v in sample_list]) wav_buf.writeframes(enc) wav_buf.close() encoded_audio_bytes = wio.getvalue() wio.close() audio = Summary.Audio(sample_rate=sample_rate, num_channels=1, length_frames=len(sample_list), encoded_audio_string=encoded_audio_bytes, content_type='audio/wav') summary = Summary(value=[Summary.Value(tag=tag, audio=audio)]) self._add_summary(summary, step) def histogram(self, tag, values, bins, step): """Saves histogram of values. Args: tag: str: label for this data values: ndarray: will be flattened by this routine bins: number of bins in histogram, or a sequence defining a monotonically increasing array of bin edges, including the rightmost edge. step: int: training step """ values = onp.array(values) bins = onp.array(bins) values = onp.reshape(values, -1) counts, limits = onp.histogram(values, bins=bins) # boundary logic # TODO(flax-dev) Investigate whether this logic can be simplified. cum_counts = onp.cumsum(onp.greater(counts, 0, dtype=onp.int32)) start, end = onp.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side='right') start, end = int(start), int(end) + 1 counts = (counts[start - 1:end] if start > 0 else onp.concatenate([[0], counts[:end]])) limits = limits[start:end + 1] sum_sq = values.dot(values) histo = HistogramProto(min=values.min(), max=values.max(), num=len(values), sum=values.sum(), sum_squares=sum_sq, bucket_limit=limits.tolist(), bucket=counts.tolist()) summary = Summary(value=[Summary.Value(tag=tag, histo=histo)]) self._add_summary(summary, step) def text(self, tag, textdata, step): """Saves a text summary. Args: tag: str: label for this data textdata: string, or 1D/2D list/numpy array of strings step: int: training step Note: markdown formatting is rendered by tensorboard. """ smd = SummaryMetadata(plugin_data=SummaryMetadata.PluginData( plugin_name='text')) if isinstance(textdata, (str, bytes)): tensor = tf.make_tensor_proto( values=[textdata.encode(encoding='utf_8')], shape=(1, )) else: textdata = onp.array(textdata) # convert lists, jax arrays, etc. datashape = onp.shape(textdata) if len(datashape) == 1: tensor = tf.make_tensor_proto( values=[td.encode(encoding='utf_8') for td in textdata], shape=(datashape[0], )) elif len(datashape) == 2: tensor = tf.make_tensor_proto(values=[ td.encode(encoding='utf_8') for td in onp.reshape(textdata, -1) ], shape=(datashape[0], datashape[1])) summary = Summary( value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)]) self._add_summary(summary, step)
def __init__(self, logdir, graph=None, max_queue=10, flush_secs=120, graph_def=None, filename_suffix=None, session=None): """Creates a `FileWriter`, optionally shared within the given session. Typically, constructing a file writer creates a new event file in `logdir`. This event file will contain `Event` protocol buffers constructed when you call one of the following functions: `add_summary()`, `add_session_log()`, `add_event()`, or `add_graph()`. If you pass a `Graph` to the constructor it is added to the event file. (This is equivalent to calling `add_graph()` later). TensorBoard will pick the graph from the file and display it graphically so you can interactively explore the graph you built. You will usually pass the graph from the session in which you launched it: ```python ...create a graph... # Launch the graph in a session. sess = tf.Session() # Create a summary writer, add the 'graph' to the event file. writer = tf.summary.FileWriter(<some-directory>, sess.graph) ``` The `session` argument to the constructor makes the returned `FileWriter` a compatibility layer over new graph-based summaries (`tf.contrib.summary`). Crucially, this means the underlying writer resource and events file will be shared with any other `FileWriter` using the same `session` and `logdir`, and with any `tf.contrib.summary.SummaryWriter` in this session using the the same shared resource name (which by default scoped to the logdir). If no such resource exists, one will be created using the remaining arguments to this constructor, but if one already exists those arguments are ignored. In either case, ops will be added to `session.graph` to control the underlying file writer resource. See `tf.contrib.summary` for more details. Args: logdir: A string. Directory where event file will be written. graph: A `Graph` object, such as `sess.graph`. max_queue: Integer. Size of the queue for pending events and summaries. flush_secs: Number. How often, in seconds, to flush the pending events and summaries to disk. graph_def: DEPRECATED: Use the `graph` argument instead. filename_suffix: A string. Every event file's name is suffixed with `suffix`. session: A `tf.Session` object. See details above. Raises: RuntimeError: If called with eager execution enabled. @compatibility(eager) `FileWriter` is not compatible with eager execution. To write TensorBoard summaries under eager execution, use `tf.contrib.summary` instead. @end_compatbility """ if context.executing_eagerly(): raise RuntimeError( "tf.summary.FileWriter is not compatible with eager execution. " "Use tf.contrib.summary instead.") if session is not None: event_writer = EventFileWriterV2( session, logdir, max_queue, flush_secs, filename_suffix) else: event_writer = EventFileWriter(logdir, max_queue, flush_secs, filename_suffix) super(FileWriter, self).__init__(event_writer, graph, graph_def)
class SummaryWriter(object): """Saves data in event and summary protos for tensorboard.""" def __init__(self, log_dir, enable=True): """Create a new SummaryWriter. Args: log_dir: path to record tfevents files in. enable: bool: if False don't actually write or flush data. Used in multihost training. """ # If needed, create log_dir directory as well as missing parent directories. if not tf.io.gfile.isdir(log_dir): tf.io.gfile.makedirs(log_dir) log_dir = os.path.join(log_dir, "{}".format(len(os.listdir(log_dir)))) logging.info("Created new logger at: {}".format(log_dir)) self.log_dir = log_dir self._event_writer = EventFileWriter(log_dir, 10, 120, None) self._step = 0 self._closed = False self._enabled = enable self._best_loss = float("inf") def add_summary(self, summary, step): if not self._enabled: return event = event_pb2.Event(summary=summary) event.wall_time = time.time() if step is not None: event.step = int(step) self._event_writer.add_event(event) def close(self): """Close SummaryWriter. Final!""" if not self._closed: self._event_writer.close() self._closed = True del self._event_writer def __del__(self): # safe? # TODO(afrozm): Sometimes this complains with # `TypeError: 'NoneType' object is not callable` try: self.close() except Exception: # pylint: disable=broad-except pass def flush(self): if not self._enabled: return self._event_writer.flush() def scalar(self, tag, value, step=None): """Saves scalar value. Args: tag: str: label for this data value: int/float: number to log step: int: training step """ value = float(np.array(value)) if step is None: step = self._step else: self._step = step summary = tf.compat.v1.Summary( value=[tf.compat.v1.Summary.Value(tag=tag, simple_value=value)]) self.add_summary(summary, step) def image(self, tag, image, step=None): """Saves RGB image summary from np.ndarray [H,W], [H,W,1], or [H,W,3]. Args: tag: str: label for this data image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors/ step: int: training step """ image = np.array(image) if step is None: step = self._step else: self._step = step if len(np.shape(image)) == 2: image = image[:, :, np.newaxis] if np.shape(image)[-1] == 1: image = np.repeat(image, 3, axis=-1) image_strio = io.BytesIO() plt.imsave(image_strio, image, format='png') image_summary = tf.compat.v1.Summary.Image( encoded_image_string=image_strio.getvalue(), colorspace=3, height=image.shape[0], width=image.shape[1]) summary = tf.compat.v1.Summary( value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)]) self.add_summary(summary, step) def images(self, tag, images, step=None, rows=None, cols=None): """Saves (rows, cols) tiled images from np.ndarray. If either rows or cols aren't given, they are determined automatically from the size of the image batch, if neither are given a long column of images is produced. This truncates the image batch rather than padding if it doesn't fill the final row. Args: tag: str: label for this data images: ndarray: [N,H,W,1] or [N,H,W,3] to tile in 2d step: int: training step rows: int: number of rows in tile cols: int: number of columns in tile """ images = np.array(images) if step is None: step = self._step else: self._step = step n_images = np.shape(images)[0] if rows is None and cols is None: rows = 1 cols = n_images elif rows is None: rows = n_images // cols elif cols is None: cols = n_images // rows tiled_images = _pack_images(images, rows, cols) self.image(tag, tiled_images, step=step) def plot(self, tag, mpl_plt, step=None, close_plot=True): """Saves matplotlib plot output to summary image. Args: tag: str: label for this data mpl_plt: matplotlib stateful pyplot object with prepared plotting state step: int: training step close_plot: bool: automatically closes plot """ if step is None: step = self._step else: self._step = step fig = mpl_plt.get_current_fig_manager() img_w, img_h = fig.canvas.get_width_height() image_buf = io.BytesIO() mpl_plt.savefig(image_buf, format='png') image_summary = tf.compat.v1.Summary.Image( encoded_image_string=image_buf.getvalue(), colorspace=4, # RGBA height=img_h, width=img_w) summary = tf.compat.v1.Summary( value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)]) self.add_summary(summary, step) if close_plot: mpl_plt.close() def figure(self, tag, fig, step=None, close_plot=True): """Saves matplotlib plot output to summary image. Args: tag: str: label for this data mpl_plt: matplotlib stateful pyplot object with prepared plotting state step: int: training step close_plot: bool: automatically closes plot """ if step is None: step = self._step else: self._step = step img_w, img_h = fig.canvas.get_width_height() image_buf = io.BytesIO() plt.savefig(image_buf, format='png') image_summary = tf.compat.v1.Summary.Image( encoded_image_string=image_buf.getvalue(), colorspace=4, # RGBA height=img_h, width=img_w) summary = tf.compat.v1.Summary( value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)]) self.add_summary(summary, step) if close_plot: plt.close(fig) def audio(self, tag, audiodata, step=None, sample_rate=44100): """Saves audio. NB: single channel only right now. Args: tag: str: label for this data audiodata: ndarray [Nsamples,]: data between (-1.0,1.0) to save as wave step: int: training step sample_rate: sample rate of passed in audio buffer """ audiodata = np.array(audiodata) if step is None: step = self._step else: self._step = step audiodata = np.clip(np.squeeze(audiodata), -1, 1) if audiodata.ndim != 1: raise ValueError('Audio data must be 1D.') sample_list = (32767.0 * audiodata).astype(int).tolist() wio = io.BytesIO() wav_buf = wave.open(wio, 'wb') wav_buf.setnchannels(1) wav_buf.setsampwidth(2) wav_buf.setframerate(sample_rate) enc = b''.join([struct.pack('<h', v) for v in sample_list]) wav_buf.writeframes(enc) wav_buf.close() encoded_audio_bytes = wio.getvalue() wio.close() audio = tf.compat.v1.Summary.Audio( sample_rate=sample_rate, num_channels=1, length_frames=len(sample_list), encoded_audio_string=encoded_audio_bytes, content_type='audio/wav') summary = tf.compat.v1.Summary( value=[tf.compat.v1.Summary.Value(tag=tag, audio=audio)]) self.add_summary(summary, step) def histogram(self, tag, values, bins, step=None): """Saves histogram of values. Args: tag: str: label for this data values: ndarray: will be flattened by this routine bins: number of bins in histogram, or array of bins for np.histogram step: int: training step """ if step is None: step = self._step else: self._step = step values = np.array(values) bins = np.array(bins) values = np.reshape(values, -1) counts, limits = np.histogram(values, bins=bins) # boundary logic cum_counts = np.cumsum(np.greater(counts, 0, dtype=np.int32)) start, end = np.searchsorted( cum_counts, [0, cum_counts[-1] - 1], side='right') start, end = int(start), int(end) + 1 counts = ( counts[start - 1:end] if start > 0 else np.concatenate([[0], counts[:end]])) limits = limits[start:end + 1] sum_sq = values.dot(values) histo = tf.compat.v1.HistogramProto( min=values.min(), max=values.max(), num=len(values), sum=values.sum(), sum_squares=sum_sq, bucket_limit=limits.tolist(), bucket=counts.tolist()) summary = tf.compat.v1.Summary( value=[tf.compat.v1.Summary.Value(tag=tag, histo=histo)]) self.add_summary(summary, step) def text(self, tag, textdata, step=None): """Saves a text summary. Args: tag: str: label for this data textdata: string, or 1D/2D list/numpy array of strings step: int: training step Note: markdown formatting is rendered by tensorboard. """ if step is None: step = self._step else: self._step = step smd = tf.compat.v1.SummaryMetadata( plugin_data=tf.compat.v1.SummaryMetadata.PluginData(plugin_name='text')) if isinstance(textdata, (str, bytes)): tensor = tf.make_tensor_proto( values=[textdata.encode(encoding='utf_8')], shape=(1,)) else: textdata = np.array(textdata) # convert lists, jax arrays, etc. datashape = np.shape(textdata) if len(datashape) == 1: tensor = tf.make_tensor_proto( values=[td.encode(encoding='utf_8') for td in textdata], shape=(datashape[0],)) elif len(datashape) == 2: tensor = tf.make_tensor_proto( values=[ td.encode(encoding='utf_8') for td in np.reshape(textdata, -1) ], shape=(datashape[0], datashape[1])) summary = tf.compat.v1.Summary( value=[tf.compat.v1.Summary.Value( tag=tag, metadata=smd, tensor=tensor)]) self.add_summary(summary, step) def checkpoint(self, tag, optimiser_state, step, loss=None): """Saves a copy of the model parameters using pickle Args: tag: str: label for this data params: a PyTree containing the parameters of the model step: int: training step monitor: metric to evaluate the performance of the model best_only: if True, models with a lower performance will not be saved """ parent = os.path.join(self.log_dir, "checkpoints") os.makedirs(parent, exist_ok=True) if not isinstance(optimiser_state, jax.experimental.optimizers.JoinPoint): optimiser_state = jax.experimental.optimizers.unpack_optimizer_state(optimiser_state) filepath_last = os.path.join(parent, str(tag) + "_last.pickle") with open(filepath_last, "wb") as f: pickle.dump(optimiser_state, f) if loss is None or loss > self._best_loss: return self._best_loss = loss filepath_best = os.path.join(parent, str(tag) + "_best.pickle") with open(filepath_best, "wb") as f: pickle.dump(optimiser_state, f) return True