English

Provably learning a multi-head attention layer

Machine Learning 2024-02-07 v1 Data Structures and Algorithms Machine Learning

Abstract

The multi-head attention layer is one of the key components of the transformer architecture that sets it apart from traditional feed-forward models. Given a sequence length kk, attention matrices Θ1,,ΘmRd×d\mathbf{\Theta}_1,\ldots,\mathbf{\Theta}_m\in\mathbb{R}^{d\times d}, and projection matrices W1,,WmRd×d\mathbf{W}_1,\ldots,\mathbf{W}_m\in\mathbb{R}^{d\times d}, the corresponding multi-head attention layer F:Rk×dRk×dF: \mathbb{R}^{k\times d}\to \mathbb{R}^{k\times d} transforms length-kk sequences of dd-dimensional tokens XRk×d\mathbf{X}\in\mathbb{R}^{k\times d} via F(X)i=1msoftmax(XΘiX)XWiF(\mathbf{X}) \triangleq \sum^m_{i=1} \mathrm{softmax}(\mathbf{X}\mathbf{\Theta}_i\mathbf{X}^\top)\mathbf{X}\mathbf{W}_i. In this work, we initiate the study of provably learning a multi-head attention layer from random examples and give the first nontrivial upper and lower bounds for this problem: - Provided {Wi,Θi}\{\mathbf{W}_i, \mathbf{\Theta}_i\} satisfy certain non-degeneracy conditions, we give a (dk)O(m3)(dk)^{O(m^3)}-time algorithm that learns FF to small error given random labeled examples drawn uniformly from {±1}k×d\{\pm 1\}^{k\times d}. - We prove computational lower bounds showing that in the worst case, exponential dependence on mm is unavoidable. We focus on Boolean X\mathbf{X} to mimic the discrete nature of tokens in large language models, though our techniques naturally extend to standard continuous settings, e.g. Gaussian. Our algorithm, which is centered around using examples to sculpt a convex body containing the unknown parameters, is a significant departure from existing provable algorithms for learning feedforward networks, which predominantly exploit algebraic and rotation invariance properties of the Gaussian distribution. In contrast, our analysis is more flexible as it primarily relies on various upper and lower tail bounds for the input distribution and "slices" thereof.

Keywords

Cite

@article{arxiv.2402.04084,
  title  = {Provably learning a multi-head attention layer},
  author = {Sitan Chen and Yuanzhi Li},
  journal= {arXiv preprint arXiv:2402.04084},
  year   = {2024}
}

Comments

105 pages, comments welcome

R2 v1 2026-06-28T14:40:17.020Z