def location(original_data, samples, transformed_samples, model_fn, new_data): # Optimization: For the data used for inference, values for `mu` # are already computed and available from `transformed_samples`. if new_data == original_data: return flatten(transformed_samples['mu']) else: return flatten( run_model_on_samples_and_data(model_fn, samples, new_data)['mu'])
def loc(d): # Optimization: For the data used for inference, values for # `mu` are already computed and available from # `transformed_samples`. # The `location` method on `Samples` isn't expected to # preserve the chain dim so we flatten here. if d == data: return flatten(transformed_samples['mu']) else: # TODO: This computes more than is necessary. (i.e. It # build additional tensors we immediately throw away.) # This is minor, but might be worth addressing eventually. return flatten( run_model_on_samples_and_data(assets.fn, samples, d)['mu'])
def run_model_on_samples_and_data(modelfn, samples, data): assert type(samples) == dict assert len(samples) > 0 num_chains, num_samples = next(iter(samples.values())).shape[0:2] assert all(arr.shape[0:2] == (num_chains, num_samples) for arr in samples.values()) # Flatten the sample for easier iteration. flat_samples = {k: flatten(arr) for k, arr in samples.items()} def run(i): sample = {k: arr[i] for k, arr in flat_samples.items()} return poutine.condition(modelfn, sample)(**data) return_values = [run(i) for i in range(num_chains * num_samples)] # TODO: It would probably be better to allocate output arrays and # fill them as we run the model. However, I'm holding off on # making this change since this is good enough, and it's possible # this whole approach may be replaced by something vectorized. # We know the model structure is static, so names don't change # across executions. names = return_values[0].keys() # Build output dict., restoring chain dim. return { name: unflatten(torch.stack([retval[name] for retval in return_values]), num_chains, num_samples) for name in names }
def marginals(self, qs=default_quantiles): """Produces a table containing statistics of the marginal distibutions of the parameters of the fitted model. :param qs: A list of quantiles to include in the output. :type qs: list :return: A table of marginal statistics. :rtype: brmp.fit.ArrReprWrapper Example:: fit = brm('y ~ x', df).fit() print(fit.marginals()) # mean sd 2.5% 25% 50% 75% 97.5% n_eff r_hat # b_x 0.42 0.33 -0.11 0.14 0.48 0.65 0.88 5.18 1.00 # sigma 0.78 0.28 0.48 0.61 0.68 0.87 1.32 5.28 1.10 """ names = scalar_parameter_names(self.model_desc) vecs = [self.get_scalar_param(name, True) for name in names] col_labels = ['mean', 'sd'] + format_quantiles(qs) + ['n_eff', 'r_hat'] samples = np.stack(vecs, axis=2) stats_arr = marginal_stats(flatten(samples), qs) n_eff = compute_diag_or_default(effective_sample_size, samples) r_hat = compute_diag_or_default(gelman_rubin, samples) arr = np.hstack( [stats_arr, n_eff[..., np.newaxis], r_hat[..., np.newaxis]]) return ArrReprWrapper(arr, names, col_labels)
def run_model_on_samples_and_data(modelfn, samples, data): assert type(samples) == dict assert len(samples) > 0 num_chains, num_samples = next(iter(samples.values())).shape[0:2] assert all(arr.shape[0:2] == (num_chains, num_samples) for arr in samples.values()) flat_samples = {k: flatten(arr) for k, arr in samples.items()} out = vmap(lambda sample: handler.substitute(modelfn, sample) (**data, mode='prior_and_mu'))(flat_samples) # Restore chain dim. return { k: unflatten(arr, num_chains, num_samples) for k, arr in out.items() }
def get_param(name, preserve_chains): param = all_samples[name] return param if preserve_chains else flatten(param)
def get_param(samples, name, preserve_chains): assert type(samples) == dict # Reminder to use correct interface. assert not name == 'mu', 'Use `location` to fetch `mu`.' param = samples[name] return param if preserve_chains else flatten(param)