Static arguments
Where they are needed
shapeparameters- Scan’s
lengthparameter
Hashable static arguments to JIT compiled function
i.e. no list, numpy array.
TypeError: Shape
Related issues:
There are two underlying issues:
-
JAX needs shapes to be determined at tracing time.
-
Random variables need size specified as tuples (see docstring of the dirichlet distribution)
Let’s reproduce the example exactly:
import jax shape = jax.numpy.array([1000]) def jax_funcified(prng_key): return jax.random.normal(prng_key, shape) key = jax.random.PRNGKey(0) try: jax.jit(jax_funcified)(key) except Exception as e: print(e)#+RESULTS:Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>,). If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.import jax import numpy as np shape = np.array([10]) def jax_funcified(prng_key): return jax.random.normal(prng_key, shape) key = jax.random.PRNGKey(0) print(jax.jit(jax_funcified)(key))#+RESULTS:[-0.3721109 0.26423115 -0.18252768 -0.7368197 -0.44030377 -0.1521442 -0.67135346 -0.5908641 0.73168886 0.5673026 ]import jax import numpy as np rng_key = jax.random.PRNGKey(0) try: print(jax.random.normal(rng_key, shape=10)) except Exception as e: print(e) print(jax.random.normal(rng_key, shape=[3])) print(jax.random.normal(rng_key, shape=(3,))) print(jax.random.normal(rng_key, shape=np.array([3]))) print(jax.random.normal(rng_key, shape=jax.numpy.array([3])))#+RESULTS:jax.core.NamedShape() argument after * must be an iterable, not int [ 1.8160863 -0.48262316 0.33988908] [ 1.8160863 -0.48262316 0.33988908] [ 1.8160863 -0.48262316 0.33988908] [ 1.8160863 -0.48262316 0.33988908]import jax def fun(x): rng_key = jax.random.PRNGKey(0) return jax.random.normal(rng_key, shape=x) try: jax.jit(fun)(1) except Exception as e: print(f"shape as int: {e}") try: jax.jit(fun)([1, 2]) except Exception as e: print(f"shape as list: {e}") try: jax.jit(fun)((1, 2)) except Exception as e: print(f"shape as tuple: {e}") # using static_argnums res = jax.jit(fun, static_argnums=(0,))((1, 2)) print(f"shape as tuple (static argnum): {res}") try: res = jax.jit(fun, static_argnums=(0,))([1, 2]) except Exception as e: print(f"shape as list (static_argnums): {e}")#+RESULTS:shape as int: iteration over a 0-d array shape as list: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>). If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions. shape as tuple: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>). If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions. shape as tuple (static argnum): [[-0.78476596 0.85644484]] shape as list (static_argnums): Non-hashable static arguments are not supported. An error occurred during a call to 'fun' while trying to hash an object of type <class 'list'>, [1, 2]. The error was: TypeError: unhashable type: 'list'