▲ | saagarjha 9 hours ago | |
This is neat, although it would be nice to see it merged into PyTorch instead of just a paper :) The key seems to be (beyond "obvious" optimizations like not running graphs that are measured to be slower) is that graphs "bake-in" parameters and if those change then the graph needs to be thrown away. The solution is indirecting more, so that what gets captured is a pointer that can remain constant, while the data behind it is changed. This also saves the need to copy in and out of a graph-captured buffer because you can just swap out the pointer instead. Of course there is overhead to this approach (I don't think the authors actually explore this much) in that you throw away information (divisibility, for example) that would allow for constructing better kernels, but often this is still worth it. (Or you could pass this through too.) Something worth exploring later would be getting better support for the rest of CUDA graphs into PyTorch, like conditional nodes. |