Remix.run Logo
thomasahle 4 days ago

You can do that in python using https://github.com/patrick-kidger/torchtyping

looks like this:

    def batch_outer_product(x:   TensorType["batch", "x_channels"],
                            y:   TensorType["batch", "y_channels"]
                            ) -> TensorType["batch", "x_channels", "y_channels"]:

    return x.unsqueeze(-1) * y.unsqueeze(-2)
There's also https://github.com/thomasahle/tensorgrad which uses sympy for "axis" dimension variables:

    b, x, y = sp.symbols("b x y")
    X = tg.Variable("X", b, x)
    Y = tg.Variable("Y", b, y)
    W = tg.Variable("W", x, y)
    XWmY = X @ W - Y
patrickkidger 4 days ago | parent | next [-]

Quick heads-up that these days I recommend https://github.com/patrick-kidger/jaxtyping over the older repository you've linked there.

I learnt a lot the first time around, so the newer one is much better :)

thomasahle 4 days ago | parent [-]

Ah, I would have never thought jaxtyping supports torch :)

ydj 4 days ago | parent | prev [-]

Is there a mypy plugin or other tool to check this via static analysis before runtime? To my knowledge jaxtyping can only be checked at runtime.

thomasahle 4 days ago | parent [-]

I doubt it, since jaxtyping supports some quite advanced stuff:

    def full(size: int, fill: float) -> Float[Array, "{size}"]:
        return jax.numpy.full((size,), fill)

    class SomeClass:
        some_value = 5

        def full(self, fill: float) -> Float[Array, "{self.some_value}+3"]:
            return jax.numpy.full((self.some_value + 3,), fill)