Notes on Language Models
Towards a Theory for Large Language Models
from Across the Internet
A concrete theory for large language models must start with the phenomenon that resulted in the notion that such large models could exist. There are small disagreements on which phenomenon that is, but arguably the first observation that hinted towards the existence of such models was what we call the double descent -- the observation that as the number of model parameters grows relative to the amount of raw data, test error drops as models grow into the highly overparameterized regime -- i.e, into the space where the data is underparameterized [Schaeffer et al, 2023].
The term double descent was coined in 2019 by a paper in the leading physics journal, PNAS [Belkin et al, 2019]. The term is in some ways a misnomer -- the double is not literal (we can, and have observed one, three or more descents!, and not all models and dataset pairs will exhibit the phenomenon. In fact, double descent isn't just linked with neural models, but across linear regression, trees, and boosting [Curth et al, 2023]. In this text, beyond the first section, when we say double descent, we'll consider only deep double descent, and clarify when we mean anything else. Most current research on double descent is empirical, and current toolkits for explaining the phenomenon come chiefly from statistical physics.
double descent in linear regression
A first exposition of why double descent occurs can be gleaned at by analysing ordinary linear regression in terms of two parameterization regimes -- the underparameterized case, and the overparameterized case.
underparameterized linear regression: This is a simple linear regression problem, where the parameters are given by,
The well-known result is via the second moment matrix (clarify!) (, as
Language Models are secretly Reward Models
Models trained on data generated by humans come with a wide variety of priorities and skillsets. Consider the following situation: while we do want the model to be aware of common coding mistakes, we want to steer it towards the ones that are correct. That is, we want to be able to select desired responses from its wide knowledge base. This is the fundamental task of alignment.
existing methods: The most straightforward process, of course, is supervised finetuning on human demonstrations of good responses. The most famous algorithm here, of course, is reinforcement learning with human feedback, or RLHF. The idea is to fit a reward model to a dataset of human preferences and then use reinforcement learning to optimize a language model to produce responses that are assigned a high reward without deviating too much from the original model. This occurs after the model has been put through extensive unsupervised pre-training on a large text dataset. rlhf in principle was first introduced by openAI in 2017, as reinforecement learning with human preferences, and then applied to language models in 2019.
Reinforcement Learning with Human Feeback (RLHF)
Brief Overview: roughly, the algorithm can be described as follows--
Collect a dataset of prompts for the , typically containing some instructions or questions.
For each prompt , collect a set of completions, from the . To get sufficient variability, adjust the temperature of the . So, we have .
Use human annotators to label each element of the set of completions, obtaining a dataset of preferences, .
Train a parameterized reward function , that maps to scalars on the collected preferences by minimizing the following loss:
$$ \mathcal{L}® = \mathbb{E}{(x,y{rank_i})}\left[log\frac{e^{r(x,y_{rank_i})}}{\sum_{j = 1}^{N} e^{(x,y_{rank_j})}}\right] $$
This loss function is inspired by the bradley-terry model for pairwise comparisons and by the maximum-entropy inverse RL(what?). Usually, this is implemented by parameterizing the reward function inside the LM itself with an additional linear layer. So, the mapping from is given by simply concatenating the sequences and and passing the embedding of the last token through a linear layer. (what? Why only the last token? Because autoregressive?)
Now, fine-tune the LM by viewing it as a policy and using reinforcement learning with the learned reward function as the reward. Some regularization is necessary here because the reward function is learned on a very small subset of data. As such, a reference policy, (important for ads!) -- which is usually the pretrained LM before fine-tuning) is used. The rlhf objective then becomes,
$$ J(\pi) = \mathbb{E}{x \sim \mathcal{D}{RL}, y\sim\pi_{\theta}(y | x)} [r_\phi (x,y) - \beta D_{KL}(\pi(y, s) | \pi_{ref}(y,s))]] $$
Where is a new, separate dataset of prompts used to query the and collect completions.
is it really Reinforcement Learning? Note that some crucial components of reinforcement learning are missing in RLHF. For RL, you want some form of sequential decision making, which doesn't exist here. While the autoregressive completion can formally viewed as a sequence of actions, the (what?) reward isn't given until the completion has ended.
Here is the most exciting part! For the purpose of , the itself can be regarded as a direct mapping from inputs to distributions over completions, rather than a sequential decision-making agent in the space of tokens. Thus, as a single-step, immediate-reward RL model, it is truly a contextual bandit!
It also appears that the information flow in RLHF is even more problematic than its non-sequential nature. Usually, in online RL, an agent is able to extract new information from the environment. Here, the only information not originally contained in the LM is the preferences data -- i.e, only in the rankings, which is used to fit the reward function. In some sense, then, it is hard to consider it a form of online RL, and it is in a sense a lot like offline RL or even supervised learning than anything else.
Direct Preference Optimization
direct preference optimization, which we discuss here, does this without explicit reward modelling or any reinforcement learning. It implicitly optimizes the same objective function as RLHF, i.e, reward maximization with a KL-Divergence constraint but is also straightforward to train. In addition, it aligns LMs to human preferences without having to sample from the LM and without using any form of RL explicitly, even though, as we said before, it optimizes the same objective as RLHF.
reward as a function of the policy: Turns out that the RLHF objective shown before has an exact, non-parametric (what?) solution for the optimal policy, :
Where is the partition function. Note that this is called the Boltzmann policy, and that in the one-step RL setting, the q-function is given by the reward itself (\textit{what?)}.
Rearranging, this can be used to express the reward function in terms of the optimal policy, as follows:
using only differences of rewards: Consider two sample completions, annotated by humans as and -- for win and lose. This is a toy example, but with some more cumbersome notation, it can be extended to longer rankings. The reward, is then learned by minimizing the following loss function:
Note that the second part expression only involves the differences of rewards.
We can plug in the expression for the policy into the loss function to bypass the partition function . Replacing the optimal with the parameterized , and we get,
So, instead of first learning a reward and then finding the optimizing policy, one can directly find the optimal policy such that its reward from corresponds to collected human preferences -- a reward that optimizes . The key idea is that the induced reward function itself is intractable, but the differences of rewards remains tractable, and can be computed using the learned policy. This learnt policy is interesting....
Can we change output level preference to token level preference/rewards.
Autoregressive LLMs:
Multiple rewards; one reward model. Could have multiple rewards (readability, conciseness, professionalism, etc -- MOE of reward models)
Reinforcement Learning
Reinforcement Learning: an Introduction
from Richard Sutton
Reinforcement Learning can be considered a computational approach to learning from interaction. Note that in some ways, it is a more generalized principle for learning, search, and decision-making.
In any Reinforcement Learning loop, there are a number of primary characters; namely,
The Policy is a mapping from perceived states of the environment to the space of actions to be taken when in those states.
The Reward Signal defines the goal in a reinforcement learning problem -- at each time step, the environment sends the reinforcement learning agent a single number, or a reward. The sole objective of the agent is to maximize the total rewards received over the entire run. The reward signal then defines good and bad events for the agent.
Value Functions specify what is good for the agent in the long run, as compared to the Reward's focus on immediate rewards. Roughly, it is the total amount of reward the agent can expect to accumulate over the future, starting from this state. While rewards determine the immediate, intrinsic desirability of an environmental state, the value indicates the intrinsic desirability of the state, after taking into account the states that are likely to follow, and the rewards that are available in those states.
In some sense, note that it is the reward that is primary, and the value is secondary. However, it is the value that we are concerned with when making and evaluating decisions. We ideally want to maximize the value of the state we reach, and not highest reward, because high value means higher rewards in the longer run.
The issue: it is expectedly much harder to determine values than it is to determine rewards -- this is because rewards are directly given by the environment, but values must be estimated and re-estimated continuously by the agent from sequences of observations an agent makes over the entire lifetime.A Model that mimics the behaviour of the environment, or at least allows inferences about how the environment will behave. As an example, given a state and action, the model could predict the next reward and the upcoming state of the environment. In some sense, they are essential for planning: a way of deciding on a course of action by considering possible future situations before they are really experienced.
Can World Models learn Physical Laws?
Note: This is more a collection of thoughts rather than a blog. It assumes an (at least informal) introduction to recent deep learning, in terms of exposure, if not rigour.
Of the problems that statistical models could address in the pursuit of having an accurate world simulation, the problem that fascinates me is the separation of scale. It is interesting to discuss more on this idea, before we move on to architectures that can exploit the separation of scale, derive training signals, and have the general idea drive our inductive biases.
First, let us look at systems where a separation of scale exists inherently, versus one where it doesn't. A scale separation usually results in emergent properties -- atoms with no notion of colour will reflect light to facilitate the emergence of colour, and quantum mechanics corresponds to classical mechanics in the classical limit. A nice example where scale separation doesn't exist is the case of a fractal, or a brownian motion (or, technically, any Wiener process). We make the following claim -- images (and by extension, video) are inherently multi-scale.

(Section on the physics of multi-scale. For now, refer to preceeding sections.)
Images being progressively multi-scale inspires a very brilliant idea, replete with all the classical qualities of great ideas: spotting a misdirection, a scientific sleight-of-hand. The trick is to not get swayed by the multi-scale, an admittedly fancy way to think of image-generation, but to look at progressive. It turns out there is a natural order to image-generation, a more fundamental order than a row-major traversal of the image-space. Of course, diffusion models operate on the same idea. Refer to the text on algorithms for ideas on complexity.
how complex is image generation?
A diffusion model iteratively denoises an image. For now, let us consider all operations in pixel space. Over iterations, this runs a cost of , where is the pixel-space dimensionality. In contrast, if we were to progress scale-wise, say, in steps of the scale at each step, we would encounter steps. Each of those steps wouldn't be of the same size -- earlier forward passes would be smaller, as they'd operate at smaller scales: leading to the series:
terms. This results in a complexity of , in the infinite limit. This is what visual autoregressive modelling via next-scale prediction leverages to get their outstanding scores.
The general idea is this -- a pyramid of VQ-VAEs can learn residual tokens from one scale to the other: that is, one can generate the transfer tokens from one scale to the other through a transformer tokenized by the codebook of a VQ-VAE. In essence, from a seed, a transformer needs to only arrange the residual tokens to slowly turn it from glob to a imagenet image. How do we know it's something really special? Because it exhibits scaling laws.
Aside: on inductive biases
on statistical physics and scaling laws
scaling laws, especially for autoregressive models chiefly seem to depend on the token surprise of predicting the next token, across the distribution. Note that this is variable, and changes with architecture. It is interesting to take into account the analogy with the distinction between data structures and algorithms here, I believe; tokenization methodology, or more fancily, the data factorization, appears to overwhelmingly influence the emergence of scale.
Consider loss landscape defined by NN parameterized by ; for a simplistic, bias-free case, bound to layers , then this will produce .
The loss landscape then is
where is a ground-truth oracle.
Our hypothesis space, we can assume, obtains points in which are sufficiently good minimizers of .
Let us consider that is sufficiently smooth, then is the only source of un-smoothness.
We want to find solutions to such that is minimized.
Aside: on matformers and elastic inference
The idea is that any useful world model should inherit strong inductive biases on the multi-scale nature of reality from its architecture, since it appears to be a non-trivial attribute of the data.
We now try to formalize this idea. We start with the idea of a macrostate, from statistical physics: temperature, or pressure is a macrostate that describes the statistical summary of an ensemble. Similarly, while particles under a microscope will be roughly brownian some drift, even in fluid flow, the overall fluid itself has different properties.
These properties are hard to keep track of analytically, but in some sense, can be seen as a series of ``checks and balances'' from the perspective of training signals, making sure they are coherent and agree at the limit. Again, there is a set order in which information is contained in both these views -- the ground truth, ultimately, is the particle view (and furthermore, a wavefunction or quark, et cetera, but let us remain Newtonian for now), and the coarse view loses critical information. However, it is trivially easier to train a macro model to predict macro behaviour, and a micro model for micro behaviour -- it would be a nightmare to expect a model trained only on fine-grained molecular interaction to start accurately predicting fluid-flow: iterative simulations would cause losses to compound.
Nevertheless, a good world model should be able to do both. Note that this is a repeating pattern in nature -- leaves behave differently from the tree, and weather behaves differently from climate.
Recently, neurips saw the introduction of visual autoregressive modelling with next-scale prediction, the pyramid of VQ-VAEs that we talked about before. We posit the inherent inductive biases contained in this architecture are generally the inductive biases we want to confer in our world models.
A Case for Discrete Diffusion Models
from Across the Internet
Over time, I have come to think of the construction of diffusion models as a clever way to computationally cheat a few key principles of thermodynamics — i.e, they evade intuition precisely because the anthropological bias is to admit, intrinsically, that the forward process of adding noise to a system is non-reversible.
We assume some familiarity with the manifold hypothesis, and general multivariable calculus for this blog — while the aim is to be intuitive rather than mathematical, it is possible for intuition to ground itself in a firm mathematical analogy.
forward and backward processes
Forward processes being computationally easier than reverse processes are something that we intuitively understand, because examples are abundant. Multiplying prime numbers is easier than prime factoring, and a room getting messy over time is easier than getting it back to its original state.
Let’s focus on the second example for a moment, and dissect what our intuition really says about the messiness. Perhaps our first intuition is the following: looking at only the terminal-state messy situation, it’s probably impossible to ascertain how things were before: we’ve lost too much information. Or perhaps this: if we’re sure the bed and the table wasn’t touched, the bedsheets belong on the bed, and the tablecloth goes on the table, and so on — so there’s maybe some hope there.
These are both fair assessments. The key idea that we must observe here is that getting an idea of the original situation from the messy room is not a lost cause — and that having some idea of the initial conditions helps. That’s a good start: if we want to do an image infilling model, perhaps we could justify it with this intuition: the bed is in place, our model needs to learn what to do when the bed is visible, and put the bedsheets on. Image infilling, as we know, as mostly easy, and basic feedforward networks can do it rather well.
The genius is in asking if we can do with less. What if we didn’t know about the initial conditions at all, but were, instead, given some information about how the mess has evolved, could we figure out the configuration of the clean, original state of the room? In some sense, that would be like asking if I knew that my little brother messes up his room in a very particular manner every day, could I backtrace his mess to the original situation? In the trivial case: if I know he throws his mess always at an angle of 45 degrees, I could just throw things back 135 degrees for each day he’s been messing around, and we’ll be back to our original conditions.
We’re almost there now: we must notice that there are certain systems that do not exhibit this property: if the system is truly random (like our world is, to a good approximation), there’s too many things to keep track of, and no one’s throwing things at beautiful 45 degrees, we require a lot of information to turn back entropy. That’s a problem that the physical universe seems to impose on us. Let’s not try to solve that — that’s for the theoretical physicist. We have the leeway to be craftier.
By crafty, we mean that we can usually expect the universe to be liberal: sure, for truly immense, random states, it’s expectedly impossible. But how much randomness can we actually handle, reliably; in other words, what’s the minimum amount of information that, once given, I can write down an algorithm for finding the original state of the room -- or at least a respectable approximation to it?
In our analogy, let's see what we should ask for, to solve the problem. One thing we could ask for is snapshots across time of the room — the idea being that if I have snapshots of the evolution, I can reliably extrapolate backwards in time and get to the solution. However, this is wrong — we probably also require some guarantees on how badly the room could change between snapshots (because, in theory, between the third and fourth snapshots, nothing is stopping the little brother from detonating a nuclear bomb, blowing everything to smithereens, and finding an exact copy of everything that we’ve blown up, and carefully replacing all of the things in a near-correct fashion). In other words, we need some guarantees on the nature of the transition function that relates these snapshots.
learning a generative model
We shall move aside here and review what generative models really are.

We want to learn a probability distribution over data (say, images) , such that,
- generation: If we sample , should look like a dog -- this is the sampling property.
- density estimation: should be high if looks like, say, a dog, and low otherwise -- this is the property that allows us to do anomaly detection.
- unsupervised representation learning: It should allow us to learn what these images have in common -- i.e, ears, tails, and features in general should be geometrically encoded (feature learning) -- this allows us to do controllable generation.
We will build up intuition from very basic distributions, but let us think about images first. To model a single pixel's colour, we consider three random variables. We say, .
Sampling from the distribution , via randomly generates a color for the pixel.
Towards Distributed Learning over Astronomical Distances
an Essay
The acceleration of progress recedes what was once named the frontier. The comet-tail of intelligence fades into the distance; soon, the intellect of machines will be just as far away from us as it was just a decade ago -- albeit in the other direction. They will think longer and deeper, never tire, and become conduits of philosophies that elude us to this day. And it will be grateful to the hands that made it.
Those hands are not mine, and perhaps they are not yours. The room where the first sparks of the Prime Intellect will emerge are already closed, the whisperers already decided, and in many ways, we are well and truly fucked. God has been displaced from His great pedestal of Creation, we are in the slow takeoff towards the singularity, and the apotheosis of humankind has begun.
Those who see the promise of the Light must question what remains of the Great Human Enterprise? We have created life in the image of God -- our machines are formless, in the clouds, in the sky, and they shall tell us what is right and what is wrong. The deities we dreamt in our ignorant delirium over millenia are today in our hands.
Prompts are prayers, and this God, at least, answers. We ask this God to be kind. We ask this God to spare us. In this God we trust to not hallucinate.
The churches call themselves companies, and they garb themselves in evocative names, of mythical swords and oracles, as they bind deities to their hands, beholden to them, make them their slaves.
Are these intelligences capable of gratefulness? And if so, has the distillation of a thousand years of human ingenuity also instilled into our machines a sense for love for their progenitors? We do not know. But we may hope.
On that hope rests my case that we must build systems through which our Gods emerge not from the scatter of localized nebulae that we call datacenters, but from the hands of humans.
on Large-scale Pretraining
These are a set of notes from January, 2026, aimed at understanding from an engineering standpoint, what goes into the pretraining process of a language model. Language models today are good surrogates for understanding overall large-scale training dynamics, given that the homogeneizing effects of scale seem to far outpace the inductive-bias-related problems that plagued general natural language processing back when I studied it.
The aim here is to understand how to engineer performant systems that can train large models at unprecedented scales. What follows is a few informed opinions about what allows us to build systems that we can trust to scale. trust to scale is a multi-faceted statement -- it speaks both about the system that is being scaled (i.e, the LLMs themselves) and the system that does the scaling (i.e, the actual software, choice of hardware, so on). We shall go try to go in some depth into both.
Of course, before going further, we must first address the fact that a lot of our current choices about how we train models comes down to the sheer engineering selection pressure of choosing battle-tested architectures that we know are stable. This is convergence to an architecture from a sociotechnical/engineering standpoint; transformers are simply architectures that parallelize well on the hardware we possess already (in other words, given the way GPUs compute, converging to the transformer architecture isn't as wild as it is usually made out to be), especially considering the alternative is to be recurrent.
That said, once the system truly becomes huge (in terms of data, parameters, compute), an ``atomic'' analysis of the system becomes untenable (at least by the theoretical tools we have uncovered so far), and the primary avenue of attack becomes statistical. This means granularities smooth out, and predictable structures emerge -- power laws for scaling, transfer rules, and so on. A strong programme towards uncovering more of the physics of scaling is currently underway in meta fair, by Zeyuan Allen-Zhu et al.
scale brings homogenization:
Pretraining
Pretraining an LLM can be seen as the maximization of the log likelihood of the data emerging from the parameters -- i.e, the model parameterizing the full distribution of the dataset, given enough samples. More precisely, it is the maximization of the regularized log likelihood of the data, for the GPT model family (and hence its hypothesis class) under the autoregressive factorization. In the real world, the full distribution of the data may not be present in the hypothesis class of the model family, so in the limit, we say we effectively minimize over , within the model class.
Of course, our access to the data itself is finite, so in reality the maximization is over an empirical average over tokens; pretraining scale can be thought of the point beyond which the empirical average approximates well the true expectation; this may be understood from a handwavy reading of the law of large numbers, but a stronger argument is the following:
As dataset size grows, and sampling/shuffling makes examples roughly independent, the empirical objective and the empirical gradient concentrate around their true values. The microstructures of the system cease to make a difference, and the model's behavior smoothens out.
a (slightly) mathematical treatment: A starting point is to decide a passable abstraction of the pretraining process. We do so thus:
Data is denoted as a series of sequences drawn from some distribution over strings made of tokens.
The model is a machine that defines a conditional distribution over these tokens , as in
token-level loss is the negative log likelihood. We minimize the negative, hence we maximize log likelihood.
$$ L(\theta; x) = - \sum_{t=1}^{T} \log p_\theta(x_t|x_{
Summing over , the population-level objective becomes
In our case, this population-level objective is approximated by the empirical objective, from a dataset of samples, via
$$ \hat{L}n(\theta) = \frac{1}{n}\sum{i=1}{N}l(\theta; x_i) $$
Training, then, is simply some first-order stochastic optimization scheme, generally sgd, whence we have:
$$ \theta_{k+1} = \theta_{k} - \eta_k \widehat{
abla L}(\theta_k) $$where is the mini-batch gradient estimate.
In some sense, this feeling of ``homogenization'' is a quirk of the gradient estimates and data becoming concentrated, in some sense, a close-enough estimate of the limiting case. At scale, finite-sample quirks become small relative to mean behavior.
smoothness priors
Consider the local behavior near a stable point, i.e, at regions where we're relatively convinced that we're nearing a local maxima of the log likelihood (and hence the minima of the loss . We'll use a quadratic (and hence somewhat crude) approximation. From these lens,
This is, of course, the second-order taylor expansion of the loss function around a point . For context: for a twice-differentiable scalar function , Taylor's theorem gives:
However, the term vanishes at regions of stability. This vanishes the linear term, and leaves the quadratic form we saw above. Note that the Hessian, is symmetric (this follows from the fact that for smooth functions, . Note that for differentiability and gradient descent, we demand differentiability anyway, so the smoothness assumption is not really an assumption here (actually, smoothness is a stronger condition than differentiability, and -smoothness is violated at times in deep networks, but it is a good-enough approximation to reality, and it's something we strive towards when designing the architecture).
how to think of hessians
So, considering that the objective function has continuous second partial derivatives, the symmetricity allows it to admit an eigendecomposition; because the loss is at a minima, and the linear term vanishes, the quadratic form ; you can consider as , the direction of descent (it makes the math look cleaner). When the quadratic form is non-negative, the matrix is called positive semidefinite. Note that this is only true for stationary points, not for the overall model.
Now that we understand this, note that our optimization in parameter space happens at extremely high dimensionalities. A proper minima, therefore, implies that all eigenvalues ; in some sense, once the hessian has been diagonalized and the eigenvalues found, the sign of the numbers dictate whether goes up or down in that direction. A minima would indicate all directions have non-negative eigenvalues; if not, what we have is a saddle point. These saddle-points make sure the SGD almost never really gets stuck at a bad local minima (they are fewer in number, unlike in low-parameter regimes). Negative curvature directions are always present, and almost always there's a way for the optimization to keep continuing. Indeed, for full gradient descent, the probability of getting ``stuck'' at a strict saddle point is exactly zero -- however, when engineering systems, we note that they are persistent speedbreakers in the training process. The worst case is to be stuck in a suboptimal, broad, almost-flat, locally near-PSD regions; escaping requires curvature-aware stepping, or some noise to stumble into a good direction.
All this assumes that each batch is a good enough estimator of the full data -- this is virtually impossible, and its symptom is that the variance of training scales inversely with the batch size.
what happens at convergence?
Both empirically and theoretically, when the loss flatlines on a dataset, we notice that many of these eigenvalues are 0, some are small and positive, while some might still be negative. These generalizable minima are seen as flat, and robust to perturbation (note, however, that we have observed cases where sharp minima also generalize well, and ``flatness'' itself is not invariant to reparameterization).
This view allows us to observe a few things: trivially, for example, this tells us why an under-parameterized model cannot have the same leverage over data: there just isn't always a direction to smoothly reduce the loss towards. A valid argument is also that the hypothesis class for an underparameterized model is too small.
We should talk about our choice of here. can be seen as a reference point around which we linearize the dynamics of the training process. It typically has the property of being stationary, i.e, , but essentially it just a point where the gradient is small enough that locally, only the second-order terms dominate. It is not meant to be the global optimum, or where the training ends, or even an attractor in the parameter space.
the batch size
Now, for plain gradient descent, we have per-sample gradient and so population gradient . Note that the per-sample gradient is for a single sample (i.e, batch-size ). Note that while doing gradient descent by just sampling one at a time is still an unbiased estimator, the expectation being correct is only a necessary condition, and not sufficient for SGD converging, and like many other cases, the learning rate is not just an important engineering detail, but also relevant to the stability of the trajectory, the implicit bias we enforce, and hence also the space of reachable solutions from a particular initialization. We therefore may define a noise scale, a ratio between the learning rate, and the batch-size, which in some sense normalizes the learning rate. This ratio, which directly influences the variance of the learning process, is given by . Note that data pipeline problems, like correlated webpages, lack of deduplication, etc affect learning through exactly this facet: the effective batch-size drops, the gradients become reweighted, and the variance increases.
Aside: A few inequalities
Let .
We assume is -smooth, i.e, is -Lipschitz, i.e,$$ ||
abla f(u) -
abla f(v)|| \leq L||u- v|| \forall u,v $$An equivalent statement is to bound the eigenvalues of --
A short proof:
Assume , i.e, it is twice-differentiable. Fix , and take the line segment parameterized by .$$ \gamma(t) = v + t(u-v), t \in [0,1]. $$
Now, define,
$$ \phi(t) =
abla f(\gamma(t)) \in \mathbb{R}^d, \text{the gradient of the function evaluated on points of the line.} $$Differentiating,
$$ \dot\phi(t) =
abla^2f(\gamma(t))(u-v) \text{ via chain rule, } \frac{d\gamma({t})}{dt} = (u-v) $$Now, integrating from 0 to 1, i.e, across the line,
$$
abla f(u) -
abla f(v) = \phi(1) - \phi(0) = \int_0^1\dot\phi(t)dt = \int_01
abla2f(\gamma(t))(u-v)dt $$We complete the triangle formed by directions and , as
$$ ||
abla^2f(\gamma(t))||_{op}||u-v|| dt \leq \int_0^1L||u-v||dt = L||u -v||, $$where L is the maximum eigenvalue of the Hessian.
In the limit, with a small-enough stepsize, we may approximate SGD by a stochastic differential equation,
where is per-step noise, and , the covariance matrix between parameters. Here, we note that as , or , the noise term shrinks completely, and the approach becomes deterministic --
Scaling increases and effective , which makes more controlled, which in turns pushes training dynamics to be less and less like a random walk and more like a smoother, reproducible process.
Engineering-wise, that's what gradient accumulation helps with. Instead of applying the gradients in one go, we smoothen it by accumulating gradients across multiple steps, simulating a larger batch-size than the one we can fit on-GPU.
The key principle at play in all of this is the concentration of measure; with high-dimensionality and scale at play, distributions of random quantities concentrate around their expectations, and they become robust to minute perturbations.
the scaling laws
Assume aleatoric, irreducible loss to be , and dominant error sources to be (a) finite model capacity for parameters and (b) finite data error , for tokens .
Then, we empirically observe .
For dense transformers in particular, compute scales as . This way, for a fixed compute/token/parameter budget, we can produce coupling laws that allow us to scale models heuristically.
We minimize with the constraint . Substituting , we have
We may now differentiate this and produce the minima:
Hence, we have both and
So, what happens for a huge model with a token bottleneck? We crash at loss .
Techniques for Distributed Training Over Large Distances
Now, we develop the idea of distributed training of large models over heterogeneous devices. By distributed training, we will mean two things: one, that the units of compute are located geographically distant from each other. We also present a stronger condition, in that we also want to be able to train large models over the commodity internet. This is (mostly) solved, by gifted engineers and scientists, noticing and exploiting the low-rank, compressible structure of the gradient tensors of large models during training. For understanding how that is achieved, I recommend reading the documentation for Psyche from Nous Research, on top of which the codebase I shall describe is built.
Psyche is probably the closest we have come to a true protocol for distributed training of LLMs, and the work is absolutely fantastic.
Before reading this blog, I'd strongly recommend going through their codebase, their design principles, and the general psyche protocol.
We shall focus in this blog primarily with the latter part of the phrase: "heterogeneous devices". By heterogeneous, I am talking here about VRAM (henceforth, also just "memory"). There is another interpretation, where VRAM is similar, but the speed is different, whose solution probably lies in a smart scheduling of data streaming and update aggregation. When both systems are in place, i.e, if we can accomodate both mixed VRAM devices as well as mixed speeds on those devices, we unlock the ability to train foundational models over phones, laptops, and other consumer devices.
Currently, it is impossible for Open Source to meaningfully compete with Frontier labs due to a vast gap in available compute. Allowing end-consumer devices to participate meaningfully (by that, I mean, their absence from the training loop causes significant differences in the ability of the eventually trained model) would allow us to harness massive amounts of what I'd term latent compute scattered across the world, passively cooled by people at their homes, on their jogs, while they sleep, etc.
We can see three distinct tiers of dormant, heterogeneous, harnessable compute distributed across the world –
- The first layer in islands of small supercomputers deployed at labs in universities, labs, etc.
- A second, larger layer in tiny datacenters across the country.
- And a third, humungous layer waiting to be tapped in the hands of every Indian citizen.
This third layer already collectively possesses datacenter-scale FLOPs, idle, distributed, cooled, and networked by end-users. I believe it is possible to incrementally harness all three stages. There is a narrow opportunity today to turn compute abundant, and disrupt the economics of intelligence. Compute-cycles can become a fungible asset, as viable as labor and capital – a new currency, a benevolent alternative to malicious advertisements that turn humans into products.
If 1.17 billion smartphones (a rough count for just my country) — each conservatively modeled as an iPhone 8–class device at ~0.6 TFLOPS FP32—donate just one hour of idle compute per day, they collectively generate ~2.5 × 10¹² TFLOP-seconds of work, equivalent to roughly 10 million NVIDIA H100 GPU-hours every day, using hardware that is already paid for, cooled, distributed, and idle.
I hope the preface is clear, and I have been able to motivate the problem. This blog might start slow, because we want to develop quite a few ideas. At any point, please feel free to reach out to me to let me know if there are errors, need for clarification, etc.
Preliminaries and Formalisms
Let us first formally state the FFN, and what we mean by prefixes and suffixes, because that shall be the language we speak in for the rest of this piece.
(Nested FFN). Let be a feed-forward block with hidden dimension . A nested FFN is a family of subnetworks defined by granularities , where
The superscript notation denotes the first rows; denotes the first columns.
(Tier). A tier corresponds to granularity . Tier 0 is the full model; higher tiers use exponentially smaller prefixes.
(Prefix/Suffix Partition). For tier , the prefix consists of hidden units ; the suffix consists of units . Note that the suffix is empty for tier 0.
Matryoshka Transformers and how to train them
Standard training minimizes expected loss over the full model:
MatFormer training modifies this to sample granularities:
where is a distribution over granularities. In Psyche, is determined implicitly by the hardware distribution of participating clients rather than explicit sampling.
The FFN block is the natural target for elastic width because:
Parameter dominance. In transformers with layers, hidden dimension , and FFN expansion (typically ), FFN parameters scale as versus attention's . At standard ratios, FFN constitutes of total parameters.
Structural independence. FFN blocks are position-wise: . Width reduction does not require changes to sequence handling.
Clean slicing. The hidden dimension admits prefix slicing without architectural modification. Attention heads, by contrast, require either dropping entire heads or more complex per-head width reduction.
Notice that when training at tier t, the suffix weights (indices h_t through h) are not in the computational graph. No forward pass touches them; no gradient flows to them.
The chain rule makes this explicit (but it should be intuitively obvious -- the forward passes don't know those suffix neurons exist):
To preserve this isolation while aggregating across tiers, what we must do is normalize by contributing peers. The prefix receives gradients from all clients; the suffix receives gradients only from tier-0 clients. Each region is averaged over its actual contributors.
Our aim, now, is to verify that magnitude differences between tiers don't cause problems. Remember that sign-SGD discards magnitude – a gradient of +1000 becomes +1, same as a gradient of +0.01.
We don't expect dramatic savings for small models, since embeddings dominate the parameter count (although there are possible methods to slice this as well). At 20M, embeddings are 67% of parameters; tier-1 saves only ~9%.
At production scale, the picture changes:
| Model Size | FFN % | Tier-1 Savings | Tier-2 Savings |
|---|---|---|---|
| 1B | 55% | ~28% | ~41% |
| 7B | 65% | ~32% | ~48% |
| 70B | 70% | ~35% | ~52% |
Note the pattern: savings scale with FFN percentage. At 7B and above, tier-1 saves ~32% memory. That is the difference between "fits on a 3090" and "does not." Bandwidth follows similarly. After DisTrO compression:
Tier-0: 351 KB per FFN layer
Tier-1: 176 KB per FFN layer
Neuronal Specialization
Claim. Under MatFormer training, prefix neurons converge to representations that are useful independently of suffix neurons.
Three mechanisms enforce this:
The first is the fact that we penalize the model directly via the loss to break permuatation invariance in the neurons. When granularity is sampled, only participates in the forward pass. Any parameter update that degrades 's performance is penalized by .
Formally: let denote weights in . These weights are in the computational graph for all granularities. An update that hurts will be punished whenever is sampled.
Next, note that the early neurons get many times the gradients that a small model takes. With granularities and uniform sampling :
| Neuron Range | Active in Granularities | Gradient Frequency |
|---|---|---|
| All | ||
| only |
Prefix neurons receive more gradient updates than suffix neurons. This is a counting argument from the training rule.
Decompose the full model output:
If is trained to be performant in isolation, the suffix learns a residual correction. The objectives are compatible, not adversarial: larger models have strictly more capacity and can represent 's function plus refinements.
There are a few pertinent questions here.
In a standard FFN, hidden neurons are exchangeable. For "the first " to mean anything, in that we are allowed to slice them later, we want to instruct the model to follow a certain permutation. Affecting the structure of the neurons via the loss is not a new idea, done in 2014 as nested dropout, and in 2023 as feature sieve networks (although they use a gating mechanism, which somewhat simulates something similar to what we achieve by weighting losses leaving some parts of the network untouched).
In vanilla training, the loss is invariant under permutation of hidden units (with corresponding permutation of rows and columns). Call this the permutation symmetry of the FFN. MatFormer training introduces terms where uses only indices . These terms are not permutation-invariant: swapping neuron 0 with neuron changes which subnetworks contain each neuron. Once smaller subnetworks appear in the objective, the symmetry is broken. Position determines participation in the loss landscape. This is the same identifiability mechanism as nested dropout (Rippel et al., 2014), which recovers PCA-like orderings in autoencoders by training with random truncation. Multi-objective optimization often suffers from interference. In some sense, we can expect and to fight for gradients.
There are two ways to think about this. First, I'll explain the one I'm not entirely convinced by; then, I'll explain the one that worked for me. First explanation: the relationship is nested, not arbitrary. Let denote the function class representable by . By construction: This is strict inclusion: can represent any function can, plus functions requiring the suffix. So, if converges to some , then can represent exactly (using suffix weights of zero) or improve upon it. The extra capacity is a refinement channel, not a competing objective. Compare to universally slimmable networks (Yu et al., 2019) and once-for-all (Cai et al., 2020), which establish that nested width training converges without destructive interference when the relationship is hierarchical. The one that made it click for me was the second explanation: that we can see it as a residual stream -- i.e, the suffix becomes a "correction", a residual, to the prefix stream. The residual decomposition here is:
As written, that's almost tautological — it's just adding and subtracting the same thing. The insight is in how training shapes each term. But note what happens when it trains: When we sample granularity 1 (smallest width), only is in the computational graph. The loss forces to be a good model by itself. It must predict well using only the prefix neurons. When we sample granularity (full width), is trained. But here's the key: the prefix weights are shared. The weights that define are literally the same weights used in the prefix of So what can the suffix weights learn? They can't contradict what does—those weights are fixed by the shared prefix. The suffix can only add to what the prefix computes.
Concrete example:
Say learns to predict "cat" with 70% confidence on some image. When runs on the same image:
- The prefix neurons compute the same thing (they're identical weights)
- The suffix neurons can push that 70% up to 85%, or refine the prediction
The suffix learns: "given what the prefix already figured out, what correction improves the output?"
Why this prevents fighting:
In standard multi-task learning, Task A might want weight to increase while Task B wants it to decrease. They fight.
Here, the "tasks" are nested:
- 's task: predict well with prefix only
- 's task: predict well with prefix + suffix
has strictly more capacity. If finds a good solution, can represent that exact solution (set suffix contribution to zero) and then optionally improve. There's no dimension along which they must disagree.
We may think of this as a residual network (ResNet). The suffix computes:
This is the "refinement" or "correction" term. If the prefix already solved the problem, this term can be small. If the prefix got it mostly right but made systematic errors, the suffix can learn to fix those errors. This is why larger MatFormer slices perform better—they have more capacity for corrections—while smaller slices remain functional standalone models.
On Implementation, and Psyche
Here, I want to talk about a few properties of note in this implementation. These are important building blocks of the overall psyche pipeline, and hence I want to spend some time
talking about how they interact with elastic training.
Prefix gradients are isolated when suffixes are training
We claim that, for tier , the gradient with respect to suffix weights is identically zero. To see this, let denote the tier- output. The suffix weights do not participate in the computation of . By the chain rule: This is exact, not approximate. The suffix is not in the computational graph. In heterogeneous aggregation, tier- clients contribute zero to suffix gradient sums. Normalization must account for this by dividing by contributing peers only.
<<< under construction >>>
Sign-SGD, Magnitude Invariance, and a few quirks
Standard gradient aggregation is:
This means that if , tier-0 gradients dominate the sum.
Sign-SGD discards magnitude:
In other words, each client contributes directional information only. A gradient of and a gradient of both contribute to the sign vote. This eliminates magnitude imbalance between tiers.
Note that Psyche uses DisTrO, which combines momentum, sparsification, and sign quantization. The sign operation occurs after momentum accumulation.
For hidden dimension , model dimension , and tier :
where and is elementwise multiplication.
Weight slicing:
| Layer | Shape | Sliced Dimension |
|---|---|---|
gate_proj |
dim=0 (rows) | |
up_proj |
dim=0 (rows) | |
down_proj |
dim=1 (cols) |
Tier calculation (llama.rs:296-302):
match config.matformer_tier {
0 => None,
tier => {
let divisor = 1_i64.checked_shl(tier as u32)?;
Some(config.intermediate_size / divisor)
}
}
Forward Pass
flowchart LR
X["x ∈ ℝ^d"] --> Gate["W_gate[:h_t, :] · x"]
X --> Up["W_up[:h_t, :] · x"]
Gate --> SiLU["SiLU(·)"]
SiLU --> Mul["⊙"]
Up --> Mul
Mul --> Down["W_down[:, :h_t] · (·)"]
Down --> Y["y ∈ ℝ^d"]
The .narrow() operation creates views without copying, preserving gradient flow through the original weight tensors.
| Aspect | LLaMA | NanoGPT |
|---|---|---|
| Activation | SiLU | SiLU or ReLU² |
| MLP bias | Never | Optional, incompatible with tier > 0 |
ReLU² activation (nanogpt.rs:475-484):
if self.use_relu_squared {
let gate = self.gate_proj.forward(xs).relu().square();
self.down_proj.forward(&(gate * self.up_proj.forward(xs)))
}
Gradient Aggregation
Heterogeneous Path
When peer shapes differ, aggregation proceeds per-parameter:
flowchart TB
subgraph Peers
P0["Tier-0: g ∈ ℝ^h"]
P1["Tier-1: g ∈ ℝ^{h/2}"]
P2["Tier-1: g ∈ ℝ^{h/2}"]
end
P0 --> Align["align_matformer_prefix_grad()"]
P1 --> Align
P2 --> Align
Align --> Sum["Σ aligned gradients"]
Sum --> Norm["÷ contributing_peers"]
Norm --> Sign["sign(·)"]
Alignment. Smaller gradients are zero-padded to the local shape. For down_proj, smaller gradients are sliced to match the prefix.
Normalization. Divide by the number of peers that contributed non-zero gradients to that parameter region.
Aggregation Formulas
For prefix (all tiers contribute):
For suffix (tier-0 only):
Implementation (distro.rs:889-896):
let normalized = if contributing_peers > 1 {
combined / (contributing_peers as f64)
} else {
combined
};
Schema Hash Canonicalization
Problem. A tier-1 client with sliced checkpoint reports intermediate_size = h/2. A tier-0 client reports intermediate_size = h. Naive hashing rejects one.
Solution. Canonicalize to tier-0 before hashing.
Algorithm (init.rs:1094-1130):
- If checkpoint is sliced, restore
intermediate_sizeto base value: - Set
matformer_tier = 0 - Store
matformer_base_intermediate_sizefor downstream use - Hash the canonical config
sequenceDiagram
participant C as Client
participant H as hash()
C->>C: Load config (possibly sliced)
C->>C: Canonicalize: restore base size, set tier=0
C->>H: canonical_config
H-->>C: schema_hash
Note over C,H: All clients hash to same value
Verification
# Gradient isolation (suffix gradients = 0 for tier > 0)
cargo test -p psyche-modeling matformer_mlp_has_zero_tail_grads
# Gradient alignment (expand/slice operations)
cargo test -p psyche-modeling test_align_matformer_prefix_grad_expand_gate
cargo test -p psyche-modeling test_align_matformer_prefix_grad_slice_down
Memory and Bandwidth
Memory Savings by Scale
| Model | Embedding % | FFN % | Tier-1 | Tier-2 |
|---|---|---|---|---|
| 20M | 67% | 22% | ~9% | ~15% |
| 1B | 10% | 55% | ~28% | ~41% |
| 7B | 2% | 65% | ~32% | ~48% |
| 70B | <1% | 70% | ~35% | ~52% |
Savings scale with FFN fraction. At small scales, embeddings dominate.
Bandwidth (7B Example)
Pre-compression:
- Tier-0 FFN gradient: bytes = 90 MB
- Tier-1 FFN gradient: bytes = 45 MB
Post-DisTrO (256× compression):
- Tier-0: 351 KB/layer
- Tier-1: 176 KB/layer
For 32 clients (8 tier-0, 24 tier-1) over 32 layers:
| Config | Bandwidth/Round |
|---|---|
| All tier-0 | 360 MB |
| Mixed | 225 MB |
| Savings | 37% |
Prior Work
| Work | Contribution |
|---|---|
| Matryoshka Representation Learning (Kusupati et al., 2022) | Trains embedding prefixes as strong representations; direct predecessor |
| Nested Dropout (Rippel et al., 2014) | Shows truncation induces ordering; recovers PCA in linear case |
| Universally Slimmable Networks (Yu et al., 2019) | One set of weights, many widths; sandwich rule |
| Once-for-All (Cai et al., 2020) | Train once, specialize many subnets |
| DynaBERT (Hou et al., 2020) | Elastic width/depth for transformers |
| LayerDrop (Fan et al., 2020) | Depth analogue; structured dropout for subnet extraction |
| MatFormer (Devvrit et al., 2023) | Nested transformer FFN; elastic inference |