Esempio n. 1
0
    def __init__(self, log_dir):
        """Create a new SummaryWriter.

    Args:
      log_dir: path to record tfevents files in.
    """
        # If needed, create log_dir directory as well as missing parent directories.
        if not gfile.isdir(log_dir):
            gfile.makedirs(log_dir)

        self._event_writer = EventFileWriter(log_dir, 10, 120, None)
        self._closed = False
Esempio n. 2
0
    def __init__(self) -> None:
        super().__init__()
        self.writer = EventFileWriter(logdir=str(tensorboard.get_base_path(
            {})),
                                      filename_suffix=None)
        self.createSummary = tf.Summary

        # _seen_summary_tags is vendored from TensorFlow: tensorflow/python/summary/writer/writer.py
        # This set contains tags of Summary Values that have been encountered
        # already. The motivation here is that the SummaryWriter only keeps the
        # metadata property (which is a SummaryMetadata proto) of the first Summary
        # Value encountered for each tag. The SummaryWriter strips away the
        # SummaryMetadata for all subsequent Summary Values with tags seen
        # previously. This saves space.
        self._seen_summary_tags: Set[str] = set()
Esempio n. 3
0
  def __init__(self, log_dir, enable=True):
    """Create a new SummaryWriter.

    Args:
      log_dir: path to record tfevents files in.
      enable: bool: if False don't actually write or flush data.  Used in
        multihost training.
    """
    # If needed, create log_dir directory as well as missing parent directories.
    if not tf.io.gfile.isdir(log_dir):
      tf.io.gfile.makedirs(log_dir)

    self._event_writer = EventFileWriter(log_dir, 10, 120, None)
    self._step = 0
    self._closed = False
    self._enabled = enable
Esempio n. 4
0
    def __init__(self, log_dir, enable=True):
        """Create a new SummaryWriter.

        Args:
            log_dir: path to record tfevents files in.
            enable: bool: if False don't actually write or flush data.    Used in
                multihost training.
        """
        # If needed, create log_dir directory as well as missing parent directories.
        if not tf.io.gfile.isdir(log_dir):
            tf.io.gfile.makedirs(log_dir)

        log_dir = os.path.join(log_dir, "{}".format(len(os.listdir(log_dir))))
        logging.info("Created new logger at: {}".format(log_dir))

        self.log_dir = log_dir
        self._event_writer = EventFileWriter(log_dir, 10, 120, None)
        self._step = 0
        self._closed = False
        self._enabled = enable
        self._best_loss = float("inf")
Esempio n. 5
0
  def __init__(self,
               logdir,
               graph=None,
               max_queue=10,
               flush_secs=120,
               graph_def=None,
               filename_suffix=None):
    """Creates a `FileWriter` and an event file.

    On construction the summary writer creates a new event file in `logdir`.
    This event file will contain `Event` protocol buffers constructed when you
    call one of the following functions: `add_summary()`, `add_session_log()`,
    `add_event()`, or `add_graph()`.

    If you pass a `Graph` to the constructor it is added to
    the event file. (This is equivalent to calling `add_graph()` later).

    TensorBoard will pick the graph from the file and display it graphically so
    you can interactively explore the graph you built. You will usually pass
    the graph from the session in which you launched it:

    ```python
    ...create a graph...
    # Launch the graph in a session.
    sess = tf.Session()
    # Create a summary writer, add the 'graph' to the event file.
    writer = tf.summary.FileWriter(<some-directory>, sess.graph)
    ```

    The other arguments to the constructor control the asynchronous writes to
    the event file:

    *  `flush_secs`: How often, in seconds, to flush the added summaries
       and events to disk.
    *  `max_queue`: Maximum number of summaries or events pending to be
       written to disk before one of the 'add' calls block.

    Args:
      logdir: A string. Directory where event file will be written.
      graph: A `Graph` object, such as `sess.graph`.
      max_queue: Integer. Size of the queue for pending events and summaries.
      flush_secs: Number. How often, in seconds, to flush the
        pending events and summaries to disk.
      graph_def: DEPRECATED: Use the `graph` argument instead.
      filename_suffix: A string. Every event file's name is suffixed with
        `suffix`.
    """
    event_writer = EventFileWriter(logdir, max_queue, flush_secs,
                                   filename_suffix)
    super(FileWriter, self).__init__(event_writer, graph, graph_def)
Esempio n. 6
0
 def __init__(self,
              logdir,
              graph=None,
              max_queue=10,
              flush_secs=120,
              graph_def=None):
     event_writer = EventFileWriter(logdir, max_queue, flush_secs)
     super(LegacySummaryWriter, self).__init__(event_writer, graph,
                                               graph_def)
     # Proxy the event_writer public API onto the LegacySummaryWriter
     # this gives consistency with the tf.train.SummaryWriter API.
     self.get_logdir = self.event_writer.get_logdir
     self.add_event = self.event_writer.add_event
     self.flush = self.event_writer.flush
     self.close = self.event_writer.close
     self.reopen = self.event_writer.reopen
Esempio n. 7
0
class SummaryWriter(object):
    """Saves data in event and summary protos for tensorboard."""
    def __init__(self, log_dir):
        """Create a new SummaryWriter.

    Args:
      log_dir: path to record tfevents files in.
    """
        # If needed, create log_dir directory as well as missing parent directories.
        if not gfile.isdir(log_dir):
            gfile.makedirs(log_dir)

        self._event_writer = EventFileWriter(log_dir, 10, 120, None)
        self._step = 0
        self._closed = False

    def add_summary(self, summary, step):
        event = event_pb2.Event(summary=summary)
        event.wall_time = time.time()
        if step is not None:
            event.step = int(step)
        self._event_writer.add_event(event)

    def close(self):
        """Close SummaryWriter. Final!"""
        if not self._closed:
            self._event_writer.close()
            self._closed = True
            del self._event_writer

    def __del__(self):  # safe?
        self.close()

    def flush(self):
        self._event_writer.flush()

    def scalar(self, tag, value, step=None):
        """Saves scalar value.

    Args:
      tag: str: label for this data
      value: int/float: number to log
      step: int: training step
    """
        value = float(onp.array(value))
        if step is None:
            step = self._step
        else:
            self._step = step
        summary = Summary(value=[Summary.Value(tag=tag, simple_value=value)])
        self.add_summary(summary, step)

    def image(self, tag, image, step=None):
        """Saves RGB image summary from onp.ndarray [H,W], [H,W,1], or [H,W,3].

    Args:
      tag: str: label for this data
      image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors/
      step: int: training step
    """
        image = onp.array(image)
        if step is None:
            step = self._step
        else:
            self._step = step
        if len(onp.shape(image)) == 2:
            image = image[:, :, onp.newaxis]
        if onp.shape(image)[-1] == 1:
            image = onp.repeat(image, 3, axis=-1)
        image_strio = io.BytesIO()
        plt.imsave(image_strio, image, format='png')
        image_summary = Summary.Image(
            encoded_image_string=image_strio.getvalue(),
            colorspace=3,
            height=image.shape[0],
            width=image.shape[1])
        summary = Summary(value=[Summary.Value(tag=tag, image=image_summary)])
        self.add_summary(summary, step)

    def images(self, tag, images, step=None, rows=None, cols=None):
        """Saves (rows, cols) tiled images from onp.ndarray.

    If either rows or cols aren't given, they are determined automatically
    from the size of the image batch, if neither are given a long column
    of images is produced. This truncates the image batch rather than padding
    if it doesn't fill the final row.

    Args:
      tag: str: label for this data
      images: ndarray: [N,H,W,1] or [N,H,W,3] to tile in 2d
      step: int: training step
      rows: int: number of rows in tile
      cols: int: number of columns in tile
    """
        images = onp.array(images)
        if step is None:
            step = self._step
        else:
            self._step = step
        n_images = onp.shape(images)[0]
        if rows is None and cols is None:
            rows = 1
            cols = n_images
        elif rows is None:
            rows = n_images // cols
        elif cols is None:
            cols = n_images // rows
        tiled_images = _pack_images(images, rows, cols)
        self.image(tag, tiled_images, step=step)

    def plot(self, tag, mpl_plt, step=None, close_plot=True):
        """Saves matplotlib plot output to summary image.

    Args:
      tag: str: label for this data
      mpl_plt: matplotlib stateful pyplot object with prepared plotting state
      step: int: training step
      close_plot: bool: automatically closes plot
    """
        if step is None:
            step = self._step
        else:
            self._step = step
        fig = mpl_plt.get_current_fig_manager()
        img_w, img_h = fig.canvas.get_width_height()
        image_buf = io.BytesIO()
        mpl_plt.savefig(image_buf, format='png')
        image_summary = Summary.Image(
            encoded_image_string=image_buf.getvalue(),
            colorspace=4,  # RGBA
            height=img_h,
            width=img_w)
        summary = Summary(value=[Summary.Value(tag=tag, image=image_summary)])
        self.add_summary(summary, step)
        if close_plot:
            mpl_plt.close()

    def audio(self, tag, audiodata, step=None, sample_rate=44100):
        """Saves audio.

    NB: single channel only right now.

    Args:
      tag: str: label for this data
      audiodata: ndarray [Nsamples,]: data between (-1.0,1.0) to save as wave
      step: int: training step
      sample_rate: sample rate of passed in audio buffer
    """
        audiodata = onp.array(audiodata)
        if step is None:
            step = self._step
        else:
            self._step = step
        audiodata = onp.clip(onp.squeeze(audiodata), -1, 1)
        if audiodata.ndim != 1:
            raise ValueError('Audio data must be 1D.')
        sample_list = (32767.0 * audiodata).astype(int).tolist()
        wio = io.BytesIO()
        wav_buf = wave.open(wio, 'wb')
        wav_buf.setnchannels(1)
        wav_buf.setsampwidth(2)
        wav_buf.setframerate(sample_rate)
        enc = b''.join([struct.pack('<h', v) for v in sample_list])
        wav_buf.writeframes(enc)
        wav_buf.close()
        encoded_audio_bytes = wio.getvalue()
        wio.close()
        audio = Summary.Audio(sample_rate=sample_rate,
                              num_channels=1,
                              length_frames=len(sample_list),
                              encoded_audio_string=encoded_audio_bytes,
                              content_type='audio/wav')
        summary = Summary(value=[Summary.Value(tag=tag, audio=audio)])
        self.add_summary(summary, step)

    def histogram(self, tag, values, bins, step=None):
        """Saves histogram of values.

    Args:
      tag: str: label for this data
      values: ndarray: will be flattened by this routine
      bins: number of bins in histogram, or array of bins for onp.histogram
      step: int: training step
    """
        if step is None:
            step = self._step
        else:
            self._step = step
        values = onp.array(values)
        bins = onp.array(bins)
        values = onp.reshape(values, -1)
        counts, limits = onp.histogram(values, bins=bins)
        # boundary logic
        cum_counts = onp.cumsum(onp.greater(counts, 0, dtype=onp.int32))
        start, end = onp.searchsorted(cum_counts, [0, cum_counts[-1] - 1],
                                      side='right')
        start, end = int(start), int(end) + 1
        counts = (counts[start - 1:end]
                  if start > 0 else onp.concatenate([[0], counts[:end]]))
        limits = limits[start:end + 1]
        sum_sq = values.dot(values)
        histo = HistogramProto(min=values.min(),
                               max=values.max(),
                               num=len(values),
                               sum=values.sum(),
                               sum_squares=sum_sq,
                               bucket_limit=limits.tolist(),
                               bucket=counts.tolist())
        summary = Summary(value=[Summary.Value(tag=tag, histo=histo)])
        self.add_summary(summary, step)

    def text(self, tag, textdata, step=None):
        """Saves a text summary.

    Args:
      tag: str: label for this data
      textdata: string, or 1D/2D list/numpy array of strings
      step: int: training step
    Note: markdown formatting is rendered by tensorboard.
    """
        if step is None:
            step = self._step
        else:
            self._step = step
        smd = SummaryMetadata(plugin_data=SummaryMetadata.PluginData(
            plugin_name='text'))
        if isinstance(textdata, (str, bytes)):
            tensor = tf.make_tensor_proto(
                values=[textdata.encode(encoding='utf_8')], shape=(1, ))
        else:
            textdata = onp.array(textdata)  # convert lists, jax arrays, etc.
            datashape = onp.shape(textdata)
            if len(datashape) == 1:
                tensor = tf.make_tensor_proto(
                    values=[td.encode(encoding='utf_8') for td in textdata],
                    shape=(datashape[0], ))
            elif len(datashape) == 2:
                tensor = tf.make_tensor_proto(values=[
                    td.encode(encoding='utf_8')
                    for td in onp.reshape(textdata, -1)
                ],
                                              shape=(datashape[0],
                                                     datashape[1]))
        summary = Summary(
            value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
        self.add_summary(summary, step)
Esempio n. 8
0
class TFWriter(tensorboard.MetricWriter):
    """
    TFWriter uses tensorflow file writers and summary operations to write out
    tfevent files containing scalar batch metrics.
    """
    def __init__(self) -> None:
        super().__init__()
        self.writer = EventFileWriter(logdir=str(tensorboard.get_base_path(
            {})),
                                      filename_suffix=None)
        self.createSummary = tf.Summary

        # _seen_summary_tags is vendored from TensorFlow: tensorflow/python/summary/writer/writer.py
        # This set contains tags of Summary Values that have been encountered
        # already. The motivation here is that the SummaryWriter only keeps the
        # metadata property (which is a SummaryMetadata proto) of the first Summary
        # Value encountered for each tag. The SummaryWriter strips away the
        # SummaryMetadata for all subsequent Summary Values with tags seen
        # previously. This saves space.
        self._seen_summary_tags: Set[str] = set()

    def add_scalar(self, name: str, value: Union[int, float, np.number],
                   step: int) -> None:
        summary = self.createSummary()
        summary_value = summary.value.add()
        summary_value.tag = name
        summary_value.simple_value = value
        self._add_summary(summary, step)

    def _add_summary(self,
                     summary: Union[str, summary_pb2.Summary],
                     global_step: Optional[int] = None) -> None:
        """
        _add_summary is vendored from TensorFlow: tensorflow/python/summary/writer/writer.py

        Adds a `Summary` protocol buffer to the event file.

        This method wraps the provided summary in an `Event` protocol buffer
        and adds it to the event file.

        You can pass the result of evaluating any summary op, using
        `tf.Session.run` or
        `tf.Tensor.eval`, to this
        function. Alternatively, you can pass a `tf.compat.v1.Summary` protocol
        buffer that you populate with your own data. The latter is
        commonly done to report evaluation results in event files.

        Args:
          summary: A `Summary` protocol buffer, optionally serialized as a string.
          global_step: Number. Optional global step value to record with the
            summary.
        """
        if isinstance(summary, bytes):
            summ = summary_pb2.Summary()
            summ.ParseFromString(summary)
            summary = summ

        # We strip metadata from values with tags that we have seen before in order
        # to save space - we just store the metadata on the first value with a
        # specific tag.
        for value in summary.value:
            if not value.metadata:
                continue

            if value.tag in self._seen_summary_tags:
                # This tag has been encountered before. Strip the metadata.
                value.ClearField("metadata")
                continue

            # We encounter a value with a tag we have not encountered previously. And
            # it has metadata. Remember to strip metadata from future values with this
            # tag string.
            self._seen_summary_tags.add(value.tag)

        event = event_pb2.Event(summary=summary)
        self._add_event(event, global_step)

    def _add_event(self, event: event_pb2.Event, step: Optional[int]) -> None:
        # _add_event is vendored from TensorFlow: tensorflow/python/summary/writer/writer.py
        event.wall_time = time.time()
        if step is not None:
            event.step = int(step)
        self.writer.add_event(event)

    def reset(self) -> None:
        self.writer.close()
        self.writer.reopen()
Esempio n. 9
0
class SummaryWriter(object):
    """Saves data in event and summary protos for tensorboard."""
    def __init__(self, log_dir):
        """Create a new SummaryWriter.

    Args:
      log_dir: path to record tfevents files in.
    """
        # If needed, create log_dir directory as well as missing parent directories.
        if not gfile.isdir(log_dir):
            gfile.makedirs(log_dir)

        self._event_writer = EventFileWriter(log_dir, 10, 120, None)
        self._closed = False

    def _add_summary(self, summary, step):
        event = event_pb2.Event(summary=summary)
        event.wall_time = time.time()
        if step is not None:
            event.step = int(step)
        self._event_writer.add_event(event)

    def close(self):
        """Close SummaryWriter. Final!"""
        if not self._closed:
            self._event_writer.close()
            self._closed = True
            del self._event_writer

    def flush(self):
        self._event_writer.flush()

    def scalar(self, tag, value, step):
        """Saves scalar value.

    Args:
      tag: str: label for this data
      value: int/float: number to log
      step: int: training step
    """
        value = float(onp.array(value))
        summary = Summary(value=[Summary.Value(tag=tag, simple_value=value)])
        self._add_summary(summary, step)

    def image(self, tag, image, step):
        """Saves RGB image summary from onp.ndarray [H,W], [H,W,1], or [H,W,3].

    Args:
      tag: str: label for this data
      image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors.
        Pixel values should be in the range [0, 1].
      step: int: training step
    """
        image = onp.array(image)
        if len(onp.shape(image)) == 2:
            image = image[:, :, onp.newaxis]
        if onp.shape(image)[-1] == 1:
            image = onp.repeat(image, 3, axis=-1)
        image_strio = io.BytesIO()
        plt.imsave(image_strio, image, vmin=0., vmax=1., format='png')
        image_summary = Summary.Image(
            encoded_image_string=image_strio.getvalue(),
            colorspace=3,
            height=image.shape[0],
            width=image.shape[1])
        summary = Summary(value=[Summary.Value(tag=tag, image=image_summary)])
        self._add_summary(summary, step)

    def audio(self, tag, audiodata, step, sample_rate=44100):
        """Saves audio.

    NB: single channel only right now.

    Args:
      tag: str: label for this data
      audiodata: ndarray [Nsamples,]: audo data to be saves as wave.
        The data will be clipped to [-1, 1].

      step: int: training step
      sample_rate: sample rate of passed in audio buffer
    """
        audiodata = onp.array(audiodata)
        audiodata = onp.clip(onp.squeeze(audiodata), -1, 1)
        if audiodata.ndim != 1:
            raise ValueError('Audio data must be 1D.')
        # convert from [-1, 1] -> [-2^15-1, 2^15-1]
        sample_list = (32767.0 * audiodata).astype(int).tolist()
        wio = io.BytesIO()
        wav_buf = wave.open(wio, 'wb')
        wav_buf.setnchannels(1)
        wav_buf.setsampwidth(2)
        wav_buf.setframerate(sample_rate)
        enc = b''.join([struct.pack('<h', v) for v in sample_list])
        wav_buf.writeframes(enc)
        wav_buf.close()
        encoded_audio_bytes = wio.getvalue()
        wio.close()
        audio = Summary.Audio(sample_rate=sample_rate,
                              num_channels=1,
                              length_frames=len(sample_list),
                              encoded_audio_string=encoded_audio_bytes,
                              content_type='audio/wav')
        summary = Summary(value=[Summary.Value(tag=tag, audio=audio)])
        self._add_summary(summary, step)

    def histogram(self, tag, values, bins, step):
        """Saves histogram of values.

    Args:
      tag: str: label for this data
      values: ndarray: will be flattened by this routine
      bins: number of bins in histogram, or a sequence defining a monotonically
        increasing array of bin edges, including the rightmost edge.
      step: int: training step
    """
        values = onp.array(values)
        bins = onp.array(bins)
        values = onp.reshape(values, -1)
        counts, limits = onp.histogram(values, bins=bins)
        # boundary logic
        # TODO(flax-dev) Investigate whether this logic can be simplified.
        cum_counts = onp.cumsum(onp.greater(counts, 0, dtype=onp.int32))
        start, end = onp.searchsorted(cum_counts, [0, cum_counts[-1] - 1],
                                      side='right')
        start, end = int(start), int(end) + 1
        counts = (counts[start - 1:end]
                  if start > 0 else onp.concatenate([[0], counts[:end]]))
        limits = limits[start:end + 1]
        sum_sq = values.dot(values)
        histo = HistogramProto(min=values.min(),
                               max=values.max(),
                               num=len(values),
                               sum=values.sum(),
                               sum_squares=sum_sq,
                               bucket_limit=limits.tolist(),
                               bucket=counts.tolist())
        summary = Summary(value=[Summary.Value(tag=tag, histo=histo)])
        self._add_summary(summary, step)

    def text(self, tag, textdata, step):
        """Saves a text summary.

    Args:
      tag: str: label for this data
      textdata: string, or 1D/2D list/numpy array of strings
      step: int: training step
    Note: markdown formatting is rendered by tensorboard.
    """
        smd = SummaryMetadata(plugin_data=SummaryMetadata.PluginData(
            plugin_name='text'))
        if isinstance(textdata, (str, bytes)):
            tensor = tf.make_tensor_proto(
                values=[textdata.encode(encoding='utf_8')], shape=(1, ))
        else:
            textdata = onp.array(textdata)  # convert lists, jax arrays, etc.
            datashape = onp.shape(textdata)
            if len(datashape) == 1:
                tensor = tf.make_tensor_proto(
                    values=[td.encode(encoding='utf_8') for td in textdata],
                    shape=(datashape[0], ))
            elif len(datashape) == 2:
                tensor = tf.make_tensor_proto(values=[
                    td.encode(encoding='utf_8')
                    for td in onp.reshape(textdata, -1)
                ],
                                              shape=(datashape[0],
                                                     datashape[1]))
        summary = Summary(
            value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
        self._add_summary(summary, step)
Esempio n. 10
0
  def __init__(self,
               logdir,
               graph=None,
               max_queue=10,
               flush_secs=120,
               graph_def=None,
               filename_suffix=None,
               session=None):
    """Creates a `FileWriter`, optionally shared within the given session.

    Typically, constructing a file writer creates a new event file in `logdir`.
    This event file will contain `Event` protocol buffers constructed when you
    call one of the following functions: `add_summary()`, `add_session_log()`,
    `add_event()`, or `add_graph()`.

    If you pass a `Graph` to the constructor it is added to
    the event file. (This is equivalent to calling `add_graph()` later).

    TensorBoard will pick the graph from the file and display it graphically so
    you can interactively explore the graph you built. You will usually pass
    the graph from the session in which you launched it:

    ```python
    ...create a graph...
    # Launch the graph in a session.
    sess = tf.Session()
    # Create a summary writer, add the 'graph' to the event file.
    writer = tf.summary.FileWriter(<some-directory>, sess.graph)
    ```

    The `session` argument to the constructor makes the returned `FileWriter` a
    compatibility layer over new graph-based summaries (`tf.contrib.summary`).
    Crucially, this means the underlying writer resource and events file will
    be shared with any other `FileWriter` using the same `session` and `logdir`,
    and with any `tf.contrib.summary.SummaryWriter` in this session using the
    the same shared resource name (which by default scoped to the logdir). If
    no such resource exists, one will be created using the remaining arguments
    to this constructor, but if one already exists those arguments are ignored.
    In either case, ops will be added to `session.graph` to control the
    underlying file writer resource. See `tf.contrib.summary` for more details.

    Args:
      logdir: A string. Directory where event file will be written.
      graph: A `Graph` object, such as `sess.graph`.
      max_queue: Integer. Size of the queue for pending events and summaries.
      flush_secs: Number. How often, in seconds, to flush the
        pending events and summaries to disk.
      graph_def: DEPRECATED: Use the `graph` argument instead.
      filename_suffix: A string. Every event file's name is suffixed with
        `suffix`.
      session: A `tf.Session` object. See details above.

    Raises:
      RuntimeError: If called with eager execution enabled.

    @compatibility(eager)
    `FileWriter` is not compatible with eager execution. To write TensorBoard
    summaries under eager execution, use `tf.contrib.summary` instead.
    @end_compatbility
    """
    if context.executing_eagerly():
      raise RuntimeError(
          "tf.summary.FileWriter is not compatible with eager execution. "
          "Use tf.contrib.summary instead.")
    if session is not None:
      event_writer = EventFileWriterV2(
          session, logdir, max_queue, flush_secs, filename_suffix)
    else:
      event_writer = EventFileWriter(logdir, max_queue, flush_secs,
                                     filename_suffix)
    super(FileWriter, self).__init__(event_writer, graph, graph_def)
Esempio n. 11
0
class SummaryWriter(object):
    """Saves data in event and summary protos for tensorboard."""

    def __init__(self, log_dir, enable=True):
        """Create a new SummaryWriter.

        Args:
            log_dir: path to record tfevents files in.
            enable: bool: if False don't actually write or flush data.    Used in
                multihost training.
        """
        # If needed, create log_dir directory as well as missing parent directories.
        if not tf.io.gfile.isdir(log_dir):
            tf.io.gfile.makedirs(log_dir)

        log_dir = os.path.join(log_dir, "{}".format(len(os.listdir(log_dir))))
        logging.info("Created new logger at: {}".format(log_dir))

        self.log_dir = log_dir
        self._event_writer = EventFileWriter(log_dir, 10, 120, None)
        self._step = 0
        self._closed = False
        self._enabled = enable
        self._best_loss = float("inf")

    def add_summary(self, summary, step):
        if not self._enabled:
            return
        event = event_pb2.Event(summary=summary)
        event.wall_time = time.time()
        if step is not None:
            event.step = int(step)
        self._event_writer.add_event(event)

    def close(self):
        """Close SummaryWriter. Final!"""
        if not self._closed:
            self._event_writer.close()
            self._closed = True
            del self._event_writer

    def __del__(self):    # safe?
        # TODO(afrozm): Sometimes this complains with
        #    `TypeError: 'NoneType' object is not callable`
        try:
            self.close()
        except Exception:    # pylint: disable=broad-except
            pass

    def flush(self):
        if not self._enabled:
            return
        self._event_writer.flush()

    def scalar(self, tag, value, step=None):
        """Saves scalar value.

        Args:
            tag: str: label for this data
            value: int/float: number to log
            step: int: training step
        """
        value = float(np.array(value))
        if step is None:
            step = self._step
        else:
            self._step = step
        summary = tf.compat.v1.Summary(
                value=[tf.compat.v1.Summary.Value(tag=tag, simple_value=value)])
        self.add_summary(summary, step)

    def image(self, tag, image, step=None):
        """Saves RGB image summary from np.ndarray [H,W], [H,W,1], or [H,W,3].

        Args:
            tag: str: label for this data
            image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors/
            step: int: training step
        """
        image = np.array(image)
        if step is None:
            step = self._step
        else:
            self._step = step
        if len(np.shape(image)) == 2:
            image = image[:, :, np.newaxis]
        if np.shape(image)[-1] == 1:
            image = np.repeat(image, 3, axis=-1)
        image_strio = io.BytesIO()
        plt.imsave(image_strio, image, format='png')
        image_summary = tf.compat.v1.Summary.Image(
                encoded_image_string=image_strio.getvalue(),
                colorspace=3,
                height=image.shape[0],
                width=image.shape[1])
        summary = tf.compat.v1.Summary(
                value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)])
        self.add_summary(summary, step)

    def images(self, tag, images, step=None, rows=None, cols=None):
        """Saves (rows, cols) tiled images from np.ndarray.

        If either rows or cols aren't given, they are determined automatically
        from the size of the image batch, if neither are given a long column
        of images is produced. This truncates the image batch rather than padding
        if it doesn't fill the final row.

        Args:
            tag: str: label for this data
            images: ndarray: [N,H,W,1] or [N,H,W,3] to tile in 2d
            step: int: training step
            rows: int: number of rows in tile
            cols: int: number of columns in tile
        """
        images = np.array(images)
        if step is None:
            step = self._step
        else:
            self._step = step
        n_images = np.shape(images)[0]
        if rows is None and cols is None:
            rows = 1
            cols = n_images
        elif rows is None:
            rows = n_images // cols
        elif cols is None:
            cols = n_images // rows
        tiled_images = _pack_images(images, rows, cols)
        self.image(tag, tiled_images, step=step)

    def plot(self, tag, mpl_plt, step=None, close_plot=True):
        """Saves matplotlib plot output to summary image.

        Args:
            tag: str: label for this data
            mpl_plt: matplotlib stateful pyplot object with prepared plotting state
            step: int: training step
            close_plot: bool: automatically closes plot
        """
        if step is None:
            step = self._step
        else:
            self._step = step
        fig = mpl_plt.get_current_fig_manager()
        img_w, img_h = fig.canvas.get_width_height()
        image_buf = io.BytesIO()
        mpl_plt.savefig(image_buf, format='png')
        image_summary = tf.compat.v1.Summary.Image(
                encoded_image_string=image_buf.getvalue(),
                colorspace=4,    # RGBA
                height=img_h,
                width=img_w)
        summary = tf.compat.v1.Summary(
                value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)])
        self.add_summary(summary, step)
        if close_plot:
            mpl_plt.close()

    def figure(self, tag, fig, step=None, close_plot=True):
        """Saves matplotlib plot output to summary image.

        Args:
            tag: str: label for this data
            mpl_plt: matplotlib stateful pyplot object with prepared plotting state
            step: int: training step
            close_plot: bool: automatically closes plot
        """
        if step is None:
            step = self._step
        else:
            self._step = step
        img_w, img_h = fig.canvas.get_width_height()
        image_buf = io.BytesIO()
        plt.savefig(image_buf, format='png')
        image_summary = tf.compat.v1.Summary.Image(
                encoded_image_string=image_buf.getvalue(),
                colorspace=4,    # RGBA
                height=img_h,
                width=img_w)
        summary = tf.compat.v1.Summary(
                value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)])
        self.add_summary(summary, step)
        if close_plot:
            plt.close(fig)

    def audio(self, tag, audiodata, step=None, sample_rate=44100):
        """Saves audio.

        NB: single channel only right now.

        Args:
            tag: str: label for this data
            audiodata: ndarray [Nsamples,]: data between (-1.0,1.0) to save as wave
            step: int: training step
            sample_rate: sample rate of passed in audio buffer
        """
        audiodata = np.array(audiodata)
        if step is None:
            step = self._step
        else:
            self._step = step
        audiodata = np.clip(np.squeeze(audiodata), -1, 1)
        if audiodata.ndim != 1:
            raise ValueError('Audio data must be 1D.')
        sample_list = (32767.0 * audiodata).astype(int).tolist()
        wio = io.BytesIO()
        wav_buf = wave.open(wio, 'wb')
        wav_buf.setnchannels(1)
        wav_buf.setsampwidth(2)
        wav_buf.setframerate(sample_rate)
        enc = b''.join([struct.pack('<h', v) for v in sample_list])
        wav_buf.writeframes(enc)
        wav_buf.close()
        encoded_audio_bytes = wio.getvalue()
        wio.close()
        audio = tf.compat.v1.Summary.Audio(
                sample_rate=sample_rate,
                num_channels=1,
                length_frames=len(sample_list),
                encoded_audio_string=encoded_audio_bytes,
                content_type='audio/wav')
        summary = tf.compat.v1.Summary(
                value=[tf.compat.v1.Summary.Value(tag=tag, audio=audio)])
        self.add_summary(summary, step)

    def histogram(self, tag, values, bins, step=None):
        """Saves histogram of values.

        Args:
            tag: str: label for this data
            values: ndarray: will be flattened by this routine
            bins: number of bins in histogram, or array of bins for np.histogram
            step: int: training step
        """
        if step is None:
            step = self._step
        else:
            self._step = step
        values = np.array(values)
        bins = np.array(bins)
        values = np.reshape(values, -1)
        counts, limits = np.histogram(values, bins=bins)
        # boundary logic
        cum_counts = np.cumsum(np.greater(counts, 0, dtype=np.int32))
        start, end = np.searchsorted(
                cum_counts, [0, cum_counts[-1] - 1], side='right')
        start, end = int(start), int(end) + 1
        counts = (
                counts[start -
                             1:end] if start > 0 else np.concatenate([[0], counts[:end]]))
        limits = limits[start:end + 1]
        sum_sq = values.dot(values)
        histo = tf.compat.v1.HistogramProto(
                min=values.min(),
                max=values.max(),
                num=len(values),
                sum=values.sum(),
                sum_squares=sum_sq,
                bucket_limit=limits.tolist(),
                bucket=counts.tolist())
        summary = tf.compat.v1.Summary(
                value=[tf.compat.v1.Summary.Value(tag=tag, histo=histo)])
        self.add_summary(summary, step)

    def text(self, tag, textdata, step=None):
        """Saves a text summary.

        Args:
            tag: str: label for this data
            textdata: string, or 1D/2D list/numpy array of strings
            step: int: training step
        Note: markdown formatting is rendered by tensorboard.
        """
        if step is None:
            step = self._step
        else:
            self._step = step
        smd = tf.compat.v1.SummaryMetadata(
                plugin_data=tf.compat.v1.SummaryMetadata.PluginData(plugin_name='text'))
        if isinstance(textdata, (str, bytes)):
            tensor = tf.make_tensor_proto(
                    values=[textdata.encode(encoding='utf_8')], shape=(1,))
        else:
            textdata = np.array(textdata)    # convert lists, jax arrays, etc.
            datashape = np.shape(textdata)
            if len(datashape) == 1:
                tensor = tf.make_tensor_proto(
                        values=[td.encode(encoding='utf_8') for td in textdata],
                        shape=(datashape[0],))
            elif len(datashape) == 2:
                tensor = tf.make_tensor_proto(
                        values=[
                                td.encode(encoding='utf_8') for td in np.reshape(textdata, -1)
                        ],
                        shape=(datashape[0], datashape[1]))
        summary = tf.compat.v1.Summary(
                value=[tf.compat.v1.Summary.Value(
                        tag=tag, metadata=smd, tensor=tensor)])
        self.add_summary(summary, step)

    def checkpoint(self, tag, optimiser_state, step, loss=None):
        """Saves a copy of the model parameters using pickle

        Args:
            tag: str: label for this data
            params: a PyTree containing the parameters of the model
            step: int: training step
            monitor: metric to evaluate the performance of the model
            best_only: if True, models with a lower performance will not be saved
        """
        parent = os.path.join(self.log_dir, "checkpoints")
        os.makedirs(parent, exist_ok=True)

        if not isinstance(optimiser_state, jax.experimental.optimizers.JoinPoint):
            optimiser_state = jax.experimental.optimizers.unpack_optimizer_state(optimiser_state)

        filepath_last = os.path.join(parent, str(tag) + "_last.pickle")
        with open(filepath_last, "wb") as f:
            pickle.dump(optimiser_state, f)

        if loss is None or loss > self._best_loss:
            return
        self._best_loss = loss
        filepath_best = os.path.join(parent, str(tag) + "_best.pickle")
        with open(filepath_best, "wb") as f:
            pickle.dump(optimiser_state, f)
        return True