Skip to content

ColCarroll/numpyro

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

NumPyro unstable

Pyro on Numpy. This uses JAX for autograd and JIT support. This is an early stage experimental library that is under active development, and there are likely to be many changes to the API and internal classes, as the design evolves.

Design Goals

  • Lightweight - We do not intend to reimplement any heavy inference machinery from Pyro, but would like to provide a flexible substrate that can be built upon. We will provide support for Pyro primitives like sample and param which can be interpreted with side-effects using effect handlers. Users should be able to extend this to implement custom inference algorithms, and write their models using the familiar Numpy API.
  • Functional - The API for the inference algorithms and other utility functions may deviate from Pyro in favor of a more functional style that works better with JAX. e.g. no global param store or random state.
  • Fast - Using JAX, we aim to aggressively JIT compile intermediate computations to XLA optimized kernels. We will evaluate JIT compilation, and benchmark runtime for Hamiltonian Monte Carlo.

Longer-term Plans

It is possible that much of this code will end up being absorbed into the Pyro project itself as an alternate Numpy backend.

About

Pyro on Numpy

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.9%
  • Makefile 0.1%