コード例 #1
0
 def test_gelman_rubin_raise_error(self):
     chains=setup_chains()
     for _ in range(len(chains) - 1):
         chains.pop(-1)
     with self.assertRaises(ValueError,
                            msg='Must have multiple chains'):
         CS.gelman_rubin(chains=chains)
コード例 #2
0
 def test_string_display(self):
     capturedOutput = io.StringIO()                  # Create StringIO object
     sys.stdout = capturedOutput                     #  and redirect stdout.
     CS.print_chain_acceptance_info(chain)
     sys.stdout = sys.__stdout__  # Reset redirect.
     self.assertTrue(isinstance(capturedOutput.getvalue(), str),
                     msg='Caputured string')
     self.assertFalse('Results dictionary' in capturedOutput.getvalue(),
                     msg='Expect results dictionary not included')
コード例 #3
0
 def test_string_display_w_details(self):
     capturedOutput = io.StringIO()                  # Create StringIO object
     sys.stdout = capturedOutput                     #  and redirect stdout.
     CS.chainstats(chain, display_details=True)
     sys.stdout = sys.__stdout__  # Reset redirect.
     self.assertTrue(isinstance(capturedOutput.getvalue(), str),
                     msg='Caputured string')
     self.assertFalse('Results dictionary' in capturedOutput.getvalue(),
                     msg='Expect results dictionary not included')
     self.assertTrue('Definition for items displayed:' in capturedOutput.getvalue())
コード例 #4
0
 def test_string_display_with_results(self):
     capturedOutput = io.StringIO()                  # Create StringIO object
     sys.stdout = capturedOutput                     #  and redirect stdout.
     CS.print_chain_acceptance_info(
             chain, results=dict(
                     nsimu=5000,
                     iacce=np.array([200, 800])
                     ))
     sys.stdout = sys.__stdout__  # Reset redirect.
     self.assertTrue(isinstance(capturedOutput.getvalue(), str),
                     msg='Caputured string')
     self.assertTrue('Results dictionary' in capturedOutput.getvalue(),
                     msg='Expect results dictionary included')
コード例 #5
0
 def test_gelman_rubin_with_pres(self):
     chains = setup_chains()
     pres = []
     for _, chain in enumerate(chains):
         pres.append(dict(chain=chain, nsimu=chain.shape[0]))
     psrf = CS.gelman_rubin(chains=pres)
     self.standard_check(psrf)
コード例 #6
0
 def test_gelman_rubin(self):
     chains=setup_chains()
     capturedOutput = io.StringIO()                  # Create StringIO object
     sys.stdout = capturedOutput                     #  and redirect stdout.
     psrf = CS.gelman_rubin(chains=chains, display=True)
     sys.stdout = sys.__stdout__                     # Reset redirect.
     self.assertTrue(isinstance(capturedOutput.getvalue(), str),
                     msg='Caputured string')
     self.standard_check(psrf)
コード例 #7
0
 def test_calc_psrf(self):
     x = np.concatenate((np.linspace(0, 1, 1000).reshape(1000, 1),
                         np.linspace(2.5, 3.3, 1000).reshape(1000, 1)),
                         axis=1)
     psrf = CS.calculate_psrf(x, nsimu=1000, nchains=2)
     self.assertTrue(isinstance(psrf, dict), msg='Expect dictionary output')
     self.assertAlmostEqual(psrf['R'], 8.001818935964492, places=6,
                            msg=str('R: {} neq {}'.format(psrf['R'], 8.001818935964492)))
     self.assertAlmostEqual(psrf['B'], 2879.9999999999964, places=6,
                            msg=str('R: {} neq {}'.format(psrf['B'], 2879.9999999999964)))
     self.assertAlmostEqual(psrf['W'], 0.06853867547894903, places=6,
                            msg=str('R: {} neq {}'.format(psrf['W'], 0.06853867547894903)))
     self.assertAlmostEqual(psrf['V'], 4.388470136803464, places=6,
                            msg=str('R: {} neq {}'.format(psrf['V'], 4.388470136803464)))
     self.assertAlmostEqual(psrf['neff'], 3.047548706113521, places=6,
                            msg=str('R: {} neq {}'.format(psrf['neff'], 3.047548706113521)))
コード例 #8
0
 def test_too_few_batches(self):
     with self.assertRaises(SystemExit, msg='too few batches'):
         CS.batch_mean_standard_deviation(chain, b=chain.shape[0])
コード例 #9
0
 def test_len_s(self):
     s = CS.batch_mean_standard_deviation(chain, b=None)
     self.assertEqual(len(s), 2)
コード例 #10
0
 def test_cs_eval_with_no_chain(self):
     stats = CS.chainstats(chain=None, returnstats=True)
     self.assertTrue(isinstance(stats, str))
コード例 #11
0
 def test_cs_eval_with_no_return(self):
     stats = CS.chainstats(chain=chain, returnstats=False)
     self.assertEqual(stats, None)
コード例 #12
0
 def test_cs_eval_with_return(self):
     stats = CS.chainstats(chain=chain, returnstats=True)
     self.assertTrue(isinstance(stats, dict))
コード例 #13
0
 def test_gelman_rubin_with_names(self):
     chains = setup_chains()
     psrf = CS.gelman_rubin(chains=chains, names=['a', 'b'])
     self.standard_check(psrf)
コード例 #14
0
 def test_nw_not_none(self):
     x = chain[:, 0]
     y = CS.power_spectral_density_using_hanning_window(x=x, nw=len(x))
     nfft = min(len(x), 256)
     n2 = int(np.floor(nfft/2))
     self.assertEqual(n2, len(y))
コード例 #15
0
 def test_nfft_not_none_size(self):
     x = chain[:, 0]
     nfft = 100
     y = CS.power_spectral_density_using_hanning_window(x=x, nfft=nfft)
     n2 = int(np.floor(nfft/2))
     self.assertEqual(n2, len(y))