示例#1
0
 def test_check_retcodes(self):
     exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION)
     jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
     sampler_args = SamplerArgs()
     cmdstan_args = CmdStanArgs(
         model_name='bernoulli',
         model_exe=exe,
         chain_ids=[1, 2, 3, 4],
         data=jdata,
         method_args=sampler_args,
     )
     fit = StanFit(args=cmdstan_args, chains=4)
     retcodes = fit._retcodes
     self.assertEqual(4, len(retcodes))
     for i in range(len(retcodes)):
         self.assertEqual(-1, fit._retcode(i))
     fit._set_retcode(0, 0)
     self.assertEqual(0, fit._retcode(0))
     for i in range(1, len(retcodes)):
         self.assertEqual(-1, fit._retcode(i))
     self.assertFalse(fit._check_retcodes())
     for i in range(1, len(retcodes)):
         fit._set_retcode(i, 0)
     self.assertTrue(fit._check_retcodes())
示例#2
0
    def test_validate_good_run(self):
        # construct fit using existing sampler output
        exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION)
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        output = os.path.join(goodfiles_path, 'bern')
        sampler_args = SamplerArgs(sampling_iters=100,
                                   max_treedepth=11,
                                   adapt_delta=0.95)
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        fit = StanFit(args=cmdstan_args, chains=4)
        retcodes = fit._retcodes
        for i in range(len(retcodes)):
            fit._set_retcode(i, 0)
        self.assertTrue(fit._check_retcodes())
        fit._check_console_msgs()
        fit._validate_csv_files()
        self.assertEqual(4, fit.chains)
        self.assertEqual(100, fit.draws)
        self.assertEqual(8, len(fit.column_names))
        self.assertEqual('lp__', fit.column_names[0])

        df = fit.get_drawset()
        self.assertEqual(df.shape,
                         (fit.chains * fit.draws, len(fit.column_names)))
        _ = fit.summary()

        # TODO - use cmdstan test files instead
        expected = '\n'.join([
            'Checking sampler transitions treedepth.',
            'Treedepth satisfactory for all transitions.',
            '\nChecking sampler transitions for divergences.',
            'No divergent transitions found.',
            '\nChecking E-BFMI - sampler transitions HMC potential energy.',
            'E-BFMI satisfactory for all transitions.',
            '\nEffective sample size satisfactory.',
        ])
        self.assertIn(expected, fit.diagnose().replace("\r\n", "\n"))
示例#3
0
    def sample(
        self,
        data: Union[Dict, str] = None,
        chains: int = 4,
        cores: Union[int, None] = None,
        seed: Union[int, List[int]] = None,
        chain_ids: Union[int, List[int]] = None,
        inits: Union[Dict, float, str, List[str]] = None,
        warmup_iters: int = None,
        sampling_iters: int = None,
        save_warmup: bool = False,
        thin: int = None,
        max_treedepth: float = None,
        metric: Union[str, List[str]] = None,
        step_size: Union[float, List[float]] = None,
        adapt_engaged: bool = True,
        adapt_delta: float = None,
        csv_basename: str = None,
        show_progress: bool = False,
    ) -> StanFit:
        """
        Run or more chains of the NUTS sampler to produce a set of draws
        from the posterior distribution of a model conditioned on some data.

        This function validates the specified configuration, composes a call to
        the CmdStan ``sample`` method and spawns one subprocess per chain to run
        the sampler and waits for all chains to run to completion.
        Unspecified arguments are not included in the call to CmdStan, i.e.,
        those arguments will have CmdStan default values.

        For each chain, the ``StanFit`` object records the command,
        the return code, the sampler output file paths, and the corresponding
        subprocess console outputs, if any.

        :param data: Values for all data variables in the model, specified
            either as a dictionary with entries matching the data variables,
            or as the path of a data file in JSON or Rdump format.

        :param chains: Number of sampler chains, should be > 1.

        :param cores: Number of processes to run in parallel. Must be an
            integer between 1 and the number of CPUs in the system.
            If none then set automatically to `chains` but no more
            than `total_cpu_count - 2`

        :param seed: The seed for random number generator or a list of per-chain
            seeds. Must be an integer between 0 and 2^32 - 1. If unspecified,
            numpy.random.RandomState() is used to generate a seed which will be
            used for all chains. When the same seed is used across all chains,
            the chain-id is used to advance the RNG to avoid dependent samples.

        :param chain_ids: The offset for the random number generator, either
            an integer or a list of unique per-chain offsets.  If unspecified,
            chain ids are numbered sequentially starting from 1.

        :param inits: Specifies how the sampler initializes parameter values.
            Initializiation is either uniform random on a range centered on 0,
            exactly 0, or a dictionary or file of initial values for some or all
            parameters in the model.  The default initialization behavoir will
            initialize all parameter values on range [-2, 2] on the
            _unconstrained_ support.  If the expected parameter values are
            too far from this range, this option may improve adaptation.
            The following value types are allowed:

            * Single number ``n > 0`` - initialization range is [-n, n].
            * ``0`` - all parameters are initialized to 0.
            * dictionary - pairs parameter name : initial value.
            * string - pathname to a JSON or Rdump data file.
            * list of strings - per-chain pathname to data file.

        :param warmup_iters: Number of warmup iterations for each chain.

        :param sampling_iters: Number of draws from the posterior for each
            chain.

        :param save_warmup: When True, sampler saves warmup draws as part of
            the Stan csv output file.

        :param thin: Period between saved samples.

        :param max_treedepth: Maximum depth of trees evaluated by NUTS sampler
            per iteration.

        :param metric: Specification of the mass matrix, either as a
            vector consisting of the diagonal elements of the covariance
            matrix (``diag`` or ``diag_e``) or the full covariance matrix
            (``dense`` or ``dense_e``).

            If the value of the metric argument is a string other than
            ``diag``, ``diag_e``, ``dense``, or ``dense_e``, it must be
            a valid filepath to a JSON or Rdump file which contains an entry
            ``inv_metric`` whose value is either the diagonal vector or
            the full covariance matrix.

            If the value of the metric argument is a list of paths, its
            length must match the number of chains and all paths must be
            unique.

        :param step_size: Initial stepsize for HMC sampler.  The value is either
            a single number or a list of numbers which will be used as the
            global or per-chain initial step_size, respectively.
            The length of the list of step sizes must match the number of
            chains.

        :param adapt_engaged: When True, adapt stepsize and metric.
            *Note: If True, ``warmup_iters`` must be > 0.*

        :param adapt_delta: Adaptation target Metropolis acceptance rate.
            The default value is 0.8.  Increasing this value, which must be
            strictly less than 1, causes adaptation to use smaller step sizes.
            It improves the effective sample size, but may increase the time
            per iteration.

        :param csv_basename: A path or file name which will be used as the
            base name for the sampler output files.  The csv output files
            for each chain are written to file ``<basename>-<chain_id>.csv``
            and the console output and error messages are written to file
            ``<basename>-<chain_id>.txt``.
        """
        if chains < 1:
            raise ValueError(
                'chains must be a positive integer value, found {}'.format(
                    chains))

        if chain_ids is None:
            chain_ids = [x + 1 for x in range(chains)]
        else:
            if type(chain_ids) is int:
                if chain_ids < 1:
                    raise ValueError(
                        'chain_id must be a positive integer value,'
                        ' found {}'.format(chain_ids))
                offset = chain_ids
                chain_ids = [x + offset + 1 for x in range(chains)]
            else:
                if not len(chain_ids) == chains:
                    raise ValueError(
                        'chain_ids must correspond to number of chains'
                        ' specified {} chains, found {} chain_ids'.format(
                            chains, len(chain_ids)))
                for i in len(chain_ids):
                    if chain_ids[i] < 1:
                        raise ValueError(
                            'chain_id must be a positive integer value,'
                            ' found {}'.format(chain_ids[i]))

        cores_avail = cpu_count()

        if cores is None:
            cores = max(min(cores_avail - 2, chains), 1)

        if cores < 1:
            raise ValueError(
                'cores must be a positive integer value, found {}'.format(
                    cores))
        if cores > cores_avail:
            self._logger.warning('requested %u cores, only %u available',
                                 cores, cpu_count())
            cores = cores_avail

            # TODO:  issue 49: inits can be initialization function

        sampler_args = SamplerArgs(
            warmup_iters=warmup_iters,
            sampling_iters=sampling_iters,
            save_warmup=save_warmup,
            thin=thin,
            max_treedepth=max_treedepth,
            metric=metric,
            step_size=step_size,
            adapt_engaged=adapt_engaged,
            adapt_delta=adapt_delta,
        )
        with MaybeDictToFilePath(data, inits) as (_data, _inits):
            args = CmdStanArgs(
                self._name,
                self._exe_file,
                chain_ids=chain_ids,
                data=_data,
                seed=seed,
                inits=_inits,
                output_basename=csv_basename,
                method_args=sampler_args,
            )

            stanfit = StanFit(args=args, chains=chains)
            try:
                tp = ThreadPool(cores)
                for i in range(chains):
                    tp.apply_async(self._do_sample, (stanfit, i))
            finally:
                tp.close()
                tp.join()
            if not stanfit._check_retcodes():
                msg = 'Error during sampling'
                for i in range(chains):
                    if stanfit._retcode(i) != 0:
                        msg = '{}, chain {} returned error code {}'.format(
                            msg, i, stanfit._retcode(i))
                raise RuntimeError(msg)
            stanfit._validate_csv_files()
        return stanfit
示例#4
0
    def optimize(
        self,
        data: Union[Dict, str] = None,
        seed: int = None,
        inits: Union[Dict, float, str] = None,
        csv_basename: str = None,
        algorithm: str = None,
        init_alpha: float = None,
        iter: int = None,
    ) -> StanFit:
        """
        Wrapper for optimize call
        :param data: Values for all data variables in the model, specified
            either as a dictionary with entries matching the data variables,
            or as the path of a data file in JSON or Rdump format.

        :param seed: The seed for random number generator Must be an integer
            between 0 and 2^32 - 1. If unspecified, numpy.random.RandomState()
            is used to generate a seed which will be used for all chains.

        :param inits:  Specifies how the sampler initializes parameter values.
            Initializiation is either uniform random on a range centered on 0,
            exactly 0, or a dictionary or file of initial values for some or
            all parameters in the model.  The default initialization behavoir
            will initialize all parameter values on range [-2, 2] on the
            _unconstrained_ support.  If the expected parameter values are
            too far from this range, this option may improve adaptation.
            The following value types are allowed:

            * Single number ``n > 0`` - initialization range is [-n, n].
            * ``0`` - all parameters are initialized to 0.
            * dictionary - pairs parameter name : initial value.
            * string - pathname to a JSON or Rdump data file.

        :param csv_basename:  A path or file name which will be used as the
            base name for the sampler output files.  The csv output files
            for each chain are written to file ``<basename>-0.csv``
            and the console output and error messages are written to file
            ``<basename>-0.txt``.

        :param algorithm: Algorithm to use. One of: "BFGS", "LBFGS", "Newton"

        :param init_alpha: Line search step size for first iteration

        :param iter: Total number of iterations

        :return: StanFit object
        """

        optimize_args = OptimizeArgs(algorithm=algorithm,
                                     init_alpha=init_alpha,
                                     iter=iter)

        with MaybeDictToFilePath(data, inits) as (_data, _inits):
            args = CmdStanArgs(
                self._name,
                self._exe_file,
                chain_ids=None,
                data=_data,
                seed=seed,
                inits=_inits,
                output_basename=csv_basename,
                method_args=optimize_args,
            )

            stanfit = StanFit(args=args, chains=1)
            dummy_chain_id = 0
            self._do_sample(stanfit, dummy_chain_id)

        if not stanfit._check_retcodes():
            msg = 'Error during optimizing'
            if stanfit._retcode(dummy_chain_id) != 0:
                msg = '{} Got returned error code {}'.format(
                    msg, stanfit._retcode(dummy_chain_id))
            raise RuntimeError(msg)
        stanfit._validate_csv_files()
        return stanfit
示例#5
0
    def run_generated_quantities(
        self,
        data: Union[Dict, str] = None,
        csv_files: List[str] = None,
        seed: int = None,
        gq_csv_basename: str = None,
    ) -> StanFit:
        """
        Wrapper for generated quantities call.  Given a StanFit object
        containing a sample from the fitted model, along with the
        corresponding dataset for that fit, run just the generated quantities
        block of the model in order to get additional quantities of interest.

        :param data: Values for all data variables in the model, specified
            either as a dictionary with entries matching the data variables,
            or as the path of a data file in JSON or Rdump format.

        :param csv_files: A list of sampler output csv files generated by
            fitting the model to the data, either using CmdStanPy's `sample`
            method or via another Stan interface.

        :param seed: The seed for random number generator Must be an integer
            between 0 and 2^32 - 1. If unspecified, numpy.random.RandomState()
            is used to generate a seed which will be used for all chains.
            *NOTE: Specifying the seed will guarantee the same result for
            multiple invocations of this method with the same inputs.  However
            this will not reproduce results from the sample method given
            the same inputs because the RNG will be in a different state.*

        :param gq_csv_basename: A path or file name which will be used as the
            basename for the sampler output files.  The csv output files
            for each chain are written to file ``<basename>-<chain_id>.csv``
            and the console output and error messages are written to file
            ``<basename>-<chain_id>.txt``.

        :return: StanFit object
        """
        generate_quantities_args = GenerateQuantitiesArgs(csv_files=csv_files)
        generate_quantities_args.validate(len(csv_files))
        chains = len(csv_files)
        with MaybeDictToFilePath(data, None) as (_data, _inits):
            args = CmdStanArgs(self._name,
                               self._exe_file,
                               chain_ids=[x + 1 for x in range(chains)],
                               data=_data,
                               seed=seed,
                               output_basename=gq_csv_basename,
                               method_args=generate_quantities_args)
            stanfit = StanFit(args=args, chains=chains)

            cores_avail = cpu_count()
            cores = max(min(cores_avail - 2, chains), 1)
            with ThreadPoolExecutor(max_workers=cores) as executor:
                for i in range(chains):
                    executor.submit(self._run_cmdstan(stanfit, i))
            if not stanfit._check_retcodes():
                msg = 'Error during sampling'
                for i in range(chains):
                    if stanfit._retcode(i) != 0:
                        msg = '{}, chain {} returned error code {}'.format(
                            msg, i, stanfit._retcode(i))
                raise RuntimeError(msg)
            stanfit._set_attrs_gq_csv_files(csv_files[0])
        return stanfit
示例#6
0
    def test_validate_bad_run(self):
        exe = os.path.join(datafiles_path, 'bernoulli' + EXTENSION)
        jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
        sampler_args = SamplerArgs(sampling_iters=100,
                                   max_treedepth=11,
                                   adapt_delta=0.95)

        # some chains had errors
        output = os.path.join(badfiles_path, 'bad-transcript-bern')
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        fit = StanFit(args=cmdstan_args, chains=4)
        with self.assertRaisesRegex(Exception, 'Exception'):
            fit._check_console_msgs()

        # csv file headers inconsistent
        output = os.path.join(badfiles_path, 'bad-hdr-bern')
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        fit = StanFit(args=cmdstan_args, chains=4)
        retcodes = fit._retcodes
        for i in range(len(retcodes)):
            fit._set_retcode(i, 0)
        self.assertTrue(fit._check_retcodes())
        with self.assertRaisesRegex(ValueError, 'header mismatch'):
            fit._validate_csv_files()

        # bad draws
        output = os.path.join(badfiles_path, 'bad-draws-bern')
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        fit = StanFit(args=cmdstan_args, chains=4)
        retcodes = fit._retcodes
        for i in range(len(retcodes)):
            fit._set_retcode(i, 0)
        self.assertTrue(fit._check_retcodes())
        with self.assertRaisesRegex(ValueError, 'draws'):
            fit._validate_csv_files()

        # mismatch - column headers, draws
        output = os.path.join(badfiles_path, 'bad-cols-bern')
        cmdstan_args = CmdStanArgs(
            model_name='bernoulli',
            model_exe=exe,
            chain_ids=[1, 2, 3, 4],
            seed=12345,
            data=jdata,
            output_basename=output,
            method_args=sampler_args,
        )
        fit = StanFit(args=cmdstan_args, chains=4)
        retcodes = fit._retcodes
        for i in range(len(retcodes)):
            fit._set_retcode(i, 0)
        self.assertTrue(fit._check_retcodes())
        with self.assertRaisesRegex(ValueError, 'bad draw'):
            fit._validate_csv_files()