예제 #1
0
        def when_nonempty():
            min_ = tf.reduce_min(input_tensor=data)
            max_ = tf.reduce_max(input_tensor=data)
            range_ = max_ - min_
            is_singular = tf.equal(range_, 0)

            def when_nonsingular():
                bucket_width = range_ / tf.cast(bucket_count, tf.float64)
                offsets = data - min_
                bucket_indices = tf.cast(tf.floor(offsets / bucket_width),
                                         dtype=tf.int32)
                clamped_indices = tf.minimum(bucket_indices, bucket_count - 1)
                one_hots = tf.one_hot(clamped_indices, depth=bucket_count)
                bucket_counts = tf.cast(tf.reduce_sum(input_tensor=one_hots,
                                                      axis=0),
                                        dtype=tf.float64)
                edges = tf.linspace(min_, max_, bucket_count + 1)
                # Ensure edges[-1] == max_, which TF's linspace implementation does not
                # do, leaving it subject to the whim of floating point rounding error.
                edges = tf.concat([edges[:-1], [max_]], 0)
                left_edges = edges[:-1]
                right_edges = edges[1:]
                return tf.transpose(
                    a=tf.stack([left_edges, right_edges, bucket_counts]))

            def when_singular():
                center = min_
                bucket_starts = tf.stack([center - 0.5])
                bucket_ends = tf.stack([center + 0.5])
                bucket_counts = tf.stack(
                    [tf.cast(tf.size(input=data), tf.float64)])
                return tf.transpose(
                    a=tf.stack([bucket_starts, bucket_ends, bucket_counts]))

            return tf.cond(is_singular, when_singular, when_nonsingular)
예제 #2
0
        def when_nonempty():
            min_ = tf.reduce_min(input_tensor=data)
            max_ = tf.reduce_max(input_tensor=data)
            range_ = max_ - min_
            has_single_value = tf.equal(range_, 0)

            def when_multiple_values():
                """When input data contains multiple values."""
                bucket_width = range_ / tf.cast(bucket_count, tf.float64)
                offsets = data - min_
                bucket_indices = tf.cast(tf.floor(offsets / bucket_width),
                                         dtype=tf.int32)
                clamped_indices = tf.minimum(bucket_indices, bucket_count - 1)
                # Use float64 instead of float32 to avoid accumulating floating point error
                # later in tf.reduce_sum when summing more than 2^24 individual `1.0` values.
                # See https://github.com/tensorflow/tensorflow/issues/51419 for details.
                one_hots = tf.one_hot(clamped_indices,
                                      depth=bucket_count,
                                      dtype=tf.float64)
                bucket_counts = tf.cast(
                    tf.reduce_sum(input_tensor=one_hots, axis=0),
                    dtype=tf.float64,
                )
                edges = tf.linspace(min_, max_, bucket_count + 1)
                # Ensure edges[-1] == max_, which TF's linspace implementation does not
                # do, leaving it subject to the whim of floating point rounding error.
                edges = tf.concat([edges[:-1], [max_]], 0)
                left_edges = edges[:-1]
                right_edges = edges[1:]
                return tf.transpose(
                    a=tf.stack([left_edges, right_edges, bucket_counts]))

            def when_single_value():
                """When input data contains a single unique value."""
                # Left and right edges are the same for single value input.
                edges = tf.fill([bucket_count], max_)
                # Bucket counts are 0 except the last bucket (if bucket_count > 0),
                # which is `data_size`. Ensure that the resulting counts vector has
                # length `bucket_count` always, including the bucket_count==0 case.
                zeroes = tf.fill([bucket_count], 0)
                bucket_counts = tf.cast(
                    tf.concat([zeroes[:-1], [data_size]], 0)[:bucket_count],
                    dtype=tf.float64,
                )
                return tf.transpose(a=tf.stack([edges, edges, bucket_counts]))

            return tf.cond(has_single_value, when_single_value,
                           when_multiple_values)
예제 #3
0
        def when_nonempty():
            min_ = tf.reduce_min(input_tensor=data)
            max_ = tf.reduce_max(input_tensor=data)
            range_ = max_ - min_
            is_singular = tf.equal(range_, 0)

            def when_nonsingular():
                bucket_width = range_ / tf.cast(bucket_count, tf.float64)
                offsets = data - min_
                bucket_indices = tf.cast(tf.floor(offsets / bucket_width),
                                         dtype=tf.int32)
                clamped_indices = tf.minimum(bucket_indices, bucket_count - 1)
                # Use float64 instead of float32 to avoid accumulating floating point error
                # later in tf.reduce_sum when summing more than 2^24 individual `1.0` values.
                # See https://github.com/tensorflow/tensorflow/issues/51419 for details.
                one_hots = tf.one_hot(clamped_indices,
                                      depth=bucket_count,
                                      dtype=tf.float64)
                bucket_counts = tf.cast(
                    tf.reduce_sum(input_tensor=one_hots, axis=0),
                    dtype=tf.float64,
                )
                edges = tf.linspace(min_, max_, bucket_count + 1)
                # Ensure edges[-1] == max_, which TF's linspace implementation does not
                # do, leaving it subject to the whim of floating point rounding error.
                edges = tf.concat([edges[:-1], [max_]], 0)
                left_edges = edges[:-1]
                right_edges = edges[1:]
                return tf.transpose(
                    a=tf.stack([left_edges, right_edges, bucket_counts]))

            def when_singular():
                center = min_
                bucket_starts = tf.stack([center - 0.5])
                bucket_ends = tf.stack([center + 0.5])
                bucket_counts = tf.stack(
                    [tf.cast(tf.size(input=data), tf.float64)])
                return tf.transpose(
                    a=tf.stack([bucket_starts, bucket_ends, bucket_counts]))

            return tf.cond(is_singular, when_singular, when_nonsingular)