Exemplo n.º 1
0
 def test_args_iters_adapt_mismatch(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     with self.assertRaises(ValueError):
         args = SamplerArgs(model,
                            chain_ids=[1, 2],
                            warmup_iters=0,
                            adapt_engaged=True)
Exemplo n.º 2
0
 def test_args_bad_inits_files_2(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     jinits = os.path.join(datafiles_path, 'bernoulli.init.json')
     with self.assertRaises(ValueError):
         args = SamplerArgs(model,
                            chain_ids=[1, 2],
                            inits=[jinits, 'no/such/file.json'])
Exemplo n.º 3
0
 def test_args_good(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     rdata = os.path.join(datafiles_path, 'bernoulli.data.R')
     output = os.path.join(TMPDIR, 'bernoulli.output')
     args = SamplerArgs(model,
                        chain_ids=[1, 2],
                        seed=12345,
                        data=rdata,
                        output_file=output,
                        max_treedepth=15,
                        adapt_delta=0.99)
     cmd = args.compose_command(0, ''.join([output, '-1.csv']))
     self.assertIn('random seed=12345', cmd)
     self.assertIn('data file=', cmd)
     self.assertIn(
         'algorithm=hmc engine=nuts max_depth=15 adapt delta=0.99', cmd)
Exemplo n.º 4
0
 def test_args_bad_metric_file_2(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     jmetric = os.path.join(datafiles_path, 'bernoulli.metric.json')
     with self.assertRaises(ValueError):
         args = SamplerArgs(model,
                            chain_ids=[1, 2],
                            metric=[jmetric, jmetric])
Exemplo n.º 5
0
 def test_args_bad_seed_4(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     output = os.path.join(TMPDIR, 'bernoulli.output')
     with self.assertRaises(ValueError):
         args = SamplerArgs(
             model, chain_ids=[1, 2], output_file=output, seed=4294967299
         )
Exemplo n.º 6
0
 def test_args_bad_data(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     output = os.path.join(TMPDIR, 'bernoulli.output')
     with self.assertRaises(ValueError):
         args = SamplerArgs(model,
                            chain_ids=[1, 2],
                            output_file=output,
                            data='/no/such/path/to.file')
Exemplo n.º 7
0
 def test_args_iters_schedule_mismatch(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     with self.assertRaises(ValueError):
         args = SamplerArgs(
             model,
             chain_ids=[1, 2],
             warmup_iters=0,
             warmup_schedule=(0.1, 0.8, 0.1),
         )
Exemplo n.º 8
0
 def test_args_bad_inits_files_1(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     jinits1 = os.path.join(datafiles_path, 'bernoulli.init_1.json')
     jinits2 = os.path.join(datafiles_path, 'bernoulli.init_2.json')
     jinits3 = os.path.join(datafiles_path, 'bernoulli.init.json')
     with self.assertRaises(ValueError):
         args = SamplerArgs(model,
                            chain_ids=[1, 2],
                            inits=[jinits1, jinits2, jinits3])
Exemplo n.º 9
0
 def test_args_typical(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
     output = os.path.join(TMPDIR, 'bernoulli.output')
     args = SamplerArgs(model,
                        chain_ids=[1, 2],
                        seed=12345,
                        sampling_iters=100,
                        data=jdata,
                        output_file=output,
                        max_treedepth=11,
                        adapt_delta=0.9)
     cmd = args.compose_command(0, ''.join([output, '-1.csv']))
     self.assertIn('bernoulli', cmd)
     self.assertIn('seed=12345', cmd)
     self.assertIn('num_samples=100', cmd)
     self.assertIn('bernoulli.data.json', cmd)
     self.assertIn('algorithm=hmc engine=nuts max_depth=11 adapt delta=0.9',
                   cmd)
Exemplo n.º 10
0
 def test_validate_bad_transcript(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
     output = os.path.join(badfiles_path, 'bad-transcript-bern')
     args = SamplerArgs(model, chain_ids=[1,2,3,4],
                        seed=12345,
                        data=jdata,
                        output_file=output,
                        sampling_iters=100,
                        max_treedepth=11,
                        adapt_delta=0.95)
     runset = RunSet(chains=4, args=args)
     with self.assertRaisesRegex(Exception, 'Exception'):
         runset.check_console_msgs()
Exemplo n.º 11
0
 def test_validate_bad_hdr(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     jdata = os.path.join(datafiles_path, 'bernoulli.data.json')
     output = os.path.join(badfiles_path, 'bad-hdr-bern')
     args = SamplerArgs(model, chain_ids=[1,2,3,4],
                        seed=12345,
                        data=jdata,
                        output_file=output,
                        sampling_iters=100,
                        max_treedepth=11,
                        adapt_delta=0.95)
     runset = RunSet(chains=4, args=args)
     retcodes = runset.retcodes
     for i in range(len(retcodes)):
         runset.set_retcode(i, 0)
     self.assertTrue(runset.check_retcodes())
     with self.assertRaisesRegex(ValueError, 'header mismatch'):
         runset.validate_csv_files()
Exemplo n.º 12
0
 def test_sample_big(self):
     # construct runset using existing sampler output
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     output = os.path.join(datafiles_path, 'runset-big', 'output_icar_nyc')
     args = SamplerArgs(model, chain_ids=[1, 2], output_file=output)
     runset = RunSet(chains=2, args=args)
     runset.validate_csv_files()
     runset.assemble_sample()
     sampler_state = [
         'lp__',
         'accept_stat__',
         'stepsize__',
         'treedepth__',
         'n_leapfrog__',
         'divergent__',
         'energy__',
     ]
     phis = ['phi.{}'.format(str(x + 1)) for x in range(2095)]
     column_names = sampler_state + phis
     self.assertEqual(runset.columns, len(column_names))
     self.assertEqual(runset.column_names, tuple(column_names))
     self.assertEqual(runset.metric_type, 'diag_e')
     self.assertEqual(runset.stepsize.shape, (2, ))
     self.assertEqual(runset.metric.shape, (2, 2095))
     self.assertEqual((1000, 2, 2102), runset.sample.shape)
     phis = get_drawset(runset, params=['phi'])
     self.assertEqual((2000, 2095), phis.shape)
     phi1 = get_drawset(runset, params=['phi.1'])
     self.assertEqual((2000, 1), phi1.shape)
     mo_phis = get_drawset(runset, params=['phi.1', 'phi.10', 'phi.100'])
     self.assertEqual((2000, 3), mo_phis.shape)
     phi2095 = get_drawset(runset, params=['phi.2095'])
     self.assertEqual((2000, 1), phi2095.shape)
     with self.assertRaises(Exception):
         get_drawset(runset, params=['phi.2096'])
     with self.assertRaises(Exception):
         get_drawset(runset, params=['ph'])
Exemplo n.º 13
0
    def test_diagnose_divergences(self):
        stan = os.path.join(datafiles_path, 'bernoulli.stan')
        exe = os.path.join(datafiles_path, 'bernoulli')
        model = Model(exe_file=exe, stan_file=stan)
        output = os.path.join(datafiles_path, 'diagnose-good',
                              'corr_gauss_depth8')
        args = SamplerArgs(model, chain_ids=[1], output_file=output)
        runset = RunSet(args=args, chains=1)

        # TODO - use cmdstan test files instead
        expected = ''.join([
            '424 of 1000 (42%) transitions hit the maximum ',
            'treedepth limit of 8, or 2^8 leapfrog steps. ',
            'Trajectories that are prematurely terminated ',
            'due to this limit will result in slow ',
            'exploration and you should increase the ',
            'limit to ensure optimal performance.\n',
        ])

        capturedOutput = io.StringIO()
        sys.stdout = capturedOutput
        diagnose(runset)
        sys.stdout = sys.__stdout__
        self.assertEqual(capturedOutput.getvalue(), expected)
Exemplo n.º 14
0
 def test_args_iters_4(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     with self.assertRaises(ValueError):
         args = SamplerArgs(model, chain_ids=[1, 2], sampling_iters=-123)
Exemplo n.º 15
0
 def test_args_chain_ids_bad(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     with self.assertRaisesRegex(ValueError, 'invalid chain_id -99'):
         args = SamplerArgs(model, chain_ids=[7, -99])
Exemplo n.º 16
0
def sample(
    stan_model: Model,
    data: Union[Dict, str] = None,
    chains: int = 4,
    cores: int = 1,
    seed: Union[int, List[int]] = None,
    chain_ids: Union[int, List[int]] = None,
    inits: Union[Dict, float, str, List[str]] = None,
    warmup_iters: int = None,
    sampling_iters: int = None,
    warmup_schedule: Tuple[float, float, float] = (0.15, 0.75, 0.10),
    save_warmup: bool = False,
    thin: int = None,
    max_treedepth: float = None,
    metric: Union[str, List[str]] = None,
    step_size: Union[float, List[float]] = None,
    adapt_engaged: bool = True,
    adapt_delta: float = None,
    csv_output_file: str = None,
    show_progress: bool = False,
) -> RunSet:
    """
    Run or more chains of the NUTS sampler to produce a set of draws
    from the posterior distribution of a model conditioned on some data.
    The caller must specify the model and data; all other arguments
    are optional.

    This function validates the specified configuration, composes a call to
    the CmdStan ``sample`` method and spawns one subprocess per chain to run
    the sampler and waits for all chains to run to completion.
    The composed call to CmdStan omits arguments left unspecified (i.e., value
    is ``None``) so that the default CmdStan configuration values will be used.

    For each chain, the ``RunSet`` object records the command, the return code,
    the paths to the sampler output files, and the corresponding subprocess
    console outputs, if any.

    :param stan_model: Compiled Stan model.

    :param data: Values for all data variables in the model, specified either
        as a dictionary with entries matching the data variables,
        or as the path of a data file in JSON or Rdump format.

    :param chains: Number of sampler chains, should be > 1.

    :param cores: Number of processes to run in parallel. Must be an integer
        between 1 and the number of CPUs in the system.

    :param seed: The seed for random number generator or a list of per-chain
        seeds. Must be an integer between 0 and 2^32 - 1. If unspecified,
        numpy.random.RandomState() is used to generate a seed which will be
        used for all chains. When the same seed is used across all chains,
        the chain-id is used to advance the RNG to avoid dependent samples.

    :param chain_ids: The offset for the random number generator, either
        an integer or a list of unique per-chain offsets.  If unspecified,
        chain ids are numbered sequentially starting from 1.

    :param inits: Specifies how the sampler initializes parameter values.
        Initializiation is either uniform random on a range centered on 0,
        exactly 0, or a dictionary or file of initial values for some or all
        parameters in the model.  The default initialization behavoir will
        initialize all parameter values on range [-2, 2].  If these values
        are too far from the expected parameter values, explicit initialization
        may improve adaptation. The following value types are allowed:

        * Single number ``n > 0`` - initialization range is [-n, n].
        * ``0`` - all parameters are initialized to 0.
        * dictionary - pairs parameter name : initial value.
        * string - pathname to a JSON or Rdump file of initial parameter values.
        * list of strings - per-chain pathname to data file.

    :param warmup_iters: Number of iterations during warmup for each chain.

    :param sampling_iters: Number of draws from the posterior for each chain.

    :param warmup_schedule: Triple specifying fraction of total warmup
        iterations allocated to each adaptation phase.  The default schedule
        is (.15, .75, .10) where:

        * Phase I is "fast" adaptation to find the typical set
        * Phase II is "slow" adaptation to find the metric
        * Phase III is "fast" adaptation to find the step_size.

        For further details, see the Stan Reference Manual, section
        HMC algorithm parameters.

    :param save_warmup: When True, sampler saves warmup draws as part of
        the Stan csv output file.

    :param thin: Period between saved samples.

    :param max_treedepth: Maximum depth of trees evaluated by NUTS sampler
        per iteration.

    :param metric: Specification of the mass matrix, either as a
        vector consisting of the diagonal elements of the covariance
        matrix (``diag`` or ``diag_e``) or the full covariance matrix
        (``dense`` or ``dense_e``).

        If the value of the metric argument is a string other than
        ``diag``, ``diag_e``, ``dense``, or ``dense_e``, it must be
        a valid filepath to a JSON or Rdump file which contains an entry
        ``inv_metric`` whose value is either the diagonal vector or
        the full covariance matrix. This can be used to restart sampling
        with no adaptation given the outputs of all chains from a previous run.

        If the value of the metric argument is a list of paths, its
        length must match the number of chains and all paths must be
        unique.

    :param step_size: Initial stepsize for HMC sampler.  The value is either
        a single number or a list of numbers which will be used as the global
        or per-chain initial step_size, respectively.

        The length of the list of step sizes must match the number of chains.
        This feature can be used to restart sampling with no adaptation
        given the outputs of all chains from a previous run.

    :param adapt_engaged: When True, adapt stepsize, metric.
        *Note: If True, ``warmup_iters`` must be > 0.*

    :param adapt_delta: Adaptation target Metropolis acceptance rate.
        The default value is 0.8.  Increasing this value, which must be
        strictly less than 1, causes adaptation to use smaller step sizes.
        It improves the effective sample size, but may increase the time
        per iteration.

    :param csv_output_file: A path or file name which will be used as the
        base name for the sampler output files.  The csv output files
        for each chain are written to file ``<basename>-<chain_id>.csv``
        and the console output and error messages are written to file
        ``<basename>-<chain_id>.txt``.

    :param show_progress: When True, command sends progress messages to
        console. When False, command executes silently.
    """
    if chains < 1:
        raise ValueError(
            'chains must be a positive integer value, found {}'.format(chains))

    if chain_ids is None:
        chain_ids = [x + 1 for x in range(chains)]
    else:
        if type(chain_ids) is int:
            if chain_ids < 1:
                raise ValueError('chain_id must be a positive integer value,'
                                 ' found {}'.format(chain_ids))
            offset = chain_ids
            chain_ids = [x + offset + 1 for x in range(chains)]
        else:
            if not len(chain_ids) == chains:
                raise ValueError(
                    'chain_ids must correspond to number of chains'
                    ' specified {} chains, found {} chain_ids'.format(
                        chains, len(chain_ids)))
            for i in len(chain_ids):
                if chain_ids[i] < 1:
                    raise ValueError(
                        'chain_id must be a positive integer value,'
                        ' found {}'.format(chain_ids[i]))

    if cores < 1:
        raise ValueError(
            'cores must be a positive integer value, found {}'.format(cores))
    if cores > cpu_count():
        print('requested {} cores, only {} available'.format(
            cores, cpu_count()))
        cores = cpu_count()

    if data is not None:
        if isinstance(data, dict):
            with tempfile.NamedTemporaryFile(mode='w+',
                                             suffix='.json',
                                             dir=TMPDIR,
                                             delete=False) as fd:
                data_file = fd.name
                print('input data tempfile: {}'.format(fd.name))
            sd = StanData(data_file)
            sd.write_json(data)
            data_dict = data
            data = data_file

    if inits is not None:
        if isinstance(inits, dict):
            with tempfile.NamedTemporaryFile(mode='w+',
                                             suffix='.json',
                                             dir=TMPDIR,
                                             delete=False) as fd:
                inits_file = fd.name
                print('inits tempfile: {}'.format(fd.name))
            sd = StanData(inits_file)
            sd.write_json(inits)
            inits_dict = inits
            inits = inits_file
        # TODO:  issue 49: inits can be initialization function

    args = SamplerArgs(
        model=stan_model,
        chain_ids=chain_ids,
        data=data,
        seed=seed,
        inits=inits,
        warmup_iters=warmup_iters,
        sampling_iters=sampling_iters,
        warmup_schedule=warmup_schedule,
        save_warmup=save_warmup,
        thin=thin,
        max_treedepth=max_treedepth,
        metric=metric,
        step_size=step_size,
        adapt_engaged=adapt_engaged,
        adapt_delta=adapt_delta,
        output_file=csv_output_file,
    )

    runset = RunSet(args=args, chains=chains)
    try:
        tp = ThreadPool(cores)
        for i in range(chains):
            tp.apply_async(do_sample, (runset, i))
    finally:
        tp.close()
        tp.join()
    if not runset.check_retcodes():
        msg = 'Error during sampling'
        for i in range(chains):
            if runset.retcode(i) != 0:
                msg = '{}, chain {} returned error code {}'.format(
                    msg, i, runset.retcode(i))
        raise Exception(msg)
    runset.validate_csv_files()
    return runset
Exemplo n.º 17
0
 def test_args_missing_args_1(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     with self.assertRaises(Exception):
         args = SamplerArgs()
Exemplo n.º 18
0
 def test_args_adapt_delta_3(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     with self.assertRaises(ValueError):
         args = SamplerArgs(model, chain_ids=[1, 2], adapt_delta=1.3)
Exemplo n.º 19
0
 def test_args_step_size_bad_1(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     with self.assertRaises(ValueError):
         args = SamplerArgs(model, chain_ids=[1, 2], step_size=-0.99)
Exemplo n.º 20
0
 def test_args_missing_args_2(self):
     with self.assertRaises(Exception):
         args = SamplerArgs(model)
Exemplo n.º 21
0
 def test_args_max_treedepth_bad(self):
     stan = os.path.join(datafiles_path, 'bernoulli.stan')
     exe = os.path.join(datafiles_path, 'bernoulli')
     model = Model(exe_file=exe, stan_file=stan)
     with self.assertRaises(ValueError):
         args = SamplerArgs(model, chain_ids=[1, 2], max_treedepth=-3)