def G(input_shape, num_layers, activation=sigmoid, bias=False, **kwargs): if bias: activation = act_but_last(activation) activation = elementwise(activation) block = [Linear(input_shape, **kwargs), Accumulator, activation] layers = flatten_shallow(block for i in range(num_layers))[:-1] mask = flatten_shallow([False, True, False] for i in range(num_layers))[:-1] init_fun, apply_fun = serial_with_outputs(*layers, mask=mask) def init(rng, input_shape): output_shape, params = init_fun(rng, input_shape) opt_params = params[::3] const_params = params[1::3] if bias: opt_params = [affine_padder(p) for p in opt_params] const_params = [augment_dummy_params(p) for p in const_params] return params, opt_params, const_params # even(params), odd(params) def g(opt_params, const_params, inputs): if bias: inputs = augment_data(inputs) o, ys = apply_fun( interleave(opt_params, const_params, (() for p in opt_params)), inputs, **kwargs) if bias: o = column_stripper(o) ys = [column_stripper(y) for y in ys] return o, ys return init, g
def __init__(self, indim, outdim, topology, omega=1.0, transform=None, seed=0): """ Arguments --------- indim: Dimensionality of the a single data input. outdim: Dimensionality of the a single data output. topology: Tuple Defines the structure of the inner layers for the network. omega: Weight distribution factor ω₀ for the first layer (as described in [1]). transform: Optional[Callable] Optional pre-network transformation function. seed: Optional[int] Initial seed for weight initialization. """ tlayer = build_transform_layer(transform) # Weight initialization for Sirens pdf_in = variance_scaling(1.0 / 3, "fan_in", "uniform") pdf = variance_scaling(2.0 / omega**2, "fan_in", "uniform") # Sine activation function σ_in = stax.elementwise(lambda x: np.sin(omega * x)) σ = stax.elementwise(lambda x: np.sin(x)) # Build layers layer_in = [ stax.Flatten, *tlayer, stax.Dense(topology[0], pdf_in), σ_in ] layers = list( chain.from_iterable((stax.Dense(i, pdf), σ) for i in topology[1:])) layers = layer_in + layers + [stax.Dense(outdim, pdf)] # super().__init__(indim, layers, seed)
def q_network(self): #no regression ! if self.dueling: init, apply = stax.serial( elementwise(lambda x: x / 10000.0), stax.serial(Dense(128), Relu, Dense(64), Relu), #base layers FanOut(2), stax.parallel( stax.serial(Dense(32), Relu, Dense(1)), #state value stax.serial(Dense(32), Relu, Dense(self.num_actions))) #advantage func ) else: init, apply = stax.serial(elementwise(lambda x: x/10000.0), Dense(64), Relu, \ Dense(32), Relu, Dense(self.num_actions)) return init, apply
def negativity_transform(): """Layer construction function for negativity transform. This layer is used as the last layer of xc energy density network since exhange and correlation must be negative according to exact conditions. Note we use a 'soft' negativity transformation here. The range after this transformation is (-inf, 0.278]. Returns: (init_fn, apply_fn) pair. """ def negative_fn(x): return -nn.swish(x) return stax.elementwise(negative_fn)
def constructDuelNetwork(n_actions, seed, input_shape): advantage_stream = stax.serial(Dense(512), Relu, Dense(n_actions)) state_function_stream = stax.serial(Dense(512), Relu, Dense(1)) dueling_architecture = stax.serial( elementwise(lambda x: x / 255.0), GeneralConv(dim_nums, 32, (8, 8), strides=(4, 4)), Relu, GeneralConv(dim_nums, 64, (4, 4), strides=(2, 2)), Relu, GeneralConv(dim_nums, 64, (3, 3), strides=(1, 1)), Relu, Flatten, FanOut(2), parallel(advantage_stream, state_function_stream), ) def duelingNetworkMapping(inputs): advantage_values = inputs[0] state_values = inputs[1] advantage_sums = jnp.sum(advantage_values, axis=1) advantage_sums = advantage_sums / float(n_actions) advantage_sums = advantage_sums.reshape(-1, 1) Q_values = state_values + (advantage_values - advantage_sums) return Q_values duelArchitectureMapping = jit(duelingNetworkMapping) ##### Create deep neural net model = DDQN(n_actions, input_shape, adam_params, architecture=dueling_architecture, seed=seed, mappingFunction=duelArchitectureMapping) return model
def constructSingleStreamNetwork(n_actions, seed, input_shape): single_stream_architecture = stax.serial( elementwise(lambda x: x / 255.0), # normalize ### convolutional NN (CNN) GeneralConv(dim_nums, 32, (8, 8), strides=(4, 4)), Relu, GeneralConv(dim_nums, 64, (4, 4), strides=(2, 2)), Relu, GeneralConv(dim_nums, 64, (3, 3), strides=(1, 1)), Relu, Flatten, # flatten output Dense(1024), Relu, Dense(n_actions)) model = DDQN(n_actions, input_shape, adam_params, architecture=single_stream_architecture, seed=seed) return model
import jax from jax import lax from jax import nn from jax import tree_util from jax.experimental import stax import jax.numpy as jnp from jax.scipy import ndimage from jax_dft import scf from jax_dft import utils _STAX_ACTIVATION = { 'relu': stax.Relu, 'elu': stax.Elu, 'softplus': stax.Softplus, 'swish': stax.elementwise(nn.swish), } def negativity_transform(): """Layer construction function for negativity transform. This layer is used as the last layer of xc energy density network since exhange and correlation must be negative according to exact conditions. Note we use a 'soft' negativity transformation here. The range after this transformation is (-inf, 0.278]. Returns: (init_fn, apply_fn) pair. """
def fcnn( output_dim: int, layers=2, units=256, skips=False, output_activation=None, first_layer_factor=None, parallel_with=None, ): """ Creates init and apply functions for a fully-connected NN with a Gaussian likelihood. :param output_dim: Number of output dimensions :param first_layer_factor: Scaling factor for first dense layer initialization. """ first_layer_factor = 10.0 if first_layer_factor is None else first_layer_factor activation = stax.elementwise(np.sin) block = (skip_connect( stax.serial(Dense(units, W_init=siren_w_init()()), activation)) if skips else stax.serial( Dense(units, W_init=siren_w_init()()), activation)) layer_list = [ stax.serial( Dense( units, W_init=siren_w_init(factor=first_layer_factor)(), b_init=jax.nn.initializers.normal(stddev=2.0 * math.pi), ), activation, ), *([block] * (layers - 1)), Dense(output_dim), # No skips on the last one! ] if output_activation is not None: layer_list.append(output_activation) if parallel_with is None: _init_fun, _apply_fun = stax.serial(*layer_list) else: _init_fun, _apply_fun = parallel(stax.serial(*layer_list), parallel_with) def init_fun(rng, input_shape): output_shape, net_params = _init_fun(rng, input_shape) params = {"net": net_params, "raw_noise": np.array(-2.0)} return output_shape, params # Conform to API: @t_wrapper def apply_fun(params, rng, inputs): return _apply_fun(params["net"], inputs) @t_wrapper def gaussian_fun(params, rng, inputs): pred_mean = apply_fun(params, None, inputs) pred_std = params["noise"] * np.ones(pred_mean.shape) return pred_mean, pred_std @t_wrapper def loss_fun(params, rng, data, batch_size=None, n=None, loss_type="nlp", reduce="sum"): """ :param batch_size: How large a batch to subselect from the provided data :param n: The total size of the dataset (to multiply batch estimate by) """ assert loss_type in ("nlp", "mse") inputs, targets = data n = inputs.shape[0] if n is None else n if batch_size is not None: rng, rng_batch = random.split(rng) i = random.permutation(rng_batch, n)[:batch_size] inputs, targets = inputs[i], targets[i] preds = apply_fun(params, rng, inputs).squeeze() mean_loss = ( -norm.logpdf(targets.squeeze(), preds, params["noise"]).mean() if loss_type == "nlp" else np.power(targets.squeeze() - preds, 2).mean()) if reduce == "sum": loss = n * mean_loss elif reduce == "mean": loss = mean_loss return loss def sample_fun_fun(rng, params): def f(x): return apply_fun(params, rng, x) return f return { "init": init_fun, "apply": apply_fun, "gaussian": gaussian_fun, "loss": loss_fun, "sample_function": sample_fun_fun, }
def build_transform_layer(transform: Optional[Callable] = None): return () if transform is None else (stax.elementwise(transform), )
def main(): num_iter = 50000 # Most people run 1000 steps and the OpenAI gym pendulum is 0.05s per step. # The max torque that can be applied is also 2 in their setup. T = 1000 time_delta = 0.05 max_torque = 2.0 rng = random.PRNGKey(0) dynamics = pendulum_dynamics( mass=1.0, length=1.0, gravity=9.8, friction=0.0, ) policy_init, policy_nn = stax.serial( Dense(64), Tanh, Dense(64), Tanh, Dense(1), Tanh, stax.elementwise(lambda x: max_torque * x), ) # Should it matter whether theta is wrapped into [0, 2pi]? policy = lambda params, x: policy_nn( params, jnp.array([x[0] % (2 * jnp.pi), x[1], jnp.cos(x[0]), jnp.sin(x[0])])) def loss(policy_params, x0): x = x0 acc_cost = 0.0 for _ in range(T): u = policy(policy_params, x) x += time_delta * dynamics(x, u) acc_cost += time_delta * cost(x, u) return acc_cost rng_init_params, rng = random.split(rng) _, init_policy_params = policy_init(rng_init_params, (4, )) opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params) loss_and_grad = jit(value_and_grad(loss)) loss_per_iter = [] elapsed_per_iter = [] x0s = vmap(sample_x0)(random.split(rng, num_iter)) for i in range(num_iter): t0 = time.time() loss, g = loss_and_grad(opt.value, x0s[i]) opt = opt.update(g) elapsed = time.time() - t0 loss_per_iter.append(loss) elapsed_per_iter.append(elapsed) print(f"Episode {i}") print(f" loss = {loss}") print(f" elapsed = {elapsed}") blt.remember({ "loss_per_iter": loss_per_iter, "elapsed_per_iter": elapsed_per_iter, "final_params": opt.value }) plt.figure() plt.plot(loss_per_iter) plt.yscale("log") plt.title("ODE control of an inverted pendulum") plt.xlabel("Iteration") plt.ylabel(f"Policy cost (T = {total_secs}s)") # Viz num_viz_rollouts = 50 framerate = 30 timesteps = jnp.linspace(0, int(T * time_delta), num=int(T * time_delta * framerate)) rollout = lambda x0: ode.odeint( lambda x, _: dynamics(x, policy(opt.value, x)), y0=x0, t=timesteps) plt.figure() states = rollout(jnp.zeros(2)) plt.plot(states[:, 0], states[:, 1], marker=".") plt.xlabel("theta") plt.ylabel("theta dot") plt.title("Swing up trajectory") plt.figure() states = vmap(rollout)(x0s[:num_viz_rollouts]) for i in range(num_viz_rollouts): plt.plot(states[i, :, 0], states[i, :, 1], marker='.', alpha=0.5) plt.xlabel("theta") plt.ylabel("theta dot") plt.title("Phase space trajectory") plot_control_contour(lambda x: policy(opt.value, x)) plot_policy_dynamics(dynamics, lambda x: policy(opt.value, x)) blt.show()
return output1, (params1, params2) def apply_fun(params, inputs, **kwargs): # TODO: don't use diffeq wrapper, get t, x = inputs?? params1, params2 = params t1, out1 = apply_fn1(params1, inputs, **kwargs) t2, out2 = apply_fn2(params2, inputs, **kwargs) # t1 == t2 since sdeint calls return t2, out2 + out1 return init_fun, apply_fun # activations rbf = lambda x: np.exp(-x**2) Rbf = elementwise(rbf) Rbf = register('rbf')(Rbf) # @register('rbf') Elu = elementwise(jax.nn.elu) Elu = register('elu')(stax.Elu) Softplus = register('softplus')(stax.Softplus) swish = lambda x: x * jax.nn.sigmoid(x) Swish = elementwise(swish) Swish = register('swish_nobeta')(Swish) Relu = register('relu')(stax.Relu) Tanh = register('tanh')(stax.Tanh) @register('swish') def Swish_(out_dim, beta_init=0.5): """ Trainable Swish function to learn
averages_mat = averages_mat[1:, :].T averages_df = pd.DataFrame(data=averages_mat, index=adata_ref.raw.var_names, columns=all_clusters) return averages_df from jax import random from jax.experimental import stax import jax.numpy as jnp from jax.random import PRNGKey import numpyro Log1p = stax.elementwise(jax.lax.log1p) def encoder(hidden_dim, z_dim): return stax.serial( Log1p, stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.FanOut(2), stax.parallel( stax.Dense(z_dim, W_init=stax.randn()), stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp)), ) def decoder(hidden_dim, out_dim):
def apply_fun(params, inputs, **kwargs): return inputs.sum(axis=-1) return init_fun, apply_fun SumLayer = SumLayer() def logcosh(x): x = x * jax.numpy.sign(x.real) return x + jax.numpy.log(1.0 + jax.numpy.exp(-2.0 * x)) - jax.numpy.log(2.0) LogCoshLayer = stax.elementwise(logcosh) def JaxRbm(hilbert, alpha, dtype=complex): return Jax( hilbert, stax.serial(stax.Dense(alpha * hilbert.size), LogCoshLayer, SumLayer), dtype=dtype, ) def MPSPeriodic(hilbert, graph, bond_dim, diag=False, symperiod=None,
import jax.numpy as np from jax import random from jax.experimental.stax import elementwise, serial, Dense, Softmax from jax.nn import sigmoid from jax.nn.initializers import glorot_normal, kaiming_normal, orthogonal #from utils import flatten_shallow, interleave from utils import * Abs = elementwise(np.abs) def serial_with_outputs(*layers, mask): """Combinator for composing layers in serial. Args: *layers: a sequence of layers, each an (init_fun, apply_fun) pair. Returns: A new layer, meaning an (init_fun, apply_fun) pair, representing the serial composition of the given sequence of layers. """ nlayers = len(layers) init_funs, apply_funs = zip(*layers) def init_fun(rng, input_shape): params = [] for init_fun in init_funs: rng, layer_rng = random.split(rng) input_shape, param = init_fun(layer_rng, input_shape) params.append(param) return input_shape, params def apply_fun(params, inputs, **kwargs):
jnp.array(x) * jax.nn.softplus(beta)) # no / 1.1 for lipschitz def init_fun(rng, input_shape): output_shape = input_shape[:-1] + (out_dim, ) beta0 = jnp.ones((out_dim, )) * beta_init return input_shape, (jnp.array(beta0), ) def apply_fun(params, inputs, **kwargs): beta_, = params ret = swish(beta_, inputs) return ret return init_fun, apply_fun Rbf = stax.elementwise(lambda x: jnp.exp(-x**2)) ACTFNS = { "softplus": stax.Softplus, "tanh": stax.Tanh, "elu": stax.Elu, "rbf": Rbf, "swish": Swish_ } def MLP(hidden_dims=[1, 64, 1], actfn="softplus", xt=False, ou_dw=False, p_scale=-1.0, nonzero_w=-1.0,
import functools import sys import my_sampler import random @jax.jit def logcosh(x): """logcosh activation function. To use this function as layer, use LogCoshLayer. """ x = x * jax.numpy.sign(x.real) return x + jax.numpy.log(1.0 + jax.numpy.exp(-2.0 * x)) - jax.numpy.log(2.0) LogCoshLayer = stax.elementwise(logcosh) #https://arxiv.org/pdf/1705.09792.pdf #complex activation function, see https://arxiv.org/pdf/1802.08026.pdf @jax.jit def modrelu(x): """modrelu activation function. To use this function as layer, use ModReLu. See https://arxiv.org/pdf/1705.09792.pdf """ return jnp.maximum(1, jnp.abs(x)) * x/jnp.abs(x) ModReLu = stax.elementwise(modrelu) #https://arxiv.org/pdf/1705.09792.pdf #complex activation function, see https://arxiv.org/pdf/1802.08026.pdf @jax.jit
def learned_dynamics(params): @jit def dynamics(q, q_t): # assert q.shape == (2,) state = wrap_coords(jnp.concatenate([q, q_t])) return jnp.squeeze(nn_forward_fn(params, state), axis=-1) return dynamics from jax.experimental.stax import serial, Dense, Softplus, Tanh, elementwise, Relu sigmoid = jit(lambda x: 1 / (1 + jnp.exp(-x))) swish = jit(lambda x: x / (1 + jnp.exp(-x))) relu3 = jit(lambda x: jnp.clip(x, 0.0, float('inf'))**3) Swish = elementwise(swish) Relu3 = elementwise(relu3) def extended_mlp(args): act = { 'softplus': [Softplus, Softplus], 'swish': [Swish, Swish], 'tanh': [Tanh, Tanh], 'tanh_relu': [Tanh, Relu], 'soft_relu': [Softplus, Relu], 'relu_relu': [Relu, Relu], 'relu_relu3': [Relu, Relu3], 'relu3_relu': [Relu3, Relu], 'relu_tanh': [Relu, Tanh], }[args.act]
def _elementwise(fun, **fun_kwargs): init_fun, apply_fun = stax.elementwise(fun, **fun_kwargs) ker_fun = lambda kernels: _transform_kernels(kernels, fun, **fun_kwargs) return init_fun, apply_fun, ker_fun
from __future__ import absolute_import from __future__ import division from __future__ import print_function from jax.experimental.stax import (AvgPool, BatchNorm, Conv, Dense, FanInSum, FanOut, Flatten, GeneralConv, Identity, MaxPool, Relu, LogSoftmax, Softplus) from jax.nn import sigmoid, swish from neural_tangents import stax import jax.experimental.stax as jax_stax from layers import MyConv, MyDense Swish = jax_stax.elementwise(swish) def swish_ten(x): return x * sigmoid(10 * x) Swishten = jax_stax.elementwise(swish_ten) def CNNStandard(n_channels, L, filter=(3, 3), data='cifar10', gap=True, nonlinearity='relu', parameterization='standard',
from research.estop import mdp tau = 1e-4 buffer_size = 2**15 batch_size = 64 num_eval_rollouts = 128 opt_init = make_optimizer(optimizers.adam(step_size=1e-3)) init_noise = Normal(jp.array(0.0), jp.array(0.0)) noise = lambda _1, _2: Normal(jp.array(0.0), jp.array(0.5)) actor_init, actor = stax.serial( Dense(64), Relu, Dense(1), Tanh, stax.elementwise(lambda x: config.max_torque * x), ) critic_init, critic = stax.serial( FanInConcat(), Dense(64), Relu, Dense(64), Relu, Dense(1), stax.elementwise(lambda x: x + 1.0 / (1 - config.gamma)), Scalarify, ) policy = lambda p: lambda s: Deterministic(actor(p, s))
def _hpm(structure, datasets, root_type="gp"): """ Custom hidden physics model for the ultrasound problem. Leaves have a(x,y) and u(x,y,t). Both are MLPs. Physics is utt = a(x, y) f(...), where f is a GP or the wave operator. :return: HPM """ n_leaves = len(datasets) n = sum([d[0].shape[0] for d in datasets]) # Public funcs u_funcs = fcnn.fcnn(1, layers=5, units=128, first_layer_factor=1) a_funcs = fcnn.fcnn( 1, layers=5, units=128, first_layer_factor=8.0, output_activation=stax.serial(stax.elementwise(lambda x: 0.1 * x), stax.elementwise(jnp.exp)), ) if root_type == "gp": _root_mean = gp_mean_functions.linear() _root_kernel = kernels.rbf() _root_likelihood = likelihoods.gaussian() f_funcs = sparse_layer.sparse_layer(n, _root_mean, _root_kernel, _root_likelihood, safe=False, jitter=1.0e-5) elif root_type == "wave": if not (len(structure.f_inputs) == 1 and structure.u_operators[structure.f_inputs[0]] == "div"): raise ValueError( 'Wave operator requires "div" be the only u operator fed to f.' ) f_funcs = wave() def init_fun(rng): params = {"leaf": [], "root": None} for _ in range(n_leaves): rng, rng_u, rng_a = split(rng, num=3) u_params = u_funcs["init"](rng_u, (-1, 3))[1] # x, y, t a_params = a_funcs["init"]( rng_a, (-1, 2 + len(structure.a_inputs)))[1] # x, y, ... params["leaf"].append(LeafParams(u_params, a_params)) rng, rng_root = split(rng) if root_type == "gp": params["root"] = f_funcs["init"](rng_root, (-1, len(structure.f_inputs)), m=NUM_INDUCING)[1] elif root_type == "wave": params["root"] = f_funcs["init"]()[1] return params def apply_u_ops(params, inputs): """ :param params: For u net :param inputs: (N, 3) :return: jnp.array, shape (N_Ops,N) """ return jnp.stack([op(params, inputs) for op in u_ops]).T def apply_lhs(params, inputs, leaf): """ Compute utt(inputs) for a specified leaf :param params: For the whole HPM :param inputs: (N,3) :return: (N,) """ u_params = params["leaf"][leaf].u return utt(u_params, inputs) def apply_rhs(params, inputs, leaf, rng=None): """ Compute a^2()f() for a specified leaf. Use the mean of f for now. :params inputs: (N,3) :return: (N,) """ rng = PRNGKey(42) if rng is None else rng f_params = params["root"] u_params, a_params = params["leaf"][leaf] u_op_vals = apply_u_ops(u_params, inputs) f_inputs = u_op_vals[:, structure.f_inputs] a_inputs = jnp.concatenate( (inputs[:, :2], u_op_vals[:, structure.a_inputs]), axis=1) a = a_funcs["apply"](a_params, None, a_inputs)[:, 0] # (N,) rng, rng_f = split(rng) # Posterior mean, flattened from (N,1) to (N,) f = f_funcs["gaussian"](f_params, rng_f, f_inputs)[0].squeeze() return a * a * f def train( rng, params, datasets, u_iters=None, af_iters=None, freeze_f=False, ): """ Train u's to data Train a and f to physics """ rng, rng_leaf = split(rng) params["leaf"] = tuple([ LeafParams( _train_u(rng_u, pl.u, data, iters=u_iters), pl.a, ) for rng_u, pl, data in zip(split(rng_leaf, num=n_leaves), params["leaf"], datasets) ]) # Root pre-training: freeze u's. rng, rng_af = split(rng) a_params, f_params = _train_af( params, rng_af, datasets, iters=af_iters, freeze_f=freeze_f, ) params["leaf"] = tuple( [LeafParams(pl.u, ap) for pl, ap in zip(params["leaf"], a_params)]) params["root"] = f_params return params def loss_fun( params, batches, rng, leaf_ns, root_batch=None, loss_type="nlp", u_loss=True, af_loss=True, ): """ * Data: u targets * Physics: utt = a^2 * f :param leaf_ns: (tuple of ints) How many data in each leaf. :param root_batch: n per leaf to use in root """ # Leaves: leaves_loss = 0.0 a_list, f_in_list, lhs_list = [], [], [] for lp, batch, leaf_n in zip(params["leaf"], batches, leaf_ns): # Data loss: rng, rng_u = split(rng) leaves_loss = leaves_loss + u_funcs["loss"]( lp.u, rng_u, batch, n=leaf_n, loss_type=loss_type, reduce="sum", # mean later ) # Gather inputs needed for root/physics xt_root = batch.x[:root_batch] u_op_vals = apply_u_ops(lp.u, xt_root) f_in_list.append(u_op_vals[:, structure.f_inputs]) # columns are 0,3,4 = u, uxx, uyy a_in = jnp.concatenate( (xt_root[:, :2], u_op_vals[:, structure.a_inputs]), axis=1) a_list.append(a_funcs["apply"](lp.a, None, a_in)[:, 0]) # (B_root,) lhs_list.append(utt(lp.u, xt_root)) # Physics loss f_in = jnp.concatenate(f_in_list) rng, rng_f = split(rng) f_mean, f_std = f_funcs["gaussian"](params["root"], rng_f, f_in) f_mean, f_std = f_mean.squeeze(), f_std.squeeze( ) # NN returns (N,1)'s... a = jnp.concatenate(a_list) rhs_mean, rhs_std = a * a * f_mean, a * a * f_std lhs = jnp.concatenate(lhs_list) # Yuck :( TODO into own func w/ root_type, loss_type to dispatch... if loss_type == "mse": root_loss = jnp.power(lhs - rhs_mean, 2).mean() * n else: # nlp if root_type == "gp": # GP lik_scale = transforms.apply_transform( params["root"], sparse_layer.param_transform)["likelihood"]["noise"] root_loss = -(n * norm.logpdf( lhs, loc=rhs_mean, scale=f_std + lik_scale).mean() - f_funcs["kl"](params["root"])) elif root_type == "wave": lik_scale = _wave_xform(params["root"][0]) root_loss = -(n * norm.logpdf( lhs, loc=rhs_mean, scale=f_std + lik_scale).mean()) # Also yuck; can shovel into a loss aggregation func. if u_loss and af_loss: combined_loss = leaves_loss + root_loss elif u_loss: combined_loss = leaves_loss else: combined_loss = root_loss loss = combined_loss / n return loss def print_mses(params, datasets, rng=None): """ Quick helper to show about how good the data & physics are learned. Current just takes a single minibatch instead of the actual full datasets. """ rng = PRNGKey(42) if rng is None else rng leaf_ns = [len(d.x) for d in datasets] kwargs = {"loss_type": "mse"} for batch in DataLoader(datasets, 16384): args = (params, batch, rng, leaf_ns) print("u MSE : %.2e" % loss_fun(*args, af_loss=False, **kwargs)) print("af MSE : %.2e" % loss_fun(*args, u_loss=False, **kwargs)) break # Private funcs def _train_u( rng, params, data, batch_size=4096, iters=None, plot_every=1000, checkpoint_every=1000, ): """ Train observation function on data, using cross-validation to early stop """ # Could definitely just pull this out as a general-purpose NN-training func and # give things space to breathe... # Collapse in your editor is your friend! def train_test_split(data, rng=None, n_test=None): """ Create a train-test split """ rng = PRNGKey(42) if rng is None else rng n = len(data.x) rng, rng_perm = split(rng) i = permutation(rng, n) n_test = min(16384, int(0.1 * n)) if n_test is None else n_test if isinstance(n_test, float): n_test = int(n_test * n) n_train = n - n_test i_train, i_test = i[:n_train], i[n_train:] return ( Data(data.x[i_train], data.y[i_train]), Data(data.x[i_test], data.y[i_test]), ) show = plot_every > 0 iters = 20000 if iters is None else iters data_train, data_test = train_test_split(data) n_train = len(data_train.x) dataloader = DataLoader((data_train, ), batch_size) opt_init, opt_update, opt_params = optimizers.adam( cosine_scheduler(0.001, 0.0003, iters)) state = opt_init(params) @jit def fstep(i, state, batch): f, g = value_and_grad(u_funcs["loss"])( opt_params(state), None, batch, n=n_train, loss_type="nlp", reduce="mean", ) return f, opt_update(i, g, state) jitloss = jit(partial(u_funcs["loss"], reduce="mean")) Checkpoint = namedtuple("Checkpoint", ("iter", "score", "params")) checkpoints = [] if show: losses = [] fig = plt.figure() ax = fig.gca() lines = None for i, batch in tqdm(enumerate(dataloader, 1), total=iters): if i > iters: break f, state = fstep(i, state, batch[0]) if i % checkpoint_every == 0: rng, rng_loss = split(rng) checkpoints.append( Checkpoint( i, jitloss(opt_params(state), rng_loss, data_test), opt_params(state), )) if show: losses.append(f) if i % plot_every == 0: x_train_line, y_train_line = ( jnp.arange(1, len(losses), 100), jnp.array(losses[::100]), ) x_test_line, y_test_line, _ = zip(*checkpoints) if lines is None: lines = ( plt.plot(x_train_line, y_train_line, linestyle="none", marker=".")[0], plt.plot(x_test_line, y_test_line)[0], ) ax.set_xlabel("Iteration") ax.set_ylabel("Loss") else: lines[0].set_data(x_train_line, y_train_line) lines[1].set_data(x_test_line, y_test_line) ax.set_xlim(x_train_line.min(), max(max(x_train_line), max(x_test_line))) y_min = min(min(y_train_line), min(y_test_line)) ax.set_yscale("symlog" if y_min < 0.0 else "log") ax.set_ylim(y_min, y_train_line[:10].max()) plt.pause(0.001) return (checkpoints[np.argmin([c.score for c in checkpoints])].params if len(checkpoints) > 0 else opt_params(state)) def _reinit_gp_root(params, datasets): """ Helper to reinitialize a GP root with better initial guesses. Assumes rbf kernel :param params: For whole BHPM :param datasets: all datasets. :return: The BHPM params """ # Avoid computing all x's & y's here if possible. x = jnp.concatenate([ batch_apply( partial( lambda params, inputs: apply_u_ops(params, inputs) [:, structure.f_inputs], lp.u, ), data.x, 32000, ) for lp, data in zip(params["leaf"], datasets) ]) # Mean function... # Kernel... params["root"]["kernel"]["raw_scales"] = jnp.log( x.max(axis=0) - x.min(axis=0)) # Inducings xu = kmeans_centers(x, params["root"]["xu"].shape[0]) params["root"]["xu"] = xu params["root"]["q_mu"] = jnp.zeros((len(xu), )) params["root"]["raw_q_sqrt"] = transforms.lower_cholesky()["inverse"]( 0.1 * jitchol(_root_kernel["apply"](params["root"]["kernel"], xu, xu))) # Likelihood... return params def _train_af( params, rng, datasets, iters=None, leaf_batch=8192, root_batch=8192, plot_every=1000, freeze_f=False, ): """ Train a and f while keeping u frozen. """ iters = 20000 if iters is None else iters show = plot_every is not None leaf_ns = tuple([len(d.x) for d in datasets]) root_batch_per = root_batch // len(datasets) dataloader = DataLoader(datasets, leaf_batch) if root_type == "gp" and not freeze_f: params = _reinit_gp_root(params, datasets) # Repack params to "freeze" u... u_params = tuple([pl.u for pl in params["leaf"]]) if freeze_f: af_params = tuple([pl.a for pl in params["leaf"]]) else: af_params = (tuple([pl.a for pl in params["leaf"]]), params["root"]) opt_init, opt_update, opt_params = optimizers.adam( cosine_scheduler(0.001, 0.0003, iters)) state = opt_init(af_params) def loss_af(af_params, batch, rng): if freeze_f: a_params = af_params f_params = params["root"] else: a_params, f_params = af_params p = { "leaf": tuple([ LeafParams(up, ap) for up, ap in zip(u_params, a_params) ]), "root": f_params, } return loss_fun( p, batch, rng, leaf_ns, root_batch=root_batch_per, loss_type="nlp", u_loss=False, ) @jit def fstep(i, state, batch, rng): f, g = value_and_grad(loss_af)(opt_params(state), batch, rng) return f, opt_update(i, g, state) if show: fig = plt.figure() ax = fig.gca() line = None losses = [] for i, batch in enumerate(tqdm(dataloader, total=iters), 1): if i > iters: break rng, rng_step = split(rng) loss, state = fstep(i, state, batch, rng_step) if show: losses.append(loss) if i % plot_every == 0: x_line, y_line = ( jnp.arange(1, len(losses), 100), jnp.array(losses[::100]), ) if line is None: line = plt.plot(x_line, y_line, linestyle="none", marker=".")[0] ax.set_xlabel("Iteration") ax.set_ylabel("Loss") else: line.set_data(x_line, y_line) ax.set_xlim(x_line.min(), x_line.max()) y_min = y_line.min() ax.set_yscale("symlog" if y_min < 0.0 else "log") ax.set_ylim(y_min, y_line[:10].max()) plt.pause(0.001) if freeze_f: return opt_params(state), params["root"] else: return opt_params(state) def _grad_op(f, grads): """ Apply grads (partial derivatives actually) to an op :return: apply(params, inputs) -> (N,) """ f_single = fcnn.as_single(f) for g in grads: f_single = fcnn.ddx(f_single, g) g_vmap = vmap(f_single, in_axes=(None, None, 0)) def g(params, inputs): return g_vmap(params, None, inputs).squeeze() return g def _div(f): """ Construct divergence operator """ fxx = _grad_op(f, (0, 0)) fyy = _grad_op(f, (1, 1)) def g(params, inputs): return fxx(params, inputs) + fyy(params, inputs) return g def _grad_sq(f): """ Construct gradient's squared magnitude operator """ fx = _grad_op(f, (0, )) fy = _grad_op(f, (1, )) def g(params, inputs): return jnp.power(fx(params, inputs), 2) + jnp.power( fy(params, inputs), 2) return g def _construct_op(f, op): """ Construct either a partial derivative or a compositional operator (div or grad squared) """ if isinstance(op, tuple): return _grad_op(f, op) elif op == "grad_sq": return _grad_sq(f) elif op == "div": return _div(f) else: return ValueError("Operator %s not recognized" % str(op)) # Private data: u_ops = tuple( [_construct_op(u_funcs["apply"], op) for op in structure.u_operators]) utt = _grad_op(u_funcs["apply"], (2, 2)) # utt return { "u_funcs": u_funcs, "a_funcs": a_funcs, "f_funcs": f_funcs, "init": init_fun, "apply_u_ops": apply_u_ops, "apply_lhs": apply_lhs, "apply_rhs": apply_rhs, "train": train, "loss": loss_fun, "print_mses": print_mses, }