Train transformer attention layers more precisely by accounting for how paired weight matrices interact during updates.
Drop Compositional Muon into an existing training loop alongside your current optimizer to handle only the attention pairs.
Reproduce or extend research on partner-aware optimizers for language or vision transformer models.
Experiment with spectral normalization techniques applied to query-key and output-value matrix pairs.
Requires PyTorch. The demo in src/main.py runs a small transformer out of the box. Caller must manage momentum state and handle non-attention parameters with a separate optimizer.
This repository is a Python implementation of Compositional Muon, a research-oriented method for training the attention layers of transformer models more precisely. Transformers are the model architecture behind most modern AI language and vision systems. Attention layers are a core component: they determine how different parts of an input relate to each other. Training a transformer means repeatedly adjusting millions of numerical parameters based on how wrong the model's outputs are. Most training methods treat each parameter matrix independently when deciding how to update it. The problem with attention layers is that the model never actually sees the individual matrices: it only sees the result of multiplying pairs of them together. Compositional Muon accounts for this by using information about one matrix in a pair to shape the update applied to the other. The README calls this approach partner whitening. The method builds on an existing optimizer called Muon, which applies a kind of spectral normalization to each gradient before using it. Compositional Muon extends that idea to the two matrix pairs that make up transformer attention: the query-key pair and the output-value pair. The update rule for each matrix is adjusted based on the geometry of its partner, so that the effective step size adapts to how stretched or compressed the partner matrix is. The code provides two functions, cm_qk and cm_ov, which take the relevant weight matrices, their gradients, and momentum buffers as arguments and apply the update in place. The caller is responsible for managing momentum state. These two functions only handle the attention pairs, all other model parameters are updated with a separate optimizer of the caller's choice. A runnable demo in src/main.py shows a small transformer trained with this combination. The library requires PyTorch and is released under the Apache 2.0 license.
← tilde-research on gitmyhub — every repo by this author, as a profile.
Verify against the repo before relying on details.