Ejemplo n.º 1
0
 def test_differs_given_different_seed(self):
     floats_1 = tf_utils.random_floats(100, tf.constant([123, 456],
                                                        tf.int64))
     floats_2 = tf_utils.random_floats(100, tf.constant([122, 456],
                                                        tf.int64))
     floats_1, floats_2 = self.evaluate([floats_1, floats_2])
     self.assertFalse(np.array_equal(floats_1, floats_2))
Ejemplo n.º 2
0
    def encode(self, x, encode_params):
        """See base class."""
        min_x = tf.reduce_min(x)
        max_x = tf.reduce_max(x)

        max_value = tf.cast(encode_params[self.MAX_INT_VALUE_PARAMS_KEY],
                            x.dtype)
        # Shift the values to range [0, max_value].
        # In the case of min_x == max_x, this will return all zeros.
        x = tf.div_no_nan(x - min_x, max_x - min_x) * max_value

        # Randomized rounding.
        floored_x = tf.floor(x)
        random_seed = tf.random.uniform((2, ),
                                        maxval=tf.int64.max,
                                        dtype=tf.int64)
        num_elements = x.shape.num_elements()
        rounding_floats = tf_utils.random_floats(num_elements, random_seed,
                                                 x.dtype)

        bernoulli = rounding_floats < (x - floored_x)
        quantized_x = floored_x + tf.cast(bernoulli, x.dtype)

        # Include the random seed in the encoded tensors so that it can be used to
        # generate the same random sequence in the decode method.
        encoded_tensors = {
            self.ENCODED_VALUES_KEY: quantized_x,
            self.SEED_PARAMS_KEY: random_seed,
            self.MIN_MAX_VALUES_KEY: tf.stack([min_x, max_x])
        }

        return encoded_tensors
Ejemplo n.º 3
0
    def decode(self,
               encoded_tensors,
               decode_params,
               num_summands=None,
               shape=None):
        """See base class."""
        del shape  # Unused.
        quantized_x = encoded_tensors[self.ENCODED_VALUES_KEY]
        random_seed = encoded_tensors[self.SEED_PARAMS_KEY]
        min_max = encoded_tensors[self.MIN_MAX_VALUES_KEY]
        min_x, max_x = min_max[0], min_max[1]
        max_value = tf.cast(decode_params[self.MAX_INT_VALUE_PARAMS_KEY],
                            quantized_x.dtype)

        num_elements = quantized_x.shape.num_elements()
        # The rounding_floats are identical to those used in the encode method.
        rounding_floats = tf_utils.random_floats(num_elements, random_seed,
                                                 min_x.dtype)

        # Regenerating the random values used in encode, enables us to determine a
        # narrower range of possible original values, before quantization was
        # applied. We shift the quantized values into the middle of this range,
        # corresponding to the intersection of
        # [quantized_x - 1 + rounding_floats, quantized_x + rounding_floats]
        # in the quantized range. This shifted value can be out of the range
        # [0, max_value] and therefore the decoded value can be out of the range
        # [min_x, max_x], which is impossible, but it ensures that the decoded x
        # is an unbiased estimator of the original values before quantization.
        q_shifted = quantized_x + rounding_floats - 0.5

        x = q_shifted / max_value * (max_x - min_x) + min_x
        return x
Ejemplo n.º 4
0
 def test_type_error_raises(self, dtype):
     with self.assertRaisesRegexp(
             TypeError, 'Supported types are tf.float32 and '
             'tf.float64 values'):
         tf_utils.random_floats(10, tf.constant([456, 123], tf.int64),
                                dtype)
Ejemplo n.º 5
0
 def test_expected_dtype(self, dtype):
     floats = tf_utils.random_floats(10, tf.constant([456, 123], tf.int64),
                                     dtype)
     self.assertEqual(dtype, floats.dtype)
Ejemplo n.º 6
0
 def _random_floats(self, num_elements, seed, dtype):
     return tf_utils.random_floats(num_elements, seed, dtype)