예제 #1
0
    def test_summary(self):
        # tests summary functions when time not given
        results = pints.MCMCSummary(self.chains)
        summary = np.array(results.summary())
        self.assertEqual(summary.shape[0], 3)
        self.assertEqual(summary.shape[1], 10)

        text = str(results)
        names = [
            'param',
            'mean',
            'std.',
            '2.5%',
            '25%',
            '50%',
            '75%',
            '97.5%',
            'rhat',
            'ess',
        ]
        for name in names:
            self.assertIn(name, text)

        # tests summary functions when time is given
        results = pints.MCMCSummary(self.chains, 20)
        summary = np.array(results.summary())
        self.assertEqual(summary.shape[0], 3)
        self.assertEqual(summary.shape[1], 11)

        text = str(results)
        names.append('ess per sec.')
        for name in names:
            self.assertIn(name, text)
예제 #2
0
    def test_single_chain(self):
        # tests that single chain is broken up into two bits
        xs = [self.real_parameters * 0.9]
        mcmc = pints.MCMCController(self.log_posterior,
                                    1,
                                    xs,
                                    method=pints.HaarioBardenetACMC)
        mcmc.set_max_iterations(200)
        mcmc.set_initial_phase_iterations(50)
        mcmc.set_log_to_screen(False)
        chains = mcmc.run()
        results = pints.MCMCSummary(chains)
        chains1 = results.chains()
        self.assertEqual(chains[0].shape[0], chains1[0].shape[0])
        self.assertEqual(chains[0].shape[1], chains1[0].shape[1])
        self.assertEqual(chains[0][10, 1], chains[0][10, 1])

        self.assertEqual(results.time(), None)
        self.assertEqual(results.ess_per_second(), None)
        self.assertTrue(len(results.ess()), 3)
        self.assertTrue(len(results.mean()), 3)
        self.assertTrue(len(results.rhat()), 3)
        self.assertTrue(len(results.std()), 3)

        # check positive quantities are so
        for i in range(3):
            self.assertTrue(results.ess()[i] > 0)
            self.assertTrue(results.ess()[i] < 1000)
            self.assertTrue(results.rhat()[i] > 0)
            self.assertTrue(results.std()[i] > 0)
            self.assertTrue(results.mean()[i] > 0)

        # check means are vaguely near true values
        self.assertTrue(np.abs(results.mean()[0] - 0.015) < 0.5)
        self.assertTrue(np.abs(results.mean()[1] - 500) < 200)
        self.assertTrue(np.abs(results.mean()[2] - 10) < 30)

        # check quantiles object
        quantiles = results.quantiles()
        self.assertEqual(quantiles.shape[0], 5)
        self.assertEqual(quantiles.shape[1], 3)
        for i in range(5):
            for j in range(3):
                self.assertTrue(quantiles[i, j] > 0)

        # Test with odd number of iterations
        mcmc = pints.MCMCController(self.log_posterior,
                                    1,
                                    xs,
                                    method=pints.HaarioBardenetACMC)
        mcmc.set_max_iterations(99)
        mcmc.set_initial_phase_iterations(40)
        mcmc.set_log_to_screen(False)
        chains = mcmc.run()
        results = pints.MCMCSummary(chains)
예제 #3
0
    def test_running(self):
        # tests that object works as expected
        results = pints.MCMCSummary(self.chains)
        self.assertEqual(results.time(), None)
        self.assertEqual(results.ess_per_second(), None)
        self.assertTrue(len(results.ess()), 3)
        self.assertTrue(len(results.mean()), 3)
        self.assertTrue(len(results.rhat()), 3)
        self.assertTrue(len(results.std()), 3)

        # check positive quantities are so
        for i in range(3):
            self.assertTrue(results.ess()[i] > 0)
            self.assertTrue(results.ess()[i] < 1000)
            self.assertTrue(results.rhat()[i] > 0)
            self.assertTrue(results.std()[i] > 0)
            self.assertTrue(results.mean()[i] > 0)

        # check means are vaguely near true values
        self.assertTrue(np.abs(results.mean()[0] - 0.015) < 0.1)
        self.assertTrue(np.abs(results.mean()[1] - 500) < 100)
        self.assertTrue(np.abs(results.mean()[2] - 10) < 20)

        # check quantiles object
        quantiles = results.quantiles()
        self.assertEqual(quantiles.shape[0], 5)
        self.assertEqual(quantiles.shape[1], 3)
        for i in range(5):
            for j in range(3):
                self.assertTrue(quantiles[i, j] > 0)
예제 #4
0
def plotting(old_model=False):
    if old_model:
        results_file = 'results2.pickle'
        names = ['k0', 'E0', 'Cdl', 'Ru', 'alpha', 'omega', 'sigma']
    else:
        results_file = 'results.pickle'
        names = ['k0', 'E0', 'a', 'Ru', 'Cdl', 'omega', 'sigma']

    (xs, log_posterior, log_prior, chains, mcmc_method) = pickle.load(
        open(results_file, 'rb'))

    print(
        'Found results using mcmc={} containing {} chains with {} samples'.format(
            mcmc_method, len(chains), chains[0].shape[0]
        )
    )

    print(chains.shape)
    pints.plot.trace(chains)
    plt.savefig('electrochem_pde_trace.pdf')

    pints.plot.pairwise(chains[0, :, :])
    plt.savefig('electrochem_pde_chains.pdf')

    results = pints.MCMCSummary(chains, parameter_names=names)
    print(results)
예제 #5
0
 def test_ess_per_second(self):
     # tests that ess per second is calculated when time is supplied
     t = 10
     results = pints.MCMCSummary(self.chains, t)
     self.assertEqual(results.time(), t)
     ess_per_second = results.ess_per_second()
     ess = results.ess()
     self.assertTrue(len(ess_per_second), 3)
     for i in range(3):
         self.assertEqual(ess_per_second[i], ess[i] / t)
예제 #6
0
    def test_named_parameters(self):
        # tests that parameter names are used when values supplied
        parameters = ['rrrr', 'kkkk', 'ssss']
        results = pints.MCMCSummary(self.chains, parameter_names=parameters)
        text = str(results)
        for p in parameters:
            self.assertIn(p, text)

        # with time supplied
        results = pints.MCMCSummary(self.chains,
                                    time=20,
                                    parameter_names=parameters)
        text = str(results)
        for p in parameters:
            self.assertIn(p, text)

        # Number of parameter names must equal number of parameters
        self.assertRaises(ValueError,
                          pints.MCMCSummary,
                          self.chains,
                          parameter_names=['a', 'b'])
    def inference_problem_setup(self, times, num_iter, wd=1, wp=1):
        """
        Runs the parameter inference routine for the PHE model.

        Parameters
        ----------
        times
            (list) List of time points at which we have data for the
            log-likelihood computation.
        num_iter
            Number of iterations the MCMC sampler algorithm is run for.
        wd
            Proportion of contribution of the deaths_data to the
            log-likelihood.
        wp
            Proportion of contribution of the poritives_data to the
            log-likelihood.

        """
        # Starting points using optimisation object
        x0 = [self.optimisation_problem_setup(times, wd, wp)[0].tolist()] * 3

        # Create MCMC routine
        mcmc = pints.MCMCController(self._log_posterior, 3, x0)
        mcmc.set_max_iterations(num_iter)
        mcmc.set_log_to_screen(True)
        mcmc.set_parallel(True)

        print('Running...')
        chains = mcmc.run()
        print('Done!')

        param_names = ['initial_r']
        for region in self._model.regions:
            param_names.extend([
                'beta_W{}_{}'.format(i + 1, region)
                for i in range(len(np.arange(44, len(times), 7)))
            ])
        param_names.extend(['sigma_b'])

        # Check convergence and other properties of chains
        results = pints.MCMCSummary(chains=chains,
                                    time=mcmc.time(),
                                    parameter_names=param_names)
        print(results)

        return chains