A collection of utility functions to work with PyTorch sparse tensors
Project description
Sparsity-preserving gradient utility tools for PyTorch
A collection of utility functions to work with PyTorch sparse tensors. This is work-in-progress, here be dragons.
Currenly available features with backprop include:
- Memory efficient sparse mm with batch support (workaround for https://github.com/pytorch/pytorch/issues/41128)
- Sparse triangular solver with batch support (see discussion in https://github.com/pytorch/pytorch/issues/87358)
- Generic sparse linear solver (requires a non-differentiable backbone sparse solver)
- Generic sparse linear least-squares solver (requires a non-differentiable backbone sparse linear least-squares solver)
- Wrappers around cupy sparse solvers (see discussion in https://github.com/pytorch/pytorch/issues/69538)
- Wrappers around jax sparse solvers
- Sparse multivariate normal distribution with sparse covariance and precision parameterisation, with reparameterised sampling (rsample)
Additional backbone solvers implemented in pytorch with no additional dependencies include:
- BICGSTAB (ported from pykrylov)
- CG (ported from cornellius-gp/linear_operator)
- LSMR (ported from pytorch-minimize)
- MINRES (ported from cornellius-gp/linear_operator)
Additional features:
- Pairwise voxel encoder for encoding local neighbourhood relationships in a 3D spatial volume with multiple channels, into a sparse COO or CSR matrix.
Things that are missing may be listed as issues.
Installation
The provided package can be installed using:
pip install torchsparsegradutils
or
pip install git+https://github.com/cai4cai/torchsparsegradutils
Unit Tests
A number of unittests are provided, which can be run as:
python -m pytest
(Note that this also runs the tests from unittest
)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for torchsparsegradutils-0.1.2.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | bcffddee7a2dd4db0552992b612618694f8d76a02dad2671db4cb0ccea7fbafd |
|
MD5 | 0aa16c29c0cfd3a9d5441c87eac351b8 |
|
BLAKE2b-256 | c69ccc1719a1ac675c59ca08a3361b86585f517a8d9a344ffc3eaaa829210bd5 |
Hashes for torchsparsegradutils-0.1.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3611408bf1f2ff13eceec01c68820d141dd3be42083f2bb84721cb5c771aba6a |
|
MD5 | 5d6b3cdbcc82c3048611685edd0868c3 |
|
BLAKE2b-256 | 87ee753527362adfef288b0ef2273f95ead81e6deb4f505d1c9dae0cddcf3318 |