def test_differs_given_different_seed(self): floats_1 = tf_utils.random_floats(100, tf.constant([123, 456], tf.int64)) floats_2 = tf_utils.random_floats(100, tf.constant([122, 456], tf.int64)) floats_1, floats_2 = self.evaluate([floats_1, floats_2]) self.assertFalse(np.array_equal(floats_1, floats_2))
def encode(self, x, encode_params): """See base class.""" min_x = tf.reduce_min(x) max_x = tf.reduce_max(x) max_value = tf.cast(encode_params[self.MAX_INT_VALUE_PARAMS_KEY], x.dtype) # Shift the values to range [0, max_value]. # In the case of min_x == max_x, this will return all zeros. x = tf.div_no_nan(x - min_x, max_x - min_x) * max_value # Randomized rounding. floored_x = tf.floor(x) random_seed = tf.random.uniform((2, ), maxval=tf.int64.max, dtype=tf.int64) num_elements = x.shape.num_elements() rounding_floats = tf_utils.random_floats(num_elements, random_seed, x.dtype) bernoulli = rounding_floats < (x - floored_x) quantized_x = floored_x + tf.cast(bernoulli, x.dtype) # Include the random seed in the encoded tensors so that it can be used to # generate the same random sequence in the decode method. encoded_tensors = { self.ENCODED_VALUES_KEY: quantized_x, self.SEED_PARAMS_KEY: random_seed, self.MIN_MAX_VALUES_KEY: tf.stack([min_x, max_x]) } return encoded_tensors
def decode(self, encoded_tensors, decode_params, num_summands=None, shape=None): """See base class.""" del shape # Unused. quantized_x = encoded_tensors[self.ENCODED_VALUES_KEY] random_seed = encoded_tensors[self.SEED_PARAMS_KEY] min_max = encoded_tensors[self.MIN_MAX_VALUES_KEY] min_x, max_x = min_max[0], min_max[1] max_value = tf.cast(decode_params[self.MAX_INT_VALUE_PARAMS_KEY], quantized_x.dtype) num_elements = quantized_x.shape.num_elements() # The rounding_floats are identical to those used in the encode method. rounding_floats = tf_utils.random_floats(num_elements, random_seed, min_x.dtype) # Regenerating the random values used in encode, enables us to determine a # narrower range of possible original values, before quantization was # applied. We shift the quantized values into the middle of this range, # corresponding to the intersection of # [quantized_x - 1 + rounding_floats, quantized_x + rounding_floats] # in the quantized range. This shifted value can be out of the range # [0, max_value] and therefore the decoded value can be out of the range # [min_x, max_x], which is impossible, but it ensures that the decoded x # is an unbiased estimator of the original values before quantization. q_shifted = quantized_x + rounding_floats - 0.5 x = q_shifted / max_value * (max_x - min_x) + min_x return x
def test_type_error_raises(self, dtype): with self.assertRaisesRegexp( TypeError, 'Supported types are tf.float32 and ' 'tf.float64 values'): tf_utils.random_floats(10, tf.constant([456, 123], tf.int64), dtype)
def test_expected_dtype(self, dtype): floats = tf_utils.random_floats(10, tf.constant([456, 123], tf.int64), dtype) self.assertEqual(dtype, floats.dtype)
def _random_floats(self, num_elements, seed, dtype): return tf_utils.random_floats(num_elements, seed, dtype)