def image(name, data, step=None, max_outputs=3, description=None):
    """Write an image summary.

  Arguments:
    name: A name for this summary. The summary tag used for TensorBoard will
      be this name prefixed by any active name scopes.
    data: A `Tensor` representing pixel data with shape `[k, h, w, c]`,
      where `k` is the number of images, `h` and `w` are the height and
      width of the images, and `c` is the number of channels, which
      should be 1, 2, 3, or 4 (grayscale, grayscale with alpha, RGB, RGBA).
      Any of the dimensions may be statically unknown (i.e., `None`).
      Floating point data will be clipped to the range [0,1).
    step: Explicit `int64`-castable monotonic step value for this summary. If
      omitted, this defaults to `tf.summary.experimental.get_step()`, which must
      not be None.
    max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this
      many images will be emitted at each step. When more than
      `max_outputs` many images are provided, the first `max_outputs` many
      images will be used and the rest silently discarded.
    description: Optional long-form description for this summary, as a
      constant `str`. Markdown is supported. Defaults to empty.

  Returns:
    True on success, or false if no summary was emitted because no default
    summary writer was available.

  Raises:
    ValueError: if a default writer exists, but no step was provided and
      `tf.summary.experimental.get_step()` is None.
  """
    summary_metadata = metadata.create_summary_metadata(
        display_name=None, description=description)
    # TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback
    summary_scope = (getattr(tf.summary.experimental, 'summary_scope', None)
                     or tf.summary.summary_scope)
    with summary_scope(name, 'image_summary', values=[data, max_outputs,
                                                      step]) as (tag, _):
        tf.debugging.assert_rank(data, 4)
        tf.debugging.assert_non_negative(max_outputs)
        images = tf.image.convert_image_dtype(data, tf.uint8, saturate=True)
        limited_images = images[:max_outputs]
        encoded_images = tf.map_fn(tf.image.encode_png,
                                   limited_images,
                                   dtype=tf.string,
                                   name='encode_each_image')
        # Workaround for map_fn returning float dtype for an empty elems input.
        encoded_images = tf.cond(
            tf.shape(input=encoded_images)[0] > 0, lambda: encoded_images,
            lambda: tf.constant([], tf.string))
        image_shape = tf.shape(input=images)
        dimensions = tf.stack([
            tf.as_string(image_shape[2], name='width'),
            tf.as_string(image_shape[1], name='height')
        ],
                              name='dimensions')
        tensor = tf.concat([dimensions, encoded_images], axis=0)
        return tf.summary.write(tag=tag,
                                tensor=tensor,
                                step=step,
                                metadata=summary_metadata)
Example #2
0
 def lazy_tensor():
     tf.debugging.assert_rank(data, 4)
     tf.debugging.assert_non_negative(max_outputs)
     images = tf.image.convert_image_dtype(data,
                                           tf.uint8,
                                           saturate=True)
     limited_images = images[:max_outputs]
     encoded_images = tf.map_fn(
         tf.image.encode_png,
         limited_images,
         dtype=tf.string,
         name="encode_each_image",
     )
     # Workaround for map_fn returning float dtype for an empty elems input.
     encoded_images = tf.cond(
         tf.shape(input=encoded_images)[0] > 0,
         lambda: encoded_images,
         lambda: tf.constant([], tf.string),
     )
     image_shape = tf.shape(input=images)
     dimensions = tf.stack(
         [
             tf.as_string(image_shape[2], name="width"),
             tf.as_string(image_shape[1], name="height"),
         ],
         name="dimensions",
     )
     return tf.concat([dimensions, encoded_images], axis=0)
Example #3
0
def _buckets(data, bucket_count=None):
    """Create a TensorFlow op to group data into histogram buckets.

  Arguments:
    data: A `Tensor` of any shape. Must be castable to `float64`.
    bucket_count: Optional positive `int` or scalar `int32` `Tensor`.
  Returns:
    A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is
    a triple `[left_edge, right_edge, count]` for a single bucket.
    The value of `k` is either `bucket_count` or `1` or `0`.
  """
    if bucket_count is None:
        bucket_count = DEFAULT_BUCKET_COUNT
    with tf.name_scope('buckets'):
        tf.debugging.assert_scalar(bucket_count)
        tf.debugging.assert_type(bucket_count, tf.int32)
        data = tf.reshape(data, shape=[-1])  # flatten
        data = tf.cast(data, tf.float64)
        is_empty = tf.equal(tf.size(input=data), 0)

        def when_empty():
            return tf.constant([], shape=(0, 3), dtype=tf.float64)

        def when_nonempty():
            min_ = tf.reduce_min(input_tensor=data)
            max_ = tf.reduce_max(input_tensor=data)
            range_ = max_ - min_
            is_singular = tf.equal(range_, 0)

            def when_nonsingular():
                bucket_width = range_ / tf.cast(bucket_count, tf.float64)
                offsets = data - min_
                bucket_indices = tf.cast(tf.floor(offsets / bucket_width),
                                         dtype=tf.int32)
                clamped_indices = tf.minimum(bucket_indices, bucket_count - 1)
                one_hots = tf.one_hot(clamped_indices, depth=bucket_count)
                bucket_counts = tf.cast(tf.reduce_sum(input_tensor=one_hots,
                                                      axis=0),
                                        dtype=tf.float64)
                edges = tf.linspace(min_, max_, bucket_count + 1)
                # Ensure edges[-1] == max_, which TF's linspace implementation does not
                # do, leaving it subject to the whim of floating point rounding error.
                edges = tf.concat([edges[:-1], [max_]], 0)
                left_edges = edges[:-1]
                right_edges = edges[1:]
                return tf.transpose(
                    a=tf.stack([left_edges, right_edges, bucket_counts]))

            def when_singular():
                center = min_
                bucket_starts = tf.stack([center - 0.5])
                bucket_ends = tf.stack([center + 0.5])
                bucket_counts = tf.stack(
                    [tf.cast(tf.size(input=data), tf.float64)])
                return tf.transpose(
                    a=tf.stack([bucket_starts, bucket_ends, bucket_counts]))

            return tf.cond(is_singular, when_singular, when_nonsingular)

        return tf.cond(is_empty, when_empty, when_nonempty)
Example #4
0
        def when_nonempty():
            min_ = tf.reduce_min(input_tensor=data)
            max_ = tf.reduce_max(input_tensor=data)
            range_ = max_ - min_
            is_singular = tf.equal(range_, 0)

            def when_nonsingular():
                bucket_width = range_ / tf.cast(bucket_count, tf.float64)
                offsets = data - min_
                bucket_indices = tf.cast(tf.floor(offsets / bucket_width),
                                         dtype=tf.int32)
                clamped_indices = tf.minimum(bucket_indices, bucket_count - 1)
                one_hots = tf.one_hot(clamped_indices, depth=bucket_count)
                bucket_counts = tf.cast(tf.reduce_sum(input_tensor=one_hots,
                                                      axis=0),
                                        dtype=tf.float64)
                edges = tf.linspace(min_, max_, bucket_count + 1)
                # Ensure edges[-1] == max_, which TF's linspace implementation does not
                # do, leaving it subject to the whim of floating point rounding error.
                edges = tf.concat([edges[:-1], [max_]], 0)
                left_edges = edges[:-1]
                right_edges = edges[1:]
                return tf.transpose(
                    a=tf.stack([left_edges, right_edges, bucket_counts]))

            def when_singular():
                center = min_
                bucket_starts = tf.stack([center - 0.5])
                bucket_ends = tf.stack([center + 0.5])
                bucket_counts = tf.stack(
                    [tf.cast(tf.size(input=data), tf.float64)])
                return tf.transpose(
                    a=tf.stack([bucket_starts, bucket_ends, bucket_counts]))

            return tf.cond(is_singular, when_singular, when_nonsingular)
Example #5
0
        def when_nonempty():
            min_ = tf.reduce_min(input_tensor=data)
            max_ = tf.reduce_max(input_tensor=data)
            range_ = max_ - min_
            has_single_value = tf.equal(range_, 0)

            def when_multiple_values():
                """When input data contains multiple values."""
                bucket_width = range_ / tf.cast(bucket_count, tf.float64)
                offsets = data - min_
                bucket_indices = tf.cast(tf.floor(offsets / bucket_width),
                                         dtype=tf.int32)
                clamped_indices = tf.minimum(bucket_indices, bucket_count - 1)
                # Use float64 instead of float32 to avoid accumulating floating point error
                # later in tf.reduce_sum when summing more than 2^24 individual `1.0` values.
                # See https://github.com/tensorflow/tensorflow/issues/51419 for details.
                one_hots = tf.one_hot(clamped_indices,
                                      depth=bucket_count,
                                      dtype=tf.float64)
                bucket_counts = tf.cast(
                    tf.reduce_sum(input_tensor=one_hots, axis=0),
                    dtype=tf.float64,
                )
                edges = tf.linspace(min_, max_, bucket_count + 1)
                # Ensure edges[-1] == max_, which TF's linspace implementation does not
                # do, leaving it subject to the whim of floating point rounding error.
                edges = tf.concat([edges[:-1], [max_]], 0)
                left_edges = edges[:-1]
                right_edges = edges[1:]
                return tf.transpose(
                    a=tf.stack([left_edges, right_edges, bucket_counts]))

            def when_single_value():
                """When input data contains a single unique value."""
                # Left and right edges are the same for single value input.
                edges = tf.fill([bucket_count], max_)
                # Bucket counts are 0 except the last bucket (if bucket_count > 0),
                # which is `data_size`. Ensure that the resulting counts vector has
                # length `bucket_count` always, including the bucket_count==0 case.
                zeroes = tf.fill([bucket_count], 0)
                bucket_counts = tf.cast(
                    tf.concat([zeroes[:-1], [data_size]], 0)[:bucket_count],
                    dtype=tf.float64,
                )
                return tf.transpose(a=tf.stack([edges, edges, bucket_counts]))

            return tf.cond(has_single_value, when_single_value,
                           when_multiple_values)
Example #6
0
 def lazy_tensor():
     tf.debugging.assert_rank(data, 3)
     tf.debugging.assert_non_negative(max_outputs)
     limited_audio = data[:max_outputs]
     encode_fn = functools.partial(audio_ops.encode_wav,
                                   sample_rate=sample_rate)
     encoded_audio = tf.map_fn(encode_fn,
                               limited_audio,
                               dtype=tf.string,
                               name='encode_each_audio')
     # Workaround for map_fn returning float dtype for an empty elems input.
     encoded_audio = tf.cond(
         tf.shape(input=encoded_audio)[0] > 0, lambda: encoded_audio,
         lambda: tf.constant([], tf.string))
     limited_labels = tf.tile([''], tf.shape(input=limited_audio)[:1])
     return tf.transpose(a=tf.stack([encoded_audio, limited_labels]))
Example #7
0
        def when_nonempty():
            min_ = tf.reduce_min(input_tensor=data)
            max_ = tf.reduce_max(input_tensor=data)
            range_ = max_ - min_
            is_singular = tf.equal(range_, 0)

            def when_nonsingular():
                bucket_width = range_ / tf.cast(bucket_count, tf.float64)
                offsets = data - min_
                bucket_indices = tf.cast(tf.floor(offsets / bucket_width),
                                         dtype=tf.int32)
                clamped_indices = tf.minimum(bucket_indices, bucket_count - 1)
                # Use float64 instead of float32 to avoid accumulating floating point error
                # later in tf.reduce_sum when summing more than 2^24 individual `1.0` values.
                # See https://github.com/tensorflow/tensorflow/issues/51419 for details.
                one_hots = tf.one_hot(clamped_indices,
                                      depth=bucket_count,
                                      dtype=tf.float64)
                bucket_counts = tf.cast(
                    tf.reduce_sum(input_tensor=one_hots, axis=0),
                    dtype=tf.float64,
                )
                edges = tf.linspace(min_, max_, bucket_count + 1)
                # Ensure edges[-1] == max_, which TF's linspace implementation does not
                # do, leaving it subject to the whim of floating point rounding error.
                edges = tf.concat([edges[:-1], [max_]], 0)
                left_edges = edges[:-1]
                right_edges = edges[1:]
                return tf.transpose(
                    a=tf.stack([left_edges, right_edges, bucket_counts]))

            def when_singular():
                center = min_
                bucket_starts = tf.stack([center - 0.5])
                bucket_ends = tf.stack([center + 0.5])
                bucket_counts = tf.stack(
                    [tf.cast(tf.size(input=data), tf.float64)])
                return tf.transpose(
                    a=tf.stack([bucket_starts, bucket_ends, bucket_counts]))

            return tf.cond(is_singular, when_singular, when_nonsingular)
Example #8
0
        def lazy_tensor():
            tf.debugging.assert_rank(data, 3)
            tf.debugging.assert_non_negative(max_outputs)
            limited_audio = data[:max_outputs]

            encode_fn = functools.partial(
                audio_ops.encode_wav, sample_rate=sample_rate
            )
            if lengths is not None:
                tf.debugging.assert_rank(lengths, 1)
                limited_lengths = lengths[:max_outputs]

                def encode_with_length(datum_and_length):
                    datum, length = datum_and_length
                    return encode_fn(datum[:length])

                encoded_audio = tf.map_fn(
                    encode_with_length,
                    (limited_audio, limited_lengths),
                    dtype=tf.string,
                    name="encode_each_audio",
                )
            else:
                encoded_audio = tf.map_fn(
                    encode_fn,
                    limited_audio,
                    dtype=tf.string,
                    name="encode_each_audio",
                )
            # Workaround for map_fn returning float dtype for an empty elems input.
            encoded_audio = tf.cond(
                tf.shape(input=encoded_audio)[0] > 0,
                lambda: encoded_audio,
                lambda: tf.constant([], tf.string),
            )
            limited_labels = tf.tile([""], tf.shape(input=limited_audio)[:1])
            return tf.transpose(a=tf.stack([encoded_audio, limited_labels]))
Example #9
0
def _buckets(data, bucket_count=None):
    """Create a TensorFlow op to group data into histogram buckets.

    Arguments:
      data: A `Tensor` of any shape. Must be castable to `float64`.
      bucket_count: Optional non-negative `int` or scalar `int32` `Tensor`,
        defaults to 30.
    Returns:
      A `Tensor` of shape `[k, 3]` and type `float64`. The `i`th row is
      a triple `[left_edge, right_edge, count]` for a single bucket.
      The value of `k` is either `bucket_count` or `0` (when input data
      is empty).
    """
    if bucket_count is None:
        bucket_count = DEFAULT_BUCKET_COUNT
    with tf.name_scope("buckets"):
        tf.debugging.assert_scalar(bucket_count)
        tf.debugging.assert_type(bucket_count, tf.int32)
        # Treat a negative bucket count as zero.
        bucket_count = tf.math.maximum(0, bucket_count)
        data = tf.reshape(data, shape=[-1])  # flatten
        data = tf.cast(data, tf.float64)
        data_size = tf.size(input=data)
        is_empty = tf.logical_or(tf.equal(data_size, 0),
                                 tf.less_equal(bucket_count, 0))

        def when_empty():
            """When input data is empty or bucket_count is zero.

            1. If bucket_count is specified as zero, an empty tensor of shape
              (0, 3) will be returned.
            2. If the input data is empty, a tensor of shape (bucket_count, 3)
              of all zero values will be returned.
            """
            return tf.zeros((bucket_count, 3), dtype=tf.float64)

        def when_nonempty():
            min_ = tf.reduce_min(input_tensor=data)
            max_ = tf.reduce_max(input_tensor=data)
            range_ = max_ - min_
            has_single_value = tf.equal(range_, 0)

            def when_multiple_values():
                """When input data contains multiple values."""
                bucket_width = range_ / tf.cast(bucket_count, tf.float64)
                offsets = data - min_
                bucket_indices = tf.cast(tf.floor(offsets / bucket_width),
                                         dtype=tf.int32)
                clamped_indices = tf.minimum(bucket_indices, bucket_count - 1)
                # Use float64 instead of float32 to avoid accumulating floating point error
                # later in tf.reduce_sum when summing more than 2^24 individual `1.0` values.
                # See https://github.com/tensorflow/tensorflow/issues/51419 for details.
                one_hots = tf.one_hot(clamped_indices,
                                      depth=bucket_count,
                                      dtype=tf.float64)
                bucket_counts = tf.cast(
                    tf.reduce_sum(input_tensor=one_hots, axis=0),
                    dtype=tf.float64,
                )
                edges = tf.linspace(min_, max_, bucket_count + 1)
                # Ensure edges[-1] == max_, which TF's linspace implementation does not
                # do, leaving it subject to the whim of floating point rounding error.
                edges = tf.concat([edges[:-1], [max_]], 0)
                left_edges = edges[:-1]
                right_edges = edges[1:]
                return tf.transpose(
                    a=tf.stack([left_edges, right_edges, bucket_counts]))

            def when_single_value():
                """When input data contains a single unique value."""
                # Left and right edges are the same for single value input.
                edges = tf.fill([bucket_count], max_)
                # Bucket counts are 0 except the last bucket (if bucket_count > 0),
                # which is `data_size`. Ensure that the resulting counts vector has
                # length `bucket_count` always, including the bucket_count==0 case.
                zeroes = tf.fill([bucket_count], 0)
                bucket_counts = tf.cast(
                    tf.concat([zeroes[:-1], [data_size]], 0)[:bucket_count],
                    dtype=tf.float64,
                )
                return tf.transpose(a=tf.stack([edges, edges, bucket_counts]))

            return tf.cond(has_single_value, when_single_value,
                           when_multiple_values)

        return tf.cond(is_empty, when_empty, when_nonempty)
Example #10
0
def audio(name,
          data,
          sample_rate,
          step,
          max_outputs=3,
          encoding=None,
          description=None):
    """Write an audio summary.

  Arguments:
    name: A name for this summary. The summary tag used for TensorBoard will
      be this name prefixed by any active name scopes.
    data: A `Tensor` representing audio data with shape `[k, t, c]`,
      where `k` is the number of audio clips, `t` is the number of
      frames, and `c` is the number of channels. Elements should be
      floating-point values in `[-1.0, 1.0]`. Any of the dimensions may
      be statically unknown (i.e., `None`).
    sample_rate: An `int` or rank-0 `int32` `Tensor` that represents the
      sample rate, in Hz. Must be positive.
    step: Required `int64`-castable monotonic step value.
    max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this
      many audio clips will be emitted at each step. When more than
      `max_outputs` many clips are provided, the first `max_outputs`
      many clips will be used and the rest silently discarded.
    encoding: Optional constant `str` for the desired encoding. Only "wav"
      is currently supported, but this is not guaranteed to remain the
      default, so if you want "wav" in particular, set this explicitly.
    description: Optional long-form description for this summary, as a
      constant `str`. Markdown is supported. Defaults to empty.

  Returns:
    True on success, or false if no summary was emitted because no default
    summary writer was available.
  """
    # TODO(nickfelt): get encode_wav() exported in the public API.
    from tensorflow.python.ops import gen_audio_ops

    if encoding is None:
        encoding = 'wav'
    if encoding != 'wav':
        raise ValueError('Unknown encoding: %r' % encoding)
    summary_metadata = metadata.create_summary_metadata(
        display_name=None,
        description=description,
        encoding=metadata.Encoding.Value('WAV'))
    inputs = [data, sample_rate, max_outputs, step]
    with tf.summary.summary_scope(name, 'audio_summary',
                                  values=inputs) as (tag, _):
        tf.debugging.assert_rank(data, 3)
        tf.debugging.assert_non_negative(max_outputs)
        limited_audio = data[:max_outputs]
        encode_fn = functools.partial(gen_audio_ops.encode_wav,
                                      sample_rate=sample_rate)
        encoded_audio = tf.map_fn(encode_fn,
                                  limited_audio,
                                  dtype=tf.string,
                                  name='encode_each_audio')
        # Workaround for map_fn returning float dtype for an empty elems input.
        encoded_audio = tf.cond(
            tf.shape(input=encoded_audio)[0] > 0, lambda: encoded_audio,
            lambda: tf.constant([], tf.string))
        limited_labels = tf.tile([''], tf.shape(input=limited_audio)[:1])
        tensor = tf.transpose(a=tf.stack([encoded_audio, limited_labels]))
        return tf.summary.write(tag=tag,
                                tensor=tensor,
                                step=step,
                                metadata=summary_metadata)
Example #11
0
def histogram(name, data, step=None, buckets=None, description=None):
    """Write a histogram summary.

    See also `tf.summary.scalar`, `tf.summary.SummaryWriter`.

    Writes a histogram to the current default summary writer, for later analysis
    in TensorBoard's 'Histograms' and 'Distributions' dashboards (data written
    using this API will appear in both places). Like `tf.summary.scalar` points,
    each histogram is associated with a `step` and a `name`. All the histograms
    with the same `name` constitute a time series of histograms.

    The histogram is calculated over all the elements of the given `Tensor`
    without regard to its shape or rank.

    This example writes 2 histograms:

    ```python
    w = tf.summary.create_file_writer('test/logs')
    with w.as_default():
        tf.summary.histogram("activations", tf.random.uniform([100, 50]), step=0)
        tf.summary.histogram("initial_weights", tf.random.normal([1000]), step=0)
    ```

    A common use case is to examine the changing activation patterns (or lack
    thereof) at specific layers in a neural network, over time.

    ```python
    w = tf.summary.create_file_writer('test/logs')
    with w.as_default():
    for step in range(100):
        # Generate fake "activations".
        activations = [
            tf.random.normal([1000], mean=step, stddev=1),
            tf.random.normal([1000], mean=step, stddev=10),
            tf.random.normal([1000], mean=step, stddev=100),
        ]

        tf.summary.histogram("layer1/activate", activations[0], step=step)
        tf.summary.histogram("layer2/activate", activations[1], step=step)
        tf.summary.histogram("layer3/activate", activations[2], step=step)
    ```

    Arguments:
      name: A name for this summary. The summary tag used for TensorBoard will
        be this name prefixed by any active name scopes.
      data: A `Tensor` of any shape. The histogram is computed over its elements,
        which must be castable to `float64`.
      step: Explicit `int64`-castable monotonic step value for this summary. If
        omitted, this defaults to `tf.summary.experimental.get_step()`, which must
        not be None.
      buckets: Optional positive `int`. The output will have this
        many buckets, except in two edge cases. If there is no data, then
        there are no buckets. If there is data but all points have the
        same value, then there is one bucket whose left and right
        endpoints are the same.
      description: Optional long-form description for this summary, as a
        constant `str`. Markdown is supported. Defaults to empty.

    Returns:
      True on success, or false if no summary was emitted because no default
      summary writer was available.

    Raises:
      ValueError: if a default writer exists, but no step was provided and
        `tf.summary.experimental.get_step()` is None.
    """
    # Avoid building unused gradient graphs for conds below. This works around
    # an error building second-order gradient graphs when XlaDynamicUpdateSlice
    # is used, and will generally speed up graph building slightly.
    data = tf.stop_gradient(data)
    summary_metadata = metadata.create_summary_metadata(
        display_name=None, description=description)
    # TODO(https://github.com/tensorflow/tensorboard/issues/2109): remove fallback
    summary_scope = (getattr(tf.summary.experimental, "summary_scope", None)
                     or tf.summary.summary_scope)

    # Try to capture current name scope so we can re-enter it below within our
    # histogram_summary helper. We do this to avoid having the `tf.cond` below
    # insert an extra `cond` into the tag name.
    # TODO(https://github.com/tensorflow/tensorboard/issues/2885): Remove this
    # special handling once the format no longer requires dynamic output shapes.
    name_scope_cms = []
    if hasattr(tf, "get_current_name_scope"):
        # Coerce None to ""; this API should only return a string but as of TF
        # 2.5 it returns None in contexts that re-enter the empty scope.
        current_scope = tf.get_current_name_scope() or ""
        # Append a "/" to the scope name, which causes that scope to be treated
        # as absolute instead of relative to the current scope, so that we can
        # re-enter it. It also prevents auto-incrementing of the scope name.
        # This is legacy graph mode behavior, undocumented except in comments:
        # https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/python/framework/ops.py#L6664-L6666
        scope_to_reenter = current_scope + "/" if current_scope else ""
        name_scope_cms.append(tf.name_scope(scope_to_reenter))

    def histogram_summary(data, buckets, histogram_metadata, step):
        with contextlib.ExitStack() as stack:
            for cm in name_scope_cms:
                stack.enter_context(cm)
            with summary_scope(name,
                               "histogram_summary",
                               values=[data, buckets, step]) as (tag, _):
                # Defer histogram bucketing logic by passing it as a callable to
                # write(), wrapped in a LazyTensorCreator for backwards
                # compatibility, so that we only do this work when summaries are
                # actually written.
                @lazy_tensor_creator.LazyTensorCreator
                def lazy_tensor():
                    return _buckets(data, buckets)

                return tf.summary.write(
                    tag=tag,
                    tensor=lazy_tensor,
                    step=step,
                    metadata=summary_metadata,
                )

    # `_buckets()` has dynamic output shapes which is not supported on TPU's.
    # To address this, explicitly mark this logic for outside compilation so it
    # will be executed on the CPU, and skip it entirely if we aren't actually
    # recording summaries to avoid overhead of transferring data.
    # TODO(https://github.com/tensorflow/tensorboard/issues/2885): Remove this
    # special handling once the format no longer requires dynamic output shapes.
    if isinstance(
            tf.distribute.get_strategy(),
        (tf.distribute.experimental.TPUStrategy, tf.distribute.TPUStrategy),
    ):
        return tf.cond(
            tf.summary.should_record_summaries(),
            lambda: tf.compat.v1.tpu.outside_compilation(
                histogram_summary, data, buckets, summary_metadata, step),
            lambda: False,
        )
    return histogram_summary(data, buckets, summary_metadata, step)