def when_nonempty(): min_ = tf.reduce_min(input_tensor=data) max_ = tf.reduce_max(input_tensor=data) range_ = max_ - min_ is_singular = tf.equal(range_, 0) def when_nonsingular(): bucket_width = range_ / tf.cast(bucket_count, tf.float64) offsets = data - min_ bucket_indices = tf.cast(tf.floor(offsets / bucket_width), dtype=tf.int32) clamped_indices = tf.minimum(bucket_indices, bucket_count - 1) one_hots = tf.one_hot(clamped_indices, depth=bucket_count) bucket_counts = tf.cast(tf.reduce_sum(input_tensor=one_hots, axis=0), dtype=tf.float64) edges = tf.linspace(min_, max_, bucket_count + 1) # Ensure edges[-1] == max_, which TF's linspace implementation does not # do, leaving it subject to the whim of floating point rounding error. edges = tf.concat([edges[:-1], [max_]], 0) left_edges = edges[:-1] right_edges = edges[1:] return tf.transpose( a=tf.stack([left_edges, right_edges, bucket_counts])) def when_singular(): center = min_ bucket_starts = tf.stack([center - 0.5]) bucket_ends = tf.stack([center + 0.5]) bucket_counts = tf.stack( [tf.cast(tf.size(input=data), tf.float64)]) return tf.transpose( a=tf.stack([bucket_starts, bucket_ends, bucket_counts])) return tf.cond(is_singular, when_singular, when_nonsingular)
def when_nonempty(): min_ = tf.reduce_min(input_tensor=data) max_ = tf.reduce_max(input_tensor=data) range_ = max_ - min_ has_single_value = tf.equal(range_, 0) def when_multiple_values(): """When input data contains multiple values.""" bucket_width = range_ / tf.cast(bucket_count, tf.float64) offsets = data - min_ bucket_indices = tf.cast(tf.floor(offsets / bucket_width), dtype=tf.int32) clamped_indices = tf.minimum(bucket_indices, bucket_count - 1) # Use float64 instead of float32 to avoid accumulating floating point error # later in tf.reduce_sum when summing more than 2^24 individual `1.0` values. # See https://github.com/tensorflow/tensorflow/issues/51419 for details. one_hots = tf.one_hot(clamped_indices, depth=bucket_count, dtype=tf.float64) bucket_counts = tf.cast( tf.reduce_sum(input_tensor=one_hots, axis=0), dtype=tf.float64, ) edges = tf.linspace(min_, max_, bucket_count + 1) # Ensure edges[-1] == max_, which TF's linspace implementation does not # do, leaving it subject to the whim of floating point rounding error. edges = tf.concat([edges[:-1], [max_]], 0) left_edges = edges[:-1] right_edges = edges[1:] return tf.transpose( a=tf.stack([left_edges, right_edges, bucket_counts])) def when_single_value(): """When input data contains a single unique value.""" # Left and right edges are the same for single value input. edges = tf.fill([bucket_count], max_) # Bucket counts are 0 except the last bucket (if bucket_count > 0), # which is `data_size`. Ensure that the resulting counts vector has # length `bucket_count` always, including the bucket_count==0 case. zeroes = tf.fill([bucket_count], 0) bucket_counts = tf.cast( tf.concat([zeroes[:-1], [data_size]], 0)[:bucket_count], dtype=tf.float64, ) return tf.transpose(a=tf.stack([edges, edges, bucket_counts])) return tf.cond(has_single_value, when_single_value, when_multiple_values)
def when_nonempty(): min_ = tf.reduce_min(input_tensor=data) max_ = tf.reduce_max(input_tensor=data) range_ = max_ - min_ is_singular = tf.equal(range_, 0) def when_nonsingular(): bucket_width = range_ / tf.cast(bucket_count, tf.float64) offsets = data - min_ bucket_indices = tf.cast(tf.floor(offsets / bucket_width), dtype=tf.int32) clamped_indices = tf.minimum(bucket_indices, bucket_count - 1) # Use float64 instead of float32 to avoid accumulating floating point error # later in tf.reduce_sum when summing more than 2^24 individual `1.0` values. # See https://github.com/tensorflow/tensorflow/issues/51419 for details. one_hots = tf.one_hot(clamped_indices, depth=bucket_count, dtype=tf.float64) bucket_counts = tf.cast( tf.reduce_sum(input_tensor=one_hots, axis=0), dtype=tf.float64, ) edges = tf.linspace(min_, max_, bucket_count + 1) # Ensure edges[-1] == max_, which TF's linspace implementation does not # do, leaving it subject to the whim of floating point rounding error. edges = tf.concat([edges[:-1], [max_]], 0) left_edges = edges[:-1] right_edges = edges[1:] return tf.transpose( a=tf.stack([left_edges, right_edges, bucket_counts])) def when_singular(): center = min_ bucket_starts = tf.stack([center - 0.5]) bucket_ends = tf.stack([center + 0.5]) bucket_counts = tf.stack( [tf.cast(tf.size(input=data), tf.float64)]) return tf.transpose( a=tf.stack([bucket_starts, bucket_ends, bucket_counts])) return tf.cond(is_singular, when_singular, when_nonsingular)