Exemple #1
0
def G(input_shape, num_layers, activation=sigmoid, bias=False, **kwargs):
    if bias:
        activation = act_but_last(activation)
    activation = elementwise(activation)
    block = [Linear(input_shape, **kwargs), Accumulator, activation]
    layers = flatten_shallow(block for i in range(num_layers))[:-1]
    mask = flatten_shallow([False, True, False]
                           for i in range(num_layers))[:-1]
    init_fun, apply_fun = serial_with_outputs(*layers, mask=mask)

    def init(rng, input_shape):
        output_shape, params = init_fun(rng, input_shape)
        opt_params = params[::3]
        const_params = params[1::3]
        if bias:
            opt_params = [affine_padder(p) for p in opt_params]
            const_params = [augment_dummy_params(p) for p in const_params]
        return params, opt_params, const_params  # even(params), odd(params)

    def g(opt_params, const_params, inputs):
        if bias:
            inputs = augment_data(inputs)
        o, ys = apply_fun(
            interleave(opt_params, const_params, (() for p in opt_params)),
            inputs, **kwargs)
        if bias:
            o = column_stripper(o)
            ys = [column_stripper(y) for y in ys]
        return o, ys

    return init, g
Exemple #2
0
    def __init__(self,
                 indim,
                 outdim,
                 topology,
                 omega=1.0,
                 transform=None,
                 seed=0):
        """
        Arguments
        ---------

        indim:
            Dimensionality of the a single data input.

        outdim:
            Dimensionality of the a single data output.

        topology: Tuple
            Defines the structure of the inner layers for the network.

        omega:
            Weight distribution factor ω₀ for the first layer (as described in [1]).

        transform: Optional[Callable]
            Optional pre-network transformation function.

        seed: Optional[int]
            Initial seed for weight initialization.
        """

        tlayer = build_transform_layer(transform)
        # Weight initialization for Sirens
        pdf_in = variance_scaling(1.0 / 3, "fan_in", "uniform")
        pdf = variance_scaling(2.0 / omega**2, "fan_in", "uniform")
        # Sine activation function
        σ_in = stax.elementwise(lambda x: np.sin(omega * x))
        σ = stax.elementwise(lambda x: np.sin(x))
        # Build layers
        layer_in = [
            stax.Flatten, *tlayer,
            stax.Dense(topology[0], pdf_in), σ_in
        ]
        layers = list(
            chain.from_iterable((stax.Dense(i, pdf), σ) for i in topology[1:]))
        layers = layer_in + layers + [stax.Dense(outdim, pdf)]
        #
        super().__init__(indim, layers, seed)
Exemple #3
0
    def q_network(self):
        #no regression !
        if self.dueling:
            init, apply = stax.serial(
                elementwise(lambda x: x / 10000.0),
                stax.serial(Dense(128), Relu, Dense(64), Relu),  #base layers
                FanOut(2),
                stax.parallel(
                    stax.serial(Dense(32), Relu, Dense(1)),  #state value
                    stax.serial(Dense(32), Relu,
                                Dense(self.num_actions)))  #advantage func
            )

        else:
            init, apply = stax.serial(elementwise(lambda x: x/10000.0), Dense(64), Relu, \
                                      Dense(32), Relu, Dense(self.num_actions))

        return init, apply
def negativity_transform():
    """Layer construction function for negativity transform.

  This layer is used as the last layer of xc energy density network since
  exhange and correlation must be negative according to exact conditions.

  Note we use a 'soft' negativity transformation here. The range after this
  transformation is (-inf, 0.278].

  Returns:
    (init_fn, apply_fn) pair.
  """
    def negative_fn(x):
        return -nn.swish(x)

    return stax.elementwise(negative_fn)
Exemple #5
0
def constructDuelNetwork(n_actions, seed, input_shape):
    advantage_stream = stax.serial(Dense(512), Relu, Dense(n_actions))

    state_function_stream = stax.serial(Dense(512), Relu, Dense(1))
    dueling_architecture = stax.serial(
        elementwise(lambda x: x / 255.0),
        GeneralConv(dim_nums, 32, (8, 8), strides=(4, 4)),
        Relu,
        GeneralConv(dim_nums, 64, (4, 4), strides=(2, 2)),
        Relu,
        GeneralConv(dim_nums, 64, (3, 3), strides=(1, 1)),
        Relu,
        Flatten,
        FanOut(2),
        parallel(advantage_stream, state_function_stream),
    )

    def duelingNetworkMapping(inputs):
        advantage_values = inputs[0]
        state_values = inputs[1]
        advantage_sums = jnp.sum(advantage_values, axis=1)

        advantage_sums = advantage_sums / float(n_actions)
        advantage_sums = advantage_sums.reshape(-1, 1)

        Q_values = state_values + (advantage_values - advantage_sums)

        return Q_values

    duelArchitectureMapping = jit(duelingNetworkMapping)

    ##### Create deep neural net
    model = DDQN(n_actions,
                 input_shape,
                 adam_params,
                 architecture=dueling_architecture,
                 seed=seed,
                 mappingFunction=duelArchitectureMapping)

    return model
Exemple #6
0
def constructSingleStreamNetwork(n_actions, seed, input_shape):
    single_stream_architecture = stax.serial(
        elementwise(lambda x: x / 255.0),  # normalize
        ### convolutional NN (CNN)
        GeneralConv(dim_nums, 32, (8, 8), strides=(4, 4)),
        Relu,
        GeneralConv(dim_nums, 64, (4, 4), strides=(2, 2)),
        Relu,
        GeneralConv(dim_nums, 64, (3, 3), strides=(1, 1)),
        Relu,
        Flatten,  # flatten output
        Dense(1024),
        Relu,
        Dense(n_actions))

    model = DDQN(n_actions,
                 input_shape,
                 adam_params,
                 architecture=single_stream_architecture,
                 seed=seed)

    return model
import jax
from jax import lax
from jax import nn
from jax import tree_util
from jax.experimental import stax
import jax.numpy as jnp
from jax.scipy import ndimage

from jax_dft import scf
from jax_dft import utils

_STAX_ACTIVATION = {
    'relu': stax.Relu,
    'elu': stax.Elu,
    'softplus': stax.Softplus,
    'swish': stax.elementwise(nn.swish),
}


def negativity_transform():
    """Layer construction function for negativity transform.

  This layer is used as the last layer of xc energy density network since
  exhange and correlation must be negative according to exact conditions.

  Note we use a 'soft' negativity transformation here. The range after this
  transformation is (-inf, 0.278].

  Returns:
    (init_fn, apply_fn) pair.
  """
Exemple #8
0
def fcnn(
    output_dim: int,
    layers=2,
    units=256,
    skips=False,
    output_activation=None,
    first_layer_factor=None,
    parallel_with=None,
):
    """
    Creates init and apply functions for a fully-connected NN with a Gaussian
    likelihood.

    :param output_dim: Number of output dimensions
    :param first_layer_factor: Scaling factor for first dense layer initialization.
    """

    first_layer_factor = 10.0 if first_layer_factor is None else first_layer_factor

    activation = stax.elementwise(np.sin)
    block = (skip_connect(
        stax.serial(Dense(units, W_init=siren_w_init()()),
                    activation)) if skips else stax.serial(
                        Dense(units, W_init=siren_w_init()()), activation))

    layer_list = [
        stax.serial(
            Dense(
                units,
                W_init=siren_w_init(factor=first_layer_factor)(),
                b_init=jax.nn.initializers.normal(stddev=2.0 * math.pi),
            ),
            activation,
        ),
        *([block] * (layers - 1)),
        Dense(output_dim),  # No skips on the last one!
    ]
    if output_activation is not None:
        layer_list.append(output_activation)
    if parallel_with is None:
        _init_fun, _apply_fun = stax.serial(*layer_list)
    else:
        _init_fun, _apply_fun = parallel(stax.serial(*layer_list),
                                         parallel_with)

    def init_fun(rng, input_shape):
        output_shape, net_params = _init_fun(rng, input_shape)
        params = {"net": net_params, "raw_noise": np.array(-2.0)}
        return output_shape, params

    # Conform to API:
    @t_wrapper
    def apply_fun(params, rng, inputs):
        return _apply_fun(params["net"], inputs)

    @t_wrapper
    def gaussian_fun(params, rng, inputs):
        pred_mean = apply_fun(params, None, inputs)
        pred_std = params["noise"] * np.ones(pred_mean.shape)
        return pred_mean, pred_std

    @t_wrapper
    def loss_fun(params,
                 rng,
                 data,
                 batch_size=None,
                 n=None,
                 loss_type="nlp",
                 reduce="sum"):
        """
        :param batch_size: How large a batch to subselect from the provided data
        :param n: The total size of the dataset (to multiply batch estimate by)
        """
        assert loss_type in ("nlp", "mse")
        inputs, targets = data
        n = inputs.shape[0] if n is None else n
        if batch_size is not None:
            rng, rng_batch = random.split(rng)
            i = random.permutation(rng_batch, n)[:batch_size]
            inputs, targets = inputs[i], targets[i]
        preds = apply_fun(params, rng, inputs).squeeze()
        mean_loss = (
            -norm.logpdf(targets.squeeze(), preds, params["noise"]).mean()
            if loss_type == "nlp" else np.power(targets.squeeze() -
                                                preds, 2).mean())
        if reduce == "sum":
            loss = n * mean_loss
        elif reduce == "mean":
            loss = mean_loss
        return loss

    def sample_fun_fun(rng, params):
        def f(x):
            return apply_fun(params, rng, x)

        return f

    return {
        "init": init_fun,
        "apply": apply_fun,
        "gaussian": gaussian_fun,
        "loss": loss_fun,
        "sample_function": sample_fun_fun,
    }
Exemple #9
0
def build_transform_layer(transform: Optional[Callable] = None):
    return () if transform is None else (stax.elementwise(transform), )
Exemple #10
0
def main():
    num_iter = 50000
    # Most people run 1000 steps and the OpenAI gym pendulum is 0.05s per step.
    # The max torque that can be applied is also 2 in their setup.
    T = 1000
    time_delta = 0.05
    max_torque = 2.0
    rng = random.PRNGKey(0)

    dynamics = pendulum_dynamics(
        mass=1.0,
        length=1.0,
        gravity=9.8,
        friction=0.0,
    )

    policy_init, policy_nn = stax.serial(
        Dense(64),
        Tanh,
        Dense(64),
        Tanh,
        Dense(1),
        Tanh,
        stax.elementwise(lambda x: max_torque * x),
    )

    # Should it matter whether theta is wrapped into [0, 2pi]?
    policy = lambda params, x: policy_nn(
        params,
        jnp.array([x[0] % (2 * jnp.pi), x[1],
                   jnp.cos(x[0]),
                   jnp.sin(x[0])]))

    def loss(policy_params, x0):
        x = x0
        acc_cost = 0.0
        for _ in range(T):
            u = policy(policy_params, x)
            x += time_delta * dynamics(x, u)
            acc_cost += time_delta * cost(x, u)
        return acc_cost

    rng_init_params, rng = random.split(rng)
    _, init_policy_params = policy_init(rng_init_params, (4, ))
    opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params)
    loss_and_grad = jit(value_and_grad(loss))

    loss_per_iter = []
    elapsed_per_iter = []
    x0s = vmap(sample_x0)(random.split(rng, num_iter))
    for i in range(num_iter):
        t0 = time.time()
        loss, g = loss_and_grad(opt.value, x0s[i])
        opt = opt.update(g)
        elapsed = time.time() - t0

        loss_per_iter.append(loss)
        elapsed_per_iter.append(elapsed)

        print(f"Episode {i}")
        print(f"    loss = {loss}")
        print(f"    elapsed = {elapsed}")

    blt.remember({
        "loss_per_iter": loss_per_iter,
        "elapsed_per_iter": elapsed_per_iter,
        "final_params": opt.value
    })

    plt.figure()
    plt.plot(loss_per_iter)
    plt.yscale("log")
    plt.title("ODE control of an inverted pendulum")
    plt.xlabel("Iteration")
    plt.ylabel(f"Policy cost (T = {total_secs}s)")

    # Viz
    num_viz_rollouts = 50
    framerate = 30
    timesteps = jnp.linspace(0,
                             int(T * time_delta),
                             num=int(T * time_delta * framerate))
    rollout = lambda x0: ode.odeint(
        lambda x, _: dynamics(x, policy(opt.value, x)), y0=x0, t=timesteps)

    plt.figure()
    states = rollout(jnp.zeros(2))
    plt.plot(states[:, 0], states[:, 1], marker=".")
    plt.xlabel("theta")
    plt.ylabel("theta dot")
    plt.title("Swing up trajectory")

    plt.figure()
    states = vmap(rollout)(x0s[:num_viz_rollouts])
    for i in range(num_viz_rollouts):
        plt.plot(states[i, :, 0], states[i, :, 1], marker='.', alpha=0.5)
    plt.xlabel("theta")
    plt.ylabel("theta dot")
    plt.title("Phase space trajectory")

    plot_control_contour(lambda x: policy(opt.value, x))
    plot_policy_dynamics(dynamics, lambda x: policy(opt.value, x))

    blt.show()
Exemple #11
0
        return output1, (params1, params2)

    def apply_fun(params, inputs, **kwargs):
        # TODO: don't use diffeq wrapper, get t, x = inputs??
        params1, params2 = params
        t1, out1 = apply_fn1(params1, inputs, **kwargs)
        t2, out2 = apply_fn2(params2, inputs, **kwargs)
        # t1 == t2 since sdeint calls
        return t2, out2 + out1

    return init_fun, apply_fun


# activations
rbf = lambda x: np.exp(-x**2)
Rbf = elementwise(rbf)
Rbf = register('rbf')(Rbf)  # @register('rbf')
Elu = elementwise(jax.nn.elu)
Elu = register('elu')(stax.Elu)
Softplus = register('softplus')(stax.Softplus)
swish = lambda x: x * jax.nn.sigmoid(x)
Swish = elementwise(swish)
Swish = register('swish_nobeta')(Swish)
Relu = register('relu')(stax.Relu)
Tanh = register('tanh')(stax.Tanh)


@register('swish')
def Swish_(out_dim, beta_init=0.5):
    """
  Trainable Swish function to learn 
    averages_mat = averages_mat[1:, :].T
    averages_df = pd.DataFrame(data=averages_mat,
                               index=adata_ref.raw.var_names,
                               columns=all_clusters)

    return averages_df


from jax import random
from jax.experimental import stax
import jax.numpy as jnp
from jax.random import PRNGKey

import numpyro

Log1p = stax.elementwise(jax.lax.log1p)


def encoder(hidden_dim, z_dim):
    return stax.serial(
        Log1p,
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.FanOut(2),
        stax.parallel(
            stax.Dense(z_dim, W_init=stax.randn()),
            stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp)),
    )


def decoder(hidden_dim, out_dim):
Exemple #13
0
    def apply_fun(params, inputs, **kwargs):
        return inputs.sum(axis=-1)

    return init_fun, apply_fun


SumLayer = SumLayer()


def logcosh(x):
    x = x * jax.numpy.sign(x.real)
    return x + jax.numpy.log(1.0 +
                             jax.numpy.exp(-2.0 * x)) - jax.numpy.log(2.0)


LogCoshLayer = stax.elementwise(logcosh)


def JaxRbm(hilbert, alpha, dtype=complex):
    return Jax(
        hilbert,
        stax.serial(stax.Dense(alpha * hilbert.size), LogCoshLayer, SumLayer),
        dtype=dtype,
    )


def MPSPeriodic(hilbert,
                graph,
                bond_dim,
                diag=False,
                symperiod=None,
Exemple #14
0
import jax.numpy as np
from jax import random
from jax.experimental.stax import elementwise, serial, Dense, Softmax
from jax.nn import sigmoid
from jax.nn.initializers import glorot_normal, kaiming_normal, orthogonal

#from utils import flatten_shallow, interleave
from utils import *
Abs = elementwise(np.abs)


def serial_with_outputs(*layers, mask):
    """Combinator for composing layers in serial.
    Args:
      *layers: a sequence of layers, each an (init_fun, apply_fun) pair.
    Returns:
      A new layer, meaning an (init_fun, apply_fun) pair, representing the serial
      composition of the given sequence of layers.
    """
    nlayers = len(layers)
    init_funs, apply_funs = zip(*layers)

    def init_fun(rng, input_shape):
        params = []
        for init_fun in init_funs:
            rng, layer_rng = random.split(rng)
            input_shape, param = init_fun(layer_rng, input_shape)
            params.append(param)
        return input_shape, params

    def apply_fun(params, inputs, **kwargs):
Exemple #15
0
        jnp.array(x) * jax.nn.softplus(beta))  # no / 1.1 for lipschitz

    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim, )
        beta0 = jnp.ones((out_dim, )) * beta_init
        return input_shape, (jnp.array(beta0), )

    def apply_fun(params, inputs, **kwargs):
        beta_, = params
        ret = swish(beta_, inputs)
        return ret

    return init_fun, apply_fun


Rbf = stax.elementwise(lambda x: jnp.exp(-x**2))
ACTFNS = {
    "softplus": stax.Softplus,
    "tanh": stax.Tanh,
    "elu": stax.Elu,
    "rbf": Rbf,
    "swish": Swish_
}


def MLP(hidden_dims=[1, 64, 1],
        actfn="softplus",
        xt=False,
        ou_dw=False,
        p_scale=-1.0,
        nonzero_w=-1.0,
Exemple #16
0
import functools
import sys
import my_sampler
import random





@jax.jit
def logcosh(x):
    """logcosh activation function. To use this function as layer, use LogCoshLayer.
    """
    x = x * jax.numpy.sign(x.real)
    return x + jax.numpy.log(1.0 + jax.numpy.exp(-2.0 * x)) - jax.numpy.log(2.0)
LogCoshLayer = stax.elementwise(logcosh)

#https://arxiv.org/pdf/1705.09792.pdf
#complex activation function, see https://arxiv.org/pdf/1802.08026.pdf
@jax.jit
def modrelu(x):
    """modrelu activation function. To use this function as layer, use ModReLu.

        See https://arxiv.org/pdf/1705.09792.pdf
        """
    return jnp.maximum(1, jnp.abs(x)) * x/jnp.abs(x)
ModReLu = stax.elementwise(modrelu)

#https://arxiv.org/pdf/1705.09792.pdf
#complex activation function, see https://arxiv.org/pdf/1802.08026.pdf
@jax.jit
def learned_dynamics(params):
    @jit
    def dynamics(q, q_t):
        #     assert q.shape == (2,)
        state = wrap_coords(jnp.concatenate([q, q_t]))
        return jnp.squeeze(nn_forward_fn(params, state), axis=-1)

    return dynamics


from jax.experimental.stax import serial, Dense, Softplus, Tanh, elementwise, Relu

sigmoid = jit(lambda x: 1 / (1 + jnp.exp(-x)))
swish = jit(lambda x: x / (1 + jnp.exp(-x)))
relu3 = jit(lambda x: jnp.clip(x, 0.0, float('inf'))**3)
Swish = elementwise(swish)
Relu3 = elementwise(relu3)


def extended_mlp(args):
    act = {
        'softplus': [Softplus, Softplus],
        'swish': [Swish, Swish],
        'tanh': [Tanh, Tanh],
        'tanh_relu': [Tanh, Relu],
        'soft_relu': [Softplus, Relu],
        'relu_relu': [Relu, Relu],
        'relu_relu3': [Relu, Relu3],
        'relu3_relu': [Relu3, Relu],
        'relu_tanh': [Relu, Tanh],
    }[args.act]
def _elementwise(fun, **fun_kwargs):
    init_fun, apply_fun = stax.elementwise(fun, **fun_kwargs)
    ker_fun = lambda kernels: _transform_kernels(kernels, fun, **fun_kwargs)
    return init_fun, apply_fun, ker_fun
Exemple #19
0
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum,
                                   FanOut, Flatten, GeneralConv, Identity,
                                   MaxPool, Relu, LogSoftmax, Softplus)

from jax.nn import sigmoid, swish

from neural_tangents import stax
import jax.experimental.stax as jax_stax
from layers import MyConv, MyDense

Swish = jax_stax.elementwise(swish)


def swish_ten(x):
    return x * sigmoid(10 * x)


Swishten = jax_stax.elementwise(swish_ten)


def CNNStandard(n_channels,
                L,
                filter=(3, 3),
                data='cifar10',
                gap=True,
                nonlinearity='relu',
                parameterization='standard',
Exemple #20
0
from research.estop import mdp

tau = 1e-4
buffer_size = 2**15
batch_size = 64
num_eval_rollouts = 128
opt_init = make_optimizer(optimizers.adam(step_size=1e-3))
init_noise = Normal(jp.array(0.0), jp.array(0.0))
noise = lambda _1, _2: Normal(jp.array(0.0), jp.array(0.5))

actor_init, actor = stax.serial(
    Dense(64),
    Relu,
    Dense(1),
    Tanh,
    stax.elementwise(lambda x: config.max_torque * x),
)

critic_init, critic = stax.serial(
    FanInConcat(),
    Dense(64),
    Relu,
    Dense(64),
    Relu,
    Dense(1),
    stax.elementwise(lambda x: x + 1.0 / (1 - config.gamma)),
    Scalarify,
)

policy = lambda p: lambda s: Deterministic(actor(p, s))
Exemple #21
0
def _hpm(structure, datasets, root_type="gp"):
    """
    Custom hidden physics model for the ultrasound problem.

    Leaves have a(x,y) and u(x,y,t).
    Both are MLPs.

    Physics is
    utt = a(x, y) f(...),
    where f is a GP or the wave operator.

    :return: HPM
    """

    n_leaves = len(datasets)
    n = sum([d[0].shape[0] for d in datasets])

    # Public funcs
    u_funcs = fcnn.fcnn(1, layers=5, units=128, first_layer_factor=1)
    a_funcs = fcnn.fcnn(
        1,
        layers=5,
        units=128,
        first_layer_factor=8.0,
        output_activation=stax.serial(stax.elementwise(lambda x: 0.1 * x),
                                      stax.elementwise(jnp.exp)),
    )
    if root_type == "gp":
        _root_mean = gp_mean_functions.linear()
        _root_kernel = kernels.rbf()
        _root_likelihood = likelihoods.gaussian()
        f_funcs = sparse_layer.sparse_layer(n,
                                            _root_mean,
                                            _root_kernel,
                                            _root_likelihood,
                                            safe=False,
                                            jitter=1.0e-5)
    elif root_type == "wave":
        if not (len(structure.f_inputs) == 1
                and structure.u_operators[structure.f_inputs[0]] == "div"):
            raise ValueError(
                'Wave operator requires "div" be the only u operator fed to f.'
            )
        f_funcs = wave()

    def init_fun(rng):
        params = {"leaf": [], "root": None}
        for _ in range(n_leaves):
            rng, rng_u, rng_a = split(rng, num=3)
            u_params = u_funcs["init"](rng_u, (-1, 3))[1]  # x, y, t
            a_params = a_funcs["init"](
                rng_a, (-1, 2 + len(structure.a_inputs)))[1]  # x, y, ...
            params["leaf"].append(LeafParams(u_params, a_params))
        rng, rng_root = split(rng)
        if root_type == "gp":
            params["root"] = f_funcs["init"](rng_root,
                                             (-1, len(structure.f_inputs)),
                                             m=NUM_INDUCING)[1]
        elif root_type == "wave":
            params["root"] = f_funcs["init"]()[1]
        return params

    def apply_u_ops(params, inputs):
        """
        :param params: For u net
        :param inputs: (N, 3)
        :return: jnp.array, shape (N_Ops,N)
        """
        return jnp.stack([op(params, inputs) for op in u_ops]).T

    def apply_lhs(params, inputs, leaf):
        """
        Compute utt(inputs) for a specified leaf

        :param params: For the whole HPM
        :param inputs: (N,3)
        :return: (N,)
        """
        u_params = params["leaf"][leaf].u
        return utt(u_params, inputs)

    def apply_rhs(params, inputs, leaf, rng=None):
        """
        Compute a^2()f() for a specified leaf.

        Use the mean of f for now.

        :params inputs: (N,3)
        :return: (N,)
        """
        rng = PRNGKey(42) if rng is None else rng
        f_params = params["root"]
        u_params, a_params = params["leaf"][leaf]
        u_op_vals = apply_u_ops(u_params, inputs)
        f_inputs = u_op_vals[:, structure.f_inputs]
        a_inputs = jnp.concatenate(
            (inputs[:, :2], u_op_vals[:, structure.a_inputs]), axis=1)
        a = a_funcs["apply"](a_params, None, a_inputs)[:, 0]  # (N,)
        rng, rng_f = split(rng)
        # Posterior mean, flattened from (N,1) to (N,)
        f = f_funcs["gaussian"](f_params, rng_f, f_inputs)[0].squeeze()
        return a * a * f

    def train(
        rng,
        params,
        datasets,
        u_iters=None,
        af_iters=None,
        freeze_f=False,
    ):
        """
        Train u's to data
        Train a and f to physics
        """

        rng, rng_leaf = split(rng)
        params["leaf"] = tuple([
            LeafParams(
                _train_u(rng_u, pl.u, data, iters=u_iters),
                pl.a,
            ) for rng_u, pl, data in zip(split(rng_leaf, num=n_leaves),
                                         params["leaf"], datasets)
        ])

        # Root pre-training: freeze u's.
        rng, rng_af = split(rng)
        a_params, f_params = _train_af(
            params,
            rng_af,
            datasets,
            iters=af_iters,
            freeze_f=freeze_f,
        )
        params["leaf"] = tuple(
            [LeafParams(pl.u, ap) for pl, ap in zip(params["leaf"], a_params)])
        params["root"] = f_params

        return params

    def loss_fun(
        params,
        batches,
        rng,
        leaf_ns,
        root_batch=None,
        loss_type="nlp",
        u_loss=True,
        af_loss=True,
    ):
        """
        * Data: u targets
        * Physics: utt = a^2 * f

        :param leaf_ns: (tuple of ints) How many data in each leaf.
        :param root_batch: n per leaf to use in root
        """
        # Leaves:
        leaves_loss = 0.0
        a_list, f_in_list, lhs_list = [], [], []
        for lp, batch, leaf_n in zip(params["leaf"], batches, leaf_ns):
            # Data loss:
            rng, rng_u = split(rng)
            leaves_loss = leaves_loss + u_funcs["loss"](
                lp.u,
                rng_u,
                batch,
                n=leaf_n,
                loss_type=loss_type,
                reduce="sum",  # mean later
            )
            # Gather inputs needed for root/physics
            xt_root = batch.x[:root_batch]
            u_op_vals = apply_u_ops(lp.u, xt_root)
            f_in_list.append(u_op_vals[:, structure.f_inputs])
            # columns are 0,3,4 = u, uxx, uyy
            a_in = jnp.concatenate(
                (xt_root[:, :2], u_op_vals[:, structure.a_inputs]), axis=1)
            a_list.append(a_funcs["apply"](lp.a, None, a_in)[:,
                                                             0])  # (B_root,)
            lhs_list.append(utt(lp.u, xt_root))

        # Physics loss
        f_in = jnp.concatenate(f_in_list)
        rng, rng_f = split(rng)
        f_mean, f_std = f_funcs["gaussian"](params["root"], rng_f, f_in)
        f_mean, f_std = f_mean.squeeze(), f_std.squeeze(
        )  # NN returns (N,1)'s...
        a = jnp.concatenate(a_list)
        rhs_mean, rhs_std = a * a * f_mean, a * a * f_std
        lhs = jnp.concatenate(lhs_list)

        # Yuck :( TODO into own func w/ root_type, loss_type to dispatch...
        if loss_type == "mse":
            root_loss = jnp.power(lhs - rhs_mean, 2).mean() * n
        else:  # nlp
            if root_type == "gp":  # GP
                lik_scale = transforms.apply_transform(
                    params["root"],
                    sparse_layer.param_transform)["likelihood"]["noise"]
                root_loss = -(n * norm.logpdf(
                    lhs, loc=rhs_mean, scale=f_std + lik_scale).mean() -
                              f_funcs["kl"](params["root"]))
            elif root_type == "wave":
                lik_scale = _wave_xform(params["root"][0])
                root_loss = -(n * norm.logpdf(
                    lhs, loc=rhs_mean, scale=f_std + lik_scale).mean())
        # Also yuck; can shovel into a loss aggregation func.
        if u_loss and af_loss:
            combined_loss = leaves_loss + root_loss
        elif u_loss:
            combined_loss = leaves_loss
        else:
            combined_loss = root_loss
        loss = combined_loss / n
        return loss

    def print_mses(params, datasets, rng=None):
        """
        Quick helper to show about how good the data & physics are learned.
        Current just takes a single minibatch instead of the actual full datasets.
        """
        rng = PRNGKey(42) if rng is None else rng
        leaf_ns = [len(d.x) for d in datasets]
        kwargs = {"loss_type": "mse"}
        for batch in DataLoader(datasets, 16384):
            args = (params, batch, rng, leaf_ns)
            print("u  MSE : %.2e" % loss_fun(*args, af_loss=False, **kwargs))
            print("af MSE : %.2e" % loss_fun(*args, u_loss=False, **kwargs))
            break

    # Private funcs
    def _train_u(
        rng,
        params,
        data,
        batch_size=4096,
        iters=None,
        plot_every=1000,
        checkpoint_every=1000,
    ):
        """
        Train observation function on data, using cross-validation to early stop
        """

        # Could definitely just pull this out as a general-purpose NN-training func and
        # give things space to breathe...
        # Collapse in your editor is your friend!

        def train_test_split(data, rng=None, n_test=None):
            """
            Create a train-test split
            """
            rng = PRNGKey(42) if rng is None else rng
            n = len(data.x)
            rng, rng_perm = split(rng)
            i = permutation(rng, n)
            n_test = min(16384, int(0.1 * n)) if n_test is None else n_test
            if isinstance(n_test, float):
                n_test = int(n_test * n)
            n_train = n - n_test
            i_train, i_test = i[:n_train], i[n_train:]
            return (
                Data(data.x[i_train], data.y[i_train]),
                Data(data.x[i_test], data.y[i_test]),
            )

        show = plot_every > 0
        iters = 20000 if iters is None else iters
        data_train, data_test = train_test_split(data)
        n_train = len(data_train.x)
        dataloader = DataLoader((data_train, ), batch_size)
        opt_init, opt_update, opt_params = optimizers.adam(
            cosine_scheduler(0.001, 0.0003, iters))
        state = opt_init(params)

        @jit
        def fstep(i, state, batch):
            f, g = value_and_grad(u_funcs["loss"])(
                opt_params(state),
                None,
                batch,
                n=n_train,
                loss_type="nlp",
                reduce="mean",
            )
            return f, opt_update(i, g, state)

        jitloss = jit(partial(u_funcs["loss"], reduce="mean"))

        Checkpoint = namedtuple("Checkpoint", ("iter", "score", "params"))
        checkpoints = []
        if show:
            losses = []
            fig = plt.figure()
            ax = fig.gca()
            lines = None
        for i, batch in tqdm(enumerate(dataloader, 1), total=iters):
            if i > iters:
                break
            f, state = fstep(i, state, batch[0])
            if i % checkpoint_every == 0:
                rng, rng_loss = split(rng)
                checkpoints.append(
                    Checkpoint(
                        i,
                        jitloss(opt_params(state), rng_loss, data_test),
                        opt_params(state),
                    ))
            if show:
                losses.append(f)
                if i % plot_every == 0:
                    x_train_line, y_train_line = (
                        jnp.arange(1, len(losses), 100),
                        jnp.array(losses[::100]),
                    )
                    x_test_line, y_test_line, _ = zip(*checkpoints)
                    if lines is None:
                        lines = (
                            plt.plot(x_train_line,
                                     y_train_line,
                                     linestyle="none",
                                     marker=".")[0],
                            plt.plot(x_test_line, y_test_line)[0],
                        )
                        ax.set_xlabel("Iteration")
                        ax.set_ylabel("Loss")
                    else:
                        lines[0].set_data(x_train_line, y_train_line)
                        lines[1].set_data(x_test_line, y_test_line)
                    ax.set_xlim(x_train_line.min(),
                                max(max(x_train_line), max(x_test_line)))
                    y_min = min(min(y_train_line), min(y_test_line))
                    ax.set_yscale("symlog" if y_min < 0.0 else "log")
                    ax.set_ylim(y_min, y_train_line[:10].max())
                    plt.pause(0.001)

        return (checkpoints[np.argmin([c.score for c in checkpoints])].params
                if len(checkpoints) > 0 else opt_params(state))

    def _reinit_gp_root(params, datasets):
        """
        Helper to reinitialize a GP root with better initial guesses.
        Assumes rbf kernel

        :param params: For whole BHPM
        :param datasets: all datasets.
        :return: The BHPM params
        """
        # Avoid computing all x's & y's here if possible.
        x = jnp.concatenate([
            batch_apply(
                partial(
                    lambda params, inputs: apply_u_ops(params, inputs)
                    [:, structure.f_inputs],
                    lp.u,
                ),
                data.x,
                32000,
            ) for lp, data in zip(params["leaf"], datasets)
        ])

        # Mean function...
        # Kernel...
        params["root"]["kernel"]["raw_scales"] = jnp.log(
            x.max(axis=0) - x.min(axis=0))
        # Inducings
        xu = kmeans_centers(x, params["root"]["xu"].shape[0])
        params["root"]["xu"] = xu
        params["root"]["q_mu"] = jnp.zeros((len(xu), ))
        params["root"]["raw_q_sqrt"] = transforms.lower_cholesky()["inverse"](
            0.1 *
            jitchol(_root_kernel["apply"](params["root"]["kernel"], xu, xu)))
        # Likelihood...

        return params

    def _train_af(
        params,
        rng,
        datasets,
        iters=None,
        leaf_batch=8192,
        root_batch=8192,
        plot_every=1000,
        freeze_f=False,
    ):
        """
        Train a and f while keeping u frozen.
        """
        iters = 20000 if iters is None else iters
        show = plot_every is not None
        leaf_ns = tuple([len(d.x) for d in datasets])
        root_batch_per = root_batch // len(datasets)
        dataloader = DataLoader(datasets, leaf_batch)
        if root_type == "gp" and not freeze_f:
            params = _reinit_gp_root(params, datasets)
        # Repack params to "freeze" u...
        u_params = tuple([pl.u for pl in params["leaf"]])
        if freeze_f:
            af_params = tuple([pl.a for pl in params["leaf"]])
        else:
            af_params = (tuple([pl.a
                                for pl in params["leaf"]]), params["root"])

        opt_init, opt_update, opt_params = optimizers.adam(
            cosine_scheduler(0.001, 0.0003, iters))

        state = opt_init(af_params)

        def loss_af(af_params, batch, rng):
            if freeze_f:
                a_params = af_params
                f_params = params["root"]
            else:
                a_params, f_params = af_params
            p = {
                "leaf":
                tuple([
                    LeafParams(up, ap) for up, ap in zip(u_params, a_params)
                ]),
                "root":
                f_params,
            }
            return loss_fun(
                p,
                batch,
                rng,
                leaf_ns,
                root_batch=root_batch_per,
                loss_type="nlp",
                u_loss=False,
            )

        @jit
        def fstep(i, state, batch, rng):
            f, g = value_and_grad(loss_af)(opt_params(state), batch, rng)
            return f, opt_update(i, g, state)

        if show:
            fig = plt.figure()
            ax = fig.gca()
            line = None
            losses = []
        for i, batch in enumerate(tqdm(dataloader, total=iters), 1):
            if i > iters:
                break
            rng, rng_step = split(rng)
            loss, state = fstep(i, state, batch, rng_step)
            if show:
                losses.append(loss)
                if i % plot_every == 0:
                    x_line, y_line = (
                        jnp.arange(1, len(losses), 100),
                        jnp.array(losses[::100]),
                    )
                    if line is None:
                        line = plt.plot(x_line,
                                        y_line,
                                        linestyle="none",
                                        marker=".")[0]
                        ax.set_xlabel("Iteration")
                        ax.set_ylabel("Loss")
                    else:
                        line.set_data(x_line, y_line)
                    ax.set_xlim(x_line.min(), x_line.max())
                    y_min = y_line.min()
                    ax.set_yscale("symlog" if y_min < 0.0 else "log")
                    ax.set_ylim(y_min, y_line[:10].max())
                    plt.pause(0.001)

        if freeze_f:
            return opt_params(state), params["root"]
        else:
            return opt_params(state)

    def _grad_op(f, grads):
        """
        Apply grads (partial derivatives actually) to an op

        :return: apply(params, inputs) -> (N,)
        """
        f_single = fcnn.as_single(f)
        for g in grads:
            f_single = fcnn.ddx(f_single, g)
        g_vmap = vmap(f_single, in_axes=(None, None, 0))

        def g(params, inputs):
            return g_vmap(params, None, inputs).squeeze()

        return g

    def _div(f):
        """
        Construct divergence operator
        """
        fxx = _grad_op(f, (0, 0))
        fyy = _grad_op(f, (1, 1))

        def g(params, inputs):
            return fxx(params, inputs) + fyy(params, inputs)

        return g

    def _grad_sq(f):
        """
        Construct gradient's squared magnitude operator
        """
        fx = _grad_op(f, (0, ))
        fy = _grad_op(f, (1, ))

        def g(params, inputs):
            return jnp.power(fx(params, inputs), 2) + jnp.power(
                fy(params, inputs), 2)

        return g

    def _construct_op(f, op):
        """
        Construct either a partial derivative or a compositional operator (div or grad squared)
        """
        if isinstance(op, tuple):
            return _grad_op(f, op)
        elif op == "grad_sq":
            return _grad_sq(f)
        elif op == "div":
            return _div(f)
        else:
            return ValueError("Operator %s not recognized" % str(op))

    # Private data:
    u_ops = tuple(
        [_construct_op(u_funcs["apply"], op) for op in structure.u_operators])
    utt = _grad_op(u_funcs["apply"], (2, 2))  # utt

    return {
        "u_funcs": u_funcs,
        "a_funcs": a_funcs,
        "f_funcs": f_funcs,
        "init": init_fun,
        "apply_u_ops": apply_u_ops,
        "apply_lhs": apply_lhs,
        "apply_rhs": apply_rhs,
        "train": train,
        "loss": loss_fun,
        "print_mses": print_mses,
    }