def test_quantizes_to_integers_modulo_offset(self): noisy = uniform_noise.NoisyNormal(loc=.25, scale=10.) em = ContinuousBatchedEntropyModel(noisy, 1) x = tf.range(-20., 20.) + .25 x_perturbed = x + tf.random.uniform(x.shape, -.49, .49) x_quantized = em.quantize(x_perturbed) self.assertAllEqual(x, x_quantized)
def test_compression_consistent_with_quantization(self): noisy = uniform_noise.NoisyNormal(loc=.25, scale=10.) em = ContinuousBatchedEntropyModel(noisy, 1, compression=True) x = noisy.base.sample([100]) x_quantized = em.quantize(x) x_decompressed = em.decompress(em.compress(x), [100]) self.assertAllEqual(x_decompressed, x_quantized)
def test_gradients_are_straight_through(self): noisy = uniform_noise.NoisyNormal(loc=0, scale=1) em = ContinuousBatchedEntropyModel(noisy, 1) x = tf.range(-20., 20.) x_perturbed = x + tf.random.uniform(x.shape, -.49, .49) with tf.GradientTape() as tape: tape.watch(x_perturbed) x_quantized = em.quantize(x_perturbed) gradients = tape.gradient(x_quantized, x_perturbed) self.assertAllEqual(gradients, tf.ones_like(gradients))
def test_compression_works_after_serialization_no_offset(self): noisy = uniform_noise.NoisyNormal(loc=0, scale=5.) em = ContinuousBatchedEntropyModel(noisy, 1, compression=True) self.assertIs(em._quantization_offset, None) json = tf.keras.utils.serialize_keras_object(em) weights = em.get_weights() x = noisy.base.sample([100]) x_quantized = em.quantize(x) x_compressed = em.compress(x) em = tf.keras.utils.deserialize_keras_object(json) em.set_weights(weights) self.assertAllEqual(em.compress(x), x_compressed) self.assertAllEqual(em.decompress(x_compressed, [100]), x_quantized)