Beispiel #1
0
    def _setup_prototype(self, *args, **kwargs):
        rng_key = numpyro.prng_key()
        with handlers.block():
            (
                init_params,
                _,
                self._postprocess_fn,
                self.prototype_trace,
            ) = initialize_model(
                rng_key,
                self.model,
                init_strategy=self.init_loc_fn,
                dynamic_args=False,
                model_args=args,
                model_kwargs=kwargs,
            )
        self._init_locs = init_params[0]

        self._prototype_frames = {}
        self._prototype_plate_sizes = {}
        for name, site in self.prototype_trace.items():
            if site["type"] == "sample":
                for frame in site["cond_indep_stack"]:
                    self._prototype_frames[frame.name] = frame
            elif site["type"] == "plate":
                self._prototype_frame_full_sizes[name] = site["args"][0]
Beispiel #2
0
        def log_likelihood(params_flat, subsample_indices=None):
            if subsample_indices is None:
                subsample_indices = {
                    k: jnp.arange(v[0])
                    for k, v in subsample_plate_sizes.items()
                }
            params = unravel_fn(params_flat)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                params = {
                    name: biject_to(prototype_trace[name]["fn"].support)(value)
                    for name, value in params.items()
                }
                with block(), trace() as tr, substitute(
                        data=subsample_indices), substitute(data=params):
                    model(*model_args, **model_kwargs)

            log_lik = {}
            for site in tr.values():
                if site["type"] == "sample" and site["is_observed"]:
                    for frame in site["cond_indep_stack"]:
                        if frame.name in log_lik:
                            log_lik[frame.name] += _sum_all_except_at_dim(
                                site["fn"].log_prob(site["value"]), frame.dim)
                        else:
                            log_lik[frame.name] = _sum_all_except_at_dim(
                                site["fn"].log_prob(site["value"]), frame.dim)
            return log_lik
Beispiel #3
0
    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
        # Use `block` to not record sample primitives in `init_loc_fn`.
        seeded_model = substitute(model, substitute_fn=block(seed(init_strategy, subkey)))
        model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
        constrained_values, inv_transforms = {}, {}
        for k, v in model_trace.items():
            if v['type'] == 'sample' and not v['is_observed']:
                if v['intermediates']:
                    constrained_values[k] = v['intermediates'][0][0]
                    inv_transforms[k] = biject_to(v['fn'].base_dist.support)
                else:
                    constrained_values[k] = v['value']
                    inv_transforms[k] = biject_to(v['fn'].support)
            elif v['type'] == 'param' and param_as_improper:
                constraint = v['kwargs'].pop('constraint', real)
                transform = biject_to(constraint)
                if isinstance(transform, ComposeTransform):
                    base_transform = transform.parts[0]
                    inv_transforms[k] = base_transform
                    constrained_values[k] = base_transform(transform.inv(v['value']))
                else:
                    inv_transforms[k] = transform
                    constrained_values[k] = v['value']
        params = transform_fn(inv_transforms,
                              {k: v for k, v in constrained_values.items()},
                              invert=True)
        potential_fn = jax.partial(potential_energy, model, inv_transforms, model_args, model_kwargs)
        pe, param_grads = value_and_grad(potential_fn)(params)
        z_grad = ravel_pytree(param_grads)[0]
        is_valid = np.isfinite(pe) & np.all(np.isfinite(z_grad))
        return i + 1, key, params, is_valid
Beispiel #4
0
    def _setup_prototype(self, *args, **kwargs):
        super(AutoContinuous, self)._setup_prototype(*args, **kwargs)
        rng_key = numpyro.sample("_{}_rng_key_init".format(self.prefix),
                                 dist.PRNGIdentity())
        init_params, _ = handlers.block(find_valid_initial_params)(
            rng_key,
            self.model,
            init_strategy=self.init_strategy,
            model_args=args,
            model_kwargs=kwargs)
        self._inv_transforms = {}
        self._has_transformed_dist = False
        unconstrained_sites = {}
        for name, site in self.prototype_trace.items():
            if site['type'] == 'sample' and not site['is_observed']:
                if site['intermediates']:
                    transform = biject_to(site['fn'].base_dist.support)
                    self._inv_transforms[name] = transform
                    unconstrained_sites[name] = transform.inv(
                        site['intermediates'][0][0])
                    self._has_transformed_dist = True
                else:
                    transform = biject_to(site['fn'].support)
                    self._inv_transforms[name] = transform
                    unconstrained_sites[name] = transform.inv(site['value'])

        self._init_latent, self._unpack_latent = ravel_pytree(init_params)
        self.latent_size = np.size(self._init_latent)
        if self.base_dist is None:
            self.base_dist = dist.Independent(
                dist.Normal(np.zeros(self.latent_size), 1.), 1)
        if self.latent_size == 0:
            raise RuntimeError(
                '{} found no latent variables; Use an empty guide instead'.
                format(type(self).__name__))
Beispiel #5
0
    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i) and i == 0) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)

        seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == 'condition':
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == 'substitute':
                seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

        if init:
            with handlers.scope(prefix="_init"):
                new_carry, y = seeded_fn(carry, x)
                trace = {}
        else:
            with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov():
                # Like scan_wrapper, we collect the trace of scan's transition function
                # `seeded_fn` here. To put time dimension to the correct position, we need to
                # promote shapes to make `fn` and `value`
                # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
                # and value's batch_shape is (3,), then we promote shape of
                # value so that its batch shape is (1, 3)).
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            nonlocal carry_shape_at_t1
            carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]]
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)),
                                      new_carry, carry)
        return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)
Beispiel #6
0
    def _setup_prototype(self, *args, **kwargs):
        super(AutoContinuous, self)._setup_prototype(*args, **kwargs)
        # FIXME: without block statement, get AssertionError: all sites must have unique names
        init_params, is_valid = block(find_valid_initial_params)(
            self._init_rng,
            self.model,
            *args,
            init_strategy=self.init_strategy,
            **kwargs)
        self._inv_transforms = {}
        self._has_transformed_dist = False
        unconstrained_sites = {}
        for name, site in self.prototype_trace.items():
            if site['type'] == 'sample' and not site['is_observed']:
                if site['intermediates']:
                    transform = biject_to(site['fn'].base_dist.support)
                    self._inv_transforms[name] = transform
                    unconstrained_sites[name] = transform.inv(
                        site['intermediates'][0][0])
                    self._has_transformed_dist = True
                else:
                    transform = biject_to(site['fn'].support)
                    self._inv_transforms[name] = transform
                    unconstrained_sites[name] = transform.inv(site['value'])

        self._init_latent, self.unpack_latent = ravel_pytree(init_params)
        self.latent_size = np.size(self._init_latent)
        if self.latent_size == 0:
            raise RuntimeError(
                '{} found no latent variables; Use an empty guide instead'.
                format(type(self).__name__))
Beispiel #7
0
    def body_fn(wrapped_carry, x):
        i, rng_key, carry = wrapped_carry
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (
            None, None)

        with handlers.block():

            # we need to tell unconstrained messenger in potential energy computation
            # that only the item at time `i` is needed when transforming
            fn = handlers.infer_config(
                f, config_fn=lambda msg: {"_scan_current_index": i})

            seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn
            for subs_type, subs_map in substitute_stack:
                subs_fn = partial(_subs_wrapper, subs_map, i, length)
                if subs_type == "condition":
                    seeded_fn = handlers.condition(seeded_fn,
                                                   condition_fn=subs_fn)
                elif subs_type == "substitute":
                    seeded_fn = handlers.substitute(seeded_fn,
                                                    substitute_fn=subs_fn)

            with handlers.trace() as trace:
                carry, y = seeded_fn(carry, x)

        return (i + 1, rng_key, carry), (PytreeTrace(trace), y)
Beispiel #8
0
 def _setup_prototype(self, *args, **kwargs):
     # run the model so we can inspect its structure
     rng_key = numpyro.sample("_{}_rng_key_setup".format(self.prefix), dist.PRNGIdentity())
     model = handlers.seed(self.model, rng_key)
     self.prototype_trace = handlers.block(handlers.trace(model).get_trace)(*args, **kwargs)
     self._args = args
     self._kwargs = kwargs
Beispiel #9
0
 def log_prior(params):
     with warnings.catch_warnings():
         warnings.filterwarnings("ignore", category=UserWarning)
         dummy_subsample = {
             k: jnp.array([], dtype=jnp.int32)
             for k in subsample_plate_sizes
         }
         with block(), substitute(data=dummy_subsample):
             prior_prob, _ = log_density(model, model_kwargs, params)
     return prior_prob
Beispiel #10
0
 def __init__(self,
              rng,
              model,
              get_params_fn,
              prefix="auto",
              init_loc_fn=init_to_median):
     # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
     # Use `block` to not record sample primitives in `init_loc_fn`.
     model = substitute(model, substitute_fn=block(seed(init_loc_fn, rng)))
     super(AutoContinuous, self).__init__(model,
                                          get_params_fn,
                                          prefix=prefix)
Beispiel #11
0
 def find_params(self, rng_keys, *args, **kwargs):
     params = {}
     init_params, _ = handlers.block(find_valid_initial_params)(rng_keys, self.model,
                                                                init_strategy=self.init_strategy,
                                                                model_args=args,
                                                                model_kwargs=kwargs)
     for name, site in self.prototype_trace.items():
         if site['type'] == 'sample' and not site['is_observed']:
             param_name = "{}_{}".format(self.prefix, name)
             param_val = biject_to(site['fn'].support)(init_params[name])
             params[name] = (param_name, param_val, site['fn'].support)
     self._param_map = params
     self._init_params = {param: (val, constr) for param, val, constr in self._param_map.values()}
Beispiel #12
0
    def body_fn(wrapped_carry, x):
        i, rng_key, carry = wrapped_carry
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)

        with handlers.block():
            seeded_fn = handlers.seed(f, subkey) if subkey is not None else f
            for subs_type, subs_map in substitute_stack:
                subs_fn = partial(_subs_wrapper, subs_map, i, length)
                if subs_type == 'condition':
                    seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
                elif subs_type == 'substitute':
                    seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

            with handlers.trace() as trace:
                carry, y = seeded_fn(carry, x)

        return (i + 1, rng_key, carry), (PytreeTrace(trace), y)
Beispiel #13
0
    def _setup_prototype(self, *args, **kwargs):
        rng_key = numpyro.sample("_{}_rng_key_setup".format(self.prefix), dist.PRNGIdentity())
        with handlers.block():
            init_params, _, self._postprocess_fn, self.prototype_trace = initialize_model(
                rng_key, self.model,
                init_strategy=self.init_strategy,
                dynamic_args=False,
                model_args=args,
                model_kwargs=kwargs)

        self._init_latent, unpack_latent = ravel_pytree(init_params[0])
        # this is to match the behavior of Pyro, where we can apply
        # unpack_latent for a batch of samples
        self._unpack_latent = UnpackTransform(unpack_latent)
        self.latent_dim = jnp.size(self._init_latent)
        if self.latent_dim == 0:
            raise RuntimeError('{} found no latent variables; Use an empty guide instead'
                               .format(type(self).__name__))
Beispiel #14
0
    def wrapper(wrapped_operand):
        rng_key, operand = wrapped_operand

        with handlers.block():
            seeded_fn = handlers.seed(fn,
                                      rng_key) if rng_key is not None else fn
            for subs_type, subs_map in substitute_stack:
                subs_fn = partial(_subs_wrapper, subs_map)
                if subs_type == "condition":
                    seeded_fn = handlers.condition(seeded_fn,
                                                   condition_fn=subs_fn)
                elif subs_type == "substitute":
                    seeded_fn = handlers.substitute(seeded_fn,
                                                    substitute_fn=subs_fn)

            with handlers.trace() as trace:
                value = seeded_fn(operand)

        return value, PytreeTrace(trace)
Beispiel #15
0
def get_observations_scale(model, model_args, model_kwargs, params):
    """
    Traces through a model to extract the scale applied to observation log-likelihood.
    """

    # todo(lumip): is there a way to avoid tracing through the entire model?
    #       need to experiment with effect handlers and what exactly blocking achieves
    model = substitute(seed(model, 0), data=params)
    model = block(model, lambda msg: msg['type'] != 'sample' or not msg['is_observed'])
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    scales = np.unique(
        [msg['scale'] if msg['scale'] is not None else 1 for msg in model_trace.values()]
    )

    if len(scales) > 1:
        raise ValueError("The model received several observation sites with different example counts. This is not supported in DPSVI.")
    elif len(scales) == 0:
        return 1.

    return scales[0]
Beispiel #16
0
 def find_params(self, rng_keys, *args, **kwargs):
     guide_trace = handlers.trace(handlers.seed(self.fn,
                                                rng_keys[0])).get_trace(
                                                    *args, **kwargs)
     init_params, _ = handlers.block(find_valid_initial_params)(
         rng_keys,
         self.fn,
         init_strategy=self.init_strategy,
         param_as_improper=True,  # To get new values for existing parameters
         model_args=args,
         model_kwargs=kwargs)
     params = {}
     for name, site in guide_trace.items():
         if site['type'] == 'param':
             constraint = site['kwargs'].pop('constraint', real)
             param_val = biject_to(constraint)(init_params[name])
             params[name] = (name, param_val, constraint)
     self._init_params = {
         param: (val, constr)
         for param, val, constr in params.values()
     }
Beispiel #17
0
        def log_likelihood(params, subsample_indices=None):
            params_flat, unravel_fn = ravel_pytree(params)
            if subsample_indices is None:
                subsample_indices = {
                    k: jnp.arange(v[0])
                    for k, v in subsample_plate_sizes.items()
                }
            params = unravel_fn(params_flat)
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                with block(), trace(
                ) as tr, substitute(data=subsample_indices), substitute(
                        substitute_fn=partial(_unconstrain_reparam, params)):
                    model(*model_args, **model_kwargs)

            log_lik = defaultdict(float)
            for site in tr.values():
                if site["type"] == "sample" and site["is_observed"]:
                    for frame in site["cond_indep_stack"]:
                        if frame.name in subsample_plate_sizes:
                            log_lik[frame.name] += _sum_all_except_at_dim(
                                site["fn"].log_prob(site["value"]), frame.dim)
            return log_lik
Beispiel #18
0
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args,
                      **kwargs):

    if temperature == 0:
        sum_op, prod_op = funsor.ops.max, funsor.ops.add
        approx = funsor.approximations.argmax_approximate
    elif temperature == 1:
        sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add
        rng_key, sub_key = random.split(rng_key)
        approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key)
    else:
        raise ValueError("temperature must be 0 (map) or 1 (sample) for now")

    if first_available_dim is None:
        with block():
            model_trace = trace(seed(model,
                                     rng_key)).get_trace(*args, **kwargs)
        first_available_dim = -_guess_max_plate_nesting(model_trace) - 1

    with funsor.adjoint.AdjointTape() as tape:
        with block(), enum(first_available_dim=first_available_dim):
            log_prob, model_tr, log_measures = _enum_log_density(
                model, args, kwargs, {}, sum_op, prod_op)

    with approx:
        approx_factors = tape.adjoint(sum_op, prod_op, log_prob)

    # construct a result trace to replay against the model
    sample_tr = model_tr.copy()
    sample_subs = {}
    for name, node in sample_tr.items():
        if node["type"] != "sample":
            continue
        if node["is_observed"]:
            # "observed" values may be collapsed samples that depend on enumerated
            # values, so we have to slice them down
            # TODO this should really be handled entirely under the hood by adjoint
            output = funsor.Reals[node["fn"].event_shape]
            value = funsor.to_funsor(node["value"],
                                     output,
                                     dim_to_name=node["infer"]["dim_to_name"])
            value = value(**sample_subs)
            node["value"] = funsor.to_data(
                value, name_to_dim=node["infer"]["name_to_dim"])
        else:
            log_measure = approx_factors[log_measures[name]]
            sample_subs[name] = _get_support_value(log_measure, name)
            node["value"] = funsor.to_data(
                sample_subs[name], name_to_dim=node["infer"]["name_to_dim"])

    data = {
        name: site["value"]
        for name, site in sample_tr.items() if site["type"] == "sample"
    }

    # concatenate _PREV_foo to foo
    time_vars = defaultdict(list)
    for name in data:
        if name.startswith("_PREV_"):
            root_name = _shift_name(name, -_get_shift(name))
            time_vars[root_name].append(name)
    for name in time_vars:
        if name in data:
            time_vars[name].append(name)
        time_vars[name] = sorted(time_vars[name], key=len, reverse=True)

    for root_name, vars in time_vars.items():
        prototype_shape = model_trace[root_name]["value"].shape
        values = [data.pop(name) for name in vars]
        if len(values) == 1:
            data[root_name] = values[0].reshape(prototype_shape)
        else:
            assert len(prototype_shape) >= 1
            values = [v.reshape((-1, ) + prototype_shape[1:]) for v in values]
            data[root_name] = jnp.concatenate(values)

    return data
Beispiel #19
0
def scan_enum(
    f,
    init,
    xs,
    length,
    reverse,
    rng_key=None,
    substitute_stack=None,
    history=1,
    first_available_dim=None,
):
    from numpyro.contrib.funsor import (
        config_enumerate,
        enum,
        markov,
        trace as packed_trace,
    )

    # amount number of steps to unroll
    history = min(history, length)
    unroll_steps = min(2 * history - 1, length)
    if reverse:
        x0 = tree_map(lambda x: x[-unroll_steps:][::-1], xs)
        xs_ = tree_map(lambda x: x[:-unroll_steps], xs)
    else:
        x0 = tree_map(lambda x: x[:unroll_steps], xs)
        xs_ = tree_map(lambda x: x[unroll_steps:], xs)

    carry_shapes = []

    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i)
                        and i in range(unroll_steps)) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (
            None, None)

        # we need to tell unconstrained messenger in potential energy computation
        # that only the item at time `i` is needed when transforming
        fn = handlers.infer_config(
            f, config_fn=lambda msg: {"_scan_current_index": i})

        seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == "condition":
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == "substitute":
                seeded_fn = handlers.substitute(seeded_fn,
                                                substitute_fn=subs_fn)

        if init:
            # handler the name to match the pattern of sakkar_bilmes product
            with handlers.scope(prefix="_PREV_" * (unroll_steps - i),
                                divider=""):
                new_carry, y = config_enumerate(seeded_fn)(carry, x)
                trace = {}
        else:
            # Like scan_wrapper, we collect the trace of scan's transition function
            # `seeded_fn` here. To put time dimension to the correct position, we need to
            # promote shapes to make `fn` and `value`
            # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
            # and value's batch_shape is (3,), then we promote shape of
            # value so that its batch shape is (1, 3)).
            # Here we will promote `fn` shape first. `value` shape will be promoted after scanned.
            # We don't promote `value` shape here because we need to store carry shape
            # at this step. If we reshape the `value` here, output carry might get wrong shape.
            with _promote_fn_shapes(), packed_trace() as trace:
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            if len(carry_shapes) < (history + 1):
                carry_shapes.append(
                    [jnp.shape(x) for x in tree_flatten(new_carry)[0]])
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(
                lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry)
        return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)

    with handlers.block(
            hide_fn=lambda site: not site["name"].startswith("_PREV_")), enum(
                first_available_dim=first_available_dim):
        wrapped_carry = (0, rng_key, init)
        y0s = []
        # We run unroll_steps + 1 where the last step is used for rolling with `lax.scan`
        for i in markov(range(unroll_steps + 1), history=history):
            if i < unroll_steps:
                wrapped_carry, (_, y0) = body_fn(wrapped_carry,
                                                 tree_map(lambda z: z[i], x0))
                if i > 0:
                    # reshape y1, y2,... to have the same shape as y0
                    y0 = tree_multimap(
                        lambda z0, z: jnp.reshape(z, jnp.shape(z0)), y0s[0],
                        y0)
                y0s.append(y0)
                # shapes of the first `history - 1` steps are not useful to interpret the last carry
                # shape so we don't need to record them here
                if (i >= history - 1) and (len(carry_shapes) < history + 1):
                    carry_shapes.append(
                        jnp.shape(x)
                        for x in tree_flatten(wrapped_carry[-1])[0])
            else:
                # this is the last rolling step
                y0s = tree_multimap(lambda *z: jnp.stack(z, axis=0), *y0s)
                # return early if length = unroll_steps
                if length == unroll_steps:
                    return wrapped_carry, (PytreeTrace({}), y0s)
                wrapped_carry = device_put(wrapped_carry)
                wrapped_carry, (pytree_trace,
                                ys) = lax.scan(body_fn, wrapped_carry, xs_,
                                               length - unroll_steps, reverse)

    first_var = None
    for name, site in pytree_trace.trace.items():
        # currently, we only record sample or deterministic in the trace
        # we don't need to adjust `dim_to_name` for deterministic site
        if site["type"] not in ("sample", ):
            continue
        # add `time` dimension, the name will be '_time_{first variable in the trace}'
        if first_var is None:
            first_var = name

        # we haven't promote shapes of values yet during `lax.scan`, so we do it here
        site["value"] = _promote_scanned_value_shapes(site["value"],
                                                      site["fn"])

        # XXX: site['infer']['dim_to_name'] is not enough to determine leftmost dimension because
        # we don't record 1-size dimensions in this field
        time_dim = -min(len(site["fn"].batch_shape),
                        jnp.ndim(site["value"]) - site["fn"].event_dim)
        site["infer"]["dim_to_name"][time_dim] = "_time_{}".format(first_var)

    # similar to carry, we need to reshape due to shape alternating in markov
    ys = tree_multimap(
        lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)[1:]), y0s, ys)
    # then join with y0s
    ys = tree_multimap(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys)
    # we also need to reshape `carry` to match sequential behavior
    i = (length + 1) % (history + 1)
    t, rng_key, carry = wrapped_carry
    carry_shape = carry_shapes[i]
    flatten_carry, treedef = tree_flatten(carry)
    flatten_carry = [
        jnp.reshape(x, t1_shape)
        for x, t1_shape in zip(flatten_carry, carry_shape)
    ]
    carry = tree_unflatten(treedef, flatten_carry)
    wrapped_carry = (t, rng_key, carry)
    return wrapped_carry, (pytree_trace, ys)
Beispiel #20
0
 def init(self, rng_key, *args, **kwargs):
     """
     :param jax.random.PRNGKey rng_key: random number generator seed.
     :param args: arguments to the model / guide (these can possibly vary during
         the course of fitting).
     :param kwargs: keyword arguments to the model / guide (these can possibly vary
         during the course of fitting).
     :return: initial :data:`SVGDState`
     """
     rng_key, model_seed, guide_seed = jax.random.split(rng_key, 3)
     model_init = handlers.seed(self.model, model_seed)
     guide_init = handlers.seed(self.guide, guide_seed)
     guide_trace = handlers.trace(guide_init).get_trace(
         *args, **kwargs, **self.static_kwargs)
     model_trace = handlers.trace(model_init).get_trace(
         *args, **kwargs, **self.static_kwargs)
     rng_key, particle_seed = jax.random.split(rng_key)
     particle_seeds = jax.random.split(particle_seed,
                                       num=self.num_stein_particles)
     self.guide.find_params(
         particle_seeds, *args, **kwargs,
         **self.static_kwargs)  # Get parameter values for each particle
     guide_init_params = self.guide.init_params()
     params = {}
     transforms = {}
     inv_transforms = {}
     guide_param_names = set()
     # NB: params in model_trace will be overwritten by params in guide_trace
     for site in list(model_trace.values()) + list(guide_trace.values()):
         if site['type'] == 'param':
             constraint = site['kwargs'].pop('constraint', constraints.real)
             transform = biject_to(constraint)
             inv_transforms[site['name']] = transform
             transforms[site['name']] = transform.inv
             if site['name'] in guide_init_params:
                 pval, _ = guide_init_params[site['name']]
             else:
                 pval = site['value']
             params[site['name']] = transform.inv(pval)
             if site['name'] in guide_trace:
                 guide_param_names.add(site['name'])
     self.guide_param_names = guide_param_names
     self.constrain_fn = jax.partial(transform_fn, inv_transforms)
     self.uconstrain_fn = jax.partial(transform_fn, transforms)
     classic_uparam_names = {
         p
         for p in params.keys() if p not in self.guide_param_names
         or self.classic_guide_params_fn(p)
     }
     # Ensure not to sample parameters that should be classically updated
     sampler = self.sampler_fn(
         handlers.block(self.model,
                        lambda site: site['name'] in classic_uparam_names),
         **self.sampler_kwargs)
     self.mcmc = MCMC(sampler,
                      self.num_mcmc_warmup,
                      self.num_mcmc_updates,
                      num_chains=self.num_mcmc_particles,
                      progress_bar=False,
                      **self.mcmc_kwargs)
     return SVGDState(self.optim.init(params), rng_key)
Beispiel #21
0
 def _setup_prototype(self, *args, **kwargs):
     # run the model so we can inspect its structure
     self.prototype_trace = block(trace(self.model).get_trace)(*args,
                                                               **kwargs)
Beispiel #22
0
def test_block():
    with handlers.trace() as trace:
        with handlers.block(hide=['x']):
            with handlers.seed(rng_seed=0):
                numpyro.sample('x', dist.Normal())
    assert 'x' not in trace
Beispiel #23
0
def _sample_posterior(model, first_available_dim, temperature, rng_key, *args,
                      **kwargs):

    if temperature == 0:
        sum_op, prod_op = funsor.ops.max, funsor.ops.add
        approx = funsor.approximations.argmax_approximate
    elif temperature == 1:
        sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add
        rng_key, sub_key = random.split(rng_key)
        approx = funsor.montecarlo.MonteCarlo(rng_key=sub_key)
    else:
        raise ValueError("temperature must be 0 (map) or 1 (sample) for now")

    if first_available_dim is None:
        with block():
            model_trace = trace(seed(model,
                                     rng_key)).get_trace(*args, **kwargs)
        first_available_dim = -_guess_max_plate_nesting(model_trace) - 1

    with block(), enum(first_available_dim=first_available_dim):
        with plate_to_enum_plate():
            model_tr = packed_trace(model).get_trace(*args, **kwargs)

    terms = terms_from_trace(model_tr)
    # terms["log_factors"] = [log p(x) for each observed or latent sample site x]
    # terms["log_measures"] = [log p(z) or other Dice factor
    #                          for each latent sample site z]

    with funsor.interpretations.lazy:
        log_prob = funsor.sum_product.sum_product(
            sum_op,
            prod_op,
            list(terms["log_factors"].values()) +
            list(terms["log_measures"].values()),
            eliminate=terms["measure_vars"] | terms["plate_vars"],
            plates=terms["plate_vars"],
        )
        log_prob = funsor.optimizer.apply_optimizer(log_prob)

    with approx:
        approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)

    # construct a result trace to replay against the model
    sample_tr = model_tr.copy()
    sample_subs = {}
    for name, node in sample_tr.items():
        if node["type"] != "sample":
            continue
        if node["is_observed"]:
            # "observed" values may be collapsed samples that depend on enumerated
            # values, so we have to slice them down
            # TODO this should really be handled entirely under the hood by adjoint
            output = funsor.Reals[node["fn"].event_shape]
            value = funsor.to_funsor(node["value"],
                                     output,
                                     dim_to_name=node["infer"]["dim_to_name"])
            value = value(**sample_subs)
            node["value"] = funsor.to_data(
                value, name_to_dim=node["infer"]["name_to_dim"])
        else:
            log_measure = approx_factors[terms["log_measures"][name]]
            sample_subs[name] = _get_support_value(log_measure, name)
            node["value"] = funsor.to_data(
                sample_subs[name], name_to_dim=node["infer"]["name_to_dim"])

    with replay(guide_trace=sample_tr):
        return model(*args, **kwargs)
Beispiel #24
0
def test_block():
    with handlers.trace() as trace:
        with handlers.block(hide=["x"]):
            with handlers.seed(rng_seed=0):
                numpyro.sample("x", dist.Normal())
    assert "x" not in trace