Exemplo n.º 1
0
 def _validate_csv_files(self) -> None:
     """
     Checks that csv output files for all chains are consistent.
     Populates attributes for draws, column_names, num_params, metric_type.
     Raises exception when inconsistencies detected.
     """
     dzero = {}
     for i in range(self.runset.chains):
         if i == 0:
             dzero = check_sampler_csv(
                 path=self.runset.csv_files[i],
                 is_fixed_param=self._is_fixed_param,
                 iter_sampling=self._iter_sampling,
                 iter_warmup=self._iter_warmup,
                 save_warmup=self._save_warmup,
                 thin=self._thin,
             )
         else:
             drest = check_sampler_csv(
                 path=self.runset.csv_files[i],
                 is_fixed_param=self._is_fixed_param,
                 iter_sampling=self._iter_sampling,
                 iter_warmup=self._iter_warmup,
                 save_warmup=self._save_warmup,
                 thin=self._thin,
             )
             for key in dzero:
                 if (key not in ['id', 'diagnostic_file']
                         and dzero[key] != drest[key]):
                     raise ValueError(
                         'csv file header mismatch, '
                         'file {}, key {} is {}, expected {}'.format(
                             self.runset.csv_files[i],
                             key,
                             dzero[key],
                             drest[key],
                         ))
     self._draws_sampling = dzero['draws_sampling']
     if self._save_warmup:
         self._draws_warmup = dzero['draws_warmup']
     else:
         self._draws_warmup = 0
     self._column_names = dzero['column_names']
     if not self._is_fixed_param:
         self._num_params = dzero['num_params']
         self._metric_type = dzero.get('metric')
Exemplo n.º 2
0
 def test_check_sampler_csv_1(self):
     csv_good = os.path.join(DATAFILES_PATH, 'bernoulli_output_1.csv')
     dict = check_sampler_csv(csv_good)
     self.assertEqual('bernoulli_model', dict['model'])
     self.assertEqual(10, dict['num_samples'])
     self.assertFalse('save_warmup' in dict)
     self.assertEqual(10, dict['draws'])
     self.assertEqual(8, len(dict['column_names']))
Exemplo n.º 3
0
    def test_check_sampler_csv_1(self):
        csv_good = os.path.join(DATAFILES_PATH, 'bernoulli_output_1.csv')
        dict = check_sampler_csv(
            path=csv_good,
            is_fixed_param=False,
            iter_warmup=100,
            iter_sampling=10,
            thin=1,
        )
        self.assertEqual('bernoulli_model', dict['model'])
        self.assertEqual(10, dict['num_samples'])
        self.assertFalse('save_warmup' in dict)
        self.assertEqual(10, dict['draws_sampling'])
        self.assertEqual(8, len(dict['column_names']))

        with self.assertRaisesRegex(
            ValueError, 'config error, expected thin = 2'
        ):
            check_sampler_csv(
                path=csv_good, iter_warmup=100, iter_sampling=20, thin=2
            )
        with self.assertRaisesRegex(
            ValueError, 'config error, expected save_warmup'
        ):
            check_sampler_csv(
                path=csv_good,
                iter_warmup=100,
                iter_sampling=10,
                save_warmup=True,
            )
        with self.assertRaisesRegex(ValueError, 'expected 1000 draws'):
            check_sampler_csv(path=csv_good, iter_warmup=100)
Exemplo n.º 4
0
    def test_good(self):
        # construct fit using existing sampler output
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs(iter_sampling=100,
                                   max_treedepth=11,
                                   adapt_delta=0.95)
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_dir=DATAFILES_PATH,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args)
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'),
        ]
        retcodes = runset._retcodes
        for i in range(len(retcodes)):
            runset._set_retcode(i, 0)
        config = check_sampler_csv(
            path=runset.csv_files[i],
            is_fixed_param=False,
            iter_sampling=100,
            iter_warmup=1000,
            save_warmup=False,
            thin=1,
        )
        expected = 'Metadata:\n{}\n'.format(config)
        metadata = InferenceMetadata(config)
        actual = '{}'.format(metadata)
        self.assertEqual(expected, actual)
        self.assertEqual(config, metadata.cmdstan_config)

        hmc_vars = {
            'lp__',
            'accept_stat__',
            'stepsize__',
            'treedepth__',
            'n_leapfrog__',
            'divergent__',
            'energy__',
        }

        sampler_vars_cols = metadata.sampler_vars_cols
        self.assertEqual(hmc_vars, sampler_vars_cols.keys())
        bern_model_vars = {'theta'}
        self.assertEqual(bern_model_vars, metadata.stan_vars_dims.keys())
        self.assertEqual((), metadata.stan_vars_dims['theta'])
        self.assertEqual(bern_model_vars, metadata.stan_vars_cols.keys())
        self.assertEqual((7, ), metadata.stan_vars_cols['theta'])
Exemplo n.º 5
0
 def validate_csv_files(self) -> None:
     """
     Checks that csv output files for all chains are consistent.
     Populates attributes for metadata, draws, metric, step size.
     Raises exception when inconsistencies detected.
     """
     dzero = {}
     for i in range(self.chains):
         if i == 0:
             dzero = check_sampler_csv(
                 path=self.runset.csv_files[i],
                 is_fixed_param=self._is_fixed_param,
                 iter_sampling=self._iter_sampling,
                 iter_warmup=self._iter_warmup,
                 save_warmup=self._save_warmup,
                 thin=self._thin,
             )
         else:
             drest = check_sampler_csv(
                 path=self.runset.csv_files[i],
                 is_fixed_param=self._is_fixed_param,
                 iter_sampling=self._iter_sampling,
                 iter_warmup=self._iter_warmup,
                 save_warmup=self._save_warmup,
                 thin=self._thin,
             )
             for key in dzero:
                 if (key not in [
                         'id',
                         'diagnostic_file',
                         'metric_file',
                         'stepsize',
                         'init',
                         'seed',
                 ] and dzero[key] != drest[key]):
                     raise ValueError(
                         'csv file header mismatch, '
                         'file {}, key {} is {}, expected {}'.format(
                             self.runset.csv_files[i],
                             key,
                             dzero[key],
                             drest[key],
                         ))
     self._metadata = InferenceMetadata(dzero)
Exemplo n.º 6
0
 def test_check_sampler_csv_thin(self):
     stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
     bern_model = CmdStanModel(stan_file=stan)
     bern_model.compile()
     jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
     bern_fit = bern_model.sample(
         data=jdata,
         chains=1,
         cores=1,
         seed=12345,
         sampling_iters=490,
         warmup_iters=490,
         thin=7,
         max_treedepth=11,
         adapt_delta=0.98,
     )
     csv_file = bern_fit.runset.csv_files[0]
     dict = check_sampler_csv(csv_file)
     self.assertEqual(dict['num_samples'], 490)
     self.assertEqual(dict['thin'], 7)
     self.assertEqual(dict['draws'], 70)
     self.assertEqual(dict['seed'], 12345)
     self.assertEqual(dict['max_depth'], 11)
     self.assertEqual(dict['delta'], 0.98)
Exemplo n.º 7
0
 def test_check_sampler_csv_metric_1(self):
     csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_metric_1.csv')
     with self.assertRaisesRegex(Exception, 'expecting metric'):
         check_sampler_csv(csv_bad)
Exemplo n.º 8
0
 def test_check_sampler_csv_4(self):
     csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_rows.csv')
     with self.assertRaisesRegex(Exception, 'found 9'):
         check_sampler_csv(csv_bad)
Exemplo n.º 9
0
 def test_check_sampler_csv_3(self):
     csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_cols.csv')
     with self.assertRaisesRegex(Exception, '8 items'):
         check_sampler_csv(csv_bad)
Exemplo n.º 10
0
 def test_check_sampler_csv_2(self):
     csv_bad = os.path.join(DATAFILES_PATH, 'no_such_file.csv')
     with self.assertRaises(Exception):
         check_sampler_csv(csv_bad)
Exemplo n.º 11
0
 def test_check_sampler_csv_3(self):
     csv_bad = os.path.join(datafiles_path, 'output_bad_cols.csv')
     with self.assertRaisesRegex(Exception, '8 items'):
         dict = check_sampler_csv(csv_bad)
Exemplo n.º 12
0
 def test_check_sampler_csv_metric_4(self):
     csv_bad = os.path.join(datafiles_path, 'output_bad_metric_4.csv')
     with self.assertRaisesRegex(
             Exception, 'invalid or missing mass matrix specification'):
         dict = check_sampler_csv(csv_bad)
Exemplo n.º 13
0
 def test_check_sampler_csv_metric_2(self):
     csv_bad = os.path.join(datafiles_path, 'output_bad_metric_2.csv')
     with self.assertRaisesRegex(Exception, 'invalid stepsize'):
         dict = check_sampler_csv(csv_bad)
Exemplo n.º 14
0
 def test_check_sampler_csv_metric_1(self):
     csv_bad = os.path.join(datafiles_path, 'output_bad_metric_1.csv')
     with self.assertRaisesRegex(Exception, 'expecting metric'):
         dict = check_sampler_csv(csv_bad)
Exemplo n.º 15
0
 def test_check_sampler_csv_4(self):
     csv_bad = os.path.join(datafiles_path, 'output_bad_rows.csv')
     with self.assertRaisesRegex(Exception, 'found 9'):
         dict = check_sampler_csv(csv_bad)
Exemplo n.º 16
0
 def test_check_sampler_csv_metric_2(self):
     csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_metric_2.csv')
     with self.assertRaisesRegex(Exception, 'invalid stepsize'):
         check_sampler_csv(csv_bad)
Exemplo n.º 17
0
 def test_check_sampler_csv_metric_4(self):
     csv_bad = os.path.join(DATAFILES_PATH, 'output_bad_metric_4.csv')
     with self.assertRaisesRegex(
         Exception, 'invalid or missing mass matrix specification'
     ):
         check_sampler_csv(csv_bad)
Exemplo n.º 18
0
 def test_check_sampler_csv_2(self):
     csv_bad = os.path.join(datafiles_path, 'no_such_file.csv')
     with self.assertRaises(Exception):
         dict = check_sampler_csv(csv_bad)