Low-Rank Sparse Attention (Lorsa)
Low-Rank Sparse Attention (Lorsa) is a specialized sparse dictionary architecture designed to decompose attention layers into interpretable sparse components. Unlike standard SAEs that treat attention as a black box, Lorsa explicitly models the query-key-value structure while maintaining sparsity and interpretability. Lorsa decomposes attention computations into interpretable sparse features that preserve positional information through explicit query-key attention mechanisms. This allows for fine-grained analysis of attention patterns and understanding how models route information based on both content and position.
The architecture was introduced in Towards Understanding the Nature of Attention with Low-Rank Sparse Decomposition (ICLR 2026), which proposes using sparse dictionary learning to address attention superposition—the challenge of disentangling attention-mediated interactions between features at different token positions. For detailed architectural specifications and mathematical formulations, please refer to this paper.
Configuration
Lorsa is configured using the LorsaConfig class. All sparse dictionary models inherit common parameters from BaseSAEConfig. See the Common Configuration Parameters section for the full list of inherited parameters.
Lorsa-Specific Parameters
from lm_saes import LorsaConfig
import torch
lorsa_config = LorsaConfig(
# Hook points
hook_point_in="blocks.13.ln1.hook_normalized",
hook_point_out="blocks.13.hook_attn_out",
# Attention dimensions
n_qk_heads=16,
d_qk_head=128,
n_ctx=2048,
# Positional embeddings
positional_embedding_type="rotary",
rotary_dim=128,
rotary_base=1000000,
rotary_adjacent_pairs=False,
rotary_scale=1,
# NTK-aware RoPE (optional)
use_NTK_by_parts_rope=False,
NTK_by_parts_factor=1.0,
NTK_by_parts_low_freq_factor=1.0,
NTK_by_parts_high_freq_factor=1.0,
old_context_len=2048,
# Attention settings
attn_scale=None,
use_post_qk_ln=True,
normalization_type="RMS",
eps=1e-6,
# Common parameters
d_model=2048,
expansion_factor=32,
act_fn="topk",
top_k=256,
dtype=torch.float32,
device="cuda",
)
Hook Points
| Parameter | Type | Description | Default |
|---|---|---|---|
hook_point_in |
str |
Input hook point, typically the attention input (e.g., blocks.L.ln1.hook_normalized). Must differ from hook_point_out |
Required |
hook_point_out |
str |
Output hook point, typically the attention output (e.g., blocks.L.hook_attn_out). Must differ from hook_point_in |
Required |
Attention Dimensions
| Parameter | Type | Description | Default |
|---|---|---|---|
n_qk_heads |
int |
Number of query-key attention heads | Required |
d_qk_head |
int |
Dimension per query-key head | Required |
n_ctx |
int |
Maximum context length / sequence length | Required |
Number of Value Heads
The number of value heads (output features) is automatically computed as: n_ov_heads = expansion_factor * d_model (same as d_sae). The ov_group_size is n_ov_heads // n_qk_heads.
Positional Embeddings
| Parameter | Type | Description | Default |
|---|---|---|---|
positional_embedding_type |
str |
Type of positional embedding: "rotary" or "none" |
"rotary" |
rotary_dim |
int |
Dimension of rotary embeddings (typically d_qk_head) |
Required |
rotary_base |
int |
Base for rotary embeddings frequency | 10000 |
rotary_adjacent_pairs |
bool |
Whether to apply RoPE on adjacent pairs vs. all dimensions | True |
rotary_scale |
int |
Scaling factor for rotary embeddings | 1 |
NTK-Aware RoPE (for Llama 3.1 and 3.2 herd models)
| Parameter | Type | Description | Default |
|---|---|---|---|
use_NTK_by_parts_rope |
bool |
Enable NTK-aware RoPE scaling for extended context | False |
NTK_by_parts_factor |
float |
NTK scaling factor | 1.0 |
NTK_by_parts_low_freq_factor |
float |
Low-frequency component scaling factor | 1.0 |
NTK_by_parts_high_freq_factor |
float |
High-frequency component scaling factor | 1.0 |
old_context_len |
int |
Original context length before scaling | 2048 |
Attention Settings
| Parameter | Type | Description | Default |
|---|---|---|---|
attn_scale |
float \| None |
Attention scaling factor. If None, uses \(\frac{1}{\sqrt{d_{\text{qk\_head}}}}\) |
None |
use_post_qk_ln |
bool |
Apply LayerNorm/RMSNorm after computing Q and K projections | False |
normalization_type |
str \| None |
Normalization type: "LN" (LayerNorm) or "RMS" (RMSNorm). Only used when use_post_qk_ln=True |
None |
eps |
float |
Epsilon for numerical stability in normalization | 1e-6 |
Training
Basic Training Setup
Lorsa requires 2D activations with sequence dimension preserved (ActivationFactoryTarget.ACTIVATIONS_2D) since it models positional attention patterns:
from lm_saes import (
TrainLorsaSettings,
train_lorsa,
LorsaConfig,
InitializerConfig,
TrainerConfig,
ActivationFactoryConfig,
ActivationFactoryActivationsSource,
ActivationFactoryTarget,
LanguageModelConfig,
)
import torch
settings = TrainLorsaSettings(
sae=LorsaConfig(
hook_point_in="blocks.13.ln1.hook_normalized",
hook_point_out="blocks.13.hook_attn_out",
d_model=2048,
expansion_factor=32,
# Attention configuration
n_qk_heads=16,
d_qk_head=128,
n_ctx=2048,
# RoPE configuration
positional_embedding_type="rotary",
rotary_dim=128,
rotary_base=1000000,
rotary_adjacent_pairs=False,
# Sparsity
act_fn="topk",
top_k=256,
# Normalization
use_post_qk_ln=True,
normalization_type="RMS",
eps=1e-6,
dtype=torch.float32,
device="cuda",
),
initializer=InitializerConfig(
grid_search_init_norm=True,
initialize_lorsa_with_mhsa=True, # Initialize with original attention weights
initialize_W_D_with_active_subspace=True,
model_layer=13,
),
trainer=TrainerConfig(
lr=2e-4,
total_training_tokens=800_000_000,
initial_k=256,
k_warmup_steps=1500,
log_frequency=1000,
exp_result_path="results/lorsa",
),
activation_factory=ActivationFactoryConfig(
sources=[
ActivationFactoryActivationsSource(
path="path/to/cached/activations",
name="lorsa-activations",
device="cuda",
)
],
target=ActivationFactoryTarget.ACTIVATIONS_2D, # Preserve sequence dimension
hook_points=[
"blocks.13.ln1.hook_normalized",
"blocks.13.hook_attn_out",
],
batch_size=16, # Batch size is per-sequence, not per-token
),
sae_name="qwen-lorsa",
sae_series="qwen-interpretability",
model_name="Qwen/Qwen3-1.7B",
model=LanguageModelConfig(
model_name="Qwen/Qwen3-1.7B",
device="cuda",
dtype=torch.float16,
model_from_pretrained_path="path/to/model",
),
data_parallel_size=1,
model_parallel_size=1,
)
train_lorsa(settings)
Initialization Strategy
For Lorsa, initialization from the original model's attention weights is highly recommended:
InitializerConfig(
grid_search_init_norm=True,
initialize_lorsa_with_mhsa=True, # Initialize Q, K from attention weights
initialize_W_D_with_active_subspace=True, # Initialize V, O from attention weights
model_layer=13, # Specify layer to extract attention weights from
)
This initialization helps Lorsa start from a good approximation of the attention computation.
Important Training Considerations
-
Sequence batching: Since Lorsa operates on sequences,
batch_sizeinActivationFactoryConfigrepresents the number of sequences (not tokens). The effective token batch size isbatch_size * n_ctx. -
Memory requirements: Lorsa stores attention patterns and requires more memory than standard SAEs. Consider using parallelism (see distributed-guidelines) reducing batch size.
-
Context length: Ensure
n_ctxinLorsaConfigmatches thecontext_sizeinActivationFactoryConfigduring activation generation.