class SummaryWriter(object): """Saves data in event and summary protos for tensorboard.""" def __init__(self, log_dir): """Create a new SummaryWriter. Args: log_dir: path to record tfevents files in. """ # If needed, create log_dir directory as well as missing parent directories. if not gfile.isdir(log_dir): gfile.makedirs(log_dir) self._event_writer = EventFileWriter(log_dir, 10, 120, None) self._step = 0 self._closed = False def add_summary(self, summary, step): event = event_pb2.Event(summary=summary) event.wall_time = time.time() if step is not None: event.step = int(step) self._event_writer.add_event(event) def close(self): """Close SummaryWriter. Final!""" if not self._closed: self._event_writer.close() self._closed = True del self._event_writer def __del__(self): # safe? self.close() def flush(self): self._event_writer.flush() def scalar(self, tag, value, step=None): """Saves scalar value. Args: tag: str: label for this data value: int/float: number to log step: int: training step """ value = float(onp.array(value)) if step is None: step = self._step else: self._step = step summary = Summary(value=[Summary.Value(tag=tag, simple_value=value)]) self.add_summary(summary, step) def image(self, tag, image, step=None): """Saves RGB image summary from onp.ndarray [H,W], [H,W,1], or [H,W,3]. Args: tag: str: label for this data image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors/ step: int: training step """ image = onp.array(image) if step is None: step = self._step else: self._step = step if len(onp.shape(image)) == 2: image = image[:, :, onp.newaxis] if onp.shape(image)[-1] == 1: image = onp.repeat(image, 3, axis=-1) image_strio = io.BytesIO() plt.imsave(image_strio, image, format='png') image_summary = Summary.Image( encoded_image_string=image_strio.getvalue(), colorspace=3, height=image.shape[0], width=image.shape[1]) summary = Summary(value=[Summary.Value(tag=tag, image=image_summary)]) self.add_summary(summary, step) def images(self, tag, images, step=None, rows=None, cols=None): """Saves (rows, cols) tiled images from onp.ndarray. If either rows or cols aren't given, they are determined automatically from the size of the image batch, if neither are given a long column of images is produced. This truncates the image batch rather than padding if it doesn't fill the final row. Args: tag: str: label for this data images: ndarray: [N,H,W,1] or [N,H,W,3] to tile in 2d step: int: training step rows: int: number of rows in tile cols: int: number of columns in tile """ images = onp.array(images) if step is None: step = self._step else: self._step = step n_images = onp.shape(images)[0] if rows is None and cols is None: rows = 1 cols = n_images elif rows is None: rows = n_images // cols elif cols is None: cols = n_images // rows tiled_images = _pack_images(images, rows, cols) self.image(tag, tiled_images, step=step) def plot(self, tag, mpl_plt, step=None, close_plot=True): """Saves matplotlib plot output to summary image. Args: tag: str: label for this data mpl_plt: matplotlib stateful pyplot object with prepared plotting state step: int: training step close_plot: bool: automatically closes plot """ if step is None: step = self._step else: self._step = step fig = mpl_plt.get_current_fig_manager() img_w, img_h = fig.canvas.get_width_height() image_buf = io.BytesIO() mpl_plt.savefig(image_buf, format='png') image_summary = Summary.Image( encoded_image_string=image_buf.getvalue(), colorspace=4, # RGBA height=img_h, width=img_w) summary = Summary(value=[Summary.Value(tag=tag, image=image_summary)]) self.add_summary(summary, step) if close_plot: mpl_plt.close() def audio(self, tag, audiodata, step=None, sample_rate=44100): """Saves audio. NB: single channel only right now. Args: tag: str: label for this data audiodata: ndarray [Nsamples,]: data between (-1.0,1.0) to save as wave step: int: training step sample_rate: sample rate of passed in audio buffer """ audiodata = onp.array(audiodata) if step is None: step = self._step else: self._step = step audiodata = onp.clip(onp.squeeze(audiodata), -1, 1) if audiodata.ndim != 1: raise ValueError('Audio data must be 1D.') sample_list = (32767.0 * audiodata).astype(int).tolist() wio = io.BytesIO() wav_buf = wave.open(wio, 'wb') wav_buf.setnchannels(1) wav_buf.setsampwidth(2) wav_buf.setframerate(sample_rate) enc = b''.join([struct.pack('<h', v) for v in sample_list]) wav_buf.writeframes(enc) wav_buf.close() encoded_audio_bytes = wio.getvalue() wio.close() audio = Summary.Audio(sample_rate=sample_rate, num_channels=1, length_frames=len(sample_list), encoded_audio_string=encoded_audio_bytes, content_type='audio/wav') summary = Summary(value=[Summary.Value(tag=tag, audio=audio)]) self.add_summary(summary, step) def histogram(self, tag, values, bins, step=None): """Saves histogram of values. Args: tag: str: label for this data values: ndarray: will be flattened by this routine bins: number of bins in histogram, or array of bins for onp.histogram step: int: training step """ if step is None: step = self._step else: self._step = step values = onp.array(values) bins = onp.array(bins) values = onp.reshape(values, -1) counts, limits = onp.histogram(values, bins=bins) # boundary logic cum_counts = onp.cumsum(onp.greater(counts, 0, dtype=onp.int32)) start, end = onp.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side='right') start, end = int(start), int(end) + 1 counts = (counts[start - 1:end] if start > 0 else onp.concatenate([[0], counts[:end]])) limits = limits[start:end + 1] sum_sq = values.dot(values) histo = HistogramProto(min=values.min(), max=values.max(), num=len(values), sum=values.sum(), sum_squares=sum_sq, bucket_limit=limits.tolist(), bucket=counts.tolist()) summary = Summary(value=[Summary.Value(tag=tag, histo=histo)]) self.add_summary(summary, step) def text(self, tag, textdata, step=None): """Saves a text summary. Args: tag: str: label for this data textdata: string, or 1D/2D list/numpy array of strings step: int: training step Note: markdown formatting is rendered by tensorboard. """ if step is None: step = self._step else: self._step = step smd = SummaryMetadata(plugin_data=SummaryMetadata.PluginData( plugin_name='text')) if isinstance(textdata, (str, bytes)): tensor = tf.make_tensor_proto( values=[textdata.encode(encoding='utf_8')], shape=(1, )) else: textdata = onp.array(textdata) # convert lists, jax arrays, etc. datashape = onp.shape(textdata) if len(datashape) == 1: tensor = tf.make_tensor_proto( values=[td.encode(encoding='utf_8') for td in textdata], shape=(datashape[0], )) elif len(datashape) == 2: tensor = tf.make_tensor_proto(values=[ td.encode(encoding='utf_8') for td in onp.reshape(textdata, -1) ], shape=(datashape[0], datashape[1])) summary = Summary( value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)]) self.add_summary(summary, step)
class SummaryWriter(object): """Saves data in event and summary protos for tensorboard.""" def __init__(self, log_dir): """Create a new SummaryWriter. Args: log_dir: path to record tfevents files in. """ # If needed, create log_dir directory as well as missing parent directories. if not gfile.isdir(log_dir): gfile.makedirs(log_dir) self._event_writer = EventFileWriter(log_dir, 10, 120, None) self._closed = False def _add_summary(self, summary, step): event = event_pb2.Event(summary=summary) event.wall_time = time.time() if step is not None: event.step = int(step) self._event_writer.add_event(event) def close(self): """Close SummaryWriter. Final!""" if not self._closed: self._event_writer.close() self._closed = True del self._event_writer def flush(self): self._event_writer.flush() def scalar(self, tag, value, step): """Saves scalar value. Args: tag: str: label for this data value: int/float: number to log step: int: training step """ value = float(onp.array(value)) summary = Summary(value=[Summary.Value(tag=tag, simple_value=value)]) self._add_summary(summary, step) def image(self, tag, image, step): """Saves RGB image summary from onp.ndarray [H,W], [H,W,1], or [H,W,3]. Args: tag: str: label for this data image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors. Pixel values should be in the range [0, 1]. step: int: training step """ image = onp.array(image) if len(onp.shape(image)) == 2: image = image[:, :, onp.newaxis] if onp.shape(image)[-1] == 1: image = onp.repeat(image, 3, axis=-1) image_strio = io.BytesIO() plt.imsave(image_strio, image, vmin=0., vmax=1., format='png') image_summary = Summary.Image( encoded_image_string=image_strio.getvalue(), colorspace=3, height=image.shape[0], width=image.shape[1]) summary = Summary(value=[Summary.Value(tag=tag, image=image_summary)]) self._add_summary(summary, step) def audio(self, tag, audiodata, step, sample_rate=44100): """Saves audio. NB: single channel only right now. Args: tag: str: label for this data audiodata: ndarray [Nsamples,]: audo data to be saves as wave. The data will be clipped to [-1, 1]. step: int: training step sample_rate: sample rate of passed in audio buffer """ audiodata = onp.array(audiodata) audiodata = onp.clip(onp.squeeze(audiodata), -1, 1) if audiodata.ndim != 1: raise ValueError('Audio data must be 1D.') # convert from [-1, 1] -> [-2^15-1, 2^15-1] sample_list = (32767.0 * audiodata).astype(int).tolist() wio = io.BytesIO() wav_buf = wave.open(wio, 'wb') wav_buf.setnchannels(1) wav_buf.setsampwidth(2) wav_buf.setframerate(sample_rate) enc = b''.join([struct.pack('<h', v) for v in sample_list]) wav_buf.writeframes(enc) wav_buf.close() encoded_audio_bytes = wio.getvalue() wio.close() audio = Summary.Audio(sample_rate=sample_rate, num_channels=1, length_frames=len(sample_list), encoded_audio_string=encoded_audio_bytes, content_type='audio/wav') summary = Summary(value=[Summary.Value(tag=tag, audio=audio)]) self._add_summary(summary, step) def histogram(self, tag, values, bins, step): """Saves histogram of values. Args: tag: str: label for this data values: ndarray: will be flattened by this routine bins: number of bins in histogram, or a sequence defining a monotonically increasing array of bin edges, including the rightmost edge. step: int: training step """ values = onp.array(values) bins = onp.array(bins) values = onp.reshape(values, -1) counts, limits = onp.histogram(values, bins=bins) # boundary logic # TODO(flax-dev) Investigate whether this logic can be simplified. cum_counts = onp.cumsum(onp.greater(counts, 0, dtype=onp.int32)) start, end = onp.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side='right') start, end = int(start), int(end) + 1 counts = (counts[start - 1:end] if start > 0 else onp.concatenate([[0], counts[:end]])) limits = limits[start:end + 1] sum_sq = values.dot(values) histo = HistogramProto(min=values.min(), max=values.max(), num=len(values), sum=values.sum(), sum_squares=sum_sq, bucket_limit=limits.tolist(), bucket=counts.tolist()) summary = Summary(value=[Summary.Value(tag=tag, histo=histo)]) self._add_summary(summary, step) def text(self, tag, textdata, step): """Saves a text summary. Args: tag: str: label for this data textdata: string, or 1D/2D list/numpy array of strings step: int: training step Note: markdown formatting is rendered by tensorboard. """ smd = SummaryMetadata(plugin_data=SummaryMetadata.PluginData( plugin_name='text')) if isinstance(textdata, (str, bytes)): tensor = tf.make_tensor_proto( values=[textdata.encode(encoding='utf_8')], shape=(1, )) else: textdata = onp.array(textdata) # convert lists, jax arrays, etc. datashape = onp.shape(textdata) if len(datashape) == 1: tensor = tf.make_tensor_proto( values=[td.encode(encoding='utf_8') for td in textdata], shape=(datashape[0], )) elif len(datashape) == 2: tensor = tf.make_tensor_proto(values=[ td.encode(encoding='utf_8') for td in onp.reshape(textdata, -1) ], shape=(datashape[0], datashape[1])) summary = Summary( value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)]) self._add_summary(summary, step)
class TFWriter(tensorboard.MetricWriter): """ TFWriter uses tensorflow file writers and summary operations to write out tfevent files containing scalar batch metrics. """ def __init__(self) -> None: super().__init__() self.writer = EventFileWriter(logdir=str(tensorboard.get_base_path( {})), filename_suffix=None) self.createSummary = tf.Summary # _seen_summary_tags is vendored from TensorFlow: tensorflow/python/summary/writer/writer.py # This set contains tags of Summary Values that have been encountered # already. The motivation here is that the SummaryWriter only keeps the # metadata property (which is a SummaryMetadata proto) of the first Summary # Value encountered for each tag. The SummaryWriter strips away the # SummaryMetadata for all subsequent Summary Values with tags seen # previously. This saves space. self._seen_summary_tags: Set[str] = set() def add_scalar(self, name: str, value: Union[int, float, np.number], step: int) -> None: summary = self.createSummary() summary_value = summary.value.add() summary_value.tag = name summary_value.simple_value = value self._add_summary(summary, step) def _add_summary(self, summary: Union[str, summary_pb2.Summary], global_step: Optional[int] = None) -> None: """ _add_summary is vendored from TensorFlow: tensorflow/python/summary/writer/writer.py Adds a `Summary` protocol buffer to the event file. This method wraps the provided summary in an `Event` protocol buffer and adds it to the event file. You can pass the result of evaluating any summary op, using `tf.Session.run` or `tf.Tensor.eval`, to this function. Alternatively, you can pass a `tf.compat.v1.Summary` protocol buffer that you populate with your own data. The latter is commonly done to report evaluation results in event files. Args: summary: A `Summary` protocol buffer, optionally serialized as a string. global_step: Number. Optional global step value to record with the summary. """ if isinstance(summary, bytes): summ = summary_pb2.Summary() summ.ParseFromString(summary) summary = summ # We strip metadata from values with tags that we have seen before in order # to save space - we just store the metadata on the first value with a # specific tag. for value in summary.value: if not value.metadata: continue if value.tag in self._seen_summary_tags: # This tag has been encountered before. Strip the metadata. value.ClearField("metadata") continue # We encounter a value with a tag we have not encountered previously. And # it has metadata. Remember to strip metadata from future values with this # tag string. self._seen_summary_tags.add(value.tag) event = event_pb2.Event(summary=summary) self._add_event(event, global_step) def _add_event(self, event: event_pb2.Event, step: Optional[int]) -> None: # _add_event is vendored from TensorFlow: tensorflow/python/summary/writer/writer.py event.wall_time = time.time() if step is not None: event.step = int(step) self.writer.add_event(event) def reset(self) -> None: self.writer.close() self.writer.reopen()
class SummaryWriter(object): """Saves data in event and summary protos for tensorboard.""" def __init__(self, log_dir, enable=True): """Create a new SummaryWriter. Args: log_dir: path to record tfevents files in. enable: bool: if False don't actually write or flush data. Used in multihost training. """ # If needed, create log_dir directory as well as missing parent directories. if not tf.io.gfile.isdir(log_dir): tf.io.gfile.makedirs(log_dir) log_dir = os.path.join(log_dir, "{}".format(len(os.listdir(log_dir)))) logging.info("Created new logger at: {}".format(log_dir)) self.log_dir = log_dir self._event_writer = EventFileWriter(log_dir, 10, 120, None) self._step = 0 self._closed = False self._enabled = enable self._best_loss = float("inf") def add_summary(self, summary, step): if not self._enabled: return event = event_pb2.Event(summary=summary) event.wall_time = time.time() if step is not None: event.step = int(step) self._event_writer.add_event(event) def close(self): """Close SummaryWriter. Final!""" if not self._closed: self._event_writer.close() self._closed = True del self._event_writer def __del__(self): # safe? # TODO(afrozm): Sometimes this complains with # `TypeError: 'NoneType' object is not callable` try: self.close() except Exception: # pylint: disable=broad-except pass def flush(self): if not self._enabled: return self._event_writer.flush() def scalar(self, tag, value, step=None): """Saves scalar value. Args: tag: str: label for this data value: int/float: number to log step: int: training step """ value = float(np.array(value)) if step is None: step = self._step else: self._step = step summary = tf.compat.v1.Summary( value=[tf.compat.v1.Summary.Value(tag=tag, simple_value=value)]) self.add_summary(summary, step) def image(self, tag, image, step=None): """Saves RGB image summary from np.ndarray [H,W], [H,W,1], or [H,W,3]. Args: tag: str: label for this data image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors/ step: int: training step """ image = np.array(image) if step is None: step = self._step else: self._step = step if len(np.shape(image)) == 2: image = image[:, :, np.newaxis] if np.shape(image)[-1] == 1: image = np.repeat(image, 3, axis=-1) image_strio = io.BytesIO() plt.imsave(image_strio, image, format='png') image_summary = tf.compat.v1.Summary.Image( encoded_image_string=image_strio.getvalue(), colorspace=3, height=image.shape[0], width=image.shape[1]) summary = tf.compat.v1.Summary( value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)]) self.add_summary(summary, step) def images(self, tag, images, step=None, rows=None, cols=None): """Saves (rows, cols) tiled images from np.ndarray. If either rows or cols aren't given, they are determined automatically from the size of the image batch, if neither are given a long column of images is produced. This truncates the image batch rather than padding if it doesn't fill the final row. Args: tag: str: label for this data images: ndarray: [N,H,W,1] or [N,H,W,3] to tile in 2d step: int: training step rows: int: number of rows in tile cols: int: number of columns in tile """ images = np.array(images) if step is None: step = self._step else: self._step = step n_images = np.shape(images)[0] if rows is None and cols is None: rows = 1 cols = n_images elif rows is None: rows = n_images // cols elif cols is None: cols = n_images // rows tiled_images = _pack_images(images, rows, cols) self.image(tag, tiled_images, step=step) def plot(self, tag, mpl_plt, step=None, close_plot=True): """Saves matplotlib plot output to summary image. Args: tag: str: label for this data mpl_plt: matplotlib stateful pyplot object with prepared plotting state step: int: training step close_plot: bool: automatically closes plot """ if step is None: step = self._step else: self._step = step fig = mpl_plt.get_current_fig_manager() img_w, img_h = fig.canvas.get_width_height() image_buf = io.BytesIO() mpl_plt.savefig(image_buf, format='png') image_summary = tf.compat.v1.Summary.Image( encoded_image_string=image_buf.getvalue(), colorspace=4, # RGBA height=img_h, width=img_w) summary = tf.compat.v1.Summary( value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)]) self.add_summary(summary, step) if close_plot: mpl_plt.close() def figure(self, tag, fig, step=None, close_plot=True): """Saves matplotlib plot output to summary image. Args: tag: str: label for this data mpl_plt: matplotlib stateful pyplot object with prepared plotting state step: int: training step close_plot: bool: automatically closes plot """ if step is None: step = self._step else: self._step = step img_w, img_h = fig.canvas.get_width_height() image_buf = io.BytesIO() plt.savefig(image_buf, format='png') image_summary = tf.compat.v1.Summary.Image( encoded_image_string=image_buf.getvalue(), colorspace=4, # RGBA height=img_h, width=img_w) summary = tf.compat.v1.Summary( value=[tf.compat.v1.Summary.Value(tag=tag, image=image_summary)]) self.add_summary(summary, step) if close_plot: plt.close(fig) def audio(self, tag, audiodata, step=None, sample_rate=44100): """Saves audio. NB: single channel only right now. Args: tag: str: label for this data audiodata: ndarray [Nsamples,]: data between (-1.0,1.0) to save as wave step: int: training step sample_rate: sample rate of passed in audio buffer """ audiodata = np.array(audiodata) if step is None: step = self._step else: self._step = step audiodata = np.clip(np.squeeze(audiodata), -1, 1) if audiodata.ndim != 1: raise ValueError('Audio data must be 1D.') sample_list = (32767.0 * audiodata).astype(int).tolist() wio = io.BytesIO() wav_buf = wave.open(wio, 'wb') wav_buf.setnchannels(1) wav_buf.setsampwidth(2) wav_buf.setframerate(sample_rate) enc = b''.join([struct.pack('<h', v) for v in sample_list]) wav_buf.writeframes(enc) wav_buf.close() encoded_audio_bytes = wio.getvalue() wio.close() audio = tf.compat.v1.Summary.Audio( sample_rate=sample_rate, num_channels=1, length_frames=len(sample_list), encoded_audio_string=encoded_audio_bytes, content_type='audio/wav') summary = tf.compat.v1.Summary( value=[tf.compat.v1.Summary.Value(tag=tag, audio=audio)]) self.add_summary(summary, step) def histogram(self, tag, values, bins, step=None): """Saves histogram of values. Args: tag: str: label for this data values: ndarray: will be flattened by this routine bins: number of bins in histogram, or array of bins for np.histogram step: int: training step """ if step is None: step = self._step else: self._step = step values = np.array(values) bins = np.array(bins) values = np.reshape(values, -1) counts, limits = np.histogram(values, bins=bins) # boundary logic cum_counts = np.cumsum(np.greater(counts, 0, dtype=np.int32)) start, end = np.searchsorted( cum_counts, [0, cum_counts[-1] - 1], side='right') start, end = int(start), int(end) + 1 counts = ( counts[start - 1:end] if start > 0 else np.concatenate([[0], counts[:end]])) limits = limits[start:end + 1] sum_sq = values.dot(values) histo = tf.compat.v1.HistogramProto( min=values.min(), max=values.max(), num=len(values), sum=values.sum(), sum_squares=sum_sq, bucket_limit=limits.tolist(), bucket=counts.tolist()) summary = tf.compat.v1.Summary( value=[tf.compat.v1.Summary.Value(tag=tag, histo=histo)]) self.add_summary(summary, step) def text(self, tag, textdata, step=None): """Saves a text summary. Args: tag: str: label for this data textdata: string, or 1D/2D list/numpy array of strings step: int: training step Note: markdown formatting is rendered by tensorboard. """ if step is None: step = self._step else: self._step = step smd = tf.compat.v1.SummaryMetadata( plugin_data=tf.compat.v1.SummaryMetadata.PluginData(plugin_name='text')) if isinstance(textdata, (str, bytes)): tensor = tf.make_tensor_proto( values=[textdata.encode(encoding='utf_8')], shape=(1,)) else: textdata = np.array(textdata) # convert lists, jax arrays, etc. datashape = np.shape(textdata) if len(datashape) == 1: tensor = tf.make_tensor_proto( values=[td.encode(encoding='utf_8') for td in textdata], shape=(datashape[0],)) elif len(datashape) == 2: tensor = tf.make_tensor_proto( values=[ td.encode(encoding='utf_8') for td in np.reshape(textdata, -1) ], shape=(datashape[0], datashape[1])) summary = tf.compat.v1.Summary( value=[tf.compat.v1.Summary.Value( tag=tag, metadata=smd, tensor=tensor)]) self.add_summary(summary, step) def checkpoint(self, tag, optimiser_state, step, loss=None): """Saves a copy of the model parameters using pickle Args: tag: str: label for this data params: a PyTree containing the parameters of the model step: int: training step monitor: metric to evaluate the performance of the model best_only: if True, models with a lower performance will not be saved """ parent = os.path.join(self.log_dir, "checkpoints") os.makedirs(parent, exist_ok=True) if not isinstance(optimiser_state, jax.experimental.optimizers.JoinPoint): optimiser_state = jax.experimental.optimizers.unpack_optimizer_state(optimiser_state) filepath_last = os.path.join(parent, str(tag) + "_last.pickle") with open(filepath_last, "wb") as f: pickle.dump(optimiser_state, f) if loss is None or loss > self._best_loss: return self._best_loss = loss filepath_best = os.path.join(parent, str(tag) + "_best.pickle") with open(filepath_best, "wb") as f: pickle.dump(optimiser_state, f) return True