Models
Sparse dictionary model architectures and their configuration classes.
BaseSAEConfig
pydantic-model
Bases: BaseModelConfig, ABC
Base class for SAE configs with common settings that are able to apply to various SAE variants. This class should not be used directly but only as a base config class for other SAE variants like SAEConfig, CrossCoderConfig, etc.
Fields:
-
device(str) -
dtype(dtype) -
sae_type(str) -
d_model(int) -
expansion_factor(float) -
use_decoder_bias(bool) -
act_fn(Literal['relu', 'jumprelu', 'topk', 'batchtopk', 'batchlayertopk', 'layertopk']) -
norm_activation(Literal['token-wise', 'batch-wise', 'dataset-wise', 'inference']) -
sparsity_include_decoder_norm(bool) -
top_k(int) -
use_triton_kernel(bool) -
sparsity_threshold_for_triton_spmm_kernel(float) -
jumprelu_threshold_window(float)
sae_type
pydantic-field
The type of the sparse dictionary. Must be one of the registered SAE types.
d_model
pydantic-field
The dimension of the input/label activation space. In common settings where activations come from a transformer, this is the dimension of the model (may also known as hidden_size).
expansion_factor
pydantic-field
The expansion factor of the sparse dictionary. The hidden dimension of the sparse dictionary d_sae is d_model * expansion_factor.
use_decoder_bias
pydantic-field
Whether to use a bias term in the decoder. Including bias term may make it easier to train a better sparse dictionary, in exchange for increased architectural complexity.
act_fn
pydantic-field
The activation function to use for the sparse dictionary. Currently supported activation functions are relu, jumprelu, topk, batchtopk, batchlayertopk, and layertopk.
relu: ReLU activation function. Used in the most vanilla SAE settings.jumprelu: JumpReLU activation function, adding a trainable element-wise threshold that pre-activations must pass to be activated, which is formally defined as :math:f(x) = \max(0, x - \theta)where :math:\thetais the threshold. Proposed in Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders.topk: TopK activation function. Retains the top K activations per sample, zeroing out the rest. Proposed in Scaling and evaluating sparse autoencoders.batchtopk: BatchTopK activation function. Batch TopK relaxes TopK function to batch-level, ing the topk * batch_sizeactivations per batch and zeroing out the rest. This allows more adaptive allocation of latents on each sample. Proposed in BatchTopK Sparse Autoencoders.batchlayertopk: (For CrossLayerTranscoder only) Extension of BatchTopK to layer-and-batch-aware, retaining the topk * batch_size * n_layersactivations per batch and layer and zeroing out the rest.layertopk: (For CrossLayerTranscoder only) Extension of BatchTopK to layer-aware, retaining the topk * n_layersactivations per layer and zeroing out the rest. Note that this activation function does not take batch dimension into account.
norm_activation
pydantic-field
norm_activation: Literal[
"token-wise", "batch-wise", "dataset-wise", "inference"
] = "dataset-wise"
The activation normalization strategy to use for the input/label activations. During call of normalize_activations (which will be called by the Trainer during training), the input/label activations will be normalized to an average norm of :math:\sqrt{d_{model}}. This allows easier hyperparameter (mostly learning rate) transfer between different scale of model activations, since the MSE loss without normalization is proportional to the square of the activation norm.
Different activation normalization strategy determines in what view the norm is averaged, with the following options:
- token-wise: Norm is directly computed for activation from each token. No averaging is performed.
- batch-wise: Norm is computed for each batch, then averaged over the batch dimension.
- dataset-wise: Norm is computed from several samples from the activation. Compared to batch-wise, dataset-wise gives a fixed value of average norm for all activations, preserving the linearity of pre-activation encoding and decoding.
- inference: No normalization is performed. A inference mode is produced after calling standardize_parameters_of_dataset_norm method, which folds the dataset-wise average norm into the weights and biases of the model. Switching to inference mode doesn't affect the encoding and decoding as a whole, that is, the reconstructed activations keep the same as the denormalized reconstructed activations in dataset-wise mode. However, the feature activations will reflect the activation scale. This allows real magnitude of feature activations to present during inference.
sparsity_include_decoder_norm
pydantic-field
Whether to include the decoder norm term in feature activation gating. If true, the pre-activation hidden states will be scaled by the decoder norm before applying the activation function, and then scale back after the activation function. Formally, considering activation function :math:f(x), an activation gating function :math:g(x) is defined as :math:g(x) = f(x) / x (element-wise division). When sparsity_include_decoder_norm is True, we replace :math:f(x) with :math:x * g(x * || W_ ext{dec} ||). This effectively suppresses the training dynamics that model tries to increase the decoder norm in exchange of a smaller feature activation magnitude, resulting in lower sparsity loss (L1 norm).
top_k
pydantic-field
The k value to use for the topk family of activation functions. For vanilla TopK, the L0 norm of the feature activations will be exactly equal to top_k.
use_triton_kernel
pydantic-field
Whether to use the Triton SpMM kernel for the sparse matrix multiplication. Currently only supported for vanilla SAE.
sparsity_threshold_for_triton_spmm_kernel
pydantic-field
The sparsity threshold for the Triton SpMM kernel. Only when feature activation sparsity reaches this threshold, the Triton SpMM kernel will be used for the sparse matrix multiplication. This is useful for JumpReLU or TopK with a k annealing schedule, where the sparsity is not guaranteed throughout the training.
jumprelu_threshold_window
pydantic-field
The window size for the JumpReLU threshold. When pre-activations are element-wise in the window-neighborhood of the threshold, the threshold will begin to receive gradient. See Anthropic's Circuits Update - January 2025 for more details on how JumpReLU is optimized (where they refer to this window as :math:\epsilon).
d_sae
property
The hidden dimension of the sparse dictionary. Calculated as d_model * expansion_factor.
associated_hook_points
abstractmethod
property
List of hook points used by the SAE, including all input and label hook points. This is used to retrieve useful data from the input activation source.
from_pretrained
classmethod
Load the config of the sparse dictionary from a pretrained name or path. Config is read from
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pretrained_name_or_path
|
str
|
The path to the pretrained sparse dictionary. |
required |
**kwargs
|
Additional keyword arguments to pass to the config constructor. |
{}
|
Source code in src/lm_saes/abstract_sae.py
SAEConfig
pydantic-model
Bases: BaseSAEConfig
Fields:
-
device(str) -
dtype(dtype) -
d_model(int) -
expansion_factor(float) -
use_decoder_bias(bool) -
act_fn(Literal['relu', 'jumprelu', 'topk', 'batchtopk', 'batchlayertopk', 'layertopk']) -
norm_activation(Literal['token-wise', 'batch-wise', 'dataset-wise', 'inference']) -
sparsity_include_decoder_norm(bool) -
top_k(int) -
use_triton_kernel(bool) -
sparsity_threshold_for_triton_spmm_kernel(float) -
jumprelu_threshold_window(float) -
sae_type(str) -
hook_point_in(str) -
hook_point_out(str) -
use_glu_encoder(bool)
SparseAutoEncoder
SparseAutoEncoder(
cfg: SAEConfig, device_mesh: DeviceMesh | None = None
)
Bases: AbstractSparseAutoEncoder
Source code in src/lm_saes/sae.py
encoder_norm
Compute the norm of the encoder weight.
Source code in src/lm_saes/sae.py
decoder_norm
Compute the norm of the decoder weight.
Source code in src/lm_saes/sae.py
set_decoder_to_fixed_norm
Set the decoder weights to a fixed norm.
Source code in src/lm_saes/sae.py
set_encoder_to_fixed_norm
dim_maps
Return a dictionary mapping parameter names to dimension maps.
Returns:
| Type | Description |
|---|---|
dict[str, DimMap]
|
A dictionary mapping parameter names to DimMap objects. |
Source code in src/lm_saes/sae.py
standardize_parameters_of_dataset_norm
Standardize the parameters of the model to account for dataset_norm during inference. This function should be called during inference by the Initializer.
During training, the activations correspond to an input x where the norm is sqrt(d_model).
However, during inference, the norm of the input x corresponds to the dataset_norm.
To ensure consistency between training and inference, the activations during inference
are scaled by the factor:
scaled_activation = training_activation * (dataset_norm / sqrt(d_model))
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dataset_average_activation_norm
|
dict[str, float]
|
A dictionary where keys represent in or out and values specify the average activation norm of the dataset during inference. dataset_average_activation_norm = { self.cfg.hook_point_in: 1.0, self.cfg.hook_point_out: 1.0, } |
required |
Returns:
| Name | Type | Description |
|---|---|---|
None |
Updates the internal parameters to reflect the standardized activations and change the norm_activation to "inference" mode. |
Source code in src/lm_saes/sae.py
encode
encode(
x: Union[
Float[Tensor, "batch d_model"],
Float[Tensor, "batch seq_len d_model"],
],
return_hidden_pre: Literal[False] = False,
**kwargs,
) -> Union[
Float[Tensor, "batch d_sae"],
Float[Tensor, "batch seq_len d_sae"],
]
encode(
x: Union[
Float[Tensor, "batch d_model"],
Float[Tensor, "batch seq_len d_model"],
],
return_hidden_pre: Literal[True],
**kwargs,
) -> tuple[
Union[
Float[Tensor, "batch d_sae"],
Float[Tensor, "batch seq_len d_sae"],
],
Union[
Float[Tensor, "batch d_sae"],
Float[Tensor, "batch seq_len d_sae"],
],
]
encode(
x: Union[
Float[Tensor, "batch d_model"],
Float[Tensor, "batch seq_len d_model"],
],
return_hidden_pre: bool = False,
**kwargs,
) -> Union[
Float[Tensor, "batch d_sae"],
Float[Tensor, "batch seq_len d_sae"],
tuple[
Union[
Float[Tensor, "batch d_sae"],
Float[Tensor, "batch seq_len d_sae"],
],
Union[
Float[Tensor, "batch d_sae"],
Float[Tensor, "batch seq_len d_sae"],
],
],
]
Encode input tensor through the sparse autoencoder.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Union[Float[Tensor, 'batch d_model'], Float[Tensor, 'batch seq_len d_model']]
|
Input tensor of shape (batch, d_model) or (batch, seq_len, d_model) |
required |
return_hidden_pre
|
bool
|
If True, also return the pre-activation hidden states |
False
|
Returns:
| Type | Description |
|---|---|
Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae'], tuple[Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae']], Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae']]]]
|
If return_hidden_pre is False: Feature activations tensor of shape (batch, d_sae) or (batch, seq_len, d_sae) |
Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae'], tuple[Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae']], Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae']]]]
|
If return_hidden_pre is True: Tuple of (feature_acts, hidden_pre) where both have shape (batch, d_sae) or (batch, seq_len, d_sae) |
Source code in src/lm_saes/sae.py
decode_coo
Decode feature activations back to model space using COO format.
Source code in src/lm_saes/sae.py
init_W_D_with_active_subspace
Initialize W_D with the active subspace.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
dict[str, Tensor]
|
The batch. |
required |
d_active_subspace
|
int
|
The dimension of the active subspace. |
required |
Source code in src/lm_saes/sae.py
CrossCoderConfig
pydantic-model
Bases: BaseSAEConfig
Fields:
-
device(str) -
dtype(dtype) -
d_model(int) -
expansion_factor(float) -
use_decoder_bias(bool) -
act_fn(Literal['relu', 'jumprelu', 'topk', 'batchtopk', 'batchlayertopk', 'layertopk']) -
norm_activation(Literal['token-wise', 'batch-wise', 'dataset-wise', 'inference']) -
sparsity_include_decoder_norm(bool) -
top_k(int) -
use_triton_kernel(bool) -
sparsity_threshold_for_triton_spmm_kernel(float) -
jumprelu_threshold_window(float) -
sae_type(str) -
hook_points(list[str])
CrossCoder
CrossCoder(
cfg: CrossCoderConfig,
device_mesh: Optional[DeviceMesh] = None,
)
Bases: AbstractSparseAutoEncoder
Sparse AutoEncoder model.
An autoencoder model that learns to compress the input activation tensor into a high-dimensional but sparse feature activation tensor.
Can also act as a transcoder model, which learns to compress the input activation tensor into a feature activation tensor, and then reconstruct a label activation tensor from the feature activation tensor.
Source code in src/lm_saes/crosscoder.py
specs
class-attribute
instance-attribute
Tensor specs for CrossCoder with n_heads dimension.
init_parameters
Initialize the weights of the model.
Source code in src/lm_saes/crosscoder.py
encode
encode(
x: Union[
Float[Tensor, "batch d_model"],
Float[Tensor, "batch seq_len d_model"],
],
return_hidden_pre: Literal[False] = False,
*,
no_einsum: bool = True,
**kwargs,
) -> Union[
Float[Tensor, "batch d_sae"],
Float[Tensor, "batch seq_len d_sae"],
]
encode(
x: Union[
Float[Tensor, "batch n_heads d_model"],
Float[Tensor, "batch seq_len n_heads d_model"],
],
return_hidden_pre: Literal[True],
*,
no_einsum: bool = True,
**kwargs,
) -> tuple[
Union[
Float[Tensor, "batch n_heads d_sae"],
Float[Tensor, "batch seq_len n_heads d_sae"],
],
Union[
Float[Tensor, "batch n_heads d_sae"],
Float[Tensor, "batch seq_len n_heads d_sae"],
],
]
encode(
x: Union[
Float[Tensor, "batch n_heads d_model"],
Float[Tensor, "batch seq_len n_heads d_model"],
],
return_hidden_pre: bool = False,
*,
no_einsum: bool = True,
**kwargs,
) -> Union[
Float[Tensor, "batch n_heads d_sae"],
Float[Tensor, "batch seq_len n_heads d_sae"],
tuple[
Union[
Float[Tensor, "batch n_heads d_sae"],
Float[Tensor, "batch seq_len n_heads d_sae"],
],
Union[
Float[Tensor, "batch n_heads d_sae"],
Float[Tensor, "batch seq_len n_heads d_sae"],
],
],
]
Encode the input tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Union[Float[Tensor, 'batch n_heads d_model'], Float[Tensor, 'batch seq_len n_heads d_model']]
|
Input tensor of shape (..., n_heads, d_model). |
required |
Returns:
| Type | Description |
|---|---|
Union[Float[Tensor, 'batch n_heads d_sae'], Float[Tensor, 'batch seq_len n_heads d_sae'], tuple[Union[Float[Tensor, 'batch n_heads d_sae'], Float[Tensor, 'batch seq_len n_heads d_sae']], Union[Float[Tensor, 'batch n_heads d_sae'], Float[Tensor, 'batch seq_len n_heads d_sae']]]]
|
Encoded tensor of shape (..., n_heads, d_sae). |
Source code in src/lm_saes/crosscoder.py
decode
decode(
feature_acts: Union[
Float[Tensor, "batch d_sae"],
Float[Tensor, "batch seq_len d_sae"],
],
*,
no_einsum: bool = True,
**kwargs,
) -> Union[
Float[Tensor, "batch d_model"],
Float[Tensor, "batch seq_len d_model"],
]
Decode the encoded tensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Encoded tensor of shape (n_heads, d_sae). |
required |
Returns:
| Type | Description |
|---|---|
Union[Float[Tensor, 'batch d_model'], Float[Tensor, 'batch seq_len d_model']]
|
Decoded tensor of shape (n_heads, d_model). |
Source code in src/lm_saes/crosscoder.py
decoder_norm
Calculate the norm of the decoder weights.
Returns:
| Type | Description |
|---|---|
Tensor
|
Norm of decoder weights of shape (n_heads, d_sae). |
Source code in src/lm_saes/crosscoder.py
encoder_norm
Calculate the norm of the encoder weights.
Returns:
| Type | Description |
|---|---|
Tensor
|
Norm of encoder weights of shape (n_heads, d_sae). |
Source code in src/lm_saes/crosscoder.py
standardize_parameters_of_dataset_norm
Standardize the parameters of the model to account for dataset_norm during inference.
Source code in src/lm_saes/crosscoder.py
compute_training_metrics
compute_training_metrics(
*,
feature_acts: Tensor,
l_rec: Tensor,
l0: Tensor,
explained_variance: Tensor,
**kwargs,
) -> dict[str, float]
Compute per-head training metrics for CrossCoder.
Source code in src/lm_saes/crosscoder.py
dim_maps
Return a dictionary mapping parameter names to dimension maps.
Returns:
| Type | Description |
|---|---|
dict[str, DimMap]
|
A dictionary mapping parameter names to DimMap objects. |
Source code in src/lm_saes/crosscoder.py
CLTConfig
pydantic-model
Bases: BaseSAEConfig
Configuration for Cross Layer Transcoder (CLT).
A CLT consists of L encoders and L(L+1)/2 decoders where each encoder at layer L reads from the residual stream at that layer and can decode to layers L through L-1.
Fields:
-
device(str) -
dtype(dtype) -
d_model(int) -
expansion_factor(float) -
use_decoder_bias(bool) -
norm_activation(Literal['token-wise', 'batch-wise', 'dataset-wise', 'inference']) -
sparsity_include_decoder_norm(bool) -
top_k(int) -
use_triton_kernel(bool) -
sparsity_threshold_for_triton_spmm_kernel(float) -
jumprelu_threshold_window(float) -
sae_type(str) -
act_fn(Literal['relu', 'jumprelu', 'topk', 'batchtopk', 'batchlayertopk', 'layertopk']) -
init_cross_layer_decoder_all_zero(bool) -
hook_points_in(list[str]) -
hook_points_out(list[str]) -
decode_with_csr(bool) -
sparsity_threshold_for_csr(float)
hook_points_in
pydantic-field
List of hook points to capture input activations from, one for each layer.
hook_points_out
pydantic-field
List of hook points to capture output activations from, one for each layer.
decode_with_csr
pydantic-field
Whether to decode with CSR matrices. If True, will use CSR matrices for decoding. If False, will use dense matrices for decoding.
sparsity_threshold_for_csr
pydantic-field
The sparsity threshold for the CSR matrices. If the sparsity of the feature activations reaches this threshold, the CSR matrices will be used for decoding. The current conditioning for sparsity is dependent on usage of TopK family of activation functions, so this will not work with other activation functions like relu or jumprelu.
CrossLayerTranscoder
CrossLayerTranscoder(
cfg: CLTConfig, device_mesh: Optional[DeviceMesh] = None
)
Bases: AbstractSparseAutoEncoder
Cross Layer Transcoder (CLT) implementation.
A CLT has L encoders (one per layer) and L(L+1)/2 decoders arranged in an upper triangular pattern. Each encoder at layer L reads from the residual stream at that layer, and features can decode to layers L through L-1.
We store all parameters in the same object and shard them across GPUs for efficient distributed training.
Initialize the Cross Layer Transcoder.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg
|
CLTConfig
|
Configuration for the CLT. |
required |
device_mesh
|
Optional[DeviceMesh]
|
Device mesh for distributed training. |
None
|
Source code in src/lm_saes/clt.py
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | |
specs
class-attribute
instance-attribute
Tensor specs for CrossLayerTranscoder with layer dimension.
init_parameters
Initialize parameters.
Encoders: uniformly initialized in range (-1/sqrt(d_sae), 1/sqrt(d_sae)) Decoders at layer L: uniformly initialized in range (-1/sqrt(Ld_model), 1/sqrt(Ld_model))
Source code in src/lm_saes/clt.py
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 | |
get_decoder_weights
Get decoder weights for all layers from 0..layer_to to layer_to.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
layer_to
|
int
|
Target layer (0 to n_layers-1) |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Decoder weights for all source layers to the specified target layer |
Source code in src/lm_saes/clt.py
encode
encode(
x: Union[
Float[Tensor, "batch n_layers d_model"],
Float[Tensor, "batch seq_len n_layers d_model"],
],
return_hidden_pre: Literal[False] = False,
**kwargs,
) -> Union[
Float[Tensor, "batch n_layers d_sae"],
Float[Tensor, "batch seq_len n_layers d_sae"],
]
encode(
x: Union[
Float[Tensor, "batch n_layers d_model"],
Float[Tensor, "batch seq_len n_layers d_model"],
],
return_hidden_pre: Literal[True],
**kwargs,
) -> tuple[
Union[
Float[Tensor, "batch n_layers d_sae"],
Float[Tensor, "batch seq_len n_layers d_sae"],
],
Union[
Float[Tensor, "batch n_layers d_sae"],
Float[Tensor, "batch seq_len n_layers d_sae"],
],
]
encode(
x: Union[
Float[Tensor, "batch n_layers d_model"],
Float[Tensor, "batch seq_len n_layers d_model"],
],
return_hidden_pre: bool = False,
**kwargs,
) -> Union[
Float[Tensor, "batch n_layers d_sae"],
Float[Tensor, "batch seq_len n_layers d_sae"],
tuple[
Union[
Float[Tensor, "batch n_layers d_sae"],
Float[Tensor, "batch seq_len n_layers d_sae"],
],
Union[
Float[Tensor, "batch n_layers d_sae"],
Float[Tensor, "batch seq_len n_layers d_sae"],
],
],
]
Encode input activations to CLT features using L encoders.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Union[Float[Tensor, 'batch n_layers d_model'], Float[Tensor, 'batch seq_len n_layers d_model']]
|
Input activations from all layers (..., n_layers, d_model) |
required |
return_hidden_pre
|
bool
|
Whether to return pre-activation values |
False
|
Returns:
| Type | Description |
|---|---|
Union[Float[Tensor, 'batch n_layers d_sae'], Float[Tensor, 'batch seq_len n_layers d_sae'], tuple[Union[Float[Tensor, 'batch n_layers d_sae'], Float[Tensor, 'batch seq_len n_layers d_sae']], Union[Float[Tensor, 'batch n_layers d_sae'], Float[Tensor, 'batch seq_len n_layers d_sae']]]]
|
Feature activations for all layers (..., n_layers, d_sae) |
Source code in src/lm_saes/clt.py
encode_single_layer
encode_single_layer(
x: Union[
Float[Tensor, "batch d_model"],
Float[Tensor, "batch seq_len d_model"],
],
layer: int,
return_hidden_pre: bool = False,
**kwargs,
) -> Union[
Float[Tensor, "batch d_sae"],
Float[Tensor, "batch seq_len d_sae"],
tuple[
Union[
Float[Tensor, "batch d_sae"],
Float[Tensor, "batch seq_len d_sae"],
],
Union[
Float[Tensor, "batch d_sae"],
Float[Tensor, "batch seq_len d_sae"],
],
],
]
Encode input activations to CLT features using L encoders.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Union[Float[Tensor, 'batch d_model'], Float[Tensor, 'batch seq_len d_model']]
|
Input activations from a given layer (..., d_model) |
required |
layer
|
int
|
The layer to encode |
required |
return_hidden_pre
|
bool
|
Whether to return pre-activation values |
False
|
Returns:
| Type | Description |
|---|---|
Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae'], tuple[Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae']], Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae']]]]
|
Feature activations for the given layer (..., d_sae) |
Source code in src/lm_saes/clt.py
decode
decode(
feature_acts: Union[
Float[Tensor, "batch n_layers d_sae"],
Float[Tensor, "batch seq_len n_layers d_sae"],
List[Float[Tensor, "seq_len d_sae"]],
],
batch_first: bool = False,
**kwargs,
) -> Union[
Float[Tensor, "n_layers batch d_model"],
Float[Tensor, "n_layers batch seq_len d_model"],
Float[Tensor, "batch n_layers d_model"],
Float[Tensor, "batch seq_len n_layers d_model"],
]
Decode CLT features to output activations using the upper triangular pattern.
The output at layer L is the sum of contributions from all layers 0 through L: y_L = Σ_{i=0}^{L} W_D[i→L] @ feature_acts[..., i, :] + b_D[L]
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature_acts
|
Union[Float[Tensor, 'batch n_layers d_sae'], Float[Tensor, 'batch seq_len n_layers d_sae'], List[Float[Tensor, 'seq_len d_sae']]]
|
CLT feature activations (..., n_layers, d_sae) |
required |
Returns:
| Type | Description |
|---|---|
Union[Float[Tensor, 'n_layers batch d_model'], Float[Tensor, 'n_layers batch seq_len d_model'], Float[Tensor, 'batch n_layers d_model'], Float[Tensor, 'batch seq_len n_layers d_model']]
|
Reconstructed activations for all layers (..., n_layers, d_model) |
Source code in src/lm_saes/clt.py
decoder_norm
Compute the effective norm of decoder weights for each feature.
Source code in src/lm_saes/clt.py
encoder_norm
Compute the norm of encoder weights averaged across layers.
Source code in src/lm_saes/clt.py
decoder_bias_norm
Compute the norm of decoder bias for each target layer.
Source code in src/lm_saes/clt.py
set_encoder_to_fixed_norm
keep_only_decoders_for_layer_from
Keep only the decoder norm for the given layer.
Source code in src/lm_saes/clt.py
decoder_norm_per_feature
Compute the norm of decoder weights for each feature. If layer is not None, only compute the norm for the decoder weights from layer to subsequent layers.
Source code in src/lm_saes/clt.py
decoder_norm_per_decoder
Compute the L2 norm of decoder weights for each decoder (layer_from -> layer_to). Returns: norms: torch.Tensor or DTensor of shape (n_decoders,), where n_decoders = n_layers * (n_layers + 1) // 2
Source code in src/lm_saes/clt.py
standardize_parameters_of_dataset_norm
Standardize parameters for dataset-wise normalization during inference.
Source code in src/lm_saes/clt.py
prepare_input
prepare_input(
batch: dict[str, Tensor], **kwargs
) -> tuple[Tensor, dict[str, Any], dict[str, Any]]
Prepare input tensor from batch by stacking all layer activations from hook_points_in.
Source code in src/lm_saes/clt.py
prepare_input_single_layer
prepare_input_single_layer(
batch: dict[str, Tensor], layer: int, **kwargs
) -> tuple[Tensor, dict[str, Any], dict[str, Any]]
Prepare input tensor from batch by stacking all layer activations from hook_points_in.
Source code in src/lm_saes/clt.py
prepare_label
Prepare label tensor from batch using hook_points_out.
Source code in src/lm_saes/clt.py
compute_training_metrics
compute_training_metrics(
*,
l0: Tensor,
explained_variance_legacy: Tensor,
**kwargs,
) -> dict[str, float]
Compute per-layer training metrics for CLT.
Source code in src/lm_saes/clt.py
compute_loss
compute_loss(
batch: dict[str, Tensor],
label: Optional[
Union[
Float[Tensor, "batch d_model"],
Float[Tensor, "batch seq_len d_model"],
]
] = None,
*,
sparsity_loss_type: Literal[
"power", "tanh", "tanh-quad", None
] = None,
tanh_stretch_coefficient: float = 4.0,
frequency_scale: float = 0.01,
p: int = 1,
l1_coefficient: float = 1.0,
return_aux_data: bool = True,
**kwargs,
) -> Union[Float[Tensor, " batch"], dict[str, Any]]
Compute the loss for the autoencoder.
Ensure that the input activations are normalized by calling normalize_activations before calling this method.
Source code in src/lm_saes/clt.py
1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 | |
dim_maps
Return dimension maps for distributed training along feature dimension.
Source code in src/lm_saes/clt.py
LorsaConfig
pydantic-model
Bases: BaseSAEConfig
Configuration for Low Rank Sparse Attention.
Fields:
-
device(str) -
dtype(dtype) -
d_model(int) -
expansion_factor(float) -
use_decoder_bias(bool) -
act_fn(Literal['relu', 'jumprelu', 'topk', 'batchtopk', 'batchlayertopk', 'layertopk']) -
norm_activation(Literal['token-wise', 'batch-wise', 'dataset-wise', 'inference']) -
sparsity_include_decoder_norm(bool) -
top_k(int) -
use_triton_kernel(bool) -
sparsity_threshold_for_triton_spmm_kernel(float) -
jumprelu_threshold_window(float) -
sae_type(str) -
hook_point_in(str) -
hook_point_out(str) -
n_qk_heads(int) -
d_qk_head(int) -
positional_embedding_type(Literal['rotary', 'none']) -
rotary_dim(int) -
rotary_base(int) -
rotary_adjacent_pairs(bool) -
rotary_scale(int) -
use_NTK_by_parts_rope(bool) -
NTK_by_parts_factor(float) -
NTK_by_parts_low_freq_factor(float) -
NTK_by_parts_high_freq_factor(float) -
old_context_len(int) -
n_ctx(int) -
attn_scale(float | None) -
use_post_qk_ln(bool) -
normalization_type(Literal['LN', 'RMS'] | None) -
eps(float)
LowRankSparseAttention
LowRankSparseAttention(
cfg: LorsaConfig,
device_mesh: Optional[DeviceMesh] = None,
)
Bases: AbstractSparseAutoEncoder
Source code in src/lm_saes/lorsa.py
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | |
init_parameters
Initialize parameters.
Source code in src/lm_saes/lorsa.py
init_lorsa_with_mhsa
Initialize Lorsa with Original Multi Head Sparse Attention
Source code in src/lm_saes/lorsa.py
init_W_D_with_active_subspace_per_head
init_W_D_with_active_subspace_per_head(
batch: dict[str, Tensor],
mhsa: Attention | GroupedQueryAttention,
)
Initialize W_D with the active subspace for each head.
Source code in src/lm_saes/lorsa.py
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 | |
init_W_V_with_active_subspace_per_head
init_W_V_with_active_subspace_per_head(
batch: dict[str, Tensor],
mhsa: Attention | GroupedQueryAttention,
)
Initialize W_D with the active subspace for each head.
Source code in src/lm_saes/lorsa.py
encoder_norm
Norm of encoder (Q/K weights).
Source code in src/lm_saes/lorsa.py
decoder_norm
Norm of decoder (O weights).
Source code in src/lm_saes/lorsa.py
decoder_bias_norm
transform_to_unit_decoder_norm
Transform to unit decoder norm.
standardize_parameters_of_dataset_norm
Standardize parameters for dataset norm.
Source code in src/lm_saes/lorsa.py
compute_hidden_pre
compute_hidden_pre(
x: Float[Tensor, "batch seq_len d_model"],
) -> Float[Tensor, "batch seq_len d_sae"]
Compute the hidden pre-activations.
Source code in src/lm_saes/lorsa.py
encode
encode(
x: Float[Tensor, "batch seq_len d_model"],
return_hidden_pre: bool = False,
return_attention_pattern: bool = False,
return_attention_score: bool = False,
**kwargs,
) -> Union[
Float[Tensor, "batch seq_len d_sae"],
Tuple[
Float[Tensor, "batch seq_len d_sae"],
Float[Tensor, "batch seq_len d_sae"],
],
Tuple[
Float[Tensor, "batch seq_len d_sae"],
Float[Tensor, "batch seq_len d_sae"],
Float[Tensor, "batch n_qk_heads q_pos k_pos"],
],
]
Encode to sparse head activations.
Source code in src/lm_saes/lorsa.py
decode
Decode head activations to output.
Source code in src/lm_saes/lorsa.py
set_decoder_to_fixed_norm
Set decoder weights to a fixed norm.
Source code in src/lm_saes/lorsa.py
set_encoder_to_fixed_norm
dim_maps
Return a dictionary mapping parameter names to dimension maps.
Returns:
| Type | Description |
|---|---|
dict[str, DimMap]
|
A dictionary mapping parameter names to DimMap objects. |
Source code in src/lm_saes/lorsa.py
MOLTConfig
pydantic-model
Bases: BaseSAEConfig
Configuration for Mixture of Linear Transforms (MOLT).
MOLT is a more efficient alternative to transcoders that sparsely replaces MLP computation in transformers. It converts dense MLP layers into sparse, interpretable linear transforms.
Config:
arbitrary_types_allowed:True
Fields:
-
device(str) -
dtype(dtype) -
d_model(int) -
expansion_factor(float) -
use_decoder_bias(bool) -
act_fn(Literal['relu', 'jumprelu', 'topk', 'batchtopk', 'batchlayertopk', 'layertopk']) -
norm_activation(Literal['token-wise', 'batch-wise', 'dataset-wise', 'inference']) -
sparsity_include_decoder_norm(bool) -
top_k(int) -
use_triton_kernel(bool) -
sparsity_threshold_for_triton_spmm_kernel(float) -
jumprelu_threshold_window(float) -
sae_type(str) -
hook_point_in(str) -
hook_point_out(str) -
rank_counts(dict[int, int])
rank_counts
pydantic-field
Dictionary mapping rank values to their integer counts. Example: {4: 128, 8: 256, 16: 128} means 128 transforms of rank 4, 256 transforms of rank 8, and 128 transforms of rank 16.
generate_rank_assignments
Generate rank assignment for each of the d_sae linear transforms.
Returns:
| Type | Description |
|---|---|
list[int]
|
List of rank assignments for each transform. |
list[int]
|
For example: [1, 1, 1, 1, 2, 2, 4]. |
Source code in src/lm_saes/molt.py
get_local_rank_assignments
Get rank assignments for a specific local device in distributed running.
Each device gets all rank groups, with each group evenly divided across devices. This ensures consistent encoder/decoder sharding without feature_acts redistribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model_parallel_size
|
int
|
Number of model parallel devices for training and inference. |
required |
Returns:
| Type | Description |
|---|---|
list[int]
|
List of rank assignments for this local device |
list[int]
|
For example: |
list[int]
|
global_rank_assignments = [1, 1, 2, 2], model_parallel_size = 2 -> local_rank_assignments = [1, 2] |
Source code in src/lm_saes/molt.py
MixtureOfLinearTransform
MixtureOfLinearTransform(
cfg: MOLTConfig, device_mesh: DeviceMesh | None = None
)
Bases: AbstractSparseAutoEncoder
Mixture of Linear Transforms (MOLT) model.
MOLT is a sparse autoencoder variant that uses d_sae linear transforms, each with its own rank for UtVt decomposition.
Mathematical Formulation: - Encoder: ϕ(et · x - bt) where ϕ is the activation function - Decoder: Σᵢ fᵢ · (Uᵢ @ Vᵢ @ x) where fᵢ are feature activations - Decoder norm: ||UᵢVᵢ||_F for each transform i
The rank of each transform is determined by the rank_counts configuration, allowing for adaptive model capacity allocation.
Source code in src/lm_saes/molt.py
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 | |
dim_maps
Return dimension maps for distributed training.
Encoder and decoder use consistent sharding: - W_E sharded along d_sae (output) dimension - U/V matrices sharded along transform count (first) dimension This ensures feature_acts from encoder can directly feed decoder without redistribution.
Source code in src/lm_saes/molt.py
encoder_norm
Compute the norm of the encoder weight.
Source code in src/lm_saes/molt.py
decoder_norm
Compute the Frobenius norm of each linear transform's UtVt decomposition.
Source code in src/lm_saes/molt.py
set_encoder_to_fixed_norm
Set encoder weights to a fixed norm.
decode
decode(
feature_acts: Union[
Float[Tensor, "batch d_sae"],
Float[Tensor, "batch seq_len d_sae"],
],
**kwargs,
) -> Union[
Float[Tensor, "batch d_model"],
Float[Tensor, "batch seq_len d_model"],
]
Decode feature activations back to model space using MOLT transforms.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature_acts
|
Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae']]
|
Feature activations from encode() |
required |
**kwargs
|
Must contain 'original_x' - the original input tensor |
{}
|
Returns:
| Type | Description |
|---|---|
Union[Float[Tensor, 'batch d_model'], Float[Tensor, 'batch seq_len d_model']]
|
Reconstructed tensor in model space |
Source code in src/lm_saes/molt.py
compute_training_metrics
Compute per-rank group training metrics for MOLT.
Source code in src/lm_saes/molt.py
forward
forward(
x: Union[
Float[Tensor, "batch d_model"],
Float[Tensor, "batch seq_len d_model"],
],
encoder_kwargs: dict[str, Any] = {},
decoder_kwargs: dict[str, Any] = {},
) -> Union[
Float[Tensor, "batch d_model"],
Float[Tensor, "batch seq_len d_model"],
]
Forward pass through the autoencoder.
Ensure that the input activations are normalized by calling normalize_activations before calling this method.