Language Model SAEs
Welcome to the documentation for Language Model SAEs - a library for training and analyzing Sparse Autoencoders (SAEs) on language models.
Overview
Sparse Autoencoders (SAEs) are neural network models used to extract interpretable features from language models. They help address the superposition problem in neural networks by learning sparse, interpretable representations of activations.
This library provides:
- Scalability: Our framework is fully distributed with arbitrary combinations of data, model, and head parallelism for both training and analysis. Enjoy training SAEs with millions of features!
- Flexibility: We support a wide range of SAE variants, including vanilla SAEs, Lorsa (Low-rank Sparse Attention), CLT (Cross-layer Transcoder), MoLT (Mixture of Linear Transforms), CrossCoder, and more. Each variant can be combined with different activation functions (e.g., ReLU, JumpReLU, TopK, BatchTopK) and sparsity penalties (e.g., L1, Tanh).
- Easy to Use: We provide high-level
runnersAPIs to quickly launch experiments with simple configurations. Check our examples for verified hyperparameters. - Visualization: We provide a unified web interface to visualize learned SAE variants and their features.
Quick Start
Installation
We strongly recommend users to use uv for dependency management. uv is a modern drop-in replacement of poetry or pdm, with a lightning fast dependency resolution and package installation. See their instructions on how to initialize a Python project with uv.
To add our library as a project dependency, run:
We also support Ascend NPU as an accelerator backend. To add our library as a project dependency with NPU dependency constraints, run:
Of course, you can also directly use pip to install our library. To install our library with pip, run:
We also support Ascend NPU as an accelerator backend. To install our library with NPU dependency constraints, run:
Load a trained Sparse Autoencoder from HuggingFace
Load any Sparse Autoencoder or other sparse dictionaries in Language-Model-SAEs or SAELens format.
# Load Gemma Scope 2 SAE
sae = AbstractSparseAutoEncoder.from_pretrained("gemma-scope-2-1b-pt-res-all:layer_12_width_16k_l0_small")
Training a Sparse Autoencoder
To train a simple Sparse Autoencoder on blocks.5.hook_resid_post of a Pythia-160M model with \(768*8\) features, you can use the following:
settings = TrainSAESettings(
sae=SAEConfig(
hook_point_in="blocks.6.hook_resid_post",
hook_point_out="blocks.6.hook_resid_post",
d_model=768,
expansion_factor=8,
act_fn="topk",
top_k=50,
dtype=torch.float32,
device="cuda",
),
initializer=InitializerConfig(
grid_search_init_norm=True,
),
trainer=TrainerConfig(
amp_dtype=torch.float32,
lr=1e-4,
initial_k=50,
k_warmup_steps=0.1,
k_schedule_type="linear",
total_training_tokens=800_000_000,
log_frequency=1000,
eval_frequency=1000000,
n_checkpoints=0,
check_point_save_mode="linear",
exp_result_path="results",
),
model=LanguageModelConfig(
model_name="EleutherAI/pythia-160m",
device="cuda",
dtype="torch.float16",
),
model_name="pythia-160m",
datasets={
"SlimPajama-3B": DatasetConfig(
dataset_name_or_path="Hzfinfdu/SlimPajama-3B",
)
},
wandb=WandbConfig(
wandb_project="lm-saes",
exp_name="pythia-160m-sae",
),
activation_factory=ActivationFactoryConfig(
sources=[
ActivationFactoryDatasetSource(
name="SlimPajama-3B",
)
],
target=ActivationFactoryTarget.ACTIVATIONS_1D,
hook_points=["blocks.6.hook_resid_post"],
batch_size=4096,
buffer_size=4096 * 4,
buffer_shuffle=BufferShuffleConfig(
perm_seed=42,
generator_device="cuda",
),
),
sae_name="pythia-160m-sae",
sae_series="pythia-sae",
)
train_sae(settings)
Analyze a trained Sparse Autoencoder
Requires setting up MongoDB. See analyze-saes for details.
settings = AnalyzeSAESettings(
sae=PretrainedSAE(pretrained_name_or_path="path/to/sae", device="cuda"),
sae_name="pythia-160m-sae",
activation_factory=ActivationFactoryConfig(
sources=[ActivationFactoryDatasetSource(name="SlimPajama-3B")],
target=ActivationFactoryTarget.ACTIVATIONS_2D,
hook_points=["blocks.6.hook_resid_post"],
batch_size=16,
context_size=2048,
),
model=LanguageModelConfig(model_name="EleutherAI/pythia-160m", device="cuda"),
model_name="pythia-160m",
datasets={"SlimPajama-3B": DatasetConfig(dataset_name_or_path="Hzfinfdu/SlimPajama-3B")},
analyzer=FeatureAnalyzerConfig(total_analyzing_tokens=100_000_000),
mongo=MongoDBConfig(),
device_type="cuda",
)
analyze_sae(settings)
Convert trained Sparse Autoencoder to SAELens format
Requires sae_lens package available. Supports ReLU, JumpReLU, and TopK SAEs.
from lm_saes import SparseAutoEncoder
sae = SparseAutoEncoder.from_pretrained("path/to/sae")
sae_saelens = sae.to_saelens(model_name="pythia-160m")
You can use the sae_saelens with any tools compatible to SAELens.
Citation
If you find this library useful in your research, please cite: