예제 #1
0
    def test_validate_good_run(self):
        # construct fit using existing sampler output
        exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION)
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        output = os.path.join(goodfiles_path, 'bern')
        sampler_args = SamplerArgs(sampling_iters=100,
                                   max_treedepth=11,
                                   adapt_delta=0.95)
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        fit = StanFit(args=cmdstan_args, chains=4)
        retcodes = fit._retcodes
        for i in range(len(retcodes)):
            fit._set_retcode(i, 0)
        self.assertTrue(fit._check_retcodes())
        fit._check_console_msgs()
        fit._validate_csv_files()
        self.assertEqual(4, fit.chains)
        self.assertEqual(100, fit.draws)
        self.assertEqual(8, len(fit.column_names))
        self.assertEqual('lp__', fit.column_names[0])

        df = fit.get_drawset()
        self.assertEqual(df.shape,
                         (fit.chains * fit.draws, len(fit.column_names)))
        _ = fit.summary()

        # TODO - use cmdstan test files instead
        expected = '\n'.join([
            'Checking sampler transitions treedepth.',
            'Treedepth satisfactory for all transitions.',
            '\nChecking sampler transitions for divergences.',
            'No divergent transitions found.',
            '\nChecking E-BFMI - sampler transitions HMC potential energy.',
            'E-BFMI satisfactory for all transitions.',
            '\nEffective sample size satisfactory.',
        ])
        self.assertIn(expected, fit.diagnose().replace("\r\n", "\n"))
예제 #2
0
 def test_diagnose_divergences(self):
     exe = os.path.join(datafiles_path,
                        'bernoulli' + EXTENSION)  # fake out validation
     output = os.path.join(datafiles_path, 'diagnose-good',
                           'corr_gauss_depth8')
     sampler_args = SamplerArgs()
     cmdstan_args = CmdStanArgs(
         model_name='bernoulli',
         model_exe=exe,
         chain_ids=[1],
         output_basename=output,
         method_args=sampler_args,
     )
     fit = StanFit(args=cmdstan_args, chains=1)
     # TODO - use cmdstan test files instead
     expected = '\n'.join([
         'Checking sampler transitions treedepth.',
         '424 of 1000 (42%) transitions hit the maximum '
         'treedepth limit of 8, or 2^8 leapfrog steps.',
         'Trajectories that are prematurely terminated '
         'due to this limit will result in slow exploration.',
         'For optimal performance, increase this limit.',
     ])
     self.assertIn(expected, fit.diagnose().replace("\r\n", "\n"))