| ▲ | Were RNNs all we needed? A GPU programming perspective(dhruvmsheth.github.io) |
| 107 points by omegablues 6 days ago | 31 comments |
| |
|
| ▲ | lettergram 3 days ago | parent | next [-] |
| Back in 2016 - 2018 my work at Capital One resulted in a modified C-RNN style architecture that was producing gpt-2 level results. Using that model we were able to build a general purpose system that could generate data for any dataset (with minimal training, from scratch): https://medium.com/capital-one-tech/why-you-dont-necessarily... At the time it was clear to all on the team that RNNs, just like transformers later on, are general purpose frameworks that really only require more data and size to function. In the 2018-2020 era and probably today, they are slower to train. They also are less prone to certain pitfalls, but overall had the same characteristics. In the 2019-2020 I was convinced that transformers would give way to better architecture. The RNNs in particular trained faster and required less data, particularly when combined with several architectural components I won’t get into. I believe that’s still true today, though I haven’t worked on it in the last 2-3 years. That said, transformers “won” because they are better overall building blocks and don’t require the nuances of RNNs. Combined with the compute optimizations that are now present I don’t see that changing in the near term. Folks are even working to convert transformers to RNNs: https://medium.com/@techsachin/supra-technique-for-linearizi... There are also RNN based models beating Qwen 3 8B in certain benchmarks https://www.rwkv.com/ I suspect over time the other methods my team explored and other types of networks and nodes will continue to expand beyond transformers for state of the art LLMs |
| |
| ▲ | algo_trader 3 days ago | parent [-] | | > RNN based models beating Qwen 3 8B
> https://www.rwkv.com/ Counter consensus is where the alpha is... Do you think rnn/rwkv have an edge with verifiable domains and tree search inference time? You can use cheaper gpus and do multiple sampling. (but of course, its hard to beat the sunk cost of a foundation model) |
|
|
| ▲ | bob1029 3 days ago | parent | prev | next [-] |
| > You simply cannot compute the gates for the entire sequence in one shot, because each step requires the output from the one before it. This forces a sequential loop, which is notoriously inefficient on parallel hardware like GPUs. > The crux of the paper is to remove this direct dependency. The simplified models, minGRU and minLSTM, redefine the gates to depend only on the current input The entire hypothesis of my machine learning experiments has been that we should embrace the time domain and causal dependencies. I really think biology got these elements correct. Now, the question remains - Which kind of computer system is most ideal to run a very branchy and recursive workload? Constantly adapting our experiments to satisfy the constraints of a single kind of compute vendor is probably not healthy for science over the long term. |
| |
| ▲ | fennecbutt 2 days ago | parent | next [-] | | Absolutely, I remember an article from ages ago about a self learning algo implemented on an fpga (I think) that could modify its own make up on a hardware level. It ended up optimising in a way that wasn't obvious at first, but turned out to be the noise of one part interacting with another. Aha:
Here's the paper https://osmarks.net/assets/misc/evolved-circuit.pdf And a fluff article https://www.damninteresting.com/on-the-origin-of-circuits And as per usual, Google was hopeless in finding the article from a rough description. No chance, at all. Chatgpt thought for 10s and delivered the correct result, first time. | |
| ▲ | jstanley 3 days ago | parent | prev | next [-] | | > Which kind of computer system is most ideal to run a very branchy and recursive workload? An analogue one, possibly? | | |
| ▲ | tripplyons 2 days ago | parent [-] | | I think of analogue computing as more continuous than branchy from what I have heard about it. I don't know much about it though. |
| |
| ▲ | tripplyons 2 days ago | parent | prev | next [-] | | The output of the recurrence is still dependent on previous tokens, but it usually less expressive within the recurrence in order make parallelism possible. In MinGRU the main operation used to share information between tokens is addition (with a simple weighting). You could imagine after one layer of the recurrence, the tokens already have some information about each other, so the input to the following layers is dependent on previous tokens, although the dependence is indirect compared to a traditional RNN. | |
| ▲ | inciampati 3 days ago | parent | prev | next [-] | | It turns out you can use a fused triton kernel for a true RNN GRU and run just as fast as the minGRU model in training. Yeah, it doesn't work for very long context but neither does minGRU (activation memory...) | |
| ▲ | nickpsecurity 3 days ago | parent | prev [-] | | Analog or FPGA. Cerebras' wafer-scale technology could help. |
|
|
| ▲ | akst 3 days ago | parent | prev | next [-] |
| I think the author may have linked to the wrong paper in their opening paragraph (hopefully they see this and update the link) https://arxiv.org/abs/2410.01201 I opened it wondering what RNNs were, and for anyone else wondering they are apparently Recurrent Neural Networks, though you find that answer if you continue reading on (though the lack of a definition kind of stopped me in my tracks). |
| |
|
| ▲ | tripplyons 2 days ago | parent | prev | next [-] |
| Have you explored chunkwise parallel approaches? They use the O(n log n) parallel algorithm within a subsequence and update bewteen chunks recurrently like the O(n) recurrent algorithm. These are usually the fastest kernels for these kinds of RNNs. Here is the best explanation I have seen: https://sustcsonglin.github.io/blog/2024/deltanet-2/ |
|
| ▲ | TimorousBestie 3 days ago | parent | prev | next [-] |
| The single-thread performance of the parallel prefix sum that they use is O(N log N), so the improvement from that to O(log N) on N threads is not as surprising. The way the headline is written, it sounds like Amdahl’s law was violated. It wasn’t, of course. |
| |
| ▲ | casta 3 days ago | parent [-] | | How's the prefix sum on a single thread O(N log(N))? Isn't it trivially O(N)? It's just a for loop. | | |
| ▲ | TimorousBestie 3 days ago | parent | next [-] | | Yes, but for loop comes with all those data dependencies that prevent it from being parallelized trivially. The algorithm with fewer data dependencies is O(N log N). This is covered in more detail in the article. | |
| ▲ | gyrovagueGeist 3 days ago | parent | prev [-] | | It's from the depth of the computation, not the work |
|
|
|
| ▲ | vitus 3 days ago | parent | prev | next [-] |
| > The GPU implementation's logarithmic scaling becomes evident at longer sequence lengths. I don't see logarithmic scaling, actually. From the table for GRU performance, going from 16384 -> 65536 (namely: increasing the input by 4x) is roughly a 4x increase in time whether looking at CPU-scan or GPU-scan. Okay, maybe the inputs need to be bigger. Looking at the next plot, which goes up to 524288, we see the same behavior: the delta between CPU-scan and GPU-scan doubles as we double the input. That's a constant multiplicative factor. Same holds for LSTM performance. Is this an artifact of the benchmark setup? Are we actually measuring the amount of time needed to load the full context into RAM? Or perhaps we're bottlenecked on memory bandwidth? > Success: The gate extraction kernel, which was a huge bottleneck, now only takes 8% of the total time and is memory-bandwidth bound, saturating L2 bandwidth at 1.9 TB/s. This is a good place to be. Sounds like that might be the case. |
| |
| ▲ | tripplyons 2 days ago | parent [-] | | Typically the fastest approaches for associative RNNs combine the advantages of the parallel O(n log n) algorithm with a recurrent non-parallel O(n) approach by computing results for subsequence chunks in parallel and moving to the next chunk in a recurrent manner. This blog post explains the method (chunkwise parallel algorithm) well: https://sustcsonglin.github.io/blog/2024/deltanet-2/ |
|
|
| ▲ | DoctorOetker 4 days ago | parent | prev | next [-] |
| Is it not much simpler to parallelize by having different "readers" (using the same model parameters/weights) process different parts of the corpus in parallel? reader A is reading book A, while reader B is reading book B etc...? Is there a deeper reason why more complicated parallelization as in the OP or the article it references is more desirable? |
| |
| ▲ | jsharf 3 days ago | parent | next [-] | | If you have independent copies of the network learning gradients, then you’re effectively making the batch size smaller— unless you’re doing an all collect and making them sync, in which case there’s a lot of overhead When you take a batch and calculate gradients, you’re effectively calculating a direction the weights should move in, and then taking a step in that direction. You can do more steps at once by doing what you say, but they might not all be exactly in the right direction, so overall efficiency is hard to compare I am not an expert, but if I understand correctly I think this is the answer. | | |
| ▲ | immibis 3 days ago | parent [-] | | Batch size is just averaging the gradients from multiple calculations. |
| |
| ▲ | zozbot234 3 days ago | parent | prev [-] | | AIUI, the thinking when developing transformers might have been that "reading text A vs. text B" just isn't parallel enough for truly large-scale training. The problem was to somehow also parallelize the learning of very long range dependencies within a single sequence, and transformers managed to do that. |
|
|
| ▲ | zozbot234 3 days ago | parent | prev | next [-] |
| My understanding is that RNN and LSTM (potentially augmented with some bespoke attention mechanism) have the potential to be more training- and inference-efficient than the transformer models that are more common today, but transformers were adopted because of their unique ability to greatly scale up and parallelize training in a way that just isn't feasible with the older models. So transformers can get better outcomes from their allowed scale of compute despite possibly being less compute-efficient overall. |
| |
| ▲ | spwa4 3 days ago | parent [-] | | I wonder about that. I mean, attention is sort-of obvious isn't it? Just calculate across all data available, look at every moment in time simultaneously. I know the point of attention is to select the correct data and process that, but in order to select the correct data we do a truly massive matrix multiplication. That that was going to work, I believe, was not a mystery even in 1960. It just wasn't possible, and even in 2016 it was not really possible to train with such a principle outside of the FANGs. (not trying to say it wasn't an incredible accomplishment for the authors. There are quite a few details to get right in order to get to the "obvious" advance) Even today it's pretty obvious how such a thing might be further extended. Create a network so big in input it just contains and does attention across it's entire dataset. Such a network would only need basic understanding of language and would not hallucinate. Also it'd be obvious where anything came from, as the attention vectors would show what data was used by what specific part of the network. But this is a theoretical exercise as all the compute power in the world can't do that for any decent size dataset. RNN and LSTM are more training and inference efficient because they don't do this. They do compute for every token and more-or-less then just add every thought any part of the network had together, sequentially. We need to go the opposite direction from attention. It has to be the case that attention is extremely inefficient. | | |
| ▲ | zozbot234 3 days ago | parent [-] | | > I know the point of attention is to select the correct data and process that In a way, the point of attention mechanisms is to bias the model towards long-range dependencies (as seen e.g. in language sentence structure) and away from the very short-term ones that a plain RNN would tend to focus on. LSTM is sort-of in the middle; it manages to persist information beyond the very short run (so "paying attention" to it in a very fuzzy sense) but not as long-term as attention does. |
|
|
|
| ▲ | marcosdumay 3 days ago | parent | prev [-] |
| Well, on the title, our brain seems to be equivalent to a RNN... so yeah, possibly. Anyway, claiming that they are equivalent to transformers when RNNs are Turing complete, and forward-only NNs are not is such a strange take. |
| |
| ▲ | imtringued 3 days ago | parent [-] | | It's a much stranger take to associate the brain with RNNs. It's far more likely that the brain does something similar to a path independent spiking equilibrium model trying to find a fix point, because those models are inherently robust with respect to noise and adversarial attacks, do not require more than one layer and inherently contain feedback loops and tend to generalize well to out of distribution data. Of course in practice they end up somewhere between 2x and 3x slower than a finite layer transformer for the same in distribution performance. | | |
| ▲ | RaftPeople 3 days ago | parent [-] | | > It's a much stranger take to associate the brain with RNNs That seems like too strong of a position. The equilibrium model seems to be a good candidate for some activity in the brain, but it doesn't seem like it applies to everything. For example, the inter-layer recurrence in vision/object detection processing seems to be something different. |
|
|