Skip to content

C-J-Cundy/NuX

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

NuX - Normalizing Flows using JAX

What is NuX?

NuX is a library for building normalizing flows using JAX.

What are normalizing flows?

Normalizing flows learn a parametric model over an unknown probability density function using data. We assume that data points are sampled i.i.d. from an unknown distribution p(x). Normalizing flows learn a parametric approximation of the true data distribution using maximum likelihood learning. The learned distribution can be efficiently sampled from and has a log likelihood that can be evaluated exactly.

Why use NuX?

It is easy to build, train and evaluate normalizing flows with NuX

import nux
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(0)

# Build a dummy dataset
x_train, x_test = jnp.ones((2, 100, 2))
train_inputs, test_inputs = {"x": x_train}, {"x": x_test}

# Build a simple normalizing flow
def create_flow():
  return nux.sequential(nux.Coupling(), nux.AffineLDU(), nux.UnitGaussianPrior())

# Perform data-dependent initialization
flow = nux.Flow(create_flow, key, train_inputs, batch_axes=(0,))

# Run the flow on inputs
outputs = flow.apply(key, test_inputs)
finv_x, log_px = outputs["x"], outputs["log_px"]

# Generate reconstructions
outputs = flow.reconstruct(key, {"x": finv_x})
reconstr = outputs["x"]

# Sample from the flow
outputs = flow.sample(key, n_samples=8)
fz, log_pfz = outputs["x"], outputs["log_px"]

# Construct a maximum likelihood trainer for the flow
trainer = nux.MaximumLikelihoodTrainer(flow)

# Train the flow
keys = jax.random.split(key, 10)
for key in keys:
  trainer.grad_step(key, train_inputs)

About

Normalizing Flows using JAX

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 65.1%
  • Jupyter Notebook 34.9%