English

Learning a Decision Tree Algorithm with Transformers

Machine Learning 2024-08-27 v2 Artificial Intelligence Computation and Language

Abstract

Decision trees are renowned for their ability to achieve high predictive performance while remaining interpretable, especially on tabular data. Traditionally, they are constructed through recursive algorithms, where they partition the data at every node in a tree. However, identifying a good partition is challenging, as decision trees optimized for local segments may not yield global generalization. To address this, we introduce MetaTree, a transformer-based model trained via meta-learning to directly produce strong decision trees. Specifically, we fit both greedy decision trees and globally optimized decision trees on a large number of datasets, and train MetaTree to produce only the trees that achieve strong generalization performance. This training enables MetaTree to emulate these algorithms and intelligently adapt its strategy according to the context, thereby achieving superior generalization performance.

Keywords

Cite

@article{arxiv.2402.03774,
  title  = {Learning a Decision Tree Algorithm with Transformers},
  author = {Yufan Zhuang and Liyuan Liu and Chandan Singh and Jingbo Shang and Jianfeng Gao},
  journal= {arXiv preprint arXiv:2402.03774},
  year   = {2024}
}