Comment by thomasahle

9 days ago

You can do that in python using https://github.com/patrick-kidger/torchtyping

looks like this:

    def batch_outer_product(x:   TensorType["batch", "x_channels"],
                            y:   TensorType["batch", "y_channels"]
                            ) -> TensorType["batch", "x_channels", "y_channels"]:

    return x.unsqueeze(-1) * y.unsqueeze(-2)

There's also https://github.com/thomasahle/tensorgrad which uses sympy for "axis" dimension variables:

    b, x, y = sp.symbols("b x y")
    X = tg.Variable("X", b, x)
    Y = tg.Variable("Y", b, y)
    W = tg.Variable("W", x, y)
    XWmY = X @ W - Y

Is there a mypy plugin or other tool to check this via static analysis before runtime? To my knowledge jaxtyping can only be checked at runtime.

  • I doubt it, since jaxtyping supports some quite advanced stuff:

        def full(size: int, fill: float) -> Float[Array, "{size}"]:
            return jax.numpy.full((size,), fill)
    
        class SomeClass:
            some_value = 5
    
            def full(self, fill: float) -> Float[Array, "{self.some_value}+3"]:
                return jax.numpy.full((self.some_value + 3,), fill)