Ejemplo n.º 1
0
    def test_validate_good_run(self):
        # construct fit using existing sampler output
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs(iter_sampling=100,
                                   max_treedepth=11,
                                   adapt_delta=0.95)
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_dir=DATAFILES_PATH,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args, chains=4)
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'),
        ]
        self.assertEqual(4, runset.chains)
        retcodes = runset._retcodes
        for i in range(len(retcodes)):
            runset._set_retcode(i, 0)
        self.assertTrue(runset._check_retcodes())

        fit = CmdStanMCMC(runset)
        self.assertEqual(100, fit.num_draws)
        self.assertEqual(len(BERNOULLI_COLS), len(fit.column_names))
        self.assertEqual('lp__', fit.column_names[0])

        drawset = fit.get_drawset()
        self.assertEqual(
            drawset.shape,
            (fit.runset.chains * fit.num_draws, len(fit.column_names)),
        )
        _ = fit.summary()
        self.assertTrue(True)

        # TODO - use cmdstan test files instead
        expected = '\n'.join([
            'Checking sampler transitions treedepth.',
            'Treedepth satisfactory for all transitions.',
            '\nChecking sampler transitions for divergences.',
            'No divergent transitions found.',
            '\nChecking E-BFMI - sampler transitions HMC potential energy.',
            'E-BFMI satisfactory for all transitions.',
            '\nEffective sample size satisfactory.',
        ])
        self.assertIn(expected, fit.diagnose().replace('\r\n', '\n'))
Ejemplo n.º 2
0
 def test_validate_big_run(self):
     exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
     sampler_args = SamplerArgs()
     cmdstan_args = CmdStanArgs(
         model_name='bernoulli',
         model_exe=exe,
         chain_ids=[1, 2],
         seed=12345,
         output_dir=DATAFILES_PATH,
         method_args=sampler_args,
     )
     runset = RunSet(args=cmdstan_args, chains=2)
     runset._csv_files = [
         os.path.join(DATAFILES_PATH, 'runset-big',
                      'output_icar_nyc-1.csv'),
         os.path.join(DATAFILES_PATH, 'runset-big',
                      'output_icar_nyc-1.csv'),
     ]
     fit = CmdStanMCMC(runset)
     fit._validate_csv_files()
     sampler_state = [
         'lp__',
         'accept_stat__',
         'stepsize__',
         'treedepth__',
         'n_leapfrog__',
         'divergent__',
         'energy__',
     ]
     phis = ['phi.{}'.format(str(x + 1)) for x in range(2095)]
     column_names = sampler_state + phis
     self.assertEqual(fit.columns, len(column_names))
     self.assertEqual(fit.column_names, tuple(column_names))
     self.assertEqual(fit.metric_type, 'diag_e')
     self.assertEqual(fit.stepsize.shape, (2, ))
     self.assertEqual(fit.metric.shape, (2, 2095))
     self.assertEqual((1000, 2, 2102), fit.sample.shape)
     phis = fit.get_drawset(params=['phi'])
     self.assertEqual((2000, 2095), phis.shape)
     phi1 = fit.get_drawset(params=['phi.1'])
     self.assertEqual((2000, 1), phi1.shape)
     mo_phis = fit.get_drawset(params=['phi.1', 'phi.10', 'phi.100'])
     self.assertEqual((2000, 3), mo_phis.shape)
     phi2095 = fit.get_drawset(params=['phi.2095'])
     self.assertEqual((2000, 1), phi2095.shape)
     with self.assertRaises(Exception):
         fit.get_drawset(params=['phi.2096'])
     with self.assertRaises(Exception):
         fit.get_drawset(params=['ph'])
Ejemplo n.º 3
0
    def test_validate_good_run(self):
        # construct fit using existing sampler output
        exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
        jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
        sampler_args = SamplerArgs(iter_sampling=100,
                                   max_treedepth=11,
                                   adapt_delta=0.95)
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_dir=DATAFILES_PATH,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args)
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-1.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-2.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-3.csv'),
            os.path.join(DATAFILES_PATH, 'runset-good', 'bern-4.csv'),
        ]
        self.assertEqual(4, runset.chains)
        retcodes = runset._retcodes
        for i in range(len(retcodes)):
            runset._set_retcode(i, 0)
        self.assertTrue(runset._check_retcodes())

        fit = CmdStanMCMC(runset)
        self.assertEqual(100, fit.num_draws)
        self.assertEqual(len(BERNOULLI_COLS), len(fit.column_names))
        self.assertEqual('lp__', fit.column_names[0])

        drawset = fit.get_drawset()
        self.assertEqual(
            drawset.shape,
            (fit.runset.chains * fit.num_draws, len(fit.column_names)),
        )

        summary = fit.summary()
        self.assertIn('5%', list(summary.columns))
        self.assertIn('50%', list(summary.columns))
        self.assertIn('95%', list(summary.columns))
        self.assertNotIn('1%', list(summary.columns))
        self.assertNotIn('99%', list(summary.columns))

        summary = fit.summary(percentiles=[1, 45, 99])
        self.assertIn('1%', list(summary.columns))
        self.assertIn('45%', list(summary.columns))
        self.assertIn('99%', list(summary.columns))
        self.assertNotIn('5%', list(summary.columns))
        self.assertNotIn('50%', list(summary.columns))
        self.assertNotIn('95%', list(summary.columns))

        with self.assertRaises(ValueError):
            fit.summary(percentiles=[])

        with self.assertRaises(ValueError):
            fit.summary(percentiles=[-1])

        diagnostics = fit.diagnose()
        self.assertIn('Treedepth satisfactory for all transitions.',
                      diagnostics)
        self.assertIn('No divergent transitions found.', diagnostics)
        self.assertIn('E-BFMI satisfactory for all transitions.', diagnostics)
        self.assertIn('Effective sample size satisfactory.', diagnostics)
Ejemplo n.º 4
0
    def generate_quantities(
        self,
        data: Union[Dict, str] = None,
        mcmc_sample: Union[CmdStanMCMC, List[str]] = None,
        seed: int = None,
        gq_output_dir: str = None,
    ) -> CmdStanGQ:
        """
        Run CmdStan's generate_quantities method which runs the generated
        quantities block of a model given an existing sample.

        This function takes a CmdStanMCMC object and the dataset used to
        generate that sample and calls to the CmdStan ``generate_quantities``
        method to generate additional quantities of interest.

        The ``CmdStanGQ`` object records the command, the return code,
        and the paths to the generate method output csv and console files.
        The output files are written either to a specified output directory
        or to a temporary directory which is deleted upon session exit.

        Output filenames are composed of the model name, a timestamp
        in the form 'YYYYMMDDhhmm' and the chain id, plus the corresponding
        filetype suffix, either '.csv' for the CmdStan output or '.txt' for
        the console messages, e.g. `bernoulli_ppc-201912081451-1.csv`. Output
        files  written to the temporary directory contain an additional
        8-character random string, e.g.
        `bernoulli_ppc-201912081451-1-5nm6as7u.csv`.

        :param data: Values for all data variables in the model, specified
            either as a dictionary with entries matching the data variables,
            or as the path of a data file in JSON or Rdump format.

        :param mcmc_sample: Can be either a ``CmdStanMCMC`` object returned by
            the ``sample`` method or a list of stan-csv files generated
            by fitting the model to the data using any Stan interface.

        :param seed: The seed for random number generator. Must be an integer
            between 0 and 2^32 - 1. If unspecified,
            ``numpy.random.RandomState()``
            is used to generate a seed which will be used for all chains.
            *NOTE: Specifying the seed will guarantee the same result for
            multiple invocations of this method with the same inputs.  However
            this will not reproduce results from the sample method given
            the same inputs because the RNG will be in a different state.*

        :param gq_output_dir:  Name of the directory in which the CmdStan output
            files are saved.  If unspecified, files will be written to a
            temporary directory which is deleted upon session exit.

        :return: CmdStanGQ object
        """
        sample_csv_files = []
        sample_drawset = None
        chains = 0

        if isinstance(mcmc_sample, CmdStanMCMC):
            sample_csv_files = mcmc_sample.runset.csv_files
            sample_drawset = mcmc_sample.get_drawset()
            chains = mcmc_sample.chains
        elif isinstance(mcmc_sample, list):
            sample_csv_files = mcmc_sample
        else:
            raise ValueError('MCMC sample must be either CmdStanMCMC object'
                             ' or list of paths to sample csv_files.')

        try:
            chains = len(sample_csv_files)
            if sample_drawset is None:  # assemble sample from csv files
                config = {}
                # scan 1st csv file to get config
                try:
                    config = scan_sampler_csv(sample_csv_files[0])
                except ValueError:
                    config = scan_sampler_csv(sample_csv_files[0], True)
                conf_iter_sampling = None
                if 'num_samples' in config:
                    conf_iter_sampling = int(config['num_samples'])
                conf_iter_warmup = None
                if 'num_warmup' in config:
                    conf_iter_warmup = int(config['num_warmup'])
                conf_thin = None
                if 'thin' in config:
                    conf_thin = int(config['thin'])
                sampler_args = SamplerArgs(
                    iter_sampling=conf_iter_sampling,
                    iter_warmup=conf_iter_warmup,
                    thin=conf_thin,
                )
                args = CmdStanArgs(
                    self._name,
                    self._exe_file,
                    chain_ids=[x + 1 for x in range(chains)],
                    method_args=sampler_args,
                )
                runset = RunSet(args=args, chains=chains)
                runset._csv_files = sample_csv_files
                sample_fit = CmdStanMCMC(runset)
                sample_drawset = sample_fit.get_drawset()
        except ValueError as e:
            raise ValueError('Invalid mcmc_sample, error:\n\t{}\n\t'
                             ' while processing files\n\t{}'.format(
                                 repr(e), '\n\t'.join(sample_csv_files)))

        generate_quantities_args = GenerateQuantitiesArgs(
            csv_files=sample_csv_files)
        generate_quantities_args.validate(chains)
        with MaybeDictToFilePath(data, None) as (_data, _inits):
            args = CmdStanArgs(
                self._name,
                self._exe_file,
                chain_ids=[x + 1 for x in range(chains)],
                data=_data,
                seed=seed,
                output_dir=gq_output_dir,
                method_args=generate_quantities_args,
            )
            runset = RunSet(args=args, chains=chains)

            cores_avail = cpu_count()
            cores = max(min(cores_avail - 2, chains), 1)
            with ThreadPoolExecutor(max_workers=cores) as executor:
                for i in range(chains):
                    executor.submit(self._run_cmdstan, runset, i)

            if not runset._check_retcodes():
                msg = 'Error during generate_quantities'
                for i in range(chains):
                    if runset._retcode(i) != 0:
                        msg = '{}, chain {} returned error code {}'.format(
                            msg, i, runset._retcode(i))
                raise RuntimeError(msg)
            quantities = CmdStanGQ(runset=runset, mcmc_sample=sample_drawset)
        return quantities
Ejemplo n.º 5
0
    def generate_quantities(
        self,
        data: Union[Dict, str] = None,
        mcmc_sample: Union[CmdStanMCMC, List[str]] = None,
        seed: int = None,
        gq_csv_basename: str = None,
    ) -> CmdStanGQ:
        """
        Wrapper for generated quantities call.  Given a CmdStanMCMC object
        containing a sample from the fitted model, along with the
        corresponding dataset for that fit, run just the generated quantities
        block of the model in order to get additional quantities of interest.

        :param data: Values for all data variables in the model, specified
            either as a dictionary with entries matching the data variables,
            or as the path of a data file in JSON or Rdump format.

        :param mcmc_sample: Can be either a CmdStanMCMC object returned by
            CmdStanPy's `sample` method or a list of stan-csv files generated
            by fitting the model to the data using any Stan interface.

        :param seed: The seed for random number generator. Must be an integer
            between ``0`` and ``2^32 - 1``. If unspecified,
            ``numpy.random.RandomState()``
            is used to generate a seed which will be used for all chains.
            *NOTE: Specifying the seed will guarantee the same result for
            multiple invocations of this method with the same inputs.  However
            this will not reproduce results from the sample method given
            the same inputs because the RNG will be in a different state.*

        :param gq_csv_basename: A path or file name which will be used as the
            basename for the sampler output files.  The csv output files
            for each chain are written to file ``<basename>-<chain_id>.csv``
            and the console output and error messages are written to file
            ``<basename>-<chain_id>.txt``.

        :return: CmdStanGQ object
        """
        sample_csv_files = []
        sample_drawset = None
        chains = 0

        if isinstance(mcmc_sample, CmdStanMCMC):
            sample_csv_files = mcmc_sample.runset.csv_files
            sample_drawset = mcmc_sample.get_drawset()
            chains = mcmc_sample.chains
        elif isinstance(mcmc_sample, list):
            sample_csv_files = mcmc_sample
        else:
            raise ValueError(
                'mcmc_sample must be either CmdStanMCMC object'
                ' or list of paths to sample csv_files'
            )

        try:
            chains = len(sample_csv_files)
            if sample_drawset is None:  # assemble sample from csv files
                sampler_args = SamplerArgs()
                args = CmdStanArgs(
                    self._name,
                    self._exe_file,
                    chain_ids=[x + 1 for x in range(chains)],
                    method_args=sampler_args,
                )
                runset = RunSet(args=args, chains=chains)
                runset._csv_files = sample_csv_files
                sample_fit = CmdStanMCMC(runset)
                sample_fit._validate_csv_files()
                sample_drawset = sample_fit.get_drawset()
        except ValueError as e:
            raise ValueError(
                'Invalid mcmc_sample, error:\n\t{}\n\t'
                ' while processing files\n\t{}'.format(
                    repr(e), '\n\t'.join(sample_csv_files))
            )

        generate_quantities_args = GenerateQuantitiesArgs(
            csv_files=sample_csv_files
        )
        generate_quantities_args.validate(chains)
        with MaybeDictToFilePath(data, None) as (_data, _inits):
            args = CmdStanArgs(
                self._name,
                self._exe_file,
                chain_ids=[x + 1 for x in range(chains)],
                data=_data,
                seed=seed,
                output_basename=gq_csv_basename,
                method_args=generate_quantities_args,
            )
            runset = RunSet(args=args, chains=chains)

            cores_avail = cpu_count()
            cores = max(min(cores_avail - 2, chains), 1)
            with ThreadPoolExecutor(max_workers=cores) as executor:
                for i in range(chains):
                    executor.submit(self._run_cmdstan, runset, i)

            if not runset._check_retcodes():
                msg = 'Error during generate_quantities'
                for i in range(chains):
                    if runset._retcode(i) != 0:
                        msg = '{}, chain {} returned error code {}'.format(
                            msg, i, runset._retcode(i)
                        )
                raise RuntimeError(msg)
            quantities = CmdStanGQ(runset=runset, mcmc_sample=sample_drawset)
            quantities._set_attrs_gq_csv_files(sample_csv_files[0])
        return quantities