Example #1
0
 def _get_cached_fn(self):
     if self._jit_model_args:
         args, kwargs = (None, ), (None, )
     else:
         args = tree_map(lambda x: _hashable(x), self._args)
         kwargs = tree_map(lambda x: _hashable(x),
                           tuple(sorted(self._kwargs.items())))
     key = args + kwargs
     try:
         fn = self._cache.get(key, None)
     # If unhashable arguments are provided, proceed normally
     # without caching
     except TypeError:
         fn, key = None, None
     if fn is None:
         if self._jit_model_args:
             fn = partial(_sample_fn_jit_args, sampler=self.sampler)
         else:
             fn = partial(_sample_fn_nojit_args,
                          sampler=self.sampler,
                          args=self._args,
                          kwargs=self._kwargs)
         if key is not None:
             self._cache[key] = fn
     return fn
Example #2
0
def get_potential_fn(rng_key, model, dynamic_args=False, model_args=(), model_kwargs=None):
    """
    (EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns a
    function which, given unconstrained parameters, evaluates the potential
    energy (negative log joint density). In addition, this returns a
    function to transform unconstrained values at sample sites to constrained
    values within their respective support.

    :param jax.random.PRNGKey rng_key: random number generator seed to
        sample from the prior. The returned `init_params` will have the
        batch shape ``rng_key.shape[:-1]``.
    :param model: Python callable containing Pyro primitives.
    :param bool dynamic_args: if `True`, the `potential_fn` and
        `constraints_fn` are themselves dependent on model arguments.
        When provided a `*model_args, **model_kwargs`, they return
        `potential_fn` and `constraints_fn` callables, respectively.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :return: tuple of (`potential_fn`, `constrain_fn`). The latter is used
        to constrain unconstrained samples (e.g. those returned by HMC)
        to values that lie within the site's support.
    """
    model_kwargs = {} if model_kwargs is None else model_kwargs
    seeded_model = seed(model, rng_key if rng_key.ndim == 1 else rng_key[0])
    model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
    inv_transforms = {}
    has_transformed_dist = False
    for k, v in model_trace.items():
        if v['type'] == 'sample' and not v['is_observed']:
            if v['intermediates']:
                inv_transforms[k] = biject_to(v['fn'].base_dist.support)
                has_transformed_dist = True
            else:
                inv_transforms[k] = biject_to(v['fn'].support)
        elif v['type'] == 'param':
            constraint = v['kwargs'].pop('constraint', real)
            transform = biject_to(constraint)
            if isinstance(transform, ComposeTransform):
                inv_transforms[k] = transform.parts[0]
                has_transformed_dist = True
            else:
                inv_transforms[k] = transform

    if dynamic_args:
        def potential_fn(*args, **kwargs):
            return jax.partial(potential_energy, model, inv_transforms, args, kwargs)
        if has_transformed_dist:
            def constrain_fun(*args, **kwargs):
                return jax.partial(constrain_fn, model, inv_transforms, args, kwargs)
        else:
            def constrain_fun(*args, **kwargs):
                return jax.partial(transform_fn, inv_transforms)
    else:
        potential_fn = jax.partial(potential_energy, model, inv_transforms, model_args, model_kwargs)
        if has_transformed_dist:
            constrain_fun = jax.partial(constrain_fn, model, inv_transforms, model_args, model_kwargs)
        else:
            constrain_fun = jax.partial(transform_fn, inv_transforms)

    return potential_fn, constrain_fun
Example #3
0
 def constrain_fun(*args, **kwargs):
     inv_transforms, has_transformed_dist = get_model_transforms(
         rng_key, model, args, kwargs)
     if has_transformed_dist:
         return jax.partial(constrain_fn, model, inv_transforms, args,
                            kwargs)
     else:
         return jax.partial(transform_fn, inv_transforms)
Example #4
0
    def run(self, rng, *args, collect_fields=('z',), collect_warmup=False, init_params=None, **kwargs):
        """
        Run the MCMC samplers and collect samples.

        :param random.PRNGKey rng: Random number generator key to be used for the sampling.
        :param args: Arguments to be provided to the :meth:`numpyro.mcmc.MCMCKernel.init` method.
            These are typically the arguments needed by the `model`.
        :param collect_fields: Fields from :data:`numpyro.mcmc.HMCState` to collect
            during the MCMC run. By default, only the latent sample sites `z` is collected.
        :type collect_fields: tuple or list
        :param bool collect_warmup: Whether to collect samples from the warmup phase. Defaults
            to `False`.
        :param init_params: Initial parameters to begin sampling. The type must be consistent
            with the input type to `potential_fn`.
        :param kwargs: Keyword arguments to be provided to the :meth:`numpyro.mcmc.MCMCKernel.init`
            method. These are typically the keyword arguments needed by the `model`.
        """
        chain_method = self.chain_method
        if chain_method == 'parallel' and xla_bridge.device_count() < self.num_chains:
            chain_method = 'sequential'
            warnings.warn('There are not enough devices to run parallel chains: expected {} but got {}.'
                          ' Chains will be drawn sequentially. If you are running `mcmc` in CPU,'
                          ' consider to disable XLA intra-op parallelism by setting the environment'
                          ' flag "XLA_FLAGS=--xla_force_host_platform_device_count={}".'
                          .format(self.num_chains, xla_bridge.device_count(), self.num_chains))

        if init_params is not None and self.num_chains > 1:
            prototype_init_val = tree_flatten(init_params)[0][0]
            if np.shape(prototype_init_val)[0] != self.num_chains:
                raise ValueError('`init_params` must have the same leading dimension'
                                 ' as `num_chains`.')
        assert isinstance(collect_fields, (tuple, list))
        self._collect_fields = collect_fields
        if self.num_chains == 1:
            samples_flat = self._single_chain_mcmc((rng, init_params), collect_fields, collect_warmup,
                                                   args, kwargs)
            samples = tree_map(lambda x: x[np.newaxis, ...], samples_flat)
        else:
            rngs = random.split(rng, self.num_chains)
            partial_map_fn = partial(self._single_chain_mcmc,
                                     collect_fields=collect_fields,
                                     collect_warmup=collect_warmup,
                                     args=args,
                                     kwargs=kwargs)
            if chain_method == 'sequential':
                map_fn = partial(lax.map, partial_map_fn)
            elif chain_method == 'parallel':
                map_fn = pmap(partial_map_fn)
            elif chain_method == 'vectorized':
                map_fn = vmap(partial_map_fn)
            else:
                raise ValueError('Only supporting the following methods to draw chains:'
                                 ' "sequential", "parallel", or "vectorized"')
            samples = map_fn((rngs, init_params))
            samples_flat = tree_map(lambda x: np.reshape(x, (-1,) + x.shape[2:]), samples)
        self._samples = samples
        self._samples_flat = samples_flat
Example #5
0
def InitRandomRotation(rng: PRNGKey, jitted=False):
    # create marginal functions
    key = rng
    f = jax.partial(
        get_random_rotation,
        return_params=True,
    )

    f_slim = jax.partial(
        get_random_rotation,
        return_params=False,
    )

    if jitted:
        f = jax.jit(f)
        f_slim = jax.jit(f_slim)

    def init_params(inputs, **kwargs):

        key_, rng = kwargs.get("rng", jax.random.split(key, 2))

        _, params = f(rng, inputs)
        return params

    def params_and_transform(inputs, **kwargs):

        key_, rng = kwargs.get("rng", jax.random.split(key, 2))

        outputs, params = f(rng, inputs)
        return outputs, params

    def transform(inputs, **kwargs):

        key_, rng = kwargs.get("rng", jax.random.split(key, 2))

        outputs = f_slim(inputs)
        return outputs

    def bijector(inputs, **kwargs):
        params = init_params(inputs, **kwargs)
        bijector = Rotation(rotation=params.rotation, )
        return bijector

    def bijector_and_transform(inputs, **kwargs):
        print(inputs.shape)
        outputs, params = params_and_transform(inputs, **kwargs)
        bijector = Rotation(rotation=params.rotation, )
        return outputs, bijector

    return InitLayersFunctions(
        bijector=bijector,
        bijector_and_transform=bijector_and_transform,
        transform=transform,
        params=init_params,
        params_and_transform=params_and_transform,
    )
Example #6
0
def get_optimizer(model, hyperparams: ServerHyperParams):
    ffgb = FedAlgorithm(
        sampler=sampler,
        server_init=jax.partial(server_init, model, hyperparams),
        client_init=jax.jit(jax.partial(client_init, model, hyperparams)),
        client_step=jax.partial(client_step, model, hyperparams),
        client_end=None,
        server_step=jax.partial(server_step, model, hyperparams))

    return ffgb
Example #7
0
def get_potential_fn(rng_key,
                     model,
                     dynamic_args=False,
                     model_args=(),
                     model_kwargs=None):
    """
    (EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns a
    function which, given unconstrained parameters, evaluates the potential
    energy (negative log joint density). In addition, this returns a
    function to transform unconstrained values at sample sites to constrained
    values within their respective support.

    :param jax.random.PRNGKey rng_key: random number generator seed to
        sample from the prior. The returned `init_params` will have the
        batch shape ``rng_key.shape[:-1]``.
    :param model: Python callable containing Pyro primitives.
    :param bool dynamic_args: if `True`, the `potential_fn` and
        `constraints_fn` are themselves dependent on model arguments.
        When provided a `*model_args, **model_kwargs`, they return
        `potential_fn` and `constraints_fn` callables, respectively.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :return: tuple of (`potential_fn`, `constrain_fn`). The latter is used
        to constrain unconstrained samples (e.g. those returned by HMC)
        to values that lie within the site's support.
    """
    if dynamic_args:

        def potential_fn(*args, **kwargs):
            inv_transforms, has_transformed_dist = get_model_transforms(
                rng_key, model, args, kwargs)
            return jax.partial(potential_energy, model, inv_transforms, args,
                               kwargs)

        def constrain_fun(*args, **kwargs):
            inv_transforms, has_transformed_dist = get_model_transforms(
                rng_key, model, args, kwargs)
            if has_transformed_dist:
                return jax.partial(constrain_fn, model, inv_transforms, args,
                                   kwargs)
            else:
                return jax.partial(transform_fn, inv_transforms)
    else:
        inv_transforms, has_transformed_dist = get_model_transforms(
            rng_key, model, model_args, model_kwargs)
        potential_fn = jax.partial(potential_energy, model, inv_transforms,
                                   model_args, model_kwargs)
        if has_transformed_dist:
            constrain_fun = jax.partial(constrain_fn, model, inv_transforms,
                                        model_args, model_kwargs)
        else:
            constrain_fun = jax.partial(transform_fn, inv_transforms)

    return potential_fn, constrain_fun
Example #8
0
 def postprocess_fn(*args, **kwargs):
     inv_transforms, replay_model = get_model_transforms(
         rng_key, model, args, kwargs)
     if replay_model:
         return jax.partial(constrain_fn,
                            model,
                            inv_transforms,
                            args,
                            kwargs,
                            return_deterministic=True)
     else:
         return jax.partial(transform_fn, inv_transforms)
Example #9
0
def main():

    data = get_classic(100_000)

    # plot data
    plot_joint(
        data[:1_000],
        "blue",
        "Original Data",
        kind="kde",
        save_name=str(Path(FIG_PATH).joinpath("joint_data.png")),
    )

    # define marginal entropy function
    entropy_f = jax.partial(histogram_entropy, nbins=1_000, base=2)

    # define marginal uniformization function
    hist_transform_f = jax.partial(histogram_transform, nbins=1_000)

    n_iterations = 100

    X_trans, loss = total_corr_f(
        np.array(data).block_until_ready(),
        marginal_uni=hist_transform_f,
        marginal_entropy=entropy_f,
        n_iterations=n_iterations,
    )

    total_corr = np.sum(loss) * np.log(2)

    plot_info_loss(
        loss,
        n_layers=len(loss),
        save_name=str(Path(FIG_PATH).joinpath("info_loss.png")),
    )

    print(f"Total Correlation: {total_corr}")

    X_plot = onp.array(X_trans)

    plot_joint(
        X_plot[:10_000],
        "blue",
        "Latent Space",
        kind="kde",
        save_name=str(Path(FIG_PATH).joinpath("joint_latent.png")),
    )

    pass
Example #10
0
def e_step(
    rng: PRNGSequence,
    actor_target_params: FrozenDict,
    critic_target_params: FrozenDict,
    max_action: float,
    action_dim: int,
    temp: float,
    eps_eta: float,
    state: jnp.ndarray,
    batch_size: int,
    action_sample_size: int,
) -> Tuple[optim.Optimizer, jnp.ndarray, jnp.ndarray]:
    """
    The 'E-step' from the MPO paper. We calculate our weights, which correspond
    to the relative likelihood of obtaining the maximum reward for each of the
    sampled actions. We also take steps on our temperature parameter, which
    induces diversity in the weights.
    """
    Q1, sampled_actions = sample_actions_and_evaluate(
        rng,
        actor_target_params,
        critic_target_params,
        max_action,
        action_dim,
        state,
        batch_size,
        action_sample_size,
    )

    jac = jax.grad(dual, argnums=2)
    jac = jax.partial(jac, Q1, eps_eta)

    # use nonconvex optimizer to minimize the dual of the temperature parameter
    # we have direct access to the jacobian function with jax so we can take
    # advantage of it here
    this_dual = jax.partial(dual, Q1, eps_eta)
    bounds = [(1e-6, None)]
    res = minimize(this_dual, temp, jac=jac, method="SLSQP", bounds=bounds)
    temp = jax.lax.stop_gradient(res.x)

    # calculate the sample-based q distribution. we can think of these weights
    # as the relative likelihood of each of the sampled actions giving us the
    # maximum score when taken at the corresponding state.
    weights = jax.nn.softmax(Q1 / temp, axis=1)
    weights = jax.lax.stop_gradient(weights)
    weights = jnp.expand_dims(weights, axis=-1)

    return temp, weights, sampled_actions
Example #11
0
    def init_kernel(init_samples,
                    num_warmup_steps,
                    step_size=1.0,
                    num_steps=None,
                    adapt_step_size=True,
                    adapt_mass_matrix=True,
                    diag_mass=True,
                    target_accept_prob=0.8,
                    run_warmup=True,
                    rng=PRNGKey(0)):
        step_size = float(step_size)
        nonlocal trajectory_length, momentum_generator, wa_update

        if num_steps is None:
            trajectory_length = 2 * math.pi
        else:
            trajectory_length = num_steps * step_size

        z = init_samples
        z_flat, unravel_fn = ravel_pytree(z)
        momentum_generator = partial(_sample_momentum, unravel_fn)

        find_reasonable_ss = partial(find_reasonable_step_size, potential_fn,
                                     kinetic_fn, momentum_generator)

        wa_init, wa_update = warmup_adapter(
            num_warmup_steps,
            find_reasonable_step_size=find_reasonable_ss,
            adapt_step_size=adapt_step_size,
            adapt_mass_matrix=adapt_mass_matrix,
            diag_mass=diag_mass,
            target_accept_prob=target_accept_prob)

        rng_hmc, rng_wa = random.split(rng)
        wa_state = wa_init(z, rng_wa, mass_matrix_size=np.size(z_flat))
        r = momentum_generator(wa_state.inverse_mass_matrix, rng)
        vv_state = vv_init(z, r)
        hmc_state = HMCState(vv_state.z, vv_state.z_grad,
                             vv_state.potential_energy, 0, 0.,
                             wa_state.step_size, wa_state.inverse_mass_matrix,
                             rng_hmc)

        if run_warmup:
            hmc_state, _ = fori_loop(0, num_warmup_steps, warmup_update,
                                     (hmc_state, wa_state))
            return hmc_state
        else:
            return hmc_state, wa_state, warmup_update
Example #12
0
def mixture_gaussian_invcdf(
    x: JaxArray, prior_logits: JaxArray, means: JaxArray, scales: JaxArray
) -> JaxArray:
    """
    Args:
        x (JaxArray): input vector
            (D,)
        prior_logits (JaxArray): prior logits to weight the components
            (D, K)
        means (JaxArray): means per component per feature
            (D, K)
        scales (JaxArray): scales per component per feature
            (D, K)
    Returns:
        x_invcdf (JaxArray) : log CDF for the mixture distribution
    """
    # INITIALIZE BOUNDS
    init_lb = np.ones_like(means).max(axis=1) - 1_000.0
    init_ub = np.ones_like(means).max(axis=1) + 1_000.0

    # INITIALIZE FUNCTION
    f = jax.partial(
        mixture_gaussian_cdf, prior_logits=prior_logits, means=means, scales=scales,
    )

    return bisection_search(f, x, init_lb, init_ub)
Example #13
0
def _call_init(primitive, rng, submodule_params, params, jaxpr, consts,
               freevar_vals, in_vals, **kwargs):
    jaxpr, = jaxpr
    consts, = consts
    freevar_vals, = freevar_vals
    f = lu.wrap_init(partial(jc.eval_jaxpr, jaxpr, consts, freevar_vals))
    return primitive.bind(f, *in_vals, **params), submodule_params
Example #14
0
    def body_fn(state):
        i, key, _, _ = state
        key, subkey = random.split(key)

        # Wrap model in a `substitute` handler to initialize from `init_loc_fn`.
        # Use `block` to not record sample primitives in `init_loc_fn`.
        seeded_model = substitute(model, substitute_fn=block(seed(init_strategy, subkey)))
        model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
        constrained_values, inv_transforms = {}, {}
        for k, v in model_trace.items():
            if v['type'] == 'sample' and not v['is_observed']:
                if v['intermediates']:
                    constrained_values[k] = v['intermediates'][0][0]
                    inv_transforms[k] = biject_to(v['fn'].base_dist.support)
                else:
                    constrained_values[k] = v['value']
                    inv_transforms[k] = biject_to(v['fn'].support)
            elif v['type'] == 'param' and param_as_improper:
                constraint = v['kwargs'].pop('constraint', real)
                transform = biject_to(constraint)
                if isinstance(transform, ComposeTransform):
                    base_transform = transform.parts[0]
                    inv_transforms[k] = base_transform
                    constrained_values[k] = base_transform(transform.inv(v['value']))
                else:
                    inv_transforms[k] = transform
                    constrained_values[k] = v['value']
        params = transform_fn(inv_transforms,
                              {k: v for k, v in constrained_values.items()},
                              invert=True)
        potential_fn = jax.partial(potential_energy, model, inv_transforms, model_args, model_kwargs)
        pe, param_grads = value_and_grad(potential_fn)(params)
        z_grad = ravel_pytree(param_grads)[0]
        is_valid = np.isfinite(pe) & np.all(np.isfinite(z_grad))
        return i + 1, key, params, is_valid
Example #15
0
 def process_call(self, call_primitive, f, tracers, params):
     flat_inputs, submodule_params_iter = ApplyTrace.Tracer.merge(tracers)
     f = ApplyTrace._apply_subtrace(f, self.master,
                                    WrapHashably(submodule_params_iter))
     flat_outs = call_primitive.bind(f, *flat_inputs, **params)
     return map(partial(ApplyTrace.Tracer, self, submodule_params_iter),
                flat_outs)
Example #16
0
File: sample.py Project: rlouf/mcx
def build_loglikelihoods(model, args, observations):
    """Function to compute the loglikelihood contribution
    of each variable.
    """
    loglikelihoods = jax.partial(mcx.log_prob_contributions(model),
                                 **observations, **args)
    return loglikelihoods
Example #17
0
    def init(self, rng_key, *args, **kwargs):
        """

        :param jax.random.PRNGKey rng_key: random number generator seed.
        :param args: arguments to the model / guide (these can possibly vary during
            the course of fitting).
        :param kwargs: keyword arguments to the model / guide (these can possibly vary
            during the course of fitting).
        :return: tuple containing initial :data:`SVIState`, and `get_params`, a callable
            that transforms unconstrained parameter values from the optimizer to the
            specified constrained domain
        """
        rng_key, model_seed, guide_seed = random.split(rng_key, 3)
        model_init = seed(self.model, model_seed)
        guide_init = seed(self.guide, guide_seed)
        guide_trace = trace(guide_init).get_trace(*args, **kwargs,
                                                  **self.static_kwargs)
        model_trace = trace(model_init).get_trace(*args, **kwargs,
                                                  **self.static_kwargs)
        params = {}
        inv_transforms = {}
        # NB: params in model_trace will be overwritten by params in guide_trace
        for site in list(model_trace.values()) + list(guide_trace.values()):
            if site['type'] == 'param':
                constraint = site['kwargs'].pop('constraint', constraints.real)
                transform = biject_to(constraint)
                inv_transforms[site['name']] = transform
                params[site['name']] = transform.inv(site['value'])

        self.constrain_fn = jax.partial(transform_fn, inv_transforms)
        return SVIState(self.optim.init(params), rng_key)
Example #18
0
def initialize_model(rng,
                     model,
                     *model_args,
                     init_strategy='uniform',
                     **model_kwargs):
    """
    Given a model with Pyro primitives, returns a function which, given
    unconstrained parameters, evaluates the potential energy (negative
    joint density). In addition, this also returns initial parameters
    sampled from the prior to initiate MCMC sampling and functions to
    transform unconstrained values at sample sites to constrained values
    within their respective support.

    :param jax.random.PRNGKey rng: random number generator seed to
        sample from the prior.
    :param model: Python callable containing Pyro primitives.
    :param `*model_args`: args provided to the model.
    :param str init_strategy: initialization strategy - `uniform`
        initializes the unconstrained parameters by drawing from
        a `Uniform(-2, 2)` distribution (as used by Stan), whereas
        `prior` initializes the parameters by sampling from the prior
        for each of the sample sites.
    :param `**model_kwargs`: kwargs provided to the model.
    :return: tuple of (`init_params`, `potential_fn`, `constrain_fn`)
        `init_params` are values from the prior used to initiate MCMC.
        `constrain_fn` is a callable that uses inverse transforms
        to convert unconstrained HMC samples to constrained values that
        lie within the site's support.
    """
    model = seed(model, rng)
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    sample_sites = {
        k: v
        for k, v in model_trace.items()
        if v['type'] == 'sample' and not v['is_observed']
    }
    inv_transforms = {
        k: biject_to(v['fn'].support)
        for k, v in sample_sites.items()
    }
    prior_params = constrain_fn(
        inv_transforms, {k: v['value']
                         for k, v in sample_sites.items()},
        invert=True)
    if init_strategy == 'uniform':
        init_params = {}
        for k, v in prior_params.items():
            rng, = random.split(rng, 1)
            init_params[k] = random.uniform(rng,
                                            shape=np.shape(v),
                                            minval=-2,
                                            maxval=2)
    elif init_strategy == 'prior':
        init_params = prior_params
    else:
        raise ValueError(
            'initialize={} is not a valid initialization strategy.'.format(
                init_strategy))
    return init_params, potential_energy(model, model_args, model_kwargs, inv_transforms), \
        jax.partial(constrain_fn, inv_transforms)
Example #19
0
def module(name, nn, input_shape=None):
    """
    Declare a :mod:`~jax.experimental.stax` style neural network inside a
    model so that its parameters are registered for optimization via
    :func:`~numpyro.primitives.param` statements.

    :param str name: name of the module to be registered.
    :param tuple nn: a tuple of `(init_fn, apply_fn)` obtained by a :mod:`~jax.experimental.stax`
        constructor function.
    :param tuple input_shape: shape of the input taken by the
        neural network.
    :return: a `apply_fn` with bound parameters that takes an array
        as an input and returns the neural network transformed output
        array.
    """
    module_key = name + '$params'
    nn_init, nn_apply = nn
    nn_params = param(module_key)
    if nn_params is None:
        if input_shape is None:
            raise ValueError('Valid value for `input_size` needed to initialize.')
        rng = numpyro.sample(name + '$rng', PRNGIdentity())
        _, nn_params = nn_init(rng, input_shape)
        param(module_key, nn_params)
    return jax.partial(nn_apply, nn_params)
Example #20
0
    def __init__(self, nin, nclass, scales, filters, filters_max,
                 pooling=objax.functional.max_pool_2d, **kwargs):
        """Creates ConvNet instance.

        Args:
            nin: number of channels in the input image.
            nclass: number of output classes.
            scales: number of pooling layers, each of which reduces spatial dimension by 2.
            filters: base number of convolution filters.
                     Number of convolution filters is increased by 2 every scale until it reaches filters_max.
            filters_max: maximum number of filters.
            pooling: type of pooling layer.
        """
        del kwargs

        def nf(scale):
            return min(filters_max, filters << scale)

        ops = [objax.nn.Conv2D(nin, nf(0), 3), objax.functional.leaky_relu]
        for i in range(scales):
            ops.extend([objax.nn.Conv2D(nf(i), nf(i), 3), objax.functional.leaky_relu,
                        objax.nn.Conv2D(nf(i), nf(i + 1), 3), objax.functional.leaky_relu,
                        jax.partial(pooling, size=2, strides=2)])
        ops.extend([objax.nn.Conv2D(nf(scales), nclass, 3), self._mean_reduce])
        super().__init__(ops)
Example #21
0
    def solve(
        self,
        graph: "StackedFactorGraph",
        initial_assignments: VariableAssignments,
    ) -> VariableAssignments:
        """Run MAP inference on a factor graph."""

        # Initialize
        assignments = initial_assignments
        cost, residual_vector = graph.compute_cost(assignments)
        state = self._initialize_state(graph, initial_assignments)

        # Optimization
        state = jax.lax.while_loop(
            cond_fun=lambda state: jnp.logical_and(
                jnp.logical_not(state.done), state.iterations < self.
                max_iterations),
            body_fun=jax.partial(self._step, graph),
            init_val=state,
        )

        self._hcb_print(
            lambda i, max_i, cost:
            f"Terminated @ iteration #{i}/{max_i}: cost={str(cost).ljust(15)}",
            i=state.iterations,
            max_i=self.max_iterations,
            cost=state.cost,
        )

        return state.assignments
Example #22
0
 def batched(*batched_args):
     args = tree_map(lambda x: x[0], batched_args)
     params = Parameter(
         lambda key: unbatched_model.init_parameters(*args, key=key),
         'model')()
     batched_apply = vmap(partial(unbatched_model.apply, params), batch_dim)
     return batched_apply(*batched_args)
    def log_likelihood_grad_bias(self, data, reward_model, bias_params):
        om, empirical_om, _ = self._ll_compute_oms(data, reward_model,
                                                   bias_params)

        def blind_reward(biases, obs_matrix):
            """Compute blind reward for all states in such a way that Jax can
            differentiate with respect to the bias/masking vector. This is
            trivial for linear rewards, but harder for more general
            RewardModels."""
            blind_obs_mat = obs_matrix * biases
            assert blind_obs_mat.shape == obs_matrix.shape
            return reward_model.out(blind_obs_mat)

        # compute gradient of reward in each state w.r.t. biases
        # (we do this separately for each input)
        blind_rew_grad_fn = jax.grad(blind_reward)
        lifted_blind_rew_grad_fn = jax.vmap(
            jax.partial(blind_rew_grad_fn, bias_params))
        lifted_grads = lifted_blind_rew_grad_fn(self.env.observation_matrix)

        empirical_grad_term = jnp.sum(
            empirical_om[:, None] * lifted_grads, axis=0)
        pi_grad_term = jnp.sum(om[:, None] * lifted_grads, axis=0)
        grads = empirical_grad_term - pi_grad_term

        return grads
Example #24
0
def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier,
                 batch_size):
    """Return differentially private gradients for params, evaluated on batch."""
    def _clipped_grad(params, single_example_batch):
        """Evaluate gradient for a single-example batch and clip its grad norm."""
        grads = grad_loss(params, single_example_batch)

        nonempty_grads, tree_def = tree_util.tree_flatten(grads)
        total_grad_norm = np.linalg.norm(
            [np.linalg.norm(neg.ravel()) for neg in nonempty_grads])
        divisor = stop_gradient(np.amax((total_grad_norm / l2_norm_clip, 1.)))
        normalized_nonempty_grads = [g / divisor for g in nonempty_grads]
        return tree_util.tree_unflatten(tree_def, normalized_nonempty_grads)

    px_clipped_grad_fn = vmap(partial(_clipped_grad, params))
    std_dev = l2_norm_clip * noise_multiplier
    noise_ = lambda n: n + std_dev * random.normal(rng, n.shape)
    normalize_ = lambda n: n / float(batch_size)
    tree_map = tree_util.tree_map
    sum_ = lambda n: np.sum(n, 0)  # aggregate
    aggregated_clipped_grads = tree_map(sum_, px_clipped_grad_fn(batch))
    noised_aggregated_clipped_grads = tree_map(noise_,
                                               aggregated_clipped_grads)
    normalized_noised_aggregated_clipped_grads = (tree_map(
        normalize_, noised_aggregated_clipped_grads))
    return normalized_noised_aggregated_clipped_grads
Example #25
0
    def fit_transform(self, X):

        self.n_features = X.shape[1]

        # initialize parameter storage
        params = []
        losses = []
        i_layer = 0

        # loop through
        while i_layer < self.max_layers:

            loss = jax.partial(self.loss_f, X=X)

            # fix info criteria
            X, block_params = self.block_forward(X)

            info_red = loss(Y=X)

            # append Parameters
            params.append(block_params)
            losses.append(info_red)

            i_layer += 1

        self.n_layers = i_layer
        self.params = params
        self.info_loss = np.array(losses)
        return X
Example #26
0
 def grads(self, inputs):
     in_grad_partial = jax.partial(self._net_grads, self._net_params)
     grad_vmap = jax.vmap(in_grad_partial)
     rich_grads = grad_vmap(inputs)
     flat_grads = np.asarray(self._flatten_batch(rich_grads))
     assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
     return flat_grads
Example #27
0
def _parse_nn_pepo_obc(C, D, vL, vR, vB, vT, lx, ly):
    assert (lx > 2) and (ly > 2
                         )  # (otherwise there is no bulk, to put the Ds in)
    x_d = lx // 2
    y_d = ly // 2

    # HORIZONTAL
    vL_C = np.tensordot(vL, C, [0, 2])  # (p,p*,r)
    C_vR = np.tensordot(vR, C, [0, 3])  # (p,p*,l)
    vB_D = np.tensordot(vB, D, [0, 4])  # (p,p*,l,r,u)
    D_vT = np.tensordot(vT, D, [0, 5])  # (p,p*,l,r,d)

    left_col = [
        vL_C[:, :, :, None]
    ] + [vL_C[:, :, None, :, None]] * (ly - 2) + [vL_C[:, :, None, :]]

    # bottom C:  (p,p*,i,j) = (p,p*,l,r) -> (p,p*,r,l) -> (p,p*,r,u,l)
    # bulk C: (p,p*,i,j) = (p,p*,l,r) -> (p,p*,u,l,d,r)
    # top C: (p,p*,i,j) = (p,p*,l,r) -> (p,p*,l,d,r)
    mid_col = [np.transpose(C, [0, 1, 3, 2])[:, :, :, None, :]] \
              + [C[:, :, None, :, None, :]] * (ly - 2) \
              + [C[:, :, :, None, :]]

    # vB_D: (p,p*,ijl) = (p,p*,lru) -> (p,p*,rul)
    # D: (p,p*,ijkl) -> (p,p*,likj) = (p,p*,uldr)
    # D_vT: (p,p*,ijk) = (p,p*,lrd) -> (p,p*,ldr)
    d_col = [np.transpose(vB_D, [0, 1, 3, 4, 2])] \
            + [np.transpose(D, [0, 1, 5, 2, 4, 3])] * (ly - 2) \
            + [np.transpose(D_vT, [0, 1, 2, 4, 3])]

    right_col = [
        C_vR[:, :, None, :]
    ] + [C_vR[:, :, None, :, None]] * (ly - 2) + [C_vR[:, :, :, None]]
    tensors = [left_col] \
              + [mid_col] * (x_d - 1) \
              + [d_col] \
              + [mid_col] * (lx - x_d - 2) \
              + [right_col]
    pepo_hor = Pepo(
        tensors, OBC, False
    )  # even if the NnPepo is hermitian, the two separate Pepos could be not.

    # VERTICAL
    # rotate tensors clockwise

    # (p,p*,u,l,d,r) -> (p,p*,l,d,r,u)
    _rotate90 = partial(np.transpose, axes=[0, 1, 3, 4, 5, 2])

    # tensor at new location (x,y) was at (-y-1,x) before
    tensors = [[tensors[-y - 1][0] for y in range(ly)]] \
              + [[tensors[-1][x]] + [_rotate90(tensors[-y - 1][x]) for y in range(1, ly - 1)]
                 + [tensors[0][x]] for x in range(1, lx - 1)] \
              + [[tensors[-y - 1][-1] for y in range(ly)]]

    pepo_vert = Pepo(
        tensors, OBC, False
    )  # even if the NnPepo is hermitian, the two separate Pepos could be not.

    return pepo_hor, pepo_vert
Example #28
0
def build_loglikelihoods(model, **kwargs):
    """Function to compute the loglikelihood contribution
    of each variable.
    """
    artifact = compile_to_loglikelihoods(model.graph, model.namespace)
    logpdf = artifact.compiled_fn
    loglikelihoods = jax.partial(logpdf, **kwargs)
    return loglikelihoods
Example #29
0
 def _potential_energy(params):
     params_constrained = constrain_fn(inv_transforms, params)
     log_joint = jax.partial(log_density, model, model_args,
                             model_kwargs)(params_constrained)[0]
     for name, t in inv_transforms.items():
         log_joint = log_joint + np.sum(
             t.log_abs_det_jacobian(params[name], params_constrained[name]))
     return -log_joint
Example #30
0
 def _apply_subtrace(master, submodule_params, *vals):
     submodule_params = submodule_params.val
     trace = ApplyTrace(master, jc.cur_sublevel())
     outs = yield map(
         partial(ApplyTrace.Tracer, trace,
                 ApplyTrace.SubmoduleParamsIterator(submodule_params)),
         vals), {}
     out_tracers = map(trace.full_raise, outs)
     yield [t.val for t in out_tracers]