Beispiel #1
0
    def test_validate_good_run(self):
        # construct fit using existing sampler output
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs(iter_sampling=100,
                                   max_treedepth=11,
                                   adapt_delta=0.95)
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_dir=DATAFILES_PATH,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=4)
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'),
        ]
        self.assertEqual(4, runset.chains)
        retcodes = runset._retcodes
        for i in range(len(retcodes)):
            runset._set_retcode(i, 0)
        self.assertTrue(runset._check_retcodes())

        fit = CmdStanMCMC(runset)
        self.assertEqual(100, fit.num_draws)
        self.assertEqual(len(BERNOULLI_COLS), len(fit.column_names))
        self.assertEqual('lp__', fit.column_names[0])

        drawset = fit.get_drawset()
        self.assertEqual(
            drawset.shape,
            (fit.runset.chains * fit.num_draws, len(fit.column_names)),
        )
        _ = fit.summary()
        self.assertTrue(True)

        # TODO - use cmdstan test files instead
        expected = '\n'.join([
            'Checking sampler transitions treedepth.',
            'Treedepth satisfactory for all transitions.',
            '\nChecking sampler transitions for divergences.',
            'No divergent transitions found.',
            '\nChecking E-BFMI - sampler transitions HMC potential energy.',
            'E-BFMI satisfactory for all transitions.',
            '\nEffective sample size satisfactory.',
        ])
        self.assertIn(expected, fit.diagnose().replace('\r\n', '\n'))
Beispiel #2
0
    def test_validate_summary_sig_figs(self):
        # construct CmdStanMCMC from logistic model output, config
        exe = os.path.join(DATAFILES_PATH, 'logistic' + EXTENSION)
        rdata = os.path.join(DATAFILES_PATH, 'logistic.data.R')
        sampler_args = SamplerArgs(iter_sampling=100)
        cmdstan_args = CmdStanArgs(
            model_name='logistic',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=rdata,
            output_dir=DATAFILES_PATH,
            sig_figs=17,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args)
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'logistic_output_1.csv'),
            os.path.join(DATAFILES_PATH, 'logistic_output_2.csv'),
            os.path.join(DATAFILES_PATH, 'logistic_output_3.csv'),
            os.path.join(DATAFILES_PATH, 'logistic_output_4.csv'),
        ]
        retcodes = runset._retcodes
        for i in range(len(retcodes)):
            runset._set_retcode(i, 0)
        fit = CmdStanMCMC(runset)

        sum_default = fit.summary()
        beta1_default = format(sum_default.iloc[1, 0], '.18g')
        self.assertTrue(beta1_default.startswith('1.3'))

        if cmdstan_version_at(2, 25):
            sum_17 = fit.summary(sig_figs=17)
            beta1_17 = format(sum_17.iloc[1, 0], '.18g')
            self.assertTrue(beta1_17.startswith('1.345767078273'))

            sum_10 = fit.summary(sig_figs=10)
            beta1_10 = format(sum_10.iloc[1, 0], '.18g')
            self.assertTrue(beta1_10.startswith('1.34576707'))

        with self.assertRaises(ValueError):
            fit.summary(sig_figs=20)
        with self.assertRaises(ValueError):
            fit.summary(sig_figs=-1)
Beispiel #3
0
    def test_validate_good_run(self):
        # construct fit using existing sampler output
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs(iter_sampling=100,
                                   max_treedepth=11,
                                   adapt_delta=0.95)
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_dir=DATAFILES_PATH,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args)
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'),
        ]
        self.assertEqual(4, runset.chains)
        retcodes = runset._retcodes
        for i in range(len(retcodes)):
            runset._set_retcode(i, 0)
        self.assertTrue(runset._check_retcodes())

        fit = CmdStanMCMC(runset)
        self.assertEqual(100, fit.num_draws)
        self.assertEqual(len(BERNOULLI_COLS), len(fit.column_names))
        self.assertEqual('lp__', fit.column_names[0])

        drawset = fit.get_drawset()
        self.assertEqual(
            drawset.shape,
            (fit.runset.chains * fit.num_draws, len(fit.column_names)),
        )

        summary = fit.summary()
        self.assertIn('5%', list(summary.columns))
        self.assertIn('50%', list(summary.columns))
        self.assertIn('95%', list(summary.columns))
        self.assertNotIn('1%', list(summary.columns))
        self.assertNotIn('99%', list(summary.columns))

        summary = fit.summary(percentiles=[1, 45, 99])
        self.assertIn('1%', list(summary.columns))
        self.assertIn('45%', list(summary.columns))
        self.assertIn('99%', list(summary.columns))
        self.assertNotIn('5%', list(summary.columns))
        self.assertNotIn('50%', list(summary.columns))
        self.assertNotIn('95%', list(summary.columns))

        with self.assertRaises(ValueError):
            fit.summary(percentiles=[])

        with self.assertRaises(ValueError):
            fit.summary(percentiles=[-1])

        diagnostics = fit.diagnose()
        self.assertIn('Treedepth satisfactory for all transitions.',
                      diagnostics)
        self.assertIn('No divergent transitions found.', diagnostics)
        self.assertIn('E-BFMI satisfactory for all transitions.', diagnostics)
        self.assertIn('Effective sample size satisfactory.', diagnostics)