def text(self, tag, textdata, step): """Saves a text summary. Args: tag: str: label for this data textdata: string, or 1D/2D list/numpy array of strings step: int: training step Note: markdown formatting is rendered by tensorboard. """ smd = SummaryMetadata(plugin_data=SummaryMetadata.PluginData( plugin_name='text')) if isinstance(textdata, (str, bytes)): tensor = tf.make_tensor_proto( values=[textdata.encode(encoding='utf_8')], shape=(1, )) else: textdata = onp.array(textdata) # convert lists, jax arrays, etc. datashape = onp.shape(textdata) if len(datashape) == 1: tensor = tf.make_tensor_proto( values=[td.encode(encoding='utf_8') for td in textdata], shape=(datashape[0], )) elif len(datashape) == 2: tensor = tf.make_tensor_proto(values=[ td.encode(encoding='utf_8') for td in onp.reshape(textdata, -1) ], shape=(datashape[0], datashape[1])) summary = Summary( value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)]) self._add_summary(summary, step)
def histogram(self, tag, values, bins, step): """Saves histogram of values. Args: tag: str: label for this data values: ndarray: will be flattened by this routine bins: number of bins in histogram, or a sequence defining a monotonically increasing array of bin edges, including the rightmost edge. step: int: training step """ values = onp.array(values) bins = onp.array(bins) values = onp.reshape(values, -1) counts, limits = onp.histogram(values, bins=bins) # boundary logic # TODO(flax-dev) Investigate whether this logic can be simplified. cum_counts = onp.cumsum(onp.greater(counts, 0, dtype=onp.int32)) start, end = onp.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side='right') start, end = int(start), int(end) + 1 counts = (counts[start - 1:end] if start > 0 else onp.concatenate([[0], counts[:end]])) limits = limits[start:end + 1] sum_sq = values.dot(values) histo = HistogramProto(min=values.min(), max=values.max(), num=len(values), sum=values.sum(), sum_squares=sum_sq, bucket_limit=limits.tolist(), bucket=counts.tolist()) summary = Summary(value=[Summary.Value(tag=tag, histo=histo)]) self._add_summary(summary, step)
def scalar(self, tag, value, step): """Saves scalar value. Args: tag: str: label for this data value: int/float: number to log step: int: training step """ value = float(onp.array(value)) summary = Summary(value=[Summary.Value(tag=tag, simple_value=value)]) self._add_summary(summary, step)
def saveToTensorBoard(self, image, tag, batch=None): """ save a protobuf image with a given tag with the writter class :param image: a protobuf image :param tag: str, the tag of the image :param batch: the current number of steps :return: """ image_summary = Summary.Value(tag=tag, image=image) summary_value = Summary(value=[image_summary]) self.writer.add_summary(summary_value, global_step=batch)
def image(self, tag, image, step): """Saves RGB image summary from onp.ndarray [H,W], [H,W,1], or [H,W,3]. Args: tag: str: label for this data image: ndarray: [H,W], [H,W,1], [H,W,3] save image in greyscale or colors. Pixel values should be in the range [0, 1]. step: int: training step """ image = onp.array(image) if len(onp.shape(image)) == 2: image = image[:, :, onp.newaxis] if onp.shape(image)[-1] == 1: image = onp.repeat(image, 3, axis=-1) image_strio = io.BytesIO() plt.imsave(image_strio, image, vmin=0., vmax=1., format='png') image_summary = Summary.Image( encoded_image_string=image_strio.getvalue(), colorspace=3, height=image.shape[0], width=image.shape[1]) summary = Summary(value=[Summary.Value(tag=tag, image=image_summary)]) self._add_summary(summary, step)
def audio(self, tag, audiodata, step, sample_rate=44100): """Saves audio. NB: single channel only right now. Args: tag: str: label for this data audiodata: ndarray [Nsamples,]: audo data to be saves as wave. The data will be clipped to [-1, 1]. step: int: training step sample_rate: sample rate of passed in audio buffer """ audiodata = onp.array(audiodata) audiodata = onp.clip(onp.squeeze(audiodata), -1, 1) if audiodata.ndim != 1: raise ValueError('Audio data must be 1D.') # convert from [-1, 1] -> [-2^15-1, 2^15-1] sample_list = (32767.0 * audiodata).astype(int).tolist() wio = io.BytesIO() wav_buf = wave.open(wio, 'wb') wav_buf.setnchannels(1) wav_buf.setsampwidth(2) wav_buf.setframerate(sample_rate) enc = b''.join([struct.pack('<h', v) for v in sample_list]) wav_buf.writeframes(enc) wav_buf.close() encoded_audio_bytes = wio.getvalue() wio.close() audio = Summary.Audio( sample_rate=sample_rate, num_channels=1, length_frames=len(sample_list), encoded_audio_string=encoded_audio_bytes, content_type='audio/wav') summary = Summary(value=[Summary.Value(tag=tag, audio=audio)]) self._add_summary(summary, step)