Remix.run Logo
D-Machine 3 days ago

I'm not sure your claim is correct, nor particularly precise. Skip connections do propagate feature maps from earlier to later/deeper in the network, and one can model this mathematically as using an identity function to do this, but I wouldn't say that this is "about preserving identity". Also, the reason the identity is useful here is because it prevents vanishing gradients (due to the identity function having an identity tensor as the gradient). Arguably, preventing vanishing gradients is exactly about numerical stability, so I would in fact argue that skip connections are there to make gradient descent of certain architectures more numerically stable.

EDIT: Just to provide some research support: https://arxiv.org/html/2405.01725v1#S5.SS4.5.1.1. I.e. skip connections do a lot of things, but most of these things boil down to making gradients better conditioned. Insofar as as one problem of poor conditioning is gradient vanishing, and vanishing is really only a problem because of the limited precision of floating point, it is hard for me to not think skip connections are primarily there for numerical stability, broadly.

One case supporting your point, though, where one might argue the "main" reason for skip connections is propagating information rather than for stability, is in LONG skip connections, as in U-Nets. There, since feature maps shrink with network depth, long skip connections are needed to bring high-resolution information from earlier layers to deeper ones.

storus 7 hours ago | parent | next [-]

Stanford's CS224N told us it was about preserving identity because non-linear neural networks are bad at learning identity.

godelski 3 days ago | parent | prev [-]

You're correct, I also added some info in a sibling comment that might interest you. But you might also be interested in reading the DenseNet paper[0]. While it's convs, the ideas still apply. There's actually a lot of research in this topic though they'll be in more "theory" papers and the ideas are slow to propagate into the mainstream works.

[0] https://arxiv.org/abs/1608.06993

D-Machine 3 days ago | parent [-]

Yup, I am pretty up to date with the literature (mostly) and read about DenseNet a long time ago (and have trained trained and tuned quite a few of them too). I've also done enough experimentation with tuning custom CNN architectures (on smaller medical datasets, mind you) where the inclusion of residual connections was a hyperparameter to be tuned. It is pretty obvious that with deeper networks, turning off the skip connections just makes your loss curves hit a premature (and bad) plateau far more often than not.

Residual connections in CNNs clearly don't change the representational capacity, and don't obviously induce any bias that seems beneficial. I.e. if you needed information from earlier feature maps, the optimizer could surely just learn to make intermediate layers preserve that information (by having some conv-layer channels just basically be identities). Actually, if you inspect a lot of feature maps through visualizations, it is pretty clear that information is lost very gradually, over many, many layers, and it just isn't plausible that typical (short) skip connections are really meaningfully being helpful because they "preserve information". This is in stark contrast to U-Nets, where the skips are long enough, and from layers of differing enough resolutions it is clear the purpose is very different. EDIT: With DenseNet, it is trickier, because some of the residual connections are quite long, and the paper does provide some evidence to suggest that the length is useful. But there are not any long residual connections in Vision Transformers, and it broadly seems like ViTs and modern conv nets (which don't use DenseNet "long" concats) have basically equivalent performance (https://arxiv.org/abs/2310.16764). So the whole idea of long connections being important just doesn't sit right with me. Short, typical residual connection is useful, and this seems most clearly to be about gradients.

And practically, when suddenly switching on residual connections leads to better training despite an otherwise identical architecture, it is really hard to see them as anything other than a tool for stabilizing optimization.

godelski 2 days ago | parent [-]

  > there are not any long residual connections in Vision Transformers
That's incorrect. If you look at the architecture a little closer you'll see you have very long residuals. The transformer arch is res(norm + attn) + res(norm + FFN). That allows you to just skip a whole layer. Going backwards we take the first shortcut, skipping the Norm + FFN, then can take the second shortcut, skipping the norm + attn. So you skip the whole transformer! It's pretty unlikely that this will happen in practice so you can strengthen that by doing dense transformers but as far as I've seen it isn't really much help and it does look like the gradient propagates the whole way

I'd recommend looking at the famous 3 Things paper [0]. You'll see what I said in their diagram. It's got some other good gems.

[0] https://arxiv.org/abs/2203.09795

D-Machine 2 days ago | parent [-]

Yeah, I am aware that there are sort of long residuals even in classic ViTs, and that, as you say, you can sort of even skip the whole transformer. Like you said, though, this seems very unlikely in practice, and at least, this is a different kind of long residual as in DenseNets or U-Nets though (and yes, Dense Transformers - though I know very, very little about these). I.e. the long residual connections in these seem to be far more "direct" and less "sequential" than the "long residuals" in a classic transformer.

It is hard for me to say what the different consequences for training and gradients are between these two kinds of long residuals, that sounds like more your expertise. But, practically, if you implement your own e.g. DenseNet and torch `forward` calls with Conv layers and adds (or concats), and then implement your own little ViT with multiple MultiheadAttention layers, these really don't feel like the same things at all, in terms of the values you need to save access to, and what you pass in to deeper layers. Just doing a bit of research, it seems like these dense residual transformers are being used for super-resolution tasks. This again seems like the U-Net long residuals, in that the functionality here from the direct long residuals is again about more efficient information propagation, and less clearly about gradients, whereas the "sequential" long residuals implicit in transformers feels again more like a gradient thing.

But, I am definitely NOT an expert here, I just have done a lot of practical twiddling with custom architectures in academic research contexts. I've also often worked with smaller datasets and more unusual data (e.g. 3D or 4D images like MRI, fMRI, or multivariate timeseries like continuous bedside monitoring data), also often with a limited training budget, so my focus has been more on practical differences than theoretical claims / arguments. The DenseNet and "direct" long residual architectures (e.g. U-Net) tended to be unworkable or inordinately expensive for larger 3D or 4D image data, because you have to hold so much monstrously large tensors in memory (or manually move between CPU and GPU to avoid this problem) for the long direct skips. Absent clear performance (or training efficiency) evidence for these architectures made me skeptical of the hand-wavey "feature reuse" claims made in their support, especially when the shorter more sequential residuals (as in classic ViTs, or, in my case, HighResNet for 3D images: https://arxiv.org/abs/1707.01992) seemed just obviously better practically in almost every way.

But of course, we still have much to learn about all this!

godelski 2 days ago | parent [-]

  > you can sort of even skip the whole transformer
I don't mean "sort of", I mean literally.

  > I am definitely NOT an expert here, I just have done a lot of practical twiddling with custom architectures in academic research contexts
If a PhD makes me an expert, then I am. My thesis was on the design of neural architectures
D-Machine 2 days ago | parent [-]

>> I don't mean "sort of", I mean literally.

Well, then we disagree, or are talking past each other, and I think writing out the equations and code show that, as I said, these are really not exactly the same thing. From a code standpoint, direct, long skips require retaining copies of earlier "x" values, and this is a memory cost that is a problem for certain purposes. Mathematically, this also means in a long, direct skip, you are adding in that exact earlier x value.

In the sequential or indirect "long skips" of a transformer, this is not the case. Yes, if you write the equations you can see there is a "path" for identity information to theoretically flow from any layer to any layer unmolested, but in practice this is not how it is implemented, and identity information is not flowing through the layers unchanged.

If everyone thought these subtle differences were irrelevant, than I am not sure why anyone would bother with making a dense residual transformer over the classic transformer. EDIT: nor would much of the papers incorporating special additional long skip connections to various transformer architectures make much sense. The point I was merely making was that, long skips generally serve a very different purpose than shorter / classic residual connections.