Ejemplo n.º 1
0
 def testNumericalPartitionIsAccurate(self):
     """Test _numerical_base_partition_function against some golden data."""
     for (numer, denom) in [(0, 1), (1, 8), (1, 2), (1, 1), (2, 1), (8, 1)]:
         alpha = tf.cast(numer, tf.float64) / tf.cast(denom, tf.float64)
         z_true = distribution.analytical_base_partition_function(
             numer, denom)
         z = fit_partition_spline.numerical_base_partition_function(alpha)
         self.assertAllClose(z, z_true, atol=1e-10, rtol=1e-10)
Ejemplo n.º 2
0
 def testLogPartitionFractionsAreAccurate(self, float_dtype):
     """Test that the partition function is correct for [0/11, ... 22/11]."""
     numers = range(0, 23)
     denom = 11
     log_zs_true = [
         np.log(distribution.analytical_base_partition_function(n, denom))
         for n in numers
     ]
     log_zs = self._distribution.log_base_partition_function(
         float_dtype(np.array(numers)) / float_dtype(denom))
     self.assertAllClose(log_zs, log_zs_true, atol=1e-7, rtol=1e-7)
Ejemplo n.º 3
0
 def _log_partition_fractions_are_accurate(self, float_dtype):
   """Test that the partition function is correct for [0/11, ... 22/11]."""
   numers = range(0, 23)
   denom = 11
   log_zs_true = [
       np.log(distribution.analytical_base_partition_function(n, denom))
       for n in numers
   ]
   with self.session():
     log_zs = distribution.log_base_partition_function(
         float_dtype(np.array(numers)) / float_dtype(denom)).eval()
   self.assertAllClose(log_zs, log_zs_true, atol=1e-7, rtol=1e-7)
Ejemplo n.º 4
0
 def testAnalyaticalPartitionIsCorrect(self):
   """Tests _analytical_base_partition_function against some golden data."""
   # Here we enumerate a set of positive rational numbers n/d alongside
   # numerically approximated values of Z(n / d) up to 10 digits of precision,
   # stored as (n, d, Z(n/d)). This was generated with an external mathematica
   # script.
   ground_truth_rational_partitions = (
       (1, 7, 4.080330073), (1, 6, 4.038544331), (1, 5, 3.984791180),
       (1, 4, 3.912448576), (1, 3, 3.808203509), (2, 5, 3.735479786),
       (3, 7, 3.706553276), (1, 2, 3.638993131), (3, 5, 3.553489270),
       (2, 3, 3.501024540), (3, 4, 3.439385624), (4, 5, 3.404121259),
       (1, 1, 3.272306973), (6, 5, 3.149249092), (5, 4, 3.119044506),
       (4, 3, 3.068687433), (7, 5, 3.028084866), (3, 2, 2.965924889),
       (8, 5, 2.901059987), (5, 3, 2.855391798), (7, 4, 2.794052016),
       (7, 3, 2.260434598), (5, 2, 2.218882601), (8, 3, 2.190349858),
       (3, 1, 2.153202857), (4, 1, 2.101960916), (7, 2, 2.121140098),
       (5, 1, 2.080000512), (9, 2, 2.089161164), (6, 1, 2.067751267),
       (7, 1, 2.059929623), (8, 1, 2.054500222), (10, 3, 2.129863884),
       (11, 3, 2.113763384), (13, 3, 2.092928254), (14, 3, 2.085788350),
       (16, 3, 2.075212740), (11, 2, 2.073116001), (17, 3, 2.071185791),
       (13, 2, 2.063452243), (15, 2, 2.056990258))  # pyformat: disable
   for numer, denom, z_true in ground_truth_rational_partitions:
     z = distribution.analytical_base_partition_function(numer, denom)
     self.assertAllClose(z, z_true, atol=1e-9, rtol=1e-9)