コード例 #1
0
    def test_logging(self):
        """ Tests logging to screen and file. """

        # No logging
        with StreamCapture() as c:
            sampler = pints.NestedRejectionSampler(self.log_likelihood,
                                                   self.log_prior)
            sampler.set_posterior_samples(2)
            sampler.set_iterations(10)
            sampler.set_active_points_rate(10)
            sampler.set_log_to_screen(False)
            sampler.set_log_to_file(False)
            samples, margin = sampler.run()
        self.assertEqual(c.text(), '')

        # Log to screen
        with StreamCapture() as c:
            sampler = pints.NestedRejectionSampler(self.log_likelihood,
                                                   self.log_prior)
            sampler.set_posterior_samples(2)
            sampler.set_iterations(20)
            sampler.set_active_points_rate(10)
            sampler.set_log_to_screen(True)
            sampler.set_log_to_file(False)
            samples, margin = sampler.run()
        lines = c.text().splitlines()
        self.assertEqual(lines[0], 'Running nested rejection sampling')
        self.assertEqual(lines[1], 'Number of active points: 10')
        self.assertEqual(lines[2], 'Total number of iterations: 20')
        self.assertEqual(lines[3], 'Total number of posterior samples: 2')
        self.assertEqual(lines[4], 'Iter. Eval. Time m:s')
        pattern = re.compile('[0-9]+[ ]+[0-9]+[ ]+[0-9]{1}:[0-9]{2}.[0-9]{1}')
        for line in lines[5:]:
            self.assertTrue(pattern.match(line))
        self.assertEqual(len(lines), 11)

        # Log to file
        with StreamCapture() as c:
            with TemporaryDirectory() as d:
                filename = d.path('test.txt')
                sampler = pints.NestedRejectionSampler(self.log_likelihood,
                                                       self.log_prior)
                sampler.set_posterior_samples(2)
                sampler.set_iterations(10)
                sampler.set_active_points_rate(10)
                sampler.set_log_to_screen(False)
                sampler.set_log_to_file(filename)
                samples, margin = sampler.run()
                with open(filename, 'r') as f:
                    lines = f.read().splitlines()
            self.assertEqual(c.text(), '')
        self.assertEqual(len(lines), 6)
        self.assertEqual(lines[0], 'Iter. Eval. Time m:s')
        pattern = re.compile('[0-9]+[ ]+[0-9]+[ ]+[0-9]{1}:[0-9]{2}.[0-9]{1}')
        for line in lines[5:]:
            self.assertTrue(pattern.match(line))
コード例 #2
0
    def test_ask(self):
        # Tests ask.
        sampler = pints.NestedRejectionSampler(self.log_prior)
        pts = sampler.ask(1)
        self.assertTrue(np.isfinite(self.log_likelihood(pts)))

        # test multiple points being asked and tell'd
        sampler = pints.NestedRejectionSampler(self.log_prior)
        pts = sampler.ask(50)
        self.assertEqual(len(pts), 50)
        fx = [self.log_likelihood(pt) for pt in pts]
        proposed = sampler.tell(fx)
        self.assertTrue(len(proposed) > 1)
コード例 #3
0
    def test_getters_and_setters(self):
        """
        Tests various get() and set() methods.
        """
        sampler = pints.NestedRejectionSampler(self.log_likelihood,
                                               self.log_prior)

        # Iterations
        x = sampler.iterations() + 1
        self.assertNotEqual(sampler.iterations(), x)
        sampler.set_iterations(x)
        self.assertEqual(sampler.iterations(), x)
        self.assertRaisesRegex(ValueError, 'negative', sampler.set_iterations,
                               -1)

        # Active points rate
        x = sampler.active_points_rate() + 1
        self.assertNotEqual(sampler.active_points_rate(), x)
        sampler.set_active_points_rate(x)
        self.assertEqual(sampler.active_points_rate(), x)
        self.assertRaisesRegex(ValueError, 'greater than 5',
                               sampler.set_active_points_rate, 5)

        # Posterior samples
        x = sampler.posterior_samples() + 1
        self.assertNotEqual(sampler.posterior_samples(), x)
        sampler.set_posterior_samples(x)
        self.assertEqual(sampler.posterior_samples(), x)
        self.assertRaisesRegex(ValueError, 'greater than zero',
                               sampler.set_posterior_samples, 0)
コード例 #4
0
 def test_hyper_params(self):
     """
     Tests the hyper parameter interface is working.
     """
     sampler = pints.NestedRejectionSampler(self.log_likelihood,
                                            self.log_prior)
     self.assertEqual(sampler.n_hyper_parameters(), 1)
     sampler.set_hyper_parameters([6])
     self.assertEqual(sampler.active_points_rate(), 6)
コード例 #5
0
    def test_quick_run(self):
        """ Test a single run. """

        sampler = pints.NestedRejectionSampler(self.log_likelihood,
                                               self.log_prior)
        sampler.set_posterior_samples(10)
        sampler.set_iterations(50)
        sampler.set_active_points_rate(50)
        sampler.set_log_to_screen(False)
        samples, margin = sampler.run()
        # Check output: Note n returned samples = n posterior samples
        self.assertEqual(samples.shape, (10, 2))
コード例 #6
0
    def test_getters_and_setters(self):
        # Tests various get() and set() methods.
        sampler = pints.NestedRejectionSampler(self.log_prior)

        # Active points
        x = sampler.n_active_points() + 1
        self.assertNotEqual(sampler.n_active_points(), x)
        sampler.set_n_active_points(x)
        self.assertEqual(sampler.n_active_points(), x)
        self.assertRaisesRegex(
            ValueError, 'greater than 5', sampler.set_n_active_points, 5)
        self.assertEqual(sampler.name(), 'Nested rejection sampler')
        self.assertTrue(not sampler.needs_initial_phase())
コード例 #7
0
    def test_settings_check(self):
        """
        Tests the settings check at the start of a run.
        """
        sampler = pints.NestedRejectionSampler(self.log_likelihood,
                                               self.log_prior)
        sampler.set_posterior_samples(2)
        sampler.set_iterations(10)
        sampler.set_active_points_rate(10)
        sampler.set_log_to_screen(False)
        sampler.run()

        sampler.set_posterior_samples(10)
        self.assertRaisesRegex(ValueError, 'exceed 0.25', sampler.run)
コード例 #8
0
 def test_hyper_params(self):
     # Tests the hyper parameter interface is working.
     sampler = pints.NestedRejectionSampler(self.log_prior)
     self.assertEqual(sampler.n_hyper_parameters(), 1)
     sampler.set_hyper_parameters([220])