Exemplo n.º 1
0
        def sample_chains_arviz(self,
                                n_sample,
                                init_states,
                                chain_var_funcs=None,
                                n_process=1):
            """Sample one or more Markov chains from given initial states.

            Performs a specified number of chain iterations (each of which may
            be composed of multiple individual Markov transitions), recording
            the outputs of functions of the sampled chain state after each
            iteration. The chains may be run in parallel across multiple
            independent processes or sequentially. Chain data is returned in an
            `arviz.InferenceData` container object.

            Args:
                n_sample (int): Number of samples to draw per chain.
                init_states (Iterable[HamiltonianState] or Iterable[array]):
                   Initial chain states. Each state can be either an array
                   specifying the state position component or a
                   `HamiltonianState` instance. If an array is passed or the
                   `mom` attribute of the state is not set, a momentum
                   component will be independently sampled from its conditional
                   distribution. One chain will be run for each state in the
                   iterable sequence.
                chain_var_funcs (dict[str, callable]): Dictionary of functions
                   which compute the chain variables to be recorded at each
                   iteration, with each function being passed the current state
                   and returning an array corresponding to the variable(s) to
                   be stored. By default (or if set to `None`) a single
                   function which returns the position component of the state
                   is used. The keys to the functions are used to index the
                   chain variable arrays in the returned data.
                n_process (int or None): Number of parallel processes to run
                    chains over. If set to one then chains will be run
                    sequentially in otherwise a `multiprocessing.Pool` object
                    will be used to dynamically assign the chains across
                    multiple processes. If set to `None` then the number of
                    processes will default to the output of `os.cpu_count()`.

            Returns:
                arvix.InferenceData:
                    An arviz data container with groups `posterior` and
                    'sample_stats', both of instances of `xarray.Dataset`.
                    The `posterior` group corresponds to the chain variable
                    samples computed using the `chain_var_funcs` entries (with
                    the data variable keys corresponding to the keys there).
                    The `sample_stats` group corresponds to the statistics of
                    the integration transition such as the acceptance
                    probabilities and number of integrator steps.
            """
            chains, chain_stats = self.sample_chains(n_sample, init_states,
                                                     chain_var_funcs,
                                                     n_process)
            return arviz.InferenceData(
                posterior=arviz.dict_to_dataset(chains, library=hmc),
                sample_stats=arviz.dict_to_dataset(chain_stats, library=hmc))
Exemplo n.º 2
0
    def convert_to_arviz_inference_data(
            traces, chain_stats, sample_stats_key=None):
        """Wrap chain outputs in an `arviz.InferenceData` container object.

        The `traces` and `chain_stats` arguments should correspond to a
        multiple-chain sampler output i.e. the returned values from a
        `sample_chains` call.

        Args:
            traces (Dict[str, List[array]]): Trace arrays, with one entry per
                function in `trace_funcs` passed to sampler method. Each entry
                consists of a list of arrays, one per chain, with the first
                axes of the arrays corresponding to the sampling (draw) index.
            chain_stats (Dict[str, List[array]]): Chain integration transition
                statistics as a dictionary with string keys describing the
                statistics recorded and values corresponding to a list of
                arrays with one array per chain and the first axis of the
                arrays corresponding to the sampling index.
            sample_stats_key (str): Optional. Key of transition in
                `chain_stats` to use the recorded statistics of to populate the
                `sampling_stats` group in the returned `InferenceData` object.

        Returns:
            arviz.InferenceData:
                An arviz data container with groups `posterior` and
                'sample_stats', both of instances of `xarray.Dataset`. The
                `posterior` group corresponds to the chain variable traces
                provides in the `traces` argument and the `sample_stats`
                group corresponds to the chain transitions statistics passed
                in the `chain_stats` argument (if multiple transition
                statistics dictionaries are present the `sample_stats_key`
                argument should be specified to indicate which to use).
        """
        if (sample_stats_key is not None and
                sample_stats_key not in chain_stats):
            raise ValueError(
                f'Specified `sample_stats_key` ({sample_stats_key}) does '
                f'not match any transition in `chain_stats`.')
        if sample_stats_key is not None:
            return arviz.InferenceData(
                posterior=arviz.dict_to_dataset(traces, library=mici),
                sample_stats=arviz.dict_to_dataset(
                    chain_stats[sample_stats_key], library=mici))
        elif not isinstance(next(iter(chain_stats.values())), dict):
            # chain_stats dictionary value not another dictionary therefore
            # assume corresponds to statistics for a single transition
            return arviz.InferenceData(
                posterior=arviz.dict_to_dataset(traces, library=mici),
                sample_stats=arviz.dict_to_dataset(chain_stats, library=mici))
        elif len(chain_stats) == 1:
            # single transtition statistics dictionary in chain_stats therefore
            # unambiguous to set sample_stats
            return arviz.InferenceData(
                posterior=arviz.dict_to_dataset(traces, library=mici),
                sample_stats=arviz.dict_to_dataset(
                    chain_stats.popitem()[1], library=mici))
        else:
            raise ValueError(
                '`sample_stats_key` must be specified as `chain_stats` '
                'contains multiple transtitiion statistics dictionaries.')
Exemplo n.º 3
0
 def to_arviz(self, burn_in=0, thin=0):
     '''Exports posterior samples as Arviz object for visualisation.'''
     import arviz as az
     samples = self.get_results(burn_in=burn_in, thin=thin)
     nchains = len(set(samples['chain']))
     nsteps = len(set(samples['iter']))
     npars = len(self.par_names)
     par_dict = {}
     for k in self.par_names:
         X = samples.pivot_table(values=k, columns='iter', index='chain').values
         par_dict[k] = X
     posterior = az.dict_to_dataset(par_dict)
     return posterior
Exemplo n.º 4
0
        def sample_chains_arviz(self,
                                n_sample,
                                init_states,
                                chain_var_funcs=None,
                                sample_stats_key=None,
                                **kwargs):
            """Sample one or more Markov chains from given initial states.

            Performs a specified number of chain iterations (each of which may
            be composed of multiple individual Markov transitions), recording
            the outputs of functions of the sampled chain state after each
            iteration. The chains may be run in parallel across multiple
            independent processes or sequentially. Chain data is returned in an
            `arviz.InferenceData` container object.

            Args:
                n_sample (int): Number of samples to draw per chain.
                init_states (Iterable[ChainState] or Iterable[array]):
                    Initial chain states. Each entry can be either an array
                    specifying the state or a `ChainState` instance. One chain
                    will be run for each state in the iterable sequence.
                chain_var_funcs (dict[str, callable]): Dictionary of functions
                    which compute the chain variables to be recorded at each
                    iteration, with each function being passed the current
                    state and returning an array corresponding to the
                    variable(s) to be stored. The keys to the functions are
                    used to index the chain variable arrays in the returned
                    data.
                sample_stats_key (str): Key of transition to use the
                    recorded statistics of to populate the `sampling_stats`
                    group in the returned `InferenceData` object.

            Kwargs:
                n_process (int or None): Number of parallel processes to run
                    chains over. If set to one then chains will be run
                    sequentially in otherwise a `multiprocessing.Pool` object
                    will be used to dynamically assign the chains across
                    multiple processes. If set to `None` then the number of
                    processes will default to the output of `os.cpu_count()`.
                memmap_enabled (bool): Whether to memory-map arrays used to
                    store chain data to files on disk to avoid excessive system
                    memory usage for long chains and/or high memory chain
                    states. The chain data is written to `.npy` files in the
                    directory specified by `memmap_path` (or a temporary
                    directory if not provided). These files persist after the
                    termination of the function so should be manually deleted
                    when no longer required.
                memmap_path (str): Path to directory to write memory-mapped
                    chain data to. If not provided, a temporary directory will
                    be created and the chain data written to files there.

            Returns:
                arvix.InferenceData:
                    An arviz data container with groups `posterior` and
                    'sample_stats', both of instances of `xarray.Dataset`.
                    The `posterior` group corresponds to the chain variable
                    samples computed using the `chain_var_funcs` entries (with
                    the data variable keys corresponding to the keys there).
                    The `sample_stats` group corresponds to the statistics of
                    the transition indicated by the `sample_stats_key`
                    argument.
            """
            if (sample_stats_key is not None
                    and sample_stats_key not in self.transitions):
                raise ValueError(
                    f'Specified `sample_stats_key` ({sample_stats_key}) does '
                    f'not match any transition.')
            chains, chain_stats = self.sample_chains(n_sample, init_states,
                                                     chain_var_funcs, **kwargs)
            if sample_stats_key is None:
                return arviz.InferenceData(
                    posterior=arviz.dict_to_dataset(chains, library=hmc))
            else:
                return arviz.InferenceData(posterior=arviz.dict_to_dataset(
                    chains, library=hmc),
                                           sample_stats=arviz.dict_to_dataset(
                                               chain_stats[sample_stats_key],
                                               library=hmc))
Exemplo n.º 5
0
        def sample_chains_arviz(self, n_sample, init_states,
                                chain_var_funcs=None, **kwargs):
            """Sample one or more Markov chains from given initial states.

            Performs a specified number of chain iterations (each of which may
            be composed of multiple individual Markov transitions), recording
            the outputs of functions of the sampled chain state after each
            iteration. The chains may be run in parallel across multiple
            independent processes or sequentially. Chain data is returned in an
            `arviz.InferenceData` container object.

            Args:
                n_sample (int): Number of samples to draw per chain.
                init_states (Iterable[HamiltonianState] or Iterable[array]):
                   Initial chain states. Each state can be either an array
                   specifying the state position component or a
                   `HamiltonianState` instance. If an array is passed or the
                   `mom` attribute of the state is not set, a momentum
                   component will be independently sampled from its conditional
                   distribution. One chain will be run for each state in the
                   iterable sequence.
                chain_var_funcs (dict[str, callable]): Dictionary of functions
                   which compute the chain variables to be recorded at each
                   iteration, with each function being passed the current state
                   and returning an array corresponding to the variable(s) to
                   be stored. By default (or if set to `None`) a single
                   function which returns the position component of the state
                   is used. The keys to the functions are used to index the
                   chain variable arrays in the returned data.

            Kwargs:
                n_process (int or None): Number of parallel processes to run
                    chains over. If set to one then chains will be run
                    sequentially in otherwise a `multiprocessing.Pool` object
                    will be used to dynamically assign the chains across
                    multiple processes. If set to `None` then the number of
                    processes will default to the output of `os.cpu_count()`.
                memmap_enabled (bool): Whether to memory-map arrays used to
                    store chain data to files on disk to avoid excessive system
                    memory usage for long chains and/or high memory chain
                    states. The chain data is written to `.npy` files in the
                    directory specified by `memmap_path` (or a temporary
                    directory if not provided). These files persist after the
                    termination of the function so should be manually deleted
                    when no longer required.
                memmap_path (str): Path to directory to write memory-mapped
                    chain data to. If not provided, a temporary directory will
                    be created and the chain data written to files there.

            Returns:
                arvix.InferenceData:
                    An arviz data container with groups `posterior` and
                    'sample_stats', both of instances of `xarray.Dataset`.
                    The `posterior` group corresponds to the chain variable
                    samples computed using the `chain_var_funcs` entries (with
                    the data variable keys corresponding to the keys there).
                    The `sample_stats` group corresponds to the statistics of
                    the integration transition such as the acceptance
                    probabilities and number of integrator steps.
            """
            chains, chain_stats = self.sample_chains(
                n_sample, init_states, chain_var_funcs, **kwargs)
            return arviz.InferenceData(
                posterior=arviz.dict_to_dataset(chains, library=hmc),
                sample_stats=arviz.dict_to_dataset(chain_stats, library=hmc))