JAX dispatcher

Static arguments

Where they are needed

  • shape parameters
  • Scan's length parameter

Hashable static arguments to JIT compiled function

i.e. no list, numpy array.

TypeError: Shape

Related issues:

There are two underlying issues:

  • JAX needs shapes to be determined at tracing time.
  • Random variables need size specified as tuples (see docstring of the dirichlet distribution)

    Let's reproduce the example exactly:

    import jax
    
    shape = jax.numpy.array([1000])
    
    def jax_funcified(prng_key):
        return jax.random.normal(prng_key, shape)
    
    key = jax.random.PRNGKey(0)
    try:
        jax.jit(jax_funcified)(key)
    except Exception as e:
        print(e)
    
    import jax
    import numpy as np
    
    shape = np.array([10])
    
    def jax_funcified(prng_key):
        return jax.random.normal(prng_key, shape)
    
    key = jax.random.PRNGKey(0)
    print(jax.jit(jax_funcified)(key))
    
import jax
import numpy as np

rng_key = jax.random.PRNGKey(0)
try:
    print(jax.random.normal(rng_key, shape=10))
except Exception as e:
    print(e)
print(jax.random.normal(rng_key, shape=[3]))
print(jax.random.normal(rng_key, shape=(3,)))
print(jax.random.normal(rng_key, shape=np.array([3])))
print(jax.random.normal(rng_key, shape=jax.numpy.array([3])))

import jax

def fun(x):
    rng_key = jax.random.PRNGKey(0)
    return jax.random.normal(rng_key, shape=x)

try:
    jax.jit(fun)(1)
except Exception as e:
    print(f"shape as int: {e}")

try:
    jax.jit(fun)([1, 2])
except Exception as e:
    print(f"shape as list: {e}")

try:
    jax.jit(fun)((1, 2))
except Exception as e:
    print(f"shape as tuple: {e}")

# using static_argnums
res = jax.jit(fun, static_argnums=(0,))((1, 2))
print(f"shape as tuple (static argnum): {res}")

try:
    res = jax.jit(fun, static_argnums=(0,))([1, 2])
except Exception as e:
    print(f"shape as list (static_argnums): {e}")

Links to this note