Remix.run Logo
D-Machine 2 days ago

Yeah, I am aware that there are sort of long residuals even in classic ViTs, and that, as you say, you can sort of even skip the whole transformer. Like you said, though, this seems very unlikely in practice, and at least, this is a different kind of long residual as in DenseNets or U-Nets though (and yes, Dense Transformers - though I know very, very little about these). I.e. the long residual connections in these seem to be far more "direct" and less "sequential" than the "long residuals" in a classic transformer.

It is hard for me to say what the different consequences for training and gradients are between these two kinds of long residuals, that sounds like more your expertise. But, practically, if you implement your own e.g. DenseNet and torch `forward` calls with Conv layers and adds (or concats), and then implement your own little ViT with multiple MultiheadAttention layers, these really don't feel like the same things at all, in terms of the values you need to save access to, and what you pass in to deeper layers. Just doing a bit of research, it seems like these dense residual transformers are being used for super-resolution tasks. This again seems like the U-Net long residuals, in that the functionality here from the direct long residuals is again about more efficient information propagation, and less clearly about gradients, whereas the "sequential" long residuals implicit in transformers feels again more like a gradient thing.

But, I am definitely NOT an expert here, I just have done a lot of practical twiddling with custom architectures in academic research contexts. I've also often worked with smaller datasets and more unusual data (e.g. 3D or 4D images like MRI, fMRI, or multivariate timeseries like continuous bedside monitoring data), also often with a limited training budget, so my focus has been more on practical differences than theoretical claims / arguments. The DenseNet and "direct" long residual architectures (e.g. U-Net) tended to be unworkable or inordinately expensive for larger 3D or 4D image data, because you have to hold so much monstrously large tensors in memory (or manually move between CPU and GPU to avoid this problem) for the long direct skips. Absent clear performance (or training efficiency) evidence for these architectures made me skeptical of the hand-wavey "feature reuse" claims made in their support, especially when the shorter more sequential residuals (as in classic ViTs, or, in my case, HighResNet for 3D images: https://arxiv.org/abs/1707.01992) seemed just obviously better practically in almost every way.

But of course, we still have much to learn about all this!

godelski 2 days ago | parent [-]

  > you can sort of even skip the whole transformer
I don't mean "sort of", I mean literally.

  > I am definitely NOT an expert here, I just have done a lot of practical twiddling with custom architectures in academic research contexts
If a PhD makes me an expert, then I am. My thesis was on the design of neural architectures
D-Machine 2 days ago | parent [-]

>> I don't mean "sort of", I mean literally.

Well, then we disagree, or are talking past each other, and I think writing out the equations and code show that, as I said, these are really not exactly the same thing. From a code standpoint, direct, long skips require retaining copies of earlier "x" values, and this is a memory cost that is a problem for certain purposes. Mathematically, this also means in a long, direct skip, you are adding in that exact earlier x value.

In the sequential or indirect "long skips" of a transformer, this is not the case. Yes, if you write the equations you can see there is a "path" for identity information to theoretically flow from any layer to any layer unmolested, but in practice this is not how it is implemented, and identity information is not flowing through the layers unchanged.

If everyone thought these subtle differences were irrelevant, than I am not sure why anyone would bother with making a dense residual transformer over the classic transformer. EDIT: nor would much of the papers incorporating special additional long skip connections to various transformer architectures make much sense. The point I was merely making was that, long skips generally serve a very different purpose than shorter / classic residual connections.