Train Sparse Autoencoders
Language-Model-SAEs provides a general way to train, analyze and visualize Sparse Autoencoders and their variants. To help you get started quickly, we've included example scripts that guide you through each stage of working with SAEs. This guide begins with a foundational example and progressively introduces the core features and capabilities of the library.
Training Basic Sparse Autoencoders
A Sparse Autoencoder is trained to reconstruct model activations at specific position. We depend on TransformerLens to take activations out of model forward pass, specified by hook points. Language-Model-SAEs provides complete abstraction on the necessary components to train Sparse Autoencoders at ease.
Load Model & Dataset
Generation of our training data, the model activations, requires the presence of both the language model and the dataset. First, we can load a pretrained language model by:
from lm_saes import LanguageModelConfig, TransformerLensLanguageModel
model = TransformerLensLanguageModel(
LanguageModelConfig(
model_name="EleutherAI/pythia-160m",
device="cuda",
dtype="torch.float16",
)
)
where TransformerLensLanguageModel is a simple wrapper around the TransformerLens HookedTransformer, enhanced with:
- Unified interface for extracting activations (compatible with native HuggingFace transformers) and tracing token positions from original texts.
- Distributed training support with simple data parallelism integrated.
See the section below to find how to use HuggingFace transformers directly for generating activation.
Use Half Precision
Activation generation constitutes the majority of training time. We strongly recommend using half precision (float16 or bfloat16) to accelerate the forward pass and reduce GPU memory usage. Here we use FP16 since Pythia models are trained in FP16.
Next, we load some text corpus from HuggingFace. Different pretraining text corpus often does not have so much effect on SAE training. Here we load Hzfinfdu/SlimPajama-3B, a 3B-token subset of the 627B SlimPajama dataset, which is typically sufficient for basic SAE training.
Generate Activations
Model activations often require some further transformation to ensure correct and efficient SAE training. We provide ActivationFactory as the core abstraction for producing activation streams. It provides a comprehensive interface to generate activations from model forward passes, filter unnecessary tokens, and reshape, re-batch, and shuffle activations.
We can create an ActivationFactory as follow:
from lm_saes import (
ActivationFactory,
ActivationFactoryConfig,
ActivationFactoryDatasetSource,
ActivationFactoryTarget,
BufferShuffleConfig,
)
activation_factory = ActivationFactory(
ActivationFactoryConfig(
sources=[ActivationFactoryDatasetSource(name="SlimPajama-3B")],
target=ActivationFactoryTarget.ACTIVATIONS_1D,
hook_points=["blocks.6.hook_resid_post"],
batch_size=4096,
buffer_size=4096 * 4, # Set to enable the online activation shuffling
buffer_shuffle=BufferShuffleConfig(
perm_seed=42,
generator_device="cuda",
),
)
)
Then, we call the .process method, passing in our loaded model and dataset, to start the stream processing of the activations.
activations_stream = activation_factory.process(
model=model,
model_name="pythia-160m",
datasets={"SlimPajama-3B": dataset},
)
It returns a streaming iterator. Each item is a dictionary mainly containing:
| Key | Description |
|---|---|
"blocks.6.hook_resid_post" |
Activation tensor with shape (batch_size, d_model) — in this config, (4096, 768) |
"tokens" |
The corresponding token IDs |
With target set to ActivationFactoryTarget.ACTIVATIONS_1D, the produced activations will have no sequence dimension. They are shuffled across both samples and context positions to ensure the SAE trains on randomly sampled activations from any position in any sample.
Create and Initialize SAE
We've successfully prepared the data we need. It's time to turn to the SAE itself! But before training, we should first define the SAE architecture and initialize it. Create an SAEConfig to define the SAE architecture:
import torch
from lm_saes import SAEConfig
sae_cfg = SAEConfig(
hook_point_in="blocks.6.hook_resid_post",
hook_point_out="blocks.6.hook_resid_post",
d_model=768,
expansion_factor=8,
act_fn="relu",
dtype=torch.float32,
device="cuda",
)
Here're some brief explanations of the config we set:
| Parameter | Description |
|---|---|
hook_point_in / hook_point_out |
When identical, this defines an SAE; when different, it becomes a transcoder |
d_model |
Must match the model's hidden size (768 for Pythia-160m) |
expansion_factor |
Multiplier for the latent dimension. Here, d_sae = 768 × 8 = 6144 |
act_fn |
Activation function. Modern SAEs often use "jumprelu" or "batchtopk", but we use "relu" for simplicity |
More options of SAEConfig are introduced in the reference.
With only SAEConfig defined, the created SAE will have nothing but empty tensors as parameters. We need to fill the empty parameters with proper initialization, which is often proved crucial for final SAE performance. The Initializer class handles parameter initialization. The grid_search_init_norm option (recommended) searches for the optimal encoder/decoder parameter scale to minimize initial MSE loss on the activation distribution.
from lm_saes import Initializer, InitializerConfig
initializer = Initializer(InitializerConfig(grid_search_init_norm=True))
sae = initializer.initialize_sae_from_config(
sae_cfg,
activation_stream=activations_stream
)
Train SAE
Finally, we can start training! A Trainer instance is responsible for holding optimizer & scheduler states.
from lm_saes import Trainer, TrainerConfig
trainer = Trainer(
TrainerConfig(
amp_dtype=torch.float32,
lr=1e-4,
total_training_tokens=800_000_000,
log_frequency=1000,
exp_result_path="results",
)
)
| Parameter | Description |
|---|---|
amp_dtype |
Mixed precision dtype. Also handles precision mismatches between SAE parameters and activations |
lr |
Learning rate |
total_training_tokens |
Total tokens for training. Training steps = total tokens / batch size |
log_frequency |
Logging interval (in steps) for console and W&B |
exp_result_path |
Directory for saving results and checkpoints |
More options on the optimizer/scheduler and other hyperparameters are available. See the reference for more detail.
Just run trainer.fit and pass in the initialized SAE and the activation stream, and keep eyes on the console log to see whether the training goes well!
sae.cfg.save_hyperparameters("results") # Save hyperparameter before training
trainer.fit(
sae=sae,
activation_stream=activations_stream,
)
sae.save_pretrained(save_path="results") # Save the trained weight after training
Consistent Save Path
The path in sae.cfg.save_hyperparameters and sae.save_pretrained should be the same as specified in exp_result_path in TrainerConfig. Otherwise, the trained SAE may not be able to be correctly loaded.
Using the High-Level Runner API
For a more streamlined experience, Language-Model-SAEs also provides a high-level train_sae function that bundles all configuration into a single TrainSAESettings object. You can programmatically create the settings object and call the train_sae, or you can also use a configuration-file based settings and run it with our lm-saes CLI:
Create the TrainSAESettings in Python and call train_sae with it.
import torch
from lm_saes import (
TrainSAESettings,
train_sae,
SAEConfig,
InitializerConfig,
TrainerConfig,
LanguageModelConfig,
DatasetConfig,
ActivationFactoryConfig,
ActivationFactoryDatasetSource,
ActivationFactoryTarget,
BufferShuffleConfig,
)
settings = TrainSAESettings(
sae=SAEConfig(
hook_point_in="blocks.6.hook_resid_post",
hook_point_out="blocks.6.hook_resid_post",
d_model=768,
expansion_factor=8,
act_fn="relu",
dtype=torch.float32,
device="cuda",
),
initializer=InitializerConfig(
grid_search_init_norm=True,
),
trainer=TrainerConfig(
amp_dtype=torch.float32,
lr=1e-4,
total_training_tokens=800_000_000,
log_frequency=1000,
exp_result_path="results",
),
model=LanguageModelConfig(
model_name="EleutherAI/pythia-160m",
device="cuda",
dtype="torch.float16",
),
model_name="pythia-160m",
datasets={
"SlimPajama-3B": DatasetConfig(
dataset_name_or_path="Hzfinfdu/SlimPajama-3B",
)
},
activation_factory=ActivationFactoryConfig(
sources=[
ActivationFactoryDatasetSource(
name="SlimPajama-3B",
)
],
target=ActivationFactoryTarget.ACTIVATIONS_1D,
hook_points=["blocks.6.hook_resid_post"],
batch_size=4096,
buffer_size=4096 * 4,
buffer_shuffle=BufferShuffleConfig(
perm_seed=42,
generator_device="cuda",
),
),
sae_name="pythia-160m-sae",
sae_series="pythia-sae",
)
train_sae(settings)
CLI-based workflow requires a configuration file containing the settings consistent with TrainSAESettings. Common configuration file type like TOML, JSON and YAML are supported.
Create a TOML configuration file (e.g., train_config.toml) with the following content:
sae_name = "pythia-160m-sae"
sae_series = "pythia-sae"
model_name = "pythia-160m"
device_type = "cuda"
[sae]
sae_type = "sae"
hook_point_in = "blocks.6.hook_resid_post"
hook_point_out = "blocks.6.hook_resid_post"
d_model = 768
expansion_factor = 8
act_fn = "relu"
dtype = "torch.float32"
device = "cuda"
[initializer]
grid_search_init_norm = true
[trainer]
amp_dtype = "torch.float32"
lr = 0.0001
total_training_tokens = 800_000_000
log_frequency = 1000
exp_result_path = "results"
[model]
model_name = "EleutherAI/pythia-160m"
device = "cuda"
dtype = "torch.float16"
[datasets."SlimPajama-3B"]
dataset_name_or_path = "Hzfinfdu/SlimPajama-3B"
[activation_factory]
target = "activations-1d"
hook_points = ["blocks.6.hook_resid_post"]
batch_size = 4096
buffer_size = 16384
[[activation_factory.sources]]
type = "dataset"
name = "SlimPajama-3B"
[activation_factory.buffer_shuffle]
perm_seed = 42
generator_device = "cuda"
Then run the training with:
We also recommend users to directly use the low level semantics for launching training, which allows more granular control and easier customizing:
import datasets
import torch
from lm_saes import (
ActivationFactory,
ActivationFactoryConfig,
ActivationFactoryDatasetSource,
ActivationFactoryTarget,
BufferShuffleConfig,
Initializer,
InitializerConfig,
LanguageModelConfig,
SAEConfig,
Trainer,
TrainerConfig,
TransformerLensLanguageModel,
)
# 1. Load Model & Dataset
model = TransformerLensLanguageModel(
LanguageModelConfig(
model_name="EleutherAI/pythia-160m",
device="cuda",
dtype="torch.float16",
)
)
dataset = datasets.load_dataset(
"Hzfinfdu/SlimPajama-3B",
split="train",
)
# 2. Generate Activations
activation_factory = ActivationFactory(
ActivationFactoryConfig(
sources=[ActivationFactoryDatasetSource(name="SlimPajama-3B")],
target=ActivationFactoryTarget.ACTIVATIONS_1D,
hook_points=["blocks.6.hook_resid_post"],
batch_size=4096,
buffer_size=4096 * 4,
buffer_shuffle=BufferShuffleConfig(
perm_seed=42,
generator_device="cuda",
),
)
)
activations_stream = activation_factory.process(
model=model,
model_name="pythia-160m",
datasets={"SlimPajama-3B": dataset},
)
# 3. Create and Initialize SAE
sae_cfg = SAEConfig(
hook_point_in="blocks.6.hook_resid_post",
hook_point_out="blocks.6.hook_resid_post",
d_model=768,
expansion_factor=8,
act_fn="relu",
dtype=torch.float32,
device="cuda",
)
initializer = Initializer(InitializerConfig(grid_search_init_norm=True))
sae = initializer.initialize_sae_from_config(
sae_cfg,
activation_stream=activations_stream,
)
# 4. Train SAE
trainer = Trainer(
TrainerConfig(
amp_dtype=torch.float32,
lr=1e-4,
total_training_tokens=800_000_000,
log_frequency=1000,
exp_result_path="results",
)
)
sae.cfg.save_hyperparameters("results")
trainer.fit(
sae=sae,
activation_stream=activations_stream,
)
sae.save_pretrained(save_path="results")
Logging to W&B
Aside from the console logger, we support logging to Weights & Biases for tracking loss and metric changes throughout the training. Training metrics including explained variance and \(L^0\) norm will be automatically recorded. Below is a screenshot of the W&B logging:
Screenshot of W&B logging in training our LlamaScope 2 Beta PLTs.
To enable W&B logging, add the wandb configuration to your training setup:
Checkpoints and Continue Training
WIP
Activation Functions
Activation functions are the direct architectural design to enforce a sparse feature activations in SAE and its variants.
ReLU
ReLU is the most classical activation, proposed in initial works (Sparse Autoencoders Find Highly Interpretable Features in Language Models and Towards Monosemanticity: Decomposing Language Models With Dictionary Learning) using SAEs to disentangle superposition. Though its performance is found inferior to other activation functions in term of explained variance and \(L^0\) norms, it might be a good starting point to understand how SAE works due to its simplicity.
To use ReLU activation function, just set act_fn = "relu" in SAEConfig.
JumpReLU
JumpReLU is a state-of-the-art activation function proposed in Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders, and adopted by both Google DeepMind GemmaScope and GemmaScope 2, and Anthropic Cross Layer Transcoder.
JumpReLU modifies the ReLU activation function, allowing only elements that passing the corresponding element-wise thresholds to activate. Consider an input element \(x\) and a log-threshold \(t\), it computes:
where \(H(\cdot)\) is the Heaviside step function1. For comparison, ReLU can be written as \(\operatorname{ReLU}(x) = H(x)x\).
Since the Heaviside step function cannot be differentiated, JumpReLU uses a straight-through estimator of the gradient through the discontinuity of the nonlinearity. See Anthropic Circuit Update - January 2025 to learn how (log) JumpReLU thresholds are optimized.
To use the JumpReLU activation function, set act_fn = "jumprelu" in SAEConfig. You may also adjust the jumprelu_threshold_window to control the sensitivity of how JumpReLU thresholds update.
Dedicated Learning Rate for Log JumpReLU Thresholds
In our crosscoder training experiments in Evolution of Concepts in Language Model Pre-Training, we find it better to apply a smaller learning rate (0.1x) to the log JumpReLU thresholds. Though this setting hardly affects the final performance on reconstruction and sparsity, it makes the training loss far more smooth. The mean feature activation becomes lower after the change.
TopK
TopK is an activation function proposed in Scaling and evaluating sparse autoencoders. It keeps only the \(k\) largest elements, zeroing out the rest, thus directly enforcing strict sparsity on feature activation. To this end, it removes the need for additional sparsity penalties (which are typically basic requirements for ReLU & JumpReLU activations), and enables direct control over the sparsity quantitative (\(L^0\)) of feature activation.
To use TopK activation function, set act_fn = "topk" in SAEConfig, and set the top_k value to control the final sparsity of feature activation. We also provide some options in TrainerConfig to enable scheduling on \(k\) value during training:
| Parameter | Description | Default |
|---|---|---|
initial_k |
The starting \(k\) value for scheduling. Must be greater than or equal to top_k set in SAEConfig. |
None |
k_warmup_steps |
Steps (int) or fraction of total steps (float) for \(k\) to decay from initial_k to top_k. |
0.1 |
k_cold_booting_steps |
Steps (int) or fraction of total steps (float) to keep \(k\) at initial_k before starting the decay. |
0 |
k_schedule_type |
Scheduling strategy: "linear" or "exponential". |
"linear" |
k_exponential_factor |
Controls the curvature of the exponential decay. | 3.0 |
Use BatchTopK Activation
TopK activation enforces unnecessary fixed allocation of active latents. For strict architectural sparsity control, we recommend using BatchTopK for better performance.
BatchTopK
BatchTopK is a state-of-the-art activation function proposed in BatchTopK Sparse Autoencoders. It follows the idea of TopK to directly enforce sparsity, but replaces the sample-level TopK operation with a batch-level BatchTopK operation. For pre-feature-activations of shape (batch_size, d_sae), it selects the top batch_size * top_k activations across the entire batch of batch_size samples. This allows for more flexible allocation of active latents.
To use TopK activation function, set act_fn = "batchtopk" in SAEConfig, and set the top_k value to control the final sparsity of feature activation. We also provide some options in TrainerConfig to enable scheduling on \(k\) value during training:
| Parameter | Description | Default |
|---|---|---|
initial_k |
The starting \(k\) value for scheduling. Must be greater than or equal to top_k set in SAEConfig. |
None |
k_warmup_steps |
Steps (int) or fraction of total steps (float) for \(k\) to decay from initial_k to top_k. |
0.1 |
k_cold_booting_steps |
Steps (int) or fraction of total steps (float) to keep \(k\) at initial_k before starting the decay. |
0 |
k_schedule_type |
Scheduling strategy: "linear" or "exponential". |
"linear" |
k_exponential_factor |
Controls the curvature of the exponential decay. | 3.0 |
Convert BatchTopK to JumpReLU
BatchTopK introduces a dependency between the activations for the samples in a batch. To eliminate the effect, it's better to estimate a threshold \(\theta\) as average minimum positive activation values, and convert the activation function to JumpReLU with this threshold.
Sparsity Penalties
Activation functions like ReLU and JumpReLU do not strictly enforce sparsity on feature activations. It's the responsibility of the regularization functions to provide dynamics pushing feature activations sparse. Language-Model-SAEs supports the following sparsity penalties on feature activations:
\(L^p\)-Norm
Sparsity referes to the number of active latents in feature activation. In principle, we may want to directly add \(L^0\) norm to the loss term. However, \(L^0\) norm is discontinuous and cannot be differentiated,
In practice, \(L^1\) norm, as the best convex approximation to \(L^0\)2, is widely used in SAE training for controlling sparsity without lossing convexity of the optimization. Language-Model-SAEs implements a more general \(L^p\) norm as regularization, which is computed as:
where \(f(x)\) is the feature activation, \(\| W_\text{dec} \|_2\) is the decoder norm, \(p\) is the \(L^p\) power, and \(\lambda\) is the coefficient for the sparsity loss term.
To use the \(L^p\) norm, set sparsity_loss_type="power" in TrainerConfig. Other parameters include:
| Parameter | Description | Default |
|---|---|---|
l1_coefficient |
Coefficient \(\lambda\) for the sparsity loss term. | 0.00008 |
l1_coefficient_warmup_steps |
Steps (int) or fraction of total steps (float) to warm up the sparsity coefficient from 0. | 0.1 |
p |
The power \(p\) for \(L^p\) norm. Set to \(1\) for \(L^1\) norm. | 1 |
Tanh
One challenge with \(L^1\) penalty is shrinkage: in addition to encouraging sparsity, the penalty encourages activations to be smaller than they would be otherwise. This causes SAEs to recover a smaller fraction of the model loss than might be expected3.
The \(\tanh\) penalty addresses shrinkage by applying a bounded function to feature activations. For marginal cases where a feature is on the edge of activating, it provides the same gradient towards zero as \(L^1\), but for strongly-activating features it provides no penalty and hence no incentive to shrink the activation. The loss is computed as:
where \(f_i(x)\) is the \(i\)-th feature activation, \(\| W_{\text{dec},i} \|_2\) is the decoder norm for feature \(i\), and \(c\) is the stretch coefficient. Since \(\tanh(x) \to 1\) as \(x \to \infty\), this loss approximates counting the number of active features (\(L^0\) norm).
While the \(\tanh\) penalty was found to be a Pareto improvement in the \(L^0\)/MSE tradeoff, Anthropic's experiments showed that features trained with tanh were much harder to interpret due to many more high-frequency features (some activating on over 10% of inputs). However, their following up experiments show that these high density features don’t seem to be pathological as previous thought.
To use the \(\tanh\) penalty, set sparsity_loss_type="tanh" in TrainerConfig. Other parameters include:
| Parameter | Description | Default |
|---|---|---|
l1_coefficient |
Coefficient \(\lambda\) for the sparsity loss term. | 0.00008 |
l1_coefficient_warmup_steps |
Steps (int) or fraction of total steps (float) to warm up the sparsity coefficient from 0. | 0.1 |
tanh_stretch_coefficient |
Stretch coefficient \(c\) controlling the steepness of the tanh function. | 4.0 |
Tanh-Quadratic
A key issue with standard sparsity penalties (\(L^0\), \(L^1\), or \(\tanh\)) is that they only control the average number of active features, but are indifferent to the distribution of firing frequencies (See Removing High Frequency Latents from JumpReLU SAEs from the GDM Mech Interp Team for a detailed analysis). This allows some features to fire on a large fraction of inputs (>10%), which often leads to uninterpretable high-frequency features.
The tanh-quadratic loss addresses this by adding a quadratic term that specifically penalizes high-frequency features. First, an approximate frequency \(\hat{p}_i\) is computed by averaging the tanh scores across samples:
Then the loss is:
where \(s\) is the frequency scale (controlled by frequency_scale). The first term \(\hat{p}_i\) behaves like a standard sparsity penalty for low-frequency features (\(\hat{p}_i \ll s\)), while the quadratic term \(\hat{p}_i^2 / s\) dominates for high-frequency features (\(\hat{p}_i \gtrsim s\)), making it increasingly expensive for features to activate on a large fraction of inputs.
This formulation successfully eliminates high-frequency latents with only a modest impact on reconstruction loss, while improving frequency-weighted interpretability scores compared to standard JumpReLU SAEs.
Note
Our implementation of quadratic loss term uses \(\tanh\) as differentiable \(L^0\) proxies, which is different to the original proposal by GDM which directly use \(L^0\) paired with straight-through estimators.
To use tanh-quadratic, set sparsity_loss_type="tanh-quad" in TrainerConfig. Other parameters include:
| Parameter | Description | Default |
|---|---|---|
l1_coefficient |
Coefficient \(\lambda\) for the sparsity loss term. | 0.00008 |
l1_coefficient_warmup_steps |
Steps (int) or fraction of total steps (float) to warm up the sparsity coefficient from 0. | 0.1 |
tanh_stretch_coefficient |
Stretch coefficient \(c\) controlling the steepness of the tanh function. | 4.0 |
frequency_scale |
Scale factor \(s\) for the quadratic penalty. Smaller values penalize high-frequency features more aggressively. Typical values are 0.1 or 0.01 to suppress features firing on >10% of tokens. |
0.01 |
Auxiliary Losses
JumpReLU Pre-act Loss
For JumpReLU SAEs, an additional \(L_p\) penalty (called the "pre-act loss") proposed by Anthropic applies a small penalty to features which don't fire:
where \(\theta_i\) is the log-threshold and \(h_i\) is the pre-activation. This loss has been found extremely helpful in reducing dead features by providing a gradient signal whenever a feature is inactive, pushing the threshold lower.
Note
Since this loss provides a gradient signal whenever a feature is inactive, the appropriate scale for lp_coefficient should be a factor of the typical feature activation density lower than other loss terms (e.g., l1_coefficient).
To use the JumpReLU pre-act loss, set lp_coefficient to a positive value in TrainerConfig:
| Parameter | Description | Default |
|---|---|---|
lp_coefficient |
Coefficient \(\lambda_p\) for the JumpReLU pre-act loss. Recommended value is 3e-6. Set to None to disable. |
None |
Aux-K Loss
For TopK activation function, an additional auxiliary loss (AuxK) proposed in Scaling and evaluating sparse autoencoders helps revive dead latents during training, similar to Ghost Grads.
A latent is flagged as "dead" during training if it has not activated for a predetermined number of tokens (typically 10 million). Given the reconstruction error from the main model \(e = x - \hat{x}\), the auxiliary loss models this error using the top-\(k_{\text{aux}}\) dead latents:
where \(\hat{e} = W_{\text{dec}} z_{\text{dead}}\) is the reconstruction using only the top-\(k_{\text{aux}}\) dead latents. The full loss is then:
where \(\alpha\) is a small coefficient (typically \(1/32\)). Since the encoder forward pass can be shared (and dominates decoder cost and encoder backwards cost), adding this auxiliary loss only increases the computational cost by about 10%.
To use the Aux-K loss, set auxk_coefficient to a positive value in TrainerConfig:
| Parameter | Description | Default |
|---|---|---|
auxk_coefficient |
Coefficient \(\alpha\) for the Aux-K loss. Set to None to disable. Typical value is 1/32 (~0.03125). |
None |
k_aux |
Number of top dead latents \(k_{\text{aux}}\) to use for auxiliary reconstruction. | 512 |
dead_threshold |
Number of tokens a latent must not activate for to be considered dead. | 10_000_000 |
Note
Aux-K loss is specifically designed for TopK SAEs. It will not have effect on other activation functions like ReLU or JumpReLU.
Legacy Stategies
Early researches on SAEs employ strategies like Neuron Resampling and Ghost Grads to make dead neurons live again. However, modern initialization and sparsity losses have largely alleviated dead neurons. Thus, we remove the support for these strategies for simplicity of our internal code structure.
Caching Activations
Training with cached activations is a common workflow in practice. It enables efficient hyperparameter sweeping by reusing pre-generated activations and facilitates parallelized training and analysis (DP/TP). This approach significantly accelerates training; for example, training an 8x expansion SAE on Pythia 160M with 800M tokens typically drops from ~6 hours (on-the-fly) to ~30 minutes (cached). However, caching requires substantial disk space. For 800M tokens of activations from a single Pythia 160M layer (\(d_{\text{model}}=768\)) stored in FP16, the storage requirement is:
In this workflow, a separate task caches activations to disk at the output of ActivationFactory. When training, we re-configure the ActivationFactory to directly read from disk instead of generating activation from the language model on the fly.
To cache activation on disk, you can:
Create the GenerateActivationsSettings in Python and call generate_activations with it. Configurations except output_dir and total_tokens should be consistent with on-the-fly settings above. output_dir is where you want to place your generated activations. Ensure you have enough space at this directory. total_tokens should be equal or greater than the total_training_tokens you want to train your SAE on.
from lm_saes import (
GenerateActivationsSettings,
generate_activations,
LanguageModelConfig,
DatasetConfig,
ActivationFactoryTarget,
BufferShuffleConfig,
)
settings = GenerateActivationsSettings(
model=LanguageModelConfig(
model_name="EleutherAI/pythia-160m",
device="cuda",
dtype="torch.float16",
),
model_name="pythia-160m",
dataset=DatasetConfig(dataset_name_or_path="Hzfinfdu/SlimPajama-3B"),
dataset_name="SlimPajama-3B",
hook_points=["blocks.6.hook_resid_post"],
output_dir="path/to/activations",
total_tokens=800_000_000,
context_size=1024,
target=ActivationFactoryTarget.ACTIVATIONS_1D,
model_batch_size=32,
batch_size=4096,
buffer_size=16384,
buffer_shuffle=BufferShuffleConfig(
perm_seed=42,
generator_device="cuda",
),
device_type="cuda",
)
generate_activations(settings)
CLI-based workflow requires a configuration file containing the settings consistent with GenerateActivationsSettings. Common configuration file type like TOML, JSON and YAML are supported.
Create a TOML configuration file (e.g., generate_config.toml) with the following content:
model_name = "pythia-160m"
dataset_name = "SlimPajama-3B"
hook_points = ["blocks.6.hook_resid_post"]
output_dir = "path/to/activations"
total_tokens = 800_000_000
context_size = 1024
target = "activations-1d"
model_batch_size = 32
batch_size = 4096
buffer_size = 16384
device_type = "cuda"
[model]
model_name = "EleutherAI/pythia-160m"
device = "cuda"
dtype = "torch.float16"
[dataset]
dataset_name_or_path = "Hzfinfdu/SlimPajama-3B"
[buffer_shuffle]
perm_seed = 42
generator_device = "cuda"
Then run the generation with:
Also, you can directly create ActivationFactory and ActivationWriter instances to generate and write activations to disk.
import datasets
from lm_saes import (
LanguageModelConfig,
TransformerLensLanguageModel,
ActivationFactory,
ActivationFactoryConfig,
ActivationFactoryDatasetSource,
ActivationFactoryTarget,
BufferShuffleConfig,
ActivationWriter,
ActivationWriterConfig,
)
# Use same way to generate activations
model = TransformerLensLanguageModel(
LanguageModelConfig(
model_name="EleutherAI/pythia-160m",
device="cuda",
dtype="torch.float16",
)
)
dataset = datasets.load_dataset(
"Hzfinfdu/SlimPajama-3B",
split="train",
)
factory_cfg = ActivationFactoryConfig(
sources=[ActivationFactoryDatasetSource(name="SlimPajama-3B")],
target=ActivationFactoryTarget.ACTIVATIONS_1D,
hook_points=["blocks.6.hook_resid_post"],
context_size=1024,
model_batch_size=32,
batch_size=4096,
buffer_size=16384,
buffer_shuffle=BufferShuffleConfig(
perm_seed=42,
generator_device="cuda",
),
)
factory = ActivationFactory(factory_cfg)
activations = factory.process(
model=model,
model_name="pythia-160m",
datasets={"SlimPajama-3B": (dataset, None)},
)
# Create an ActivationWriter to write the activation stream to disk
writer_cfg = ActivationWriterConfig(
hook_points=["blocks.6.hook_resid_post"],
total_generating_tokens=800_000_000,
cache_dir="path/to/activations",
)
writer = ActivationWriter(writer_cfg)
writer.process(activations)
Training with Cached Activations
Once you have generated and saved activations to disk, you can configure the ActivationFactory to read from these files instead of running the language model. This is done by replacing ActivationFactoryDatasetSource with ActivationFactoryActivationsSource in the configuration.
import torch
from lm_saes import (
TrainSAESettings,
train_sae,
SAEConfig,
TrainerConfig,
ActivationFactoryConfig,
ActivationFactoryActivationsSource,
ActivationFactoryTarget,
)
settings = TrainSAESettings(
sae=SAEConfig(
hook_point_in="blocks.6.hook_resid_post",
hook_point_out="blocks.6.hook_resid_post",
d_model=768,
expansion_factor=8,
act_fn="relu",
dtype=torch.float32,
device="cuda",
),
trainer=TrainerConfig(
amp_dtype=torch.float32,
lr=1e-4,
total_training_tokens=800_000_000,
exp_result_path="results",
),
activation_factory=ActivationFactoryConfig(
sources=[
ActivationFactoryActivationsSource(
path="path/to/activations",
name="pythia-160m-cached",
device="cuda",
)
],
target=ActivationFactoryTarget.ACTIVATIONS_1D,
hook_points=["blocks.6.hook_resid_post"],
batch_size=4096,
),
sae_name="pythia-160m-sae",
sae_series="pythia-sae",
)
train_sae(settings)
Update your training configuration to use the activations source type:
In a full script, you can omit the language model and dataset loading, and directly use ActivationFactory with cached sources:
import torch
from lm_saes import (
ActivationFactory,
ActivationFactoryConfig,
ActivationFactoryActivationsSource,
ActivationFactoryTarget,
SAEConfig,
Trainer,
TrainerConfig,
SparseAutoEncoder,
)
# 1. Configure Activation Factory with Cached Source
factory_cfg = ActivationFactoryConfig(
sources=[
ActivationFactoryActivationsSource(
path="path/to/activations",
name="pythia-160m-cached",
device="cuda",
)
],
target=ActivationFactoryTarget.ACTIVATIONS_1D,
hook_points=["blocks.6.hook_resid_post"],
batch_size=4096,
)
factory = ActivationFactory(factory_cfg)
activations_stream = factory.process()
# 2. Initialize SAE and Trainer
sae = SparseAutoEncoder(SAEConfig(
hook_point_in="blocks.6.hook_resid_post",
hook_point_out="blocks.6.hook_resid_post",
d_model=768,
expansion_factor=8,
act_fn="relu",
device="cuda",
))
trainer = Trainer(TrainerConfig(
lr=1e-4,
total_training_tokens=800_000_000,
exp_result_path="results",
))
# 3. Train
trainer.fit(sae=sae, activation_stream=activations_stream)
Use HuggingFace Backend
While TransformerLens provides a unified set of hook points across numerous Transformer variants, there are compelling reasons to use the Hugging Face Transformers library directly for generating activations:
- Provides a wider range of model architectures. Almost every frontier model is released with official Hugging Face-compatible modeling code, ensuring immediate support for the latest architectures.
- Integrates GPU accelerate kernels, e.g. Flash Attention, for faster activation generation.
- Reduces numerical errors. TransformerLens re-implements model logic using more semantic linear algebraic operation (primarily Einsum), which will inevitably introduce subtle numerical discrepancies. Direct usage of Huggingface ensures the activation used for training sparse dictionaries is just the same as whatever you use for other purpose.
To use the HuggingFace backend, you simply need to change the backend field in LanguageModelConfig to "huggingface", or (in full script) load the model with HuggingFaceLanguageModel wrapper.
-
The Heaviside step function \(H(x)\) is defined as:
\[ H(x) = \begin{cases} 1 & \text{if } x > 0 \\ 0 & \text{otherwise} \end{cases} \] -
See more discussion on \(L^0\) approximate functions in Comparing Measures of Sparsity and Why \(\ell_1\) Is a Good Approximation to \(\ell_0\): A Geometric Explanation. ↩
-
See Fixing Feature Suppression in SAEs for more discussion. ↩
