def test_gen_quantities_good(self): stan = os.path.join(datafiles_path, 'bernoulli_ppc.stan') model = Model(stan_file=stan) model.compile() jdata = os.path.join(datafiles_path, 'bernoulli.data.json') # synthesize stanfit object - # see test_stanfit.py, method 'test_validate_good_run' goodfiles_path = os.path.join(datafiles_path, 'runset-good') output = os.path.join(goodfiles_path, 'bern') sampler_args = SamplerArgs( sampling_iters=100, max_treedepth=11, adapt_delta=0.95 ) cmdstan_args = CmdStanArgs( model_name=model.name, model_exe=model.exe_file, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_basename=output, method_args=sampler_args, ) sampler_fit = StanFit(args=cmdstan_args, chains=4) for i in range(4): sampler_fit._set_retcode(i, 0) bern_fit = model.run_generated_quantities( csv_files=sampler_fit.csv_files, data=jdata) # check results - ouput files, quantities of interest, draws self.assertEqual(bern_fit.chains, 4) for i in range(4): self.assertEqual(bern_fit._retcodes[i], 0) csv_file = bern_fit.csv_files[i] self.assertTrue(os.path.exists(csv_file)) column_names = [ 'y_rep.1', 'y_rep.2', 'y_rep.3', 'y_rep.4', 'y_rep.5', 'y_rep.6', 'y_rep.7', 'y_rep.8', 'y_rep.9', 'y_rep.10' ] self.assertEqual(bern_fit.column_names, tuple(column_names)) self.assertEqual(bern_fit.draws, 100)
def test_validate_good_run(self): # construct fit using existing sampler output exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION) jdata = os.path.join(datafiles_path, 'bernoulli.data.json') output = os.path.join(goodfiles_path, 'bern') sampler_args = SamplerArgs(sampling_iters=100, max_treedepth=11, adapt_delta=0.95) cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_basename=output, method_args=sampler_args, ) fit = StanFit(args=cmdstan_args, chains=4) retcodes = fit._retcodes for i in range(len(retcodes)): fit._set_retcode(i, 0) self.assertTrue(fit._check_retcodes()) fit._check_console_msgs() fit._validate_csv_files() self.assertEqual(4, fit.chains) self.assertEqual(100, fit.draws) self.assertEqual(8, len(fit.column_names)) self.assertEqual('lp__', fit.column_names[0]) df = fit.get_drawset() self.assertEqual(df.shape, (fit.chains * fit.draws, len(fit.column_names))) _ = fit.summary() # TODO - use cmdstan test files instead expected = '\n'.join([ 'Checking sampler transitions treedepth.', 'Treedepth satisfactory for all transitions.', '\nChecking sampler transitions for divergences.', 'No divergent transitions found.', '\nChecking E-BFMI - sampler transitions HMC potential energy.', 'E-BFMI satisfactory for all transitions.', '\nEffective sample size satisfactory.', ]) self.assertIn(expected, fit.diagnose().replace("\r\n", "\n"))
def _do_sample(self, stanfit: StanFit, idx: int) -> None: """ Encapsulates call to sampler. Spawn process, capture console output to file, record returncode. """ cmd = stanfit.cmds[idx] self._logger.info('start chain %u', idx + 1) self._logger.debug('sampling: %s', cmd) proc = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) proc.wait() stdout, stderr = proc.communicate() transcript_file = stanfit.console_files[idx] self._logger.info('finish chain %u', idx + 1) with open(transcript_file, 'w+') as transcript: if stdout: transcript.write(stdout.decode('utf-8')) if stderr: transcript.write('ERROR') transcript.write(stderr.decode('utf-8')) stanfit._set_retcode(idx, proc.returncode)
def test_check_retcodes(self): exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION) jdata = os.path.join(datafiles_path, 'bernoulli.data.json') sampler_args = SamplerArgs() cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], data=jdata, method_args=sampler_args, ) fit = StanFit(args=cmdstan_args, chains=4) retcodes = fit._retcodes self.assertEqual(4, len(retcodes)) for i in range(len(retcodes)): self.assertEqual(-1, fit._retcode(i)) fit._set_retcode(0, 0) self.assertEqual(0, fit._retcode(0)) for i in range(1, len(retcodes)): self.assertEqual(-1, fit._retcode(i)) self.assertFalse(fit._check_retcodes()) for i in range(1, len(retcodes)): fit._set_retcode(i, 0) self.assertTrue(fit._check_retcodes())
def test_validate_bad_run(self): exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION) jdata = os.path.join(datafiles_path, 'bernoulli.data.json') sampler_args = SamplerArgs(sampling_iters=100, max_treedepth=11, adapt_delta=0.95) # some chains had errors output = os.path.join(badfiles_path, 'bad-transcript-bern') cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_basename=output, method_args=sampler_args, ) fit = StanFit(args=cmdstan_args, chains=4) with self.assertRaisesRegex(Exception, 'Exception'): fit._check_console_msgs() # csv file headers inconsistent output = os.path.join(badfiles_path, 'bad-hdr-bern') cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_basename=output, method_args=sampler_args, ) fit = StanFit(args=cmdstan_args, chains=4) retcodes = fit._retcodes for i in range(len(retcodes)): fit._set_retcode(i, 0) self.assertTrue(fit._check_retcodes()) with self.assertRaisesRegex(ValueError, 'header mismatch'): fit._validate_csv_files() # bad draws output = os.path.join(badfiles_path, 'bad-draws-bern') cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_basename=output, method_args=sampler_args, ) fit = StanFit(args=cmdstan_args, chains=4) retcodes = fit._retcodes for i in range(len(retcodes)): fit._set_retcode(i, 0) self.assertTrue(fit._check_retcodes()) with self.assertRaisesRegex(ValueError, 'draws'): fit._validate_csv_files() # mismatch - column headers, draws output = os.path.join(badfiles_path, 'bad-cols-bern') cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_basename=output, method_args=sampler_args, ) fit = StanFit(args=cmdstan_args, chains=4) retcodes = fit._retcodes for i in range(len(retcodes)): fit._set_retcode(i, 0) self.assertTrue(fit._check_retcodes()) with self.assertRaisesRegex(ValueError, 'bad draw'): fit._validate_csv_files()