Token Statistics Transformer: Linear-Time Attention via Variational Rate Reduction

Ziyang Wu1
Tianjiao Ding2†
Yifu Lu3†
Druv Pai1
Jingyuan Zhang4
Weida Wang5
Yaodong Yu1
Yi Ma1,6
Benjamin D. Haeffele7
1UC Berkeley 2UPenn 3UMich 4THU & Transcengram 5Tsinghua SIGS 6HKU 7JHU

TL;DR: ToST (Token Statistics Transformer) is a novel transformer architecture that achieves linear-time complexity and competitive performance through unrolled optimization of the variational rate reduction objective, leading to significantly improved computational efficiency and interpretability.

ToST Architecture

One layer of Token Statistics Transformer (ToST). The Token Statistics Self-Attention (TSSA) operator transforms tokens efficiently via multiplying each row of the projected token by a scalar, leading to linear complexity in both space and time.

Key Contributions

  • A new variational form of the MCR² objective leading to Token Statistics Self-Attention (TSSA)
  • A novel transformer-like architecture constructed through first principles with theoretical grounding
  • The resulting architecture enjoys linear complexity of time and memory
  • Competitive performance across diverse tasks with improved interpretability

Method

Main Idea

Following CRATE, ToST aims to map input tokens(e.g. image patches) to a structured feature space of lower dimension. Tokens with similar semantics may belong to the same geometric structures in the original space and be grouped together. A learned mapping $\phi$ converts these tokens into features which are compressed, linearized, and discriminative.

We derive our network architecture by extending prior work CRATE, which has shown that a transformer style architecture naturally arises by "white-box" architecture design, where each layer of the network is designed to implement an incremental optimization step of a maximal coding rate reduction objective (MCR²).

Specifically, we derive a novel variational form of the MCR² objective and show that the architecture that results from unrolled gradient descent of this variational objective leads to a new attention module called Token Statistics Self-Attention (TSSA). TSSA has linear computational and memory complexity and radically departs from the typical attention architecture that computes pairwise similarities between tokens.

Key Formulations

The MCR² objective is defined as:

$$\Delta R(\mathbf{Z},\mathbf{\Pi})\doteq\underbrace{\frac{1}{2}\log\det\left(\mathbf{I}+\frac{d}{\epsilon^{2}}\frac{1}{n}\mathbf{Z}\mathbf{Z}^{\top}\right)}_{\doteq R(\mathbf{Z})}-\underbrace{\frac{1}{2}\sum_{k=1}^{K}\frac{n_{k}}{n}\log\det\left(\mathbf{I}+\frac{d}{\epsilon^{2}}\frac{1}{n_{k}}\mathbf{Z}\mathrm{Diag}(\bm{\pi}_{k})\mathbf{Z}^{\top}\right)}_{\doteq R_{c}(\mathbf{Z},\mathbf{\Pi})}$$

We propose a novel variational form of the compression objective:

$$R^{var}_{c,f} (\mathbf{Z},\mathbf{\Pi} \mid \{\mathbf{U}_{k}\}_{k = 1}^{K}) \doteq \frac{1}{2}\sum_{k=1}^K \frac{n_k}{n} \sum_{i=1}^{d} f\left(\frac{1}{n_k} (\mathbf{U}_{k}^{\top} \mathbf{Z} \mathrm{Diag}(\bm{\pi}_{k}) \mathbf{Z}^{\top} \mathbf{U}_{k})_{ii} \right)$$

The Token Statistics Self-Attention (TSSA) operator is defined as:

$$\operatorname{\texttt{TSSA}}(\mathbf{Z}\mid\{\mathbf{U}_{k}\}_{k=1}^{K})\doteq-\frac{\tau}{n}\sum_{k=1}^{K}\mathbf{U}_{k}\mathbf{D}(\mathbf{Z},\bm{\pi}_{k}\mid\mathbf{U}_{k})\mathbf{U}_{k}^{\top}\mathbf{Z}\mathrm{Diag}(\bm{\pi}_{k})$$

Empirical Results

Linear Complexity of Compute and Memory

ToST achieves linear scaling with sequence length in both computation time and memory usage, making it significantly more efficient than standard transformers.

ToST speed

Complexity Analysis comparison.

ToST speed

Speed & Mem usage comparison evaluated on GPUs.

Competitive Performance on Vision

ToST demonstrates comparable performance with conventional transformers while being significantly more computationally efficient.

ToST Results

Experiments on Long Sequence Tasks and Language Modeling

ToST can be extended and works on various task scenerios including causal language modeling.

ToST Architecture

Performance on NLP tasks.

Principled Design

Since ToST is derived from a learning objective through unrolling, we can analyze the behavior of a learned model layer-by-layer in a principled manner.

Coding Rate across layers

The variational compression term of the TSSA outputs at different layers of the ToST model

Interpretability in Learned Representations

ToST naturally produces interpretable attention patterns without complex self-supervised training

ToST head vis

Comparison of [CLS] token attention map from the last head in the penultimate global class attention layer.

ToST Pi vis

Visualize each row (after reshaping) of the estimated membership matrix $\Pi$ in the TSSA layer.

In summary, we develop a novel, efficient attention mechanism derived from a theoretically principled objective of data compression and representation learning. Our proposed TSSA operator is unique among attention operators in that it does not require computing pairwise interactions between tokens and instead is constructed from a second moment statistic of projected token features. This results in our operator being significantly more efficient than standard attention operators, while still achieving similar performance to comparable transformers. We believe that this work provides an initial demonstration of the tremendous potential in designing novel and efficient deep architectures from mathematical principles.

Acknowledgements

This website template was adapted from CRATE's project page.