Comment by sieve
4 months ago
Nice! His Shakespeare generator was one of the first projects I tried after ollama. The goal was to understand what LLMs were about.
I have been on an LLM binge this last week or so trying to build a from-scratch training and inference system with two back ends:
- CPU (backed by JAX)
- GPU (backed by wgpu-py). This is critical for me as I am unwilling to deal with the nonsense that is rocm/pytorch. Vulkan works for me. That is what I use with llama-cpp.
I got both back ends working last week, but the GPU back end was buggy. So the week has been about fixing bugs, refactoring the WGSL code, making things more efficient.
I am using LLMs extensively in this process and they have been a revelation. Use a nice refactoring prompt and they are able to fix things one by one resulting in something fully functional and type-checked by astral ty.
Unwilling to deal with pytorch? You couldn't possibly hobble yourself anymore if you tried.
If you want to train/sample large models, then use what the rest of the industry uses.
My use case is different. I want something that I can run quickly on one GPU without worrying about whether it is supported or not.
I am interested in convenience, not in squeezing out the last bit of performance from a card.
You wildly misunderstand pytorch.
9 replies →
If you’re not writing/modifying the model itself but only training, fine tuning, and inferencing, ONNX now supports these with basically any backend execution provider without needing to get into dependency version hell.
What are your thoughts on using JAX? I've used TensorFlow and Pytorch and I feel like I'm missing out by not having experience with JAX. But at the same time, I'm not sure what the advantages are.
I only used it to build the CPU back end. It was a fair bit faster than the previous numpy back end. One good thing about JAX (unlike numpy) is that it also gives you access to a GPU back end if you have the appropriate stuff installed.