コード例 #1
0
 def get_model(self, prior_fn=uniform_noise.NoisyNormal,
               coding_rank=1, **kwargs):
   return continuous_indexed.LocationScaleIndexedEntropyModel(
       prior_fn,
       64,
       lambda i: tf.exp(i / 8. - 5.),
       coding_rank,
       **kwargs)
コード例 #2
0
 def test_can_instantiate(self):
   em = continuous_indexed.LocationScaleIndexedEntropyModel(
       uniform_noise.NoisyNormal, 64, lambda i: tf.exp(i / 8 - 5), 1)
   self.assertIsInstance(em.prior, uniform_noise.NoisyNormal)
   self.assertEqual(em.coding_rank, 1)
   self.assertEqual(em.likelihood_bound, 1e-9)
   self.assertEqual(em.tail_mass, 2**-8)
   self.assertEqual(em.range_coder_precision, 12)
   self.assertEqual(em.dtype, tf.float32)
コード例 #3
0
 def test_can_instantiate(self):
     em = continuous_indexed.LocationScaleIndexedEntropyModel(
         uniform_noise.NoisyNormal, 64, lambda i: tf.exp(i / 8 - 5), 1)
     self.assertIsInstance(em.distribution, uniform_noise.NoisyNormal)
     self.assertEqual(em.coding_rank, 1)
     self.assertEqual(em.likelihood_bound, 1e-9)
     self.assertEqual(em.tail_mass, 2**-8)
     self.assertEqual(em.range_coder_precision, 12)
     self.assertEqual(em.dtype, tf.float32)
     self.assertSequenceEqual(em.quantization_offset().shape, [64])
     self.assertSequenceEqual(em.upper_tail().shape, [64])
     self.assertSequenceEqual(em.lower_tail().shape, [64])
コード例 #4
0
 def test_can_instantiate_and_compress(self):
     em = continuous_indexed.LocationScaleIndexedEntropyModel(
         uniform_noise.NoisyNormal,
         64,
         lambda i: tf.exp(i / 8 - 5),
         1,
         compression=True)
     self.assertIsInstance(em.prior, uniform_noise.NoisyNormal)
     self.assertEqual(em.coding_rank, 1)
     self.assertEqual(em.tail_mass, 2**-8)
     self.assertEqual(em.range_coder_precision, 12)
     self.assertEqual(em.dtype, tf.float32)
     x = tf.random.stateless_normal((3, 8, 16), seed=(0, 0))
     indexes = tf.cast(
         10 * tf.random.stateless_uniform((3, 8, 16), seed=(0, 0)),
         tf.int32)
     loc = tf.random.stateless_uniform((3, 8, 16), seed=(0, 0))
     em(x, indexes, loc=loc)
     bitstring = em.compress(x, indexes, loc=loc)
     x_hat = em.decompress(bitstring, indexes, loc=loc)
     self.assertAllLessEqual(x - x_hat, 0.5)
     self.assertAllGreaterEqual(x - x_hat, -0.5)