Ejemplo n.º 1
0
 def test_validate_outputs(self):
     # construct runset using existing sampler output
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
     output = os.path.join(goodfiles_path, 'bern')
     args = SamplerArgs(
         model,
         chain_ids=[1, 2, 3, 4],
         seed=12345,
         data=jdata,
         output_file=output,
         sampling_iters=100,
         max_treedepth=11,
         adapt_delta=0.95,
     )
     runset = RunSet(chains=4, args=args)
     retcodes = runset.retcodes
     for i in range(len(retcodes)):
         runset.set_retcode(i, 0)
     self.assertTrue(runset.check_retcodes())
     runset.check_console_msgs()
     runset.validate_csv_files()
     self.assertEqual(4, runset.chains)
     self.assertEqual(100, runset.draws)
     self.assertEqual(8, len(runset.column_names))
     self.assertEqual('lp__', runset.column_names[0])
Ejemplo n.º 2
0
 def test_validate_bad_transcript(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
     output = os.path.join(badfiles_path, 'bad-transcript-bern')
     args = SamplerArgs(model, chain_ids=[1,2,3,4],
                        seed=12345,
                        data=jdata,
                        output_file=output,
                        sampling_iters=100,
                        max_treedepth=11,
                        adapt_delta=0.95)
     runset = RunSet(chains=4, args=args)
     with self.assertRaisesRegex(Exception, 'Exception'):
         runset.check_console_msgs()