Low-Rank Sparse Attention (Lorsa)
Low-Rank Sparse Attention (Lorsa) is a specialized sparse dictionary architecture designed to decompose attention layers into interpretable sparse components. Similar to transcoders decomposing MLP computation, Lorsa decomposes attention computation by explicitly modeling a sparse query-key-value structure.
Given an input sequence \(X \in \mathbb{R}^{n_{\text{ctx}} \times d_{\text{model}}}\), Lorsa has:
- \(n_{\text{qk}}\) QK heads, each with projections \(W_q^h, W_k^h \in \mathbb{R}^{d_{\text{model}} \times d_{\text{qk}}}\)
- \(n_{\text{ov}}\) rank-1 OV heads, each with projections \(\mathbf{w}_v^i \in \mathbb{R}^{d_{\text{model}} \times 1}\), \(\mathbf{w}_o^i \in \mathbb{R}^{1 \times d_{\text{model}}}\)
Every group of \(n_{\text{ov}} / n_{\text{qk}}\) consecutive OV heads shares the same QK head. Let \(h\) be the QK head assigned to OV head \(i\). The forward pass for each OV head \(i\) is:
The pre-activations across all OV heads are then passed through a sparsity-inducing activation function \(\sigma(\cdot)\):
where \(\mathbf{Z} = [\mathbf{z}^0, \ldots, \mathbf{z}^{n_{\text{ov}}-1}]\) and \(\tilde{\mathbf{Z}} = [\tilde{\mathbf{z}}^0, \ldots, \tilde{\mathbf{z}}^{n_{\text{ov}}-1}]\). The final output sums the contributions of all OV heads weighted by their activations:
The architecture was introduced in Towards Understanding the Nature of Attention with Low-Rank Sparse Decomposition (ICLR 2026), which proposes using sparse dictionary learning to address attention superposition—the challenge of disentangling attention-mediated interactions between features at different token positions. For detailed architectural specifications and mathematical formulations, please refer to this paper.
Configuration
Lorsa is configured using the LorsaConfig class. All sparse dictionary models inherit common parameters from SparseDictionaryConfig. See the Common Configuration Parameters section for the full list of inherited parameters.
Lorsa-Specific Parameters
from lm_saes import LorsaConfig
import torch
lorsa_config = LorsaConfig(
# Hook points
hook_point_in="blocks.13.ln1.hook_normalized",
hook_point_out="blocks.13.hook_attn_out",
# Attention dimensions
n_qk_heads=16,
d_qk_head=128,
n_ctx=2048,
# Positional embeddings
positional_embedding_type="rotary",
rotary_dim=128,
rotary_base=1000000,
rotary_adjacent_pairs=False,
rotary_scale=1,
# NTK-aware RoPE (optional)
use_NTK_by_parts_rope=False,
NTK_by_parts_factor=1.0,
NTK_by_parts_low_freq_factor=1.0,
NTK_by_parts_high_freq_factor=1.0,
old_context_len=2048,
# Attention settings
attn_scale=None,
use_post_qk_ln=True,
normalization_type="RMS",
eps=1e-6,
# Common parameters
d_model=2048,
expansion_factor=32,
act_fn="topk",
top_k=256,
dtype=torch.float32,
device="cuda",
)
Hook Points
| Parameter | Type | Description | Default |
|---|---|---|---|
hook_point_in |
str |
Input hook point, typically the attention input (e.g., blocks.L.ln1.hook_normalized). |
Required |
hook_point_out |
str |
Output hook point, typically the attention output (e.g., blocks.L.hook_attn_out). |
Required |
Attention Dimensions
We recommend setting d_qk_head to match the target model's head dimension. n_qk_heads can be freely chosen: a natural starting point is n_qk_heads = n_heads * expansion_factor (n_heads is the num of attention heads of target attention layer), though a smaller value is also reasonable if you want to reduce Lorsa's parameter count(not less than n_heads).
| Parameter | Type | Description | Default |
|---|---|---|---|
n_qk_heads |
int |
Number of QK heads. | Required |
d_qk_head |
int |
Dimension per QK head. | Required |
n_ctx |
int |
Maximum context length. | Required |
Number of OV Heads
The number of OV heads is automatically computed as: n_ov_heads = expansion_factor * d_model (same as d_sae).
Positional Embeddings
It is strongly recommended to copy the positional embedding parameters directly from the target model's implementation. Incorrect settings will make it harder for Lorsa to learn the target attention patterns.
| Parameter | Type | Description | Default |
|---|---|---|---|
positional_embedding_type |
str |
Type of positional embedding: "rotary" or "none" |
"rotary" |
rotary_dim |
int |
Dimension of rotary embeddings (typically d_qk_head) |
Required |
rotary_base |
int |
Base for rotary embeddings frequency | 10000 |
rotary_adjacent_pairs |
bool |
Whether to apply RoPE on adjacent pairs | True |
rotary_scale |
int |
Scaling factor of the head dimension for rotary embeddings | 1 |
NTK-Aware RoPE (only for Llama 3.1 and 3.2 herd models)
| Parameter | Type | Description | Default |
|---|---|---|---|
use_NTK_by_parts_rope |
bool |
Enable NTK-aware RoPE scaling for extended context | False |
NTK_by_parts_factor |
float |
NTK scaling factor | 1.0 |
NTK_by_parts_low_freq_factor |
float |
Low-frequency component scaling factor | 1.0 |
NTK_by_parts_high_freq_factor |
float |
High-frequency component scaling factor | 1.0 |
old_context_len |
int |
Original context length before scaling | 2048 |
Attention Computation Details
| Parameter | Type | Description | Default |
|---|---|---|---|
attn_scale |
float | None |
Attention scaling factor. If None, uses \(\frac{1}{\sqrt{d_{\text{qk\_head}}}}\) |
None |
use_post_qk_ln |
bool |
Apply LayerNorm/RMSNorm after computing Q and K projections | False |
normalization_type |
str | None |
Normalization type: "LN" (LayerNorm) or "RMS" (RMSNorm). Only used when use_post_qk_ln=True |
None |
eps |
float |
Epsilon for numerical stability in normalization | 1e-6 |
Initialization Strategy
For Lorsa, initialization from the original model's attention weights is highly recommended:
InitializerConfig(
grid_search_init_norm=True,
initialize_lorsa_with_mhsa=True, # Initialize Q, K from attention weights
initialize_W_D_with_active_subspace=True, # Initialize V, O from attention weights
model_layer=13, # Specify layer to extract attention weights from
)
This initialization helps Lorsa start from a good approximation of the attention computation.
Training
Basic Training Setup
Lorsa requires 2D activations with sequence dimension preserved (ActivationFactoryTarget.ACTIVATIONS_2D) since it models positional attention patterns:
from lm_saes import (
TrainLorsaSettings,
train_lorsa,
LorsaConfig,
InitializerConfig,
TrainerConfig,
ActivationFactoryConfig,
ActivationFactoryActivationsSource,
ActivationFactoryTarget,
LanguageModelConfig,
)
import torch
settings = TrainLorsaSettings(
sae=LorsaConfig(
hook_point_in="blocks.13.ln1.hook_normalized",
hook_point_out="blocks.13.hook_attn_out",
# ... other settings ...
),
initializer=InitializerConfig(
grid_search_init_norm=True,
initialize_lorsa_with_mhsa=True, # Initialize with original attention weights
initialize_W_D_with_active_subspace=True,
model_layer=13,
),
trainer=TrainerConfig(
lr=2e-4,
total_training_tokens=800_000_000,
initial_k=256,
k_warmup_steps=1500,
log_frequency=1000,
exp_result_path="results/lorsa",
),
activation_factory=ActivationFactoryConfig(
sources=[
ActivationFactoryActivationsSource(
path="path/to/cached/activations",
name="lorsa-activations",
device="cuda",
)
],
target=ActivationFactoryTarget.ACTIVATIONS_2D, # Preserve sequence dimension
hook_points=[
"blocks.13.ln1.hook_normalized",
"blocks.13.hook_attn_out",
],
batch_size=16, # Batch size is per-sequence, not per-token
),
sae_name="qwen-lorsa",
sae_series="qwen-interpretability",
model_name="Qwen/Qwen3-1.7B",
model=LanguageModelConfig(
model_name="Qwen/Qwen3-1.7B",
device="cuda",
dtype=torch.float16,
model_from_pretrained_path="path/to/model",
),
data_parallel_size=1,
model_parallel_size=1,
)
train_lorsa(settings)
Important Training Considerations
-
Sequence batching: Since Lorsa operates on sequences,
batch_sizeinActivationFactoryConfigrepresents the number of sequences (not tokens). The effective token batch size isbatch_size * n_ctx. -
Memory requirements: Lorsa stores attention patterns and requires more memory than standard SAEs. Consider using parallelism (see distributed-guidelines) reducing batch size.
-
Context length: Ensure
n_ctxinLorsaConfigmatches thecontext_sizeinActivationFactoryConfigduring activation generation.