def _quantize(x, params, randomize=True): """Quantize x according to params, optionally randomizing the rounding.""" if not params.quantize: return x if not randomize: return tf.bitcast(tf.cast(x / params.quantization_scale, tf.int16), tf.float16) abs_x = tf.abs(x) sign_x = tf.sign(x) y = abs_x / params.quantization_scale y = tf.floor(y + tf.random_uniform(common_layers.shape_list(x))) y = tf.minimum(y, tf.int16.max) * sign_x q = tf.bitcast(tf.cast(y, tf.int16), tf.float16) return q
def create_topk_unique(inputs, k): height = inputs.shape[0] width = inputs.shape[1] neg_inf_r0 = tf.constant(-np.inf, dtype=tf.float32) ones = tf.ones([height, width], dtype=tf.float32) neg_inf_r2 = ones * neg_inf_r0 inputs = tf.where(tf.is_nan(inputs), neg_inf_r2, inputs) tmp = inputs topk_r2 = tf.zeros([height, k], dtype=tf.float32) for i in range(k): kth_order_statistic = tf.reduce_max(tmp, axis=1, keepdims=True) k_mask = tf.tile(tf.expand_dims(tf.equal(tf.range(k), tf.fill([k], i)), 0), [height, 1]) topk_r2 = tf.where(k_mask, tf.tile(kth_order_statistic, [1, k]), topk_r2) ge_r2 = tf.greater_equal(inputs, tf.tile(kth_order_statistic, [1, width])) tmp = tf.where(ge_r2, neg_inf_r2, inputs) log2_ceiling = int(math.ceil(math.log(float(int(width)), 2))) next_power_of_two = 1 << log2_ceiling count_mask = next_power_of_two - 1 mask_r0 = tf.constant(count_mask) mask_r2 = tf.fill([height, k], mask_r0) topk_r2_s32 = tf.bitcast(topk_r2, tf.int32) topk_indices_r2 = tf.bitwise.bitwise_and(topk_r2_s32, mask_r2) return topk_r2, topk_indices_r2
def parse_image(value, image_key="image/encoded", label_key="image/class/label", label_bias=0, decode=True): keys_to_features = { image_key: tf.FixedLenFeature((), tf.string, ''), label_key: tf.FixedLenFeature([], tf.int64, -1), } parsed = tf.parse_single_example(value, keys_to_features) image_bytes = tf.reshape(parsed[image_key], shape=[]) image = tf.io.decode_image(image_bytes, 3) if decode else image_bytes # For imagenet records, set label_bias = -1 so that labels are in [0, 1000). label = tf.cast(tf.reshape(parsed[label_key], shape=[]), dtype=tf.int32) + label_bias # compute a hash of the image fingerprint = tf.raw_ops.Fingerprint(data=[image_bytes], method="farmhash64") fingerprint = tf.bitcast(fingerprint, tf.int64) fingerprint = fingerprint[0] result = { 'image_bytes': image_bytes, 'label': label, 'hash': fingerprint, } if decode: result['image'] = image return result
def create_make_unique(inputs): if inputs.shape.ndims != 2: raise ValueError("Input of top_k_with_unique must be rank-2 " "but got: %s" % inputs.shape) height = inputs.shape[0] width = inputs.shape[1] zeros = tf.zeros([height, width], dtype=tf.int32) # count_mask is used to mask away the low order bits to ensure that every # element is distinct. log2_ceiling = int(math.ceil(math.log(float(int(width)), 2))) next_power_of_two = 1 << log2_ceiling count_mask = ~(next_power_of_two - 1) count_mask_r0 = tf.constant(count_mask) count_mask_r2 = tf.fill([height, width], count_mask_r0) # smallest_normal is the bit representation of the smallest # positive normal floating point number. The sign is zero, # exponent is one, and the fraction is zero. smallest_normal = 1 << 23 smallest_normal_r0 = tf.constant(smallest_normal, dtype=tf.int32) smallest_normal_r2 = tf.fill([height, width], smallest_normal_r0) # Used to mask away the sign bit when computing the absolute value. low_bit_mask = ~(1 << 31) low_bit_mask_r0 = tf.constant(low_bit_mask, dtype=tf.int32) low_bit_mask_r2 = tf.fill([height, width], low_bit_mask_r0) iota = tf.tile(tf.expand_dims(tf.range(width, dtype=tf.int32), 0), [height, 1]) # Compare the absolute value with positive zero to handle negative zero. # # Pseudocode: input_no_zeros = abs(input) == 0 ? FLT_MIN : input input_r2 = tf.bitcast(inputs, tf.int32) abs_r2 = tf.bitwise.bitwise_and(input_r2, low_bit_mask_r2) if_zero_r2 = tf.equal(abs_r2, zeros) smallest_normal_preserving_sign_r2 = tf.bitwise.bitwise_or( input_r2, smallest_normal_r2) input_no_zeros_r2 = tf.where( if_zero_r2, smallest_normal_preserving_sign_r2, input_r2) # Discard the low-order bits and replace with iota. and_r2 = tf.bitwise.bitwise_and(input_no_zeros_r2, count_mask_r2) or_r2 = tf.bitwise.bitwise_or(and_r2, iota) return tf.bitcast(or_r2, tf.float32)
def body(bit_index, value): """Body for the while loop executing the binary search.""" new_value = tf.bitwise.bitwise_or(value, tf.bitwise.left_shift(1, bit_index)) larger = larger_count(scores, tf.bitcast(new_value, tf.float32)) next_value = tf.where(tf.logical_xor(larger >= k, kth_negative), new_value, value) return bit_index - 1, next_value
def triangles_to_edges(faces): """Computes mesh edges from triangles.""" # collect edges from triangles edges = tf.concat([faces[:, 0:2], faces[:, 1:3], tf.stack([faces[:, 2], faces[:, 0]], axis=1)], axis=0) # those edges are sometimes duplicated (within the mesh) and sometimes # single (at the mesh boundary). # sort & pack edges as single tf.int64 receivers = tf.reduce_min(edges, axis=1) senders = tf.reduce_max(edges, axis=1) packed_edges = tf.bitcast(tf.stack([senders, receivers], axis=1), tf.int64) # remove duplicates and unpack unique_edges = tf.bitcast(tf.unique(packed_edges)[0], tf.int32) senders, receivers = tf.unstack(unique_edges, axis=1) # create two-way connectivity return (tf.concat([senders, receivers], axis=0), tf.concat([receivers, senders], axis=0))
def from_characters(raw, lookup_): """Convert ascii+2 encoded codes to string-tokens.""" corrected = tf.bitcast(tf.clip_by_value(tf.subtract(raw, 2), 0, 255), tf.uint8) gathered = tf.gather(lookup_, tf.cast(corrected, tf.int32))[:, :, 0] joined = tf.reduce_join(gathered, axis=1) cleaned = tf.regex_replace(joined, b"\0", b"") tokens = tf.string_split(cleaned, " ") return tokens
def _create_topk_unique(inputs, k): """Creates the top k values in sorted order with indices. Args: inputs: A tensor with rank of 2. [batch_size, original_size]. k: An integer, number of top elements to select. Returns: topk_r2: A tensor, the k largest elements. [batch_size, k]. topk_indices_r2: A tensor, indices of the top k values. [batch_size, k]. """ height = inputs.shape[0] width = inputs.shape[1] neg_inf_r0 = tf.constant(-np.inf, dtype=tf.float32) ones = tf.ones([height, width], dtype=tf.float32) neg_inf_r2 = ones * neg_inf_r0 inputs = tf.where(tf.is_nan(inputs), neg_inf_r2, inputs) # Select the current largest value k times and keep them in topk_r2. The # selected largest values are marked as the smallest value to avoid being # selected again. tmp = inputs topk_r2 = tf.zeros([height, k], dtype=tf.float32) for i in range(k): kth_order_statistic = tf.reduce_max(tmp, axis=1, keepdims=True) k_mask = tf.tile( tf.expand_dims(tf.equal(tf.range(k), tf.fill([k], i)), 0), [height, 1]) topk_r2 = tf.where(k_mask, tf.tile(kth_order_statistic, [1, k]), topk_r2) ge_r2 = tf.greater_equal(inputs, tf.tile(kth_order_statistic, [1, width])) tmp = tf.where(ge_r2, neg_inf_r2, inputs) log2_ceiling = int(math.ceil(math.log(float(int(width)), 2))) next_power_of_two = 1 << log2_ceiling count_mask = next_power_of_two - 1 mask_r0 = tf.constant(count_mask) mask_r2 = tf.fill([height, k], mask_r0) topk_r2_s32 = tf.bitcast(topk_r2, tf.int32) topk_indices_r2 = tf.bitwise.bitwise_and(topk_r2_s32, mask_r2) return topk_r2, topk_indices_r2
def _topk_mask(scores, k): """Efficient implementation of topk_mask for TPUs.""" def larger_count(data, limit): """Number of elements larger than limit along the most minor dimension.""" ret = [] for d in data: ret.append( tf.reduce_sum(tf.cast( d > tf.reshape(limit, [-1] + [1] * (d.shape.ndims - 1)), tf.int32), axis=range(1, d.shape.ndims))) return tf.add_n(ret) def body(bit_index, value): """Body for the while loop executing the binary search.""" new_value = tf.bitwise.bitwise_or(value, tf.bitwise.left_shift(1, bit_index)) larger = larger_count(scores, tf.bitcast(new_value, tf.float32)) next_value = tf.where(tf.logical_xor(larger >= k, kth_negative), new_value, value) return bit_index - 1, next_value kth_negative = (larger_count(scores, 0.0) < k) limit_sign = tf.where(kth_negative, tf.broadcast_to(1, kth_negative.shape), tf.broadcast_to(0, kth_negative.shape)) next_value = tf.bitwise.left_shift(limit_sign, 31) _, limit = tf.while_loop(lambda bit_index, _: bit_index >= 0, body, (30, next_value)) ret = [] for score in scores: # Filter scores that are smaller than the threshold. ret.append( tf.where( score >= tf.reshape(tf.bitcast(limit, tf.float32), [-1] + [1] * (score.shape.ndims - 1)), tf.ones(score.shape), tf.zeros(score.shape))) return ret
def _dequantize(q, params): """Dequantize q according to params.""" if not params.quantize: return q return tf.to_float(tf.bitcast(q, tf.int16)) * params.quantization_scale
def _create_make_unique(inputs): """Replaces the lower bits of each element with iota. The iota is used to derive the index, and also serves the purpose to make each element unique to break ties. Args: inputs: A tensor with rank of 2 and dtype of tf.float32. [batch_size, original_size]. Returns: A tensor after element wise transformation, with dtype the same as inputs. [batch_size, original_size]. Raises: ValueError: If the rank of the input tensor does not equal 2. """ if inputs.shape.ndims != 2: raise ValueError("Input of top_k_with_unique must be rank-2 " "but got: %s" % inputs.shape) height = inputs.shape[0] width = inputs.shape[1] zeros = tf.zeros([height, width], dtype=tf.int32) # Count_mask is used to mask away the low order bits to ensure that every # element is distinct. log2_ceiling = int(math.ceil(math.log(int(width), 2))) next_power_of_two = 1 << log2_ceiling count_mask = ~(next_power_of_two - 1) count_mask_r0 = tf.constant(count_mask) count_mask_r2 = tf.fill([height, width], count_mask_r0) # Smallest_normal is the bit representation of the smallest positive normal # floating point number. The sign is zero, exponent is one, and the fraction # is zero. smallest_normal = 1 << 23 smallest_normal_r0 = tf.constant(smallest_normal, dtype=tf.int32) smallest_normal_r2 = tf.fill([height, width], smallest_normal_r0) # Low_bit_mask is used to mask away the sign bit when computing the absolute # value. low_bit_mask = ~(1 << 31) low_bit_mask_r0 = tf.constant(low_bit_mask, dtype=tf.int32) low_bit_mask_r2 = tf.fill([height, width], low_bit_mask_r0) iota = tf.tile(tf.expand_dims(tf.range(width, dtype=tf.int32), 0), [height, 1]) # Compare the absolute value with positive zero to handle negative zero. input_r2 = tf.bitcast(inputs, tf.int32) abs_r2 = tf.bitwise.bitwise_and(input_r2, low_bit_mask_r2) if_zero_r2 = tf.equal(abs_r2, zeros) smallest_normal_preserving_sign_r2 = tf.bitwise.bitwise_or( input_r2, smallest_normal_r2) input_no_zeros_r2 = tf.where(if_zero_r2, smallest_normal_preserving_sign_r2, input_r2) # Discard the low-order bits and replace with iota. and_r2 = tf.bitwise.bitwise_and(input_no_zeros_r2, count_mask_r2) or_r2 = tf.bitwise.bitwise_or(and_r2, iota) return tf.bitcast(or_r2, tf.float32)