Trainable Weight Averaging: Accelerating Training and Improving Generalization
Abstract
Weight averaging is a widely used technique for accelerating training and improving the generalization of deep neural networks (DNNs). While existing approaches like stochastic weight averaging (SWA) rely on pre-set weighting schemes, they can be suboptimal when handling diverse weights. We introduce Trainable Weight Averaging (TWA), a novel optimization method that operates within a reduced subspace spanned by candidate weights and learns optimal weighting coefficients through optimization. TWA offers greater flexibility and can be applied to different training scenarios. For large-scale applications, we develop a distributed training framework that combines parallel computation with low-bit compression for the projection matrix, effectively managing memory and computational demands. TWA can be implemented using either training data (TWA-t) or validation data (TWA-v), with the latter providing more effective averaging. Extensive experiments showcase TWA's advantages: (i) it consistently outperforms SWA in generalization performance and flexibility, (ii) when applied during early training, it reduces training time by over 40\% on CIFAR datasets and 30\% on ImageNet while maintaining comparable performance, and (iii) during fine-tuning, it significantly enhances generalization by weighted averaging of model checkpoints. In summary, we present an efficient and effective framework for trainable weight averaging. The code is available at https://github.com/nblt/TWA.
Cite
@article{arxiv.2205.13104,
title = {Trainable Weight Averaging: Accelerating Training and Improving Generalization},
author = {Tao Li and Zhehao Huang and Yingwen Wu and Zhengbao He and Qinghua Tao and Xiaolin Huang and Chih-Jen Lin},
journal= {arXiv preprint arXiv:2205.13104},
year = {2025}
}
Comments
Journal version in progress. Previously accepted to ICLR 2023