Pure Transformers are Powerful Graph Learners

tl;dr

They model each node and edge as a token in a vocabulary and learn embeddings for each, and feed them to a Transformer (“TokenGT”). In addition to the node/edge features \mathbf X_v each token is given a set of orthonormal vectors as an identifier: [\mathbf P_u, \mathbf P_v] , where these are the same values for a node, but represent the source/destination identifiers in the case of an edge. Since \{\mathbf P_i\} are orthonormal, [\mathbf P_u, \mathbf P_v][\mathbf P_j, \mathbf P_k]^T is 1 if u=j \text{ or } v=k, and 0 otherwise. They also learn a “type identifier”, which is an embedding for “node” and “edge”, that is appended to the token. The final token embedding is then: \mathbf X = [\mathbf X_i, \mathbf P_u, \mathbf P_v, \mathbf E^\tau].

There is a significant theory contribution where, “we prove that with appropriate token-wise embeddings, self-attention over node and edge tokens can approximate any permutation equivariant linear operator on a graph…This provides a solid theoretical guarantee that, with the embeddings and enough attention heads, a Transformer is at least as expressive as a second-order invariant graph network (2-IGN), which is already more expressive than all message-passing GNNs”. This is mainly built from the Invariant Graph Networks (IGN) work from Maron, which is based on “linear equivariant layers” (arXiv).

Notes

  • Section 5 shows LPEs work better than Orthogonal Random Features (ORFs).

I have this paper on my reading list, but curious if you have a TL;DR on how this compares to the “Global Self Attention as a Replacement for Graph Convolution” paper (which you had mentioned is a top performer on the leaderboards in your TLDR).

Welcome! One key difference is that in the Global Self-Attention (GSA) work, they have a token for every possible edge (i.e., always N^2 tokens, where N is the number of nodes), and then the attention mechanism is modified so that these node-pairs can restrict/enhance information flow. Presumably, it’s so that observed edges can enhance information exchange, and unobserved edges can restrict, but not be totally cut off.

In contrast, this work only includes tokens of nodes and node-pairs that are observed in the graph, and the transformer’s attention mechanism is not modified. Nodes can still attend to any other node, and the orthogonal positional encodings can help it reason about the connection structure, but there are not explicit tokens for edges that are unobserved in the graph.

I hope this helps.

Absolutely thanks! PS i found this interesting comparison of various graph transformers:
[2302.04181] Attending to Graph Transformers (arxiv.org)
(I found it by looking for papers citing this one).
Cheers!

EDIT: wrong link fixed

Thanks! I summarized this paper here: Attending to Graph Transformers

1 Like

In thinking about an implementation of this method, it occurred to me that the randomness in the positional encodings (from either the orthogonal random feature method, or the sign-flip and basis ambiguity problem of Laplacian eigenvectors) means that multiple forward passes of the same graph would generate different predictions. How then to manage this variance in a production setting? I asked this to the one of the authors and will paste the exchange here.

Question:

I’m playing with transformers on graphs and had a question about your approach of using ORFs. Won’t the output from your model be sensitive to these values, meaning that as you evaluate the same graph with different ORFs, the output will change? How do you deal with this, particularly in evaluation? Did you study the sensitivity to this by, e.g., looking at variance of predictions when different ORFs are sampled? I suppose this is also a problem with sign-flipping LPEs. Thanks for any insight.

Response:

Hi, thank you for your interest in our work. You are correct, with ORFs (or LPE with test time sign flipping), the prediction of the model should be interpreted as realizations (samples) of a random variable. That said, what we were doing in the paper is actually 1-sample MC estimation of the mean of this random variable. More samples (k random forward and then ensemble the outputs) would improve the variance of this estimation, but we didn’t look deeper into the variance reduction of ORF models in the paper (partially because LPE worked better in practice than ORF).

After follow-up investigations, we concluded that we cannot fully avoid stochastic predictions (like as in ORF) for pure transformer style graph models, as they always require random tie breaking between automorphic nodes. So in a follow-up research, we built a refined version of TokenGT (kind of) where the stochastic nature of the model is made more explicit.

In the experiments regarding this work (LPS), we observed that around 10-20 samples during inference is enough to obtain low-variance predictions — although the model design is different, I weakly conjecture similar samples would reasonably work for TokenGT with ORF.

Just saw your tweet from 2 hours ago, yes, this stochasticity also implicitly applies to Laplacian eigenvectors with repeated eigenvalues. So in this case, TokenGT-Lap would be also a stochastic predictor up to the internal random tie breaking. Good news is that the expectation of this stochastic prediction would be still permutation invariant/equivariant. This paper has relevant discussions one might find it interesting.

For context, the tweet reference is to one of mine where I lament the discovery that Laplacian Eigenvectors are not permutation equivariant if there’s redundancy in the eigenvalues. Several experts chimed in with interesting follow-up works and explanation.