def _validate_csv_files(self) -> None: """ Checks that csv output files for all chains are consistent. Populates attributes for draws, column_names, num_params, metric_type. Raises exception when inconsistencies detected. """ dzero = {} for i in range(self.runset.chains): if i == 0: dzero = check_sampler_csv( path=self.runset.csv_files[i], is_fixed_param=self._is_fixed_param, iter_sampling=self._iter_sampling, iter_warmup=self._iter_warmup, save_warmup=self._save_warmup, thin=self._thin, ) else: drest = check_sampler_csv( path=self.runset.csv_files[i], is_fixed_param=self._is_fixed_param, iter_sampling=self._iter_sampling, iter_warmup=self._iter_warmup, save_warmup=self._save_warmup, thin=self._thin, ) for key in dzero: if (key not in ['id', 'diagnostic_file'] and dzero[key] != drest[key]): raise ValueError( 'csv file header mismatch, ' 'file {}, key {} is {}, expected {}'.format( self.runset.csv_files[i], key, dzero[key], drest[key], )) self._draws_sampling = dzero['draws_sampling'] if self._save_warmup: self._draws_warmup = dzero['draws_warmup'] else: self._draws_warmup = 0 self._column_names = dzero['column_names'] if not self._is_fixed_param: self._num_params = dzero['num_params'] self._metric_type = dzero.get('metric')
def test_check_sampler_csv_1(self): csv_good = os.path.join(DATAFILES_PATH, 'bernoulli_output_1.csv') dict = check_sampler_csv(csv_good) self.assertEqual('bernoulli_model', dict['model']) self.assertEqual(10, dict['num_samples']) self.assertFalse('save_warmup' in dict) self.assertEqual(10, dict['draws']) self.assertEqual(8, len(dict['column_names']))
def test_check_sampler_csv_1(self): csv_good = os.path.join(DATAFILES_PATH, 'bernoulli_output_1.csv') dict = check_sampler_csv( path=csv_good, is_fixed_param=False, iter_warmup=100, iter_sampling=10, thin=1, ) self.assertEqual('bernoulli_model', dict['model']) self.assertEqual(10, dict['num_samples']) self.assertFalse('save_warmup' in dict) self.assertEqual(10, dict['draws_sampling']) self.assertEqual(8, len(dict['column_names'])) with self.assertRaisesRegex( ValueError, 'config error, expected thin = 2' ): check_sampler_csv( path=csv_good, iter_warmup=100, iter_sampling=20, thin=2 ) with self.assertRaisesRegex( ValueError, 'config error, expected save_warmup' ): check_sampler_csv( path=csv_good, iter_warmup=100, iter_sampling=10, save_warmup=True, ) with self.assertRaisesRegex(ValueError, 'expected 1000 draws'): check_sampler_csv(path=csv_good, iter_warmup=100)
def test_good(self): # construct fit using existing sampler output exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') sampler_args = SamplerArgs(iter_sampling=100, max_treedepth=11, adapt_delta=0.95) cmdstan_args = CmdStanArgs( model_name='bernoulli', model_exe=exe, chain_ids=[1, 2, 3, 4], seed=12345, data=jdata, output_dir=DATAFILES_PATH, method_args=sampler_args, ) runset = RunSet(args=cmdstan_args) runset._csv_files = [ os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'), os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'), ] retcodes = runset._retcodes for i in range(len(retcodes)): runset._set_retcode(i, 0) config = check_sampler_csv( path=runset.csv_files[i], is_fixed_param=False, iter_sampling=100, iter_warmup=1000, save_warmup=False, thin=1, ) expected = 'Metadata:\n{}\n'.format(config) metadata = InferenceMetadata(config) actual = '{}'.format(metadata) self.assertEqual(expected, actual) self.assertEqual(config, metadata.cmdstan_config) hmc_vars = { 'lp__', 'accept_stat__', 'stepsize__', 'treedepth__', 'n_leapfrog__', 'divergent__', 'energy__', } sampler_vars_cols = metadata.sampler_vars_cols self.assertEqual(hmc_vars, sampler_vars_cols.keys()) bern_model_vars = {'theta'} self.assertEqual(bern_model_vars, metadata.stan_vars_dims.keys()) self.assertEqual((), metadata.stan_vars_dims['theta']) self.assertEqual(bern_model_vars, metadata.stan_vars_cols.keys()) self.assertEqual((7, ), metadata.stan_vars_cols['theta'])
def validate_csv_files(self) -> None: """ Checks that csv output files for all chains are consistent. Populates attributes for metadata, draws, metric, step size. Raises exception when inconsistencies detected. """ dzero = {} for i in range(self.chains): if i == 0: dzero = check_sampler_csv( path=self.runset.csv_files[i], is_fixed_param=self._is_fixed_param, iter_sampling=self._iter_sampling, iter_warmup=self._iter_warmup, save_warmup=self._save_warmup, thin=self._thin, ) else: drest = check_sampler_csv( path=self.runset.csv_files[i], is_fixed_param=self._is_fixed_param, iter_sampling=self._iter_sampling, iter_warmup=self._iter_warmup, save_warmup=self._save_warmup, thin=self._thin, ) for key in dzero: if (key not in [ 'id', 'diagnostic_file', 'metric_file', 'stepsize', 'init', 'seed', ] and dzero[key] != drest[key]): raise ValueError( 'csv file header mismatch, ' 'file {}, key {} is {}, expected {}'.format( self.runset.csv_files[i], key, dzero[key], drest[key], )) self._metadata = InferenceMetadata(dzero)
def test_check_sampler_csv_thin(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') bern_model = CmdStanModel(stan_file=stan) bern_model.compile() jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_fit = bern_model.sample( data=jdata, chains=1, cores=1, seed=12345, sampling_iters=490, warmup_iters=490, thin=7, max_treedepth=11, adapt_delta=0.98, ) csv_file = bern_fit.runset.csv_files[0] dict = check_sampler_csv(csv_file) self.assertEqual(dict['num_samples'], 490) self.assertEqual(dict['thin'], 7) self.assertEqual(dict['draws'], 70) self.assertEqual(dict['seed'], 12345) self.assertEqual(dict['max_depth'], 11) self.assertEqual(dict['delta'], 0.98)
def test_check_sampler_csv_metric_1(self): csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_metric_1.csv') with self.assertRaisesRegex(Exception, 'expecting metric'): check_sampler_csv(csv_bad)
def test_check_sampler_csv_4(self): csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_rows.csv') with self.assertRaisesRegex(Exception, 'found 9'): check_sampler_csv(csv_bad)
def test_check_sampler_csv_3(self): csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_cols.csv') with self.assertRaisesRegex(Exception, '8 items'): check_sampler_csv(csv_bad)
def test_check_sampler_csv_2(self): csv_bad = os.path.join(DATAFILES_PATH, 'no_such_file.csv') with self.assertRaises(Exception): check_sampler_csv(csv_bad)
def test_check_sampler_csv_3(self): csv_bad = os.path.join(datafiles_path, 'output_bad_cols.csv') with self.assertRaisesRegex(Exception, '8 items'): dict = check_sampler_csv(csv_bad)
def test_check_sampler_csv_metric_4(self): csv_bad = os.path.join(datafiles_path, 'output_bad_metric_4.csv') with self.assertRaisesRegex( Exception, 'invalid or missing mass matrix specification'): dict = check_sampler_csv(csv_bad)
def test_check_sampler_csv_metric_2(self): csv_bad = os.path.join(datafiles_path, 'output_bad_metric_2.csv') with self.assertRaisesRegex(Exception, 'invalid stepsize'): dict = check_sampler_csv(csv_bad)
def test_check_sampler_csv_metric_1(self): csv_bad = os.path.join(datafiles_path, 'output_bad_metric_1.csv') with self.assertRaisesRegex(Exception, 'expecting metric'): dict = check_sampler_csv(csv_bad)
def test_check_sampler_csv_4(self): csv_bad = os.path.join(datafiles_path, 'output_bad_rows.csv') with self.assertRaisesRegex(Exception, 'found 9'): dict = check_sampler_csv(csv_bad)
def test_check_sampler_csv_metric_2(self): csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_metric_2.csv') with self.assertRaisesRegex(Exception, 'invalid stepsize'): check_sampler_csv(csv_bad)
def test_check_sampler_csv_metric_4(self): csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_metric_4.csv') with self.assertRaisesRegex( Exception, 'invalid or missing mass matrix specification' ): check_sampler_csv(csv_bad)
def test_check_sampler_csv_2(self): csv_bad = os.path.join(datafiles_path, 'no_such_file.csv') with self.assertRaises(Exception): dict = check_sampler_csv(csv_bad)