def test_validate_outputs(self): # construct runset using existing sampler output stan = os.path.join(datafiles_path, 'bernoulli.stan') exe = os.path.join(datafiles_path, 'bernoulli') model = Model(exe_file=exe, stan_file=stan) jdata = os.path.join(datafiles_path, 'bernoulli.data.json') output = os.path.join(goodfiles_path, 'bern') args = SamplerArgs( model, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_file=output, sampling_iters=100, max_treedepth=11, adapt_delta=0.95, ) runset = RunSet(chains=4, args=args) retcodes = runset.retcodes for i in range(len(retcodes)): runset.set_retcode(i, 0) self.assertTrue(runset.check_retcodes()) runset.check_console_msgs() runset.validate_csv_files() self.assertEqual(4, runset.chains) self.assertEqual(100, runset.draws) self.assertEqual(8, len(runset.column_names)) self.assertEqual('lp__', runset.column_names[0])
def test_validate_bad_transcript(self): stan = os.path.join(datafiles_path, 'bernoulli.stan') exe = os.path.join(datafiles_path, 'bernoulli') model = Model(exe_file=exe, stan_file=stan) jdata = os.path.join(datafiles_path, 'bernoulli.data.json') output = os.path.join(badfiles_path, 'bad-transcript-bern') args = SamplerArgs(model, chain_ids=[1,2,3,4], seed=12345, data=jdata, output_file=output, sampling_iters=100, max_treedepth=11, adapt_delta=0.95) runset = RunSet(chains=4, args=args) with self.assertRaisesRegex(Exception, 'Exception'): runset.check_console_msgs()