Skip to main content

Inference Combinators in JAX

Project description

coix

Unittests Documentation Status PyPI version

Coix (COmbinators In jaX) is a flexible and backend-agnostic implementation of inference combinators (Stites and Zimmermann et al., 2021), a set of program transformations for compositional inference with probabilistic programs. Coix ships with backends for numpyro and oryx, and a set of pre-implemented losses and utility functions that allows to implement and run a wide variety of inference algorithms out-of-the-box.

Coix is a lightweight framework which includes the following main components:

  • coix.api: Implementation of the program combinators.
  • coix.core: Basic program transformations which are used to modify behavior of a stochastic program.
  • coix.loss: Common objectives for variational inference.
  • coix.algo: Example inference algorithms.

Currently, we support numpyro and oryx backends. But other backends can be easily added via the coix.register_backend utility.

This is not an officially supported Google product.

Installation

To install Coix, you can use pip:

pip install coix

or you can clone the repository:

git clone https://github.com/jax-ml/coix.git
cd coix
pip install -e .[dev,doc]

Many examples would run faster on accelerators. You can follow the JAX installation instruction for how to install JAX with GPU or TPU support.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

coix-0.1.0.tar.gz (26.3 kB view hashes)

Uploaded Source

Built Distribution

coix-0.1.0-py3-none-any.whl (34.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page