Remix.run Logo
Scene_Cast2 5 days ago

How would this matrix get trained with PyTorch? I currently have a toy Transformer network - I ended up marking the matrix as sparse and using SparseAdam - gives a bit of a performance boost, but at the same time I can't use torch.compile() on the fetch from this matrix.