Пример #1
0
 def test_variables(self):
     # construct fit using existing sampler output
     exe = os.path.join(DATAFILES_PATH, 'lotka-volterra' + EXTENSION)
     jdata = os.path.join(DATAFILES_PATH, 'lotka-volterra.data.json')
     sampler_args = SamplerArgs(iter_sampling=20)
     cmdstan_args = CmdStanArgs(
         model_name='lotka-volterra',
         model_exe=exe,
         chain_ids=[1],
         seed=12345,
         data=jdata,
         output_dir=DATAFILES_PATH,
         method_args=sampler_args,
     )
     runset = RunSet(args=cmdstan_args, chains=1)
     runset._csv_files = [
         os.path.join(DATAFILES_PATH, 'lotka-volterra.csv')
     ]
     runset._set_retcode(0, 0)
     fit = CmdStanMCMC(runset)
     self.assertEqual(20, fit.num_draws)
     self.assertEqual(8, len(fit._stan_variable_dims))
     self.assertTrue('z' in fit._stan_variable_dims)
     self.assertEqual(fit._stan_variable_dims['z'], (20, 2))
     vars = fit.stan_variables()
     self.assertEqual(len(vars), len(fit._stan_variable_dims))
     self.assertTrue('z' in vars)
     self.assertEqual(vars['z'].shape, (20, 20, 2))
     self.assertTrue('theta' in vars)
     self.assertEqual(vars['theta'].shape, (20, 4))
Пример #2
0
    def test_check_retcodes(self):
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs()
        chain_ids = [1, 2, 3, 4]  # default
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=chain_ids,
            data=jdata,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=4)

        retcodes = runset._retcodes
        self.assertEqual(4, len(retcodes))
        for i in range(len(retcodes)):
            self.assertEqual(-1, runset._retcode(i))
        runset._set_retcode(0, 0)
        self.assertEqual(0, runset._retcode(0))
        for i in range(1, len(retcodes)):
            self.assertEqual(-1, runset._retcode(i))
        self.assertFalse(runset._check_retcodes())
        for i in range(1, len(retcodes)):
            runset._set_retcode(i, 0)
        self.assertTrue(runset._check_retcodes())
Пример #3
0
 def _run_cmdstan(self,
                  runset: RunSet,
                  idx: int = 0,
                  pbar: Any = None) -> None:
     """
     Encapsulates call to CmdStan.
     Spawn process, capture console output to file, record returncode.
     """
     cmd = runset.cmds[idx]
     self._logger.info('start chain %u', idx + 1)
     self._logger.debug('threads: %s',
                        str(os.environ.get('STAN_NUM_THREADS')))
     self._logger.debug('sampling: %s', cmd)
     proc = subprocess.Popen(cmd,
                             stdout=subprocess.PIPE,
                             stderr=subprocess.PIPE,
                             env=os.environ)
     if pbar:
         stdout_pbar = self._read_progress(proc, pbar, idx)
     stdout, stderr = proc.communicate()
     if pbar:
         stdout = stdout_pbar + stdout
     self._logger.info('finish chain %u', idx + 1)
     if stdout:
         with open(runset.stdout_files[idx], 'w+') as fd:
             fd.write(stdout.decode('utf-8'))
     if stderr:
         with open(runset.stderr_files[idx], 'w+') as fd:
             fd.write(stderr.decode('utf-8'))
     runset._set_retcode(idx, proc.returncode)
Пример #4
0
 def _run_cmdstan(
     self, runset: RunSet, idx: int = 0, pbar: List[Any] = None
 ) -> None:
     """
     Encapsulates call to cmdstan.
     Spawn process, capture console output to file, record returncode.
     """
     cmd = runset.cmds[idx]
     self._logger.info('start chain %u', idx + 1)
     self._logger.debug('sampling: %s', cmd)
     proc = subprocess.Popen(
         cmd,
         stdout=subprocess.PIPE,
         stderr=subprocess.PIPE,
         env=os.environ,
     )
     if pbar:
         stdout_pbar = self._read_progress(proc, pbar, idx)
     stdout, stderr = proc.communicate()
     if pbar:
         stdout = stdout_pbar + stdout
     transcript_file = runset.console_files[idx]
     self._logger.info('finish chain %u', idx + 1)
     with open(transcript_file, 'w+') as transcript:
         if stdout:
             transcript.write(stdout.decode('utf-8'))
         if stderr:
             transcript.write('ERROR')
             transcript.write(stderr.decode('utf-8'))
     runset._set_retcode(idx, proc.returncode)
Пример #5
0
    def test_good(self):
        # construct fit using existing sampler output
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs(iter_sampling=100,
                                   max_treedepth=11,
                                   adapt_delta=0.95)
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_dir=DATAFILES_PATH,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args)
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'),
        ]
        retcodes = runset._retcodes
        for i in range(len(retcodes)):
            runset._set_retcode(i, 0)
        config = check_sampler_csv(
            path=runset.csv_files[i],
            is_fixed_param=False,
            iter_sampling=100,
            iter_warmup=1000,
            save_warmup=False,
            thin=1,
        )
        expected = 'Metadata:\n{}\n'.format(config)
        metadata = InferenceMetadata(config)
        actual = '{}'.format(metadata)
        self.assertEqual(expected, actual)
        self.assertEqual(config, metadata.cmdstan_config)

        hmc_vars = {
            'lp__',
            'accept_stat__',
            'stepsize__',
            'treedepth__',
            'n_leapfrog__',
            'divergent__',
            'energy__',
        }

        sampler_vars_cols = metadata.sampler_vars_cols
        self.assertEqual(hmc_vars, sampler_vars_cols.keys())
        bern_model_vars = {'theta'}
        self.assertEqual(bern_model_vars, metadata.stan_vars_dims.keys())
        self.assertEqual((), metadata.stan_vars_dims['theta'])
        self.assertEqual(bern_model_vars, metadata.stan_vars_cols.keys())
        self.assertEqual((7, ), metadata.stan_vars_cols['theta'])
Пример #6
0
    def test_validate_good_run(self):
        # construct fit using existing sampler output
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs(iter_sampling=100,
                                   max_treedepth=11,
                                   adapt_delta=0.95)
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_dir=DATAFILES_PATH,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=4)
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'),
        ]
        self.assertEqual(4, runset.chains)
        retcodes = runset._retcodes
        for i in range(len(retcodes)):
            runset._set_retcode(i, 0)
        self.assertTrue(runset._check_retcodes())

        fit = CmdStanMCMC(runset)
        self.assertEqual(100, fit.num_draws)
        self.assertEqual(len(BERNOULLI_COLS), len(fit.column_names))
        self.assertEqual('lp__', fit.column_names[0])

        drawset = fit.get_drawset()
        self.assertEqual(
            drawset.shape,
            (fit.runset.chains * fit.num_draws, len(fit.column_names)),
        )
        _ = fit.summary()
        self.assertTrue(True)

        # TODO - use cmdstan test files instead
        expected = '\n'.join([
            'Checking sampler transitions treedepth.',
            'Treedepth satisfactory for all transitions.',
            '\nChecking sampler transitions for divergences.',
            'No divergent transitions found.',
            '\nChecking E-BFMI - sampler transitions HMC potential energy.',
            'E-BFMI satisfactory for all transitions.',
            '\nEffective sample size satisfactory.',
        ])
        self.assertIn(expected, fit.diagnose().replace('\r\n', '\n'))
Пример #7
0
    def test_validate_summary_sig_figs(self):
        # construct CmdStanMCMC from logistic model output, config
        exe = os.path.join(DATAFILES_PATH, 'logistic' + EXTENSION)
        rdata = os.path.join(DATAFILES_PATH, 'logistic.data.R')
        sampler_args = SamplerArgs(iter_sampling=100)
        cmdstan_args = CmdStanArgs(
            model_name='logistic',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=rdata,
            output_dir=DATAFILES_PATH,
            sig_figs=17,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args)
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'logistic_output_1.csv'),
            os.path.join(DATAFILES_PATH, 'logistic_output_2.csv'),
            os.path.join(DATAFILES_PATH, 'logistic_output_3.csv'),
            os.path.join(DATAFILES_PATH, 'logistic_output_4.csv'),
        ]
        retcodes = runset._retcodes
        for i in range(len(retcodes)):
            runset._set_retcode(i, 0)
        fit = CmdStanMCMC(runset)

        sum_default = fit.summary()
        beta1_default = format(sum_default.iloc[1, 0], '.18g')
        self.assertTrue(beta1_default.startswith('1.3'))

        if cmdstan_version_at(2, 25):
            sum_17 = fit.summary(sig_figs=17)
            beta1_17 = format(sum_17.iloc[1, 0], '.18g')
            self.assertTrue(beta1_17.startswith('1.345767078273'))

            sum_10 = fit.summary(sig_figs=10)
            beta1_10 = format(sum_10.iloc[1, 0], '.18g')
            self.assertTrue(beta1_10.startswith('1.34576707'))

        with self.assertRaises(ValueError):
            fit.summary(sig_figs=20)
        with self.assertRaises(ValueError):
            fit.summary(sig_figs=-1)
Пример #8
0
 def test_get_err_msgs(self):
     exe = os.path.join(DATAFILES_PATH, 'logistic' + EXTENSION)
     rdata = os.path.join(DATAFILES_PATH, 'logistic.data.R')
     sampler_args = SamplerArgs()
     cmdstan_args = CmdStanArgs(
         model_name='logistic',
         model_exe=exe,
         chain_ids=[1, 2, 3],
         data=rdata,
         method_args=sampler_args,
     )
     runset = RunSet(args=cmdstan_args, chains=3)
     for i in range(3):
         runset._set_retcode(i, 70)
         stdout_file = 'chain-' + str(i + 1) + '-missing-data-stdout.txt'
         path = os.path.join(DATAFILES_PATH, stdout_file)
         runset._stdout_files[i] = path
     errs = '\n\t'.join(runset._get_err_msgs())
     self.assertIn('Exception', errs)
Пример #9
0
 def test_variables_3d(self):
     # construct fit using existing sampler output
     exe = os.path.join(DATAFILES_PATH, 'multidim_vars' + EXTENSION)
     jdata = os.path.join(DATAFILES_PATH, 'logistic.data.R')
     sampler_args = SamplerArgs(iter_sampling=20)
     cmdstan_args = CmdStanArgs(
         model_name='multidim_vars',
         model_exe=exe,
         chain_ids=[1],
         seed=12345,
         data=jdata,
         output_dir=DATAFILES_PATH,
         method_args=sampler_args,
     )
     runset = RunSet(args=cmdstan_args, chains=1)
     runset._csv_files = [os.path.join(DATAFILES_PATH, 'multidim_vars.csv')]
     runset._set_retcode(0, 0)
     fit = CmdStanMCMC(runset)
     self.assertEqual(20, fit.num_draws_sampling)
     self.assertEqual(3, len(fit.stan_vars_dims))
     self.assertTrue('y_rep' in fit.stan_vars_dims)
     self.assertEqual(fit.stan_vars_dims['y_rep'], (5, 4, 3))
     var_y_rep = fit.stan_variable(name='y_rep')
     self.assertEqual(var_y_rep.shape, (20, 5, 4, 3))
     var_beta = fit.stan_variable(name='beta')
     self.assertEqual(var_beta.shape, (20, 2))
     var_frac_60 = fit.stan_variable(name='frac_60')
     self.assertEqual(var_frac_60.shape, (20, ))
     vars = fit.stan_variables()
     self.assertEqual(len(vars), len(fit.stan_vars_dims))
     self.assertTrue('y_rep' in vars)
     self.assertEqual(vars['y_rep'].shape, (20, 5, 4, 3))
     self.assertTrue('beta' in vars)
     self.assertEqual(vars['beta'].shape, (20, 2))
     self.assertTrue('frac_60' in vars)
     self.assertEqual(vars['frac_60'].shape, (20, ))
Пример #10
0
    def test_validate_bad_run(self):
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs(max_treedepth=11, adapt_delta=0.95)

        # some chains had errors
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_dir=DATAFILES_PATH,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=4)
        for i in range(4):
            runset._set_retcode(i, 0)
        self.assertTrue(runset._check_retcodes())

        # errors reported
        runset._stderr_files = [
            os.path.join(DATAFILES_PATH, 'runset-bad',
                         'bad-transcript-bern-1.txt'),
            os.path.join(DATAFILES_PATH, 'runset-bad',
                         'bad-transcript-bern-2.txt'),
            os.path.join(DATAFILES_PATH, 'runset-bad',
                         'bad-transcript-bern-3.txt'),
            os.path.join(DATAFILES_PATH, 'runset-bad',
                         'bad-transcript-bern-4.txt'),
        ]
        self.assertEqual(len(runset._get_err_msgs()), 4)

        # csv file headers inconsistent
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-hdr-bern-1.csv'),
            os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-hdr-bern-2.csv'),
            os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-hdr-bern-3.csv'),
            os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-hdr-bern-4.csv'),
        ]
        with self.assertRaisesRegex(ValueError, 'header mismatch'):
            CmdStanMCMC(runset)

        # bad draws
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-draws-bern-1.csv'),
            os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-draws-bern-2.csv'),
            os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-draws-bern-3.csv'),
            os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-draws-bern-4.csv'),
        ]
        with self.assertRaisesRegex(ValueError, 'draws'):
            CmdStanMCMC(runset)

        # mismatch - column headers, draws
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-1.csv'),
            os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-2.csv'),
            os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-3.csv'),
            os.path.join(DATAFILES_PATH, 'runset-bad', 'bad-cols-bern-4.csv'),
        ]
        with self.assertRaisesRegex(ValueError,
                                    'bad draw, expecting 9 items, found 8'):
            CmdStanMCMC(runset)
Пример #11
0
    def test_validate_good_run(self):
        # construct fit using existing sampler output
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs(iter_sampling=100,
                                   max_treedepth=11,
                                   adapt_delta=0.95)
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_dir=DATAFILES_PATH,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args)
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'),
        ]
        self.assertEqual(4, runset.chains)
        retcodes = runset._retcodes
        for i in range(len(retcodes)):
            runset._set_retcode(i, 0)
        self.assertTrue(runset._check_retcodes())

        fit = CmdStanMCMC(runset)
        self.assertEqual(100, fit.num_draws)
        self.assertEqual(len(BERNOULLI_COLS), len(fit.column_names))
        self.assertEqual('lp__', fit.column_names[0])

        drawset = fit.get_drawset()
        self.assertEqual(
            drawset.shape,
            (fit.runset.chains * fit.num_draws, len(fit.column_names)),
        )

        summary = fit.summary()
        self.assertIn('5%', list(summary.columns))
        self.assertIn('50%', list(summary.columns))
        self.assertIn('95%', list(summary.columns))
        self.assertNotIn('1%', list(summary.columns))
        self.assertNotIn('99%', list(summary.columns))

        summary = fit.summary(percentiles=[1, 45, 99])
        self.assertIn('1%', list(summary.columns))
        self.assertIn('45%', list(summary.columns))
        self.assertIn('99%', list(summary.columns))
        self.assertNotIn('5%', list(summary.columns))
        self.assertNotIn('50%', list(summary.columns))
        self.assertNotIn('95%', list(summary.columns))

        with self.assertRaises(ValueError):
            fit.summary(percentiles=[])

        with self.assertRaises(ValueError):
            fit.summary(percentiles=[-1])

        diagnostics = fit.diagnose()
        self.assertIn('Treedepth satisfactory for all transitions.',
                      diagnostics)
        self.assertIn('No divergent transitions found.', diagnostics)
        self.assertIn('E-BFMI satisfactory for all transitions.', diagnostics)
        self.assertIn('Effective sample size satisfactory.', diagnostics)
Пример #12
0
    def test_validate_bad_run(self):
        exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION)
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        sampler_args = SamplerArgs(sampling_iters=100,
                                   max_treedepth=11,
                                   adapt_delta=0.95)

        # some chains had errors
        output = os.path.join(badfiles_path, 'bad-transcript-bern')
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=4)
        with self.assertRaisesRegex(Exception, 'Exception'):
            runset._check_console_msgs()

        # csv file headers inconsistent
        output = os.path.join(badfiles_path, 'bad-hdr-bern')
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=4)
        retcodes = runset._retcodes
        for i in range(len(retcodes)):
            runset._set_retcode(i, 0)
        self.assertTrue(runset._check_retcodes())
        fit = CmdStanMCMC(runset)
        with self.assertRaisesRegex(ValueError, 'header mismatch'):
            fit._validate_csv_files()

        # bad draws
        output = os.path.join(badfiles_path, 'bad-draws-bern')
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=4)
        retcodes = runset._retcodes
        for i in range(len(retcodes)):
            runset._set_retcode(i, 0)
        self.assertTrue(runset._check_retcodes())
        fit = CmdStanMCMC(runset)
        with self.assertRaisesRegex(ValueError, 'draws'):
            fit._validate_csv_files()

        # mismatch - column headers, draws
        output = os.path.join(badfiles_path, 'bad-cols-bern')
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=4)
        retcodes = runset._retcodes
        for i in range(len(retcodes)):
            runset._set_retcode(i, 0)
        self.assertTrue(runset._check_retcodes())
        fit = CmdStanMCMC(runset)
        with self.assertRaisesRegex(ValueError, 'bad draw'):
            fit._validate_csv_files()
Пример #13
0
    def test_metadata(self):
        # construct CmdStanMCMC from logistic model output, config
        exe = os.path.join(DATAFILES_PATH, 'logistic' + EXTENSION)
        rdata = os.path.join(DATAFILES_PATH, 'logistic.data.R')
        sampler_args = SamplerArgs(iter_sampling=100)
        cmdstan_args = CmdStanArgs(
            model_name='logistic',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=rdata,
            output_dir=DATAFILES_PATH,
            sig_figs=17,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args)
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'logistic_output_1.csv'),
            os.path.join(DATAFILES_PATH, 'logistic_output_2.csv'),
            os.path.join(DATAFILES_PATH, 'logistic_output_3.csv'),
            os.path.join(DATAFILES_PATH, 'logistic_output_4.csv'),
        ]
        retcodes = runset._retcodes
        for i in range(len(retcodes)):
            runset._set_retcode(i, 0)
        fit = CmdStanMCMC(runset)

        col_names = tuple([
            'lp__',
            'accept_stat__',
            'stepsize__',
            'treedepth__',
            'n_leapfrog__',
            'divergent__',
            'energy__',
            'beta[1]',
            'beta[2]',
        ])

        self.assertEqual(fit.chains, 4)
        self.assertEqual(fit.chain_ids, [1, 2, 3, 4])
        self.assertEqual(fit.num_draws_warmup, 1000)
        self.assertEqual(fit.num_draws_sampling, 100)
        self.assertEqual(fit.column_names, col_names)
        self.assertEqual(fit.num_unconstrained_params, 2)
        self.assertEqual(fit.metric_type, 'diag_e')

        self.assertEqual(fit.sampler_config['num_samples'], 100)
        self.assertEqual(fit.sampler_config['thin'], 1)
        self.assertEqual(fit.sampler_config['algorithm'], 'hmc')
        self.assertEqual(fit.sampler_config['metric'], 'diag_e')
        self.assertAlmostEqual(fit.sampler_config['delta'], 0.80)

        self.assertTrue('n_leapfrog__' in fit.sampler_vars_cols)
        self.assertTrue('energy__' in fit.sampler_vars_cols)
        self.assertTrue('beta' not in fit.sampler_vars_cols)
        self.assertTrue('energy__' not in fit.stan_vars_dims)
        self.assertTrue('beta' in fit.stan_vars_dims)
        self.assertTrue('beta' in fit.stan_vars_cols)
        self.assertEqual(fit.stan_vars_dims['beta'], tuple([2]))
        self.assertEqual(fit.stan_vars_cols['beta'], tuple([7, 8]))