def test_validate_big_run(self): exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION) # fake out validation output = os.path.join(datafiles_path, 'runset-big', 'output_icar_nyc') sampler_args = SamplerArgs() cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2], seed=12345, output_basename=output, method_args=sampler_args, ) fit = StanFit(args=cmdstan_args, chains=2) fit._validate_csv_files() sampler_state = [ 'lp__', 'accept_stat__', 'stepsize__', 'treedepth__', 'n_leapfrog__', 'divergent__', 'energy__', ] phis = ['phi.{}'.format(str(x + 1)) for x in range(2095)] column_names = sampler_state + phis self.assertEqual(fit.columns, len(column_names)) self.assertEqual(fit.column_names, tuple(column_names)) self.assertEqual(fit.metric_type, 'diag_e') self.assertEqual(fit.stepsize.shape, (2, )) self.assertEqual(fit.metric.shape, (2, 2095)) self.assertEqual((1000, 2, 2102), fit.sample.shape) phis = fit.get_drawset(params=['phi']) self.assertEqual((2000, 2095), phis.shape) phi1 = fit.get_drawset(params=['phi.1']) self.assertEqual((2000, 1), phi1.shape) mo_phis = fit.get_drawset(params=['phi.1', 'phi.10', 'phi.100']) self.assertEqual((2000, 3), mo_phis.shape) phi2095 = fit.get_drawset(params=['phi.2095']) self.assertEqual((2000, 1), phi2095.shape) with self.assertRaises(Exception): fit.get_drawset(params=['phi.2096']) with self.assertRaises(Exception): fit.get_drawset(params=['ph'])
def test_validate_good_run(self): # construct fit using existing sampler output exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION) jdata = os.path.join(datafiles_path, 'bernoulli.data.json') output = os.path.join(goodfiles_path, 'bern') sampler_args = SamplerArgs(sampling_iters=100, max_treedepth=11, adapt_delta=0.95) cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_basename=output, method_args=sampler_args, ) fit = StanFit(args=cmdstan_args, chains=4) retcodes = fit._retcodes for i in range(len(retcodes)): fit._set_retcode(i, 0) self.assertTrue(fit._check_retcodes()) fit._check_console_msgs() fit._validate_csv_files() self.assertEqual(4, fit.chains) self.assertEqual(100, fit.draws) self.assertEqual(8, len(fit.column_names)) self.assertEqual('lp__', fit.column_names[0]) df = fit.get_drawset() self.assertEqual(df.shape, (fit.chains * fit.draws, len(fit.column_names))) _ = fit.summary() # TODO - use cmdstan test files instead expected = '\n'.join([ 'Checking sampler transitions treedepth.', 'Treedepth satisfactory for all transitions.', '\nChecking sampler transitions for divergences.', 'No divergent transitions found.', '\nChecking E-BFMI - sampler transitions HMC potential energy.', 'E-BFMI satisfactory for all transitions.', '\nEffective sample size satisfactory.', ]) self.assertIn(expected, fit.diagnose().replace("\r\n", "\n"))