This is a quick post mortem on a experiment that I’m shelving for now.
Why Latent Tokens
I’ve been playing around with the idea of using latent tokens for reasoning models. Normally, when an LLM generates text, it produces a vector embedding. This embedding then gets translated into probabilities for different tokens (roughly tokens = words), and after a token is chosed, it is converted back into a vector embedding. Typically, the probabilities are calculated such that the re-embedded token is similar to the initial embedding. This can be thought of like snapping the vectors onto known embeddings: the predicted vector is pretty close to some token x, so we just snap it onto x. This way, the model learns an intuitive structure, and we can easily interpret the output vectors by looking at the corresponding decoded words.
With latent tokens, you skip the decoding step and just keep the original embedding vector. The motivation is simple: while decoding/snapping makes the outputs human-readable, it’s also a limitation. The model might be capable of learning richer representations that go beyond individual tokens, but there is no incentive for it to do so. The richer representation will get reduced back to a token, and any extra information will be lost. Tokens are also a very expensive and inefficient approach: intuitively, expressing something in words is often much lengthier than the actual thought or feeling. The advantage of having richer representations from latent tokens is also potentially much bigger in some domains than other. In particular, there are domains like spatial reasoning where it is quite difficult to communicate concepts with just language. The theory goes that a model might be much more capable at a domain like spatial reasoning if it is able to develop a more native (e.g. “visual”) latent representation.
There has been some work on this idea, and I think the most notable is the Coconut paper from Meta. The Coconut model takes example chains of thought and gradually replaces the steps with latent tokens i.e. you go from “Problem Step 1 Step 2… Answer” to “Problem Latent 1 Latent 2… Answer”. Here a step is a full sentence i.e. a full thought.
I believe that this is the wrong approach. First, it’s hard to have any interpretability or training stability if you completely replace the language tokens with latent steps. You’re switching to this completely alien paradigm and potentially losing many of the benefits of pretraining on all of the language data that’s on the Internet. Second, it doesn’t make sense to replace an entire sentence with a single latent embedding. The size of the embedding is meant to represent a token, so if you try to represent an entire sentence with a single embedding, then you’re squeezing a lot of information in. Third, using latents can be seen as a form of variable compute: instead of using a single forward pass (i.e. a single run) of the model to predict the next token, you can instead produce a few latents before predicting the next language token. Each of these latents is an another forward pass: allowing you to use more compute to predict that next language token. Not all tokens are equally difficult to predict or equally consequential. For example, if you’re trying to remember a name or complete an equation, then getting the answer wrong is significant because it materially affects your downstream reasoning. If you’re working on a math problem, then subsequent steps often rely on intermediate results calculated in prior steps. If you say Bob ate 3 pairs of bananas in one step, then you’re gonna say he ate 6 total bananas in the next step. On the other hand, many language tokens are either easy to predict or inconsequential. They might be parts of phrases or little bits of grammar like articles and prepositions. They could also be easily interchangeable: for example, you might substitute “also” with “as well”.
In total, my thesis is that you can make better use of latent tokens if you use them intermittently, sparingly, and intentionally. Most of language generation does not require latent reasoning or extra compute: in these cases, the model should just output language tokens. In cases where extra compute is necessary or where reasoning in language is difficult, the model should use latents to hone in on the right answer before switching back to language. Interleaving latents and language like this should help preserve some interpretability because humans can still read the language output and maybe use them to piece together what’s happening in the latent steps.
How
My plan was to implement this with two phases of training. In the first phase, I would hold the model’s hand as it learned to utilize the latents with supervised learning. Then in the second phase, I would give the model free reign with reinforcement learning.
In the supervised learning phase, you tell the model where and how to use latents. The Grade School Math (GSM8k) dataset from OpenAI is a collection of 8,000 grade school level math problems. Each problem is accompanied by an example step-by-step answer, and the answer is marked with places where it’d be beneficial to use a calculator e.g. Tom takes 3,200 steps a day for a total of 7 * 3,200 = «use_calculator» 22,400 steps a week. The calculator markings are a clear place where it may be particularly beneficial to use latents: these are consequential steps where having some extra compute could help you work through the math. My setup was simple: break up the step-by-step answers at the calculator markings. Then have the model learn to use latents at these breakpoints: we can say that a model is using latents well if it is more likely to correctly predict the next step. Finally, you can train a pair of classifier models to decide when to start and stop using latent tokens. The start classifier is simple: you just train it to start using latents at the calculator breakpoints. The end classifier is much trickier, but really the only thing you care about in this first stage is that it does something reasonable i.e. it should not decide to output 500 latent tokens in a row — especially on a grade school level math problem.
Through the first stage, the model would gain a baseline understanding of how you could use latents. After this jump start, the model can go further by learning on its own through reinforcement learning or specifically reinforcement learning with verifiable rewards (RLVR). In this paradigm, the output of a language model is graded using some automated criteria. For example, on a math problem, you might look at the model’s final answer and check if it matches the correct answer. This approach has produced impressive results like DeepSeek’s R1 reasoning model. In this phase of training, you’d no longer dictate when or how to use latents. Instead we’d simply have the model output tokens — both language and latents. Next, we’d automatically grade it on just the language tokens, and give the model a reward for correct answers. For example, if the model outputs 6 latent tokens and then says “The answer is 5,” we’d only grade the language snippet at the end. Through this process, the models can learn to use latents however it likes — as long as it gets to the right answer.
Results
I tried training the Qwen2 0.5 billion parameter model from Alibaba on the first phase. The results were not promising enough to attempt the second phase. Again the first phase is supervised, and we’re just trying to get the model to utilize latents to correctly predict the next step. The fear with this kind of approach is always that the model will just memorize the specific example instead of learning something more generalizable. This is known as overfitting because you’re matching your training data too closely.
I tried two experiments. In the first run, I just used the pretrained Qwen2 model and finetuned the whole thing. In this experiment, the train loss consistently improved, but there wasn’t a clear improvement on test loss. This is typically indicative of overfitting because the model’s performance is not generalizing to the held out test set. It’s also unclear if the latents are actually being utilized. Because the entire model gets the training signal, the model might just be memorizing the ground truth language answers and then learning to ignore the latents.
In the second run, I tried to avoid these pitfalls by freezing the model’s weights: I created two copies of the model — one for processing latents and one for processing language. The language copy is frozen while the latent copy is trained on the same signal as the first run. This setup was meant to ensure that the latents are actually utilized because it is the only thing that can be changed by the training process. It also ensures that the language model can’t just memorize the answer because the copy of the model that processes language is frozen. In essence, the purpose of the latent copy is to produce latent embeddings that elicit the right answer when those latent embeddings are seen by the language copy.
In this second run, I saw a rapid drop off in train loss before it flatlined. Similarly, test loss was also flat. The alignment between train and test loss suggests that the model isn’t overfitting, but the flatlining of both losses suggests that the model just isn’t learning anything at all. The most likely explanation here is that the model was initially perplexed by the sudden introduction of latent tokens. It then quickly learned to ignore these latents — at which point it stopped improving. To validate this hypothesis, I tried decoding the latent embeddings into language tokens. If the model is learning some useful representation, then you might still see some similarity to normal language tokens — after all, the model is initialized from a pretrained language model and only trained for a short while. You’d also expect the decoded tokens to be meaningful and have some relation to the actual problem and process. For example, the tokens might be numeric, or they might connect to the necessary operations.
Instead I saw nonsense tokens that reappeared in different chains of thought without any clear correlation to the actual problem. Recurring tokens included “BaseContext” and “-gages”, and the tokens often involved language switching with several instances of Chinese, Japanese, and Korean characters. I’m guessing the recurring English tokens were either very common or uncommon in Qwen2’s training data: both cases may lead to effectively meaningless representations. There were also a pair of learned embeddings that represented the beginning and end of the chain of thought. These tokens were even less meaningful when I decoded them. Language switching was even more frequent — ranging from Arabic to Russian to Thai. These tokens suggest that the model did just learn to ignore the latent tokens. It was not able to improve beyond this point and meaningfully use the latents because this is a local optimum: the base model’s ability to predict the next set of language tokens is already pretty good, so the immediate and easy solution is to just ignore latents.
Next Steps
It’s possible that just letting the model learn freely with RLVR will work much better for latents. I’m pretty skeptical of this approach being able to get off the ground, and I suspect we’d see the same result of the model just learning to ignore latents. Another possibility is that you just need much more data — or maybe a different dataset or base model. I’m personally not interested enough to try these things, but my code is available at this repo. The main
branch contains the initial method of training the whole model, and the two-trunk
has the second experiment of separate language and latent copies.
References
Hao, Shibo, et al. “Training large language models to reason in a continuous latent space.” arXiv preprint arXiv:2412.06769 (2024).
Yang, An et al. “Qwen2 Technical Report.” ArXiv abs/2407.10671 (2024).