def test_quantizes_to_integers_modulo_offset(self):
     noisy = uniform_noise.NoisyNormal(loc=.25, scale=10.)
     em = ContinuousBatchedEntropyModel(noisy, 1)
     x = tf.range(-20., 20.) + .25
     x_perturbed = x + tf.random.uniform(x.shape, -.49, .49)
     x_quantized = em.quantize(x_perturbed)
     self.assertAllEqual(x, x_quantized)
 def test_compression_consistent_with_quantization(self):
     noisy = uniform_noise.NoisyNormal(loc=.25, scale=10.)
     em = ContinuousBatchedEntropyModel(noisy, 1, compression=True)
     x = noisy.base.sample([100])
     x_quantized = em.quantize(x)
     x_decompressed = em.decompress(em.compress(x), [100])
     self.assertAllEqual(x_decompressed, x_quantized)
 def test_gradients_are_straight_through(self):
     noisy = uniform_noise.NoisyNormal(loc=0, scale=1)
     em = ContinuousBatchedEntropyModel(noisy, 1)
     x = tf.range(-20., 20.)
     x_perturbed = x + tf.random.uniform(x.shape, -.49, .49)
     with tf.GradientTape() as tape:
         tape.watch(x_perturbed)
         x_quantized = em.quantize(x_perturbed)
     gradients = tape.gradient(x_quantized, x_perturbed)
     self.assertAllEqual(gradients, tf.ones_like(gradients))
 def test_compression_works_after_serialization_no_offset(self):
     noisy = uniform_noise.NoisyNormal(loc=0, scale=5.)
     em = ContinuousBatchedEntropyModel(noisy, 1, compression=True)
     self.assertIs(em._quantization_offset, None)
     json = tf.keras.utils.serialize_keras_object(em)
     weights = em.get_weights()
     x = noisy.base.sample([100])
     x_quantized = em.quantize(x)
     x_compressed = em.compress(x)
     em = tf.keras.utils.deserialize_keras_object(json)
     em.set_weights(weights)
     self.assertAllEqual(em.compress(x), x_compressed)
     self.assertAllEqual(em.decompress(x_compressed, [100]), x_quantized)