TabPFN: A Transformer That Solves Small Tabular Classification

Authors:

  • Noah Hollmann
  • Samuel Müller
  • Katharina Eggensperger
  • Frank Hutter

tl;dr

The authors train a transformer on a diversity of synthetic tabular datasets and end up with a model that can solve new regression/classification problems on small tabular datasets without further training, and is instead accomplished with only a forward pass of a transformer. This paper mainly applies the core methods presented in the paper, Transformers Can Do Bayesian Inference (arXiv), but generalize to more tasks and demonstrate the effectiveness on a wide variety of small-data benchmarks.

Given a small synthetic dataset \mathcal D, it can be split into \mathcal D_\text{train} and \mathcal D_\text{test}. Then samples from the test set, (x_\text{test}, y_\text{test}) \in \mathcal D_\text{test} are used to train a transformer q_\theta to minimize: -\log q_\theta (y_\text{test} | x_\text{test}, \mathcal D_\text{train}). In other words, the inputs to the transformer with a training set of size m are like: \{ (x_0, y_0), (x_1, y_1), \ldots , (x_m, y_m) \} \cup (x_\text{test}, ), and it’s trained to predict y_\text{test}. The attention mechanism allows the model to attend to all of the training samples, but is masked such that it cannot attend to other test samples. Since this model is trained on train/test datasets from a wide variety of tasks, it can generalize to new, previously unseen datasets and tasks. Loosely, the idea is that the transformer has learned the general algorithm of Bayesian Inference rather than simply a mapping from inputs to outputs of any particular task. In practice, multiple forward passes of TabPFN are made on manipulations of the data (e.g., shuffling the column order, power transforms), and the results are ensembled.

Key to making the connection to the framing of Bayesian methods is specifying the “prior”, which in this case represents the family of data-generating processes. Their primary method is via the construction of Structural Causal Models (SCMs), from which they generate the synthetic datasets. The underlying causal model graphs are efficiently implemented by instantiating neural networks with random parameters, randomly dropping some connections (i.e., removing causal relationships), and then randomly assigning nodes as inputs, output or hidden. To generate data, noise is sampled and passed through the neural network to assign values to each of the nodes, which can then be collected as rows in the synthetic dataset. A variety of neural network architectures are sampled to give diverse datasets with differing numbers of inputs and causal mechanisms.

In experiments, multiple small tabular data benchmarks are used to compare multiple baseline methods, such as standards like Logistic Regression, XGBoost, Catboost…etc., as well as AutoML frameworks like AutoGluon and Auto-sklearn. Despite requiring no training process on the evaluation dataset, TabPFN significantly out-performs, particularly when considering the time-budget.