I doubt it, since jaxtyping supports some quite advanced stuff:
def full(size: int, fill: float) -> Float[Array, "{size}"]:
return jax.numpy.full((size,), fill)
class SomeClass:
some_value = 5
def full(self, fill: float) -> Float[Array, "{self.some_value}+3"]:
return jax.numpy.full((self.some_value + 3,), fill)