def _setup_prototype(self, *args, **kwargs): rng_key = numpyro.prng_key() with handlers.block(): ( init_params, _, self._postprocess_fn, self.prototype_trace, ) = initialize_model( rng_key, self.model, init_strategy=self.init_loc_fn, dynamic_args=False, model_args=args, model_kwargs=kwargs, ) self._init_locs = init_params[0] self._prototype_frames = {} self._prototype_plate_sizes = {} for name, site in self.prototype_trace.items(): if site["type"] == "sample": for frame in site["cond_indep_stack"]: self._prototype_frames[frame.name] = frame elif site["type"] == "plate": self._prototype_frame_full_sizes[name] = site["args"][0]
def log_likelihood(params_flat, subsample_indices=None): if subsample_indices is None: subsample_indices = { k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items() } params = unravel_fn(params_flat) with warnings.catch_warnings(): warnings.simplefilter("ignore") params = { name: biject_to(prototype_trace[name]["fn"].support)(value) for name, value in params.items() } with block(), trace() as tr, substitute( data=subsample_indices), substitute(data=params): model(*model_args, **model_kwargs) log_lik = {} for site in tr.values(): if site["type"] == "sample" and site["is_observed"]: for frame in site["cond_indep_stack"]: if frame.name in log_lik: log_lik[frame.name] += _sum_all_except_at_dim( site["fn"].log_prob(site["value"]), frame.dim) else: log_lik[frame.name] = _sum_all_except_at_dim( site["fn"].log_prob(site["value"]), frame.dim) return log_lik
def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. # Use `block` to not record sample primitives in `init_loc_fn`. seeded_model = substitute(model, substitute_fn=block(seed(init_strategy, subkey))) model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if v['type'] == 'sample' and not v['is_observed']: if v['intermediates']: constrained_values[k] = v['intermediates'][0][0] inv_transforms[k] = biject_to(v['fn'].base_dist.support) else: constrained_values[k] = v['value'] inv_transforms[k] = biject_to(v['fn'].support) elif v['type'] == 'param' and param_as_improper: constraint = v['kwargs'].pop('constraint', real) transform = biject_to(constraint) if isinstance(transform, ComposeTransform): base_transform = transform.parts[0] inv_transforms[k] = base_transform constrained_values[k] = base_transform(transform.inv(v['value'])) else: inv_transforms[k] = transform constrained_values[k] = v['value'] params = transform_fn(inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True) potential_fn = jax.partial(potential_energy, model, inv_transforms, model_args, model_kwargs) pe, param_grads = value_and_grad(potential_fn)(params) z_grad = ravel_pytree(param_grads)[0] is_valid = np.isfinite(pe) & np.all(np.isfinite(z_grad)) return i + 1, key, params, is_valid
def _setup_prototype(self, *args, **kwargs): super(AutoContinuous, self)._setup_prototype(*args, **kwargs) rng_key = numpyro.sample("_{}_rng_key_init".format(self.prefix), dist.PRNGIdentity()) init_params, _ = handlers.block(find_valid_initial_params)( rng_key, self.model, init_strategy=self.init_strategy, model_args=args, model_kwargs=kwargs) self._inv_transforms = {} self._has_transformed_dist = False unconstrained_sites = {} for name, site in self.prototype_trace.items(): if site['type'] == 'sample' and not site['is_observed']: if site['intermediates']: transform = biject_to(site['fn'].base_dist.support) self._inv_transforms[name] = transform unconstrained_sites[name] = transform.inv( site['intermediates'][0][0]) self._has_transformed_dist = True else: transform = biject_to(site['fn'].support) self._inv_transforms[name] = transform unconstrained_sites[name] = transform.inv(site['value']) self._init_latent, self._unpack_latent = ravel_pytree(init_params) self.latent_size = np.size(self._init_latent) if self.base_dist is None: self.base_dist = dist.Independent( dist.Normal(np.zeros(self.latent_size), 1.), 1) if self.latent_size == 0: raise RuntimeError( '{} found no latent variables; Use an empty guide instead'. format(type(self).__name__))
def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i == 0) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None) seeded_fn = handlers.seed(f, subkey) if subkey is not None else f for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == 'condition': seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == 'substitute': seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: with handlers.scope(prefix="_init"): new_carry, y = seeded_fn(carry, x) trace = {} else: with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov(): # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable nonlocal carry_shape_at_t1 carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]] # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)
def _setup_prototype(self, *args, **kwargs): super(AutoContinuous, self)._setup_prototype(*args, **kwargs) # FIXME: without block statement, get AssertionError: all sites must have unique names init_params, is_valid = block(find_valid_initial_params)( self._init_rng, self.model, *args, init_strategy=self.init_strategy, **kwargs) self._inv_transforms = {} self._has_transformed_dist = False unconstrained_sites = {} for name, site in self.prototype_trace.items(): if site['type'] == 'sample' and not site['is_observed']: if site['intermediates']: transform = biject_to(site['fn'].base_dist.support) self._inv_transforms[name] = transform unconstrained_sites[name] = transform.inv( site['intermediates'][0][0]) self._has_transformed_dist = True else: transform = biject_to(site['fn'].support) self._inv_transforms[name] = transform unconstrained_sites[name] = transform.inv(site['value']) self._init_latent, self.unpack_latent = ravel_pytree(init_params) self.latent_size = np.size(self._init_latent) if self.latent_size == 0: raise RuntimeError( '{} found no latent variables; Use an empty guide instead'. format(type(self).__name__))
def body_fn(wrapped_carry, x): i, rng_key, carry = wrapped_carry rng_key, subkey = random.split(rng_key) if rng_key is not None else ( None, None) with handlers.block(): # we need to tell unconstrained messenger in potential energy computation # that only the item at time `i` is needed when transforming fn = handlers.infer_config( f, config_fn=lambda msg: {"_scan_current_index": i}) seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == "condition": seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == "substitute": seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) with handlers.trace() as trace: carry, y = seeded_fn(carry, x) return (i + 1, rng_key, carry), (PytreeTrace(trace), y)
def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure rng_key = numpyro.sample("_{}_rng_key_setup".format(self.prefix), dist.PRNGIdentity()) model = handlers.seed(self.model, rng_key) self.prototype_trace = handlers.block(handlers.trace(model).get_trace)(*args, **kwargs) self._args = args self._kwargs = kwargs
def log_prior(params): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) dummy_subsample = { k: jnp.array([], dtype=jnp.int32) for k in subsample_plate_sizes } with block(), substitute(data=dummy_subsample): prior_prob, _ = log_density(model, model_kwargs, params) return prior_prob
def __init__(self, rng, model, get_params_fn, prefix="auto", init_loc_fn=init_to_median): # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. # Use `block` to not record sample primitives in `init_loc_fn`. model = substitute(model, substitute_fn=block(seed(init_loc_fn, rng))) super(AutoContinuous, self).__init__(model, get_params_fn, prefix=prefix)
def find_params(self, rng_keys, *args, **kwargs): params = {} init_params, _ = handlers.block(find_valid_initial_params)(rng_keys, self.model, init_strategy=self.init_strategy, model_args=args, model_kwargs=kwargs) for name, site in self.prototype_trace.items(): if site['type'] == 'sample' and not site['is_observed']: param_name = "{}_{}".format(self.prefix, name) param_val = biject_to(site['fn'].support)(init_params[name]) params[name] = (param_name, param_val, site['fn'].support) self._param_map = params self._init_params = {param: (val, constr) for param, val, constr in self._param_map.values()}
def body_fn(wrapped_carry, x): i, rng_key, carry = wrapped_carry rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None) with handlers.block(): seeded_fn = handlers.seed(f, subkey) if subkey is not None else f for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == 'condition': seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == 'substitute': seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) with handlers.trace() as trace: carry, y = seeded_fn(carry, x) return (i + 1, rng_key, carry), (PytreeTrace(trace), y)
def _setup_prototype(self, *args, **kwargs): rng_key = numpyro.sample("_{}_rng_key_setup".format(self.prefix), dist.PRNGIdentity()) with handlers.block(): init_params, _, self._postprocess_fn, self.prototype_trace = initialize_model( rng_key, self.model, init_strategy=self.init_strategy, dynamic_args=False, model_args=args, model_kwargs=kwargs) self._init_latent, unpack_latent = ravel_pytree(init_params[0]) # this is to match the behavior of Pyro, where we can apply # unpack_latent for a batch of samples self._unpack_latent = UnpackTransform(unpack_latent) self.latent_dim = jnp.size(self._init_latent) if self.latent_dim == 0: raise RuntimeError('{} found no latent variables; Use an empty guide instead' .format(type(self).__name__))
def wrapper(wrapped_operand): rng_key, operand = wrapped_operand with handlers.block(): seeded_fn = handlers.seed(fn, rng_key) if rng_key is not None else fn for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map) if subs_type == "condition": seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == "substitute": seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) with handlers.trace() as trace: value = seeded_fn(operand) return value, PytreeTrace(trace)
def get_observations_scale(model, model_args, model_kwargs, params): """ Traces through a model to extract the scale applied to observation log-likelihood. """ # todo(lumip): is there a way to avoid tracing through the entire model? # need to experiment with effect handlers and what exactly blocking achieves model = substitute(seed(model, 0), data=params) model = block(model, lambda msg: msg['type'] != 'sample' or not msg['is_observed']) model_trace = trace(model).get_trace(*model_args, **model_kwargs) scales = np.unique( [msg['scale'] if msg['scale'] is not None else 1 for msg in model_trace.values()] ) if len(scales) > 1: raise ValueError("The model received several observation sites with different example counts. This is not supported in DPSVI.") elif len(scales) == 0: return 1. return scales[0]
def find_params(self, rng_keys, *args, **kwargs): guide_trace = handlers.trace(handlers.seed(self.fn, rng_keys[0])).get_trace( *args, **kwargs) init_params, _ = handlers.block(find_valid_initial_params)( rng_keys, self.fn, init_strategy=self.init_strategy, param_as_improper=True, # To get new values for existing parameters model_args=args, model_kwargs=kwargs) params = {} for name, site in guide_trace.items(): if site['type'] == 'param': constraint = site['kwargs'].pop('constraint', real) param_val = biject_to(constraint)(init_params[name]) params[name] = (name, param_val, constraint) self._init_params = { param: (val, constr) for param, val, constr in params.values() }
def log_likelihood(params, subsample_indices=None): params_flat, unravel_fn = ravel_pytree(params) if subsample_indices is None: subsample_indices = { k: jnp.arange(v[0]) for k, v in subsample_plate_sizes.items() } params = unravel_fn(params_flat) with warnings.catch_warnings(): warnings.simplefilter("ignore") with block(), trace( ) as tr, substitute(data=subsample_indices), substitute( substitute_fn=partial(_unconstrain_reparam, params)): model(*model_args, **model_kwargs) log_lik = defaultdict(float) for site in tr.values(): if site["type"] == "sample" and site["is_observed"]: for frame in site["cond_indep_stack"]: if frame.name in subsample_plate_sizes: log_lik[frame.name] += _sum_all_except_at_dim( site["fn"].log_prob(site["value"]), frame.dim) return log_lik
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args, **kwargs): if temperature == 0: sum_op, prod_op = funsor.ops.max, funsor.ops.add approx = funsor.approximations.argmax_approximate elif temperature == 1: sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add rng_key, sub_key = random.split(rng_key) approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key) else: raise ValueError("temperature must be 0 (map) or 1 (sample) for now") if first_available_dim is None: with block(): model_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs) first_available_dim = -_guess_max_plate_nesting(model_trace) - 1 with funsor.adjoint.AdjointTape() as tape: with block(), enum(first_available_dim=first_available_dim): log_prob, model_tr, log_measures = _enum_log_density( model, args, kwargs, {}, sum_op, prod_op) with approx: approx_factors = tape.adjoint(sum_op, prod_op, log_prob) # construct a result trace to replay against the model sample_tr = model_tr.copy() sample_subs = {} for name, node in sample_tr.items(): if node["type"] != "sample": continue if node["is_observed"]: # "observed" values may be collapsed samples that depend on enumerated # values, so we have to slice them down # TODO this should really be handled entirely under the hood by adjoint output = funsor.Reals[node["fn"].event_shape] value = funsor.to_funsor(node["value"], output, dim_to_name=node["infer"]["dim_to_name"]) value = value(**sample_subs) node["value"] = funsor.to_data( value, name_to_dim=node["infer"]["name_to_dim"]) else: log_measure = approx_factors[log_measures[name]] sample_subs[name] = _get_support_value(log_measure, name) node["value"] = funsor.to_data( sample_subs[name], name_to_dim=node["infer"]["name_to_dim"]) data = { name: site["value"] for name, site in sample_tr.items() if site["type"] == "sample" } # concatenate _PREV_foo to foo time_vars = defaultdict(list) for name in data: if name.startswith("_PREV_"): root_name = _shift_name(name, -_get_shift(name)) time_vars[root_name].append(name) for name in time_vars: if name in data: time_vars[name].append(name) time_vars[name] = sorted(time_vars[name], key=len, reverse=True) for root_name, vars in time_vars.items(): prototype_shape = model_trace[root_name]["value"].shape values = [data.pop(name) for name in vars] if len(values) == 1: data[root_name] = values[0].reshape(prototype_shape) else: assert len(prototype_shape) >= 1 values = [v.reshape((-1, ) + prototype_shape[1:]) for v in values] data[root_name] = jnp.concatenate(values) return data
def scan_enum( f, init, xs, length, reverse, rng_key=None, substitute_stack=None, history=1, first_available_dim=None, ): from numpyro.contrib.funsor import ( config_enumerate, enum, markov, trace as packed_trace, ) # amount number of steps to unroll history = min(history, length) unroll_steps = min(2 * history - 1, length) if reverse: x0 = tree_map(lambda x: x[-unroll_steps:][::-1], xs) xs_ = tree_map(lambda x: x[:-unroll_steps], xs) else: x0 = tree_map(lambda x: x[:unroll_steps], xs) xs_ = tree_map(lambda x: x[unroll_steps:], xs) carry_shapes = [] def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry init = True if (not_jax_tracer(i) and i in range(unroll_steps)) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else ( None, None) # we need to tell unconstrained messenger in potential energy computation # that only the item at time `i` is needed when transforming fn = handlers.infer_config( f, config_fn=lambda msg: {"_scan_current_index": i}) seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) if subs_type == "condition": seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == "substitute": seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: # handler the name to match the pattern of sakkar_bilmes product with handlers.scope(prefix="_PREV_" * (unroll_steps - i), divider=""): new_carry, y = config_enumerate(seeded_fn)(carry, x) trace = {} else: # Like scan_wrapper, we collect the trace of scan's transition function # `seeded_fn` here. To put time dimension to the correct position, we need to # promote shapes to make `fn` and `value` # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, # and value's batch_shape is (3,), then we promote shape of # value so that its batch shape is (1, 3)). # Here we will promote `fn` shape first. `value` shape will be promoted after scanned. # We don't promote `value` shape here because we need to store carry shape # at this step. If we reshape the `value` here, output carry might get wrong shape. with _promote_fn_shapes(), packed_trace() as trace: new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable if len(carry_shapes) < (history + 1): carry_shapes.append( [jnp.shape(x) for x in tree_flatten(new_carry)[0]]) # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap( lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y) with handlers.block( hide_fn=lambda site: not site["name"].startswith("_PREV_")), enum( first_available_dim=first_available_dim): wrapped_carry = (0, rng_key, init) y0s = [] # We run unroll_steps + 1 where the last step is used for rolling with `lax.scan` for i in markov(range(unroll_steps + 1), history=history): if i < unroll_steps: wrapped_carry, (_, y0) = body_fn(wrapped_carry, tree_map(lambda z: z[i], x0)) if i > 0: # reshape y1, y2,... to have the same shape as y0 y0 = tree_multimap( lambda z0, z: jnp.reshape(z, jnp.shape(z0)), y0s[0], y0) y0s.append(y0) # shapes of the first `history - 1` steps are not useful to interpret the last carry # shape so we don't need to record them here if (i >= history - 1) and (len(carry_shapes) < history + 1): carry_shapes.append( jnp.shape(x) for x in tree_flatten(wrapped_carry[-1])[0]) else: # this is the last rolling step y0s = tree_multimap(lambda *z: jnp.stack(z, axis=0), *y0s) # return early if length = unroll_steps if length == unroll_steps: return wrapped_carry, (PytreeTrace({}), y0s) wrapped_carry = device_put(wrapped_carry) wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry, xs_, length - unroll_steps, reverse) first_var = None for name, site in pytree_trace.trace.items(): # currently, we only record sample or deterministic in the trace # we don't need to adjust `dim_to_name` for deterministic site if site["type"] not in ("sample", ): continue # add `time` dimension, the name will be '_time_{first variable in the trace}' if first_var is None: first_var = name # we haven't promote shapes of values yet during `lax.scan`, so we do it here site["value"] = _promote_scanned_value_shapes(site["value"], site["fn"]) # XXX: site['infer']['dim_to_name'] is not enough to determine leftmost dimension because # we don't record 1-size dimensions in this field time_dim = -min(len(site["fn"].batch_shape), jnp.ndim(site["value"]) - site["fn"].event_dim) site["infer"]["dim_to_name"][time_dim] = "_time_{}".format(first_var) # similar to carry, we need to reshape due to shape alternating in markov ys = tree_multimap( lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)[1:]), y0s, ys) # then join with y0s ys = tree_multimap(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys) # we also need to reshape `carry` to match sequential behavior i = (length + 1) % (history + 1) t, rng_key, carry = wrapped_carry carry_shape = carry_shapes[i] flatten_carry, treedef = tree_flatten(carry) flatten_carry = [ jnp.reshape(x, t1_shape) for x, t1_shape in zip(flatten_carry, carry_shape) ] carry = tree_unflatten(treedef, flatten_carry) wrapped_carry = (t, rng_key, carry) return wrapped_carry, (pytree_trace, ys)
def init(self, rng_key, *args, **kwargs): """ :param jax.random.PRNGKey rng_key: random number generator seed. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary during the course of fitting). :return: initial :data:`SVGDState` """ rng_key, model_seed, guide_seed = jax.random.split(rng_key, 3) model_init = handlers.seed(self.model, model_seed) guide_init = handlers.seed(self.guide, guide_seed) guide_trace = handlers.trace(guide_init).get_trace( *args, **kwargs, **self.static_kwargs) model_trace = handlers.trace(model_init).get_trace( *args, **kwargs, **self.static_kwargs) rng_key, particle_seed = jax.random.split(rng_key) particle_seeds = jax.random.split(particle_seed, num=self.num_stein_particles) self.guide.find_params( particle_seeds, *args, **kwargs, **self.static_kwargs) # Get parameter values for each particle guide_init_params = self.guide.init_params() params = {} transforms = {} inv_transforms = {} guide_param_names = set() # NB: params in model_trace will be overwritten by params in guide_trace for site in list(model_trace.values()) + list(guide_trace.values()): if site['type'] == 'param': constraint = site['kwargs'].pop('constraint', constraints.real) transform = biject_to(constraint) inv_transforms[site['name']] = transform transforms[site['name']] = transform.inv if site['name'] in guide_init_params: pval, _ = guide_init_params[site['name']] else: pval = site['value'] params[site['name']] = transform.inv(pval) if site['name'] in guide_trace: guide_param_names.add(site['name']) self.guide_param_names = guide_param_names self.constrain_fn = jax.partial(transform_fn, inv_transforms) self.uconstrain_fn = jax.partial(transform_fn, transforms) classic_uparam_names = { p for p in params.keys() if p not in self.guide_param_names or self.classic_guide_params_fn(p) } # Ensure not to sample parameters that should be classically updated sampler = self.sampler_fn( handlers.block(self.model, lambda site: site['name'] in classic_uparam_names), **self.sampler_kwargs) self.mcmc = MCMC(sampler, self.num_mcmc_warmup, self.num_mcmc_updates, num_chains=self.num_mcmc_particles, progress_bar=False, **self.mcmc_kwargs) return SVGDState(self.optim.init(params), rng_key)
def _setup_prototype(self, *args, **kwargs): # run the model so we can inspect its structure self.prototype_trace = block(trace(self.model).get_trace)(*args, **kwargs)
def test_block(): with handlers.trace() as trace: with handlers.block(hide=['x']): with handlers.seed(rng_seed=0): numpyro.sample('x', dist.Normal()) assert 'x' not in trace
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args, **kwargs): if temperature == 0: sum_op, prod_op = funsor.ops.max, funsor.ops.add approx = funsor.approximations.argmax_approximate elif temperature == 1: sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add rng_key, sub_key = random.split(rng_key) approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key) else: raise ValueError("temperature must be 0 (map) or 1 (sample) for now") if first_available_dim is None: with block(): model_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs) first_available_dim = -_guess_max_plate_nesting(model_trace) - 1 with block(), enum(first_available_dim=first_available_dim): with plate_to_enum_plate(): model_tr = packed_trace(model).get_trace(*args, **kwargs) terms = terms_from_trace(model_tr) # terms["log_factors"] = [log p(x) for each observed or latent sample site x] # terms["log_measures"] = [log p(z) or other Dice factor # for each latent sample site z] with funsor.interpretations.lazy: log_prob = funsor.sum_product.sum_product( sum_op, prod_op, list(terms["log_factors"].values()) + list(terms["log_measures"].values()), eliminate=terms["measure_vars"] | terms["plate_vars"], plates=terms["plate_vars"], ) log_prob = funsor.optimizer.apply_optimizer(log_prob) with approx: approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob) # construct a result trace to replay against the model sample_tr = model_tr.copy() sample_subs = {} for name, node in sample_tr.items(): if node["type"] != "sample": continue if node["is_observed"]: # "observed" values may be collapsed samples that depend on enumerated # values, so we have to slice them down # TODO this should really be handled entirely under the hood by adjoint output = funsor.Reals[node["fn"].event_shape] value = funsor.to_funsor(node["value"], output, dim_to_name=node["infer"]["dim_to_name"]) value = value(**sample_subs) node["value"] = funsor.to_data( value, name_to_dim=node["infer"]["name_to_dim"]) else: log_measure = approx_factors[terms["log_measures"][name]] sample_subs[name] = _get_support_value(log_measure, name) node["value"] = funsor.to_data( sample_subs[name], name_to_dim=node["infer"]["name_to_dim"]) with replay(guide_trace=sample_tr): return model(*args, **kwargs)
def test_block(): with handlers.trace() as trace: with handlers.block(hide=["x"]): with handlers.seed(rng_seed=0): numpyro.sample("x", dist.Normal()) assert "x" not in trace