Billy Ian's Short Leisure-time Wander

into learning, investment, intelligence and beyond

Shortcut Learning Hypothesis of Modern Language Models

| Comments

Disclaimer: This post was completed in my spare time, with no relevance to my current work. And opinions in this post are my own, not my employers’.

As the already gigantic modern language models become ever larger by the day, people seem to ignore many existing works discussing their limitations. In this blog post, I try to connect the results and observations from several such works to form the “shortcut learning hypothesis” of those language models. Such a hypothesis implies that modern language models are fundamentally flawed to effectively capture long-range dependencies or complicated structures of the data, if we still stick with the current training objective. Hopefully, this post can help people take a step back and ponder a bit first before going all-in into the current scale competition of language models.

A 137B parameter language model fails to answer a simple question with prompting. (The example comes from this paper)

Preliminaries

A language model (LM) is an estimate $Q(x_{1:T})$ of the true underlying probability distribution $P(x_{1:T})$ over sequence of text $x_{1:T}=(x_1,\dots,x_T)$, consisting of tokens $x_t$ from a fixed vocabulary. Prevailing neural language models estimate the joint distribution $Q(x_{1:T})$ autoregressively which is implicitly defined by the conditional distributions $Q(x_t \vert x_{:t})$. Such an autoregressive factorization possesses several benefits:

  • It's efficiently computable;
  • It's well aligned with the human intuition to speak, read and write text sequentially;
  • It can be trained in an unsupervised fashion with massive amount of raw text;
  • Text can be conveniently generated by sampling from $Q(x_t \vert x_{:t})$ in a sequential fashion

It is standard for such a language model to be trained to minimize the cross entropy objective:

With such a formulation, the problem becomes one of the most basic learning tasks of predicting the next observation $x_t$ given a sequence of past observations $x_1,x_2,\dots,x_{t-1}$. And the objective is basically to minimize the “average error (uncertainty / entropy)” on the token level, which is commonly referred to as “perplexity” in the literature. Masked token prediction, a popular variant of the standard objective, first proposed in the BERT paper breaks the autoregressive factorization, but its characteristic of “average error” remains unchanged.

Long-Range Dependencies

For the natural languages, the true distributions $P$ exhibit complex interactions between distant observations, i.e., long-range dependencies. Modern LMs based on the Transformer architecture have achieved tremendous successes in NLP by pushing the “average error” low enough. However, these large models still struggle to effectively capture long-range dependencies, for example, generating long and coherent stories, answering questions depending on long context or robustly performing numerical reasoning.

One limitation of the transformer-based LMs is the fixed number of tokens they can encode at once, and increasing this number linearly introduces a quadratic computational cost. Recently, a lot of efforts are dedicated to make the Transformer architecture more efficient to encode longer context, in the hope to help capture long-range dependencies. However, one recent EMNLP paper showed that whether encoding longer context can help capture long-range dependencies remains unclear. It seems that the LMs still largely rely on the most recent observations to make their predictions even though they have direct access to far distant past.

“Average Error” is Theoretically Not Good Enough

The STOC paper “Prediction with a Short Memory” presents an interesting theoretical result. The paper is quite technical, but the key result is quite intuitive. Here is the most important proposition (to me) from the paper:

Let $\mathcal{M}$ be any distribution over sequences with mutual information $I(\mathcal{M})$ between past observations $\dots,x_{t-2},x_{t-1}$ and future observations $x_t, x_{t+1},\dots$. The best $l$-th order Markov model, which makes predictions based only on the most recent $l$ observations, predicts the distribution of the next observation with average KL error $I(\mathcal{M})/l$, with respect to the actual conditional distribution of $x_t$ given all past observations.

Essentially, it shows that a Markov model – a model that cannot capture long-range dependencies or structure of the data – can predict accurately on any data-generating distribution, provided the order of the Markov model scales with the complexity of the distribution, as parameterized by the mutual information between the past and future. Strikingly, this parameterization is indifferent to whether the dependencies in the sequence are relatively short-range or very long-range. Independent of the nature of these dependencies, provided the mutual information is small, accurate prediction is possible based only on the most recent few observations.

Intuitively, it means that the “average error” can be pushed low enough without capturing long-range dependencies at all, by only doing well on the time steps when prediction relies little on long-range dependencies. There is one condition to make this argument valid, that is, the amount of long-range dependencies is small. To get a sense on whether this condition is met, let’s take a partition of the sequence $x_{1:T}$ into $A = x_{1:t}$ and $B=x_{t+1:T}$. Then, the cross entropy objective is equivalent to:

As $T$ increases, the configuration space of $B$ grows exponentially $\lvert\mathcal{B}\vert \sim d^{T-t}$, where $d$ is the vocabulary size and $\mathcal{B}$ is the set of all possible instances of $B$. However, with one specific instance of $A$ fixed, the amount of possible dependencies with $\mathcal{B}$ remains relatively small. As a result, the dependencies between $A$ and $B$ for large $T$ are very rare. In short, long-range dependencies are likely very sparse on average in the data.

The above result and intuition imply that the “average error”, though ubiquitously used in practice, is not a good metric to train and evaluate the LMs, if we are interested in capturing long-range dependencies. As long as the number of dependencies is not too large (usually valid), models with no capability to capture long-range dependencies can still perform well under the “average error”.

Shortcut Learning Hypothesis

One may argue that such a result only reveals that models with a short memory can perform well measured by the “average error”. But it is not direct evidence that modern LMs trained with the “average error” are fundamentally flawed to capture long-range dependencies. Though without rigorous proof, I hypothesize that shortcut learning exists in modern LMs as a direct outcome of optimizing the “average error”. As put in the paper “Shortcut Learning in Deep Neural Networks”, “shortcut learning typically reveals itself by a strong discrepancy between intended and actual learning strategy, causing an unexpected failure”. The figure below shows a toy example of shortcut learning.

Toy example of shortcut learning. When trained on a simple dataset of stars and moons (top row), a standard neural network can easily categorise novel similar exemplars (middle row). However, tesing it on a slightly different dataset (bottom row) reveals a shortcut strategy: The network has learned to associate object location with a category.

In the case of modern LMs, we hope that they will capture long-range dependencies naturally by scaling up with more and more parameters and data. However, with the “average error” as the main (usually, the only) objective, they learn unintended shortcuts by only leveraging recent observations but ignore most of the long-range dependencies. And as suggested by the above metioned theoretical result, such shortcuts indeed exist. (Such a phenomenon can also be interpreted as learning of spurious correlations in the data, encouraged by the inductive bias introduced by the “average error” objective.)

Actually, many empirical observations support the “shortcut learning hypothesis”. The most obvious evidence to me is the current large LMs’ inability to effectively capture long-range dependencies and understand complicated structures of the data, manifested in many complicated real-world tasks:

  • They fail to generate coherent long documents;
  • They perform poorly answering questions requiring logic, reasoning or grounding;
  • They easily misunderstand the high-level structure of the long documents;
  • They tend to hallucinate facts contradicting either the common knowledge or the given long context.

From my point of view, the models are learning in an unintended way that differs from what we, as humans, expect. Despite performing well under the “average error”, they try to generalize at places where memorizations are actually needed (e.g., factual information), and memorize at places where generalizations are actually needed (e.g., logic and reasoning). It is the typical characteristic of shortcut learning!

More clues are also observed in different quantitative ways. Our previous AISTATS paper found that there is a large discrepancy between $I(P)$ and $I(Q)$. Recall that $I(\mathcal{M})$ is the mutual information between past and future observations of any distribution $\mathcal{M}$, $P$ is the true distribution, and $Q$ is the model distribution. Basically, it shows that the learnt model distribution $Q$ exhibits much fewer long-range dependencies than the true distribution $P$. It may be the reason why the current language models are unable to generate long coherent documents. Long-range dependencies are largely lost when conditioning on the models’ own outputs.

In a similar vein, this ICML paper observed that the entropy rate of the model distribution $Q$

diverges quickly from the cross entropy objective

as the length of the generated sequences $T$ increases. Ideally, an accurate language model, we expect that $CE(P \Vert Q) \approx EntRate(Q)$. Such a divergence means that the language models become increasingly uncertain conditioned on their own outputs, even though they are able to push the “average error” to a very low level with respect to the true distribution. It also highlights that the learnt model distribution $Q$ ignores some crucial properties of the true distribution $P$. Why? Likely these properties, albeit important, do not offer much direct help to decrease the “average error”.

Both of these two observations provide additional empirical support that shortcut learning indeed exists for modern LMs, as a direct outcome of only optimizing the “average error”.

How to Do Better?

The shortcut learning hypothesis clearly suggests that a better metric/objective may be essential for modern LMs to better capture long-range dependencies or complicated structure of the data. One possibility suggested in the “Prediction with a Short Memory” paper is to only train and evaluate the models at a chosen set of (hard) time steps instead of all time steps. Hence the models can no longer do well with those unintended shortcuts.

Actually, a similar idea is already widely adopted in the practice of natural language processing. When GPT-3 and BERT first comes out, they seem to possess amazing “zero-shot” or “few-shot” transferrabilities. However, people soon realized such capabilities are largely overestimated, especially on harder tasks, e.g., question answering. Instead, “fine-tuning” the model weights with a downstream objective usually achieves much better performance. Just as revealed by the latest InstructGPT from OpenAI, a smart fine-tuning strategy can make small models outperform much larger models. The wide success of fine-tuning again validates the “shortcut learning hypothesis” and implies that only optimizing the “average error” is not enough in the end.

However, such a strategy is imperfect in many ways:

  • Fine-tuning relies on small annotated downstream datasets which are much more expensive to scale;
  • The fine-tuned models on a specific task may perform worse on the other downstream tasks, so a separate copy of the model may be required for each downstream task;
  • Fine-tuning is usually conducted with a very small learning rate, which I doubt can really help escape those unintended shortcuts learnt during the pre-training.

Many attempts to reslove these issues did not achieve much success, like multi-task pre-training and fine-tuning reported in the T5 paper from Google. It is unsurprising to me, since the annotated data, even combined across tasks, is tiny in size as compared to the raw data used for the unsupervised pre-training. As a result, fine-tuning is not good enough and we need a better unsupervised objective other than the “average error”.

While the “average error” is still the dominated unsupervised objective to train LMs, the original BERT actually introduced two promising directions to improve upon the “average error” metric:

  • Masked token prediction
    • predict the masked tokens corrupted with different strategies.
    • Pros: it breaks the autoregressive factorization, which makes more complicated interactions among tokens possible; and it works pretty well in practice.
    • Cons: it is still “average error”.
    • Masked token prediction is now widely used in follow-up works with many variants.
  • Next sentence prediction
    • distinguish between next sentences and randomly selected sentences.
    • Pros: it encourages a higher mutual information $\mathcal{I}(Q)$ under the model distribution.
    • Cons: it does not work well in practice.
    • Next sentence prediction is dropped quickly by the community due to its empirical ineffectiveness. In our AISTATS paper, some initial explorations are conducted to help improve this objective on RNN-based LMs.

Generally, I think both new ways to factorize the joint distribution $P(x_{1:T})$, and new unsupervised objectives beyond the “average error” are worth exploring further down the road. Another interesting direction to me is to leverage the massive information available on the web to serve as LM’s external memory, that is, retrieval-based LMs. Interestingly, three prestigious industrial AI labs all released their efforts in this direction recently. However, they all still rely either on the “average error” (RETRO from DeepMind) or fine-tuning on a human-annotated dataset (LaMDA from Google and WebGPT from OpenAI). This recent paper from Brain points out an intriguing direction to break down long complicated dependencies into short simple dependencies, though it’s completed through prompting. As suggested by this paper, prompting is unlikely to work on task semantics not close enough to the LM pre-training objective. But these initial efforts are definitely meaningful and provide us guidance towards a better LM paradigm which may circumvent the “shortcut learning hypothesis” to better capture long-range dependencies and understand complicated structures of the data.

Conclusion

Benefitting from the scaling success, modern LMs are becoming a general purpose model able to handle all kinds of tasks (through fine-tuning and prompting). As a result, their impact and implications are also becoming ever larger. In addition, the discussions of modern LMs are also becoming more and more controversial on the social media. As the stake is so high right now, some frank discussions about their limitations appear to be more indispensable, like this latest paper from Anthropic. Hopefully, this post can also contribute to such a purpose.

Acknowledgements

Thanks for the valuable feedback from Vatsal Sharan (USC/Stanford), Robert Geirhos (University of Tübingen), Yu Hou (USC), Guy Gur-Ari (X, Blueshift), Denny Zhou (Google Brain), Jeff Dean (Google Cloud ML / Google Research) and many other Google colleagues of mine.

Comments