TL;DR: ToST (Token Statistics Transformer) is a novel transformer architecture that achieves linear-time complexity and competitive performance through unrolled optimization of the variational rate reduction objective, leading to significantly improved computational efficiency and interpretability.
One layer of Token Statistics Transformer (ToST). The Token Statistics Self-Attention (TSSA) operator transforms tokens efficiently via multiplying each row of the projected token by a scalar, leading to linear complexity in both space and time.
Following CRATE, ToST aims to map input tokens(e.g. image patches) to a structured feature space of lower dimension. Tokens with similar semantics may belong to the same geometric structures in the original space and be grouped together. A learned mapping $\phi$ converts these tokens into features which are compressed, linearized, and discriminative.
We derive our network architecture by extending prior work CRATE, which has shown that a transformer style architecture naturally arises by "white-box" architecture design, where each layer of the network is designed to implement an incremental optimization step of a maximal coding rate reduction objective (MCR²).
Specifically, we derive a novel variational form of the MCR² objective and show that the architecture that results from unrolled gradient descent of this variational objective leads to a new attention module called Token Statistics Self-Attention (TSSA). TSSA has linear computational and memory complexity and radically departs from the typical attention architecture that computes pairwise similarities between tokens.
The MCR² objective is defined as:
We propose a novel variational form of the compression objective:
The Token Statistics Self-Attention (TSSA) operator is defined as:
ToST achieves linear scaling with sequence length in both computation time and memory usage, making it significantly more efficient than standard transformers.
Complexity Analysis comparison.
Speed & Mem usage comparison evaluated on GPUs.
ToST demonstrates comparable performance with conventional transformers while being significantly more computationally efficient.
ToST can be extended and works on various task scenerios including causal language modeling.
Performance on NLP tasks.
Since ToST is derived from a learning objective through unrolling, we can analyze the behavior of a learned model layer-by-layer in a principled manner.
The variational compression term of the TSSA outputs at different layers of the ToST model
ToST naturally produces interpretable attention patterns without complex self-supervised training
Comparison of [CLS] token attention map from the last head in the penultimate global class attention layer.
Visualize each row (after reshaping) of the estimated membership matrix $\Pi$ in the TSSA layer.
In summary, we develop a novel, efficient attention mechanism derived from a theoretically principled objective of data compression and representation learning. Our proposed TSSA operator is unique among attention operators in that it does not require computing pairwise interactions between tokens and instead is constructed from a second moment statistic of projected token features. This results in our operator being significantly more efficient than standard attention operators, while still achieving similar performance to comparable transformers. We believe that this work provides an initial demonstration of the tremendous potential in designing novel and efficient deep architectures from mathematical principles.
This website template was adapted from CRATE's project page.