Comment by sillysaurusx

4 years ago

I used Ray to train a massive GPT model by putting each layer on a separate TPU. Ray was able to send all the gradients back and forth as needed.

It scaled fine up to 33 TPUs (i.e. 33 layers).

Ray is impressive as hell.

By the way, I didn't write the code to do any of that. kindiana, aka "the guy that wrote GPT-J", also happened to write this:

https://news.ycombinator.com/item?id=27728225

That's the last I'll be mentioning it for some time, though.

Ray + JAX is such a killer combo.

as someone new to ML/DL where can I get started? would you recommend the Fast AI course using pytorch based libraries or something else that focuses on tensorflow?

  • The key is to find something you think is fun, and play with it. It doesn’t matter what it’s written in. I don’t think it would be fun to take a course, so I never did. But it was a blast to get all this working: https://youtube.com/channel/UCqpwMaJbb-zj-MlMkRd28QQ

    You can see my early videos were crude meme attempts, and eventually it morphed to 100 TPU training. You can’t really predict the things that you’ll like, so don’t try.

    What’s fun to you? That’s the question to focus on. For me, it was language modeling and image generation, which I learned from https://gwern.net/GPT-2 and Peter Baylor’s’ stylegan notebook, respectively. Then I transitioned to audio memes https://youtu.be/koU3L7WBz_s and kept going from there.

    Tensorflow is bullshit, pytorch is different bullshit, Jax is pretty friggin awesome but still has one or two flecks of bs. But none of it feels hard to deal with, because it’s all an enormous box of legos; as long as you chase your interests, you’ll never* feel like it’s annoying.

    * you’ll be annoyed all the time, but at least you’ll keep going.

Can this be done with Dask?

  • I'm not sure. I thought it was nothing short of a miracle that it could be done at all. I tried, hard, in Tensorflow, to make it work. But there was also no way to communicate directly from TPU to TPU; I had to serialize to GCE buckets as a middle step, which added massive complexity to the code.

    The ray solution here is so simple. It was a joy to use. But I don't know anything about Dask.

    By the way, if anyone wants to see how the loss graphs turned out: https://twitter.com/theshawwn/status/1406171487988498433

    (I wish I'd uploaded the logs to tensorboard.dev for posterity. But you can see quite clearly in the screenshots all the information you'd want to see anyway, with apologies to blind engineers. Oh lord, is there a single blind engineer working in ML? Suddenly it's an appealing idea to try to make tensorboard accessible... I wonder how blind people could interpret graphs. Averages, probably.)

  • I don't know about TPUs, but in GPU land, yeah, you can doing fast GPU<>GPU transfers without much code, incl. for dask kernels. More typically, the code is automatically optimized enough without doing manual optimization here, and at least for us, we end up spending our optimization time elsewhere.

    I don't remember what's normal for direct GPU<>GPU, but for many cases we see, the occasions we've done it is through a special pinned memory mode through a staging area. That used to be hard, but nowadays with the rapids.ai ecosystem (cupy / rmm / etc), nice python wrappers.

    Dask is part of that ecosystem ("dask-cudf"), but helps more w/ automation around bigger-than-memory paging, multi-gpu dispatch, and multi-node dispatch. Underneath, it does some nice things for you, like setting CPU<>GPU affinities. However, doing custom peer-to-peer / NUMA stuff quickly gets you back to cupy/rmm, and thankfully, that's Python and integrates nicely with pandas/arrow/ etc :)

    EDIT: There's a wild world of fancy NVLink GPU<>GPU interconnects in more expensive multigpu boxes. We mostly deal with more end-to-end IO issues like network/SSD array ->PCI cards->GPUs as the I/O bottleneck, such as during bigger-than-memory and oneshot/cold use, so I can't speak to the p2p bits as much.

  • I've not trained model using Dask, But I've used it for distributed computing exploration over local network for some data science workload. I found that dask much more stable when compared with Ray, Modin for multi-architecture distributed computing i.e. nodes with different CPU arch - ARMv8, x86_64.

    My goal was to explore the extent of distributed computing using local low power compute nodes where different architectures are common and not to be compared with professional work like gp has detailed with homogeneous architectures.

    But in case you'd like to indulge in similar masochist activities, I have couple of gists like installing Ray on ARM[1], Apache Arrow on ARM[2].

    [1] https://gist.github.com/heavyinfo/aa0bf2feb02aedb3b38eef203b...

    [2] https://gist.github.com/heavyinfo/04e1326bb9bed9cecb19c2d603...