def test_validate_outputs(self): # construct runset using existing sampler output stan = os.path.join(datafiles_path, 'bernoulli.stan') exe = os.path.join(datafiles_path, 'bernoulli') model = Model(exe_file=exe, stan_file=stan) jdata = os.path.join(datafiles_path, 'bernoulli.data.json') output = os.path.join(goodfiles_path, 'bern') args = SamplerArgs( model, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_file=output, sampling_iters=100, max_treedepth=11, adapt_delta=0.95, ) runset = RunSet(chains=4, args=args) retcodes = runset.retcodes for i in range(len(retcodes)): runset.set_retcode(i, 0) self.assertTrue(runset.check_retcodes()) runset.check_console_msgs() runset.validate_csv_files() self.assertEqual(4, runset.chains) self.assertEqual(100, runset.draws) self.assertEqual(8, len(runset.column_names)) self.assertEqual('lp__', runset.column_names[0])
def test_check_retcodes(self): stan = os.path.join(datafiles_path, 'bernoulli.stan') exe = os.path.join(datafiles_path, 'bernoulli') model = Model(exe_file=exe, stan_file=stan) jdata = os.path.join(datafiles_path, 'bernoulli.data.json') output = os.path.join(goodfiles_path, 'bern') args = SamplerArgs( model, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_file=output, sampling_iters=100, max_treedepth=11, adapt_delta=0.95, ) runset = RunSet(chains=4, args=args) retcodes = runset.retcodes self.assertEqual(4, len(retcodes)) for i in range(len(retcodes)): self.assertEqual(-1, runset.retcode(i)) runset.set_retcode(0, 0) self.assertEqual(0, runset.retcode(0)) for i in range(1, len(retcodes)): self.assertEqual(-1, runset.retcode(i)) self.assertFalse(runset.check_retcodes()) for i in range(1, len(retcodes)): runset.set_retcode(i, 0) self.assertTrue(runset.check_retcodes())
def test_validate_bad_hdr(self): stan = os.path.join(datafiles_path, 'bernoulli.stan') exe = os.path.join(datafiles_path, 'bernoulli') model = Model(exe_file=exe, stan_file=stan) jdata = os.path.join(datafiles_path, 'bernoulli.data.json') output = os.path.join(badfiles_path, 'bad-hdr-bern') args = SamplerArgs(model, chain_ids=[1,2,3,4], seed=12345, data=jdata, output_file=output, sampling_iters=100, max_treedepth=11, adapt_delta=0.95) runset = RunSet(chains=4, args=args) retcodes = runset.retcodes for i in range(len(retcodes)): runset.set_retcode(i, 0) self.assertTrue(runset.check_retcodes()) with self.assertRaisesRegex(ValueError, 'header mismatch'): runset.validate_csv_files()
def do_sample(runset: RunSet, idx: int) -> None: """ Encapsulates call to sampler. Spawn process, capture console output to file, record returncode. """ cmd = runset.cmds[idx] print('start chain {}. '.format(idx + 1)) proc = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) proc.wait() stdout, stderr = proc.communicate() transcript_file = runset.console_files[idx] print('finish chain {}. '.format(idx + 1)) with open(transcript_file, 'w+') as transcript: if stdout: transcript.write(stdout.decode('utf-8')) if stderr: transcript.write('ERROR') transcript.write(stderr.decode('utf-8')) runset.set_retcode(idx, proc.returncode)