Population Transformer (PopT)

*Equal contribution, 1California Institute of Technology, 2Massachusetts Institute of Technology
ICLR 2025 Oral Presentation

We present Population Transformer, a technique to learn population-level codes for arbitrary ensembles of neural recordings at scale.

PopT Method

(a) The Population Transformer (PopT) stacks on top of pretrained temporal embeddings (shown: BrainBERT in red dotted outline) and enhances downstream decoding by enabling learned aggregation of multiple spatially-sparse data channels. The model is trained using a self-supervised loss that optimizes both (b) ensemble-level and (c) channel-level objectives. This allows PopT to learn robust spatial and temporal representations that generalize across subjects and datasets.


We address key challenges in scaling models with neural time-series data, namely, sparse and variable electrode distribution across subjects and datasets. The pretrained PopT lowers the amount of data required for downstream decoding experiments, while increasing accuracy, even on held-out subjects and tasks. Our framework is generalizable to multiple time-series embeddings and neural data modalities, suggesting broad applicability of our technique.

Decoding performance

Pretraining PopT on large amounts of neural data significantly improves downstream decoding performance compared to baseline non-pretrained aggregation methods.

We see this across tasks (x-axis), data modalities (a: iEEG) (b: EEG), and temporal embedding types (hatch):

PopT vs Baseline Performance across datasets, modalities, tasks, and temporal encoding type



and across channel ensemble sizes (x-axis):

PopT Performance

Sample efficiency

Pretraining PopT improves sample efficiency, requiring fewer labeled examples (x-axis) for strong performance:

PopT Sample Efficiency

Subject generalizability

Gains in decoding performance are available to new (held-out) subjects:

PopT Subject Generalizability

Pretraining data scaling

Increasing the amount of pretraining data (colors) leads to improvements in downstream decoding performance:

PopT Pretraining Benefits

Interpretability

We explore how a pretrained PopT can generate insights into neural data.

We can recover connectivity maps from the pretrained model (right) and compare with coherence analysis (left):

PopT connectivity map


We can discover functional brain regions with [CLS] token attention weight analysis, finding auditory and language regions (arrows) being attended to:

PopT attention weights

BibTeX

@misc{chau2024populationtransformer,
        title={Population Transformer: Learning Population-level Representations of Neural Activity}, 
        author={Geeling Chau and Christopher Wang and Sabera Talukder and Vighnesh Subramaniam and Saraswati Soedarmadji and Yisong Yue and Boris Katz and Andrei Barbu},
        year={2024},
        eprint={2406.03044},
        archivePrefix={arXiv},
        primaryClass={cs.LG},
        url={https://arxiv.org/abs/2406.03044}, 
      }