def test_multi_proc(self): logistic_stan = os.path.join(datafiles_path, 'logistic.stan') logistic_model = CmdStanModel(stan_file=logistic_stan) logistic_data = os.path.join(datafiles_path, 'logistic.data.R') with LogCapture() as log: logger = logging.getLogger() fit = logistic_model.sample(data=logistic_data, chains=4, cores=1) log.check_present( ('cmdstanpy', 'INFO', 'finish chain 1'), ('cmdstanpy', 'INFO', 'start chain 2'), ) with LogCapture() as log: logger = logging.getLogger() fit = logistic_model.sample(data=logistic_data, chains=4, cores=2) if cpu_count() >= 4: # finish chains 1, 2 before starting chains 3, 4 log.check_present( ('cmdstanpy', 'INFO', 'finish chain 1'), ('cmdstanpy', 'INFO', 'start chain 4'), ) if cpu_count() >= 4: with LogCapture() as log: logger = logging.getLogger() fit = logistic_model.sample(data=logistic_data, chains=4, cores=4) log.check_present( ('cmdstanpy', 'INFO', 'start chain 4'), ('cmdstanpy', 'INFO', 'finish chain 1'), )
def test_bernoulli_bad(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') bern_model = CmdStanModel(stan_file=stan) with self.assertRaisesRegex(Exception, 'Error during sampling'): bern_model.sample(chains=4, cores=2, seed=12345, sampling_iters=100)
def test_save_csv(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_model = CmdStanModel(stan_file=stan) bern_fit = bern_model.sample( data=jdata, chains=2, parallel_chains=2, seed=12345, iter_sampling=200, ) for i in range(bern_fit.runset.chains): csv_file = bern_fit.runset.csv_files[i] stdout_file = bern_fit.runset.stdout_files[i] self.assertTrue(os.path.exists(csv_file)) self.assertTrue(os.path.exists(stdout_file)) # save files to good dir bern_fit.save_csvfiles(dir=DATAFILES_PATH) for i in range(bern_fit.runset.chains): csv_file = bern_fit.runset.csv_files[i] self.assertTrue(os.path.exists(csv_file)) with self.assertRaisesRegex(Exception, 'file exists'): bern_fit.save_csvfiles(dir=DATAFILES_PATH) tmp2_dir = os.path.join(HERE, 'tmp2') os.mkdir(tmp2_dir) bern_fit.save_csvfiles(dir=tmp2_dir) for i in range(bern_fit.runset.chains): csv_file = bern_fit.runset.csv_files[i] self.assertTrue(os.path.exists(csv_file)) for i in range(bern_fit.runset.chains): # cleanup datafile_path dir os.remove(bern_fit.runset.csv_files[i]) if os.path.exists(bern_fit.runset.stdout_files[i]): os.remove(bern_fit.runset.stdout_files[i]) if os.path.exists(bern_fit.runset.stderr_files[i]): os.remove(bern_fit.runset.stderr_files[i]) shutil.rmtree(tmp2_dir, ignore_errors=True) # regenerate to tmpdir, save to good dir bern_fit = bern_model.sample( data=jdata, chains=2, parallel_chains=2, seed=12345, iter_sampling=200, ) bern_fit.save_csvfiles() # default dir for i in range(bern_fit.runset.chains): csv_file = bern_fit.runset.csv_files[i] self.assertTrue(os.path.exists(csv_file)) for i in range(bern_fit.runset.chains): # cleanup default dir os.remove(bern_fit.runset.csv_files[i]) if os.path.exists(bern_fit.runset.stdout_files[i]): os.remove(bern_fit.runset.stdout_files[i]) if os.path.exists(bern_fit.runset.stderr_files[i]): os.remove(bern_fit.runset.stderr_files[i])
def test_custom_seed(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_model = CmdStanModel(stan_file=stan) # just test that it runs without error bern_model.sample( data=jdata, chains=2, parallel_chains=2, seed=[44444, 55555], iter_sampling=200, )
def test_custom_metric(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_model = CmdStanModel(stan_file=stan) jmetric = os.path.join(DATAFILES_PATH, 'bernoulli.metric.json') # just test that it runs without error bern_model.sample( data=jdata, chains=2, cores=2, seed=12345, iter_sampling=200, metric=jmetric, )
def test_deprecated(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_model = CmdStanModel(stan_file=stan) bern_fit = bern_model.sample( data=jdata, chains=2, seed=12345, iter_warmup=200, iter_sampling=100, save_warmup=True, ) with LogCapture() as log: self.assertEqual(bern_fit.sample.shape, (100, 2, len(BERNOULLI_COLS))) log.check_present(( 'cmdstanpy', 'WARNING', 'method "sample" will be deprecated,' ' use method "draws" instead.', )) with LogCapture() as log: self.assertEqual(bern_fit.warmup.shape, (300, 2, len(BERNOULLI_COLS))) log.check_present(( 'cmdstanpy', 'WARNING', 'method "warmup" has been deprecated, instead use method' ' "draws(inc_warmup=True)", returning draws from both' ' warmup and sampling iterations.', ))
def test_show_console(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') bern_model = CmdStanModel(stan_file=stan) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_fit = bern_model.sample( data=jdata, chains=4, parallel_chains=2, seed=12345, iter_sampling=100, ) stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan') model = CmdStanModel(stan_file=stan) sys_stdout = io.StringIO() with contextlib.redirect_stdout(sys_stdout): model.generate_quantities( data=jdata, mcmc_sample=bern_fit, show_console=True, ) console = sys_stdout.getvalue() self.assertTrue('Chain [1] method = generate' in console) self.assertTrue('Chain [2] method = generate' in console) self.assertTrue('Chain [3] method = generate' in console) self.assertTrue('Chain [4] method = generate' in console)
def test_no_xarray(self): with self.without_import('xarray', cmdstanpy.stanfit): with self.assertRaises(ImportError): # if this fails the testing framework is the problem import xarray as _ # noqa stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') bern_model = CmdStanModel(stan_file=stan) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_fit = bern_model.sample( data=jdata, chains=4, parallel_chains=2, seed=12345, iter_sampling=100, ) stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan') model = CmdStanModel(stan_file=stan) bern_gqs = model.generate_quantities( data=jdata, mcmc_sample=bern_fit ) with self.assertRaises(RuntimeError): bern_gqs.draws_xr()
def test_sample_plus_quantities_dedup(self): # fitted_params - model GQ block: y_rep is PPC of theta stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan') model = CmdStanModel(stan_file=stan) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_fit = model.sample( data=jdata, chains=4, parallel_chains=2, seed=12345, iter_sampling=100, ) # gq_model - y_rep[n] == y[n] stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc_dup.stan') model = CmdStanModel(stan_file=stan) bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=bern_fit) # check that models have different y_rep values assert_raises( AssertionError, assert_array_equal, bern_fit.stan_variable(var='y_rep'), bern_gqs.stan_variable(var='y_rep'), ) # check that stan_variable returns values from gq model with open(jdata) as fd: bern_data = json.load(fd) y_rep = bern_gqs.stan_variable(var='y_rep') for i in range(10): self.assertEqual(y_rep[0, i], bern_data['y'][i])
def test_from_mcmc_sample_variables(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') bern_model = CmdStanModel(stan_file=stan) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_fit = bern_model.sample( data=jdata, chains=4, parallel_chains=2, seed=12345, iter_sampling=100, ) stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan') model = CmdStanModel(stan_file=stan) bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=bern_fit) theta = bern_gqs.stan_variable(var='theta') self.assertEqual(theta.shape, (400,)) y_rep = bern_gqs.stan_variable(var='y_rep') self.assertEqual(y_rep.shape, (400, 10)) with self.assertRaises(ValueError): bern_gqs.stan_variable(var='eta') with self.assertRaises(ValueError): bern_gqs.stan_variable(var='lp__') vars_dict = bern_gqs.stan_variables() var_names = list( bern_gqs.mcmc_sample.metadata.stan_vars_cols.keys() ) + list(bern_gqs.metadata.stan_vars_cols.keys()) self.assertEqual(set(var_names), set(list(vars_dict.keys())))
def test_sampler_diags(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_model = CmdStanModel(stan_file=stan) bern_fit = bern_model.sample(data=jdata, chains=2, seed=12345, iter_warmup=100, iter_sampling=100) diags = bern_fit.sampler_variables() self.assertEqual(SAMPLER_STATE, list(diags)) for key in diags: self.assertEqual(diags[key].shape, (100, 2)) self.assertEqual(bern_fit.sample.shape, (100, 2, len(BERNOULLI_COLS))) with LogCapture() as log: diags = bern_fit.sampler_diagnostics() self.assertEqual(SAMPLER_STATE, list(diags)) for key in diags: self.assertEqual(diags[key].shape, (100, 2)) self.assertEqual(bern_fit.sample.shape, (100, 2, len(BERNOULLI_COLS))) log.check_present(( 'cmdstanpy', 'WARNING', 'method "sample" will be deprecated,' ' use method "draws" instead.', ))
def test_dont_save_warmup(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_model = CmdStanModel(stan_file=stan) bern_fit = bern_model.sample( data=jdata, chains=2, seed=12345, iter_warmup=200, iter_sampling=100, save_warmup=False, ) self.assertEqual(bern_fit.column_names, tuple(BERNOULLI_COLS)) self.assertEqual(bern_fit.num_draws, 100) self.assertEqual(bern_fit.draws().shape, (100, 2, len(BERNOULLI_COLS))) with LogCapture() as log: self.assertEqual( bern_fit.draws(inc_warmup=True).shape, (100, 2, len(BERNOULLI_COLS)), ) log.check_present(( 'cmdstanpy', 'WARNING', 'draws from warmup iterations not available,' ' must run sampler with "save_warmup=True".', ))
def test_from_mcmc_sample(self): # fitted_params sample stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') bern_model = CmdStanModel(stan_file=stan) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_fit = bern_model.sample( data=jdata, chains=4, parallel_chains=2, seed=12345, iter_sampling=100, ) # gq_model stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan') model = CmdStanModel(stan_file=stan) bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=bern_fit) self.assertEqual( bern_gqs.runset._args.method, Method.GENERATE_QUANTITIES ) self.assertIn('CmdStanGQ: model=bernoulli_ppc', bern_gqs.__repr__()) self.assertIn('method=generate_quantities', bern_gqs.__repr__()) self.assertEqual(bern_gqs.runset.chains, 4) for i in range(bern_gqs.runset.chains): self.assertEqual(bern_gqs.runset._retcode(i), 0) csv_file = bern_gqs.runset.csv_files[i] self.assertTrue(os.path.exists(csv_file))
def test_validate(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_model = CmdStanModel(stan_file=stan) bern_fit = bern_model.sample( data=jdata, chains=2, seed=12345, iter_warmup=200, iter_sampling=100, thin=2, save_warmup=True, validate_csv=False, ) # check error messages with LogCapture() as log: logging.getLogger() self.assertIsNone(bern_fit.column_names) expect = 'csv files not yet validated' msg = log.actual()[-1][-1] self.assertTrue(msg.startswith(expect)) with LogCapture() as log: logging.getLogger() self.assertIsNone(bern_fit.stan_variable_dims) expect = 'csv files not yet validated' msg = log.actual()[-1][-1] self.assertTrue(msg.startswith(expect)) with LogCapture() as log: logging.getLogger() self.assertIsNone(bern_fit.metric_type) expect = 'csv files not yet validated' msg = log.actual()[-1][-1] self.assertTrue(msg.startswith(expect)) with LogCapture() as log: logging.getLogger() self.assertIsNone(bern_fit.metric) expect = 'csv files not yet validated' msg = log.actual()[-1][-1] self.assertTrue(msg.startswith(expect)) with LogCapture() as log: logging.getLogger() self.assertIsNone(bern_fit.stepsize) expect = 'csv files not yet validated' msg = log.actual()[-1][-1] self.assertTrue(msg.startswith(expect)) # check computations match self.assertEqual(bern_fit.num_draws, 150) bern_fit.validate_csv_files() self.assertEqual(bern_fit.num_draws, 150) self.assertEqual(len(bern_fit.column_names), 8) self.assertEqual(len(bern_fit.stan_variable_dims), 1) self.assertEqual(bern_fit.metric_type, 'diag_e')
def test_save_csv(self): stan = os.path.join(datafiles_path, 'bernoulli.stan') jdata = os.path.join(datafiles_path, 'bernoulli.data.json') bern_model = CmdStanModel(stan_file=stan) bern_fit = bern_model.sample(data=jdata, chains=4, cores=2, seed=12345, sampling_iters=200) for i in range(bern_fit.runset.chains): csv_file = bern_fit.runset.csv_files[i] txt_file = ''.join([os.path.splitext(csv_file)[0], '.txt']) self.assertTrue(os.path.exists(csv_file)) self.assertTrue(os.path.exists(txt_file)) # save files to good dir basename = 'bern_save_csvfiles_test' bern_fit.save_csvfiles(dir=datafiles_path, basename=basename) for i in range(bern_fit.runset.chains): csv_file = bern_fit.runset.csv_files[i] self.assertTrue(os.path.exists(csv_file)) with self.assertRaisesRegex(Exception, 'file exists'): bern_fit.save_csvfiles(dir=datafiles_path, basename=basename) for i in range(bern_fit.runset.chains): # cleanup datafile_path dir os.remove(bern_fit.runset.csv_files[i]) os.remove(bern_fit.runset.console_files[i]) # regenerate to tmpdir, save to good dir bern_fit = bern_model.sample(data=jdata, chains=4, cores=2, seed=12345, sampling_iters=200) bern_fit.save_csvfiles(basename=basename) # default dir for i in range(bern_fit.runset.chains): csv_file = bern_fit.runset.csv_files[i] self.assertTrue(os.path.exists(csv_file)) for i in range(bern_fit.runset.chains): # cleanup default dir os.remove(bern_fit.runset.csv_files[i]) os.remove(bern_fit.runset.console_files[i])
def test_bernoulli_bad(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') bern_model = CmdStanModel(stan_file=stan) with self.assertRaisesRegex(RuntimeError, 'variable does not exist'): bern_model.sample(chains=2, cores=2, seed=12345, iter_sampling=100) with self.assertRaisesRegex(RuntimeError, 'variable does not exist'): bern_model.sample( data={'foo': 1}, chains=2, cores=2, seed=12345, iter_sampling=100, ) if platform.system() != 'Windows': jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') dirname1 = 'tmp1' + str(time()) os.mkdir(dirname1, mode=644) dirname2 = 'tmp2' + str(time()) path = os.path.join(dirname1, dirname2) with self.assertRaisesRegex(ValueError, 'invalid path for output files'): bern_model.sample(data=jdata, chains=1, output_dir=path) os.rmdir(dirname1)
def test_sample_plus_quantities_dedup(self): stan = os.path.join(datafiles_path, 'bernoulli_ppc.stan') model = CmdStanModel(stan_file=stan) jdata = os.path.join(datafiles_path, 'bernoulli.data.json') bern_fit = model.sample(data=jdata, chains=4, cores=2, seed=12345, sampling_iters=100) bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=bern_fit) self.assertEqual(bern_gqs.sample_plus_quantities.shape[1], bern_gqs.mcmc_sample.shape[1])
def test_sampler_diags(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_model = CmdStanModel(stan_file=stan) bern_fit = bern_model.sample(data=jdata, chains=2, seed=12345, iter_warmup=100, iter_sampling=100) diags = bern_fit.sampler_diagnostics() self.assertEqual(SAMPLER_STATE, list(diags)) for key in diags: self.assertEqual(diags[key].shape, (100, 2))
def test_from_mcmc_sample_draws(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') bern_model = CmdStanModel(stan_file=stan) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_fit = bern_model.sample( data=jdata, chains=4, parallel_chains=2, seed=12345, iter_sampling=100, ) stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan') model = CmdStanModel(stan_file=stan) bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=bern_fit) self.assertEqual(bern_gqs.draws_pd().shape, (400, 10)) self.assertEqual( bern_gqs.draws_pd(inc_sample=True).shape[1], bern_gqs.mcmc_sample.draws_pd().shape[1] + bern_gqs.draws_pd().shape[1], ) row1_sample_pd = bern_fit.draws_pd().iloc[0] row1_gqs_pd = bern_gqs.draws_pd().iloc[0] self.assertTrue( np.array_equal( pd.concat((row1_sample_pd, row1_gqs_pd), axis=0).values, bern_gqs.draws_pd(inc_sample=True).iloc[0].values, ) ) # draws_xr xr_data = bern_gqs.draws_xr() self.assertEqual(xr_data.y_rep.dims, ('chain', 'draw', 'y_rep_dim_0')) self.assertEqual(xr_data.y_rep.values.shape, (4, 100, 10)) xr_var = bern_gqs.draws_xr(vars='y_rep') self.assertEqual(xr_var.y_rep.dims, ('chain', 'draw', 'y_rep_dim_0')) self.assertEqual(xr_var.y_rep.values.shape, (4, 100, 10)) xr_var = bern_gqs.draws_xr(vars=['y_rep']) self.assertEqual(xr_var.y_rep.dims, ('chain', 'draw', 'y_rep_dim_0')) self.assertEqual(xr_var.y_rep.values.shape, (4, 100, 10)) xr_data_plus = bern_gqs.draws_xr(inc_sample=True) self.assertEqual( xr_data_plus.y_rep.dims, ('chain', 'draw', 'y_rep_dim_0') ) self.assertEqual(xr_data_plus.y_rep.values.shape, (4, 100, 10)) self.assertEqual(xr_data_plus.theta.dims, ('chain', 'draw')) self.assertEqual(xr_data_plus.theta.values.shape, (4, 100))
def test_init_types(self): stan = os.path.join(datafiles_path, 'bernoulli.stan') bern_model = CmdStanModel(stan_file=stan) jdata = os.path.join(datafiles_path, 'bernoulli.data.json') bern_fit = bern_model.sample(data=jdata, chains=4, cores=2, seed=12345, sampling_iters=100, inits=1.1) self.assertIn('init=1.1', bern_fit.runset.__repr__()) bern_fit = bern_model.sample(data=jdata, chains=4, cores=2, seed=12345, sampling_iters=100, inits=1) self.assertIn('init=1', bern_fit.runset.__repr__()) with self.assertRaises(ValueError): bern_fit = bern_model.sample(data=jdata, chains=4, cores=2, seed=12345, sampling_iters=100, inits=(1, 2)) with self.assertRaises(ValueError): bern_fit = bern_model.sample(data=jdata, chains=4, cores=2, seed=12345, sampling_iters=100, inits=-1)
def test_gen_quanties_mcmc_sample(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') bern_model = CmdStanModel(stan_file=stan) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_fit = bern_model.sample( data=jdata, chains=4, parallel_chains=2, seed=12345, iter_sampling=100, ) stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan') model = CmdStanModel(stan_file=stan) bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=bern_fit) self.assertEqual( bern_gqs.runset._args.method, Method.GENERATE_QUANTITIES ) self.assertIn('CmdStanGQ: model=bernoulli_ppc', bern_gqs.__repr__()) self.assertIn('method=generate_quantities', bern_gqs.__repr__()) # check results - ouput files, quantities of interest, draws self.assertEqual(bern_gqs.runset.chains, 4) for i in range(bern_gqs.runset.chains): self.assertEqual(bern_gqs.runset._retcode(i), 0) csv_file = bern_gqs.runset.csv_files[i] self.assertTrue(os.path.exists(csv_file)) column_names = [ 'y_rep[1]', 'y_rep[2]', 'y_rep[3]', 'y_rep[4]', 'y_rep[5]', 'y_rep[6]', 'y_rep[7]', 'y_rep[8]', 'y_rep[9]', 'y_rep[10]', ] self.assertEqual(bern_gqs.column_names, tuple(column_names)) self.assertEqual(bern_fit.draws_pd().shape, bern_gqs.mcmc_sample.shape) self.assertEqual( bern_gqs.sample_plus_quantities.shape[1], bern_gqs.mcmc_sample.shape[1] + bern_gqs.generated_quantities_pd.shape[1], )
def test_check_sampler_csv_thin(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') bern_model = CmdStanModel(stan_file=stan) bern_model.compile() jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_fit = bern_model.sample( data=jdata, chains=1, parallel_chains=1, seed=12345, iter_sampling=490, iter_warmup=490, thin=7, max_treedepth=11, adapt_delta=0.98, ) csv_file = bern_fit.runset.csv_files[0] dict = check_sampler_csv( path=csv_file, is_fixed_param=False, iter_sampling=490, iter_warmup=490, thin=7, ) self.assertEqual(dict['num_samples'], 490) self.assertEqual(dict['thin'], 7) self.assertEqual(dict['draws_sampling'], 70) self.assertEqual(dict['seed'], 12345) self.assertEqual(dict['max_depth'], 11) self.assertEqual(dict['delta'], 0.98) with self.assertRaisesRegex(ValueError, 'config error'): check_sampler_csv( path=csv_file, is_fixed_param=False, iter_sampling=490, iter_warmup=490, thin=9, ) with self.assertRaisesRegex(ValueError, 'expected 490 draws, found 70'): check_sampler_csv( path=csv_file, is_fixed_param=False, iter_sampling=490, iter_warmup=490, )
def test_sample_plus_quantities_dedup(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan') model = CmdStanModel(stan_file=stan) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_fit = model.sample( data=jdata, chains=4, parallel_chains=2, seed=12345, iter_sampling=100, ) bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=bern_fit) self.assertEqual( bern_gqs.sample_plus_quantities.shape[1], bern_gqs.mcmc_sample.shape[1], )
def test_gen_quanties_mcmc_sample(self): stan = os.path.join(datafiles_path, 'bernoulli.stan') bern_model = CmdStanModel(stan_file=stan) jdata = os.path.join(datafiles_path, 'bernoulli.data.json') bern_fit = bern_model.sample(data=jdata, chains=4, cores=2, seed=12345, sampling_iters=100) stan = os.path.join(datafiles_path, 'bernoulli_ppc.stan') model = CmdStanModel(stan_file=stan) bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=bern_fit) self.assertEqual(bern_gqs.runset._args.method, Method.GENERATE_QUANTITIES) self.assertIn('CmdStanGQ: model=bernoulli_ppc', bern_gqs.__repr__()) self.assertIn('method=generate_quantities', bern_gqs.__repr__()) # check results - ouput files, quantities of interest, draws self.assertEqual(bern_gqs.runset.chains, 4) for i in range(bern_gqs.runset.chains): self.assertEqual(bern_gqs.runset._retcode(i), 0) csv_file = bern_gqs.runset.csv_files[i] self.assertTrue(os.path.exists(csv_file)) column_names = [ 'y_rep.1', 'y_rep.2', 'y_rep.3', 'y_rep.4', 'y_rep.5', 'y_rep.6', 'y_rep.7', 'y_rep.8', 'y_rep.9', 'y_rep.10', ] self.assertEqual(bern_gqs.column_names, tuple(column_names)) self.assertEqual(bern_fit.get_drawset().shape, bern_gqs.mcmc_sample.shape) self.assertEqual( bern_gqs.sample_plus_quantities.shape[1], bern_gqs.mcmc_sample.shape[1] + bern_gqs.generated_quantities_pd.shape[1])
def test_dont_save_warmup(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_model = CmdStanModel(stan_file=stan) bern_fit = bern_model.sample( data=jdata, chains=2, seed=12345, iter_warmup=200, iter_sampling=100, save_warmup=False, ) self.assertEqual(bern_fit.column_names, tuple(BERNOULLI_COLS)) self.assertEqual(bern_fit.num_draws_warmup, 0) self.assertEqual(bern_fit.warmup, None) self.assertEqual(bern_fit.num_draws, 100) self.assertEqual(bern_fit.sample.shape, (100, 2, len(BERNOULLI_COLS)))
def test_variable_bern(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_model = CmdStanModel(stan_file=stan) bern_fit = bern_model.sample(data=jdata, chains=2, seed=12345, iter_warmup=100, iter_sampling=100) self.assertEqual(1, len(bern_fit._stan_variable_dims)) self.assertTrue('theta' in bern_fit._stan_variable_dims) self.assertEqual(bern_fit._stan_variable_dims['theta'], 1) theta = bern_fit.stan_variable(name='theta') self.assertEqual(theta.shape, (200, )) with self.assertRaises(ValueError): bern_fit.stan_variable(name='eta') with self.assertRaises(ValueError): bern_fit.stan_variable(name='lp__')
def test_fixed_param_good(self): stan = os.path.join(DATAFILES_PATH, 'datagen_poisson_glm.stan') datagen_model = CmdStanModel(stan_file=stan) no_data = {} datagen_fit = datagen_model.sample(data=no_data, seed=12345, sampling_iters=100, fixed_param=True) self.assertEqual(datagen_fit.runset._args.method, Method.SAMPLE) for i in range(datagen_fit.runset.chains): csv_file = datagen_fit.runset.csv_files[i] txt_file = ''.join([os.path.splitext(csv_file)[0], '.txt']) self.assertTrue(os.path.exists(csv_file)) self.assertTrue(os.path.exists(txt_file)) self.assertEqual(datagen_fit.runset.chains, 1) column_names = [ 'lp__', 'accept_stat__', 'N', 'y_sim.1', 'y_sim.2', 'y_sim.3', 'y_sim.4', 'y_sim.5', 'y_sim.6', 'y_sim.7', 'y_sim.8', 'y_sim.9', 'y_sim.10', 'y_sim.11', 'y_sim.12', 'y_sim.13', 'y_sim.14', 'y_sim.15', 'y_sim.16', 'y_sim.17', 'y_sim.18', 'y_sim.19', 'y_sim.20', 'x_sim.1', 'x_sim.2', 'x_sim.3', 'x_sim.4', 'x_sim.5', 'x_sim.6', 'x_sim.7', 'x_sim.8', 'x_sim.9', 'x_sim.10', 'x_sim.11', 'x_sim.12', 'x_sim.13', 'x_sim.14', 'x_sim.15', 'x_sim.16', 'x_sim.17', 'x_sim.18', 'x_sim.19', 'x_sim.20', 'pop_sim.1', 'pop_sim.2', 'pop_sim.3', 'pop_sim.4', 'pop_sim.5', 'pop_sim.6', 'pop_sim.7', 'pop_sim.8', 'pop_sim.9', 'pop_sim.10', 'pop_sim.11', 'pop_sim.12', 'pop_sim.13', 'pop_sim.14', 'pop_sim.15', 'pop_sim.16', 'pop_sim.17', 'pop_sim.18', 'pop_sim.19', 'pop_sim.20', 'alpha_sim', 'beta_sim', 'eta.1', 'eta.2', 'eta.3', 'eta.4', 'eta.5', 'eta.6', 'eta.7', 'eta.8', 'eta.9', 'eta.10', 'eta.11', 'eta.12', 'eta.13', 'eta.14', 'eta.15', 'eta.16', 'eta.17', 'eta.18', 'eta.19', 'eta.20' ] self.assertEqual(datagen_fit.column_names, tuple(column_names)) self.assertEqual(datagen_fit.draws, 100) self.assertEqual(datagen_fit.sample.shape, (100, 1, len(column_names))) self.assertEqual(datagen_fit.metric, None) self.assertEqual(datagen_fit.metric_type, None) self.assertEqual(datagen_fit.stepsize, None)
def test_save_warmup_thin(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_model = CmdStanModel(stan_file=stan) bern_fit = bern_model.sample( data=jdata, chains=2, seed=12345, iter_warmup=200, iter_sampling=100, thin=5, save_warmup=True, ) self.assertEqual(bern_fit.column_names, tuple(BERNOULLI_COLS)) self.assertEqual(bern_fit.num_draws, 60) self.assertEqual(bern_fit.draws().shape, (20, 2, len(BERNOULLI_COLS))) self.assertEqual( bern_fit.draws(inc_warmup=True).shape, (60, 2, len(BERNOULLI_COLS)))
def test_adapt_schedule(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_model = CmdStanModel(stan_file=stan) bern_fit = bern_model.sample( data=jdata, chains=1, seed=12345, iter_sampling=200, iter_warmup=200, adapt_init_phase=11, adapt_metric_window=12, adapt_step_size=13, ) txt_file = bern_fit.runset.stdout_files[0] with open(txt_file, 'r') as fd: lines = fd.readlines() stripped = [line.strip() for line in lines] self.assertIn('init_buffer = 11', stripped) self.assertIn('window = 12', stripped) self.assertIn('term_buffer = 13', stripped)
def test_single_row_csv(self): stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan') bern_model = CmdStanModel(stan_file=stan) jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json') bern_fit = bern_model.sample( data=jdata, chains=1, seed=12345, iter_sampling=1, ) stan = os.path.join(DATAFILES_PATH, 'matrix_var.stan') model = CmdStanModel(stan_file=stan) gqs = model.generate_quantities(mcmc_sample=bern_fit) z_as_ndarray = gqs.stan_variable(var="z") self.assertEqual(z_as_ndarray.shape, (1, 4, 3)) # flattens chains z_as_xr = gqs.draws_xr(vars="z") self.assertEqual(z_as_xr.z.data.shape, (1, 1, 4, 3)) # keeps chains for i in range(4): for j in range(3): self.assertEqual(int(z_as_ndarray[0, i, j]), i + 1) self.assertEqual(int(z_as_xr.z.data[0, 0, i, j]), i + 1)