示例#1
0
class SummaryWriter(object):
    """Saves data in event and summary protos for tensorboard."""
    def __init__(self, log_dir):
        """Create a new SummaryWriter.

    Args:
      log_dir: path to record tfevents files in.
    """
        # If needed, create log_dir directory as well as missing parent directories.
        if not gfile.isdir(log_dir):
            gfile.makedirs(log_dir)

        self._event_writer = EventFileWriter(log_dir, 10, 120, None)
        self._step = 0
        self._closed = False

    def add_summary(self, summary, step):
        event = event_pb2.Event(summary=summary)
        event.wall_time = time.time()
        if step is not None:
            event.step = int(step)
        self._event_writer.add_event(event)

    def close(self):
        """Close SummaryWriter. Final!"""
        if not self._closed:
            self._event_writer.close()
            self._closed = True
            del self._event_writer

    def __del__(self):  # safe?
        self.close()

    def flush(self):
        self._event_writer.flush()

    def scalar(self, tag, value, step=None):
        """Saves scalar value.

    Args:
      tag: str: label for this data
      value: int/float: number to log
      step: int: training step
    """
        value = float(onp.array(value))
        if step is None:
            step = self._step
        else:
            self._step = step
        summary = Summary(value=[Summary.Value(tag=tag, simple_value=value)])
        self.add_summary(summary, step)

    def image(self, tag, image, step=None):
        """Saves RGB image summary from onp.ndarray [H,W], [H,W,1], or [H,W,3].

    Args:
      tag: str: label for this data
      image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors/
      step: int: training step
    """
        image = onp.array(image)
        if step is None:
            step = self._step
        else:
            self._step = step
        if len(onp.shape(image)) == 2:
            image = image[:, :, onp.newaxis]
        if onp.shape(image)[-1] == 1:
            image = onp.repeat(image, 3, axis=-1)
        image_strio = io.BytesIO()
        plt.imsave(image_strio, image, format='png')
        image_summary = Summary.Image(
            encoded_image_string=image_strio.getvalue(),
            colorspace=3,
            height=image.shape[0],
            width=image.shape[1])
        summary = Summary(value=[Summary.Value(tag=tag, image=image_summary)])
        self.add_summary(summary, step)

    def images(self, tag, images, step=None, rows=None, cols=None):
        """Saves (rows, cols) tiled images from onp.ndarray.

    If either rows or cols aren't given, they are determined automatically
    from the size of the image batch, if neither are given a long column
    of images is produced. This truncates the image batch rather than padding
    if it doesn't fill the final row.

    Args:
      tag: str: label for this data
      images: ndarray: [N,H,W,1] or [N,H,W,3] to tile in 2d
      step: int: training step
      rows: int: number of rows in tile
      cols: int: number of columns in tile
    """
        images = onp.array(images)
        if step is None:
            step = self._step
        else:
            self._step = step
        n_images = onp.shape(images)[0]
        if rows is None and cols is None:
            rows = 1
            cols = n_images
        elif rows is None:
            rows = n_images // cols
        elif cols is None:
            cols = n_images // rows
        tiled_images = _pack_images(images, rows, cols)
        self.image(tag, tiled_images, step=step)

    def plot(self, tag, mpl_plt, step=None, close_plot=True):
        """Saves matplotlib plot output to summary image.

    Args:
      tag: str: label for this data
      mpl_plt: matplotlib stateful pyplot object with prepared plotting state
      step: int: training step
      close_plot: bool: automatically closes plot
    """
        if step is None:
            step = self._step
        else:
            self._step = step
        fig = mpl_plt.get_current_fig_manager()
        img_w, img_h = fig.canvas.get_width_height()
        image_buf = io.BytesIO()
        mpl_plt.savefig(image_buf, format='png')
        image_summary = Summary.Image(
            encoded_image_string=image_buf.getvalue(),
            colorspace=4,  # RGBA
            height=img_h,
            width=img_w)
        summary = Summary(value=[Summary.Value(tag=tag, image=image_summary)])
        self.add_summary(summary, step)
        if close_plot:
            mpl_plt.close()

    def audio(self, tag, audiodata, step=None, sample_rate=44100):
        """Saves audio.

    NB: single channel only right now.

    Args:
      tag: str: label for this data
      audiodata: ndarray [Nsamples,]: data between (-1.0,1.0) to save as wave
      step: int: training step
      sample_rate: sample rate of passed in audio buffer
    """
        audiodata = onp.array(audiodata)
        if step is None:
            step = self._step
        else:
            self._step = step
        audiodata = onp.clip(onp.squeeze(audiodata), -1, 1)
        if audiodata.ndim != 1:
            raise ValueError('Audio data must be 1D.')
        sample_list = (32767.0 * audiodata).astype(int).tolist()
        wio = io.BytesIO()
        wav_buf = wave.open(wio, 'wb')
        wav_buf.setnchannels(1)
        wav_buf.setsampwidth(2)
        wav_buf.setframerate(sample_rate)
        enc = b''.join([struct.pack('<h', v) for v in sample_list])
        wav_buf.writeframes(enc)
        wav_buf.close()
        encoded_audio_bytes = wio.getvalue()
        wio.close()
        audio = Summary.Audio(sample_rate=sample_rate,
                              num_channels=1,
                              length_frames=len(sample_list),
                              encoded_audio_string=encoded_audio_bytes,
                              content_type='audio/wav')
        summary = Summary(value=[Summary.Value(tag=tag, audio=audio)])
        self.add_summary(summary, step)

    def histogram(self, tag, values, bins, step=None):
        """Saves histogram of values.

    Args:
      tag: str: label for this data
      values: ndarray: will be flattened by this routine
      bins: number of bins in histogram, or array of bins for onp.histogram
      step: int: training step
    """
        if step is None:
            step = self._step
        else:
            self._step = step
        values = onp.array(values)
        bins = onp.array(bins)
        values = onp.reshape(values, -1)
        counts, limits = onp.histogram(values, bins=bins)
        # boundary logic
        cum_counts = onp.cumsum(onp.greater(counts, 0, dtype=onp.int32))
        start, end = onp.searchsorted(cum_counts, [0, cum_counts[-1] - 1],
                                      side='right')
        start, end = int(start), int(end) + 1
        counts = (counts[start - 1:end]
                  if start > 0 else onp.concatenate([[0], counts[:end]]))
        limits = limits[start:end + 1]
        sum_sq = values.dot(values)
        histo = HistogramProto(min=values.min(),
                               max=values.max(),
                               num=len(values),
                               sum=values.sum(),
                               sum_squares=sum_sq,
                               bucket_limit=limits.tolist(),
                               bucket=counts.tolist())
        summary = Summary(value=[Summary.Value(tag=tag, histo=histo)])
        self.add_summary(summary, step)

    def text(self, tag, textdata, step=None):
        """Saves a text summary.

    Args:
      tag: str: label for this data
      textdata: string, or 1D/2D list/numpy array of strings
      step: int: training step
    Note: markdown formatting is rendered by tensorboard.
    """
        if step is None:
            step = self._step
        else:
            self._step = step
        smd = SummaryMetadata(plugin_data=SummaryMetadata.PluginData(
            plugin_name='text'))
        if isinstance(textdata, (str, bytes)):
            tensor = tf.make_tensor_proto(
                values=[textdata.encode(encoding='utf_8')], shape=(1, ))
        else:
            textdata = onp.array(textdata)  # convert lists, jax arrays, etc.
            datashape = onp.shape(textdata)
            if len(datashape) == 1:
                tensor = tf.make_tensor_proto(
                    values=[td.encode(encoding='utf_8') for td in textdata],
                    shape=(datashape[0], ))
            elif len(datashape) == 2:
                tensor = tf.make_tensor_proto(values=[
                    td.encode(encoding='utf_8')
                    for td in onp.reshape(textdata, -1)
                ],
                                              shape=(datashape[0],
                                                     datashape[1]))
        summary = Summary(
            value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
        self.add_summary(summary, step)
示例#2
0
class SummaryWriter(object):
    """Saves data in event and summary protos for tensorboard."""
    def __init__(self, log_dir):
        """Create a new SummaryWriter.

    Args:
      log_dir: path to record tfevents files in.
    """
        # If needed, create log_dir directory as well as missing parent directories.
        if not gfile.isdir(log_dir):
            gfile.makedirs(log_dir)

        self._event_writer = EventFileWriter(log_dir, 10, 120, None)
        self._closed = False

    def _add_summary(self, summary, step):
        event = event_pb2.Event(summary=summary)
        event.wall_time = time.time()
        if step is not None:
            event.step = int(step)
        self._event_writer.add_event(event)

    def close(self):
        """Close SummaryWriter. Final!"""
        if not self._closed:
            self._event_writer.close()
            self._closed = True
            del self._event_writer

    def flush(self):
        self._event_writer.flush()

    def scalar(self, tag, value, step):
        """Saves scalar value.

    Args:
      tag: str: label for this data
      value: int/float: number to log
      step: int: training step
    """
        value = float(onp.array(value))
        summary = Summary(value=[Summary.Value(tag=tag, simple_value=value)])
        self._add_summary(summary, step)

    def image(self, tag, image, step):
        """Saves RGB image summary from onp.ndarray [H,W], [H,W,1], or [H,W,3].

    Args:
      tag: str: label for this data
      image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors.
        Pixel values should be in the range [0, 1].
      step: int: training step
    """
        image = onp.array(image)
        if len(onp.shape(image)) == 2:
            image = image[:, :, onp.newaxis]
        if onp.shape(image)[-1] == 1:
            image = onp.repeat(image, 3, axis=-1)
        image_strio = io.BytesIO()
        plt.imsave(image_strio, image, vmin=0., vmax=1., format='png')
        image_summary = Summary.Image(
            encoded_image_string=image_strio.getvalue(),
            colorspace=3,
            height=image.shape[0],
            width=image.shape[1])
        summary = Summary(value=[Summary.Value(tag=tag, image=image_summary)])
        self._add_summary(summary, step)

    def audio(self, tag, audiodata, step, sample_rate=44100):
        """Saves audio.

    NB: single channel only right now.

    Args:
      tag: str: label for this data
      audiodata: ndarray [Nsamples,]: audo data to be saves as wave.
        The data will be clipped to [-1, 1].

      step: int: training step
      sample_rate: sample rate of passed in audio buffer
    """
        audiodata = onp.array(audiodata)
        audiodata = onp.clip(onp.squeeze(audiodata), -1, 1)
        if audiodata.ndim != 1:
            raise ValueError('Audio data must be 1D.')
        # convert from [-1, 1] -> [-2^15-1, 2^15-1]
        sample_list = (32767.0 * audiodata).astype(int).tolist()
        wio = io.BytesIO()
        wav_buf = wave.open(wio, 'wb')
        wav_buf.setnchannels(1)
        wav_buf.setsampwidth(2)
        wav_buf.setframerate(sample_rate)
        enc = b''.join([struct.pack('<h', v) for v in sample_list])
        wav_buf.writeframes(enc)
        wav_buf.close()
        encoded_audio_bytes = wio.getvalue()
        wio.close()
        audio = Summary.Audio(sample_rate=sample_rate,
                              num_channels=1,
                              length_frames=len(sample_list),
                              encoded_audio_string=encoded_audio_bytes,
                              content_type='audio/wav')
        summary = Summary(value=[Summary.Value(tag=tag, audio=audio)])
        self._add_summary(summary, step)

    def histogram(self, tag, values, bins, step):
        """Saves histogram of values.

    Args:
      tag: str: label for this data
      values: ndarray: will be flattened by this routine
      bins: number of bins in histogram, or a sequence defining a monotonically
        increasing array of bin edges, including the rightmost edge.
      step: int: training step
    """
        values = onp.array(values)
        bins = onp.array(bins)
        values = onp.reshape(values, -1)
        counts, limits = onp.histogram(values, bins=bins)
        # boundary logic
        # TODO(flax-dev) Investigate whether this logic can be simplified.
        cum_counts = onp.cumsum(onp.greater(counts, 0, dtype=onp.int32))
        start, end = onp.searchsorted(cum_counts, [0, cum_counts[-1] - 1],
                                      side='right')
        start, end = int(start), int(end) + 1
        counts = (counts[start - 1:end]
                  if start > 0 else onp.concatenate([[0], counts[:end]]))
        limits = limits[start:end + 1]
        sum_sq = values.dot(values)
        histo = HistogramProto(min=values.min(),
                               max=values.max(),
                               num=len(values),
                               sum=values.sum(),
                               sum_squares=sum_sq,
                               bucket_limit=limits.tolist(),
                               bucket=counts.tolist())
        summary = Summary(value=[Summary.Value(tag=tag, histo=histo)])
        self._add_summary(summary, step)

    def text(self, tag, textdata, step):
        """Saves a text summary.

    Args:
      tag: str: label for this data
      textdata: string, or 1D/2D list/numpy array of strings
      step: int: training step
    Note: markdown formatting is rendered by tensorboard.
    """
        smd = SummaryMetadata(plugin_data=SummaryMetadata.PluginData(
            plugin_name='text'))
        if isinstance(textdata, (str, bytes)):
            tensor = tf.make_tensor_proto(
                values=[textdata.encode(encoding='utf_8')], shape=(1, ))
        else:
            textdata = onp.array(textdata)  # convert lists, jax arrays, etc.
            datashape = onp.shape(textdata)
            if len(datashape) == 1:
                tensor = tf.make_tensor_proto(
                    values=[td.encode(encoding='utf_8') for td in textdata],
                    shape=(datashape[0], ))
            elif len(datashape) == 2:
                tensor = tf.make_tensor_proto(values=[
                    td.encode(encoding='utf_8')
                    for td in onp.reshape(textdata, -1)
                ],
                                              shape=(datashape[0],
                                                     datashape[1]))
        summary = Summary(
            value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
        self._add_summary(summary, step)
示例#3
0
class TFWriter(tensorboard.MetricWriter):
    """
    TFWriter uses tensorflow file writers and summary operations to write out
    tfevent files containing scalar batch metrics.
    """
    def __init__(self) -> None:
        super().__init__()
        self.writer = EventFileWriter(logdir=str(tensorboard.get_base_path(
            {})),
                                      filename_suffix=None)
        self.createSummary = tf.Summary

        # _seen_summary_tags is vendored from TensorFlow: tensorflow/python/summary/writer/writer.py
        # This set contains tags of Summary Values that have been encountered
        # already. The motivation here is that the SummaryWriter only keeps the
        # metadata property (which is a SummaryMetadata proto) of the first Summary
        # Value encountered for each tag. The SummaryWriter strips away the
        # SummaryMetadata for all subsequent Summary Values with tags seen
        # previously. This saves space.
        self._seen_summary_tags: Set[str] = set()

    def add_scalar(self, name: str, value: Union[int, float, np.number],
                   step: int) -> None:
        summary = self.createSummary()
        summary_value = summary.value.add()
        summary_value.tag = name
        summary_value.simple_value = value
        self._add_summary(summary, step)

    def _add_summary(self,
                     summary: Union[str, summary_pb2.Summary],
                     global_step: Optional[int] = None) -> None:
        """
        _add_summary is vendored from TensorFlow: tensorflow/python/summary/writer/writer.py

        Adds a `Summary` protocol buffer to the event file.

        This method wraps the provided summary in an `Event` protocol buffer
        and adds it to the event file.

        You can pass the result of evaluating any summary op, using
        `tf.Session.run` or
        `tf.Tensor.eval`, to this
        function. Alternatively, you can pass a `tf.compat.v1.Summary` protocol
        buffer that you populate with your own data. The latter is
        commonly done to report evaluation results in event files.

        Args:
          summary: A `Summary` protocol buffer, optionally serialized as a string.
          global_step: Number. Optional global step value to record with the
            summary.
        """
        if isinstance(summary, bytes):
            summ = summary_pb2.Summary()
            summ.ParseFromString(summary)
            summary = summ

        # We strip metadata from values with tags that we have seen before in order
        # to save space - we just store the metadata on the first value with a
        # specific tag.
        for value in summary.value:
            if not value.metadata:
                continue

            if value.tag in self._seen_summary_tags:
                # This tag has been encountered before. Strip the metadata.
                value.ClearField("metadata")
                continue

            # We encounter a value with a tag we have not encountered previously. And
            # it has metadata. Remember to strip metadata from future values with this
            # tag string.
            self._seen_summary_tags.add(value.tag)

        event = event_pb2.Event(summary=summary)
        self._add_event(event, global_step)

    def _add_event(self, event: event_pb2.Event, step: Optional[int]) -> None:
        # _add_event is vendored from TensorFlow: tensorflow/python/summary/writer/writer.py
        event.wall_time = time.time()
        if step is not None:
            event.step = int(step)
        self.writer.add_event(event)

    def reset(self) -> None:
        self.writer.close()
        self.writer.reopen()
示例#4
0
class SummaryWriter(object):
    """Saves data in event and summary protos for tensorboard."""

    def __init__(self, log_dir, enable=True):
        """Create a new SummaryWriter.

        Args:
            log_dir: path to record tfevents files in.
            enable: bool: if False don't actually write or flush data.    Used in
                multihost training.
        """
        # If needed, create log_dir directory as well as missing parent directories.
        if not tf.io.gfile.isdir(log_dir):
            tf.io.gfile.makedirs(log_dir)

        log_dir = os.path.join(log_dir, "{}".format(len(os.listdir(log_dir))))
        logging.info("Created new logger at: {}".format(log_dir))

        self.log_dir = log_dir
        self._event_writer = EventFileWriter(log_dir, 10, 120, None)
        self._step = 0
        self._closed = False
        self._enabled = enable
        self._best_loss = float("inf")

    def add_summary(self, summary, step):
        if not self._enabled:
            return
        event = event_pb2.Event(summary=summary)
        event.wall_time = time.time()
        if step is not None:
            event.step = int(step)
        self._event_writer.add_event(event)

    def close(self):
        """Close SummaryWriter. Final!"""
        if not self._closed:
            self._event_writer.close()
            self._closed = True
            del self._event_writer

    def __del__(self):    # safe?
        # TODO(afrozm): Sometimes this complains with
        #    `TypeError: 'NoneType' object is not callable`
        try:
            self.close()
        except Exception:    # pylint: disable=broad-except
            pass

    def flush(self):
        if not self._enabled:
            return
        self._event_writer.flush()

    def scalar(self, tag, value, step=None):
        """Saves scalar value.

        Args:
            tag: str: label for this data
            value: int/float: number to log
            step: int: training step
        """
        value = float(np.array(value))
        if step is None:
            step = self._step
        else:
            self._step = step
        summary = tf.compat.v1.Summary(
                value=[tf.compat.v1.Summary.Value(tag=tag, simple_value=value)])
        self.add_summary(summary, step)

    def image(self, tag, image, step=None):
        """Saves RGB image summary from np.ndarray [H,W], [H,W,1], or [H,W,3].

        Args:
            tag: str: label for this data
            image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors/
            step: int: training step
        """
        image = np.array(image)
        if step is None:
            step = self._step
        else:
            self._step = step
        if len(np.shape(image)) == 2:
            image = image[:, :, np.newaxis]
        if np.shape(image)[-1] == 1:
            image = np.repeat(image, 3, axis=-1)
        image_strio = io.BytesIO()
        plt.imsave(image_strio, image, format='png')
        image_summary = tf.compat.v1.Summary.Image(
                encoded_image_string=image_strio.getvalue(),
                colorspace=3,
                height=image.shape[0],
                width=image.shape[1])
        summary = tf.compat.v1.Summary(
                value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)])
        self.add_summary(summary, step)

    def images(self, tag, images, step=None, rows=None, cols=None):
        """Saves (rows, cols) tiled images from np.ndarray.

        If either rows or cols aren't given, they are determined automatically
        from the size of the image batch, if neither are given a long column
        of images is produced. This truncates the image batch rather than padding
        if it doesn't fill the final row.

        Args:
            tag: str: label for this data
            images: ndarray: [N,H,W,1] or [N,H,W,3] to tile in 2d
            step: int: training step
            rows: int: number of rows in tile
            cols: int: number of columns in tile
        """
        images = np.array(images)
        if step is None:
            step = self._step
        else:
            self._step = step
        n_images = np.shape(images)[0]
        if rows is None and cols is None:
            rows = 1
            cols = n_images
        elif rows is None:
            rows = n_images // cols
        elif cols is None:
            cols = n_images // rows
        tiled_images = _pack_images(images, rows, cols)
        self.image(tag, tiled_images, step=step)

    def plot(self, tag, mpl_plt, step=None, close_plot=True):
        """Saves matplotlib plot output to summary image.

        Args:
            tag: str: label for this data
            mpl_plt: matplotlib stateful pyplot object with prepared plotting state
            step: int: training step
            close_plot: bool: automatically closes plot
        """
        if step is None:
            step = self._step
        else:
            self._step = step
        fig = mpl_plt.get_current_fig_manager()
        img_w, img_h = fig.canvas.get_width_height()
        image_buf = io.BytesIO()
        mpl_plt.savefig(image_buf, format='png')
        image_summary = tf.compat.v1.Summary.Image(
                encoded_image_string=image_buf.getvalue(),
                colorspace=4,    # RGBA
                height=img_h,
                width=img_w)
        summary = tf.compat.v1.Summary(
                value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)])
        self.add_summary(summary, step)
        if close_plot:
            mpl_plt.close()

    def figure(self, tag, fig, step=None, close_plot=True):
        """Saves matplotlib plot output to summary image.

        Args:
            tag: str: label for this data
            mpl_plt: matplotlib stateful pyplot object with prepared plotting state
            step: int: training step
            close_plot: bool: automatically closes plot
        """
        if step is None:
            step = self._step
        else:
            self._step = step
        img_w, img_h = fig.canvas.get_width_height()
        image_buf = io.BytesIO()
        plt.savefig(image_buf, format='png')
        image_summary = tf.compat.v1.Summary.Image(
                encoded_image_string=image_buf.getvalue(),
                colorspace=4,    # RGBA
                height=img_h,
                width=img_w)
        summary = tf.compat.v1.Summary(
                value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)])
        self.add_summary(summary, step)
        if close_plot:
            plt.close(fig)

    def audio(self, tag, audiodata, step=None, sample_rate=44100):
        """Saves audio.

        NB: single channel only right now.

        Args:
            tag: str: label for this data
            audiodata: ndarray [Nsamples,]: data between (-1.0,1.0) to save as wave
            step: int: training step
            sample_rate: sample rate of passed in audio buffer
        """
        audiodata = np.array(audiodata)
        if step is None:
            step = self._step
        else:
            self._step = step
        audiodata = np.clip(np.squeeze(audiodata), -1, 1)
        if audiodata.ndim != 1:
            raise ValueError('Audio data must be 1D.')
        sample_list = (32767.0 * audiodata).astype(int).tolist()
        wio = io.BytesIO()
        wav_buf = wave.open(wio, 'wb')
        wav_buf.setnchannels(1)
        wav_buf.setsampwidth(2)
        wav_buf.setframerate(sample_rate)
        enc = b''.join([struct.pack('<h', v) for v in sample_list])
        wav_buf.writeframes(enc)
        wav_buf.close()
        encoded_audio_bytes = wio.getvalue()
        wio.close()
        audio = tf.compat.v1.Summary.Audio(
                sample_rate=sample_rate,
                num_channels=1,
                length_frames=len(sample_list),
                encoded_audio_string=encoded_audio_bytes,
                content_type='audio/wav')
        summary = tf.compat.v1.Summary(
                value=[tf.compat.v1.Summary.Value(tag=tag, audio=audio)])
        self.add_summary(summary, step)

    def histogram(self, tag, values, bins, step=None):
        """Saves histogram of values.

        Args:
            tag: str: label for this data
            values: ndarray: will be flattened by this routine
            bins: number of bins in histogram, or array of bins for np.histogram
            step: int: training step
        """
        if step is None:
            step = self._step
        else:
            self._step = step
        values = np.array(values)
        bins = np.array(bins)
        values = np.reshape(values, -1)
        counts, limits = np.histogram(values, bins=bins)
        # boundary logic
        cum_counts = np.cumsum(np.greater(counts, 0, dtype=np.int32))
        start, end = np.searchsorted(
                cum_counts, [0, cum_counts[-1] - 1], side='right')
        start, end = int(start), int(end) + 1
        counts = (
                counts[start -
                             1:end] if start > 0 else np.concatenate([[0], counts[:end]]))
        limits = limits[start:end + 1]
        sum_sq = values.dot(values)
        histo = tf.compat.v1.HistogramProto(
                min=values.min(),
                max=values.max(),
                num=len(values),
                sum=values.sum(),
                sum_squares=sum_sq,
                bucket_limit=limits.tolist(),
                bucket=counts.tolist())
        summary = tf.compat.v1.Summary(
                value=[tf.compat.v1.Summary.Value(tag=tag, histo=histo)])
        self.add_summary(summary, step)

    def text(self, tag, textdata, step=None):
        """Saves a text summary.

        Args:
            tag: str: label for this data
            textdata: string, or 1D/2D list/numpy array of strings
            step: int: training step
        Note: markdown formatting is rendered by tensorboard.
        """
        if step is None:
            step = self._step
        else:
            self._step = step
        smd = tf.compat.v1.SummaryMetadata(
                plugin_data=tf.compat.v1.SummaryMetadata.PluginData(plugin_name='text'))
        if isinstance(textdata, (str, bytes)):
            tensor = tf.make_tensor_proto(
                    values=[textdata.encode(encoding='utf_8')], shape=(1,))
        else:
            textdata = np.array(textdata)    # convert lists, jax arrays, etc.
            datashape = np.shape(textdata)
            if len(datashape) == 1:
                tensor = tf.make_tensor_proto(
                        values=[td.encode(encoding='utf_8') for td in textdata],
                        shape=(datashape[0],))
            elif len(datashape) == 2:
                tensor = tf.make_tensor_proto(
                        values=[
                                td.encode(encoding='utf_8') for td in np.reshape(textdata, -1)
                        ],
                        shape=(datashape[0], datashape[1]))
        summary = tf.compat.v1.Summary(
                value=[tf.compat.v1.Summary.Value(
                        tag=tag, metadata=smd, tensor=tensor)])
        self.add_summary(summary, step)

    def checkpoint(self, tag, optimiser_state, step, loss=None):
        """Saves a copy of the model parameters using pickle

        Args:
            tag: str: label for this data
            params: a PyTree containing the parameters of the model
            step: int: training step
            monitor: metric to evaluate the performance of the model
            best_only: if True, models with a lower performance will not be saved
        """
        parent = os.path.join(self.log_dir, "checkpoints")
        os.makedirs(parent, exist_ok=True)

        if not isinstance(optimiser_state, jax.experimental.optimizers.JoinPoint):
            optimiser_state = jax.experimental.optimizers.unpack_optimizer_state(optimiser_state)

        filepath_last = os.path.join(parent, str(tag) + "_last.pickle")
        with open(filepath_last, "wb") as f:
            pickle.dump(optimiser_state, f)

        if loss is None or loss > self._best_loss:
            return
        self._best_loss = loss
        filepath_best = os.path.join(parent, str(tag) + "_best.pickle")
        with open(filepath_best, "wb") as f:
            pickle.dump(optimiser_state, f)
        return True