Ejemplo n.º 1
0
 def test_default_kwargs_throw_error_on_compression(self):
     noisy = uniform_noise.NoisyNormal(loc=.25, scale=10.)
     em = ContinuousBatchedEntropyModel(noisy, 1)
     x = tf.zeros(10)
     with self.assertRaises(RuntimeError):
         em.compress(x)
     s = tf.zeros(10, dtype=tf.string)
     with self.assertRaises(RuntimeError):
         em.decompress(s, [10])
Ejemplo n.º 2
0
 def test_compression_consistent_with_quantization(self):
     noisy = uniform_noise.NoisyNormal(loc=.25, scale=10.)
     em = ContinuousBatchedEntropyModel(noisy, 1, compression=True)
     x = noisy.base.sample([100])
     x_quantized = em.quantize(x)
     x_decompressed = em.decompress(em.compress(x), [100])
     self.assertAllEqual(x_decompressed, x_quantized)
Ejemplo n.º 3
0
 class Compressor:
     def compress(self, values):
         if not hasattr(self, "em"):
             self.em = ContinuousBatchedEntropyModel(noisy,
                                                     1,
                                                     compression=True)
         compressed = self.em.compress(values)
         return self.em.decompress(compressed, [100])
Ejemplo n.º 4
0
 def test_compression_works_after_serialization_no_offset(self):
     noisy = uniform_noise.NoisyNormal(loc=0, scale=5.)
     em = ContinuousBatchedEntropyModel(noisy, 1, compression=True)
     self.assertIs(em._quantization_offset, None)
     json = tf.keras.utils.serialize_keras_object(em)
     weights = em.get_weights()
     x = noisy.base.sample([100])
     x_quantized = em.quantize(x)
     x_compressed = em.compress(x)
     em = tf.keras.utils.deserialize_keras_object(json)
     em.set_weights(weights)
     self.assertAllEqual(em.compress(x), x_compressed)
     self.assertAllEqual(em.decompress(x_compressed, [100]), x_quantized)
Ejemplo n.º 5
0
 def test_small_bitcost_for_dirac_prior(self):
   prior = uniform_noise.NoisyNormal(loc=100 * tf.range(16.0), scale=1e-10)
   em = ContinuousBatchedEntropyModel(
       prior, coding_rank=2, compression=True)
   num_symbols = 1000
   source = prior.base
   x = source.sample((3, num_symbols))
   _, bits_estimate = em(x, training=True)
   bitstring = em.compress(x)
   x_decoded = em.decompress(bitstring, (num_symbols,))
   bitstring_bits = tf.reshape(
       [len(b) * 8 for b in bitstring.numpy().flatten()], bitstring.shape)
   # Max 2 bytes.
   self.assertAllLessEqual(bits_estimate, 16)
   self.assertAllLessEqual(bitstring_bits, 16)
   # Quantization noise should be between -.5 and .5
   self.assertAllLessEqual(tf.abs(x - x_decoded), 0.5)
Ejemplo n.º 6
0
 def test_dtypes_are_correct_with_mixed_precision(self):
     tf.keras.mixed_precision.set_global_policy("mixed_float16")
     try:
         noisy = uniform_noise.NoisyNormal(
             loc=tf.constant(0, dtype=tf.float64),
             scale=tf.constant(1, dtype=tf.float64))
         em = ContinuousBatchedEntropyModel(noisy, 1, compression=True)
         self.assertEqual(em.bottleneck_dtype, tf.float16)
         self.assertEqual(em.prior.dtype, tf.float64)
         x = tf.random.stateless_normal((2, 5),
                                        seed=(0, 1),
                                        dtype=tf.float16)
         x_tilde, bits = em(x)
         bitstring = em.compress(x)
         x_hat = em.decompress(bitstring, (5, ))
         self.assertEqual(x_hat.dtype, tf.float16)
         self.assertAllClose(x, x_hat, rtol=0, atol=.5)
         self.assertEqual(x_tilde.dtype, tf.float16)
         self.assertAllClose(x, x_tilde, rtol=0, atol=.5)
         self.assertEqual(bits.dtype, tf.float64)
         self.assertEqual(bits.shape, (2, ))
         self.assertAllGreaterEqual(bits, 0.)
     finally:
         tf.keras.mixed_precision.set_global_policy(None)