Ejemplo n.º 1
0
 def test_diagnose_divergences(self):
     exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
     sampler_args = SamplerArgs()
     cmdstan_args = CmdStanArgs(
         model_name='bernoulli',
         model_exe=exe,
         chain_ids=[1],
         output_dir=DATAFILES_PATH,
         method_args=sampler_args,
     )
     runset = RunSet(args=cmdstan_args, chains=1)
     runset._csv_files = [
         os.path.join(DATAFILES_PATH, 'diagnose-good',
                      'corr_gauss_depth8-1.csv')
     ]
     fit = CmdStanMCMC(runset)
     # TODO - use cmdstan test files instead
     expected = '\n'.join([
         'Checking sampler transitions treedepth.',
         '424 of 1000 (42%) transitions hit the maximum '
         'treedepth limit of 8, or 2^8 leapfrog steps.',
         'Trajectories that are prematurely terminated '
         'due to this limit will result in slow exploration.',
         'For optimal performance, increase this limit.',
     ])
     self.assertIn(expected, fit.diagnose().replace('\r\n', '\n'))
Ejemplo n.º 2
0
    def test_validate_good_run(self):
        # construct fit using existing sampler output
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs(iter_sampling=100,
                                   max_treedepth=11,
                                   adapt_delta=0.95)
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_dir=DATAFILES_PATH,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=4)
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'),
        ]
        self.assertEqual(4, runset.chains)
        retcodes = runset._retcodes
        for i in range(len(retcodes)):
            runset._set_retcode(i, 0)
        self.assertTrue(runset._check_retcodes())

        fit = CmdStanMCMC(runset)
        self.assertEqual(100, fit.num_draws)
        self.assertEqual(len(BERNOULLI_COLS), len(fit.column_names))
        self.assertEqual('lp__', fit.column_names[0])

        drawset = fit.get_drawset()
        self.assertEqual(
            drawset.shape,
            (fit.runset.chains * fit.num_draws, len(fit.column_names)),
        )
        _ = fit.summary()
        self.assertTrue(True)

        # TODO - use cmdstan test files instead
        expected = '\n'.join([
            'Checking sampler transitions treedepth.',
            'Treedepth satisfactory for all transitions.',
            '\nChecking sampler transitions for divergences.',
            'No divergent transitions found.',
            '\nChecking E-BFMI - sampler transitions HMC potential energy.',
            'E-BFMI satisfactory for all transitions.',
            '\nEffective sample size satisfactory.',
        ])
        self.assertIn(expected, fit.diagnose().replace('\r\n', '\n'))
Ejemplo n.º 3
0
    def test_validate_good_run(self):
        # construct fit using existing sampler output
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs(iter_sampling=100,
                                   max_treedepth=11,
                                   adapt_delta=0.95)
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_dir=DATAFILES_PATH,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args)
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'),
        ]
        self.assertEqual(4, runset.chains)
        retcodes = runset._retcodes
        for i in range(len(retcodes)):
            runset._set_retcode(i, 0)
        self.assertTrue(runset._check_retcodes())

        fit = CmdStanMCMC(runset)
        self.assertEqual(100, fit.num_draws)
        self.assertEqual(len(BERNOULLI_COLS), len(fit.column_names))
        self.assertEqual('lp__', fit.column_names[0])

        drawset = fit.get_drawset()
        self.assertEqual(
            drawset.shape,
            (fit.runset.chains * fit.num_draws, len(fit.column_names)),
        )

        summary = fit.summary()
        self.assertIn('5%', list(summary.columns))
        self.assertIn('50%', list(summary.columns))
        self.assertIn('95%', list(summary.columns))
        self.assertNotIn('1%', list(summary.columns))
        self.assertNotIn('99%', list(summary.columns))

        summary = fit.summary(percentiles=[1, 45, 99])
        self.assertIn('1%', list(summary.columns))
        self.assertIn('45%', list(summary.columns))
        self.assertIn('99%', list(summary.columns))
        self.assertNotIn('5%', list(summary.columns))
        self.assertNotIn('50%', list(summary.columns))
        self.assertNotIn('95%', list(summary.columns))

        with self.assertRaises(ValueError):
            fit.summary(percentiles=[])

        with self.assertRaises(ValueError):
            fit.summary(percentiles=[-1])

        diagnostics = fit.diagnose()
        self.assertIn('Treedepth satisfactory for all transitions.',
                      diagnostics)
        self.assertIn('No divergent transitions found.', diagnostics)
        self.assertIn('E-BFMI satisfactory for all transitions.', diagnostics)
        self.assertIn('Effective sample size satisfactory.', diagnostics)