Skip to main content

A JAX wrapper around ceviche to make interoperability easier. In the future it might make sense to update ceviche itself to use JAX internally.

Project description

javiche

Small package to enable using ceviche with a JAX optimizer easily.

Install

This package is not yet published. As soon as it is install with:

pip install javiche

or

conda install javiche

How to use

Import the decorator

from javiche import jaxit

decorate your function (will be differentiated using ceviches jacobian -> HIPS autograd)

@jaxit()
def square(A):
  """squares number/array"""
  return A**2

Now you can use jax as usual:

grad_fn = jax.grad(square)
grad_fn(2.0)
Array(4., dtype=float32, weak_type=True)

In this toy example that was already possible without the jaxit() decorator. However jaxit() decorated functions can contain autograd operators (but no jax operators):

import autograd.numpy as npa
def sin(A):
  """computes sin of number/array using autograds numpy"""
  return npa.sin(A)
grad_sin = jax.grad(sin)
try:
  print(grad_sin(0.0))
except Exception as e:
  print(e)
The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray(0.0, dtype=float32, weak_type=True)>with<JVPTrace(level=2/0)> with
  primal = 0.0
  tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[], weak_type=True), None)
    recipe = LambdaBinding()
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
@jaxit()
def cos(A):
  """computes sin of number/array using autograds numpy"""
  return npa.cos(A)

grad_cos = jax.grad(cos)
try:
  print(grad_cos(0.0))
except Exception as e:
  print(e)
-0.0

Usecase

This library is intended for use with ceviche, while running a JAX optimization stack as demonstated in the inverse design example

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

javiche-0.0.6.tar.gz (10.1 kB view hashes)

Uploaded Source

Built Distribution

javiche-0.0.6-py3-none-any.whl (9.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page