plugyawn's blog(?)

Notes on Distributed Training

This blog probes and develops the idea of distributed training of large models over heterogeneous devices. By distributed training, we will mean two things: one, that the units of compute are located geographically distant from each other. We also present a stronger condition, in that we also want to be able to train large models over the commodity internet. This is (mostly) solved, noticing and exploiting the low-rank, compressible structure of the gradients of large models during training.

Nous Research's Psyche is probably the closest we have come to a true protocol for distributed training of LLMs. This work primarily builds on top of Psyche.

We shall focus in this blog primarily with the latter part of the phrase: "heterogeneous devices". By heterogeneous, I am talking here about VRAM (henceforth, also just "memory"). There is another interpretation, where VRAM is similar, but the speed is different, whose solution probably lies in a smart scheduling of data streaming and update aggregation. When both systems are in place, i.e, if we can accomodate both mixed VRAM devices as well as mixed speeds on those devices, we unlock the ability to train foundational models over phones, laptops, and other consumer devices.

Currently, it is impossible for Open Source to meaningfully compete with Frontier labs due to a vast gap in available compute. Allowing end-consumer devices to participate meaningfully (by that, I mean, their absence from the training loop causes significant differences in the ability of the eventually trained model) would allow us to harness massive amounts of what I'd term latent compute scattered across the world, passively cooled by people at their homes, on their jogs, while they sleep, etc.

We can see three distinct tiers of dormant, heterogeneous, harnessable compute distributed across the world –

This third layer already collectively possesses datacenter-scale FLOPs, idle, distributed, cooled, and networked by end-users. I believe it is possible to incrementally harness all three stages. There is a narrow opportunity today to turn compute abundant, and disrupt the economics of intelligence. Compute-cycles can become a fungible asset, as viable as labor and capital – a new currency, a benevolent alternative to malicious advertisements that turn humans into products.

If 1.17 billion smartphones (a rough count for just my country) — each conservatively modeled as an iPhone 8–class device at ~0.6 TFLOPS FP32—donate just one hour of idle compute per day, they collectively generate ~2.5 × 10¹² TFLOP-seconds of work, equivalent to roughly 10 million NVIDIA H100 GPU-hours every day, using hardware that is already paid for, cooled, distributed, and idle.


Preliminaries and Formalisms

Let us first formally state the FFN, and what we mean by prefixes and suffixes, because that shall be the language we speak in for the rest of this piece.

(Nested FFN). Let FFN:dd be a feed-forward block with hidden dimension h. A nested FFN is a family of subnetworks {FFN_i}i=1g defined by granularities m1<m2<<mg=h, where

FFN_i(x)=W_down[:,:m_i]·σ(W_up[:m_i,:]·x)

The superscript notation W[:mi,:] denotes the first mi rows; W[:,:mi] denotes the first mi columns.

(Tier). A tier t{0,1,,T} corresponds to granularity mt=h/2t. Tier 0 is the full model; higher tiers use exponentially smaller prefixes.

(Prefix/Suffix Partition). For tier t, the prefix consists of hidden units [0,mt); the suffix consists of units [mt,h). Note that the suffix is empty for tier 0.


Matryoshka Transformers and how to train them

Standard training minimizes expected loss over the full model:

_standard(θ)=𝔼_(x,y)~𝒟[L(M(θ;x),y)]

MatFormer training modifies this to sample granularities:

_MatFormer(θ)=𝔼_i~p(i)𝔼_(x,y)~𝒟[L(M_i(θ;x),y)]

where p(i) is a distribution over granularities. In Psyche, p(i) is determined implicitly by the hardware distribution of participating clients rather than explicit sampling.

The FFN block is the natural target for elastic width because:

  1. Parameter dominance. In transformers with L layers, hidden dimension d, and FFN expansion r (typically r=4), FFN parameters scale as 2Lrd2 versus attention's 4Ld2. At standard ratios, FFN constitutes 2/3 of total parameters.

  2. Structural independence. FFN blocks are position-wise: FFN(X)=[FFN(x1),,FFN(xn)]. Width reduction does not require changes to sequence handling.

  3. Clean slicing. The hidden dimension admits prefix slicing without architectural modification. Attention heads, by contrast, require either dropping entire heads or more complex per-head width reduction.

Notice that when training at tier t, the suffix weights (indices h_t through h) are not in the computational graph. No forward pass touches them; no gradient flows to them.

The chain rule makes this explicit (but it should be intuitively obvious -- the forward passes don't know those suffix neurons exist):

LWsuffix=Ly·yWsuffix=Ly·0=0

To preserve this isolation while aggregating across tiers, what we must do is normalize by contributing peers. The prefix receives gradients from all clients; the suffix receives gradients only from tier-0 clients. Each region is averaged over its actual contributors.

Our aim, now, is to verify that magnitude differences between tiers don't cause problems. Remember that sign-SGD discards magnitude – a gradient of +1000 becomes +1, same as a gradient of +0.01.

We don't expect dramatic savings for small models, since embeddings dominate the parameter count (although there are possible methods to slice this as well). At 20M, embeddings are 67% of parameters; tier-1 saves only ~9%.

At production scale, the picture changes:

Model Size FFN % Tier-1 Savings Tier-2 Savings
1B 55% ~28% ~41%
7B 65% ~32% ~48%
70B 70% ~35% ~52%

Note the pattern: savings scale with FFN percentage. At 7B and above, tier-1 saves ~32% memory. That is the difference between "fits on a 3090" and "does not." Bandwidth follows similarly. After DisTrO compression:

Tier-0: 351 KB per FFN layer
Tier-1: 176 KB per FFN layer

Neuronal Specialization

Claim. Under MatFormer training, prefix neurons converge to representations that are useful independently of suffix neurons.

Three mechanisms enforce this:

The first is the fact that we penalize the model directly via the loss to break permuatation invariance in the neurons. When granularity i is sampled, only Mi participates in the forward pass. Any parameter update that degrades Mi's performance is penalized by L(Mi(x),y).

Formally: let θprefix denote weights in [0,m1). These weights are in the computational graph for all granularities. An update Δθprefix that hurts M1 will be punished whenever i=1 is sampled.

Next, note that the early neurons get many times the gradients that a small model takes. With g granularities and uniform sampling p(i)=1/g:

Neuron Range Active in Granularities Gradient Frequency
[0,m1) All g 100%
[m1,m2) {2,,g} (g1)/g
[mg1,mg) {g} only 1/g

Prefix neurons receive g× more gradient updates than suffix neurons. This is a counting argument from the training rule.

Decompose the full model output:

M_g(x)=M_1(x)+(M_g(x)M_1(x))_suffix contribution

If M1 is trained to be performant in isolation, the suffix learns a residual correction. The objectives are compatible, not adversarial: larger models have strictly more capacity and can represent M1's function plus refinements.


There are a few pertinent questions here.

In a standard FFN, hidden neurons are exchangeable. For "the first m" to mean anything, in that we are allowed to slice them later, we want to instruct the model to follow a certain permutation. Affecting the structure of the neurons via the loss is not a new idea, done in 2014 as nested dropout, and in 2023 as feature sieve networks (although they use a gating mechanism, which somewhat simulates something similar to what we achieve by weighting losses leaving some parts of the network untouched).

In vanilla training, the loss is invariant under permutation of hidden units (with corresponding permutation of Wup rows and Wdown columns). Call this the permutation symmetry of the FFN. MatFormer training introduces terms L(Mi(x),y) where Mi uses only indices [0,mi). These terms are not permutation-invariant: swapping neuron 0 with neuron h1 changes which subnetworks contain each neuron. Once smaller subnetworks appear in the objective, the symmetry is broken. Position determines participation in the loss landscape. This is the same identifiability mechanism as nested dropout (Rippel et al., 2014), which recovers PCA-like orderings in autoencoders by training with random truncation. Multi-objective optimization often suffers from interference. In some sense, we can expect M1 and Mg to fight for gradients.

There are two ways to think about this. First, I'll explain the one I'm not entirely convinced by; then, I'll explain the one that worked for me. First explanation: the relationship is nested, not arbitrary. Let i denote the function class representable by Mi. By construction: _1_2_g This is strict inclusion: Mg can represent any function M1 can, plus functions requiring the suffix. So, if M1 converges to some f*1, then Mg can represent f* exactly (using suffix weights of zero) or improve upon it. The extra capacity is a refinement channel, not a competing objective. Compare to universally slimmable networks (Yu et al., 2019) and once-for-all (Cai et al., 2020), which establish that nested width training converges without destructive interference when the relationship is hierarchical. The one that made it click for me was the second explanation: that we can see it as a residual stream -- i.e, the suffix becomes a "correction", a residual, to the prefix stream. The residual decomposition here is:

Mg(x)=M1(x)+(Mg(x)M1(x))

As written, that's almost tautological — it's just adding and subtracting the same thing. The insight is in how training shapes each term. But note what happens when it trains: When we sample granularity 1 (smallest width), only M1 is in the computational graph. The loss L(M1(x),y) forces M1 to be a good model by itself. It must predict well using only the prefix neurons. When we sample granularity g (full width), Mg is trained. But here's the key: the prefix weights are shared. The weights that define M1 are literally the same weights used in the prefix of Mg So what can the suffix weights learn? They can't contradict what M1 does—those weights are fixed by the shared prefix. The suffix can only add to what the prefix computes.

Concrete example:

Say M1 learns to predict "cat" with 70% confidence on some image. When Mg runs on the same image:

The suffix learns: "given what the prefix already figured out, what correction improves the output?"

Why this prevents fighting:

In standard multi-task learning, Task A might want weight w to increase while Task B wants it to decrease. They fight.

Here, the "tasks" are nested:

Mg has strictly more capacity. If M1 finds a good solution, Mg can represent that exact solution (set suffix contribution to zero) and then optionally improve. There's no dimension along which they must disagree.

We may think of this as a residual network (ResNet). The suffix computes: suffix contribution=Mg(x)M1(x)

This is the "refinement" or "correction" term. If the prefix already solved the problem, this term can be small. If the prefix got it mostly right but made systematic errors, the suffix can learn to fix those errors. This is why larger MatFormer slices perform better—they have more capacity for corrections—while smaller slices remain functional standalone models.


On Implementation, and Psyche

Here, I want to talk about a few properties of note in this implementation. These are important building blocks of the overall psyche pipeline, and hence I want to spend some time talking about how they interact with elastic training.

Prefix gradients are isolated when suffixes are training

We claim that, for tier t>0, the gradient with respect to suffix weights is identically zero. To see this, let y=FFN_t(x) denote the tier-t output. The suffix weights Wsuffix=Wup[mt:,:] do not participate in the computation of y. By the chain rule: LW_suffix=Ly·yW_suffix=Ly·0=0 This is exact, not approximate. The suffix is not in the computational graph. In heterogeneous aggregation, tier-t clients contribute zero to suffix gradient sums. Normalization must account for this by dividing by contributing peers only.


<<< under construction >>>

Sign-SGD, Magnitude Invariance, and a few quirks

Standard gradient aggregation is:

g=1N_i=1Ng_i,θθη·g

This means that if gtier-0gtier-1, tier-0 gradients dominate the sum.

Sign-SGD discards magnitude:

θθη·sign(_i=1Ng_i)

In other words, each client contributes directional information only. A gradient of +106 and a gradient of +106 both contribute +1 to the sign vote. This eliminates magnitude imbalance between tiers.

Note that Psyche uses DisTrO, which combines momentum, sparsification, and sign quantization. The sign operation occurs after momentum accumulation.


For hidden dimension h, model dimension d, and tier t:

SwiGLU_t(x)=W_down[:,:h_t]·(SiLU(W_gate[:h_t,:]·x)W_up[:h_t,:]·x)

where ht=h/2t and is elementwise multiplication.

Weight slicing:

Layer Shape Sliced Dimension
gate_proj [h,d] dim=0 (rows)
up_proj [h,d] dim=0 (rows)
down_proj [d,h] dim=1 (cols)

Tier calculation (llama.rs:296-302):

match config.matformer_tier {
    0 => None,
    tier => {
        let divisor = 1_i64.checked_shl(tier as u32)?;
        Some(config.intermediate_size / divisor)
    }
}

Forward Pass

flowchart LR
    X["x ∈ ℝ^d"] --> Gate["W_gate[:h_t, :] · x"]
    X --> Up["W_up[:h_t, :] · x"]
    Gate --> SiLU["SiLU(·)"]
    SiLU --> Mul["⊙"]
    Up --> Mul
    Mul --> Down["W_down[:, :h_t] · (·)"]
    Down --> Y["y ∈ ℝ^d"]

The .narrow() operation creates views without copying, preserving gradient flow through the original weight tensors.

Aspect LLaMA NanoGPT
Activation SiLU SiLU or ReLU²
MLP bias Never Optional, incompatible with tier > 0

ReLU² activation (nanogpt.rs:475-484):

if self.use_relu_squared {
    let gate = self.gate_proj.forward(xs).relu().square();
    self.down_proj.forward(&(gate * self.up_proj.forward(xs)))
}

Gradient Aggregation

Heterogeneous Path

When peer shapes differ, aggregation proceeds per-parameter:

flowchart TB
    subgraph Peers
        P0["Tier-0: g ∈ ℝ^h"]
        P1["Tier-1: g ∈ ℝ^{h/2}"]
        P2["Tier-1: g ∈ ℝ^{h/2}"]
    end
    P0 --> Align["align_matformer_prefix_grad()"]
    P1 --> Align
    P2 --> Align
    Align --> Sum["Σ aligned gradients"]
    Sum --> Norm["÷ contributing_peers"]
    Norm --> Sign["sign(·)"]

Alignment. Smaller gradients are zero-padded to the local shape. For down_proj, smaller gradients are sliced to match the prefix.

Normalization. Divide by the number of peers that contributed non-zero gradients to that parameter region.

Aggregation Formulas

For prefix (all tiers contribute):

g_prefix=1N_i=1Ng_iprefix

For suffix (tier-0 only):

g_suffix=1N_0_j:t_j=0g_jsuffix

Implementation (distro.rs:889-896):

let normalized = if contributing_peers > 1 {
    combined / (contributing_peers as f64)
} else {
    combined
};

Schema Hash Canonicalization

Problem. A tier-1 client with sliced checkpoint reports intermediate_size = h/2. A tier-0 client reports intermediate_size = h. Naive hashing rejects one.

Solution. Canonicalize to tier-0 before hashing.

Algorithm (init.rs:1094-1130):

  1. If checkpoint is sliced, restore intermediate_size to base value: hbase=hsliced·2t
  2. Set matformer_tier = 0
  3. Store matformer_base_intermediate_size for downstream use
  4. Hash the canonical config
sequenceDiagram
    participant C as Client
    participant H as hash()

    C->>C: Load config (possibly sliced)
    C->>C: Canonicalize: restore base size, set tier=0
    C->>H: canonical_config
    H-->>C: schema_hash
    Note over C,H: All clients hash to same value

Verification

# Gradient isolation (suffix gradients = 0 for tier > 0)
cargo test -p psyche-modeling matformer_mlp_has_zero_tail_grads

# Gradient alignment (expand/slice operations)
cargo test -p psyche-modeling test_align_matformer_prefix_grad_expand_gate
cargo test -p psyche-modeling test_align_matformer_prefix_grad_slice_down

Memory and Bandwidth

Memory Savings by Scale

Model Embedding % FFN % Tier-1 Tier-2
20M 67% 22% ~9% ~15%
1B 10% 55% ~28% ~41%
7B 2% 65% ~32% ~48%
70B <1% 70% ~35% ~52%

Savings scale with FFN fraction. At small scales, embeddings dominate.

Bandwidth (7B Example)

Pre-compression:

Post-DisTrO (256× compression):

For 32 clients (8 tier-0, 24 tier-1) over 32 layers:

Config Bandwidth/Round
All tier-0 360 MB
Mixed 225 MB
Savings 37%

Prior Work

Work Contribution
Matryoshka Representation Learning (Kusupati et al., 2022) Trains embedding prefixes as strong representations; direct predecessor
Nested Dropout (Rippel et al., 2014) Shows truncation induces ordering; recovers PCA in linear case
Universally Slimmable Networks (Yu et al., 2019) One set of weights, many widths; sandwich rule
Once-for-All (Cai et al., 2020) Train once, specialize many subnets
DynaBERT (Hou et al., 2020) Elastic width/depth for transformers
LayerDrop (Fan et al., 2020) Depth analogue; structured dropout for subnet extraction
MatFormer (Devvrit et al., 2023) Nested transformer FFN; elastic inference

References

  1. Devvrit et al. (2023). MatFormer: Nested Transformer for Elastic Inference. arXiv:2310.07707

  2. Kusupati et al. (2022). Matryoshka Representation Learning. arXiv:2205.13147

  3. Rippel et al. (2014). Learning Ordered Representations with Nested Dropout. ICML

  4. Bernstein et al. (2018). signSGD: Compressed Optimisation for Non-Convex Problems. arXiv:1802.04434

  5. Yu et al. (2019). Universally Slimmable Networks. ICCV

  6. Cai et al. (2020). Once-for-All: Train One Network and Specialize it for Efficient Deployment. ICLR

  7. Hou et al. (2020). DynaBERT: Dynamic BERT with Adaptive Width and Depth. NeurIPS

  8. Fan et al. (2020). Reducing Transformer Depth on Demand with Structured Dropout. ICLR