예제 #1
0
    def test_gen_quantities_good(self):
        stan = os.path.join(datafiles_path, 'bernoulli_ppc.stan')
        model = Model(stan_file=stan)
        model.compile()

        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')

        # synthesize stanfit object -
        # see test_stanfit.py, method 'test_validate_good_run'
        goodfiles_path = os.path.join(datafiles_path, 'runset-good')
        output = os.path.join(goodfiles_path, 'bern')
        sampler_args = SamplerArgs(
            sampling_iters=100, max_treedepth=11, adapt_delta=0.95
        )
        cmdstan_args = CmdStanArgs(
            model_name=model.name,
            model_exe=model.exe_file,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        sampler_fit = StanFit(args=cmdstan_args, chains=4)
        for i in range(4):
            sampler_fit._set_retcode(i, 0)

        bern_fit = model.run_generated_quantities(
            csv_files=sampler_fit.csv_files,
            data=jdata)

        # check results - ouput files, quantities of interest, draws
        self.assertEqual(bern_fit.chains, 4)
        for i in range(4):
            self.assertEqual(bern_fit._retcodes[i], 0)
            csv_file = bern_fit.csv_files[i]
            self.assertTrue(os.path.exists(csv_file))
        column_names = [
            'y_rep.1',
            'y_rep.2',
            'y_rep.3',
            'y_rep.4',
            'y_rep.5',
            'y_rep.6',
            'y_rep.7',
            'y_rep.8',
            'y_rep.9',
            'y_rep.10'
        ]
        self.assertEqual(bern_fit.column_names, tuple(column_names))
        self.assertEqual(bern_fit.draws, 100) 
예제 #2
0
    def test_validate_good_run(self):
        # construct fit using existing sampler output
        exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION)
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        output = os.path.join(goodfiles_path, 'bern')
        sampler_args = SamplerArgs(sampling_iters=100,
                                   max_treedepth=11,
                                   adapt_delta=0.95)
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        fit = StanFit(args=cmdstan_args, chains=4)
        retcodes = fit._retcodes
        for i in range(len(retcodes)):
            fit._set_retcode(i, 0)
        self.assertTrue(fit._check_retcodes())
        fit._check_console_msgs()
        fit._validate_csv_files()
        self.assertEqual(4, fit.chains)
        self.assertEqual(100, fit.draws)
        self.assertEqual(8, len(fit.column_names))
        self.assertEqual('lp__', fit.column_names[0])

        df = fit.get_drawset()
        self.assertEqual(df.shape,
                         (fit.chains * fit.draws, len(fit.column_names)))
        _ = fit.summary()

        # TODO - use cmdstan test files instead
        expected = '\n'.join([
            'Checking sampler transitions treedepth.',
            'Treedepth satisfactory for all transitions.',
            '\nChecking sampler transitions for divergences.',
            'No divergent transitions found.',
            '\nChecking E-BFMI - sampler transitions HMC potential energy.',
            'E-BFMI satisfactory for all transitions.',
            '\nEffective sample size satisfactory.',
        ])
        self.assertIn(expected, fit.diagnose().replace("\r\n", "\n"))
예제 #3
0
 def _do_sample(self, stanfit: StanFit, idx: int) -> None:
     """
     Encapsulates call to sampler.
     Spawn process, capture console output to file, record returncode.
     """
     cmd = stanfit.cmds[idx]
     self._logger.info('start chain %u', idx + 1)
     self._logger.debug('sampling: %s', cmd)
     proc = subprocess.Popen(cmd.split(),
                             stdout=subprocess.PIPE,
                             stderr=subprocess.PIPE)
     proc.wait()
     stdout, stderr = proc.communicate()
     transcript_file = stanfit.console_files[idx]
     self._logger.info('finish chain %u', idx + 1)
     with open(transcript_file, 'w+') as transcript:
         if stdout:
             transcript.write(stdout.decode('utf-8'))
         if stderr:
             transcript.write('ERROR')
             transcript.write(stderr.decode('utf-8'))
     stanfit._set_retcode(idx, proc.returncode)
예제 #4
0
 def test_check_retcodes(self):
     exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION)
     jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
     sampler_args = SamplerArgs()
     cmdstan_args = CmdStanArgs(
         model_name='bernoulli',
         model_exe=exe,
         chain_ids=[1, 2, 3, 4],
         data=jdata,
         method_args=sampler_args,
     )
     fit = StanFit(args=cmdstan_args, chains=4)
     retcodes = fit._retcodes
     self.assertEqual(4, len(retcodes))
     for i in range(len(retcodes)):
         self.assertEqual(-1, fit._retcode(i))
     fit._set_retcode(0, 0)
     self.assertEqual(0, fit._retcode(0))
     for i in range(1, len(retcodes)):
         self.assertEqual(-1, fit._retcode(i))
     self.assertFalse(fit._check_retcodes())
     for i in range(1, len(retcodes)):
         fit._set_retcode(i, 0)
     self.assertTrue(fit._check_retcodes())
예제 #5
0
    def test_validate_bad_run(self):
        exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION)
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        sampler_args = SamplerArgs(sampling_iters=100,
                                   max_treedepth=11,
                                   adapt_delta=0.95)

        # some chains had errors
        output = os.path.join(badfiles_path, 'bad-transcript-bern')
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        fit = StanFit(args=cmdstan_args, chains=4)
        with self.assertRaisesRegex(Exception, 'Exception'):
            fit._check_console_msgs()

        # csv file headers inconsistent
        output = os.path.join(badfiles_path, 'bad-hdr-bern')
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        fit = StanFit(args=cmdstan_args, chains=4)
        retcodes = fit._retcodes
        for i in range(len(retcodes)):
            fit._set_retcode(i, 0)
        self.assertTrue(fit._check_retcodes())
        with self.assertRaisesRegex(ValueError, 'header mismatch'):
            fit._validate_csv_files()

        # bad draws
        output = os.path.join(badfiles_path, 'bad-draws-bern')
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        fit = StanFit(args=cmdstan_args, chains=4)
        retcodes = fit._retcodes
        for i in range(len(retcodes)):
            fit._set_retcode(i, 0)
        self.assertTrue(fit._check_retcodes())
        with self.assertRaisesRegex(ValueError, 'draws'):
            fit._validate_csv_files()

        # mismatch - column headers, draws
        output = os.path.join(badfiles_path, 'bad-cols-bern')
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        fit = StanFit(args=cmdstan_args, chains=4)
        retcodes = fit._retcodes
        for i in range(len(retcodes)):
            fit._set_retcode(i, 0)
        self.assertTrue(fit._check_retcodes())
        with self.assertRaisesRegex(ValueError, 'bad draw'):
            fit._validate_csv_files()