NuX is a library for building normalizing flows using JAX.
Normalizing flows learn a parametric model over an unknown probability density function using data. We assume that data points are sampled i.i.d. from an unknown distribution p(x). Normalizing flows learn a parametric approximation of the true data distribution using maximum likelihood learning. The learned distribution can be efficiently sampled from and has a log likelihood that can be evaluated exactly.
It is easy to build, train and evaluate normalizing flows with NuX
import nux
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(0)
# Build a dummy dataset
x_train, x_test = jnp.ones((2, 100, 2))
train_inputs, test_inputs = {"x": x_train}, {"x": x_test}
# Build a simple normalizing flow
def create_flow():
return nux.sequential(nux.Coupling(), nux.AffineLDU(), nux.UnitGaussianPrior())
# Perform data-dependent initialization
flow = nux.Flow(create_flow, key, train_inputs, batch_axes=(0,))
# Run the flow on inputs
outputs = flow.apply(key, test_inputs)
finv_x, log_px = outputs["x"], outputs["log_px"]
# Generate reconstructions
outputs = flow.reconstruct(key, {"x": finv_x})
reconstr = outputs["x"]
# Sample from the flow
outputs = flow.sample(key, n_samples=8)
fz, log_pfz = outputs["x"], outputs["log_px"]
# Construct a maximum likelihood trainer for the flow
trainer = nux.MaximumLikelihoodTrainer(flow)
# Train the flow
keys = jax.random.split(key, 10)
for key in keys:
trainer.grad_step(key, train_inputs)