Notes on Distributed Training
This blog probes and develops the idea of distributed training of large models over heterogeneous devices. By distributed training, we will mean two things: one, that the units of compute are located geographically distant from each other. We also present a stronger condition, in that we also want to be able to train large models over the commodity internet. This is (mostly) solved, noticing and exploiting the low-rank, compressible structure of the gradients of large models during training.
Nous Research's Psyche is probably the closest we have come to a true protocol for distributed training of LLMs. This work primarily builds on top of Psyche.
We shall focus in this blog primarily with the latter part of the phrase: "heterogeneous devices". By heterogeneous, I am talking here about VRAM (henceforth, also just "memory"). There is another interpretation, where VRAM is similar, but the speed is different, whose solution probably lies in a smart scheduling of data streaming and update aggregation. When both systems are in place, i.e, if we can accomodate both mixed VRAM devices as well as mixed speeds on those devices, we unlock the ability to train foundational models over phones, laptops, and other consumer devices.
Currently, it is impossible for Open Source to meaningfully compete with Frontier labs due to a vast gap in available compute. Allowing end-consumer devices to participate meaningfully (by that, I mean, their absence from the training loop causes significant differences in the ability of the eventually trained model) would allow us to harness massive amounts of what I'd term latent compute scattered across the world, passively cooled by people at their homes, on their jogs, while they sleep, etc.
We can see three distinct tiers of dormant, heterogeneous, harnessable compute distributed across the world –
- The first layer in islands of small supercomputers deployed at labs in universities, labs, etc.
- A second, larger layer in tiny datacenters across the country.
- And a third, humungous layer waiting to be tapped in the hands of every Indian citizen.
This third layer already collectively possesses datacenter-scale FLOPs, idle, distributed, cooled, and networked by end-users. I believe it is possible to incrementally harness all three stages. There is a narrow opportunity today to turn compute abundant, and disrupt the economics of intelligence. Compute-cycles can become a fungible asset, as viable as labor and capital – a new currency, a benevolent alternative to malicious advertisements that turn humans into products.
If 1.17 billion smartphones (a rough count for just my country) — each conservatively modeled as an iPhone 8–class device at ~0.6 TFLOPS FP32—donate just one hour of idle compute per day, they collectively generate ~2.5 × 10¹² TFLOP-seconds of work, equivalent to roughly 10 million NVIDIA H100 GPU-hours every day, using hardware that is already paid for, cooled, distributed, and idle.
Preliminaries and Formalisms
Let us first formally state the FFN, and what we mean by prefixes and suffixes, because that shall be the language we speak in for the rest of this piece.
(Nested FFN). Let be a feed-forward block with hidden dimension . A nested FFN is a family of subnetworks defined by granularities , where
The superscript notation denotes the first rows; denotes the first columns.
(Tier). A tier corresponds to granularity . Tier 0 is the full model; higher tiers use exponentially smaller prefixes.
(Prefix/Suffix Partition). For tier , the prefix consists of hidden units ; the suffix consists of units . Note that the suffix is empty for tier 0.
Matryoshka Transformers and how to train them
Standard training minimizes expected loss over the full model:
MatFormer training modifies this to sample granularities:
where is a distribution over granularities. In Psyche, is determined implicitly by the hardware distribution of participating clients rather than explicit sampling.
The FFN block is the natural target for elastic width because:
Parameter dominance. In transformers with layers, hidden dimension , and FFN expansion (typically ), FFN parameters scale as versus attention's . At standard ratios, FFN constitutes of total parameters.
Structural independence. FFN blocks are position-wise: . Width reduction does not require changes to sequence handling.
Clean slicing. The hidden dimension admits prefix slicing without architectural modification. Attention heads, by contrast, require either dropping entire heads or more complex per-head width reduction.
Notice that when training at tier t, the suffix weights (indices h_t through h) are not in the computational graph. No forward pass touches them; no gradient flows to them.
The chain rule makes this explicit (but it should be intuitively obvious -- the forward passes don't know those suffix neurons exist):
To preserve this isolation while aggregating across tiers, what we must do is normalize by contributing peers. The prefix receives gradients from all clients; the suffix receives gradients only from tier-0 clients. Each region is averaged over its actual contributors.
Our aim, now, is to verify that magnitude differences between tiers don't cause problems. Remember that sign-SGD discards magnitude – a gradient of +1000 becomes +1, same as a gradient of +0.01.
We don't expect dramatic savings for small models, since embeddings dominate the parameter count (although there are possible methods to slice this as well). At 20M, embeddings are 67% of parameters; tier-1 saves only ~9%.
At production scale, the picture changes:
| Model Size | FFN % | Tier-1 Savings | Tier-2 Savings |
|---|---|---|---|
| 1B | 55% | ~28% | ~41% |
| 7B | 65% | ~32% | ~48% |
| 70B | 70% | ~35% | ~52% |
Note the pattern: savings scale with FFN percentage. At 7B and above, tier-1 saves ~32% memory. That is the difference between "fits on a 3090" and "does not." Bandwidth follows similarly. After DisTrO compression:
Tier-0: 351 KB per FFN layer
Tier-1: 176 KB per FFN layer
Neuronal Specialization
Claim. Under MatFormer training, prefix neurons converge to representations that are useful independently of suffix neurons.
Three mechanisms enforce this:
The first is the fact that we penalize the model directly via the loss to break permuatation invariance in the neurons. When granularity is sampled, only participates in the forward pass. Any parameter update that degrades 's performance is penalized by .
Formally: let denote weights in . These weights are in the computational graph for all granularities. An update that hurts will be punished whenever is sampled.
Next, note that the early neurons get many times the gradients that a small model takes. With granularities and uniform sampling :
| Neuron Range | Active in Granularities | Gradient Frequency |
|---|---|---|
| All | ||
| only |
Prefix neurons receive more gradient updates than suffix neurons. This is a counting argument from the training rule.
Decompose the full model output:
If is trained to be performant in isolation, the suffix learns a residual correction. The objectives are compatible, not adversarial: larger models have strictly more capacity and can represent 's function plus refinements.
There are a few pertinent questions here.
In a standard FFN, hidden neurons are exchangeable. For "the first " to mean anything, in that we are allowed to slice them later, we want to instruct the model to follow a certain permutation. Affecting the structure of the neurons via the loss is not a new idea, done in 2014 as nested dropout, and in 2023 as feature sieve networks (although they use a gating mechanism, which somewhat simulates something similar to what we achieve by weighting losses leaving some parts of the network untouched).
In vanilla training, the loss is invariant under permutation of hidden units (with corresponding permutation of rows and columns). Call this the permutation symmetry of the FFN. MatFormer training introduces terms where uses only indices . These terms are not permutation-invariant: swapping neuron 0 with neuron changes which subnetworks contain each neuron. Once smaller subnetworks appear in the objective, the symmetry is broken. Position determines participation in the loss landscape. This is the same identifiability mechanism as nested dropout (Rippel et al., 2014), which recovers PCA-like orderings in autoencoders by training with random truncation. Multi-objective optimization often suffers from interference. In some sense, we can expect and to fight for gradients.
There are two ways to think about this. First, I'll explain the one I'm not entirely convinced by; then, I'll explain the one that worked for me. First explanation: the relationship is nested, not arbitrary. Let denote the function class representable by . By construction: This is strict inclusion: can represent any function can, plus functions requiring the suffix. So, if converges to some , then can represent exactly (using suffix weights of zero) or improve upon it. The extra capacity is a refinement channel, not a competing objective. Compare to universally slimmable networks (Yu et al., 2019) and once-for-all (Cai et al., 2020), which establish that nested width training converges without destructive interference when the relationship is hierarchical. The one that made it click for me was the second explanation: that we can see it as a residual stream -- i.e, the suffix becomes a "correction", a residual, to the prefix stream. The residual decomposition here is:
As written, that's almost tautological — it's just adding and subtracting the same thing. The insight is in how training shapes each term. But note what happens when it trains: When we sample granularity 1 (smallest width), only is in the computational graph. The loss forces to be a good model by itself. It must predict well using only the prefix neurons. When we sample granularity (full width), is trained. But here's the key: the prefix weights are shared. The weights that define are literally the same weights used in the prefix of So what can the suffix weights learn? They can't contradict what does—those weights are fixed by the shared prefix. The suffix can only add to what the prefix computes.
Concrete example:
Say learns to predict "cat" with 70% confidence on some image. When runs on the same image:
- The prefix neurons compute the same thing (they're identical weights)
- The suffix neurons can push that 70% up to 85%, or refine the prediction
The suffix learns: "given what the prefix already figured out, what correction improves the output?"
Why this prevents fighting:
In standard multi-task learning, Task A might want weight to increase while Task B wants it to decrease. They fight.
Here, the "tasks" are nested:
- 's task: predict well with prefix only
- 's task: predict well with prefix + suffix
has strictly more capacity. If finds a good solution, can represent that exact solution (set suffix contribution to zero) and then optionally improve. There's no dimension along which they must disagree.
We may think of this as a residual network (ResNet). The suffix computes:
This is the "refinement" or "correction" term. If the prefix already solved the problem, this term can be small. If the prefix got it mostly right but made systematic errors, the suffix can learn to fix those errors. This is why larger MatFormer slices perform better—they have more capacity for corrections—while smaller slices remain functional standalone models.
On Implementation, and Psyche
Here, I want to talk about a few properties of note in this implementation. These are important building blocks of the overall psyche pipeline, and hence I want to spend some time
talking about how they interact with elastic training.
Prefix gradients are isolated when suffixes are training
We claim that, for tier , the gradient with respect to suffix weights is identically zero. To see this, let denote the tier- output. The suffix weights do not participate in the computation of . By the chain rule: This is exact, not approximate. The suffix is not in the computational graph. In heterogeneous aggregation, tier- clients contribute zero to suffix gradient sums. Normalization must account for this by dividing by contributing peers only.
<<< under construction >>>
Sign-SGD, Magnitude Invariance, and a few quirks
Standard gradient aggregation is:
This means that if , tier-0 gradients dominate the sum.
Sign-SGD discards magnitude:
In other words, each client contributes directional information only. A gradient of and a gradient of both contribute to the sign vote. This eliminates magnitude imbalance between tiers.
Note that Psyche uses DisTrO, which combines momentum, sparsification, and sign quantization. The sign operation occurs after momentum accumulation.
For hidden dimension , model dimension , and tier :
where and is elementwise multiplication.
Weight slicing:
| Layer | Shape | Sliced Dimension |
|---|---|---|
gate_proj |
dim=0 (rows) | |
up_proj |
dim=0 (rows) | |
down_proj |
dim=1 (cols) |
Tier calculation (llama.rs:296-302):
match config.matformer_tier {
0 => None,
tier => {
let divisor = 1_i64.checked_shl(tier as u32)?;
Some(config.intermediate_size / divisor)
}
}
Forward Pass
flowchart LR
X["x ∈ ℝ^d"] --> Gate["W_gate[:h_t, :] · x"]
X --> Up["W_up[:h_t, :] · x"]
Gate --> SiLU["SiLU(·)"]
SiLU --> Mul["⊙"]
Up --> Mul
Mul --> Down["W_down[:, :h_t] · (·)"]
Down --> Y["y ∈ ℝ^d"]
The .narrow() operation creates views without copying, preserving gradient flow through the original weight tensors.
| Aspect | LLaMA | NanoGPT |
|---|---|---|
| Activation | SiLU | SiLU or ReLU² |
| MLP bias | Never | Optional, incompatible with tier > 0 |
ReLU² activation (nanogpt.rs:475-484):
if self.use_relu_squared {
let gate = self.gate_proj.forward(xs).relu().square();
self.down_proj.forward(&(gate * self.up_proj.forward(xs)))
}
Gradient Aggregation
Heterogeneous Path
When peer shapes differ, aggregation proceeds per-parameter:
flowchart TB
subgraph Peers
P0["Tier-0: g ∈ ℝ^h"]
P1["Tier-1: g ∈ ℝ^{h/2}"]
P2["Tier-1: g ∈ ℝ^{h/2}"]
end
P0 --> Align["align_matformer_prefix_grad()"]
P1 --> Align
P2 --> Align
Align --> Sum["Σ aligned gradients"]
Sum --> Norm["÷ contributing_peers"]
Norm --> Sign["sign(·)"]
Alignment. Smaller gradients are zero-padded to the local shape. For down_proj, smaller gradients are sliced to match the prefix.
Normalization. Divide by the number of peers that contributed non-zero gradients to that parameter region.
Aggregation Formulas
For prefix (all tiers contribute):
For suffix (tier-0 only):
Implementation (distro.rs:889-896):
let normalized = if contributing_peers > 1 {
combined / (contributing_peers as f64)
} else {
combined
};
Schema Hash Canonicalization
Problem. A tier-1 client with sliced checkpoint reports intermediate_size = h/2. A tier-0 client reports intermediate_size = h. Naive hashing rejects one.
Solution. Canonicalize to tier-0 before hashing.
Algorithm (init.rs:1094-1130):
- If checkpoint is sliced, restore
intermediate_sizeto base value: - Set
matformer_tier = 0 - Store
matformer_base_intermediate_sizefor downstream use - Hash the canonical config
sequenceDiagram
participant C as Client
participant H as hash()
C->>C: Load config (possibly sliced)
C->>C: Canonicalize: restore base size, set tier=0
C->>H: canonical_config
H-->>C: schema_hash
Note over C,H: All clients hash to same value
Verification
# Gradient isolation (suffix gradients = 0 for tier > 0)
cargo test -p psyche-modeling matformer_mlp_has_zero_tail_grads
# Gradient alignment (expand/slice operations)
cargo test -p psyche-modeling test_align_matformer_prefix_grad_expand_gate
cargo test -p psyche-modeling test_align_matformer_prefix_grad_slice_down
Memory and Bandwidth
Memory Savings by Scale
| Model | Embedding % | FFN % | Tier-1 | Tier-2 |
|---|---|---|---|---|
| 20M | 67% | 22% | ~9% | ~15% |
| 1B | 10% | 55% | ~28% | ~41% |
| 7B | 2% | 65% | ~32% | ~48% |
| 70B | <1% | 70% | ~35% | ~52% |
Savings scale with FFN fraction. At small scales, embeddings dominate.
Bandwidth (7B Example)
Pre-compression:
- Tier-0 FFN gradient: bytes = 90 MB
- Tier-1 FFN gradient: bytes = 45 MB
Post-DisTrO (256× compression):
- Tier-0: 351 KB/layer
- Tier-1: 176 KB/layer
For 32 clients (8 tier-0, 24 tier-1) over 32 layers:
| Config | Bandwidth/Round |
|---|---|
| All tier-0 | 360 MB |
| Mixed | 225 MB |
| Savings | 37% |
Prior Work
| Work | Contribution |
|---|---|
| Matryoshka Representation Learning (Kusupati et al., 2022) | Trains embedding prefixes as strong representations; direct predecessor |
| Nested Dropout (Rippel et al., 2014) | Shows truncation induces ordering; recovers PCA in linear case |
| Universally Slimmable Networks (Yu et al., 2019) | One set of weights, many widths; sandwich rule |
| Once-for-All (Cai et al., 2020) | Train once, specialize many subnets |
| DynaBERT (Hou et al., 2020) | Elastic width/depth for transformers |
| LayerDrop (Fan et al., 2020) | Depth analogue; structured dropout for subnet extraction |
| MatFormer (Devvrit et al., 2023) | Nested transformer FFN; elastic inference |
References
Devvrit et al. (2023). MatFormer: Nested Transformer for Elastic Inference. arXiv:2310.07707
Kusupati et al. (2022). Matryoshka Representation Learning. arXiv:2205.13147
Rippel et al. (2014). Learning Ordered Representations with Nested Dropout. ICML
Bernstein et al. (2018). signSGD: Compressed Optimisation for Non-Convex Problems. arXiv:1802.04434
Yu et al. (2019). Universally Slimmable Networks. ICCV
Cai et al. (2020). Once-for-All: Train One Network and Specialize it for Efficient Deployment. ICLR
Hou et al. (2020). DynaBERT: Dynamic BERT with Adaptive Width and Depth. NeurIPS
Fan et al. (2020). Reducing Transformer Depth on Demand with Structured Dropout. ICLR