Пример #1
0
 def zoneout_decoder(inputs, prev_state):
   x, mask = inputs
   x, state = self.decoder(x, prev_state)
   state = jax.tree_multimap(lambda m, s1, s2: s1*m + s2*(1-m), mask, prev_state, state)
   return x, state
Пример #2
0
def prepare_centered_oks(
    apply_fun: Callable,
    params: PyTree,
    samples: Array,
    model_state: Optional[PyTree],
    mode: str,
    rescale_shift: bool,
    pdf=None,
    chunk_size: int = None,
) -> PyTree:
    """
    compute ΔOⱼₖ = Oⱼₖ - ⟨Oₖ⟩ = ∂/∂pₖ ln Ψ(σⱼ) - ⟨∂/∂pₖ ln Ψ⟩
    divided by √n

    In a somewhat intransparent way this also internally splits all parameters to real
    in the 'real' and 'complex' modes (for C→R, R&C→R, R&C→C and general C→C) resulting in the respective ΔOⱼₖ
    which is only compatible with split-to-real pytree vectors

    Args:
        apply_fun: The forward pass of the Ansatz
        params : a pytree of parameters p
        samples : an array of (n in total) batched samples σ
        model_state: untrained state parameters of the model
        mode: differentiation mode, must be one of 'real', 'complex', 'holomorphic'
        rescale_shift: whether scale-invariant regularisation should be used (default: True)
        pdf: |ψ(x)|^2 if exact optimization is being used else None
        chunk_size: an int specfying the size of the chunks the gradient should be computed in (default: None)

    Returns:
        if not rescale_shift:
            a pytree representing the centered jacobian of ln Ψ evaluated at the samples σ, divided by √n;
            None
        else:
            the same pytree, but the entries for each parameter normalised to unit norm;
            pytree containing the norms that were divided out (same shape as params)

    """
    # un-batch the samples
    samples = samples.reshape((-1, samples.shape[-1]))

    # pre-apply the model state
    def forward_fn(W, σ):
        return apply_fun({"params": W, **model_state}, σ)

    if mode == "real":
        split_complex_params = True  # convert C→R and R&C→R to R→R
        centered_jacobian_fun = centered_jacobian_real_holo
        jacobian_fun = jacobian_real_holo
    elif mode == "complex":
        split_complex_params = True  # convert C→C and R&C→C to R→C
        # centered_jacobian_fun = compose(stack_jacobian, centered_jacobian_cplx)

        # avoid converting to complex and then back
        # by passing around the oks as a tuple of two pytrees representing the real and imag parts
        centered_jacobian_fun = compose(
            stack_jacobian_tuple,
            partial(centered_jacobian_cplx, _build_fn=lambda *x: x),
        )
        jacobian_fun = jacobian_cplx
    elif mode == "holomorphic":
        split_complex_params = False
        centered_jacobian_fun = centered_jacobian_real_holo
        jacobian_fun = jacobian_real_holo
    else:
        raise NotImplementedError(
            'Differentiation mode should be one of "real", "complex", or "holomorphic", got {}'
            .format(mode))

    if split_complex_params:
        # doesn't do anything if the params are already real
        params, reassemble = tree_to_real(params)

        def f(W, σ):
            return forward_fn(reassemble(W), σ)

    else:
        f = forward_fn

    if pdf is None:
        centered_oks = _divide_by_sqrt_n_samp(
            centered_jacobian_fun(
                f,
                params,
                samples,
                chunk_size=chunk_size,
            ),
            samples,
        )
    else:
        oks = jacobian_fun(f, params, samples)
        oks_mean = jax.tree_map(partial(sum, axis=0),
                                _multiply_by_pdf(oks, pdf))
        centered_oks = jax.tree_multimap(lambda x, y: x - y, oks, oks_mean)

        centered_oks = _multiply_by_pdf(centered_oks, jnp.sqrt(pdf))
    if rescale_shift:
        return _rescale(centered_oks)
    else:
        return centered_oks, None
Пример #3
0
 def test_frozen_dict_partially_maps(self):
     x = jax.tree_multimap(lambda a, b: (a, b), freeze({'a': 2}),
                           freeze({'a': {
                               'b': 1
                           }}))
     self.assertEqual(unfreeze(x), {'a': (2, {'b': 1})})
Пример #4
0
 def test_prng_sequence_split(self):
     k = jax.random.PRNGKey(42)
     s = base.PRNGSequence(k)
     hk_keys = s.take(10)
     jax_keys = tuple(jax.random.split(k, num=11)[1:])
     jax.tree_multimap(np.testing.assert_array_equal, hk_keys, jax_keys)
Пример #5
0
def incremental_update(new_tensors, old_tensors, tau: Numeric):
    """Incrementally update all elements from a nested struct."""
    return jax.tree_multimap(lambda new, old: tau * new + (1.0 - tau) * old,
                             new_tensors, old_tensors)
Пример #6
0
 def dequeue():
     batch = []
     for _ in range(batch_size):
         batch.append(q.get())
     batch = jax.tree_multimap(lambda *ts: np.stack(ts, axis=1), *batch)
     return jax.device_put(batch)
Пример #7
0
def conditional_update(new_tensors: Any, old_tensors: Any, is_time: Numeric):
    """Checks whether to update the params and returns the correct params."""
    return jax.tree_multimap(
        lambda new, old: jax.lax.select(is_time, new, old), new_tensors,
        old_tensors)
Пример #8
0
    def inner(scope_fn, repack_fn, variable_groups, rng_groups, init, *args):
        def find_length(axis, x):
            if axis is not axes_scan.broadcast:
                leaves = jax.tree_leaves(x)
                if leaves:
                    return leaves[0].shape[axis]
            return ()

        # split rngs
        lengths = jax.tree_multimap(find_length, in_axes, args)
        lengths = set(jax.tree_leaves(lengths))
        if length is None and len(lengths) == 1:
            d_length, = lengths
        elif len(lengths) > 1:
            raise ValueError(f'Inconsistent scan lengths: {lengths}')
        elif length is None:
            raise ValueError('length should be specified manually.')
        else:
            d_length = length
        split_fn = lambda rng: random.split(rng, d_length)

        rng_groups = tuple(
            jax.tree_map(split_fn, rng_group) if split else rng_group
            for rng_group, split in zip(rng_groups, rng_splits))

        @functools.partial(axes_scan.scan,
                           in_axes=(variable_in_axes, rng_axes, in_axes),
                           out_axes=(out_axes, variable_out_axes),
                           length=length,
                           reverse=reverse)
        def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups,
                    args):
            carry_vars, c = carry
            variable_groups = (broadcast_vars,
                               carry_vars) + scan_variable_groups
            if data_transform is not None:
                variable_groups, rng_groups = data_transform(
                    variable_groups, rng_groups)
            scope = scope_fn(variable_groups, rng_groups)
            c, y = fn(scope, c, *args)
            out_vars = repack_fn(scope)
            broadcast_vars_out = out_vars[0]
            carry_vars = out_vars[1]
            scan_vars = out_vars[2:]
            # add immutable broadcast vars back to broadcast output
            # otherwise they won't be fed to the actual scan body
            for in_group, out_group in zip(broadcast_vars, broadcast_vars_out):
                for col in in_group:
                    if col not in out_group:
                        out_group[col] = in_group[col]
            return broadcast_vars_out, (carry_vars, c), (y, scan_vars)

        broadcast_vars = variable_groups[0]
        carry_vars = variable_groups[1]
        scan_vars = variable_groups[2:]
        broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
            broadcast_vars, (carry_vars, init), scan_vars, rng_groups, args)
        # remove immutable broadcast vars otherwise they will be updated
        # with their own value which will cause an error
        for out_group in broadcast_vars:
            for name, col in tuple(out_group.items()):
                if isinstance(col, FrozenDict):
                    del out_group[name]
        out_vars = (
            broadcast_vars,
            carry_vars,
        ) + scan_vars
        return (c, ys), out_vars
Пример #9
0
def tree_equal(x, y):
    return jax.tree_util.tree_all(jax.tree_multimap(lambda x, y: x == y, x, y))
Пример #10
0
def scan_enum(f, init, xs, length, reverse, rng_key=None, substitute_stack=None, history=1,
              first_available_dim=None):
    from numpyro.contrib.funsor import config_enumerate, enum, markov
    from numpyro.contrib.funsor import trace as packed_trace

    # amount number of steps to unroll
    history = min(history, length)
    unroll_steps = min(2 * history - 1, length)
    if reverse:
        x0 = tree_map(lambda x: x[-unroll_steps:][::-1], xs)
        xs_ = tree_map(lambda x: x[:-unroll_steps], xs)
    else:
        x0 = tree_map(lambda x: x[:unroll_steps], xs)
        xs_ = tree_map(lambda x: x[unroll_steps:], xs)

    carry_shapes = []

    def body_fn(wrapped_carry, x, prefix=None):
        i, rng_key, carry = wrapped_carry
        init = True if (not_jax_tracer(i) and i in range(unroll_steps)) else False
        rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)

        # we need to tell unconstrained messenger in potential energy computation
        # that only the item at time `i` is needed when transforming
        fn = handlers.infer_config(f, config_fn=lambda msg: {'_scan_current_index': i})

        seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn
        for subs_type, subs_map in substitute_stack:
            subs_fn = partial(_subs_wrapper, subs_map, i, length)
            if subs_type == 'condition':
                seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
            elif subs_type == 'substitute':
                seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

        if init:
            # handler the name to match the pattern of sakkar_bilmes product
            with handlers.scope(prefix='_PREV_' * (unroll_steps - i), divider=''):
                new_carry, y = config_enumerate(seeded_fn)(carry, x)
                trace = {}
        else:
            # Like scan_wrapper, we collect the trace of scan's transition function
            # `seeded_fn` here. To put time dimension to the correct position, we need to
            # promote shapes to make `fn` and `value`
            # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
            # and value's batch_shape is (3,), then we promote shape of
            # value so that its batch shape is (1, 3)).
            # Here we will promote `fn` shape first. `value` shape will be promoted after scanned.
            # We don't promote `value` shape here because we need to store carry shape
            # at this step. If we reshape the `value` here, output carry might get wrong shape.
            with _promote_fn_shapes(), packed_trace() as trace:
                new_carry, y = config_enumerate(seeded_fn)(carry, x)

            # store shape of new_carry at a global variable
            if len(carry_shapes) < (history + 1):
                carry_shapes.append([jnp.shape(x) for x in tree_flatten(new_carry)[0]])
            # make new_carry have the same shape as carry
            # FIXME: is this rigorous?
            new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)),
                                      new_carry, carry)
        return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)

    with handlers.block(hide_fn=lambda site: not site["name"].startswith("_PREV_")), \
            enum(first_available_dim=first_available_dim):
        wrapped_carry = (0, rng_key, init)
        y0s = []
        # We run unroll_steps + 1 where the last step is used for rolling with `lax.scan`
        for i in markov(range(unroll_steps + 1), history=history):
            if i < unroll_steps:
                wrapped_carry, (_, y0) = body_fn(wrapped_carry, tree_map(lambda z: z[i], x0))
                if i > 0:
                    # reshape y1, y2,... to have the same shape as y0
                    y0 = tree_multimap(lambda z0, z: jnp.reshape(z, jnp.shape(z0)), y0s[0], y0)
                y0s.append(y0)
                # shapes of the first `history - 1` steps are not useful to interpret the last carry
                # shape so we don't need to record them here
                if (i >= history - 1) and (len(carry_shapes) < history + 1):
                    carry_shapes.append(jnp.shape(x) for x in tree_flatten(wrapped_carry[-1])[0])
            else:
                # this is the last rolling step
                y0s = tree_multimap(lambda *z: jnp.stack(z, axis=0), *y0s)
                # return early if length = unroll_steps
                if length == unroll_steps:
                    return wrapped_carry, (PytreeTrace({}), y0s)
                wrapped_carry = device_put(wrapped_carry)
                wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry, xs_,
                                                             length - unroll_steps, reverse)

    first_var = None
    for name, site in pytree_trace.trace.items():
        # currently, we only record sample or deterministic in the trace
        # we don't need to adjust `dim_to_name` for deterministic site
        if site['type'] not in ('sample',):
            continue
        # add `time` dimension, the name will be '_time_{first variable in the trace}'
        if first_var is None:
            first_var = name

        # we haven't promote shapes of values yet during `lax.scan`, so we do it here
        site["value"] = _promote_scanned_value_shapes(site["value"], site["fn"])

        # XXX: site['infer']['dim_to_name'] is not enough to determine leftmost dimension because
        # we don't record 1-size dimensions in this field
        time_dim = -min(len(site['fn'].batch_shape), jnp.ndim(site['value']) - site['fn'].event_dim)
        site['infer']['dim_to_name'][time_dim] = '_time_{}'.format(first_var)

    # similar to carry, we need to reshape due to shape alternating in markov
    ys = tree_multimap(lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)[1:]), y0s, ys)
    # then join with y0s
    ys = tree_multimap(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys)
    # we also need to reshape `carry` to match sequential behavior
    i = (length + 1) % (history + 1)
    t, rng_key, carry = wrapped_carry
    carry_shape = carry_shapes[i]
    flatten_carry, treedef = tree_flatten(carry)
    flatten_carry = [jnp.reshape(x, t1_shape)
                     for x, t1_shape in zip(flatten_carry, carry_shape)]
    carry = tree_unflatten(treedef, flatten_carry)
    wrapped_carry = (t, rng_key, carry)
    return wrapped_carry, (pytree_trace, ys)
Пример #11
0
 def init_fn(params):
     mu = jax.tree_multimap(jnp.zeros_like, params)  # momentum
     return ScaleByLarsState(mu=mu)
Пример #12
0
def tree_add(a, b):
  return jax.tree_multimap(lambda e1, e2: e1+e2, a, b)
Пример #13
0
 def trees2_map(fun, tree1, tree2, *args, **kwargs):
     return tree_multimap(lambda x, y: fun(x, y, *args, **kwargs), tree1, tree2)
Пример #14
0
    def __call__(self, inputs, state):
        """Run one step of the wrapped core, handling state reset.

    Args:
      inputs: Tuple with two elements, ``inputs, should_reset``, where
        ``should_reset`` is the signal used to reset the wrapped core's state.
        ``should_reset`` can be either tensor or nest. If nest, ``should_reset``
        must match the state structure, and its components' shapes must be
        prefixes of the corresponding entries tensors' shapes in the state nest.
        If tensor, supported shapes are all commom shape prefixes of the state
        component tensors, e.g. ``[batch_size]``.
      state: Previous wrapped core state.

    Returns:
      Tuple of the wrapped core's ``output, next_state``.
    """
        inputs, should_reset = inputs
        if jax.treedef_is_leaf(jax.tree_structure(should_reset)):
            # Equivalent to not tree.is_nested, but with support for Jax extensible
            # pytrees.
            should_reset = jax.tree_map(lambda _: should_reset, state)

        # We now need to manually pad 'on the right' to ensure broadcasting operates
        # correctly.
        # Automatic broadcasting would in fact implicitly pad 'on the left',
        # resulting in the signal to trigger resets for parts of the state
        # across batch entries. For example:
        #
        # import jax
        # import jax.numpy as jnp
        #
        # shape = (2, 2, 2)
        # x = jnp.zeros(shape)
        # y = jnp.ones(shape)
        # should_reset = jnp.array([False, True])
        # v = jnp.where(should_reset, x, y)
        # for batch_entry in range(shape[0]):
        #   print("batch_entry {}:\n".format(batch_entry), v[batch_entry])
        #
        # >> batch_entry 0:
        # >>  [[1. 0.]
        # >>  [1. 0.]]
        # >> batch_entry 1:
        # >>  [[1. 0.]
        # >>  [1. 0.]]
        #
        # Note how manually padding the should_reset tensor yields the desired
        # behavior.
        #
        # import jax
        # import jax.numpy as jnp
        #
        # shape = (2, 2, 2)
        # x = jnp.zeros(shape)
        # y = jnp.ones(shape)
        # should_reset = jnp.array([False, True])
        # dims_to_add = x.ndim - should_reset.ndim
        # should_reset = should_reset.reshape(should_reset.shape + (1,)*dims_to_add)
        # v = jnp.where(should_reset, x, y)
        # for batch_entry in range(shape[0]):
        #   print("batch_entry {}:\n".format(batch_entry), v[batch_entry])
        #
        # >> batch_entry 0:
        # >>  [[1. 1.]
        # >>  [1. 1.]]
        # >> batch_entry 1:
        # >>  [[0. 0.]
        # >>  [0. 0.]]
        should_reset = jax.tree_multimap(_validate_and_conform, should_reset,
                                         state)
        if self._is_batched(state):
            batch_size = jax.tree_leaves(inputs)[0].shape[0]
        else:
            batch_size = None
        initial_state = jax.tree_multimap(lambda s, i: i.astype(s.dtype),
                                          state,
                                          self.initial_state(batch_size))
        state = jax.tree_multimap(jnp.where, should_reset, initial_state,
                                  state)
        return self.core(inputs, state)
Пример #15
0
def stack_forest(forest):
  stack_args = lambda *args: onp.stack(args)
  return jax.tree_multimap(stack_args, *forest)
Пример #16
0
 def update_fn(updates, state, params=None):
     del params
     multiplied_updates = jax.tree_multimap(
         lambda m, update: jax.tree_map(lambda u: u * m, update),
         state.multipliers, updates)
     return multiplied_updates, state
Пример #17
0
 def concatenate_leaves(pytrees):
     return jax.tree_multimap(
         lambda *leaves: onp.concatenate(leaves, axis=0), *pytrees)
Пример #18
0
# inv_D2P_eq_min = lambda v: jax.tree_map(lambda x: x, v)
inv_D2P_eq_min = lambda v: jax.tree_map(id_func, v)
inv_D2P_ineq_min = lambda v: jax.tree_map(inv_D2P_pd, v)
min_augmented_D2P_inv = (inv_D2P_eq_min, inv_D2P_ineq_min)

key1 = random.PRNGKey(0)
key = random.PRNGKey(1)
x1 = jnp.array([1., 2., 3., 4., 5.])
x2 = random.normal(key1, (5, ))
W1 = random.normal(key, (3, 3))
W2 = random.normal(key1, (3, 3))

x = ((x1, x2), (x1, x1, x2))
W = ((W1, W2), (W1, W1, W2))

print(jax.tree_multimap(lambda f, x: f(x), min_augmented_DP, x))
print(DP_pd(x1))
print(DP_pd(x2))

# Check if the inv(D2P) match the closed form.
print(inv_D2P_pd(W1)(jnp.identity(W1.shape[0])))
print(jnp.linalg.matrix_power(W1, 2).T)

print(inv_D2P_pd(x2)(x1))
print(jnp.dot(jnp.diag(x2), x1))


def make_bound_breg_original(lb=-1.0, ub=1.0):
    def breg_bound_internal(lb, ub, *args, **kwargs):
        return lambda vec: jnp.sum((-vec + ub) * jnp.log(-vec + ub) +
                                   (vec - lb) * jnp.log(vec - lb))
Пример #19
0
    def warmup(
        self,
        rng_key: jax.random.PRNGKey,
        initial_state: HMCState,
        kernel_factory: Callable,
        num_chains,
        num_warmup_steps: int = 1000,
        accelerate=False,
        initial_step_size: float = 0.1,
    ) -> Tuple[HMCState, HMCParameters, Optional[StanWarmupState]]:
        """I don't like having a ton of warmup logic in here."""

        if not self.needs_warmup:
            parameters = HMCParameters(
                jnp.ones(initial_state.position.shape[0], dtype=jnp.int32) *
                self.parameters.num_integration_steps,
                jnp.ones(initial_state.position.shape[0]) *
                self.parameters.step_size,
                jnp.array([
                    self.parameters.inverse_mass_matrix
                    for _ in range(initial_state.position.shape[0])
                ]),
            )
            return initial_state, parameters, None

        hmc_factory = jax.partial(kernel_factory,
                                  self.parameters.num_integration_steps)
        init, update, final = stan_hmc_warmup(hmc_factory,
                                              self.is_mass_matrix_diagonal)

        rng_keys = jax.random.split(rng_key, num_chains)
        chain_state = initial_state
        warmup_state = jax.vmap(init,
                                in_axes=(0, 0, None))(rng_keys, chain_state,
                                                      initial_step_size)

        schedule = jnp.array(stan_warmup_schedule(num_warmup_steps))

        if accelerate:

            print(
                f"sampler: warmup {num_chains:,} chains for {num_warmup_steps:,} iterations.",
                end=" ",
            )
            start = datetime.now()

            @jax.jit
            def update_chain(carry, interval):
                rng_key, chain_state, warmup_state = carry
                stage, is_middle_window_end = interval

                _, rng_key = jax.random.split(rng_key)
                keys = jax.random.split(rng_key, num_chains)
                chain_state, warmup_state, chain_info = jax.vmap(
                    update,
                    in_axes=(0, None, None, 0, 0))(keys, stage,
                                                   is_middle_window_end,
                                                   chain_state, warmup_state)

                return (
                    (rng_key, chain_state, warmup_state),
                    (chain_state, warmup_state, chain_info),
                )

            last_state, warmup_chain = jax.lax.scan(
                update_chain, (rng_key, chain_state, warmup_state), schedule)
            _, last_chain_state, last_warmup_state = last_state

            print(
                f"Done in {(datetime.now()-start).total_seconds():.1f} seconds."
            )

        else:

            @jax.jit
            def update_fn(rng_key, interval, chain_state, warmup_state):
                rng_keys = jax.random.split(rng_key, num_chains)
                stage, is_middle_window_end = interval
                chain_state, warmup_state, chain_info = jax.vmap(
                    update,
                    in_axes=(0, None, None, 0, 0))(rng_keys, stage,
                                                   is_middle_window_end,
                                                   chain_state, warmup_state)
                return chain_state, warmup_state, chain_info

            chain = []
            with tqdm(schedule, unit="samples") as progress:
                progress.set_description(
                    f"Warming up {num_chains} chains for {num_warmup_steps} steps",
                    refresh=False,
                )
                for interval in progress:
                    _, rng_key = jax.random.split(rng_key)
                    chain_state, warmup_state, chain_info = update_fn(
                        rng_key, interval, chain_state, warmup_state)
                    chain.append((chain_state, warmup_state, chain_info))

            last_chain_state, last_warmup_state, _ = chain[-1]

            # The sampling process, the composition between scan and for loop
            # is identical for the warmup and the sampling.  Should we
            # generalize this to only call a single `scan` function?
            stack = lambda y, *ys: jnp.stack((y, *ys))
            warmup_chain = jax.tree_multimap(stack, *chain)

        step_size, inverse_mass_matrix = jax.vmap(
            final, in_axes=(0, ))(last_warmup_state)
        num_integration_steps = self.parameters.num_integration_steps

        parameters = HMCParameters(
            jnp.ones(initial_state.position.shape[0], dtype=jnp.int32) *
            num_integration_steps,
            step_size,
            inverse_mass_matrix,
        )

        return last_chain_state, parameters, warmup_chain
Пример #20
0
def test_subsample_gradient(scale, subsample):
    data = jnp.array([-0.5, 2.0])
    subsample_size = 1 if subsample else len(data)
    precision = 0.06 * scale

    def model(subsample):
        with handlers.substitute(data={"data": subsample}):
            with numpyro.plate("data", len(data), subsample_size) as ind:
                x = data[ind]
                z = numpyro.sample("z", dist.Normal(0, 1))
                numpyro.sample("x", dist.Normal(z, 1), obs=x)

    def guide(subsample):
        scale = numpyro.param("scale", 1.)
        with handlers.substitute(data={"data": subsample}):
            with numpyro.plate("data", len(data), subsample_size):
                loc = numpyro.param("loc", jnp.zeros(len(data)), event_dim=0)
                numpyro.sample("z", dist.Normal(loc, scale))

    if scale != 1.:
        model = handlers.scale(model, scale=scale)
        guide = handlers.scale(guide, scale=scale)

    num_particles = 50000
    optimizer = optim.Adam(0.1)
    elbo = Trace_ELBO(num_particles=num_particles)
    svi = SVI(model, guide, optimizer, loss=elbo)
    svi_state = svi.init(random.PRNGKey(0), None)
    params = svi.optim.get_params(svi_state.optim_state)
    normalizer = 2 if subsample else 1
    if subsample_size == 1:
        subsample = jnp.array([0])
        loss1, grads1 = value_and_grad(
            lambda x: svi.loss.loss(svi_state.rng_key, svi.constrain_fn(x), svi
                                    .model, svi.guide, subsample))(params)
        subsample = jnp.array([1])
        loss2, grads2 = value_and_grad(
            lambda x: svi.loss.loss(svi_state.rng_key, svi.constrain_fn(x), svi
                                    .model, svi.guide, subsample))(params)
        grads = tree_multimap(lambda *vals: vals[0] + vals[1], grads1, grads2)
        loss = loss1 + loss2
    else:
        subsample = jnp.array([0, 1])
        loss, grads = value_and_grad(
            lambda x: svi.loss.loss(svi_state.rng_key, svi.constrain_fn(x), svi
                                    .model, svi.guide, subsample))(params)

    actual_loss = loss / normalizer
    expected_loss, _ = value_and_grad(lambda x: svi.loss.loss(
        svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, None))(
            params)
    assert_allclose(actual_loss, expected_loss, rtol=precision, atol=precision)

    actual_grads = {name: grad / normalizer for name, grad in grads.items()}
    expected_grads = {
        'loc': scale * jnp.array([0.5, -2.0]),
        'scale': scale * jnp.array([2.0])
    }
    assert actual_grads.keys() == expected_grads.keys()
    for name in expected_grads:
        assert_allclose(actual_grads[name],
                        expected_grads[name],
                        rtol=precision,
                        atol=precision)
Пример #21
0
def tree_allclose(t1, t2):
    t = jax.tree_multimap(jnp.allclose, t1, t2)
    return all(jax.tree_util.tree_flatten(t)[0])
Пример #22
0
def list_dot(x: Vector, y: Vector) -> Scalar:
    return np.sum(
        tree_multimap(lambda arr_x, arr_y: np.sum(arr_x * arr_y), x, y))
Пример #23
0
def periodic_update(new_tensors: Any, old_tensors: Any, is_time: Numeric):
    """Periodically switch all elements from a nested struct with new elements."""
    return jax.tree_multimap(
        lambda new, old: jax.lax.select(is_time, new, old), new_tensors,
        old_tensors)
Пример #24
0
def list_add_prefactor(x: Vector, a: Scalar, y: Vector) -> Vector:
    # mimics x + x_1*y
    return tree_multimap(lambda arr_x, arr_y: arr_x + a * arr_y, x, y)
Пример #25
0
def _jvp(oks: PyTree, v: PyTree) -> Array:
    """
    Compute the matrix-vector product between the pytree jacobian oks and the pytree vector v
    """
    td = lambda x, y: jnp.tensordot(x, y, axes=y.ndim)
    return jax.tree_util.tree_reduce(jnp.add, jax.tree_multimap(td, oks, v))
Пример #26
0
    def test_routing_reduce_correct(self, reduction):
        """Compare JAX implementations to a (slow but correct) iterative one."""
        n_variants = 2
        n_states = 4

        def make_range_shaped(shape):
            return np.arange(np.prod(shape)).reshape(shape).astype("float32")

        schema = self.build_simple_schema()
        builder = automaton_builder.AutomatonBuilder(schema)
        routing_params = automaton_builder.RoutingParams(
            move=make_range_shaped([
                n_variants,
                len(builder.in_out_route_types),
                n_states,
                n_states,
            ]),
            special=make_range_shaped([
                n_variants,
                len(builder.in_route_types),
                n_states,
                len(builder.special_actions),
            ]),
        )

        # Compute aggregates with JAX
        if reduction == "softmax":
            routing_aggregates = builder.routing_softmax(routing_params)
        else:
            routing_aggregates = builder.routing_reduce(routing_params,
                                                        reduction=reduction)
            routing_aggregates = jax.tree_multimap(
                lambda s, p: np.array(jnp.broadcast_to(s, p.shape)),
                routing_aggregates, routing_params)

        # Manual looping aggregates
        for variant in range(n_variants):
            for current_state in range(n_states):
                for in_route_type in builder.in_route_types:
                    # Compute aggregates
                    distn_vals = []
                    iroute_idx = builder.in_route_type_to_index[in_route_type]
                    for out_edge_type in schema[
                            in_route_type.node_type].out_edges:
                        ioroute_idx = builder.in_out_route_type_to_index[
                            automaton_builder.InOutRouteType(
                                in_route_type.node_type, in_route_type.in_edge,
                                out_edge_type)]
                        for next_state in range(n_states):
                            distn_vals.append(
                                routing_params.move[variant, ioroute_idx,
                                                    current_state, next_state])

                    for action_idx in range(len(builder.special_actions)):
                        distn_vals.append(routing_params.special[variant,
                                                                 iroute_idx,
                                                                 current_state,
                                                                 action_idx])

                    if reduction == "sum":
                        distn_aggregate = [sum(distn_vals)] * len(distn_vals)
                    elif reduction == "max":
                        distn_aggregate = [max(distn_vals)] * len(distn_vals)
                    elif reduction == "softmax":
                        distn_aggregate = list(
                            jax.nn.softmax(jnp.array(distn_vals)))
                    else:
                        raise ValueError(f"Invalid reduction {reduction}")

                    i = 0
                    # Check them with the JAX version
                    for out_edge_type in schema[
                            in_route_type.node_type].out_edges:
                        ioroute_idx = builder.in_out_route_type_to_index[
                            automaton_builder.InOutRouteType(
                                in_route_type.node_type, in_route_type.in_edge,
                                out_edge_type)]
                        for next_state in range(n_states):
                            np.testing.assert_allclose(
                                routing_aggregates.move[variant, ioroute_idx,
                                                        current_state,
                                                        next_state],
                                distn_aggregate[i],
                                rtol=1e-6)
                            i += 1

                    for action_idx in range(len(builder.special_actions)):
                        np.testing.assert_allclose(
                            routing_aggregates.special[variant, iroute_idx,
                                                       current_state,
                                                       action_idx],
                            distn_aggregate[i],
                            rtol=1e-6)
                        i += 1
Пример #27
0
def _sum_combinator(*args):
  return functools.reduce(
      lambda x, y: jax.tree_multimap(lambda i, j: i + j, x, y), args)
Пример #28
0
def train_step(state,
               inputs,
               outputs,
               programs,
               pretrain,
               bos_token,
               eos_token,
               learning_rate_fn,
               config,
               lp_config,
               train_rng=None):
    """Train on batch of program tasks."""
    # We handle PRNG splitting inside the top pmap, rather
    # than handling it outside in the training loop - doing the
    # latter can add some stalls to the devices.
    train_rng, new_train_rng = jax.random.split(train_rng)

    weights = jnp.where(programs > 0, 1, 0).astype(jnp.float32)

    # Embedding mask for autoencoding.
    emb_mask = jnp.ones((1, FLAGS.latent_vocab_size),
                        jnp.float32).at[:, [0, bos_token, eos_token]].set(0)

    def ae_loss_fn(params):
        """Loss function used for training autoencoder."""
        (logits,
         vq), new_variables = models.LatentProgramTransformer(config).apply(
             {
                 'params': params,
                 'vqvae': state.model_state
             },
             inputs,
             outputs,
             programs,
             emb_mask,
             pretrain=pretrain,
             mutable=['vqvae'],
             rngs={'dropout': train_rng})
        loss, weight_sum = compute_weighted_cross_entropy(
            logits, programs, weights)

        # Add EOS token for latent predictor loss.
        vq_weight_sum = jnp.sum(
            jnp.where(vq['latent_indices'] > 0, 1, 0).astype(jnp.float32))
        latent_indices = add_eos_token(vq['latent_indices'], eos_token)

        mean_loss = loss / weight_sum + vq['loss'] / vq_weight_sum
        return mean_loss, (new_variables['vqvae'], logits, latent_indices)

    step = state.step
    optimizer = state.optimizer
    lp_optimizer = state.lp_optimizer
    lr = learning_rate_fn(step)
    grad_fn = jax.value_and_grad(ae_loss_fn, has_aux=True)
    (_, (new_model_state, ae_logits,
         latent_indices)), ae_grad = grad_fn(optimizer.target)
    ae_grad = jax.lax.pmean(ae_grad, 'batch')

    latent_weights = jnp.where(latent_indices > 0, 1, 0).astype(jnp.float32)

    encoded_mask = jnp.where(outputs > 0, 1, 0).astype(jnp.float32)
    # Additionally mask out eos token in latents.
    latents_mask = jnp.where(
        jnp.logical_and(latent_indices > 0, latent_indices != eos_token), 1,
        0).astype(jnp.float32)

    def loss_fn(params, lp_params):
        """Loss function used for training."""
        latent_logits = models.ProgramTransformer(lp_config).apply(
            {'params': lp_params},
            inputs,
            outputs,
            latent_indices,
            rngs={'dropout': train_rng})
        latent_loss, latent_weight_sum = compute_weighted_cross_entropy(
            latent_logits, latent_indices, latent_weights)

        # End-to-end prediction.
        encoded = models.LatentProgramTransformer(config).apply(
            {
                'params': params,
                'vqvae': state.model_state
            },
            inputs,
            outputs,
            mutable=False,
            rngs={'dropout': train_rng},
            method=models.LatentProgramTransformer.encode)
        latents = models.LatentProgramTransformer(config).apply(
            {
                'params': params,
                'vqvae': state.model_state
            },
            latent_logits,
            mutable=False,
            rngs={'dropout': train_rng},
            method=models.LatentProgramTransformer.quantize)
        logits = models.LatentProgramTransformer(config).apply(
            {
                'params': params,
                'vqvae': state.model_state
            },
            programs,
            latents,
            encoded,
            latents_mask,
            encoded_mask,
            mutable=False,
            rngs={'dropout': train_rng},
            method=models.LatentProgramTransformer.decode)
        loss, weight_sum = compute_weighted_cross_entropy(
            logits, programs, weights)

        mean_loss = latent_loss / latent_weight_sum
        if not pretrain:
            mean_loss += loss / weight_sum
        return mean_loss, (logits, latent_logits)

    grad_fn = jax.value_and_grad(loss_fn, argnums=[0, 1], has_aux=True)
    (_, (logits, latent_logits)), grads = grad_fn(optimizer.target,
                                                  lp_optimizer.target)
    grads = jax.lax.pmean(grads, 'batch')
    new_optimizer = optimizer.apply_gradient(jax.tree_multimap(
        jnp.add, grads[0], ae_grad),
                                             learning_rate=lr)
    new_lp_optimizer = lp_optimizer.apply_gradient(grads[1], learning_rate=lr)

    metrics = compute_metrics(logits, programs, weights)
    metrics['learning_rate'] = lr
    metrics.update(compute_metrics(ae_logits, programs, weights, prefix='ae_'))
    latent_metrics = compute_metrics(latent_logits,
                                     latent_indices,
                                     latent_weights,
                                     prefix='latent_')

    new_state = state.replace(step=step + 1,
                              optimizer=new_optimizer,
                              model_state=jax.lax.pmean(
                                  new_model_state, 'batch'),
                              lp_optimizer=new_lp_optimizer)
    return new_state, metrics, latent_metrics, new_train_rng
Пример #29
0
 def soft_update_func(old, new, tau):
     return jax.tree_multimap(lambda a, b: (1 - tau) * a + tau * b, old,
                              new)
Пример #30
0
def update(params, grads):
    return jax.tree_multimap(lambda p, g: p - 0.05 * g, params, grads)