A fast and memory-saving implementation of competitive gradient descent in PyTorch
Project description
torch-cgd
A fast and memory-efficient implementation of Adaptive Competitive Gradient Descent (ACGD) for PyTorch. The non-adaptive version of the algorithm was originally proposed in this paper, but the adaptive version was proposed in this paper. This repository is essentially a fork of devzhk's cgd-package
, but the code has been heavily refactored for readability and customizability. You can install this package with pip
:
pip install torch-cgd
Get started
You can use ACGD for any competitive losses of the form $\min_x \min_y f(x,y)$, in other words those where one player tries to minimize the loss and another player tries to maximize the loss. You can for example use it to replace your conventional loss function such as the mse
loss with a competitive loss function. This can be beneficial because competitive loss functions can stimulate your network to have a more uniform error over the samples, leading to considerably lower losses although at a high computational cost. It is especially useful for knowledge distillation tasks.
Example
The following code block show an example of this replacement for a network trying to learn the function $y=\sin(x)$. Define the loss as $D(x) (G(x) - y)$, where the term within brackets is the error of the generator with respect to the target solution. In other words, the loss represents how well the discriminator is able to estimate the errors of the generator. As a result, a competitive game arises.
import torch.nn as nn
import torch
import torch_cgd
# Create the dataset
N = 100
x = torch.linspace(0,2*torch.pi,N).reshape(N,1)
y = torch.sin(x)
# Create the models (D = discriminator, G = generator)
G = nn.Sequential(nn.Linear(1, 50), nn.Tanh(), nn.Linear(50, 1))
D = nn.Sequential(nn.Linear(1, 40), nn.ReLU(), nn.Linear(40, 1))
# Initialize the optimizer
solver = torch_cgd.solvers.GMRES(tol=1e-7, atol=1e-20)
optimizer = torch_cgd.ACGD(G.parameters(), D.parameters(), 1e-3, solver=solver)
# Training loop
for i in range(10000):
optimizer.zero_grad()
g_out = G(x)
d_out = D(x)
loss_d = (d_out* (g_out - y)).mean() # Discriminator: maximize
loss_g = -loss_d # Generator: minimize
optimizer.step(loss_d)
mse = torch.mean((g_out - y)**2).item() # Calculate mse
print(i, mse)
Choosing the right solver
One of the steps in ACGD involves inverting a matrix, for which many different methods exist. This library offers two different solvers, namely the Conjugate Gradient method (CG) and the Generalized Minimum RESidual method (GMRES). You can initially them, for example, as follows:
solver = torch_cgd.solvers.CG(tol=1e-7, atol=1e-20)
solver = torch_cgd.solvers.GMRES(tol=1e-7, atol=1e-20)
Which you can then pass to the ACGD optimizer as follows:
optimizer = torch_cgd.ACGD(..., solver=solver)
From my own experience, the best results are obtained with GMRES. Currently, a direct solver is not available yet for ACGD, but it is for CGD. Note that using a direct solver is considerably slower and more memory intensive already for smaller network sizes.
More examples
See the examples folder.
Cite
If you use this code for your research, please cite it as follows:
@misc{torch-cgd,
author = {Thomas Wagenaar},
title = {torch-cgd: A fast and memory-efficient implementation of adaptive competitive gradient descent in PyTorch},
year = {2023},
url = {https://github.com/wagenaartje/torch-cgd}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for torch_cgd-0.0.2-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 38ba3ff277158dd47a6190d09c4d78ae629bce54daef8de2ad5912fc98ff44e6 |
|
MD5 | 8d5679da07f623449deffc5da02a7c06 |
|
BLAKE2b-256 | f308e25d31f6e3d5579481bccb1eb5d6684a157ad71eb27d2a81d03dbbce6e78 |