Beispiel #1
0
 def test_validate_outputs(self):
     # construct runset using existing sampler output
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
     output = os.path.join(goodfiles_path, 'bern')
     args = SamplerArgs(
         model,
         chain_ids=[1, 2, 3, 4],
         seed=12345,
         data=jdata,
         output_file=output,
         sampling_iters=100,
         max_treedepth=11,
         adapt_delta=0.95,
     )
     runset = RunSet(chains=4, args=args)
     retcodes = runset.retcodes
     for i in range(len(retcodes)):
         runset.set_retcode(i, 0)
     self.assertTrue(runset.check_retcodes())
     runset.check_console_msgs()
     runset.validate_csv_files()
     self.assertEqual(4, runset.chains)
     self.assertEqual(100, runset.draws)
     self.assertEqual(8, len(runset.column_names))
     self.assertEqual('lp__', runset.column_names[0])
Beispiel #2
0
 def test_check_retcodes(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
     output = os.path.join(goodfiles_path, 'bern')
     args = SamplerArgs(
         model,
         chain_ids=[1, 2, 3, 4],
         seed=12345,
         data=jdata,
         output_file=output,
         sampling_iters=100,
         max_treedepth=11,
         adapt_delta=0.95,
     )
     runset = RunSet(chains=4, args=args)
     retcodes = runset.retcodes
     self.assertEqual(4, len(retcodes))
     for i in range(len(retcodes)):
         self.assertEqual(-1, runset.retcode(i))
     runset.set_retcode(0, 0)
     self.assertEqual(0, runset.retcode(0))
     for i in range(1, len(retcodes)):
         self.assertEqual(-1, runset.retcode(i))
     self.assertFalse(runset.check_retcodes())
     for i in range(1, len(retcodes)):
         runset.set_retcode(i, 0)
     self.assertTrue(runset.check_retcodes())
Beispiel #3
0
 def test_validate_bad_hdr(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
     output = os.path.join(badfiles_path, 'bad-hdr-bern')
     args = SamplerArgs(model, chain_ids=[1,2,3,4],
                        seed=12345,
                        data=jdata,
                        output_file=output,
                        sampling_iters=100,
                        max_treedepth=11,
                        adapt_delta=0.95)
     runset = RunSet(chains=4, args=args)
     retcodes = runset.retcodes
     for i in range(len(retcodes)):
         runset.set_retcode(i, 0)
     self.assertTrue(runset.check_retcodes())
     with self.assertRaisesRegex(ValueError, 'header mismatch'):
         runset.validate_csv_files()
Beispiel #4
0
def do_sample(runset: RunSet, idx: int) -> None:
    """
    Encapsulates call to sampler.
    Spawn process, capture console output to file, record returncode.
    """
    cmd = runset.cmds[idx]
    print('start chain {}.  '.format(idx + 1))
    proc = subprocess.Popen(cmd.split(),
                            stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE)
    proc.wait()
    stdout, stderr = proc.communicate()
    transcript_file = runset.console_files[idx]
    print('finish chain {}.  '.format(idx + 1))
    with open(transcript_file, 'w+') as transcript:
        if stdout:
            transcript.write(stdout.decode('utf-8'))
        if stderr:
            transcript.write('ERROR')
            transcript.write(stderr.decode('utf-8'))
    runset.set_retcode(idx, proc.returncode)