def test_diagnose_divergences(self): exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) sampler_args = SamplerArgs() cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1], output_dir=DATAFILES_PATH, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=1) runset._csv_files = [ os.path.join(DATAFILES_PATH, 'diagnose-good', 'corr_gauss_depth8-1.csv') ] fit = CmdStanMCMC(runset) # TODO - use cmdstan test files instead expected = '\n'.join([ 'Checking sampler transitions treedepth.', '424 of 1000 (42%) transitions hit the maximum ' 'treedepth limit of 8, or 2^8 leapfrog steps.', 'Trajectories that are prematurely terminated ' 'due to this limit will result in slow exploration.', 'For optimal performance, increase this limit.', ]) self.assertIn(expected, fit.diagnose().replace('\r\n', '\n'))
def test_variables(self): # construct fit using existing sampler output exe = os.path.join(DATAFILES_PATH, 'lotka-volterra' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'lotka-volterra.data.json') sampler_args = SamplerArgs(iter_sampling=20) cmdstan_args = CmdStanArgs( model_name='lotka-volterra', model_exe=exe, chain_ids=[1], seed=12345, data=jdata, output_dir=DATAFILES_PATH, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=1) runset._csv_files = [ os.path.join(DATAFILES_PATH, 'lotka-volterra.csv') ] runset._set_retcode(0, 0) fit = CmdStanMCMC(runset) self.assertEqual(20, fit.num_draws) self.assertEqual(8, len(fit._stan_variable_dims)) self.assertTrue('z' in fit._stan_variable_dims) self.assertEqual(fit._stan_variable_dims['z'], (20, 2)) vars = fit.stan_variables() self.assertEqual(len(vars), len(fit._stan_variable_dims)) self.assertTrue('z' in vars) self.assertEqual(vars['z'].shape, (20, 20, 2)) self.assertTrue('theta' in vars) self.assertEqual(vars['theta'].shape, (20, 4))
def test_no_chains(self): exe = os.path.join(datafiles_path, 'bernoulli') jdata = os.path.join(datafiles_path, 'bernoulli.data.json') jinits = os.path.join(datafiles_path, 'bernoulli.init.json') sampler_args = SamplerArgs() with self.assertRaises(ValueError): CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=None, seed=[1, 2, 3], data=jdata, inits=jinits, method_args=sampler_args ) with self.assertRaises(ValueError): CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=None, data=jdata, inits=[jinits], method_args=sampler_args )
def test_args_good(self): exe = os.path.join(datafiles_path, 'bernoulli') jdata = os.path.join(datafiles_path, 'bernoulli.data.json') sampler_args = SamplerArgs() cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], data=jdata, method_args=sampler_args, ) self.assertEqual(cmdstan_args.method, Method.SAMPLE) cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv') self.assertIn('id=1 random seed=', ' '.join(cmd)) self.assertIn('data file=', ' '.join(cmd)) self.assertIn('output file=', ' '.join(cmd)) self.assertIn('method=sample algorithm=hmc', ' '.join(cmd)) cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[7, 11, 18, 29], data=jdata, method_args=sampler_args, ) cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv') self.assertIn('id=7 random seed=', ' '.join(cmd))
def test_save_diagnostics(self): exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') sampler_args = SamplerArgs() chain_ids = [1, 2, 3, 4] cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=chain_ids, data=jdata, method_args=sampler_args, save_diagnostics=True, ) runset = RunSet(args=cmdstan_args) self.assertIn(_TMPDIR, runset.diagnostic_files[0]) cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=chain_ids, data=jdata, method_args=sampler_args, save_diagnostics=True, output_dir=os.path.abspath('.'), ) runset = RunSet(args=cmdstan_args) self.assertIn(os.path.abspath('.'), runset.diagnostic_files[0])
def test_validate_big_run(self): exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) sampler_args = SamplerArgs(iter_warmup=1500, iter_sampling=1000) cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2], seed=12345, output_dir=DATAFILES_PATH, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=2) runset._csv_files = [ os.path.join(DATAFILES_PATH, 'runset-big', 'output_icar_nyc-1.csv'), os.path.join(DATAFILES_PATH, 'runset-big', 'output_icar_nyc-1.csv'), ] fit = CmdStanMCMC(runset) phis = ['phi[{}]'.format(str(x + 1)) for x in range(2095)] column_names = SAMPLER_STATE + phis self.assertEqual(fit.num_draws_sampling, 1000) self.assertEqual(fit.column_names, tuple(column_names)) self.assertEqual(fit.metric_type, 'diag_e') self.assertEqual(fit.step_size.shape, (2, )) self.assertEqual(fit.metric.shape, (2, 2095)) self.assertEqual((1000, 2, 2102), fit.draws().shape) phis = fit.draws_pd(params=['phi']) self.assertEqual((2000, 2095), phis.shape) with self.assertRaisesRegex(ValueError, r'unknown parameter: gamma'): fit.draws_pd(params=['gamma'])
def test_check_retcodes(self): exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') sampler_args = SamplerArgs() cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], data=jdata, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=4) self.assertIn('RunSet: chains=4', runset.__repr__()) self.assertIn('method=sample', runset.__repr__()) retcodes = runset._retcodes self.assertEqual(4, len(retcodes)) for i in range(len(retcodes)): self.assertEqual(-1, runset._retcode(i)) runset._set_retcode(0, 0) self.assertEqual(0, runset._retcode(0)) for i in range(1, len(retcodes)): self.assertEqual(-1, runset._retcode(i)) self.assertFalse(runset._check_retcodes()) for i in range(1, len(retcodes)): runset._set_retcode(i, 0) self.assertTrue(runset._check_retcodes())
def test_good(self): # construct fit using existing sampler output exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') sampler_args = SamplerArgs(iter_sampling=100, max_treedepth=11, adapt_delta=0.95) cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_dir=DATAFILES_PATH, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args) runset._csv_files = [ os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'), ] retcodes = runset._retcodes for i in range(len(retcodes)): runset._set_retcode(i, 0) config = check_sampler_csv( path=runset.csv_files[i], is_fixed_param=False, iter_sampling=100, iter_warmup=1000, save_warmup=False, thin=1, ) expected = 'Metadata:\n{}\n'.format(config) metadata = InferenceMetadata(config) actual = '{}'.format(metadata) self.assertEqual(expected, actual) self.assertEqual(config, metadata.cmdstan_config) hmc_vars = { 'lp__', 'accept_stat__', 'stepsize__', 'treedepth__', 'n_leapfrog__', 'divergent__', 'energy__', } sampler_vars_cols = metadata.sampler_vars_cols self.assertEqual(hmc_vars, sampler_vars_cols.keys()) bern_model_vars = {'theta'} self.assertEqual(bern_model_vars, metadata.stan_vars_dims.keys()) self.assertEqual((), metadata.stan_vars_dims['theta']) self.assertEqual(bern_model_vars, metadata.stan_vars_cols.keys()) self.assertEqual((7, ), metadata.stan_vars_cols['theta'])
def test_args_inits(self): exe = os.path.join(datafiles_path, 'bernoulli') jdata = os.path.join(datafiles_path, 'bernoulli.data.json') sampler_args = SamplerArgs() jinits = os.path.join(datafiles_path, 'bernoulli.init.json') jinits1 = os.path.join(datafiles_path, 'bernoulli.init_1.json') jinits2 = os.path.join(datafiles_path, 'bernoulli.init_2.json') cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], data=jdata, inits=jinits, method_args=sampler_args, ) cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv') self.assertIn('init=', ' '.join(cmd)) cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2], data=jdata, inits=[jinits1, jinits2], method_args=sampler_args, ) cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv') self.assertIn('bernoulli.init_1.json', ' '.join(cmd)) cmd = cmdstan_args.compose_command(idx=1, csv_file='bern-output-1.csv') self.assertIn('bernoulli.init_2.json', ' '.join(cmd)) cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], data=jdata, inits=0, method_args=sampler_args, ) cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv') self.assertIn('init=0', ' '.join(cmd)) cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], data=jdata, inits=3.33, method_args=sampler_args, ) cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv') self.assertIn('init=3.33', ' '.join(cmd))
def test_validate_good_run(self): # construct fit using existing sampler output exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') sampler_args = SamplerArgs(iter_sampling=100, max_treedepth=11, adapt_delta=0.95) cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_dir=DATAFILES_PATH, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=4) runset._csv_files = [ os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'), ] self.assertEqual(4, runset.chains) retcodes = runset._retcodes for i in range(len(retcodes)): runset._set_retcode(i, 0) self.assertTrue(runset._check_retcodes()) fit = CmdStanMCMC(runset) self.assertEqual(100, fit.num_draws) self.assertEqual(len(BERNOULLI_COLS), len(fit.column_names)) self.assertEqual('lp__', fit.column_names[0]) drawset = fit.get_drawset() self.assertEqual( drawset.shape, (fit.runset.chains * fit.num_draws, len(fit.column_names)), ) _ = fit.summary() self.assertTrue(True) # TODO - use cmdstan test files instead expected = '\n'.join([ 'Checking sampler transitions treedepth.', 'Treedepth satisfactory for all transitions.', '\nChecking sampler transitions for divergences.', 'No divergent transitions found.', '\nChecking E-BFMI - sampler transitions HMC potential energy.', 'E-BFMI satisfactory for all transitions.', '\nEffective sample size satisfactory.', ]) self.assertIn(expected, fit.diagnose().replace('\r\n', '\n'))
def test_compose(self): exe = os.path.join(datafiles_path, 'bernoulli') sampler_args = SamplerArgs() cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], method_args=sampler_args, ) with self.assertRaises(ValueError): cmdstan_args.compose_command(idx=4, csv_file='foo') with self.assertRaises(ValueError): cmdstan_args.compose_command(idx=-1, csv_file='foo')
def test_gen_quantities_good(self): stan = os.path.join(datafiles_path, 'bernoulli_ppc.stan') model = Model(stan_file=stan) model.compile() jdata = os.path.join(datafiles_path, 'bernoulli.data.json') # synthesize stanfit object - # see test_stanfit.py, method 'test_validate_good_run' goodfiles_path = os.path.join(datafiles_path, 'runset-good') output = os.path.join(goodfiles_path, 'bern') sampler_args = SamplerArgs( sampling_iters=100, max_treedepth=11, adapt_delta=0.95 ) cmdstan_args = CmdStanArgs( model_name=model.name, model_exe=model.exe_file, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_basename=output, method_args=sampler_args, ) sampler_fit = StanFit(args=cmdstan_args, chains=4) for i in range(4): sampler_fit._set_retcode(i, 0) bern_fit = model.run_generated_quantities( csv_files=sampler_fit.csv_files, data=jdata) # check results - ouput files, quantities of interest, draws self.assertEqual(bern_fit.chains, 4) for i in range(4): self.assertEqual(bern_fit._retcodes[i], 0) csv_file = bern_fit.csv_files[i] self.assertTrue(os.path.exists(csv_file)) column_names = [ 'y_rep.1', 'y_rep.2', 'y_rep.3', 'y_rep.4', 'y_rep.5', 'y_rep.6', 'y_rep.7', 'y_rep.8', 'y_rep.9', 'y_rep.10' ] self.assertEqual(bern_fit.column_names, tuple(column_names)) self.assertEqual(bern_fit.draws, 100)
def test_commands(self): exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') sampler_args = SamplerArgs() cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], data=jdata, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=4) self.assertIn('id=1', runset._cmds[0]) self.assertIn('id=4', runset._cmds[3])
def test_validate_big_run(self): exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) sampler_args = SamplerArgs() cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2], seed=12345, output_dir=DATAFILES_PATH, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=2) runset._csv_files = [ os.path.join(DATAFILES_PATH, 'runset-big', 'output_icar_nyc-1.csv'), os.path.join(DATAFILES_PATH, 'runset-big', 'output_icar_nyc-1.csv'), ] fit = CmdStanMCMC(runset) fit._validate_csv_files() sampler_state = [ 'lp__', 'accept_stat__', 'stepsize__', 'treedepth__', 'n_leapfrog__', 'divergent__', 'energy__', ] phis = ['phi.{}'.format(str(x + 1)) for x in range(2095)] column_names = sampler_state + phis self.assertEqual(fit.columns, len(column_names)) self.assertEqual(fit.column_names, tuple(column_names)) self.assertEqual(fit.metric_type, 'diag_e') self.assertEqual(fit.stepsize.shape, (2, )) self.assertEqual(fit.metric.shape, (2, 2095)) self.assertEqual((1000, 2, 2102), fit.sample.shape) phis = fit.get_drawset(params=['phi']) self.assertEqual((2000, 2095), phis.shape) phi1 = fit.get_drawset(params=['phi.1']) self.assertEqual((2000, 1), phi1.shape) mo_phis = fit.get_drawset(params=['phi.1', 'phi.10', 'phi.100']) self.assertEqual((2000, 3), mo_phis.shape) phi2095 = fit.get_drawset(params=['phi.2095']) self.assertEqual((2000, 1), phi2095.shape) with self.assertRaises(Exception): fit.get_drawset(params=['phi.2096']) with self.assertRaises(Exception): fit.get_drawset(params=['ph'])
def test_output_filenames(self): exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') sampler_args = SamplerArgs() chain_ids = [1, 2, 3, 4] cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=chain_ids, data=jdata, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=4) self.assertIn('bernoulli-', runset._csv_files[0]) self.assertIn('_1.csv', runset._csv_files[0]) self.assertIn('_4.csv', runset._csv_files[3])
def test_ctor_checks(self): exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') sampler_args = SamplerArgs() chain_ids = [11, 12, 13, 14] cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=chain_ids, data=jdata, method_args=sampler_args, ) with self.assertRaises(ValueError): RunSet(args=cmdstan_args, chains=0) with self.assertRaises(ValueError): RunSet(args=cmdstan_args, chains=4, chain_ids=[1, 2, 3])
def test_args_sig_figs(self): sampler_args = SamplerArgs() cmdstan_path() # sets os.environ['CMDSTAN'] if cmdstan_version_before(2, 25): with LogCapture() as log: logging.getLogger() CmdStanArgs( model_name='bernoulli', model_exe='bernoulli.exe', chain_ids=[1, 2, 3, 4], sig_figs=12, method_args=sampler_args, ) expect = ( 'Argument "sig_figs" invalid for CmdStan versions < 2.25, ' 'using version {} in directory {}').format( os.path.basename(cmdstan_path()), os.path.dirname(cmdstan_path()), ) log.check_present(('cmdstanpy', 'WARNING', expect)) else: cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe='bernoulli.exe', chain_ids=[1, 2, 3, 4], sig_figs=12, method_args=sampler_args, ) cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv') self.assertIn('sig_figs=', ' '.join(cmd)) with self.assertRaises(ValueError): CmdStanArgs( model_name='bernoulli', model_exe='bernoulli.exe', chain_ids=[1, 2, 3, 4], sig_figs=-1, method_args=sampler_args, ) with self.assertRaises(ValueError): CmdStanArgs( model_name='bernoulli', model_exe='bernoulli.exe', chain_ids=[1, 2, 3, 4], sig_figs=20, method_args=sampler_args, )
def test_chain_ids(self): exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') sampler_args = SamplerArgs() chain_ids = [11, 12, 13, 14] cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=chain_ids, data=jdata, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=4, chain_ids=chain_ids) self.assertIn('id=11', runset.cmd(0)) self.assertIn('_11.csv', runset._csv_files[0]) self.assertIn('id=14', runset.cmd(3)) self.assertIn('_14.csv', runset._csv_files[3])
def test_validate_good_run(self): # construct fit using existing sampler output exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION) jdata = os.path.join(datafiles_path, 'bernoulli.data.json') output = os.path.join(goodfiles_path, 'bern') sampler_args = SamplerArgs(sampling_iters=100, max_treedepth=11, adapt_delta=0.95) cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_basename=output, method_args=sampler_args, ) fit = StanFit(args=cmdstan_args, chains=4) retcodes = fit._retcodes for i in range(len(retcodes)): fit._set_retcode(i, 0) self.assertTrue(fit._check_retcodes()) fit._check_console_msgs() fit._validate_csv_files() self.assertEqual(4, fit.chains) self.assertEqual(100, fit.draws) self.assertEqual(8, len(fit.column_names)) self.assertEqual('lp__', fit.column_names[0]) df = fit.get_drawset() self.assertEqual(df.shape, (fit.chains * fit.draws, len(fit.column_names))) _ = fit.summary() # TODO - use cmdstan test files instead expected = '\n'.join([ 'Checking sampler transitions treedepth.', 'Treedepth satisfactory for all transitions.', '\nChecking sampler transitions for divergences.', 'No divergent transitions found.', '\nChecking E-BFMI - sampler transitions HMC potential energy.', 'E-BFMI satisfactory for all transitions.', '\nEffective sample size satisfactory.', ]) self.assertIn(expected, fit.diagnose().replace("\r\n", "\n"))
def test_validate_summary_sig_figs(self): # construct CmdStanMCMC from logistic model output, config exe = os.path.join(DATAFILES_PATH, 'logistic' + EXTENSION) rdata = os.path.join(DATAFILES_PATH, 'logistic.data.R') sampler_args = SamplerArgs(iter_sampling=100) cmdstan_args = CmdStanArgs( model_name='logistic', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=rdata, output_dir=DATAFILES_PATH, sig_figs=17, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args) runset._csv_files = [ os.path.join(DATAFILES_PATH, 'logistic_output_1.csv'), os.path.join(DATAFILES_PATH, 'logistic_output_2.csv'), os.path.join(DATAFILES_PATH, 'logistic_output_3.csv'), os.path.join(DATAFILES_PATH, 'logistic_output_4.csv'), ] retcodes = runset._retcodes for i in range(len(retcodes)): runset._set_retcode(i, 0) fit = CmdStanMCMC(runset) sum_default = fit.summary() beta1_default = format(sum_default.iloc[1, 0], '.18g') self.assertTrue(beta1_default.startswith('1.3')) if cmdstan_version_at(2, 25): sum_17 = fit.summary(sig_figs=17) beta1_17 = format(sum_17.iloc[1, 0], '.18g') self.assertTrue(beta1_17.startswith('1.345767078273')) sum_10 = fit.summary(sig_figs=10) beta1_10 = format(sum_10.iloc[1, 0], '.18g') self.assertTrue(beta1_10.startswith('1.34576707')) with self.assertRaises(ValueError): fit.summary(sig_figs=20) with self.assertRaises(ValueError): fit.summary(sig_figs=-1)
def test_get_err_msgs(self): exe = os.path.join(DATAFILES_PATH, 'logistic' + EXTENSION) rdata = os.path.join(DATAFILES_PATH, 'logistic.data.R') sampler_args = SamplerArgs() cmdstan_args = CmdStanArgs( model_name='logistic', model_exe=exe, chain_ids=[1, 2, 3], data=rdata, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=3) for i in range(3): runset._set_retcode(i, 70) stdout_file = 'chain-' + str(i + 1) + '-missing-data-stdout.txt' path = os.path.join(DATAFILES_PATH, stdout_file) runset._stdout_files[i] = path errs = '\n\t'.join(runset._get_err_msgs()) self.assertIn('Exception', errs)
def test_args_good(self): exe = os.path.join(DATAFILES_PATH, 'bernoulli') jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') sampler_args = SamplerArgs() cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], data=jdata, method_args=sampler_args, refresh=10, ) self.assertEqual(cmdstan_args.method, Method.SAMPLE) cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv') self.assertIn('id=1 random seed=', ' '.join(cmd)) self.assertIn('data file=', ' '.join(cmd)) self.assertIn('output file=', ' '.join(cmd)) self.assertIn('method=sample algorithm=hmc', ' '.join(cmd)) self.assertIn('refresh=10', ' '.join(cmd)) cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[7, 11, 18, 29], data=jdata, method_args=sampler_args, ) cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv') self.assertIn('id=7 random seed=', ' '.join(cmd)) dirname = 'tmp' + str(time()) if os.path.exists(dirname): os.rmdir(dirname) CmdStanArgs( model_name='bernoulli', model_exe='bernoulli.exe', chain_ids=[1, 2, 3, 4], output_dir=dirname, method_args=sampler_args, ) self.assertTrue(os.path.exists(dirname)) os.rmdir(dirname)
def test_check_repr(self): exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') sampler_args = SamplerArgs() chain_ids = [1, 2, 3, 4] # default cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=chain_ids, data=jdata, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=4) self.assertIn('RunSet: chains=4', runset.__repr__()) self.assertIn('method=sample', runset.__repr__()) self.assertIn('retcodes=[-1, -1, -1, -1]', runset.__repr__()) self.assertIn('csv_file', runset.__repr__()) self.assertIn('console_msgs', runset.__repr__()) self.assertNotIn('diagnostics_file', runset.__repr__())
def test_check_retcodes(self): exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION) jdata = os.path.join(datafiles_path, 'bernoulli.data.json') sampler_args = SamplerArgs() cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], data=jdata, method_args=sampler_args, ) fit = StanFit(args=cmdstan_args, chains=4) retcodes = fit._retcodes self.assertEqual(4, len(retcodes)) for i in range(len(retcodes)): self.assertEqual(-1, fit._retcode(i)) fit._set_retcode(0, 0) self.assertEqual(0, fit._retcode(0)) for i in range(1, len(retcodes)): self.assertEqual(-1, fit._retcode(i)) self.assertFalse(fit._check_retcodes()) for i in range(1, len(retcodes)): fit._set_retcode(i, 0) self.assertTrue(fit._check_retcodes())
def test_diagnose_divergences(self): exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION) # fake out validation output = os.path.join(datafiles_path, 'diagnose-good', 'corr_gauss_depth8') sampler_args = SamplerArgs() cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1], output_basename=output, method_args=sampler_args, ) fit = StanFit(args=cmdstan_args, chains=1) # TODO - use cmdstan test files instead expected = '\n'.join([ 'Checking sampler transitions treedepth.', '424 of 1000 (42%) transitions hit the maximum ' 'treedepth limit of 8, or 2^8 leapfrog steps.', 'Trajectories that are prematurely terminated ' 'due to this limit will result in slow exploration.', 'For optimal performance, increase this limit.', ]) self.assertIn(expected, fit.diagnose().replace("\r\n", "\n"))
def test_variables_3d(self): # construct fit using existing sampler output exe = os.path.join(DATAFILES_PATH, 'multidim_vars' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'logistic.data.R') sampler_args = SamplerArgs(iter_sampling=20) cmdstan_args = CmdStanArgs( model_name='multidim_vars', model_exe=exe, chain_ids=[1], seed=12345, data=jdata, output_dir=DATAFILES_PATH, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=1) runset._csv_files = [os.path.join(DATAFILES_PATH, 'multidim_vars.csv')] runset._set_retcode(0, 0) fit = CmdStanMCMC(runset) self.assertEqual(20, fit.num_draws_sampling) self.assertEqual(3, len(fit.stan_vars_dims)) self.assertTrue('y_rep' in fit.stan_vars_dims) self.assertEqual(fit.stan_vars_dims['y_rep'], (5, 4, 3)) var_y_rep = fit.stan_variable(name='y_rep') self.assertEqual(var_y_rep.shape, (20, 5, 4, 3)) var_beta = fit.stan_variable(name='beta') self.assertEqual(var_beta.shape, (20, 2)) var_frac_60 = fit.stan_variable(name='frac_60') self.assertEqual(var_frac_60.shape, (20, )) vars = fit.stan_variables() self.assertEqual(len(vars), len(fit.stan_vars_dims)) self.assertTrue('y_rep' in vars) self.assertEqual(vars['y_rep'].shape, (20, 5, 4, 3)) self.assertTrue('beta' in vars) self.assertEqual(vars['beta'].shape, (20, 2)) self.assertTrue('frac_60' in vars) self.assertEqual(vars['frac_60'].shape, (20, ))
def test_validate_bad_run(self): exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') sampler_args = SamplerArgs(max_treedepth=11, adapt_delta=0.95) # some chains had errors cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_dir=DATAFILES_PATH, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=4) for i in range(4): runset._set_retcode(i, 0) self.assertTrue(runset._check_retcodes()) # errors reported runset._stderr_files = [ os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-transcript-bern-1.txt'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-transcript-bern-2.txt'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-transcript-bern-3.txt'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-transcript-bern-4.txt'), ] self.assertEqual(len(runset._get_err_msgs()), 4) # csv file headers inconsistent runset._csv_files = [ os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-hdr-bern-1.csv'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-hdr-bern-2.csv'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-hdr-bern-3.csv'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-hdr-bern-4.csv'), ] with self.assertRaisesRegex(ValueError, 'header mismatch'): CmdStanMCMC(runset) # bad draws runset._csv_files = [ os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-draws-bern-1.csv'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-draws-bern-2.csv'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-draws-bern-3.csv'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-draws-bern-4.csv'), ] with self.assertRaisesRegex(ValueError, 'draws'): CmdStanMCMC(runset) # mismatch - column headers, draws runset._csv_files = [ os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-1.csv'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-2.csv'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-3.csv'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-4.csv'), ] with self.assertRaisesRegex(ValueError, 'bad draw, expecting 9 items, found 8'): CmdStanMCMC(runset)
def test_validate_good_run(self): # construct fit using existing sampler output exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') sampler_args = SamplerArgs(iter_sampling=100, max_treedepth=11, adapt_delta=0.95) cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_dir=DATAFILES_PATH, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args) runset._csv_files = [ os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'), ] self.assertEqual(4, runset.chains) retcodes = runset._retcodes for i in range(len(retcodes)): runset._set_retcode(i, 0) self.assertTrue(runset._check_retcodes()) fit = CmdStanMCMC(runset) self.assertEqual(100, fit.num_draws) self.assertEqual(len(BERNOULLI_COLS), len(fit.column_names)) self.assertEqual('lp__', fit.column_names[0]) drawset = fit.get_drawset() self.assertEqual( drawset.shape, (fit.runset.chains * fit.num_draws, len(fit.column_names)), ) summary = fit.summary() self.assertIn('5%', list(summary.columns)) self.assertIn('50%', list(summary.columns)) self.assertIn('95%', list(summary.columns)) self.assertNotIn('1%', list(summary.columns)) self.assertNotIn('99%', list(summary.columns)) summary = fit.summary(percentiles=[1, 45, 99]) self.assertIn('1%', list(summary.columns)) self.assertIn('45%', list(summary.columns)) self.assertIn('99%', list(summary.columns)) self.assertNotIn('5%', list(summary.columns)) self.assertNotIn('50%', list(summary.columns)) self.assertNotIn('95%', list(summary.columns)) with self.assertRaises(ValueError): fit.summary(percentiles=[]) with self.assertRaises(ValueError): fit.summary(percentiles=[-1]) diagnostics = fit.diagnose() self.assertIn('Treedepth satisfactory for all transitions.', diagnostics) self.assertIn('No divergent transitions found.', diagnostics) self.assertIn('E-BFMI satisfactory for all transitions.', diagnostics) self.assertIn('Effective sample size satisfactory.', diagnostics)
def test_args_chains(self): args = SamplerArgs() with self.assertRaises(ValueError): args.validate(chains=None)
def test_bad(self): args = SamplerArgs(warmup_iters=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(warmup_iters=0, adapt_engaged=True) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(sampling_iters=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(thin=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(max_treedepth=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(step_size=-10) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(step_size=[1.0, 1.1]) with self.assertRaises(ValueError): args.validate(chains=1) args = SamplerArgs(step_size=[1.0, -1.1]) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(adapt_delta=1.1) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(adapt_delta=-0.1) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(warmup_iters=100, fixed_param=True) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(save_warmup=True, fixed_param=True) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(max_treedepth=12, fixed_param=True) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(metric='dense', fixed_param=True) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(step_size=0.5, fixed_param=True) with self.assertRaises(ValueError): args.validate(chains=2) args = SamplerArgs(adapt_delta=0.88, fixed_param=True) with self.assertRaises(ValueError): args.validate(chains=2)