def _validate_csv_files(self) -> None: """ Checks that csv output files for all chains are consistent. Populates attributes for draws, column_names, num_params, metric_type. Raises exception when inconsistencies detected. """ dzero = {} for i in range(self._chains): if i == 0: dzero = check_csv(self.csv_files[i], is_optimizing=self.is_optimizing) else: d = check_csv(self.csv_files[i], is_optimizing=self.is_optimizing) for key in dzero: if key not in ('id', 'first_draw') and dzero[key] != d[key]: raise ValueError( 'csv file header mismatch, ' 'file {}, key {} is {}, expected {}'.format( self.csv_files[i], key, dzero[key], d[key])) self._draws = dzero['draws'] self._column_names = dzero['column_names'] self._num_params = dzero['num_params'] self._first_draw = dzero.get('first_draw') self._metric_type = dzero.get('metric')
def test_check_csv_1(self): csv_good = os.path.join(datafiles_path, 'bernoulli_output_1.csv') dict = check_csv(csv_good) self.assertEqual('bernoulli_model', dict['model']) self.assertEqual('10', dict['num_samples']) self.assertFalse('save_warmup' in dict) self.assertEqual(10, dict['draws']) self.assertEqual(8, len(dict['column_names']))
def _set_attrs_gq_csv_files(self, sample_csv_0: str) -> None: """ Propogate information from original sample to additional sample returned by run_generated_quantities. """ sample_meta = check_csv(sample_csv_0, is_optimizing=False, is_sampling=True) self._draws = sample_meta['draws'] self._num_params = sample_meta['num_params'] self._first_draw = sample_meta.get('first_draw') self._metric_type = sample_meta.get('metric') dzero = scan_stan_csv(self._csv_files[0], is_sampling=False) self._column_names = dzero['column_names']
def test_check_csv_metric_4(self): csv_bad = os.path.join(datafiles_path, 'output_bad_metric_4.csv') with self.assertRaisesRegex( Exception, 'invalid or missing mass matrix specification'): dict = check_csv(csv_bad)
def test_check_csv_metric_2(self): csv_bad = os.path.join(datafiles_path, 'output_bad_metric_2.csv') with self.assertRaisesRegex(Exception, 'invalid stepsize'): dict = check_csv(csv_bad)
def test_check_csv_metric_1(self): csv_bad = os.path.join(datafiles_path, 'output_bad_metric_1.csv') with self.assertRaisesRegex(Exception, 'expecting metric'): dict = check_csv(csv_bad)
def test_check_csv_4(self): csv_bad = os.path.join(datafiles_path, 'output_bad_rows.csv') with self.assertRaisesRegex(Exception, 'found 9'): dict = check_csv(csv_bad)
def test_check_csv_3(self): csv_bad = os.path.join(datafiles_path, 'output_bad_cols.csv') with self.assertRaisesRegex(Exception, '8 items'): dict = check_csv(csv_bad)
def test_check_csv_2(self): csv_bad = os.path.join(datafiles_path, 'no_such_file.csv') with self.assertRaises(Exception): dict = check_csv(csv_bad)