| ▲ | thomasahle 4 days ago |
| You can do that in python using https://github.com/patrick-kidger/torchtyping looks like this: def batch_outer_product(x: TensorType["batch", "x_channels"],
y: TensorType["batch", "y_channels"]
) -> TensorType["batch", "x_channels", "y_channels"]:
return x.unsqueeze(-1) * y.unsqueeze(-2)
There's also https://github.com/thomasahle/tensorgrad which uses sympy for "axis" dimension variables: b, x, y = sp.symbols("b x y")
X = tg.Variable("X", b, x)
Y = tg.Variable("Y", b, y)
W = tg.Variable("W", x, y)
XWmY = X @ W - Y
|
|
| ▲ | patrickkidger 4 days ago | parent | next [-] |
| Quick heads-up that these days I recommend https://github.com/patrick-kidger/jaxtyping over the older repository you've linked there. I learnt a lot the first time around, so the newer one is much better :) |
| |
|
| ▲ | ydj 4 days ago | parent | prev [-] |
| Is there a mypy plugin or other tool to check this via static analysis before runtime? To my knowledge jaxtyping can only be checked at runtime. |
| |
| ▲ | thomasahle 4 days ago | parent [-] | | I doubt it, since jaxtyping supports some quite advanced stuff: def full(size: int, fill: float) -> Float[Array, "{size}"]:
return jax.numpy.full((size,), fill)
class SomeClass:
some_value = 5
def full(self, fill: float) -> Float[Array, "{self.some_value}+3"]:
return jax.numpy.full((self.some_value + 3,), fill)
|
|