Remix.run Logo
xg15 4 days ago

> While we can use pretrained models such as Word2Vec to generate embeddings for machine learning models, LLMs commonly produce their own embeddings that are part of the input layer and are updated during training.

So out of interest: During inference, the embedding is simply a lookup table "token ID -> embedding vector". Mathematically, you could represent this as encoding the token ID as a (very very long) one-hot vector, then passing that through a linear layer to get the embedding vector. The linear layer would contain exactly the information from the lookup table.

My question: Is this also how the embeddings are trained? I.e. just treat them as a linear layer and include them in the normal backpropagation of the model?

montebicyclelo 4 days ago | parent | next [-]

So, they are included in the normal backpropagation of the model. But there is no one-hot encoding, because, although you are correct that it is equivalent, it would be very inefficient to do it that way. You can make indexing differentiable, i.e. gradient descent flows back to the vectors that were selected, which is more efficient than a one-hot matmul.

(If you're curious about the details, there's an example of making indexing differentiable in my minimal deep learning library here: https://github.com/sradc/SmallPebble/blob/2cd915c4ba72bf2d92...)

xg15 4 days ago | parent [-]

Ah, that makes sense, thanks a lot!

asjir 4 days ago | parent | prev [-]

To expand upon the other comment: Indexing and multiplying with one-hot embeddings are equivalent.

IF N is vocab size and L is sequence length, you'd need to create a NxL matrix, and multiply it with the embedding matrix. But since your NxL matrix will be sparse with only a single 1 per column, it'd make sense to represent it internally as just one number per column, representing the index at which 1 is. At which point if you defined new multiplication by this matrix, it would basically just index with this number.

And just like you write a special forward pass, you can write a special backward pass so that backpropagation would reach it.