Example #1
0
def location(original_data, samples, transformed_samples, model_fn, new_data):
    # Optimization: For the data used for inference, values for `mu`
    # are already computed and available from `transformed_samples`.
    if new_data == original_data:
        return flatten(transformed_samples['mu'])
    else:
        return flatten(
            run_model_on_samples_and_data(model_fn, samples, new_data)['mu'])
Example #2
0
    def loc(d):
        # Optimization: For the data used for inference, values for
        # `mu` are already computed and available from
        # `transformed_samples`.

        # The `location` method on `Samples` isn't expected to
        # preserve the chain dim so we flatten here.
        if d == data:
            return flatten(transformed_samples['mu'])
        else:
            # TODO: This computes more than is necessary. (i.e. It
            # build additional tensors we immediately throw away.)
            # This is minor, but might be worth addressing eventually.
            return flatten(
                run_model_on_samples_and_data(assets.fn, samples, d)['mu'])
Example #3
0
def run_model_on_samples_and_data(modelfn, samples, data):
    assert type(samples) == dict
    assert len(samples) > 0
    num_chains, num_samples = next(iter(samples.values())).shape[0:2]
    assert all(arr.shape[0:2] == (num_chains, num_samples)
               for arr in samples.values())

    # Flatten the sample for easier iteration.
    flat_samples = {k: flatten(arr) for k, arr in samples.items()}

    def run(i):
        sample = {k: arr[i] for k, arr in flat_samples.items()}
        return poutine.condition(modelfn, sample)(**data)

    return_values = [run(i) for i in range(num_chains * num_samples)]

    # TODO: It would probably be better to allocate output arrays and
    # fill them as we run the model. However, I'm holding off on
    # making this change since this is good enough, and it's possible
    # this whole approach may be replaced by something vectorized.

    # We know the model structure is static, so names don't change
    # across executions.
    names = return_values[0].keys()
    # Build output dict., restoring chain dim.
    return {
        name:
        unflatten(torch.stack([retval[name] for retval in return_values]),
                  num_chains, num_samples)
        for name in names
    }
Example #4
0
File: fit.py Project: pyro-ppl/brmp
    def marginals(self, qs=default_quantiles):
        """Produces a table containing statistics of the marginal
        distibutions of the parameters of the fitted model.

        :param qs: A list of quantiles to include in the output.
        :type qs: list
        :return: A table of marginal statistics.
        :rtype: brmp.fit.ArrReprWrapper

        Example::

          fit = brm('y ~ x', df).fit()
          print(fit.marginals())

          #        mean    sd  2.5%   25%   50%   75% 97.5% n_eff r_hat
          # b_x    0.42  0.33 -0.11  0.14  0.48  0.65  0.88  5.18  1.00
          # sigma  0.78  0.28  0.48  0.61  0.68  0.87  1.32  5.28  1.10

        """
        names = scalar_parameter_names(self.model_desc)
        vecs = [self.get_scalar_param(name, True) for name in names]
        col_labels = ['mean', 'sd'] + format_quantiles(qs) + ['n_eff', 'r_hat']
        samples = np.stack(vecs, axis=2)
        stats_arr = marginal_stats(flatten(samples), qs)
        n_eff = compute_diag_or_default(effective_sample_size, samples)
        r_hat = compute_diag_or_default(gelman_rubin, samples)
        arr = np.hstack(
            [stats_arr, n_eff[..., np.newaxis], r_hat[..., np.newaxis]])
        return ArrReprWrapper(arr, names, col_labels)
Example #5
0
def run_model_on_samples_and_data(modelfn, samples, data):
    assert type(samples) == dict
    assert len(samples) > 0
    num_chains, num_samples = next(iter(samples.values())).shape[0:2]
    assert all(arr.shape[0:2] == (num_chains, num_samples)
               for arr in samples.values())
    flat_samples = {k: flatten(arr) for k, arr in samples.items()}
    out = vmap(lambda sample: handler.substitute(modelfn, sample)
               (**data, mode='prior_and_mu'))(flat_samples)
    # Restore chain dim.
    return {
        k: unflatten(arr, num_chains, num_samples)
        for k, arr in out.items()
    }
Example #6
0
 def get_param(name, preserve_chains):
     param = all_samples[name]
     return param if preserve_chains else flatten(param)
Example #7
0
def get_param(samples, name, preserve_chains):
    assert type(samples) == dict
    # Reminder to use correct interface.
    assert not name == 'mu', 'Use `location` to fetch `mu`.'
    param = samples[name]
    return param if preserve_chains else flatten(param)