Skip to main content

Loopy belief propagation for factor graphs on discrete variables in JAX

Project description

continuous-integration PyPI version Documentation Status

PGMax

PGMax implements general factor graphs for discrete probabilistic graphical models (PGMs), and hardware-accelerated differentiable loopy belief propagation (LBP) in JAX.

  • General factor graphs: PGMax supports easy specification of general factor graphs with potentially complicated topology, factor definitions, and discrete variables with a varying number of states.
  • LBP in JAX: PGMax generates pure JAX functions implementing LBP for a given factor graph. The generated pure JAX functions run on modern accelerators (GPU/TPU), work with JAX transformations (e.g. vmap for processing batches of models/samples, grad for differentiating through the LBP iterative process), and can be easily used as part of a larger end-to-end differentiable system.

See our companion paper for more details.

PGMax is under active development. APIs may change without notice, and expect rough edges!

Installation | Getting started

Installation

Install from PyPI

pip install pgmax

Install latest version from GitHub

pip install git+https://github.com/deepmind/PGMax.git

Developer

While you can install PGMax in your standard python environment, we strongly recommend using a Python virtual environment to manage your dependencies. This should help to avoid version conflicts and just generally make the installation process easier.

git clone https://github.com/deepmind/PGMax.git
cd PGMax
python3 -m venv pgmax_env
source pgmax_env/bin/activate
pip install --upgrade pip setuptools
pip install -r requirements.txt
python3 setup.py develop

Install on GPU

By default the above commands install JAX for CPU. If you have access to a GPU, follow the official instructions here to install JAX for GPU.

Getting Started

Here are a few self-contained Colab notebooks to help you get started on using PGMax:

Citing PGMax

Please consider citing our companion paper

@article{zhou2022pgmax,
  author = {Zhou, Guangyao and Dedieu, Antoine and Kumar, Nishanth and L{\'a}zaro-Gredilla, Miguel and Kushagra, Shrinu and George, Dileep},
  title = {{PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX}},
  journal = {arXiv preprint arXiv:2202.04110},
  year={2022}
}

and using the DeepMind JAX Ecosystem citation if you use PGMax in your work.

Note

This is not an officially supported Google product.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pgmax-0.6.1.tar.gz (51.3 kB view hashes)

Uploaded Source

Built Distribution

pgmax-0.6.1-py3-none-any.whl (77.5 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page