Remix.run Logo
arjvik 3 days ago

Isn't this just Mixture-of-Depths but for DiTs?

If so, what are the DiT specific changes that needed to be made?

yorwba 3 days ago | parent [-]

Mixture-of-Depths trains the model to choose different numbers of layers for different tokens to reduce inference compute. This method is more like stochastic depth / layer dropout, where whether or not the intermediate layers are skipped for a token is random independent of the token value, and they're only using it as a training optimization. As far as I can tell, during inference all tokens are always processed by all layers.