Example #1
0
 def test_args_sig_figs(self):
     sampler_args = SamplerArgs()
     cmdstan_path()  # sets os.environ['CMDSTAN']
     if not cmdstan_version_at(2, 25):
         with LogCapture() as log:
             logging.getLogger()
             CmdStanArgs(
                 model_name='bernoulli',
                 model_exe='bernoulli.exe',
                 chain_ids=[1, 2, 3, 4],
                 sig_figs=12,
                 method_args=sampler_args,
             )
         expect = ('arg sig_figs not valid, CmdStan version must be 2.25 '
                   'or higher, using verson {} in directory {}').format(
                       os.path.basename(cmdstan_path()),
                       os.path.dirname(cmdstan_path()),
                   )
         log.check_present(('cmdstanpy', 'WARNING', expect))
     else:
         cmdstan_args = CmdStanArgs(
             model_name='bernoulli',
             model_exe='bernoulli.exe',
             chain_ids=[1, 2, 3, 4],
             sig_figs=12,
             method_args=sampler_args,
         )
         cmd = cmdstan_args.compose_command(idx=0,
                                            csv_file='bern-output-1.csv')
         self.assertIn('sig_figs=', ' '.join(cmd))
         with self.assertRaises(ValueError):
             CmdStanArgs(
                 model_name='bernoulli',
                 model_exe='bernoulli.exe',
                 chain_ids=[1, 2, 3, 4],
                 sig_figs=-1,
                 method_args=sampler_args,
             )
         with self.assertRaises(ValueError):
             CmdStanArgs(
                 model_name='bernoulli',
                 model_exe='bernoulli.exe',
                 chain_ids=[1, 2, 3, 4],
                 sig_figs=20,
                 method_args=sampler_args,
             )
Example #2
0
    def test_validate_summary_sig_figs(self):
        # construct CmdStanMCMC from logistic model output, config
        exe = os.path.join(DATAFILES_PATH, 'logistic' + EXTENSION)
        rdata = os.path.join(DATAFILES_PATH, 'logistic.data.R')
        sampler_args = SamplerArgs(iter_sampling=100)
        cmdstan_args = CmdStanArgs(
            model_name='logistic',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=rdata,
            output_dir=DATAFILES_PATH,
            sig_figs=17,
            method_args=sampler_args,
        )
        runset = RunSet(args=cmdstan_args)
        runset._csv_files = [
            os.path.join(DATAFILES_PATH, 'logistic_output_1.csv'),
            os.path.join(DATAFILES_PATH, 'logistic_output_2.csv'),
            os.path.join(DATAFILES_PATH, 'logistic_output_3.csv'),
            os.path.join(DATAFILES_PATH, 'logistic_output_4.csv'),
        ]
        retcodes = runset._retcodes
        for i in range(len(retcodes)):
            runset._set_retcode(i, 0)
        fit = CmdStanMCMC(runset)

        sum_default = fit.summary()
        beta1_default = format(sum_default.iloc[1, 0], '.18g')
        self.assertTrue(beta1_default.startswith('1.3'))

        if cmdstan_version_at(2, 25):
            sum_17 = fit.summary(sig_figs=17)
            beta1_17 = format(sum_17.iloc[1, 0], '.18g')
            self.assertTrue(beta1_17.startswith('1.345767078273'))

            sum_10 = fit.summary(sig_figs=10)
            beta1_10 = format(sum_10.iloc[1, 0], '.18g')
            self.assertTrue(beta1_10.startswith('1.34576707'))

        with self.assertRaises(ValueError):
            fit.summary(sig_figs=20)
        with self.assertRaises(ValueError):
            fit.summary(sig_figs=-1)
Example #3
0
    def test_validate_sample_sig_figs(self, stanfile='bernoulli.stan'):
        if cmdstan_version_at(2, 25):
            stan = os.path.join(DATAFILES_PATH, stanfile)
            bern_model = CmdStanModel(stan_file=stan)

            jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
            bern_fit = bern_model.sample(
                data=jdata,
                chains=1,
                seed=12345,
                iter_sampling=100,
            )
            bern_draws = bern_fit.draws()
            theta = format(bern_draws[99, 0, 7], '.18g')
            self.assertFalse(theta.startswith('0.21238045821757600'))

            bern_fit_17 = bern_model.sample(
                data=jdata,
                chains=1,
                seed=12345,
                iter_sampling=100,
                sig_figs=17,
            )
            self.assertTrue(bern_fit_17.draws().size)

            with self.assertRaises(ValueError):
                bern_model.sample(
                    data=jdata,
                    chains=1,
                    seed=12345,
                    iter_sampling=100,
                    sig_figs=27,
                )
                with self.assertRaises(ValueError):
                    bern_model.sample(
                        data=jdata,
                        chains=1,
                        seed=12345,
                        iter_sampling=100,
                        sig_figs=-1,
                    )
Example #4
0
 def test_cmdstan_version_at(self):
     cmdstan_path()  # sets os.environ['CMDSTAN']
     self.assertFalse(cmdstan_version_at(99, 99))
Example #5
0
    def summary(self,
                percentiles: List[int] = None,
                sig_figs: int = None) -> pd.DataFrame:
        """
        Run cmdstan/bin/stansummary over all output csv files, assemble
        summary into DataFrame object; first row contains summary statistics
        for total joint log probability `lp__`, remaining rows contain summary
        statistics for all parameters, transformed parameters, and generated
        quantities variables listed in the order in which they were declared
        in the Stan program.

        :param percentiles: Ordered non-empty list of percentiles to report.
            Must be integers from (1, 99), inclusive.

        :param sig_figs: Number of significant figures to report.
            Must be an integer between 1 and 18.  If unspecified, the default
            precision for the system file I/O is used; the usual value is 6.
            If precision above 6 is requested, sample must have been produced
            by CmdStan version 2.25 or later and sampler output precision
            must equal to or greater than the requested summary precision.

        :return: pandas.DataFrame
        """
        percentiles_str = '--percentiles=5,50,95'
        if percentiles is not None:
            if len(percentiles) == 0:
                raise ValueError(
                    'invalid percentiles argument, must be ordered'
                    ' non-empty list from (1, 99), inclusive.')
            cur_pct = 0
            for pct in percentiles:
                if pct > 99 or not pct > cur_pct:
                    raise ValueError(
                        'invalid percentiles spec, must be ordered'
                        ' non-empty list from (1, 99), inclusive.')
                cur_pct = pct
            percentiles_str = '='.join(
                ['--percentiles', ','.join([str(x) for x in percentiles])])
        sig_figs_str = '--sig_figs=2'
        if sig_figs is not None:
            if not isinstance(sig_figs, int) or sig_figs < 1 or sig_figs > 18:
                raise ValueError(
                    'sig_figs must be an integer between 1 and 18,'
                    ' found {}'.format(sig_figs))
            csv_sig_figs = self._sig_figs or 6
            if sig_figs > csv_sig_figs:
                self._logger.warning(
                    'Requesting %d significant digits of output, but CSV files'
                    ' only have %d digits of precision.',
                    sig_figs,
                    csv_sig_figs,
                )
            sig_figs_str = '--sig_figs=' + str(sig_figs)
        cmd_path = os.path.join(cmdstan_path(), 'bin',
                                'stansummary' + EXTENSION)
        tmp_csv_file = 'stansummary-{}-'.format(self.runset._args.model_name)
        tmp_csv_path = create_named_text_file(dir=_TMPDIR,
                                              prefix=tmp_csv_file,
                                              suffix='.csv',
                                              name_only=True)
        csv_str = '--csv_filename={}'.format(tmp_csv_path)
        if not cmdstan_version_at(2, 24):
            csv_str = '--csv_file={}'.format(tmp_csv_path)
        cmd = [
            cmd_path,
            percentiles_str,
            sig_figs_str,
            csv_str,
        ] + self.runset.csv_files
        do_command(cmd, logger=self.runset._logger)
        with open(tmp_csv_path, 'rb') as fd:
            summary_data = pd.read_csv(
                fd,
                delimiter=',',
                header=0,
                index_col=0,
                comment='#',
                float_precision='high',
            )
        mask = [
            x == 'lp__' or not x.endswith('__') for x in summary_data.index
        ]
        return summary_data[mask]
Example #6
0
    def validate(self) -> None:
        """
        Check arguments correctness and consistency.

        * input files must exist
        * output files must be in a writeable directory
        * if no seed specified, set random seed.
        * length of per-chain lists equals specified # of chains
        """
        if self.model_name is None:
            raise ValueError('no stan model specified')
        if self.model_exe is None:
            raise ValueError('model not compiled')

        if self.chain_ids is not None:
            for i in range(len(self.chain_ids)):
                if self.chain_ids[i] < 1:
                    raise ValueError('invalid chain_id {}'.format(
                        self.chain_ids[i]))
        if self.output_dir is not None:
            self.output_dir = os.path.realpath(
                os.path.expanduser(self.output_dir))
            if not os.path.exists(self.output_dir):
                try:
                    os.makedirs(self.output_dir)
                    self._logger.info('created output directory: %s',
                                      self.output_dir)
                except (RuntimeError, PermissionError) as exc:
                    raise ValueError(
                        'invalid path for output files, no such dir: {}'.
                        format(self.output_dir)) from exc
            if not os.path.isdir(self.output_dir):
                raise ValueError(
                    'specified output_dir not a directory: {}'.format(
                        self.output_dir))
            try:
                testpath = os.path.join(self.output_dir, str(time()))
                with open(testpath, 'w+'):
                    pass
                os.remove(testpath)  # cleanup
            except Exception as exc:
                raise ValueError('invalid path for output files,'
                                 ' cannot write to dir: {}'.format(
                                     self.output_dir)) from exc
        if self.refresh is not None:
            if not isinstance(self.refresh, int) or self.refresh < 1:
                raise ValueError(
                    'Argument refresh must be a positive integer value, '
                    'found {}.'.format(self.refresh))

        if self.sig_figs is not None:
            if (not isinstance(self.sig_figs, int) or self.sig_figs < 1
                    or self.sig_figs > 18):
                raise ValueError(
                    'sig_figs must be an integer between 1 and 18,'
                    ' found {}'.format(self.sig_figs))
            if not cmdstan_version_at(2, 25):
                self.sig_figs = None
                self._logger.warning(
                    'arg sig_figs not valid, CmdStan version must be 2.25 '
                    'or higher, using verson %s in directory %s',
                    os.path.basename(cmdstan_path()),
                    os.path.dirname(cmdstan_path()),
                )

        if self.seed is None:
            rng = RandomState()
            self.seed = rng.randint(1, 99999 + 1)
        else:
            if not isinstance(self.seed, (int, list)):
                raise ValueError(
                    'seed must be an integer between 0 and 2**32-1,'
                    ' found {}'.format(self.seed))
            if isinstance(self.seed, int):
                if self.seed < 0 or self.seed > 2**32 - 1:
                    raise ValueError(
                        'seed must be an integer between 0 and 2**32-1,'
                        ' found {}'.format(self.seed))
            else:
                if self.chain_ids is None:
                    raise ValueError(
                        'seed must not be a list when no chains used')

                if len(self.seed) != len(self.chain_ids):
                    raise ValueError(
                        'number of seeds must match number of chains,'
                        ' found {} seed for {} chains '.format(
                            len(self.seed), len(self.chain_ids)))
                for i in range(len(self.seed)):
                    if self.seed[i] < 0 or self.seed[i] > 2**32 - 1:
                        raise ValueError('seed must be an integer value'
                                         ' between 0 and 2**32-1,'
                                         ' found {}'.format(self.seed[i]))

        if isinstance(self.data, str):
            if not os.path.exists(self.data):
                raise ValueError('no such file {}'.format(self.data))
        elif self.data is not None and not isinstance(self.data, (str, dict)):
            raise ValueError('data must be string or dict')

        if self.inits is not None:
            if isinstance(self.inits, (Integral, Real)):
                if self.inits < 0:
                    raise ValueError('inits must be > 0, found {}'.format(
                        self.inits))
            elif isinstance(self.inits, str):
                if not os.path.exists(self.inits):
                    raise ValueError('no such file {}'.format(self.inits))
            elif isinstance(self.inits, list):
                if self.chain_ids is None:
                    raise ValueError(
                        'inits must not be a list when no chains are used')

                if len(self.inits) != len(self.chain_ids):
                    raise ValueError(
                        'number of inits files must match number of chains,'
                        ' found {} inits files for {} chains '.format(
                            len(self.inits), len(self.chain_ids)))
                names_set = set(self.inits)
                if len(names_set) != len(self.inits):
                    raise ValueError('each chain must have its own init file,'
                                     ' found duplicates in inits files list.')
                for i in range(len(self.inits)):
                    if not os.path.exists(self.inits[i]):
                        raise ValueError('no such file {}'.format(
                            self.inits[i]))