def test_args_sig_figs(self): sampler_args = SamplerArgs() cmdstan_path() # sets os.environ['CMDSTAN'] if not cmdstan_version_at(2, 25): with LogCapture() as log: logging.getLogger() CmdStanArgs( model_name='bernoulli', model_exe='bernoulli.exe', chain_ids=[1, 2, 3, 4], sig_figs=12, method_args=sampler_args, ) expect = ('arg sig_figs not valid, CmdStan version must be 2.25 ' 'or higher, using verson {} in directory {}').format( os.path.basename(cmdstan_path()), os.path.dirname(cmdstan_path()), ) log.check_present(('cmdstanpy', 'WARNING', expect)) else: cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe='bernoulli.exe', chain_ids=[1, 2, 3, 4], sig_figs=12, method_args=sampler_args, ) cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv') self.assertIn('sig_figs=', ' '.join(cmd)) with self.assertRaises(ValueError): CmdStanArgs( model_name='bernoulli', model_exe='bernoulli.exe', chain_ids=[1, 2, 3, 4], sig_figs=-1, method_args=sampler_args, ) with self.assertRaises(ValueError): CmdStanArgs( model_name='bernoulli', model_exe='bernoulli.exe', chain_ids=[1, 2, 3, 4], sig_figs=20, method_args=sampler_args, )
def test_validate_summary_sig_figs(self): # construct CmdStanMCMC from logistic model output, config exe = os.path.join(DATAFILES_PATH, 'logistic' + EXTENSION) rdata = os.path.join(DATAFILES_PATH, 'logistic.data.R') sampler_args = SamplerArgs(iter_sampling=100) cmdstan_args = CmdStanArgs( model_name='logistic', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=rdata, output_dir=DATAFILES_PATH, sig_figs=17, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args) runset._csv_files = [ os.path.join(DATAFILES_PATH, 'logistic_output_1.csv'), os.path.join(DATAFILES_PATH, 'logistic_output_2.csv'), os.path.join(DATAFILES_PATH, 'logistic_output_3.csv'), os.path.join(DATAFILES_PATH, 'logistic_output_4.csv'), ] retcodes = runset._retcodes for i in range(len(retcodes)): runset._set_retcode(i, 0) fit = CmdStanMCMC(runset) sum_default = fit.summary() beta1_default = format(sum_default.iloc[1, 0], '.18g') self.assertTrue(beta1_default.startswith('1.3')) if cmdstan_version_at(2, 25): sum_17 = fit.summary(sig_figs=17) beta1_17 = format(sum_17.iloc[1, 0], '.18g') self.assertTrue(beta1_17.startswith('1.345767078273')) sum_10 = fit.summary(sig_figs=10) beta1_10 = format(sum_10.iloc[1, 0], '.18g') self.assertTrue(beta1_10.startswith('1.34576707')) with self.assertRaises(ValueError): fit.summary(sig_figs=20) with self.assertRaises(ValueError): fit.summary(sig_figs=-1)
def test_validate_sample_sig_figs(self, stanfile='bernoulli.stan'): if cmdstan_version_at(2, 25): stan = os.path.join(DATAFILES_PATH, stanfile) bern_model = CmdStanModel(stan_file=stan) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_fit = bern_model.sample( data=jdata, chains=1, seed=12345, iter_sampling=100, ) bern_draws = bern_fit.draws() theta = format(bern_draws[99, 0, 7], '.18g') self.assertFalse(theta.startswith('0.21238045821757600')) bern_fit_17 = bern_model.sample( data=jdata, chains=1, seed=12345, iter_sampling=100, sig_figs=17, ) self.assertTrue(bern_fit_17.draws().size) with self.assertRaises(ValueError): bern_model.sample( data=jdata, chains=1, seed=12345, iter_sampling=100, sig_figs=27, ) with self.assertRaises(ValueError): bern_model.sample( data=jdata, chains=1, seed=12345, iter_sampling=100, sig_figs=-1, )
def test_cmdstan_version_at(self): cmdstan_path() # sets os.environ['CMDSTAN'] self.assertFalse(cmdstan_version_at(99, 99))
def summary(self, percentiles: List[int] = None, sig_figs: int = None) -> pd.DataFrame: """ Run cmdstan/bin/stansummary over all output csv files, assemble summary into DataFrame object; first row contains summary statistics for total joint log probability `lp__`, remaining rows contain summary statistics for all parameters, transformed parameters, and generated quantities variables listed in the order in which they were declared in the Stan program. :param percentiles: Ordered non-empty list of percentiles to report. Must be integers from (1, 99), inclusive. :param sig_figs: Number of significant figures to report. Must be an integer between 1 and 18. If unspecified, the default precision for the system file I/O is used; the usual value is 6. If precision above 6 is requested, sample must have been produced by CmdStan version 2.25 or later and sampler output precision must equal to or greater than the requested summary precision. :return: pandas.DataFrame """ percentiles_str = '--percentiles=5,50,95' if percentiles is not None: if len(percentiles) == 0: raise ValueError( 'invalid percentiles argument, must be ordered' ' non-empty list from (1, 99), inclusive.') cur_pct = 0 for pct in percentiles: if pct > 99 or not pct > cur_pct: raise ValueError( 'invalid percentiles spec, must be ordered' ' non-empty list from (1, 99), inclusive.') cur_pct = pct percentiles_str = '='.join( ['--percentiles', ','.join([str(x) for x in percentiles])]) sig_figs_str = '--sig_figs=2' if sig_figs is not None: if not isinstance(sig_figs, int) or sig_figs < 1 or sig_figs > 18: raise ValueError( 'sig_figs must be an integer between 1 and 18,' ' found {}'.format(sig_figs)) csv_sig_figs = self._sig_figs or 6 if sig_figs > csv_sig_figs: self._logger.warning( 'Requesting %d significant digits of output, but CSV files' ' only have %d digits of precision.', sig_figs, csv_sig_figs, ) sig_figs_str = '--sig_figs=' + str(sig_figs) cmd_path = os.path.join(cmdstan_path(), 'bin', 'stansummary' + EXTENSION) tmp_csv_file = 'stansummary-{}-'.format(self.runset._args.model_name) tmp_csv_path = create_named_text_file(dir=_TMPDIR, prefix=tmp_csv_file, suffix='.csv', name_only=True) csv_str = '--csv_filename={}'.format(tmp_csv_path) if not cmdstan_version_at(2, 24): csv_str = '--csv_file={}'.format(tmp_csv_path) cmd = [ cmd_path, percentiles_str, sig_figs_str, csv_str, ] + self.runset.csv_files do_command(cmd, logger=self.runset._logger) with open(tmp_csv_path, 'rb') as fd: summary_data = pd.read_csv( fd, delimiter=',', header=0, index_col=0, comment='#', float_precision='high', ) mask = [ x == 'lp__' or not x.endswith('__') for x in summary_data.index ] return summary_data[mask]
def validate(self) -> None: """ Check arguments correctness and consistency. * input files must exist * output files must be in a writeable directory * if no seed specified, set random seed. * length of per-chain lists equals specified # of chains """ if self.model_name is None: raise ValueError('no stan model specified') if self.model_exe is None: raise ValueError('model not compiled') if self.chain_ids is not None: for i in range(len(self.chain_ids)): if self.chain_ids[i] < 1: raise ValueError('invalid chain_id {}'.format( self.chain_ids[i])) if self.output_dir is not None: self.output_dir = os.path.realpath( os.path.expanduser(self.output_dir)) if not os.path.exists(self.output_dir): try: os.makedirs(self.output_dir) self._logger.info('created output directory: %s', self.output_dir) except (RuntimeError, PermissionError) as exc: raise ValueError( 'invalid path for output files, no such dir: {}'. format(self.output_dir)) from exc if not os.path.isdir(self.output_dir): raise ValueError( 'specified output_dir not a directory: {}'.format( self.output_dir)) try: testpath = os.path.join(self.output_dir, str(time())) with open(testpath, 'w+'): pass os.remove(testpath) # cleanup except Exception as exc: raise ValueError('invalid path for output files,' ' cannot write to dir: {}'.format( self.output_dir)) from exc if self.refresh is not None: if not isinstance(self.refresh, int) or self.refresh < 1: raise ValueError( 'Argument refresh must be a positive integer value, ' 'found {}.'.format(self.refresh)) if self.sig_figs is not None: if (not isinstance(self.sig_figs, int) or self.sig_figs < 1 or self.sig_figs > 18): raise ValueError( 'sig_figs must be an integer between 1 and 18,' ' found {}'.format(self.sig_figs)) if not cmdstan_version_at(2, 25): self.sig_figs = None self._logger.warning( 'arg sig_figs not valid, CmdStan version must be 2.25 ' 'or higher, using verson %s in directory %s', os.path.basename(cmdstan_path()), os.path.dirname(cmdstan_path()), ) if self.seed is None: rng = RandomState() self.seed = rng.randint(1, 99999 + 1) else: if not isinstance(self.seed, (int, list)): raise ValueError( 'seed must be an integer between 0 and 2**32-1,' ' found {}'.format(self.seed)) if isinstance(self.seed, int): if self.seed < 0 or self.seed > 2**32 - 1: raise ValueError( 'seed must be an integer between 0 and 2**32-1,' ' found {}'.format(self.seed)) else: if self.chain_ids is None: raise ValueError( 'seed must not be a list when no chains used') if len(self.seed) != len(self.chain_ids): raise ValueError( 'number of seeds must match number of chains,' ' found {} seed for {} chains '.format( len(self.seed), len(self.chain_ids))) for i in range(len(self.seed)): if self.seed[i] < 0 or self.seed[i] > 2**32 - 1: raise ValueError('seed must be an integer value' ' between 0 and 2**32-1,' ' found {}'.format(self.seed[i])) if isinstance(self.data, str): if not os.path.exists(self.data): raise ValueError('no such file {}'.format(self.data)) elif self.data is not None and not isinstance(self.data, (str, dict)): raise ValueError('data must be string or dict') if self.inits is not None: if isinstance(self.inits, (Integral, Real)): if self.inits < 0: raise ValueError('inits must be > 0, found {}'.format( self.inits)) elif isinstance(self.inits, str): if not os.path.exists(self.inits): raise ValueError('no such file {}'.format(self.inits)) elif isinstance(self.inits, list): if self.chain_ids is None: raise ValueError( 'inits must not be a list when no chains are used') if len(self.inits) != len(self.chain_ids): raise ValueError( 'number of inits files must match number of chains,' ' found {} inits files for {} chains '.format( len(self.inits), len(self.chain_ids))) names_set = set(self.inits) if len(names_set) != len(self.inits): raise ValueError('each chain must have its own init file,' ' found duplicates in inits files list.') for i in range(len(self.inits)): if not os.path.exists(self.inits[i]): raise ValueError('no such file {}'.format( self.inits[i]))