示例#1
0
    def test_multi_proc(self):
        logistic_stan = os.path.join(datafiles_path, 'logistic.stan')
        logistic_model = CmdStanModel(stan_file=logistic_stan)
        logistic_data = os.path.join(datafiles_path, 'logistic.data.R')

        with LogCapture() as log:
            logger = logging.getLogger()
            fit = logistic_model.sample(data=logistic_data, chains=4, cores=1)
        log.check_present(
            ('cmdstanpy', 'INFO', 'finish chain 1'),
            ('cmdstanpy', 'INFO', 'start chain 2'),
        )
        with LogCapture() as log:
            logger = logging.getLogger()
            fit = logistic_model.sample(data=logistic_data, chains=4, cores=2)
        if cpu_count() >= 4:
            # finish chains 1, 2 before starting chains 3, 4
            log.check_present(
                ('cmdstanpy', 'INFO', 'finish chain 1'),
                ('cmdstanpy', 'INFO', 'start chain 4'),
            )
        if cpu_count() >= 4:
            with LogCapture() as log:
                logger = logging.getLogger()
                fit = logistic_model.sample(data=logistic_data,
                                            chains=4,
                                            cores=4)
                log.check_present(
                    ('cmdstanpy', 'INFO', 'start chain 4'),
                    ('cmdstanpy', 'INFO', 'finish chain 1'),
                )
示例#2
0
    def test_bernoulli_bad(self):
        stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
        bern_model = CmdStanModel(stan_file=stan)

        with self.assertRaisesRegex(Exception, 'Error during sampling'):
            bern_model.sample(chains=4,
                              cores=2,
                              seed=12345,
                              sampling_iters=100)
示例#3
0
    def test_save_csv(self):
        stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        bern_model = CmdStanModel(stan_file=stan)
        bern_fit = bern_model.sample(
            data=jdata,
            chains=2,
            parallel_chains=2,
            seed=12345,
            iter_sampling=200,
        )
        for i in range(bern_fit.runset.chains):
            csv_file = bern_fit.runset.csv_files[i]
            stdout_file = bern_fit.runset.stdout_files[i]
            self.assertTrue(os.path.exists(csv_file))
            self.assertTrue(os.path.exists(stdout_file))

        # save files to good dir
        bern_fit.save_csvfiles(dir=DATAFILES_PATH)
        for i in range(bern_fit.runset.chains):
            csv_file = bern_fit.runset.csv_files[i]
            self.assertTrue(os.path.exists(csv_file))
        with self.assertRaisesRegex(Exception, 'file exists'):
            bern_fit.save_csvfiles(dir=DATAFILES_PATH)

        tmp2_dir = os.path.join(HERE, 'tmp2')
        os.mkdir(tmp2_dir)
        bern_fit.save_csvfiles(dir=tmp2_dir)
        for i in range(bern_fit.runset.chains):
            csv_file = bern_fit.runset.csv_files[i]
            self.assertTrue(os.path.exists(csv_file))
        for i in range(bern_fit.runset.chains):  # cleanup datafile_path dir
            os.remove(bern_fit.runset.csv_files[i])
            if os.path.exists(bern_fit.runset.stdout_files[i]):
                os.remove(bern_fit.runset.stdout_files[i])
            if os.path.exists(bern_fit.runset.stderr_files[i]):
                os.remove(bern_fit.runset.stderr_files[i])
        shutil.rmtree(tmp2_dir, ignore_errors=True)

        # regenerate to tmpdir, save to good dir
        bern_fit = bern_model.sample(
            data=jdata,
            chains=2,
            parallel_chains=2,
            seed=12345,
            iter_sampling=200,
        )
        bern_fit.save_csvfiles()  # default dir
        for i in range(bern_fit.runset.chains):
            csv_file = bern_fit.runset.csv_files[i]
            self.assertTrue(os.path.exists(csv_file))
        for i in range(bern_fit.runset.chains):  # cleanup default dir
            os.remove(bern_fit.runset.csv_files[i])
            if os.path.exists(bern_fit.runset.stdout_files[i]):
                os.remove(bern_fit.runset.stdout_files[i])
            if os.path.exists(bern_fit.runset.stderr_files[i]):
                os.remove(bern_fit.runset.stderr_files[i])
示例#4
0
 def test_custom_seed(self):
     stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
     jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
     bern_model = CmdStanModel(stan_file=stan)
     # just test that it runs without error
     bern_model.sample(
         data=jdata,
         chains=2,
         parallel_chains=2,
         seed=[44444, 55555],
         iter_sampling=200,
     )
示例#5
0
 def test_custom_metric(self):
     stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
     jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
     bern_model = CmdStanModel(stan_file=stan)
     jmetric = os.path.join(DATAFILES_PATH, 'bernoulli.metric.json')
     # just test that it runs without error
     bern_model.sample(
         data=jdata,
         chains=2,
         cores=2,
         seed=12345,
         iter_sampling=200,
         metric=jmetric,
     )
示例#6
0
    def test_deprecated(self):
        stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')

        bern_model = CmdStanModel(stan_file=stan)
        bern_fit = bern_model.sample(
            data=jdata,
            chains=2,
            seed=12345,
            iter_warmup=200,
            iter_sampling=100,
            save_warmup=True,
        )
        with LogCapture() as log:
            self.assertEqual(bern_fit.sample.shape,
                             (100, 2, len(BERNOULLI_COLS)))
        log.check_present((
            'cmdstanpy',
            'WARNING',
            'method "sample" will be deprecated,'
            ' use method "draws" instead.',
        ))
        with LogCapture() as log:
            self.assertEqual(bern_fit.warmup.shape,
                             (300, 2, len(BERNOULLI_COLS)))
        log.check_present((
            'cmdstanpy',
            'WARNING',
            'method "warmup" has been deprecated, instead use method'
            ' "draws(inc_warmup=True)", returning draws from both'
            ' warmup and sampling iterations.',
        ))
    def test_show_console(self):
        stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
        bern_model = CmdStanModel(stan_file=stan)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        bern_fit = bern_model.sample(
            data=jdata,
            chains=4,
            parallel_chains=2,
            seed=12345,
            iter_sampling=100,
        )
        stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan')
        model = CmdStanModel(stan_file=stan)

        sys_stdout = io.StringIO()
        with contextlib.redirect_stdout(sys_stdout):
            model.generate_quantities(
                data=jdata,
                mcmc_sample=bern_fit,
                show_console=True,
            )
        console = sys_stdout.getvalue()
        self.assertTrue('Chain [1] method = generate' in console)
        self.assertTrue('Chain [2] method = generate' in console)
        self.assertTrue('Chain [3] method = generate' in console)
        self.assertTrue('Chain [4] method = generate' in console)
    def test_no_xarray(self):
        with self.without_import('xarray', cmdstanpy.stanfit):
            with self.assertRaises(ImportError):
                # if this fails the testing framework is the problem
                import xarray as _  # noqa

            stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
            bern_model = CmdStanModel(stan_file=stan)
            jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
            bern_fit = bern_model.sample(
                data=jdata,
                chains=4,
                parallel_chains=2,
                seed=12345,
                iter_sampling=100,
            )
            stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan')
            model = CmdStanModel(stan_file=stan)

            bern_gqs = model.generate_quantities(
                data=jdata, mcmc_sample=bern_fit
            )

            with self.assertRaises(RuntimeError):
                bern_gqs.draws_xr()
 def test_sample_plus_quantities_dedup(self):
     # fitted_params - model GQ block: y_rep is PPC of theta
     stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan')
     model = CmdStanModel(stan_file=stan)
     jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
     bern_fit = model.sample(
         data=jdata,
         chains=4,
         parallel_chains=2,
         seed=12345,
         iter_sampling=100,
     )
     # gq_model - y_rep[n] == y[n]
     stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc_dup.stan')
     model = CmdStanModel(stan_file=stan)
     bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=bern_fit)
     # check that models have different y_rep values
     assert_raises(
         AssertionError,
         assert_array_equal,
         bern_fit.stan_variable(var='y_rep'),
         bern_gqs.stan_variable(var='y_rep'),
     )
     # check that stan_variable returns values from gq model
     with open(jdata) as fd:
         bern_data = json.load(fd)
     y_rep = bern_gqs.stan_variable(var='y_rep')
     for i in range(10):
         self.assertEqual(y_rep[0, i], bern_data['y'][i])
    def test_from_mcmc_sample_variables(self):
        stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
        bern_model = CmdStanModel(stan_file=stan)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        bern_fit = bern_model.sample(
            data=jdata,
            chains=4,
            parallel_chains=2,
            seed=12345,
            iter_sampling=100,
        )
        stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan')
        model = CmdStanModel(stan_file=stan)

        bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=bern_fit)

        theta = bern_gqs.stan_variable(var='theta')
        self.assertEqual(theta.shape, (400,))
        y_rep = bern_gqs.stan_variable(var='y_rep')
        self.assertEqual(y_rep.shape, (400, 10))
        with self.assertRaises(ValueError):
            bern_gqs.stan_variable(var='eta')
        with self.assertRaises(ValueError):
            bern_gqs.stan_variable(var='lp__')

        vars_dict = bern_gqs.stan_variables()
        var_names = list(
            bern_gqs.mcmc_sample.metadata.stan_vars_cols.keys()
        ) + list(bern_gqs.metadata.stan_vars_cols.keys())
        self.assertEqual(set(var_names), set(list(vars_dict.keys())))
示例#11
0
    def test_sampler_diags(self):
        stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        bern_model = CmdStanModel(stan_file=stan)
        bern_fit = bern_model.sample(data=jdata,
                                     chains=2,
                                     seed=12345,
                                     iter_warmup=100,
                                     iter_sampling=100)
        diags = bern_fit.sampler_variables()
        self.assertEqual(SAMPLER_STATE, list(diags))
        for key in diags:
            self.assertEqual(diags[key].shape, (100, 2))
            self.assertEqual(bern_fit.sample.shape,
                             (100, 2, len(BERNOULLI_COLS)))

        with LogCapture() as log:
            diags = bern_fit.sampler_diagnostics()
            self.assertEqual(SAMPLER_STATE, list(diags))
            for key in diags:
                self.assertEqual(diags[key].shape, (100, 2))
                self.assertEqual(bern_fit.sample.shape,
                                 (100, 2, len(BERNOULLI_COLS)))
        log.check_present((
            'cmdstanpy',
            'WARNING',
            'method "sample" will be deprecated,'
            ' use method "draws" instead.',
        ))
示例#12
0
    def test_dont_save_warmup(self):
        stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')

        bern_model = CmdStanModel(stan_file=stan)
        bern_fit = bern_model.sample(
            data=jdata,
            chains=2,
            seed=12345,
            iter_warmup=200,
            iter_sampling=100,
            save_warmup=False,
        )
        self.assertEqual(bern_fit.column_names, tuple(BERNOULLI_COLS))
        self.assertEqual(bern_fit.num_draws, 100)
        self.assertEqual(bern_fit.draws().shape, (100, 2, len(BERNOULLI_COLS)))
        with LogCapture() as log:
            self.assertEqual(
                bern_fit.draws(inc_warmup=True).shape,
                (100, 2, len(BERNOULLI_COLS)),
            )
        log.check_present((
            'cmdstanpy',
            'WARNING',
            'draws from warmup iterations not available,'
            ' must run sampler with "save_warmup=True".',
        ))
    def test_from_mcmc_sample(self):
        # fitted_params sample
        stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
        bern_model = CmdStanModel(stan_file=stan)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        bern_fit = bern_model.sample(
            data=jdata,
            chains=4,
            parallel_chains=2,
            seed=12345,
            iter_sampling=100,
        )
        # gq_model
        stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan')
        model = CmdStanModel(stan_file=stan)

        bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=bern_fit)

        self.assertEqual(
            bern_gqs.runset._args.method, Method.GENERATE_QUANTITIES
        )
        self.assertIn('CmdStanGQ: model=bernoulli_ppc', bern_gqs.__repr__())
        self.assertIn('method=generate_quantities', bern_gqs.__repr__())
        self.assertEqual(bern_gqs.runset.chains, 4)
        for i in range(bern_gqs.runset.chains):
            self.assertEqual(bern_gqs.runset._retcode(i), 0)
            csv_file = bern_gqs.runset.csv_files[i]
            self.assertTrue(os.path.exists(csv_file))
示例#14
0
    def test_validate(self):
        stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        bern_model = CmdStanModel(stan_file=stan)
        bern_fit = bern_model.sample(
            data=jdata,
            chains=2,
            seed=12345,
            iter_warmup=200,
            iter_sampling=100,
            thin=2,
            save_warmup=True,
            validate_csv=False,
        )
        # check error messages
        with LogCapture() as log:
            logging.getLogger()
            self.assertIsNone(bern_fit.column_names)
        expect = 'csv files not yet validated'
        msg = log.actual()[-1][-1]
        self.assertTrue(msg.startswith(expect))

        with LogCapture() as log:
            logging.getLogger()
            self.assertIsNone(bern_fit.stan_variable_dims)
        expect = 'csv files not yet validated'
        msg = log.actual()[-1][-1]
        self.assertTrue(msg.startswith(expect))

        with LogCapture() as log:
            logging.getLogger()
            self.assertIsNone(bern_fit.metric_type)
        expect = 'csv files not yet validated'
        msg = log.actual()[-1][-1]
        self.assertTrue(msg.startswith(expect))

        with LogCapture() as log:
            logging.getLogger()
            self.assertIsNone(bern_fit.metric)
        expect = 'csv files not yet validated'
        msg = log.actual()[-1][-1]
        self.assertTrue(msg.startswith(expect))

        with LogCapture() as log:
            logging.getLogger()
            self.assertIsNone(bern_fit.stepsize)
        expect = 'csv files not yet validated'
        msg = log.actual()[-1][-1]
        self.assertTrue(msg.startswith(expect))

        # check computations match
        self.assertEqual(bern_fit.num_draws, 150)
        bern_fit.validate_csv_files()
        self.assertEqual(bern_fit.num_draws, 150)
        self.assertEqual(len(bern_fit.column_names), 8)
        self.assertEqual(len(bern_fit.stan_variable_dims), 1)
        self.assertEqual(bern_fit.metric_type, 'diag_e')
示例#15
0
    def test_save_csv(self):
        stan = os.path.join(datafiles_path, 'bernoulli.stan')
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        bern_model = CmdStanModel(stan_file=stan)
        bern_fit = bern_model.sample(data=jdata,
                                     chains=4,
                                     cores=2,
                                     seed=12345,
                                     sampling_iters=200)

        for i in range(bern_fit.runset.chains):
            csv_file = bern_fit.runset.csv_files[i]
            txt_file = ''.join([os.path.splitext(csv_file)[0], '.txt'])
            self.assertTrue(os.path.exists(csv_file))
            self.assertTrue(os.path.exists(txt_file))

        # save files to good dir
        basename = 'bern_save_csvfiles_test'
        bern_fit.save_csvfiles(dir=datafiles_path, basename=basename)
        for i in range(bern_fit.runset.chains):
            csv_file = bern_fit.runset.csv_files[i]
            self.assertTrue(os.path.exists(csv_file))
        with self.assertRaisesRegex(Exception, 'file exists'):
            bern_fit.save_csvfiles(dir=datafiles_path, basename=basename)
        for i in range(bern_fit.runset.chains):  # cleanup datafile_path dir
            os.remove(bern_fit.runset.csv_files[i])
            os.remove(bern_fit.runset.console_files[i])

        # regenerate to tmpdir, save to good dir
        bern_fit = bern_model.sample(data=jdata,
                                     chains=4,
                                     cores=2,
                                     seed=12345,
                                     sampling_iters=200)
        bern_fit.save_csvfiles(basename=basename)  # default dir
        for i in range(bern_fit.runset.chains):
            csv_file = bern_fit.runset.csv_files[i]
            self.assertTrue(os.path.exists(csv_file))
        for i in range(bern_fit.runset.chains):  # cleanup default dir
            os.remove(bern_fit.runset.csv_files[i])
            os.remove(bern_fit.runset.console_files[i])
示例#16
0
    def test_bernoulli_bad(self):
        stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
        bern_model = CmdStanModel(stan_file=stan)

        with self.assertRaisesRegex(RuntimeError, 'variable does not exist'):
            bern_model.sample(chains=2, cores=2, seed=12345, iter_sampling=100)

        with self.assertRaisesRegex(RuntimeError, 'variable does not exist'):
            bern_model.sample(
                data={'foo': 1},
                chains=2,
                cores=2,
                seed=12345,
                iter_sampling=100,
            )
        if platform.system() != 'Windows':
            jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
            dirname1 = 'tmp1' + str(time())
            os.mkdir(dirname1, mode=644)
            dirname2 = 'tmp2' + str(time())
            path = os.path.join(dirname1, dirname2)
            with self.assertRaisesRegex(ValueError,
                                        'invalid path for output files'):
                bern_model.sample(data=jdata, chains=1, output_dir=path)
            os.rmdir(dirname1)
    def test_sample_plus_quantities_dedup(self):
        stan = os.path.join(datafiles_path, 'bernoulli_ppc.stan')
        model = CmdStanModel(stan_file=stan)

        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        bern_fit = model.sample(data=jdata,
                                chains=4,
                                cores=2,
                                seed=12345,
                                sampling_iters=100)
        bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=bern_fit)
        self.assertEqual(bern_gqs.sample_plus_quantities.shape[1],
                         bern_gqs.mcmc_sample.shape[1])
示例#18
0
 def test_sampler_diags(self):
     stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
     jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
     bern_model = CmdStanModel(stan_file=stan)
     bern_fit = bern_model.sample(data=jdata,
                                  chains=2,
                                  seed=12345,
                                  iter_warmup=100,
                                  iter_sampling=100)
     diags = bern_fit.sampler_diagnostics()
     self.assertEqual(SAMPLER_STATE, list(diags))
     for key in diags:
         self.assertEqual(diags[key].shape, (100, 2))
    def test_from_mcmc_sample_draws(self):
        stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
        bern_model = CmdStanModel(stan_file=stan)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        bern_fit = bern_model.sample(
            data=jdata,
            chains=4,
            parallel_chains=2,
            seed=12345,
            iter_sampling=100,
        )
        stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan')
        model = CmdStanModel(stan_file=stan)

        bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=bern_fit)

        self.assertEqual(bern_gqs.draws_pd().shape, (400, 10))
        self.assertEqual(
            bern_gqs.draws_pd(inc_sample=True).shape[1],
            bern_gqs.mcmc_sample.draws_pd().shape[1]
            + bern_gqs.draws_pd().shape[1],
        )
        row1_sample_pd = bern_fit.draws_pd().iloc[0]
        row1_gqs_pd = bern_gqs.draws_pd().iloc[0]
        self.assertTrue(
            np.array_equal(
                pd.concat((row1_sample_pd, row1_gqs_pd), axis=0).values,
                bern_gqs.draws_pd(inc_sample=True).iloc[0].values,
            )
        )
        # draws_xr
        xr_data = bern_gqs.draws_xr()
        self.assertEqual(xr_data.y_rep.dims, ('chain', 'draw', 'y_rep_dim_0'))
        self.assertEqual(xr_data.y_rep.values.shape, (4, 100, 10))

        xr_var = bern_gqs.draws_xr(vars='y_rep')
        self.assertEqual(xr_var.y_rep.dims, ('chain', 'draw', 'y_rep_dim_0'))
        self.assertEqual(xr_var.y_rep.values.shape, (4, 100, 10))

        xr_var = bern_gqs.draws_xr(vars=['y_rep'])
        self.assertEqual(xr_var.y_rep.dims, ('chain', 'draw', 'y_rep_dim_0'))
        self.assertEqual(xr_var.y_rep.values.shape, (4, 100, 10))

        xr_data_plus = bern_gqs.draws_xr(inc_sample=True)
        self.assertEqual(
            xr_data_plus.y_rep.dims, ('chain', 'draw', 'y_rep_dim_0')
        )
        self.assertEqual(xr_data_plus.y_rep.values.shape, (4, 100, 10))
        self.assertEqual(xr_data_plus.theta.dims, ('chain', 'draw'))
        self.assertEqual(xr_data_plus.theta.values.shape, (4, 100))
示例#20
0
    def test_init_types(self):
        stan = os.path.join(datafiles_path, 'bernoulli.stan')
        bern_model = CmdStanModel(stan_file=stan)
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')

        bern_fit = bern_model.sample(data=jdata,
                                     chains=4,
                                     cores=2,
                                     seed=12345,
                                     sampling_iters=100,
                                     inits=1.1)
        self.assertIn('init=1.1', bern_fit.runset.__repr__())

        bern_fit = bern_model.sample(data=jdata,
                                     chains=4,
                                     cores=2,
                                     seed=12345,
                                     sampling_iters=100,
                                     inits=1)
        self.assertIn('init=1', bern_fit.runset.__repr__())

        with self.assertRaises(ValueError):
            bern_fit = bern_model.sample(data=jdata,
                                         chains=4,
                                         cores=2,
                                         seed=12345,
                                         sampling_iters=100,
                                         inits=(1, 2))

        with self.assertRaises(ValueError):
            bern_fit = bern_model.sample(data=jdata,
                                         chains=4,
                                         cores=2,
                                         seed=12345,
                                         sampling_iters=100,
                                         inits=-1)
示例#21
0
    def test_gen_quanties_mcmc_sample(self):
        stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
        bern_model = CmdStanModel(stan_file=stan)

        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        bern_fit = bern_model.sample(
            data=jdata,
            chains=4,
            parallel_chains=2,
            seed=12345,
            iter_sampling=100,
        )

        stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan')
        model = CmdStanModel(stan_file=stan)

        bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=bern_fit)
        self.assertEqual(
            bern_gqs.runset._args.method, Method.GENERATE_QUANTITIES
        )
        self.assertIn('CmdStanGQ: model=bernoulli_ppc', bern_gqs.__repr__())
        self.assertIn('method=generate_quantities', bern_gqs.__repr__())

        # check results - ouput files, quantities of interest, draws
        self.assertEqual(bern_gqs.runset.chains, 4)
        for i in range(bern_gqs.runset.chains):
            self.assertEqual(bern_gqs.runset._retcode(i), 0)
            csv_file = bern_gqs.runset.csv_files[i]
            self.assertTrue(os.path.exists(csv_file))
        column_names = [
            'y_rep[1]',
            'y_rep[2]',
            'y_rep[3]',
            'y_rep[4]',
            'y_rep[5]',
            'y_rep[6]',
            'y_rep[7]',
            'y_rep[8]',
            'y_rep[9]',
            'y_rep[10]',
        ]
        self.assertEqual(bern_gqs.column_names, tuple(column_names))
        self.assertEqual(bern_fit.draws_pd().shape, bern_gqs.mcmc_sample.shape)
        self.assertEqual(
            bern_gqs.sample_plus_quantities.shape[1],
            bern_gqs.mcmc_sample.shape[1]
            + bern_gqs.generated_quantities_pd.shape[1],
        )
示例#22
0
    def test_check_sampler_csv_thin(self):
        stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
        bern_model = CmdStanModel(stan_file=stan)
        bern_model.compile()
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        bern_fit = bern_model.sample(
            data=jdata,
            chains=1,
            parallel_chains=1,
            seed=12345,
            iter_sampling=490,
            iter_warmup=490,
            thin=7,
            max_treedepth=11,
            adapt_delta=0.98,
        )
        csv_file = bern_fit.runset.csv_files[0]
        dict = check_sampler_csv(
            path=csv_file,
            is_fixed_param=False,
            iter_sampling=490,
            iter_warmup=490,
            thin=7,
        )
        self.assertEqual(dict['num_samples'], 490)
        self.assertEqual(dict['thin'], 7)
        self.assertEqual(dict['draws_sampling'], 70)
        self.assertEqual(dict['seed'], 12345)
        self.assertEqual(dict['max_depth'], 11)
        self.assertEqual(dict['delta'], 0.98)

        with self.assertRaisesRegex(ValueError, 'config error'):
            check_sampler_csv(
                path=csv_file,
                is_fixed_param=False,
                iter_sampling=490,
                iter_warmup=490,
                thin=9,
            )
        with self.assertRaisesRegex(ValueError,
                                    'expected 490 draws, found 70'):
            check_sampler_csv(
                path=csv_file,
                is_fixed_param=False,
                iter_sampling=490,
                iter_warmup=490,
            )
示例#23
0
    def test_sample_plus_quantities_dedup(self):
        stan = os.path.join(DATAFILES_PATH, 'bernoulli_ppc.stan')
        model = CmdStanModel(stan_file=stan)

        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        bern_fit = model.sample(
            data=jdata,
            chains=4,
            parallel_chains=2,
            seed=12345,
            iter_sampling=100,
        )
        bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=bern_fit)
        self.assertEqual(
            bern_gqs.sample_plus_quantities.shape[1],
            bern_gqs.mcmc_sample.shape[1],
        )
    def test_gen_quanties_mcmc_sample(self):
        stan = os.path.join(datafiles_path, 'bernoulli.stan')
        bern_model = CmdStanModel(stan_file=stan)

        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        bern_fit = bern_model.sample(data=jdata,
                                     chains=4,
                                     cores=2,
                                     seed=12345,
                                     sampling_iters=100)

        stan = os.path.join(datafiles_path, 'bernoulli_ppc.stan')
        model = CmdStanModel(stan_file=stan)

        bern_gqs = model.generate_quantities(data=jdata, mcmc_sample=bern_fit)
        self.assertEqual(bern_gqs.runset._args.method,
                         Method.GENERATE_QUANTITIES)
        self.assertIn('CmdStanGQ: model=bernoulli_ppc', bern_gqs.__repr__())
        self.assertIn('method=generate_quantities', bern_gqs.__repr__())

        # check results - ouput files, quantities of interest, draws
        self.assertEqual(bern_gqs.runset.chains, 4)
        for i in range(bern_gqs.runset.chains):
            self.assertEqual(bern_gqs.runset._retcode(i), 0)
            csv_file = bern_gqs.runset.csv_files[i]
            self.assertTrue(os.path.exists(csv_file))
        column_names = [
            'y_rep.1',
            'y_rep.2',
            'y_rep.3',
            'y_rep.4',
            'y_rep.5',
            'y_rep.6',
            'y_rep.7',
            'y_rep.8',
            'y_rep.9',
            'y_rep.10',
        ]
        self.assertEqual(bern_gqs.column_names, tuple(column_names))
        self.assertEqual(bern_fit.get_drawset().shape,
                         bern_gqs.mcmc_sample.shape)
        self.assertEqual(
            bern_gqs.sample_plus_quantities.shape[1],
            bern_gqs.mcmc_sample.shape[1] +
            bern_gqs.generated_quantities_pd.shape[1])
示例#25
0
    def test_dont_save_warmup(self):
        stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')

        bern_model = CmdStanModel(stan_file=stan)
        bern_fit = bern_model.sample(
            data=jdata,
            chains=2,
            seed=12345,
            iter_warmup=200,
            iter_sampling=100,
            save_warmup=False,
        )
        self.assertEqual(bern_fit.column_names, tuple(BERNOULLI_COLS))
        self.assertEqual(bern_fit.num_draws_warmup, 0)
        self.assertEqual(bern_fit.warmup, None)
        self.assertEqual(bern_fit.num_draws, 100)
        self.assertEqual(bern_fit.sample.shape, (100, 2, len(BERNOULLI_COLS)))
示例#26
0
 def test_variable_bern(self):
     stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
     jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
     bern_model = CmdStanModel(stan_file=stan)
     bern_fit = bern_model.sample(data=jdata,
                                  chains=2,
                                  seed=12345,
                                  iter_warmup=100,
                                  iter_sampling=100)
     self.assertEqual(1, len(bern_fit._stan_variable_dims))
     self.assertTrue('theta' in bern_fit._stan_variable_dims)
     self.assertEqual(bern_fit._stan_variable_dims['theta'], 1)
     theta = bern_fit.stan_variable(name='theta')
     self.assertEqual(theta.shape, (200, ))
     with self.assertRaises(ValueError):
         bern_fit.stan_variable(name='eta')
     with self.assertRaises(ValueError):
         bern_fit.stan_variable(name='lp__')
示例#27
0
    def test_fixed_param_good(self):
        stan = os.path.join(DATAFILES_PATH, 'datagen_poisson_glm.stan')
        datagen_model = CmdStanModel(stan_file=stan)
        no_data = {}
        datagen_fit = datagen_model.sample(data=no_data,
                                           seed=12345,
                                           sampling_iters=100,
                                           fixed_param=True)
        self.assertEqual(datagen_fit.runset._args.method, Method.SAMPLE)

        for i in range(datagen_fit.runset.chains):
            csv_file = datagen_fit.runset.csv_files[i]
            txt_file = ''.join([os.path.splitext(csv_file)[0], '.txt'])
            self.assertTrue(os.path.exists(csv_file))
            self.assertTrue(os.path.exists(txt_file))

        self.assertEqual(datagen_fit.runset.chains, 1)

        column_names = [
            'lp__', 'accept_stat__', 'N', 'y_sim.1', 'y_sim.2', 'y_sim.3',
            'y_sim.4', 'y_sim.5', 'y_sim.6', 'y_sim.7', 'y_sim.8', 'y_sim.9',
            'y_sim.10', 'y_sim.11', 'y_sim.12', 'y_sim.13', 'y_sim.14',
            'y_sim.15', 'y_sim.16', 'y_sim.17', 'y_sim.18', 'y_sim.19',
            'y_sim.20', 'x_sim.1', 'x_sim.2', 'x_sim.3', 'x_sim.4', 'x_sim.5',
            'x_sim.6', 'x_sim.7', 'x_sim.8', 'x_sim.9', 'x_sim.10', 'x_sim.11',
            'x_sim.12', 'x_sim.13', 'x_sim.14', 'x_sim.15', 'x_sim.16',
            'x_sim.17', 'x_sim.18', 'x_sim.19', 'x_sim.20', 'pop_sim.1',
            'pop_sim.2', 'pop_sim.3', 'pop_sim.4', 'pop_sim.5', 'pop_sim.6',
            'pop_sim.7', 'pop_sim.8', 'pop_sim.9', 'pop_sim.10', 'pop_sim.11',
            'pop_sim.12', 'pop_sim.13', 'pop_sim.14', 'pop_sim.15',
            'pop_sim.16', 'pop_sim.17', 'pop_sim.18', 'pop_sim.19',
            'pop_sim.20', 'alpha_sim', 'beta_sim', 'eta.1', 'eta.2', 'eta.3',
            'eta.4', 'eta.5', 'eta.6', 'eta.7', 'eta.8', 'eta.9', 'eta.10',
            'eta.11', 'eta.12', 'eta.13', 'eta.14', 'eta.15', 'eta.16',
            'eta.17', 'eta.18', 'eta.19', 'eta.20'
        ]
        self.assertEqual(datagen_fit.column_names, tuple(column_names))
        self.assertEqual(datagen_fit.draws, 100)
        self.assertEqual(datagen_fit.sample.shape, (100, 1, len(column_names)))
        self.assertEqual(datagen_fit.metric, None)
        self.assertEqual(datagen_fit.metric_type, None)
        self.assertEqual(datagen_fit.stepsize, None)
示例#28
0
    def test_save_warmup_thin(self):
        stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')

        bern_model = CmdStanModel(stan_file=stan)
        bern_fit = bern_model.sample(
            data=jdata,
            chains=2,
            seed=12345,
            iter_warmup=200,
            iter_sampling=100,
            thin=5,
            save_warmup=True,
        )
        self.assertEqual(bern_fit.column_names, tuple(BERNOULLI_COLS))
        self.assertEqual(bern_fit.num_draws, 60)
        self.assertEqual(bern_fit.draws().shape, (20, 2, len(BERNOULLI_COLS)))
        self.assertEqual(
            bern_fit.draws(inc_warmup=True).shape,
            (60, 2, len(BERNOULLI_COLS)))
示例#29
0
 def test_adapt_schedule(self):
     stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
     jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
     bern_model = CmdStanModel(stan_file=stan)
     bern_fit = bern_model.sample(
         data=jdata,
         chains=1,
         seed=12345,
         iter_sampling=200,
         iter_warmup=200,
         adapt_init_phase=11,
         adapt_metric_window=12,
         adapt_step_size=13,
     )
     txt_file = bern_fit.runset.stdout_files[0]
     with open(txt_file, 'r') as fd:
         lines = fd.readlines()
         stripped = [line.strip() for line in lines]
         self.assertIn('init_buffer = 11', stripped)
         self.assertIn('window = 12', stripped)
         self.assertIn('term_buffer = 13', stripped)
 def test_single_row_csv(self):
     stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
     bern_model = CmdStanModel(stan_file=stan)
     jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
     bern_fit = bern_model.sample(
         data=jdata,
         chains=1,
         seed=12345,
         iter_sampling=1,
     )
     stan = os.path.join(DATAFILES_PATH, 'matrix_var.stan')
     model = CmdStanModel(stan_file=stan)
     gqs = model.generate_quantities(mcmc_sample=bern_fit)
     z_as_ndarray = gqs.stan_variable(var="z")
     self.assertEqual(z_as_ndarray.shape, (1, 4, 3))  # flattens chains
     z_as_xr = gqs.draws_xr(vars="z")
     self.assertEqual(z_as_xr.z.data.shape, (1, 1, 4, 3))  # keeps chains
     for i in range(4):
         for j in range(3):
             self.assertEqual(int(z_as_ndarray[0, i, j]), i + 1)
             self.assertEqual(int(z_as_xr.z.data[0, 0, i, j]), i + 1)