def test_variables(self): # construct fit using existing sampler output exe = os.path.join(DATAFILES_PATH, 'lotka-volterra' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'lotka-volterra.data.json') sampler_args = SamplerArgs(iter_sampling=20) cmdstan_args = CmdStanArgs( model_name='lotka-volterra', model_exe=exe, chain_ids=[1], seed=12345, data=jdata, output_dir=DATAFILES_PATH, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=1) runset._csv_files = [ os.path.join(DATAFILES_PATH, 'lotka-volterra.csv') ] runset._set_retcode(0, 0) fit = CmdStanMCMC(runset) self.assertEqual(20, fit.num_draws) self.assertEqual(8, len(fit._stan_variable_dims)) self.assertTrue('z' in fit._stan_variable_dims) self.assertEqual(fit._stan_variable_dims['z'], (20, 2)) vars = fit.stan_variables() self.assertEqual(len(vars), len(fit._stan_variable_dims)) self.assertTrue('z' in vars) self.assertEqual(vars['z'].shape, (20, 20, 2)) self.assertTrue('theta' in vars) self.assertEqual(vars['theta'].shape, (20, 4))
def test_check_retcodes(self): exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') sampler_args = SamplerArgs() chain_ids = [1, 2, 3, 4] # default cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=chain_ids, data=jdata, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=4) retcodes = runset._retcodes self.assertEqual(4, len(retcodes)) for i in range(len(retcodes)): self.assertEqual(-1, runset._retcode(i)) runset._set_retcode(0, 0) self.assertEqual(0, runset._retcode(0)) for i in range(1, len(retcodes)): self.assertEqual(-1, runset._retcode(i)) self.assertFalse(runset._check_retcodes()) for i in range(1, len(retcodes)): runset._set_retcode(i, 0) self.assertTrue(runset._check_retcodes())
def _run_cmdstan(self, runset: RunSet, idx: int = 0, pbar: Any = None) -> None: """ Encapsulates call to CmdStan. Spawn process, capture console output to file, record returncode. """ cmd = runset.cmds[idx] self._logger.info('start chain %u', idx + 1) self._logger.debug('threads: %s', str(os.environ.get('STAN_NUM_THREADS'))) self._logger.debug('sampling: %s', cmd) proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=os.environ) if pbar: stdout_pbar = self._read_progress(proc, pbar, idx) stdout, stderr = proc.communicate() if pbar: stdout = stdout_pbar + stdout self._logger.info('finish chain %u', idx + 1) if stdout: with open(runset.stdout_files[idx], 'w+') as fd: fd.write(stdout.decode('utf-8')) if stderr: with open(runset.stderr_files[idx], 'w+') as fd: fd.write(stderr.decode('utf-8')) runset._set_retcode(idx, proc.returncode)
def _run_cmdstan( self, runset: RunSet, idx: int = 0, pbar: List[Any] = None ) -> None: """ Encapsulates call to cmdstan. Spawn process, capture console output to file, record returncode. """ cmd = runset.cmds[idx] self._logger.info('start chain %u', idx + 1) self._logger.debug('sampling: %s', cmd) proc = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=os.environ, ) if pbar: stdout_pbar = self._read_progress(proc, pbar, idx) stdout, stderr = proc.communicate() if pbar: stdout = stdout_pbar + stdout transcript_file = runset.console_files[idx] self._logger.info('finish chain %u', idx + 1) with open(transcript_file, 'w+') as transcript: if stdout: transcript.write(stdout.decode('utf-8')) if stderr: transcript.write('ERROR') transcript.write(stderr.decode('utf-8')) runset._set_retcode(idx, proc.returncode)
def test_good(self): # construct fit using existing sampler output exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') sampler_args = SamplerArgs(iter_sampling=100, max_treedepth=11, adapt_delta=0.95) cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_dir=DATAFILES_PATH, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args) runset._csv_files = [ os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'), ] retcodes = runset._retcodes for i in range(len(retcodes)): runset._set_retcode(i, 0) config = check_sampler_csv( path=runset.csv_files[i], is_fixed_param=False, iter_sampling=100, iter_warmup=1000, save_warmup=False, thin=1, ) expected = 'Metadata:\n{}\n'.format(config) metadata = InferenceMetadata(config) actual = '{}'.format(metadata) self.assertEqual(expected, actual) self.assertEqual(config, metadata.cmdstan_config) hmc_vars = { 'lp__', 'accept_stat__', 'stepsize__', 'treedepth__', 'n_leapfrog__', 'divergent__', 'energy__', } sampler_vars_cols = metadata.sampler_vars_cols self.assertEqual(hmc_vars, sampler_vars_cols.keys()) bern_model_vars = {'theta'} self.assertEqual(bern_model_vars, metadata.stan_vars_dims.keys()) self.assertEqual((), metadata.stan_vars_dims['theta']) self.assertEqual(bern_model_vars, metadata.stan_vars_cols.keys()) self.assertEqual((7, ), metadata.stan_vars_cols['theta'])
def test_validate_good_run(self): # construct fit using existing sampler output exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') sampler_args = SamplerArgs(iter_sampling=100, max_treedepth=11, adapt_delta=0.95) cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_dir=DATAFILES_PATH, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=4) runset._csv_files = [ os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'), ] self.assertEqual(4, runset.chains) retcodes = runset._retcodes for i in range(len(retcodes)): runset._set_retcode(i, 0) self.assertTrue(runset._check_retcodes()) fit = CmdStanMCMC(runset) self.assertEqual(100, fit.num_draws) self.assertEqual(len(BERNOULLI_COLS), len(fit.column_names)) self.assertEqual('lp__', fit.column_names[0]) drawset = fit.get_drawset() self.assertEqual( drawset.shape, (fit.runset.chains * fit.num_draws, len(fit.column_names)), ) _ = fit.summary() self.assertTrue(True) # TODO - use cmdstan test files instead expected = '\n'.join([ 'Checking sampler transitions treedepth.', 'Treedepth satisfactory for all transitions.', '\nChecking sampler transitions for divergences.', 'No divergent transitions found.', '\nChecking E-BFMI - sampler transitions HMC potential energy.', 'E-BFMI satisfactory for all transitions.', '\nEffective sample size satisfactory.', ]) self.assertIn(expected, fit.diagnose().replace('\r\n', '\n'))
def test_validate_summary_sig_figs(self): # construct CmdStanMCMC from logistic model output, config exe = os.path.join(DATAFILES_PATH, 'logistic' + EXTENSION) rdata = os.path.join(DATAFILES_PATH, 'logistic.data.R') sampler_args = SamplerArgs(iter_sampling=100) cmdstan_args = CmdStanArgs( model_name='logistic', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=rdata, output_dir=DATAFILES_PATH, sig_figs=17, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args) runset._csv_files = [ os.path.join(DATAFILES_PATH, 'logistic_output_1.csv'), os.path.join(DATAFILES_PATH, 'logistic_output_2.csv'), os.path.join(DATAFILES_PATH, 'logistic_output_3.csv'), os.path.join(DATAFILES_PATH, 'logistic_output_4.csv'), ] retcodes = runset._retcodes for i in range(len(retcodes)): runset._set_retcode(i, 0) fit = CmdStanMCMC(runset) sum_default = fit.summary() beta1_default = format(sum_default.iloc[1, 0], '.18g') self.assertTrue(beta1_default.startswith('1.3')) if cmdstan_version_at(2, 25): sum_17 = fit.summary(sig_figs=17) beta1_17 = format(sum_17.iloc[1, 0], '.18g') self.assertTrue(beta1_17.startswith('1.345767078273')) sum_10 = fit.summary(sig_figs=10) beta1_10 = format(sum_10.iloc[1, 0], '.18g') self.assertTrue(beta1_10.startswith('1.34576707')) with self.assertRaises(ValueError): fit.summary(sig_figs=20) with self.assertRaises(ValueError): fit.summary(sig_figs=-1)
def test_get_err_msgs(self): exe = os.path.join(DATAFILES_PATH, 'logistic' + EXTENSION) rdata = os.path.join(DATAFILES_PATH, 'logistic.data.R') sampler_args = SamplerArgs() cmdstan_args = CmdStanArgs( model_name='logistic', model_exe=exe, chain_ids=[1, 2, 3], data=rdata, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=3) for i in range(3): runset._set_retcode(i, 70) stdout_file = 'chain-' + str(i + 1) + '-missing-data-stdout.txt' path = os.path.join(DATAFILES_PATH, stdout_file) runset._stdout_files[i] = path errs = '\n\t'.join(runset._get_err_msgs()) self.assertIn('Exception', errs)
def test_variables_3d(self): # construct fit using existing sampler output exe = os.path.join(DATAFILES_PATH, 'multidim_vars' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'logistic.data.R') sampler_args = SamplerArgs(iter_sampling=20) cmdstan_args = CmdStanArgs( model_name='multidim_vars', model_exe=exe, chain_ids=[1], seed=12345, data=jdata, output_dir=DATAFILES_PATH, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=1) runset._csv_files = [os.path.join(DATAFILES_PATH, 'multidim_vars.csv')] runset._set_retcode(0, 0) fit = CmdStanMCMC(runset) self.assertEqual(20, fit.num_draws_sampling) self.assertEqual(3, len(fit.stan_vars_dims)) self.assertTrue('y_rep' in fit.stan_vars_dims) self.assertEqual(fit.stan_vars_dims['y_rep'], (5, 4, 3)) var_y_rep = fit.stan_variable(name='y_rep') self.assertEqual(var_y_rep.shape, (20, 5, 4, 3)) var_beta = fit.stan_variable(name='beta') self.assertEqual(var_beta.shape, (20, 2)) var_frac_60 = fit.stan_variable(name='frac_60') self.assertEqual(var_frac_60.shape, (20, )) vars = fit.stan_variables() self.assertEqual(len(vars), len(fit.stan_vars_dims)) self.assertTrue('y_rep' in vars) self.assertEqual(vars['y_rep'].shape, (20, 5, 4, 3)) self.assertTrue('beta' in vars) self.assertEqual(vars['beta'].shape, (20, 2)) self.assertTrue('frac_60' in vars) self.assertEqual(vars['frac_60'].shape, (20, ))
def test_validate_bad_run(self): exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') sampler_args = SamplerArgs(max_treedepth=11, adapt_delta=0.95) # some chains had errors cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_dir=DATAFILES_PATH, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=4) for i in range(4): runset._set_retcode(i, 0) self.assertTrue(runset._check_retcodes()) # errors reported runset._stderr_files = [ os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-transcript-bern-1.txt'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-transcript-bern-2.txt'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-transcript-bern-3.txt'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-transcript-bern-4.txt'), ] self.assertEqual(len(runset._get_err_msgs()), 4) # csv file headers inconsistent runset._csv_files = [ os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-hdr-bern-1.csv'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-hdr-bern-2.csv'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-hdr-bern-3.csv'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-hdr-bern-4.csv'), ] with self.assertRaisesRegex(ValueError, 'header mismatch'): CmdStanMCMC(runset) # bad draws runset._csv_files = [ os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-draws-bern-1.csv'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-draws-bern-2.csv'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-draws-bern-3.csv'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-draws-bern-4.csv'), ] with self.assertRaisesRegex(ValueError, 'draws'): CmdStanMCMC(runset) # mismatch - column headers, draws runset._csv_files = [ os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-1.csv'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-2.csv'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-3.csv'), os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-4.csv'), ] with self.assertRaisesRegex(ValueError, 'bad draw, expecting 9 items, found 8'): CmdStanMCMC(runset)
def test_validate_good_run(self): # construct fit using existing sampler output exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') sampler_args = SamplerArgs(iter_sampling=100, max_treedepth=11, adapt_delta=0.95) cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_dir=DATAFILES_PATH, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args) runset._csv_files = [ os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'), ] self.assertEqual(4, runset.chains) retcodes = runset._retcodes for i in range(len(retcodes)): runset._set_retcode(i, 0) self.assertTrue(runset._check_retcodes()) fit = CmdStanMCMC(runset) self.assertEqual(100, fit.num_draws) self.assertEqual(len(BERNOULLI_COLS), len(fit.column_names)) self.assertEqual('lp__', fit.column_names[0]) drawset = fit.get_drawset() self.assertEqual( drawset.shape, (fit.runset.chains * fit.num_draws, len(fit.column_names)), ) summary = fit.summary() self.assertIn('5%', list(summary.columns)) self.assertIn('50%', list(summary.columns)) self.assertIn('95%', list(summary.columns)) self.assertNotIn('1%', list(summary.columns)) self.assertNotIn('99%', list(summary.columns)) summary = fit.summary(percentiles=[1, 45, 99]) self.assertIn('1%', list(summary.columns)) self.assertIn('45%', list(summary.columns)) self.assertIn('99%', list(summary.columns)) self.assertNotIn('5%', list(summary.columns)) self.assertNotIn('50%', list(summary.columns)) self.assertNotIn('95%', list(summary.columns)) with self.assertRaises(ValueError): fit.summary(percentiles=[]) with self.assertRaises(ValueError): fit.summary(percentiles=[-1]) diagnostics = fit.diagnose() self.assertIn('Treedepth satisfactory for all transitions.', diagnostics) self.assertIn('No divergent transitions found.', diagnostics) self.assertIn('E-BFMI satisfactory for all transitions.', diagnostics) self.assertIn('Effective sample size satisfactory.', diagnostics)
def test_validate_bad_run(self): exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION) jdata = os.path.join(datafiles_path, 'bernoulli.data.json') sampler_args = SamplerArgs(sampling_iters=100, max_treedepth=11, adapt_delta=0.95) # some chains had errors output = os.path.join(badfiles_path, 'bad-transcript-bern') cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_basename=output, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=4) with self.assertRaisesRegex(Exception, 'Exception'): runset._check_console_msgs() # csv file headers inconsistent output = os.path.join(badfiles_path, 'bad-hdr-bern') cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_basename=output, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=4) retcodes = runset._retcodes for i in range(len(retcodes)): runset._set_retcode(i, 0) self.assertTrue(runset._check_retcodes()) fit = CmdStanMCMC(runset) with self.assertRaisesRegex(ValueError, 'header mismatch'): fit._validate_csv_files() # bad draws output = os.path.join(badfiles_path, 'bad-draws-bern') cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_basename=output, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=4) retcodes = runset._retcodes for i in range(len(retcodes)): runset._set_retcode(i, 0) self.assertTrue(runset._check_retcodes()) fit = CmdStanMCMC(runset) with self.assertRaisesRegex(ValueError, 'draws'): fit._validate_csv_files() # mismatch - column headers, draws output = os.path.join(badfiles_path, 'bad-cols-bern') cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_basename=output, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args, chains=4) retcodes = runset._retcodes for i in range(len(retcodes)): runset._set_retcode(i, 0) self.assertTrue(runset._check_retcodes()) fit = CmdStanMCMC(runset) with self.assertRaisesRegex(ValueError, 'bad draw'): fit._validate_csv_files()
def test_metadata(self): # construct CmdStanMCMC from logistic model output, config exe = os.path.join(DATAFILES_PATH, 'logistic' + EXTENSION) rdata = os.path.join(DATAFILES_PATH, 'logistic.data.R') sampler_args = SamplerArgs(iter_sampling=100) cmdstan_args = CmdStanArgs( model_name='logistic', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=rdata, output_dir=DATAFILES_PATH, sig_figs=17, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args) runset._csv_files = [ os.path.join(DATAFILES_PATH, 'logistic_output_1.csv'), os.path.join(DATAFILES_PATH, 'logistic_output_2.csv'), os.path.join(DATAFILES_PATH, 'logistic_output_3.csv'), os.path.join(DATAFILES_PATH, 'logistic_output_4.csv'), ] retcodes = runset._retcodes for i in range(len(retcodes)): runset._set_retcode(i, 0) fit = CmdStanMCMC(runset) col_names = tuple([ 'lp__', 'accept_stat__', 'stepsize__', 'treedepth__', 'n_leapfrog__', 'divergent__', 'energy__', 'beta[1]', 'beta[2]', ]) self.assertEqual(fit.chains, 4) self.assertEqual(fit.chain_ids, [1, 2, 3, 4]) self.assertEqual(fit.num_draws_warmup, 1000) self.assertEqual(fit.num_draws_sampling, 100) self.assertEqual(fit.column_names, col_names) self.assertEqual(fit.num_unconstrained_params, 2) self.assertEqual(fit.metric_type, 'diag_e') self.assertEqual(fit.sampler_config['num_samples'], 100) self.assertEqual(fit.sampler_config['thin'], 1) self.assertEqual(fit.sampler_config['algorithm'], 'hmc') self.assertEqual(fit.sampler_config['metric'], 'diag_e') self.assertAlmostEqual(fit.sampler_config['delta'], 0.80) self.assertTrue('n_leapfrog__' in fit.sampler_vars_cols) self.assertTrue('energy__' in fit.sampler_vars_cols) self.assertTrue('beta' not in fit.sampler_vars_cols) self.assertTrue('energy__' not in fit.stan_vars_dims) self.assertTrue('beta' in fit.stan_vars_dims) self.assertTrue('beta' in fit.stan_vars_cols) self.assertEqual(fit.stan_vars_dims['beta'], tuple([2])) self.assertEqual(fit.stan_vars_cols['beta'], tuple([7, 8]))