Пример #1
0
def test_ravel_pytree_batched(pytree, nbatch_dims):
    flat, _, unravel_fn = batch_ravel_pytree(pytree, nbatch_dims)
    unravel = unravel_fn(flat)
    tree_flatten(
        tree_multimap(lambda x, y: assert_allclose(x, y), unravel, pytree))
    assert all(
        tree_flatten(
            tree_multimap(
                lambda x, y: jnp.result_type(x) == jnp.result_type(y), unravel,
                pytree))[0])
Пример #2
0
    def init(self, rng_key, *args, **kwargs):
        """
        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: initial :data:`SteinVIState`
        """
        rng_key, kernel_seed, model_seed, guide_seed = jax.random.split(
            rng_key, 4)
        model_init = handlers.seed(self.model, model_seed)
        guide_init = handlers.seed(self.guide, guide_seed)
        guide_trace = handlers.trace(guide_init).get_trace(
            *args, **kwargs, **self.static_kwargs)
        model_trace = handlers.trace(model_init).get_trace(
            *args, **kwargs, **self.static_kwargs)
        rng_key, particle_seed = jax.random.split(rng_key)
        guide_init_params = self._find_init_params(particle_seed, self.guide,
                                                   guide_trace)
        params = {}
        transforms = {}
        inv_transforms = {}
        particle_transforms = {}
        guide_param_names = set()
        should_enum = False
        for site in model_trace.values():
            if ("fn" in site and site["type"] == "sample"
                    and not site["is_observed"]
                    and isinstance(site["fn"], Distribution)
                    and site["fn"].is_discrete):
                if site["fn"].has_enumerate_support and self.enum:
                    should_enum = True
                else:
                    raise Exception(
                        "Cannot enumerate model with discrete variables without enumerate support"
                    )
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in chain(model_trace.values(), guide_trace.values()):
            if site["type"] == "param":
                transform = get_parameter_transform(site)
                inv_transforms[site["name"]] = transform
                transforms[site["name"]] = transform.inv
                particle_transforms[site["name"]] = site.get(
                    "particle_transform", IdentityTransform())
                if site["name"] in guide_init_params:
                    pval, _ = guide_init_params[site["name"]]
                    if self.classic_guide_params_fn(site["name"]):
                        pval = tree_map(lambda x: x[0], pval)
                else:
                    pval = site["value"]
                params[site["name"]] = transform.inv(pval)
                if site["name"] in guide_trace:
                    guide_param_names.add(site["name"])

        if should_enum:
            mpn = _guess_max_plate_nesting(model_trace)
            self._inference_model = enum(config_enumerate(self.model),
                                         -mpn - 1)
        self.guide_param_names = guide_param_names
        self.constrain_fn = partial(transform_fn, inv_transforms)
        self.uconstrain_fn = partial(transform_fn, transforms)
        self.particle_transforms = particle_transforms
        self.particle_transform_fn = partial(transform_fn, particle_transforms)
        stein_particles, _, _ = batch_ravel_pytree(
            {
                k: params[k]
                for k, site in guide_trace.items() if site["type"] == "param"
                and site["name"] in guide_init_params
            },
            nbatch_dims=1,
        )

        self.kernel_fn.init(kernel_seed, stein_particles.shape)
        return SteinVIState(self.optim.init(params), rng_key)
Пример #3
0
    def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs):
        # 0. Separate model and guide parameters, since only guide parameters are updated using Stein
        classic_uparams = {
            p: v
            for p, v in unconstr_params.items() if
            p not in self.guide_param_names or self.classic_guide_params_fn(p)
        }
        stein_uparams = {
            p: v
            for p, v in unconstr_params.items() if p not in classic_uparams
        }
        # 1. Collect each guide parameter into monolithic particles that capture correlations
        # between parameter values across each individual particle
        stein_particles, unravel_pytree, unravel_pytree_batched = batch_ravel_pytree(
            stein_uparams, nbatch_dims=1)
        particle_info = self._calc_particle_info(stein_uparams,
                                                 stein_particles.shape[0])

        # 2. Calculate loss and gradients for each parameter
        def scaled_loss(rng_key, classic_params, stein_params):
            params = {**classic_params, **stein_params}
            loss_val = self.loss.loss(
                rng_key,
                params,
                handlers.scale(self._inference_model, self.loss_temperature),
                self.guide,
                *args,
                **kwargs,
            )
            return -loss_val

        def kernel_particle_loss_fn(ps):
            return scaled_loss(
                rng_key,
                self.constrain_fn(classic_uparams),
                self.constrain_fn(unravel_pytree(ps)),
            )

        def particle_transform_fn(particle):
            params = unravel_pytree(particle)

            tparams = self.particle_transform_fn(params)
            tparticle, _ = ravel_pytree(tparams)
            return tparticle

        tstein_particles = jax.vmap(particle_transform_fn)(stein_particles)

        loss, particle_ljp_grads = jax.vmap(
            jax.value_and_grad(kernel_particle_loss_fn))(tstein_particles)
        classic_param_grads = jax.vmap(
            lambda ps: jax.grad(lambda cps: scaled_loss(
                rng_key,
                self.constrain_fn(cps),
                self.constrain_fn(unravel_pytree(ps)),
            ))(classic_uparams))(stein_particles)
        classic_param_grads = tree_map(partial(jnp.mean, axis=0),
                                       classic_param_grads)

        # 3. Calculate kernel on monolithic particle
        kernel = self.kernel_fn.compute(stein_particles, particle_info,
                                        kernel_particle_loss_fn)

        # 4. Calculate the attractive force and repulsive force on the monolithic particles
        attractive_force = jax.vmap(lambda y: jnp.sum(
            jax.vmap(lambda x, x_ljp_grad: self._apply_kernel(
                kernel, x, y, x_ljp_grad))
            (tstein_particles, particle_ljp_grads),
            axis=0,
        ))(tstein_particles)
        repulsive_force = jax.vmap(lambda y: jnp.sum(
            jax.vmap(lambda x: self.repulsion_temperature * self._kernel_grad(
                kernel, x, y))(tstein_particles),
            axis=0,
        ))(tstein_particles)

        def single_particle_grad(particle, att_forces, rep_forces):
            reparam_jac = {
                k: jax.tree_map(
                    lambda variable: jax.jacfwd(self.particle_transforms[k].inv
                                                )(variable),
                    variables,
                )
                for k, variables in unravel_pytree(particle).items()
            }
            jac_params = jax.tree_multimap(
                lambda af, rf, rjac:
                ((af.reshape(-1) + rf.reshape(-1)) @ rjac.reshape(
                    (_numel(rjac.shape[:len(rjac.shape) // 2]), -1))).reshape(
                        rf.shape),
                unravel_pytree(att_forces),
                unravel_pytree(rep_forces),
                reparam_jac,
            )
            jac_particle, _ = ravel_pytree(jac_params)
            return jac_particle

        particle_grads = (jax.vmap(single_particle_grad)(
            stein_particles, attractive_force, repulsive_force) /
                          self.num_particles)

        # 5. Decompose the monolithic particle forces back to concrete parameter values
        stein_param_grads = unravel_pytree_batched(particle_grads)

        # 6. Return loss and gradients (based on parameter forces)
        res_grads = tree_map(lambda x: -x, {
            **classic_param_grads,
            **stein_param_grads
        })
        return -jnp.mean(loss), res_grads