Grokking modular arithmetic
Abstract
We present a simple neural network that can learn modular arithmetic tasks and exhibits a sudden jump in generalization known as ``grokking''. Concretely, we present (i) fully-connected two-layer networks that exhibit grokking on various modular arithmetic tasks under vanilla gradient descent with the MSE loss function in the absence of any regularization; (ii) evidence that grokking modular arithmetic corresponds to learning specific feature maps whose structure is determined by the task; (iii) analytic expressions for the weights -- and thus for the feature maps -- that solve a large class of modular arithmetic tasks; and (iv) evidence that these feature maps are also found by vanilla gradient descent as well as AdamW, thereby establishing complete interpretability of the representations learnt by the network.
Cite
@article{arxiv.2301.02679,
title = {Grokking modular arithmetic},
author = {Andrey Gromov},
journal= {arXiv preprint arXiv:2301.02679},
year = {2023}
}
Comments
11+5 pages, 10 figures