Skip to content

Models

Sparse dictionary model architectures and their configuration classes.

BaseSAEConfig pydantic-model

Bases: BaseModelConfig, ABC

Base class for SAE configs with common settings that are able to apply to various SAE variants. This class should not be used directly but only as a base config class for other SAE variants like SAEConfig, CrossCoderConfig, etc.

Fields:

sae_type pydantic-field

sae_type: str

The type of the sparse dictionary. Must be one of the registered SAE types.

d_model pydantic-field

d_model: int

The dimension of the input/label activation space. In common settings where activations come from a transformer, this is the dimension of the model (may also known as hidden_size).

expansion_factor pydantic-field

expansion_factor: float

The expansion factor of the sparse dictionary. The hidden dimension of the sparse dictionary d_sae is d_model * expansion_factor.

use_decoder_bias pydantic-field

use_decoder_bias: bool = True

Whether to use a bias term in the decoder. Including bias term may make it easier to train a better sparse dictionary, in exchange for increased architectural complexity.

act_fn pydantic-field

act_fn: Literal[
    "relu",
    "jumprelu",
    "topk",
    "batchtopk",
    "batchlayertopk",
    "layertopk",
] = "relu"

The activation function to use for the sparse dictionary. Currently supported activation functions are relu, jumprelu, topk, batchtopk, batchlayertopk, and layertopk.

  • relu: ReLU activation function. Used in the most vanilla SAE settings.
  • jumprelu: JumpReLU activation function, adding a trainable element-wise threshold that pre-activations must pass to be activated, which is formally defined as :math:f(x) = \max(0, x - \theta) where :math:\theta is the threshold. Proposed in Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders.
  • topk: TopK activation function. Retains the top K activations per sample, zeroing out the rest. Proposed in Scaling and evaluating sparse autoencoders.
  • batchtopk: BatchTopK activation function. Batch TopK relaxes TopK function to batch-level, ing the top k * batch_size activations per batch and zeroing out the rest. This allows more adaptive allocation of latents on each sample. Proposed in BatchTopK Sparse Autoencoders.
  • batchlayertopk: (For CrossLayerTranscoder only) Extension of BatchTopK to layer-and-batch-aware, retaining the top k * batch_size * n_layers activations per batch and layer and zeroing out the rest.
  • layertopk: (For CrossLayerTranscoder only) Extension of BatchTopK to layer-aware, retaining the top k * n_layers activations per layer and zeroing out the rest. Note that this activation function does not take batch dimension into account.

norm_activation pydantic-field

norm_activation: Literal[
    "token-wise", "batch-wise", "dataset-wise", "inference"
] = "dataset-wise"

The activation normalization strategy to use for the input/label activations. During call of normalize_activations (which will be called by the Trainer during training), the input/label activations will be normalized to an average norm of :math:\sqrt{d_{model}}. This allows easier hyperparameter (mostly learning rate) transfer between different scale of model activations, since the MSE loss without normalization is proportional to the square of the activation norm.

Different activation normalization strategy determines in what view the norm is averaged, with the following options: - token-wise: Norm is directly computed for activation from each token. No averaging is performed. - batch-wise: Norm is computed for each batch, then averaged over the batch dimension. - dataset-wise: Norm is computed from several samples from the activation. Compared to batch-wise, dataset-wise gives a fixed value of average norm for all activations, preserving the linearity of pre-activation encoding and decoding. - inference: No normalization is performed. A inference mode is produced after calling standardize_parameters_of_dataset_norm method, which folds the dataset-wise average norm into the weights and biases of the model. Switching to inference mode doesn't affect the encoding and decoding as a whole, that is, the reconstructed activations keep the same as the denormalized reconstructed activations in dataset-wise mode. However, the feature activations will reflect the activation scale. This allows real magnitude of feature activations to present during inference.

sparsity_include_decoder_norm pydantic-field

sparsity_include_decoder_norm: bool = True

Whether to include the decoder norm term in feature activation gating. If true, the pre-activation hidden states will be scaled by the decoder norm before applying the activation function, and then scale back after the activation function. Formally, considering activation function :math:f(x), an activation gating function :math:g(x) is defined as :math:g(x) = f(x) / x (element-wise division). When sparsity_include_decoder_norm is True, we replace :math:f(x) with :math:x * g(x * || W_ ext{dec} ||). This effectively suppresses the training dynamics that model tries to increase the decoder norm in exchange of a smaller feature activation magnitude, resulting in lower sparsity loss (L1 norm).

top_k pydantic-field

top_k: int = 50

The k value to use for the topk family of activation functions. For vanilla TopK, the L0 norm of the feature activations will be exactly equal to top_k.

use_triton_kernel pydantic-field

use_triton_kernel: bool = False

Whether to use the Triton SpMM kernel for the sparse matrix multiplication. Currently only supported for vanilla SAE.

sparsity_threshold_for_triton_spmm_kernel pydantic-field

sparsity_threshold_for_triton_spmm_kernel: float = 0.996

The sparsity threshold for the Triton SpMM kernel. Only when feature activation sparsity reaches this threshold, the Triton SpMM kernel will be used for the sparse matrix multiplication. This is useful for JumpReLU or TopK with a k annealing schedule, where the sparsity is not guaranteed throughout the training.

jumprelu_threshold_window pydantic-field

jumprelu_threshold_window: float = 2.0

The window size for the JumpReLU threshold. When pre-activations are element-wise in the window-neighborhood of the threshold, the threshold will begin to receive gradient. See Anthropic's Circuits Update - January 2025 for more details on how JumpReLU is optimized (where they refer to this window as :math:\epsilon).

d_sae property

d_sae: int

The hidden dimension of the sparse dictionary. Calculated as d_model * expansion_factor.

associated_hook_points abstractmethod property

associated_hook_points: list[str]

List of hook points used by the SAE, including all input and label hook points. This is used to retrieve useful data from the input activation source.

from_pretrained classmethod

from_pretrained(pretrained_name_or_path: str, **kwargs)

Load the config of the sparse dictionary from a pretrained name or path. Config is read from /config.json (for local storage) or //config.json (for HuggingFace Hub).

Parameters:

Name Type Description Default
pretrained_name_or_path str

The path to the pretrained sparse dictionary.

required
**kwargs

Additional keyword arguments to pass to the config constructor.

{}
Source code in src/lm_saes/abstract_sae.py
@classmethod
def from_pretrained(cls, pretrained_name_or_path: str, **kwargs):
    """Load the config of the sparse dictionary from a pretrained name or path. Config is read from <pretrained_name_or_path>/config.json (for local storage) or <repo_id>/<name>/config.json (for HuggingFace Hub).

    Args:
        pretrained_name_or_path (str): The path to the pretrained sparse dictionary.
        **kwargs: Additional keyword arguments to pass to the config constructor.
    """
    sae_type = auto_infer_pretrained_sae_type(pretrained_name_or_path)
    if sae_type == PretrainedSAEType.LOCAL:
        path = os.path.join(pretrained_name_or_path, "config.json")
    elif sae_type == PretrainedSAEType.HUGGINGFACE:
        repo_id, name = pretrained_name_or_path.split(":")
        path = os.path.join(hf_hub_download(repo_id=repo_id, filename=f"{name}/config.json"), "config.json")
    elif sae_type == PretrainedSAEType.SAELENS:
        raise ValueError(
            "Currently not support directly generating config from SAELens. Try converting the whole model from SAELens through `from_saelens` or `from_pretrained` method instead."
        )
    else:
        raise ValueError(f"Unsupported pretrained type: {sae_type}")

    with open(path, "r") as f:
        sae_config = json.load(f)

    if cls is BaseSAEConfig:
        cls = SAE_TYPE_TO_CONFIG_CLASS[sae_config["sae_type"]]

    return cls.model_validate({**sae_config, **kwargs})

SAEConfig pydantic-model

Bases: BaseSAEConfig

Fields:

SparseAutoEncoder

SparseAutoEncoder(
    cfg: SAEConfig, device_mesh: DeviceMesh | None = None
)

Bases: AbstractSparseAutoEncoder

Source code in src/lm_saes/sae.py
def __init__(self, cfg: SAEConfig, device_mesh: DeviceMesh | None = None):
    super(SparseAutoEncoder, self).__init__(cfg, device_mesh=device_mesh)
    self.cfg = cfg

    if device_mesh is None:
        self.W_E = nn.Parameter(torch.empty(cfg.d_model, cfg.d_sae, device=cfg.device, dtype=cfg.dtype))
        self.b_E = nn.Parameter(torch.empty(cfg.d_sae, device=cfg.device, dtype=cfg.dtype))
        self.W_D = nn.Parameter(torch.empty(cfg.d_sae, cfg.d_model, device=cfg.device, dtype=cfg.dtype))
        if cfg.use_decoder_bias:
            self.b_D = nn.Parameter(torch.empty(cfg.d_model, device=cfg.device, dtype=cfg.dtype))

        if cfg.use_glu_encoder:
            self.W_E_glu = nn.Parameter(torch.empty(cfg.d_model, cfg.d_sae, device=cfg.device, dtype=cfg.dtype))
            self.b_E_glu = nn.Parameter(torch.empty(cfg.d_sae, device=cfg.device, dtype=cfg.dtype))
    else:
        self.W_E = nn.Parameter(
            torch.distributed.tensor.empty(
                cfg.d_model,
                cfg.d_sae,
                dtype=cfg.dtype,
                device_mesh=device_mesh,
                placements=self.dim_maps()["W_E"].placements(device_mesh),
            )
        )
        self.b_E = nn.Parameter(
            torch.distributed.tensor.empty(
                cfg.d_sae,
                dtype=cfg.dtype,
                device_mesh=device_mesh,
                placements=self.dim_maps()["b_E"].placements(device_mesh),
            )
        )
        self.W_D = nn.Parameter(
            torch.distributed.tensor.empty(
                cfg.d_sae,
                cfg.d_model,
                dtype=cfg.dtype,
                device_mesh=device_mesh,
                placements=self.dim_maps()["W_D"].placements(device_mesh),
            )
        )
        if cfg.use_decoder_bias:
            self.b_D = nn.Parameter(
                torch.distributed.tensor.empty(
                    cfg.d_model,
                    dtype=cfg.dtype,
                    device_mesh=device_mesh,
                    placements=self.dim_maps()["b_D"].placements(device_mesh),
                )
            )
        if cfg.use_glu_encoder:
            self.W_E_glu = nn.Parameter(
                torch.distributed.tensor.empty(
                    cfg.d_model,
                    cfg.d_sae,
                    dtype=cfg.dtype,
                    device_mesh=device_mesh,
                    placements=self.dim_maps()["W_E_glu"].placements(device_mesh),
                )
            )
            self.b_E_glu = nn.Parameter(
                torch.distributed.tensor.empty(
                    cfg.d_sae,
                    dtype=cfg.dtype,
                    device_mesh=device_mesh,
                    placements=self.dim_maps()["b_E_glu"].placements(device_mesh),
                )
            )

    self.hook_hidden_pre = HookPoint()
    self.hook_feature_acts = HookPoint()
    self.hook_reconstructed = HookPoint()

encoder_norm

encoder_norm(keepdim: bool = False)

Compute the norm of the encoder weight.

Source code in src/lm_saes/sae.py
@override
def encoder_norm(self, keepdim: bool = False):
    """Compute the norm of the encoder weight."""
    if not isinstance(self.W_E, DTensor):
        return torch.norm(self.W_E, p=2, dim=0, keepdim=keepdim).to(self.cfg.device)
    else:
        assert self.device_mesh is not None
        return DTensor.from_local(
            torch.norm(self.W_E.to_local(), p=2, dim=0, keepdim=keepdim),
            device_mesh=self.device_mesh,
            placements=DimMap({"model": 1 if keepdim else 0}).placements(self.device_mesh),
        )

decoder_norm

decoder_norm(keepdim: bool = False) -> Tensor

Compute the norm of the decoder weight.

Source code in src/lm_saes/sae.py
@override
def decoder_norm(self, keepdim: bool = False) -> torch.Tensor:
    """Compute the norm of the decoder weight."""
    if not isinstance(self.W_D, DTensor):
        return torch.norm(self.W_D, p=2, dim=1, keepdim=keepdim).to(self.cfg.device)
    else:
        assert self.device_mesh is not None
        return DTensor.from_local(
            torch.norm(self.W_D.to_local(), p=2, dim=1, keepdim=keepdim),
            device_mesh=self.device_mesh,
            placements=DimMap({"model": 0}).placements(self.device_mesh),
        )

set_decoder_to_fixed_norm

set_decoder_to_fixed_norm(value: float, force_exact: bool)

Set the decoder weights to a fixed norm.

Source code in src/lm_saes/sae.py
@override
@torch.no_grad()
def set_decoder_to_fixed_norm(self, value: float, force_exact: bool):
    """Set the decoder weights to a fixed norm."""
    if force_exact:
        self.W_D.mul_(value / self.decoder_norm(keepdim=True))
    else:
        self.W_D.mul_(value / torch.clamp(self.decoder_norm(keepdim=True), min=value))

set_encoder_to_fixed_norm

set_encoder_to_fixed_norm(value: float)

Set the encoder weights to a fixed norm.

Source code in src/lm_saes/sae.py
@torch.no_grad()
def set_encoder_to_fixed_norm(self, value: float):
    """Set the encoder weights to a fixed norm."""
    self.W_E.mul_(value / self.encoder_norm(keepdim=True))

dim_maps

dim_maps() -> dict[str, DimMap]

Return a dictionary mapping parameter names to dimension maps.

Returns:

Type Description
dict[str, DimMap]

A dictionary mapping parameter names to DimMap objects.

Source code in src/lm_saes/sae.py
def dim_maps(self) -> dict[str, DimMap]:
    """Return a dictionary mapping parameter names to dimension maps.

    Returns:
        A dictionary mapping parameter names to DimMap objects.
    """
    parent_maps = super().dim_maps()
    sae_maps = {
        "W_E": DimMap({"model": 1}),
        "W_D": DimMap({"model": 0}),
        "b_E": DimMap({"model": 0}),
    }
    if self.cfg.use_decoder_bias:
        sae_maps["b_D"] = DimMap({})
    if self.cfg.use_glu_encoder:
        sae_maps["W_E_glu"] = DimMap({"model": 1})
        sae_maps["b_E_glu"] = DimMap({"model": 0})
    return parent_maps | sae_maps

standardize_parameters_of_dataset_norm

standardize_parameters_of_dataset_norm()

Standardize the parameters of the model to account for dataset_norm during inference. This function should be called during inference by the Initializer.

During training, the activations correspond to an input x where the norm is sqrt(d_model). However, during inference, the norm of the input x corresponds to the dataset_norm. To ensure consistency between training and inference, the activations during inference are scaled by the factor:

scaled_activation = training_activation * (dataset_norm / sqrt(d_model))

Parameters:

Name Type Description Default
dataset_average_activation_norm dict[str, float]

A dictionary where keys represent in or out and values specify the average activation norm of the dataset during inference.

dataset_average_activation_norm = { self.cfg.hook_point_in: 1.0, self.cfg.hook_point_out: 1.0, }

required

Returns:

Name Type Description
None

Updates the internal parameters to reflect the standardized activations and change the norm_activation to "inference" mode.

Source code in src/lm_saes/sae.py
@torch.no_grad()
def standardize_parameters_of_dataset_norm(self):  # should be overridden by subclasses due to side effects
    """
    Standardize the parameters of the model to account for dataset_norm during inference.
    This function should be called during inference by the Initializer.

    During training, the activations correspond to an input `x` where the norm is sqrt(d_model).
    However, during inference, the norm of the input `x` corresponds to the dataset_norm.
    To ensure consistency between training and inference, the activations during inference
    are scaled by the factor:

        scaled_activation = training_activation * (dataset_norm / sqrt(d_model))

    Args:
        dataset_average_activation_norm (dict[str, float]):
            A dictionary where keys represent in or out and values
            specify the average activation norm of the dataset during inference.

            dataset_average_activation_norm = {
                self.cfg.hook_point_in: 1.0,
                self.cfg.hook_point_out: 1.0,
            }

    Returns:
        None: Updates the internal parameters to reflect the standardized activations and change the norm_activation to "inference" mode.
    """
    assert self.cfg.norm_activation == "dataset-wise"
    assert self.dataset_average_activation_norm is not None
    input_norm_factor: float = (
        math.sqrt(self.cfg.d_model) / self.dataset_average_activation_norm[self.cfg.hook_point_in]
    )
    output_norm_factor: float = (
        math.sqrt(self.cfg.d_model) / self.dataset_average_activation_norm[self.cfg.hook_point_out]
    )
    self.b_E.div_(input_norm_factor)
    if self.cfg.use_decoder_bias:
        assert self.b_D is not None, "Decoder bias should exist if use_decoder_bias is True"
        self.b_D.div_(output_norm_factor)
    self.W_D.mul_(input_norm_factor / output_norm_factor)
    self.cfg.norm_activation = "inference"

encode

encode(
    x: Union[
        Float[Tensor, "batch d_model"],
        Float[Tensor, "batch seq_len d_model"],
    ],
    return_hidden_pre: Literal[False] = False,
    **kwargs,
) -> Union[
    Float[Tensor, "batch d_sae"],
    Float[Tensor, "batch seq_len d_sae"],
]
encode(
    x: Union[
        Float[Tensor, "batch d_model"],
        Float[Tensor, "batch seq_len d_model"],
    ],
    return_hidden_pre: Literal[True],
    **kwargs,
) -> tuple[
    Union[
        Float[Tensor, "batch d_sae"],
        Float[Tensor, "batch seq_len d_sae"],
    ],
    Union[
        Float[Tensor, "batch d_sae"],
        Float[Tensor, "batch seq_len d_sae"],
    ],
]
encode(
    x: Union[
        Float[Tensor, "batch d_model"],
        Float[Tensor, "batch seq_len d_model"],
    ],
    return_hidden_pre: bool = False,
    **kwargs,
) -> Union[
    Float[Tensor, "batch d_sae"],
    Float[Tensor, "batch seq_len d_sae"],
    tuple[
        Union[
            Float[Tensor, "batch d_sae"],
            Float[Tensor, "batch seq_len d_sae"],
        ],
        Union[
            Float[Tensor, "batch d_sae"],
            Float[Tensor, "batch seq_len d_sae"],
        ],
    ],
]

Encode input tensor through the sparse autoencoder.

Parameters:

Name Type Description Default
x Union[Float[Tensor, 'batch d_model'], Float[Tensor, 'batch seq_len d_model']]

Input tensor of shape (batch, d_model) or (batch, seq_len, d_model)

required
return_hidden_pre bool

If True, also return the pre-activation hidden states

False

Returns:

Type Description
Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae'], tuple[Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae']], Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae']]]]

If return_hidden_pre is False: Feature activations tensor of shape (batch, d_sae) or (batch, seq_len, d_sae)

Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae'], tuple[Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae']], Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae']]]]

If return_hidden_pre is True: Tuple of (feature_acts, hidden_pre) where both have shape (batch, d_sae) or (batch, seq_len, d_sae)

Source code in src/lm_saes/sae.py
def encode(
    self,
    x: Union[
        Float[torch.Tensor, "batch d_model"],
        Float[torch.Tensor, "batch seq_len d_model"],
    ],
    return_hidden_pre: bool = False,
    **kwargs,
) -> Union[
    Float[torch.Tensor, "batch d_sae"],
    Float[torch.Tensor, "batch seq_len d_sae"],
    tuple[
        Union[
            Float[torch.Tensor, "batch d_sae"],
            Float[torch.Tensor, "batch seq_len d_sae"],
        ],
        Union[
            Float[torch.Tensor, "batch d_sae"],
            Float[torch.Tensor, "batch seq_len d_sae"],
        ],
    ],
]:
    """Encode input tensor through the sparse autoencoder.

    Args:
        x: Input tensor of shape (batch, d_model) or (batch, seq_len, d_model)
        return_hidden_pre: If True, also return the pre-activation hidden states

    Returns:
        If return_hidden_pre is False:
            Feature activations tensor of shape (batch, d_sae) or (batch, seq_len, d_sae)
        If return_hidden_pre is True:
            Tuple of (feature_acts, hidden_pre) where both have shape (batch, d_sae) or (batch, seq_len, d_sae)
    """
    # Pass through encoder
    hidden_pre = x @ self.W_E + self.b_E

    # Apply GLU if configured
    if self.cfg.use_glu_encoder:
        hidden_pre_glu = torch.sigmoid(x @ self.W_E_glu + self.b_E_glu)
        hidden_pre = hidden_pre * hidden_pre_glu

    hidden_pre = self.hook_hidden_pre(hidden_pre)

    # Scale feature activations by decoder norm if configured
    if self.cfg.sparsity_include_decoder_norm:
        hidden_pre = hidden_pre * self.decoder_norm()

    feature_acts = self.activation_function(hidden_pre)
    feature_acts = self.hook_feature_acts(feature_acts)

    if self.cfg.sparsity_include_decoder_norm:
        feature_acts = feature_acts / self.decoder_norm()
        hidden_pre = hidden_pre / self.decoder_norm()

    if return_hidden_pre:
        return feature_acts, hidden_pre
    return feature_acts

decode_coo

decode_coo(
    feature_acts: Float[Tensor, "seq_len d_sae"],
) -> Float[Tensor, "seq_len d_model"]

Decode feature activations back to model space using COO format.

Source code in src/lm_saes/sae.py
def decode_coo(
    self,
    feature_acts: Float[torch.sparse.Tensor, "seq_len d_sae"],
) -> Float[torch.Tensor, "seq_len d_model"]:
    """Decode feature activations back to model space using COO format."""
    reconstructed = feature_acts.to(torch.float32) @ self.W_D.to(torch.float32)
    if self.cfg.use_decoder_bias:
        reconstructed = reconstructed + self.b_D
    return reconstructed.to(self.cfg.dtype)

init_W_D_with_active_subspace

init_W_D_with_active_subspace(
    batch: dict[str, Tensor], d_active_subspace: int
)

Initialize W_D with the active subspace.

Parameters:

Name Type Description Default
batch dict[str, Tensor]

The batch.

required
d_active_subspace int

The dimension of the active subspace.

required
Source code in src/lm_saes/sae.py
@override
@torch.no_grad()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def init_W_D_with_active_subspace(self, batch: dict[str, torch.Tensor], d_active_subspace: int):
    """Initialize W_D with the active subspace.

    Args:
        batch: The batch.
        d_active_subspace: The dimension of the active subspace.
    """
    label = self.prepare_label(batch)
    if self.device_mesh is not None:
        assert isinstance(label, DTensor)
        label = label.to_local()
        torch.distributed.broadcast(tensor=label, group=self.device_mesh.get_group("data"), group_src=0)
    demeaned_label = label - label.mean(dim=0)
    U, S, V = torch.svd(demeaned_label.T.to(torch.float32))
    proj_weight = U[:, :d_active_subspace]  # [d_model, d_active_subspace]
    self.W_D.copy_(self.W_D.data[:, :d_active_subspace] @ proj_weight.T.to(self.cfg.dtype))

CrossCoderConfig pydantic-model

Bases: BaseSAEConfig

Fields:

CrossCoder

CrossCoder(
    cfg: CrossCoderConfig,
    device_mesh: Optional[DeviceMesh] = None,
)

Bases: AbstractSparseAutoEncoder

Sparse AutoEncoder model.

An autoencoder model that learns to compress the input activation tensor into a high-dimensional but sparse feature activation tensor.

Can also act as a transcoder model, which learns to compress the input activation tensor into a feature activation tensor, and then reconstruct a label activation tensor from the feature activation tensor.

Source code in src/lm_saes/crosscoder.py
def __init__(self, cfg: CrossCoderConfig, device_mesh: Optional[DeviceMesh] = None):
    super(CrossCoder, self).__init__(cfg, device_mesh)
    self.cfg = cfg

    # Assertions
    assert cfg.sparsity_include_decoder_norm, "Sparsity should include decoder norm in CrossCoder"
    assert cfg.use_decoder_bias, "Decoder bias should be used in CrossCoder"
    assert not cfg.use_triton_kernel, "Triton kernel is not supported in CrossCoder"

    # Initialize weights and biases
    if device_mesh is None:
        self.W_E = nn.Parameter(
            torch.empty(cfg.n_heads, cfg.d_model, cfg.d_sae, device=cfg.device, dtype=cfg.dtype)
        )
        self.b_E = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_sae, device=cfg.device, dtype=cfg.dtype))
        self.W_D = nn.Parameter(
            torch.empty(cfg.n_heads, cfg.d_sae, cfg.d_model, device=cfg.device, dtype=cfg.dtype)
        )
        self.b_D = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_model, device=cfg.device, dtype=cfg.dtype))
    else:
        self.W_E = nn.Parameter(
            torch.distributed.tensor.empty(
                cfg.n_heads,
                cfg.d_model,
                cfg.d_sae,
                dtype=cfg.dtype,
                device_mesh=device_mesh,
                placements=self.dim_maps()["W_E"].placements(device_mesh),
            )
        )
        self.b_E = nn.Parameter(
            torch.distributed.tensor.empty(
                cfg.n_heads,
                cfg.d_sae,
                dtype=cfg.dtype,
                device_mesh=device_mesh,
                placements=self.dim_maps()["b_E"].placements(device_mesh),
            )
        )
        self.W_D = nn.Parameter(
            torch.distributed.tensor.empty(
                cfg.n_heads,
                cfg.d_sae,
                cfg.d_model,
                dtype=cfg.dtype,
                device_mesh=device_mesh,
                placements=self.dim_maps()["W_D"].placements(device_mesh),
            )
        )
        self.b_D = nn.Parameter(
            torch.distributed.tensor.empty(
                cfg.n_heads,
                cfg.d_model,
                dtype=cfg.dtype,
                device_mesh=device_mesh,
                placements=self.dim_maps()["b_D"].placements(device_mesh),
            )
        )

specs class-attribute instance-attribute

specs: type[TensorSpecs] = CrossCoderSpecs

Tensor specs for CrossCoder with n_heads dimension.

init_parameters

init_parameters(**kwargs) -> None

Initialize the weights of the model.

Source code in src/lm_saes/crosscoder.py
@torch.no_grad()
def init_parameters(self, **kwargs) -> None:
    """Initialize the weights of the model."""
    super().init_parameters(**kwargs)
    # Initialize a single head's weights
    W_E_per_head = torch.empty(
        self.cfg.d_model, self.cfg.d_sae, device=self.cfg.device, dtype=self.cfg.dtype
    ).uniform_(-kwargs["encoder_uniform_bound"], kwargs["encoder_uniform_bound"])
    W_D_per_head = torch.empty(
        self.cfg.d_sae, self.cfg.d_model, device=self.cfg.device, dtype=self.cfg.dtype
    ).uniform_(-kwargs["decoder_uniform_bound"], kwargs["decoder_uniform_bound"])

    # Repeat for all heads
    if self.device_mesh is None:
        W_E = einops.repeat(W_E_per_head, "d_model d_sae -> n_heads d_model d_sae", n_heads=self.cfg.n_heads)
        W_D = einops.repeat(W_D_per_head, "d_sae d_model -> n_heads d_sae d_model", n_heads=self.cfg.n_heads)
        b_E = torch.zeros(self.cfg.n_heads, self.cfg.d_sae, device=self.cfg.device, dtype=self.cfg.dtype)
        b_D = torch.zeros(self.cfg.n_heads, self.cfg.d_model, device=self.cfg.device, dtype=self.cfg.dtype)
    else:
        with timer.time("init_parameters_distributed"):
            W_E_slices = self.dim_maps()["W_E"].local_slices(
                (self.cfg.n_heads, self.cfg.d_model, self.cfg.d_sae), self.device_mesh
            )
            W_D_slices = self.dim_maps()["W_D"].local_slices(
                (self.cfg.n_heads, self.cfg.d_sae, self.cfg.d_model), self.device_mesh
            )
            W_E_head_repeats = get_slice_length(W_E_slices[0], self.cfg.n_heads)
            W_D_head_repeats = get_slice_length(W_D_slices[0], self.cfg.n_heads)
            W_E_local = einops.repeat(
                W_E_per_head[*W_E_slices[1:]], "d_model d_sae -> n_heads d_model d_sae", n_heads=W_E_head_repeats
            )
            W_D_local = einops.repeat(
                W_D_per_head[*W_D_slices[1:]], "d_sae d_model -> n_heads d_sae d_model", n_heads=W_D_head_repeats
            )
            W_E = DTensor.from_local(
                W_E_local, self.device_mesh, self.dim_maps()["W_E"].placements(self.device_mesh)
            )
            W_D = DTensor.from_local(
                W_D_local, self.device_mesh, self.dim_maps()["W_D"].placements(self.device_mesh)
            )
            b_E = torch.distributed.tensor.zeros(
                self.cfg.n_heads,
                self.cfg.d_sae,
                device_mesh=self.device_mesh,
                placements=self.dim_maps()["b_E"].placements(self.device_mesh),
                dtype=self.cfg.dtype,
            )
            b_D = torch.distributed.tensor.zeros(
                self.cfg.n_heads,
                self.cfg.d_model,
                device_mesh=self.device_mesh,
                placements=self.dim_maps()["b_D"].placements(self.device_mesh),
                dtype=self.cfg.dtype,
            )

    # Assign to parameters
    self.W_E.copy_(W_E)
    self.W_D.copy_(W_D)
    self.b_E.copy_(b_E)
    self.b_D.copy_(b_D)

encode

encode(
    x: Union[
        Float[Tensor, "batch d_model"],
        Float[Tensor, "batch seq_len d_model"],
    ],
    return_hidden_pre: Literal[False] = False,
    *,
    no_einsum: bool = True,
    **kwargs,
) -> Union[
    Float[Tensor, "batch d_sae"],
    Float[Tensor, "batch seq_len d_sae"],
]
encode(
    x: Union[
        Float[Tensor, "batch n_heads d_model"],
        Float[Tensor, "batch seq_len n_heads d_model"],
    ],
    return_hidden_pre: Literal[True],
    *,
    no_einsum: bool = True,
    **kwargs,
) -> tuple[
    Union[
        Float[Tensor, "batch n_heads d_sae"],
        Float[Tensor, "batch seq_len n_heads d_sae"],
    ],
    Union[
        Float[Tensor, "batch n_heads d_sae"],
        Float[Tensor, "batch seq_len n_heads d_sae"],
    ],
]
encode(
    x: Union[
        Float[Tensor, "batch n_heads d_model"],
        Float[Tensor, "batch seq_len n_heads d_model"],
    ],
    return_hidden_pre: bool = False,
    *,
    no_einsum: bool = True,
    **kwargs,
) -> Union[
    Float[Tensor, "batch n_heads d_sae"],
    Float[Tensor, "batch seq_len n_heads d_sae"],
    tuple[
        Union[
            Float[Tensor, "batch n_heads d_sae"],
            Float[Tensor, "batch seq_len n_heads d_sae"],
        ],
        Union[
            Float[Tensor, "batch n_heads d_sae"],
            Float[Tensor, "batch seq_len n_heads d_sae"],
        ],
    ],
]

Encode the input tensor.

Parameters:

Name Type Description Default
x Union[Float[Tensor, 'batch n_heads d_model'], Float[Tensor, 'batch seq_len n_heads d_model']]

Input tensor of shape (..., n_heads, d_model).

required

Returns:

Type Description
Union[Float[Tensor, 'batch n_heads d_sae'], Float[Tensor, 'batch seq_len n_heads d_sae'], tuple[Union[Float[Tensor, 'batch n_heads d_sae'], Float[Tensor, 'batch seq_len n_heads d_sae']], Union[Float[Tensor, 'batch n_heads d_sae'], Float[Tensor, 'batch seq_len n_heads d_sae']]]]

Encoded tensor of shape (..., n_heads, d_sae).

Source code in src/lm_saes/crosscoder.py
@override
@timer.time("encode")
def encode(
    self,
    x: Union[
        Float[torch.Tensor, "batch n_heads d_model"],
        Float[torch.Tensor, "batch seq_len n_heads d_model"],
    ],
    return_hidden_pre: bool = False,
    *,
    no_einsum: bool = True,
    **kwargs,
) -> Union[
    Float[torch.Tensor, "batch n_heads d_sae"],
    Float[torch.Tensor, "batch seq_len n_heads d_sae"],
    tuple[
        Union[
            Float[torch.Tensor, "batch n_heads d_sae"],
            Float[torch.Tensor, "batch seq_len n_heads d_sae"],
        ],
        Union[
            Float[torch.Tensor, "batch n_heads d_sae"],
            Float[torch.Tensor, "batch seq_len n_heads d_sae"],
        ],
    ],
]:
    """Encode the input tensor.

    Args:
        x: Input tensor of shape (..., n_heads, d_model).

    Returns:
        Encoded tensor of shape (..., n_heads, d_sae).
    """
    # Apply encoding per head
    hidden_pre = self._apply_encoding(x, no_einsum=no_einsum)

    # Sum across heads and add bias
    if not isinstance(hidden_pre, DTensor):
        accumulated_hidden_pre = torch.sum(hidden_pre, dim=-2)  # "... n_heads d_sae -> ... d_sae"
    else:
        accumulated_hidden_pre = cast(
            DTensor,
            cast(
                DTensor,
                local_map(
                    lambda x: torch.sum(x, dim=-2, keepdim=True),
                    list(hidden_pre.placements),
                )(hidden_pre),
            ).sum(dim=-2),
        )  # "... n_heads d_sae -> ... d_sae"

        with timer.time("encode_redistribute_tensor_pre_repeat"):
            accumulated_hidden_pre = DimMap({"data": 0, "model": -1}).redistribute(accumulated_hidden_pre)

    accumulated_hidden_pre = einops.repeat(
        accumulated_hidden_pre, "... d_sae -> ... n_heads d_sae", n_heads=self.cfg.n_heads
    )

    with timer.time("encode_redistribute_tensor_post_repeat"):
        if isinstance(accumulated_hidden_pre, DTensor):
            accumulated_hidden_pre = DimMap({"data": 0, "head": -2, "model": -1}).redistribute(
                accumulated_hidden_pre
            )

    # Apply activation function
    feature_acts = self.activation_function(accumulated_hidden_pre * self.decoder_norm()) / self.decoder_norm()

    if return_hidden_pre:
        return feature_acts, accumulated_hidden_pre
    return feature_acts

decode

decode(
    feature_acts: Union[
        Float[Tensor, "batch d_sae"],
        Float[Tensor, "batch seq_len d_sae"],
    ],
    *,
    no_einsum: bool = True,
    **kwargs,
) -> Union[
    Float[Tensor, "batch d_model"],
    Float[Tensor, "batch seq_len d_model"],
]

Decode the encoded tensor.

Parameters:

Name Type Description Default
x

Encoded tensor of shape (n_heads, d_sae).

required

Returns:

Type Description
Union[Float[Tensor, 'batch d_model'], Float[Tensor, 'batch seq_len d_model']]

Decoded tensor of shape (n_heads, d_model).

Source code in src/lm_saes/crosscoder.py
@override
@timer.time("decode")
def decode(
    self,
    feature_acts: Union[
        Float[torch.Tensor, "batch d_sae"],
        Float[torch.Tensor, "batch seq_len d_sae"],
    ],
    *,
    no_einsum: bool = True,
    **kwargs,
) -> Union[
    Float[torch.Tensor, "batch d_model"],
    Float[torch.Tensor, "batch seq_len d_model"],
]:  # may be overridden by subclasses
    """Decode the encoded tensor.

    Args:
        x: Encoded tensor of shape (n_heads, d_sae).

    Returns:
        Decoded tensor of shape (n_heads, d_model).
    """
    return self._apply_decoding(feature_acts, no_einsum=no_einsum)

decoder_norm

decoder_norm(keepdim: bool = False) -> Tensor

Calculate the norm of the decoder weights.

Returns:

Type Description
Tensor

Norm of decoder weights of shape (n_heads, d_sae).

Source code in src/lm_saes/crosscoder.py
@override
def decoder_norm(self, keepdim: bool = False) -> torch.Tensor:
    """Calculate the norm of the decoder weights.

    Returns:
        Norm of decoder weights of shape (n_heads, d_sae).
    """
    with timer.time("decoder_norm_computation"):
        if not isinstance(self.W_D, DTensor):
            return torch.norm(self.W_D, dim=-1, keepdim=keepdim)
        else:
            assert self.device_mesh is not None
            return DTensor.from_local(
                torch.norm(self.W_D.to_local(), dim=-1, keepdim=keepdim),
                device_mesh=self.device_mesh,
                placements=DimMap({"head": 0, "model": 1}).placements(self.device_mesh),
            )

encoder_norm

encoder_norm(keepdim: bool = False) -> Tensor

Calculate the norm of the encoder weights.

Returns:

Type Description
Tensor

Norm of encoder weights of shape (n_heads, d_sae).

Source code in src/lm_saes/crosscoder.py
@override
@timer.time("encoder_norm")
def encoder_norm(self, keepdim: bool = False) -> torch.Tensor:
    """Calculate the norm of the encoder weights.

    Returns:
        Norm of encoder weights of shape (n_heads, d_sae).
    """
    return torch.norm(self.W_E, dim=-2, keepdim=keepdim)

standardize_parameters_of_dataset_norm

standardize_parameters_of_dataset_norm()

Standardize the parameters of the model to account for dataset_norm during inference.

Source code in src/lm_saes/crosscoder.py
@override
@timer.time("standardize_parameters_of_dataset_norm")
@torch.no_grad()
def standardize_parameters_of_dataset_norm(self):
    """
    Standardize the parameters of the model to account for dataset_norm during inference.
    """
    assert self.cfg.norm_activation == "dataset-wise"
    assert self.dataset_average_activation_norm is not None
    norm_factors = torch.tensor(
        [
            math.sqrt(self.cfg.d_model) / self.dataset_average_activation_norm[hook_point]
            for hook_point in self.cfg.hook_points
        ],
        dtype=self.cfg.dtype,
        device=self.cfg.device,
    )
    self.b_E.div_(norm_factors.view(self.cfg.n_heads, 1))
    self.b_D.div_(norm_factors.view(self.cfg.n_heads, 1))
    self.cfg.norm_activation = "inference"

compute_training_metrics

compute_training_metrics(
    *,
    feature_acts: Tensor,
    l_rec: Tensor,
    l0: Tensor,
    explained_variance: Tensor,
    **kwargs,
) -> dict[str, float]

Compute per-head training metrics for CrossCoder.

Source code in src/lm_saes/crosscoder.py
@override
@torch.no_grad()
def compute_training_metrics(
    self,
    *,
    feature_acts: torch.Tensor,
    l_rec: torch.Tensor,
    l0: torch.Tensor,
    explained_variance: torch.Tensor,
    **kwargs,
) -> dict[str, float]:
    """Compute per-head training metrics for CrossCoder."""
    assert explained_variance.ndim == 1 and len(explained_variance) == len(self.cfg.hook_points)
    feature_act_spec = self.specs.feature_acts(feature_acts)
    l0_spec = tuple(spec for spec in feature_act_spec if spec != "sae")
    l_rec_spec = tuple(
        spec for spec in feature_act_spec if spec != "model" and spec != "batch" and spec != "context"
    )
    metrics = {}
    for i, k in enumerate(self.cfg.hook_points):
        metrics.update(
            {
                f"crosscoder_metrics/{k}/explained_variance": item(explained_variance[i].mean()),
                f"crosscoder_metrics/{k}/l0": item(l0.select(l0_spec.index("heads"), i).mean()),
                f"crosscoder_metrics/{k}/l_rec": item(l_rec.select(l_rec_spec.index("heads"), i).mean()),
            }
        )
    indices = feature_acts.amax(dim=1).nonzero(as_tuple=True)
    activated_feature_acts = feature_acts.permute(0, 2, 1)[indices].permute(1, 0)
    activated_decoder_norms = full_tensor(self.decoder_norm())[:, indices[1]]
    mean_decoder_norm_non_activated_in_activated = item(activated_decoder_norms[activated_feature_acts == 0].mean())
    mean_decoder_norm_activated_in_activated = item(activated_decoder_norms[activated_feature_acts != 0].mean())
    metrics.update(
        {
            "crosscoder_metrics/mean_decoder_norm_non_activated_in_activated": mean_decoder_norm_non_activated_in_activated,
            "crosscoder_metrics/mean_decoder_norm_activated_in_activated": mean_decoder_norm_activated_in_activated,
        }
    )
    return metrics

dim_maps

dim_maps() -> dict[str, DimMap]

Return a dictionary mapping parameter names to dimension maps.

Returns:

Type Description
dict[str, DimMap]

A dictionary mapping parameter names to DimMap objects.

Source code in src/lm_saes/crosscoder.py
def dim_maps(self) -> dict[str, DimMap]:
    """Return a dictionary mapping parameter names to dimension maps.

    Returns:
        A dictionary mapping parameter names to DimMap objects.
    """
    parent_maps = super().dim_maps()
    crosscoder_maps = {
        "W_E": DimMap({"head": 0, "model": 2}),
        "W_D": DimMap({"head": 0, "model": 1}),
        "b_E": DimMap({"head": 0, "model": 1}),
        "b_D": DimMap({"head": 0}),
    }
    return parent_maps | crosscoder_maps

CLTConfig pydantic-model

Bases: BaseSAEConfig

Configuration for Cross Layer Transcoder (CLT).

A CLT consists of L encoders and L(L+1)/2 decoders where each encoder at layer L reads from the residual stream at that layer and can decode to layers L through L-1.

Fields:

hook_points_in pydantic-field

hook_points_in: list[str]

List of hook points to capture input activations from, one for each layer.

hook_points_out pydantic-field

hook_points_out: list[str]

List of hook points to capture output activations from, one for each layer.

decode_with_csr pydantic-field

decode_with_csr: bool = False

Whether to decode with CSR matrices. If True, will use CSR matrices for decoding. If False, will use dense matrices for decoding.

sparsity_threshold_for_csr pydantic-field

sparsity_threshold_for_csr: float = 0.05

The sparsity threshold for the CSR matrices. If the sparsity of the feature activations reaches this threshold, the CSR matrices will be used for decoding. The current conditioning for sparsity is dependent on usage of TopK family of activation functions, so this will not work with other activation functions like relu or jumprelu.

n_layers property

n_layers: int

Number of layers in the CLT.

n_decoders property

n_decoders: int

Number of decoders in the CLT.

associated_hook_points property

associated_hook_points: list[str]

All hook points used by the CLT.

CrossLayerTranscoder

CrossLayerTranscoder(
    cfg: CLTConfig, device_mesh: Optional[DeviceMesh] = None
)

Bases: AbstractSparseAutoEncoder

Cross Layer Transcoder (CLT) implementation.

A CLT has L encoders (one per layer) and L(L+1)/2 decoders arranged in an upper triangular pattern. Each encoder at layer L reads from the residual stream at that layer, and features can decode to layers L through L-1.

We store all parameters in the same object and shard them across GPUs for efficient distributed training.

Initialize the Cross Layer Transcoder.

Parameters:

Name Type Description Default
cfg CLTConfig

Configuration for the CLT.

required
device_mesh Optional[DeviceMesh]

Device mesh for distributed training.

None
Source code in src/lm_saes/clt.py
def __init__(self, cfg: CLTConfig, device_mesh: Optional[DeviceMesh] = None):
    """Initialize the Cross Layer Transcoder.

    Args:
        cfg: Configuration for the CLT.
        device_mesh: Device mesh for distributed training.
    """
    super().__init__(cfg, device_mesh)
    self.cfg = cfg
    # CLT requires specific configuration settings
    # assert not cfg.sparsity_include_decoder_norm, "CLT requires sparsity_include_decoder_norm=False"
    # assert cfg.use_decoder_bias, "CLT requires use_decoder_bias=True"

    # Initialize weights and biases for cross-layer architecture
    if device_mesh is None:
        # L encoders: one for each layer
        self.W_E = nn.Parameter(
            torch.empty(cfg.n_layers, cfg.d_model, cfg.d_sae, device=cfg.device, dtype=cfg.dtype)
        )
        self.b_E = nn.Parameter(torch.empty(cfg.n_layers, cfg.d_sae, device=cfg.device, dtype=cfg.dtype))

        # L decoder groups: W_D[i] contains decoders from layers 0..i to layer i
        self.W_D = nn.ParameterList(
            [
                nn.Parameter(data=torch.empty(i + 1, cfg.d_sae, cfg.d_model, device=cfg.device, dtype=cfg.dtype))
                for i in range(cfg.n_layers)
            ]
        )

        # L decoder biases: one bias per target layer
        self.b_D = nn.ParameterList(
            [
                nn.Parameter(torch.empty(cfg.d_model, device=cfg.device, dtype=cfg.dtype))
                for _ in range(cfg.n_layers)
            ]
        )
    else:
        # Distributed initialization - shard along feature dimension
        self.W_E = nn.Parameter(
            torch.distributed.tensor.empty(
                cfg.n_layers,
                cfg.d_model,
                cfg.d_sae,
                dtype=cfg.dtype,
                device_mesh=device_mesh,
                placements=self.dim_maps()["W_E"].placements(device_mesh),
            )  # shard along d_sae
        )
        self.b_E = nn.Parameter(
            torch.distributed.tensor.empty(
                cfg.n_layers,
                cfg.d_sae,
                dtype=cfg.dtype,
                device_mesh=device_mesh,
                placements=self.dim_maps()["b_E"].placements(device_mesh),
            )  # shard along d_sae
        )

        # L decoder groups: W_D[i] contains decoders from layers 0..i to layer i
        self.W_D = nn.ParameterList(
            [
                nn.Parameter(
                    torch.distributed.tensor.empty(
                        i + 1,
                        cfg.d_sae,
                        cfg.d_model,
                        dtype=cfg.dtype,
                        device_mesh=device_mesh,
                        placements=self.dim_maps()["W_D"].placements(device_mesh),
                    )
                )  # shard along d_sae
                for i in range(cfg.n_layers)
            ]
        )

        self.b_D = nn.ParameterList(
            [
                nn.Parameter(
                    torch.distributed.tensor.empty(
                        cfg.d_model,
                        dtype=cfg.dtype,
                        device_mesh=device_mesh,
                        placements=self.dim_maps()["b_D"].placements(device_mesh),
                    )
                )
                for _ in range(cfg.n_layers)
            ]
        )

specs class-attribute instance-attribute

specs: type[TensorSpecs] = CrossLayerTranscoderSpecs

Tensor specs for CrossLayerTranscoder with layer dimension.

init_parameters

init_parameters(**kwargs)

Initialize parameters.

Encoders: uniformly initialized in range (-1/sqrt(d_sae), 1/sqrt(d_sae)) Decoders at layer L: uniformly initialized in range (-1/sqrt(Ld_model), 1/sqrt(Ld_model))

Source code in src/lm_saes/clt.py
@override
@torch.no_grad()
def init_parameters(self, **kwargs):
    """Initialize parameters.

    Encoders: uniformly initialized in range (-1/sqrt(d_sae), 1/sqrt(d_sae))
    Decoders at layer L: uniformly initialized in range (-1/sqrt(L*d_model), 1/sqrt(L*d_model))
    """
    super().init_parameters(**kwargs)  # jump ReLU threshold is initialized in super()

    # Initialize encoder weights and biases
    encoder_bound = 1.0 / math.sqrt(self.cfg.d_sae)

    if self.device_mesh is None:
        # Non-distributed initialization

        # Initialize encoder weights: (n_layers, d_model, d_sae)
        W_E = torch.empty(
            self.cfg.n_layers, self.cfg.d_model, self.cfg.d_sae, device=self.cfg.device, dtype=self.cfg.dtype
        ).uniform_(-encoder_bound, encoder_bound)

        # Initialize encoder biases: (n_layers, d_sae) - set to zero
        nn.init.zeros_(self.b_E)

        # Initialize decoder weights
        W_D_initialized = []
        scale = 1.0 / math.sqrt(self.cfg.n_layers * self.cfg.d_model)
        for layer_to in range(self.cfg.n_layers):
            # Initialize decoder weights for layer layer_to
            # W_D[layer_to] has shape (layer_to+1, d_sae, d_model)
            # Scale by 1/sqrt(L*d_model) where L is the number of contributing layers

            W_D_layer = torch.empty(
                layer_to + 1, self.cfg.d_sae, self.cfg.d_model, device=self.cfg.device, dtype=self.cfg.dtype
            )
            nn.init.uniform_(W_D_layer, -scale, scale)
            if self.cfg.init_cross_layer_decoder_all_zero:
                W_D_layer[:-1] = 0
            W_D_initialized.append(W_D_layer)

        # Initialize decoder biases
        for layer_to in range(self.cfg.n_layers):
            # Initialize decoder bias for layer layer_to to zero
            nn.init.zeros_(self.b_D[layer_to])

    else:
        # Distributed initialization
        # Initialize encoder weights
        W_E_local = torch.empty(
            self.cfg.n_layers, self.cfg.d_model, self.cfg.d_sae, device=self.cfg.device, dtype=self.cfg.dtype
        ).uniform_(-encoder_bound, encoder_bound)
        W_E = self.dim_maps()["W_E"].distribute(W_E_local, self.device_mesh)

        # Initialize encoder biases
        nn.init.zeros_(self.b_E)

        # Initialize decoder weights for each layer
        W_D_initialized = []
        for layer_to in range(self.cfg.n_layers):
            decoder_bound = 1.0 / math.sqrt(self.cfg.n_layers * self.cfg.d_model)
            W_D_layer_local = torch.empty(
                layer_to + 1, self.cfg.d_sae, self.cfg.d_model, device=self.cfg.device, dtype=self.cfg.dtype
            ).uniform_(-decoder_bound, decoder_bound)
            if self.cfg.init_cross_layer_decoder_all_zero:
                W_D_layer_local[:-1] = 0
            W_D_layer = self.dim_maps()["W_D"].distribute(tensor=W_D_layer_local, device_mesh=self.device_mesh)
            W_D_initialized.append(W_D_layer)

        # Initialize decoder biases
        for layer_to in range(self.cfg.n_layers):
            # Initialize decoder bias for layer layer_to to zero
            nn.init.zeros_(self.b_D[layer_to])

    # Copy initialized values to parameters
    self.W_E.copy_(W_E)

    for layer_to, W_D_layer in enumerate(W_D_initialized):
        self.W_D[layer_to].copy_(W_D_layer)

get_decoder_weights

get_decoder_weights(layer_to: int) -> Tensor

Get decoder weights for all layers from 0..layer_to to layer_to.

Parameters:

Name Type Description Default
layer_to int

Target layer (0 to n_layers-1)

required

Returns:

Type Description
Tensor

Decoder weights for all source layers to the specified target layer

Source code in src/lm_saes/clt.py
def get_decoder_weights(self, layer_to: int) -> torch.Tensor:
    """Get decoder weights for all layers from 0..layer_to to layer_to.

    Args:
        layer_to: Target layer (0 to n_layers-1)

    Returns:
        Decoder weights for all source layers to the specified target layer
    """
    return self.W_D[layer_to]

encode

encode(
    x: Union[
        Float[Tensor, "batch n_layers d_model"],
        Float[Tensor, "batch seq_len n_layers d_model"],
    ],
    return_hidden_pre: Literal[False] = False,
    **kwargs,
) -> Union[
    Float[Tensor, "batch n_layers d_sae"],
    Float[Tensor, "batch seq_len n_layers d_sae"],
]
encode(
    x: Union[
        Float[Tensor, "batch n_layers d_model"],
        Float[Tensor, "batch seq_len n_layers d_model"],
    ],
    return_hidden_pre: Literal[True],
    **kwargs,
) -> tuple[
    Union[
        Float[Tensor, "batch n_layers d_sae"],
        Float[Tensor, "batch seq_len n_layers d_sae"],
    ],
    Union[
        Float[Tensor, "batch n_layers d_sae"],
        Float[Tensor, "batch seq_len n_layers d_sae"],
    ],
]
encode(
    x: Union[
        Float[Tensor, "batch n_layers d_model"],
        Float[Tensor, "batch seq_len n_layers d_model"],
    ],
    return_hidden_pre: bool = False,
    **kwargs,
) -> Union[
    Float[Tensor, "batch n_layers d_sae"],
    Float[Tensor, "batch seq_len n_layers d_sae"],
    tuple[
        Union[
            Float[Tensor, "batch n_layers d_sae"],
            Float[Tensor, "batch seq_len n_layers d_sae"],
        ],
        Union[
            Float[Tensor, "batch n_layers d_sae"],
            Float[Tensor, "batch seq_len n_layers d_sae"],
        ],
    ],
]

Encode input activations to CLT features using L encoders.

Parameters:

Name Type Description Default
x Union[Float[Tensor, 'batch n_layers d_model'], Float[Tensor, 'batch seq_len n_layers d_model']]

Input activations from all layers (..., n_layers, d_model)

required
return_hidden_pre bool

Whether to return pre-activation values

False

Returns:

Type Description
Union[Float[Tensor, 'batch n_layers d_sae'], Float[Tensor, 'batch seq_len n_layers d_sae'], tuple[Union[Float[Tensor, 'batch n_layers d_sae'], Float[Tensor, 'batch seq_len n_layers d_sae']], Union[Float[Tensor, 'batch n_layers d_sae'], Float[Tensor, 'batch seq_len n_layers d_sae']]]]

Feature activations for all layers (..., n_layers, d_sae)

Source code in src/lm_saes/clt.py
@override
def encode(
    self,
    x: Union[
        Float[torch.Tensor, "batch n_layers d_model"],
        Float[torch.Tensor, "batch seq_len n_layers d_model"],
    ],
    return_hidden_pre: bool = False,
    **kwargs,
) -> Union[
    Float[torch.Tensor, "batch n_layers d_sae"],
    Float[torch.Tensor, "batch seq_len n_layers d_sae"],
    tuple[
        Union[
            Float[torch.Tensor, "batch n_layers d_sae"],
            Float[torch.Tensor, "batch seq_len n_layers d_sae"],
        ],
        Union[
            Float[torch.Tensor, "batch n_layers d_sae"],
            Float[torch.Tensor, "batch seq_len n_layers d_sae"],
        ],
    ],
]:
    """Encode input activations to CLT features using L encoders.

    Args:
        x: Input activations from all layers (..., n_layers, d_model)
        return_hidden_pre: Whether to return pre-activation values

    Returns:
        Feature activations for all layers (..., n_layers, d_sae)
    """
    with timer.time("encoder_matmul"):
        hidden_pre = torch.einsum("...ld,lds->...ls", x, self.W_E) + self.b_E

    if self.cfg.sparsity_include_decoder_norm:
        hidden_pre = hidden_pre * self.decoder_norm_per_feature()

    # Apply activation function (ReLU, TopK, etc.)
    with timer.time("activation_function"):
        feature_acts = self.activation_function(hidden_pre)

    if self.cfg.sparsity_include_decoder_norm:
        feature_acts = feature_acts / self.decoder_norm_per_feature()

    if return_hidden_pre:
        return feature_acts, hidden_pre
    return feature_acts

encode_single_layer

encode_single_layer(
    x: Union[
        Float[Tensor, "batch d_model"],
        Float[Tensor, "batch seq_len d_model"],
    ],
    layer: int,
    return_hidden_pre: bool = False,
    **kwargs,
) -> Union[
    Float[Tensor, "batch d_sae"],
    Float[Tensor, "batch seq_len d_sae"],
    tuple[
        Union[
            Float[Tensor, "batch d_sae"],
            Float[Tensor, "batch seq_len d_sae"],
        ],
        Union[
            Float[Tensor, "batch d_sae"],
            Float[Tensor, "batch seq_len d_sae"],
        ],
    ],
]

Encode input activations to CLT features using L encoders.

Parameters:

Name Type Description Default
x Union[Float[Tensor, 'batch d_model'], Float[Tensor, 'batch seq_len d_model']]

Input activations from a given layer (..., d_model)

required
layer int

The layer to encode

required
return_hidden_pre bool

Whether to return pre-activation values

False

Returns:

Type Description
Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae'], tuple[Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae']], Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae']]]]

Feature activations for the given layer (..., d_sae)

Source code in src/lm_saes/clt.py
def encode_single_layer(
    self,
    x: Union[
        Float[torch.Tensor, "batch d_model"],
        Float[torch.Tensor, "batch seq_len d_model"],
    ],
    layer: int,
    return_hidden_pre: bool = False,
    **kwargs,
) -> Union[
    Float[torch.Tensor, "batch d_sae"],
    Float[torch.Tensor, "batch seq_len d_sae"],
    tuple[
        Union[
            Float[torch.Tensor, "batch d_sae"],
            Float[torch.Tensor, "batch seq_len d_sae"],
        ],
        Union[
            Float[torch.Tensor, "batch d_sae"],
            Float[torch.Tensor, "batch seq_len d_sae"],
        ],
    ],
]:
    """Encode input activations to CLT features using L encoders.

    Args:
        x: Input activations from a given layer (..., d_model)
        layer: The layer to encode
        return_hidden_pre: Whether to return pre-activation values

    Returns:
        Feature activations for the given layer (..., d_sae)
    """
    # Apply each encoder to its corresponding layer: x[..., layer, :] @ W_E[layer] + b_E[layer]
    hidden_pre = torch.einsum("...d,ds->...s", x, self.W_E[layer]) + self.b_E[layer]

    # print(f'{x.shape=} {self.W_E[layer].shape=} {self.b_E[layer].shape=}')

    if self.cfg.sparsity_include_decoder_norm:
        # print(f'{hidden_pre.shape=} {self.decoder_norm_per_feature(layer=layer).shape=}')
        hidden_pre = hidden_pre * self.decoder_norm_per_feature(layer=layer)

    # Apply activation function (ReLU, TopK, etc.)
    if self.cfg.act_fn.lower() == "jumprelu":
        assert isinstance(self.activation_function, JumpReLU)
        jumprelu_threshold = self.activation_function.get_jumprelu_threshold()
        feature_acts = hidden_pre * hidden_pre.gt(jumprelu_threshold[layer])
    else:
        feature_acts = self.activation_function(hidden_pre)

    if return_hidden_pre:
        return feature_acts, hidden_pre
    return feature_acts

decode

decode(
    feature_acts: Union[
        Float[Tensor, "batch n_layers d_sae"],
        Float[Tensor, "batch seq_len n_layers d_sae"],
        List[Float[Tensor, "seq_len d_sae"]],
    ],
    batch_first: bool = False,
    **kwargs,
) -> Union[
    Float[Tensor, "n_layers batch d_model"],
    Float[Tensor, "n_layers batch seq_len d_model"],
    Float[Tensor, "batch n_layers d_model"],
    Float[Tensor, "batch seq_len n_layers d_model"],
]

Decode CLT features to output activations using the upper triangular pattern.

The output at layer L is the sum of contributions from all layers 0 through L: y_L = Σ_{i=0}^{L} W_D[i→L] @ feature_acts[..., i, :] + b_D[L]

Parameters:

Name Type Description Default
feature_acts Union[Float[Tensor, 'batch n_layers d_sae'], Float[Tensor, 'batch seq_len n_layers d_sae'], List[Float[Tensor, 'seq_len d_sae']]]

CLT feature activations (..., n_layers, d_sae)

required

Returns:

Type Description
Union[Float[Tensor, 'n_layers batch d_model'], Float[Tensor, 'n_layers batch seq_len d_model'], Float[Tensor, 'batch n_layers d_model'], Float[Tensor, 'batch seq_len n_layers d_model']]

Reconstructed activations for all layers (..., n_layers, d_model)

Source code in src/lm_saes/clt.py
@override
def decode(
    self,
    feature_acts: Union[
        Float[torch.Tensor, "batch n_layers d_sae"],
        Float[torch.Tensor, "batch seq_len n_layers d_sae"],
        List[Float[torch.sparse.Tensor, "seq_len d_sae"]],
    ],
    batch_first: bool = False,
    **kwargs,
) -> Union[
    Float[torch.Tensor, "n_layers batch d_model"],
    Float[torch.Tensor, "n_layers batch seq_len d_model"],
    Float[torch.Tensor, "batch n_layers d_model"],
    Float[torch.Tensor, "batch seq_len n_layers d_model"],
]:
    """Decode CLT features to output activations using the upper triangular pattern.

    The output at layer L is the sum of contributions from all layers 0 through L:
    y_L = Σ_{i=0}^{L} W_D[i→L] @ feature_acts[..., i, :] + b_D[L]

    Args:
        feature_acts: CLT feature activations (..., n_layers, d_sae)

    Returns:
        Reconstructed activations for all layers (..., n_layers, d_model)
    """
    # TODO: make this cleaner

    reconstructed = []
    # For each output layer L
    if (
        isinstance(feature_acts, list)
        and isinstance(feature_acts[0], torch.Tensor)
        and feature_acts[0].layout == torch.sparse_coo
    ):
        decode_single_output_layer = self._decode_single_output_layer_coo
    elif self.cfg.decode_with_csr:
        if self.current_k / (self.cfg.d_sae * self.cfg.n_layers) < self.cfg.sparsity_threshold_for_csr:
            decode_single_output_layer = self._decode_single_output_layer_csr
            assert not isinstance(feature_acts, list), (
                "feature_acts must not be a list when decode_with_csr is True"
            )
            if isinstance(feature_acts, DTensor):
                feature_acts = feature_acts.to_local()
            if feature_acts.layout != torch.sparse_csr:
                feature_acts = [fa.to_sparse_csr() for fa in feature_acts.permute(1, 0, 2)]
        else:
            decode_single_output_layer = self._decode_single_output_layer_dense
    else:
        decode_single_output_layer = self._decode_single_output_layer_dense

    for layer_to in range(self.cfg.n_layers):
        # we only compute W_D @ feature_acts here, without b_D
        contribution = decode_single_output_layer(feature_acts, layer_to)  # type: ignore

        # Add bias contribution (single bias vector for this target layer)
        contribution = contribution + self.b_D[layer_to]  # (d_model,)
        if isinstance(contribution, DTensor):
            contribution = DimMap({"data": 0}).redistribute(contribution)

        reconstructed.append(contribution)

    return torch.stack(reconstructed, dim=1 if batch_first else 0)

decoder_norm

decoder_norm(keepdim: bool = False)

Compute the effective norm of decoder weights for each feature.

Source code in src/lm_saes/clt.py
@override
def decoder_norm(self, keepdim: bool = False):
    """Compute the effective norm of decoder weights for each feature."""
    # Collect norms from all decoder groups
    return torch.ones(self.cfg.n_decoders, device=self.cfg.device, dtype=self.cfg.dtype)
    return torch.ones(self.cfg.n_decoders, device=self.cfg.device, dtype=self.cfg.dtype)

encoder_norm

encoder_norm(keepdim: bool = False)

Compute the norm of encoder weights averaged across layers.

Source code in src/lm_saes/clt.py
@override
def encoder_norm(self, keepdim: bool = False):
    """Compute the norm of encoder weights averaged across layers."""
    if not isinstance(self.W_E, DTensor):
        return torch.norm(self.W_E, p=2, dim=1, keepdim=keepdim).to(self.cfg.device)
    else:
        assert self.device_mesh is not None
        return DTensor.from_local(
            torch.norm(self.W_E.to_local(), p=2, dim=1, keepdim=keepdim),
            device_mesh=self.device_mesh,
            placements=DimMap({"model": 1 if keepdim else 0}).placements(self.device_mesh),
        )

decoder_bias_norm

decoder_bias_norm()

Compute the norm of decoder bias for each target layer.

Source code in src/lm_saes/clt.py
@override
def decoder_bias_norm(self):
    """Compute the norm of decoder bias for each target layer."""
    return torch.ones(self.cfg.n_layers, device=self.cfg.device, dtype=self.cfg.dtype)
    return torch.ones(self.cfg.n_layers, device=self.cfg.device, dtype=self.cfg.dtype)

set_encoder_to_fixed_norm

set_encoder_to_fixed_norm(value: float)

Set encoder weights to fixed norm.

Source code in src/lm_saes/clt.py
@override
@torch.no_grad()
def set_encoder_to_fixed_norm(self, value: float):
    """Set encoder weights to fixed norm."""
    raise NotImplementedError("set_encoder_to_fixed_norm does not make sense for CLT")

keep_only_decoders_for_layer_from

keep_only_decoders_for_layer_from(layer_from: int)

Keep only the decoder norm for the given layer.

Source code in src/lm_saes/clt.py
@torch.no_grad()
def keep_only_decoders_for_layer_from(self, layer_from: int):
    """Keep only the decoder norm for the given layer."""
    new_W_D = []
    for layer_to, decoder_weights in enumerate(self.W_D):
        if layer_to >= layer_from:
            new_W_D.append(decoder_weights[layer_from])
    self.decoders_for_layer_from = (layer_from, new_W_D)
    torch.cuda.empty_cache()

decoder_norm_per_feature

decoder_norm_per_feature(
    layer: int | None = None,
) -> Float[Tensor, "n_layers d_sae"]

Compute the norm of decoder weights for each feature. If layer is not None, only compute the norm for the decoder weights from layer to subsequent layers.

Source code in src/lm_saes/clt.py
@torch.no_grad()
def decoder_norm_per_feature(
    self,
    layer: int | None = None,
) -> Float[torch.Tensor, "n_layers d_sae"]:
    """
    Compute the norm of decoder weights for each feature.
    If layer is not None, only compute the norm for the decoder weights from layer to subsequent layers.
    """

    if self.device_mesh is None:
        decoder_norms = torch.zeros(
            self.cfg.n_layers,
            self.cfg.d_sae,
            dtype=self.cfg.dtype,
            device=self.cfg.device,
        )
    else:
        decoder_norms = torch.distributed.tensor.zeros(
            self.cfg.n_layers,
            self.cfg.d_sae,
            dtype=self.cfg.dtype,
            device_mesh=self.device_mesh,
            placements=self.dim_maps()["decoder_norms"].placements(self.device_mesh),
        )
    if layer is not None:
        if getattr(self, "decoders_for_layer_from", None) is not None:
            kept_layer_from, kept_decoders = getattr(self, "decoders_for_layer_from")
            assert kept_layer_from == layer
            for layer_to, decoder_weights in enumerate(kept_decoders):
                layer_to += layer
                decoder_norms[layer_to] = decoder_weights.pow(2).sum(dim=-1).sqrt()
        else:
            for layer_to, decoder_weights in enumerate(self.W_D[layer:]):
                layer_to += layer
                decoder_norms[layer_to] = decoder_weights[layer].pow(2).sum(dim=-1).sqrt()
    else:
        for layer_to, decoder_weights in enumerate(self.W_D):
            decoder_norms[: layer_to + 1] = decoder_norms[: layer_to + 1] + decoder_weights.pow(2).sum(dim=-1)
        decoder_norms = decoder_norms.sqrt()
    return decoder_norms

decoder_norm_per_decoder

decoder_norm_per_decoder() -> Union[
    Float[Tensor, n_decoders], DTensor
]

Compute the L2 norm of decoder weights for each decoder (layer_from -> layer_to). Returns: norms: torch.Tensor or DTensor of shape (n_decoders,), where n_decoders = n_layers * (n_layers + 1) // 2

Source code in src/lm_saes/clt.py
def decoder_norm_per_decoder(self) -> Union[Float[torch.Tensor, "n_decoders"], DTensor]:  # noqa: F821
    """Compute the L2 norm of decoder weights for each decoder (layer_from -> layer_to).
    Returns:
        norms: torch.Tensor or DTensor of shape (n_decoders,), where n_decoders = n_layers * (n_layers + 1) // 2
    """
    n_decoders: int = self.cfg.n_layers * (self.cfg.n_layers + 1) // 2
    if self.device_mesh is None:
        decoder_norms = torch.zeros(
            n_decoders,
            self.cfg.d_sae,
            dtype=self.cfg.dtype,
            device=self.cfg.device,
        )
    else:
        decoder_norms = torch.distributed.tensor.zeros(
            n_decoders,
            self.cfg.d_sae,
            dtype=self.cfg.dtype,
            device_mesh=self.device_mesh,
            placements=self.dim_maps()["decoder_norms"].placements(self.device_mesh),
        )
    idx = 0
    for layer_to, decoder_weights in enumerate(self.W_D):
        for layer_from in range(layer_to + 1):
            decoder_norms[idx] = decoder_weights[layer_from].pow(2).sum(dim=-1)
            idx += 1
    decoder_norms = decoder_norms.sqrt().mean(dim=-1)
    return decoder_norms

standardize_parameters_of_dataset_norm

standardize_parameters_of_dataset_norm()

Standardize parameters for dataset-wise normalization during inference.

Source code in src/lm_saes/clt.py
@override
@torch.no_grad()
def standardize_parameters_of_dataset_norm(self):
    """Standardize parameters for dataset-wise normalization during inference."""
    assert self.cfg.norm_activation == "dataset-wise"
    assert self.dataset_average_activation_norm is not None
    dataset_average_activation_norm = self.dataset_average_activation_norm

    def input_norm_factor(layer: int) -> float:
        return math.sqrt(self.cfg.d_model) / dataset_average_activation_norm[self.cfg.hook_points_in[layer]]

    def output_norm_factor(layer: int) -> float:
        return math.sqrt(self.cfg.d_model) / dataset_average_activation_norm[self.cfg.hook_points_out[layer]]

    # For CLT, we need to handle multiple input and output layers
    for layer_from in range(self.cfg.n_layers):
        # Adjust encoder bias for this layer
        self.b_E.data[layer_from].div_(input_norm_factor(layer_from))

        if self.cfg.act_fn.lower() == "jumprelu":
            assert isinstance(self.activation_function, JumpReLU)
            threshold = self.activation_function.log_jumprelu_threshold.data[layer_from].exp()
            threshold = threshold / input_norm_factor(layer_from)
            self.activation_function.log_jumprelu_threshold.data[layer_from] = torch.log(threshold)

    for layer_to in range(self.cfg.n_layers):
        self.b_D[layer_to].data.div_(output_norm_factor(layer_to))
        for layer_from in range(layer_to + 1):
            self.W_D[layer_to].data[layer_from].mul_(input_norm_factor(layer_from) / output_norm_factor(layer_to))

    self.cfg.norm_activation = "inference"

prepare_input

prepare_input(
    batch: dict[str, Tensor], **kwargs
) -> tuple[Tensor, dict[str, Any], dict[str, Any]]

Prepare input tensor from batch by stacking all layer activations from hook_points_in.

Source code in src/lm_saes/clt.py
@override
def prepare_input(
    self, batch: "dict[str, torch.Tensor]", **kwargs
) -> "tuple[torch.Tensor, dict[str, Any], dict[str, Any]]":
    """Prepare input tensor from batch by stacking all layer activations from hook_points_in."""
    x_layers = []
    for hook_point in self.cfg.hook_points_in:
        if hook_point not in batch:
            raise ValueError(f"Missing hook point {hook_point} in batch")
        x_layers.append(batch[hook_point])
    # it is a bug of DTensor, ideally, we should stack along dim=-2,but it will cause an error on shard dim.
    x = torch.stack(x_layers, dim=x_layers[0].ndim - 1)  # (..., n_layers, d_model)

    encoder_kwargs = {}
    decoder_kwargs = {}
    return x, encoder_kwargs, decoder_kwargs

prepare_input_single_layer

prepare_input_single_layer(
    batch: dict[str, Tensor], layer: int, **kwargs
) -> tuple[Tensor, dict[str, Any], dict[str, Any]]

Prepare input tensor from batch by stacking all layer activations from hook_points_in.

Source code in src/lm_saes/clt.py
def prepare_input_single_layer(
    self, batch: "dict[str, torch.Tensor]", layer: int, **kwargs
) -> "tuple[torch.Tensor, dict[str, Any], dict[str, Any]]":
    """Prepare input tensor from batch by stacking all layer activations from hook_points_in."""
    hook_point_in = self.cfg.hook_points_in[layer]
    if hook_point_in not in batch:
        raise ValueError(f"Missing hook point {hook_point_in} in batch")
    x = batch[hook_point_in]
    return x, {}, {}

prepare_label

prepare_label(batch: dict[str, Tensor], **kwargs) -> Tensor

Prepare label tensor from batch using hook_points_out.

Source code in src/lm_saes/clt.py
@override
def prepare_label(self, batch: "dict[str, torch.Tensor]", **kwargs) -> torch.Tensor:
    """Prepare label tensor from batch using hook_points_out."""
    x_layers = []
    for hook_point in self.cfg.hook_points_out:
        if hook_point not in batch:
            raise ValueError(f"Missing hook point {hook_point} in batch")
        x_layers.append(batch[hook_point])
    labels = torch.stack(x_layers, dim=0)  # (n_layers, ..., d_model)
    return labels

compute_training_metrics

compute_training_metrics(
    *,
    l0: Tensor,
    explained_variance_legacy: Tensor,
    **kwargs,
) -> dict[str, float]

Compute per-layer training metrics for CLT.

Source code in src/lm_saes/clt.py
@override
@torch.no_grad()
def compute_training_metrics(
    self,
    *,
    l0: torch.Tensor,
    explained_variance_legacy: torch.Tensor,
    **kwargs,
) -> dict[str, float]:
    """Compute per-layer training metrics for CLT."""
    assert explained_variance_legacy.ndim == 1 and len(explained_variance_legacy) == self.cfg.n_layers, (
        f"explained_variance_legacy should be of shape (n_layers,), but got {explained_variance_legacy.shape}"
    )
    clt_per_layer_ev_dict = {
        f"metrics/explained_variance_L{l}": item(explained_variance_legacy[l].mean())
        for l in range(explained_variance_legacy.size(1))
    }
    clt_per_layer_l0_dict = {f"metrics/l0_layer{l}": item(l0[:, l].mean()) for l in range(l0.size(1))}
    return {**clt_per_layer_ev_dict, **clt_per_layer_l0_dict}

compute_loss

compute_loss(
    batch: dict[str, Tensor],
    *,
    sparsity_loss_type: Literal[
        "power", "tanh", "tanh-quad", None
    ] = None,
    tanh_stretch_coefficient: float = 4.0,
    p: int = 1,
    l1_coefficient: float = 1.0,
    return_aux_data: Literal[True] = True,
    **kwargs,
) -> dict[str, Any]
compute_loss(
    batch: dict[str, Tensor],
    *,
    sparsity_loss_type: Literal[
        "power", "tanh", "tanh-quad", None
    ] = None,
    tanh_stretch_coefficient: float = 4.0,
    p: int = 1,
    l1_coefficient: float = 1.0,
    return_aux_data: Literal[False],
    **kwargs,
) -> Float[Tensor, " batch"]
compute_loss(
    batch: dict[str, Tensor],
    label: Optional[
        Union[
            Float[Tensor, "batch d_model"],
            Float[Tensor, "batch seq_len d_model"],
        ]
    ] = None,
    *,
    sparsity_loss_type: Literal[
        "power", "tanh", "tanh-quad", None
    ] = None,
    tanh_stretch_coefficient: float = 4.0,
    frequency_scale: float = 0.01,
    p: int = 1,
    l1_coefficient: float = 1.0,
    return_aux_data: bool = True,
    **kwargs,
) -> Union[Float[Tensor, " batch"], dict[str, Any]]

Compute the loss for the autoencoder. Ensure that the input activations are normalized by calling normalize_activations before calling this method.

Source code in src/lm_saes/clt.py
@timer.time("compute_loss")
def compute_loss(
    self,
    batch: dict[str, torch.Tensor],
    label: (
        Optional[
            Union[
                Float[torch.Tensor, "batch d_model"],
                Float[torch.Tensor, "batch seq_len d_model"],
            ]
        ]
    ) = None,
    *,
    sparsity_loss_type: Literal["power", "tanh", "tanh-quad", None] = None,
    tanh_stretch_coefficient: float = 4.0,
    frequency_scale: float = 0.01,
    p: int = 1,
    l1_coefficient: float = 1.0,
    return_aux_data: bool = True,
    **kwargs,
) -> Union[
    Float[torch.Tensor, " batch"],
    dict[str, Any],
]:
    """Compute the loss for the autoencoder.
    Ensure that the input activations are normalized by calling `normalize_activations` before calling this method.
    """
    x, encoder_kwargs, decoder_kwargs = self.prepare_input(batch)
    label = self.prepare_label(batch, **kwargs)

    with timer.time("encode"):
        feature_acts = self.encode(x, **encoder_kwargs)

    with timer.time("decode"):
        reconstructed = self.decode(feature_acts, **decoder_kwargs)

    with timer.time("loss_calculation"):
        l_rec = (reconstructed - label).pow(2)
        l_rec = l_rec.sum(dim=-1).mean()
        if isinstance(l_rec, DTensor):
            l_rec: Tensor = l_rec.full_tensor()
        loss_dict: dict[str, Optional[torch.Tensor]] = {
            "l_rec": l_rec,
        }
        loss = l_rec

        if sparsity_loss_type is not None:
            decoder_norm: Union[Float[torch.Tensor, "n_layers d_sae"], DTensor] = self.decoder_norm_per_feature()
            with timer.time("sparsity_loss_calculation"):
                if sparsity_loss_type == "power":
                    l_s = torch.norm(feature_acts * decoder_norm, p=p, dim=-1)
                elif sparsity_loss_type == "tanh":
                    l_s = torch.tanh(tanh_stretch_coefficient * feature_acts * decoder_norm).sum(dim=-1)
                elif sparsity_loss_type == "tanh-quad":
                    approx_frequency = einops.reduce(
                        torch.tanh(tanh_stretch_coefficient * feature_acts * decoder_norm),
                        "... d_sae -> d_sae",
                        "mean",
                    )
                    l_s = (approx_frequency * (1 + approx_frequency / frequency_scale)).sum(dim=-1)
                else:
                    raise ValueError(f"sparsity_loss_type f{sparsity_loss_type} not supported.")
                if isinstance(l_s, DTensor):
                    l_s = l_s.full_tensor()
                l_s = l1_coefficient * l_s
                # WARNING: Some DTensor bugs make if l1_coefficient * l_s goes before full_tensor, the l1_coefficient value will be internally cached. Furthermore, it will cause the backward pass to fail with redistribution error. See https://github.com/pytorch/pytorch/issues/153603 and https://github.com/pytorch/pytorch/issues/153615 .
                loss_dict["l_s"] = l_s
                loss = loss + l_s.mean()
        else:
            loss_dict["l_s"] = None

    if return_aux_data:
        return {
            "loss": loss,
            **loss_dict,
            "label": label,
            "mask": batch.get("mask"),
            "n_tokens": batch["tokens"].numel() if batch.get("mask") is None else int(item(batch["mask"].sum())),
            "feature_acts": feature_acts,
            "reconstructed": reconstructed,
        }
    return loss

dim_maps

dim_maps() -> dict[str, DimMap]

Return dimension maps for distributed training along feature dimension.

Source code in src/lm_saes/clt.py
def dim_maps(self) -> "dict[str, DimMap]":
    """Return dimension maps for distributed training along feature dimension."""
    base_maps = super().dim_maps()

    clt_maps = {
        "W_E": DimMap({"model": 2}),  # Shard along d_sae dimension
        "b_E": DimMap({"model": 1}),  # Shard along d_sae dimension
        "W_D": DimMap({"model": 1}),  # Shard along d_sae dimension
        "b_D": DimMap({}),  # Replicate decoder biases
        "decoder_norms": DimMap({"model": 1}),  # Shard along d_sae dimension
    }

    return base_maps | clt_maps

LorsaConfig pydantic-model

Bases: BaseSAEConfig

Configuration for Low Rank Sparse Attention.

Fields:

  • device (str)
  • dtype (dtype)
  • d_model (int)
  • expansion_factor (float)
  • use_decoder_bias (bool)
  • act_fn (Literal['relu', 'jumprelu', 'topk', 'batchtopk', 'batchlayertopk', 'layertopk'])
  • norm_activation (Literal['token-wise', 'batch-wise', 'dataset-wise', 'inference'])
  • sparsity_include_decoder_norm (bool)
  • top_k (int)
  • use_triton_kernel (bool)
  • sparsity_threshold_for_triton_spmm_kernel (float)
  • jumprelu_threshold_window (float)
  • sae_type (str)
  • hook_point_in (str)
  • hook_point_out (str)
  • n_qk_heads (int)
  • d_qk_head (int)
  • positional_embedding_type (Literal['rotary', 'none'])
  • rotary_dim (int)
  • rotary_base (int)
  • rotary_adjacent_pairs (bool)
  • rotary_scale (int)
  • use_NTK_by_parts_rope (bool)
  • NTK_by_parts_factor (float)
  • NTK_by_parts_low_freq_factor (float)
  • NTK_by_parts_high_freq_factor (float)
  • old_context_len (int)
  • n_ctx (int)
  • attn_scale (float | None)
  • use_post_qk_ln (bool)
  • normalization_type (Literal['LN', 'RMS'] | None)
  • eps (float)

associated_hook_points property

associated_hook_points: list[str]

All hook points used by Lorsa.

LowRankSparseAttention

LowRankSparseAttention(
    cfg: LorsaConfig,
    device_mesh: Optional[DeviceMesh] = None,
)

Bases: AbstractSparseAutoEncoder

Source code in src/lm_saes/lorsa.py
def __init__(self, cfg: LorsaConfig, device_mesh: Optional[DeviceMesh] = None):
    super().__init__(cfg, device_mesh=device_mesh)
    self.cfg = cfg

    if device_mesh is None:
        # Local parameters
        def _get_param_with_shape(shape: tuple[int, ...]) -> nn.Parameter:
            return nn.Parameter(
                torch.empty(
                    shape,
                    dtype=self.cfg.dtype,
                    device=self.cfg.device,
                )
            )

        self.W_Q = _get_param_with_shape((self.cfg.n_qk_heads, self.cfg.d_model, self.cfg.d_qk_head))
        self.W_K = _get_param_with_shape((self.cfg.n_qk_heads, self.cfg.d_model, self.cfg.d_qk_head))
        self.W_V = _get_param_with_shape((self.cfg.n_ov_heads, self.cfg.d_model))
        self.W_O = _get_param_with_shape((self.cfg.n_ov_heads, self.cfg.d_model))
        self.b_Q = _get_param_with_shape((self.cfg.n_qk_heads, self.cfg.d_qk_head))
        self.b_K = _get_param_with_shape((self.cfg.n_qk_heads, self.cfg.d_qk_head))
        self.b_V = _get_param_with_shape((self.cfg.n_ov_heads,))
        if self.cfg.use_decoder_bias:
            self.b_D = _get_param_with_shape((self.cfg.d_model,))
    else:
        # Distributed parameters with head sharding
        dim_maps = self.dim_maps()

        def _get_param_with_shape(shape: tuple[int, ...], placements: Sequence[Any]) -> nn.Parameter:
            return nn.Parameter(
                torch.distributed.tensor.empty(
                    shape,
                    dtype=self.cfg.dtype,
                    device_mesh=device_mesh,
                    placements=placements,
                )
            )

        self.W_Q = _get_param_with_shape(
            (self.cfg.n_qk_heads, self.cfg.d_model, self.cfg.d_qk_head),
            placements=dim_maps["W_Q"].placements(device_mesh),
        )
        self.W_K = _get_param_with_shape(
            (self.cfg.n_qk_heads, self.cfg.d_model, self.cfg.d_qk_head),
            placements=dim_maps["W_K"].placements(device_mesh),
        )
        self.W_V = _get_param_with_shape(
            (self.cfg.n_ov_heads, self.cfg.d_model), placements=dim_maps["W_V"].placements(device_mesh)
        )
        self.W_O = _get_param_with_shape(
            (self.cfg.n_ov_heads, self.cfg.d_model), placements=dim_maps["W_O"].placements(device_mesh)
        )
        self.b_Q = _get_param_with_shape(
            (self.cfg.n_qk_heads, self.cfg.d_qk_head), placements=dim_maps["b_Q"].placements(device_mesh)
        )
        self.b_K = _get_param_with_shape(
            (self.cfg.n_qk_heads, self.cfg.d_qk_head), placements=dim_maps["b_K"].placements(device_mesh)
        )
        self.b_V = _get_param_with_shape((self.cfg.n_ov_heads,), placements=dim_maps["b_V"].placements(device_mesh))
        if self.cfg.use_decoder_bias:
            self.b_D = _get_param_with_shape(
                (self.cfg.d_model,), placements=dim_maps["b_D"].placements(device_mesh)
            )

    # Attention mask
    mask = torch.tril(
        torch.ones(
            (self.cfg.n_ctx, self.cfg.n_ctx),
            device=self.cfg.device,
            dtype=self.cfg.dtype,
        ).bool(),
    )
    if self.device_mesh is not None:
        mask = DimMap({}).distribute(mask, self.device_mesh)
    self.register_buffer("mask", mask)

    if self.device_mesh is not None:
        IGNORE = DimMap({}).distribute(torch.tensor(-torch.inf, device=self.cfg.device), self.device_mesh)
    else:
        IGNORE = torch.tensor(-torch.inf, device=self.cfg.device)
    self.register_buffer("IGNORE", IGNORE)

    if self.cfg.use_post_qk_ln:
        # if self.cfg.normalization_type == "LN":
        #     # TODO: fix this
        #     pass
        if self.cfg.normalization_type == "RMS":
            self.qk_ln_type = RMSNormPerHead
        else:
            raise ValueError(f"Invalid normalization type for QK-norm: {self.cfg.normalization_type}")
    else:
        self.qk_ln_type = None

    if self.cfg.use_post_qk_ln:
        assert self.qk_ln_type is not None
        self.ln_q = self.qk_ln_type(self.cfg, n_heads=self.cfg.n_qk_heads, device_mesh=device_mesh)
        self.ln_k = self.qk_ln_type(self.cfg, n_heads=self.cfg.n_qk_heads, device_mesh=device_mesh)

    if self.cfg.positional_embedding_type == "rotary":
        # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position.
        if self.cfg.rotary_dim is None:  # keep mypy happy
            raise ValueError("Rotary dim must be provided for rotary positional embeddings")
        sin, cos = self._calculate_sin_cos_rotary(
            self.cfg.rotary_dim,
            self.cfg.n_ctx,
            base=self.cfg.rotary_base,
            dtype=self.cfg.dtype,
            device=self.cfg.device,
        )
        if self.device_mesh is not None:
            sin = DimMap({}).distribute(sin, self.device_mesh)
            cos = DimMap({}).distribute(cos, self.device_mesh)
        self.register_buffer("rotary_sin", sin)
        self.register_buffer("rotary_cos", cos)

init_parameters

init_parameters(**kwargs)

Initialize parameters.

Source code in src/lm_saes/lorsa.py
def init_parameters(self, **kwargs):
    """Initialize parameters."""
    super().init_parameters(**kwargs)

    torch.nn.init.xavier_uniform_(self.W_Q)
    torch.nn.init.xavier_uniform_(self.W_K)

    W_V_bound = 1 / math.sqrt(self.cfg.d_sae)
    # torch.nn.init.uniform_(self.W_V, -W_V_bound, W_V_bound)
    torch.nn.init.normal_(self.W_V, mean=0, std=W_V_bound)

    W_O_bound = 1 / math.sqrt(self.cfg.d_model)
    # torch.nn.init.uniform_(self.W_O, -W_O_bound, W_O_bound)
    torch.nn.init.normal_(self.W_O, mean=0, std=W_O_bound)

    torch.nn.init.zeros_(self.b_Q)
    torch.nn.init.zeros_(self.b_K)
    torch.nn.init.zeros_(self.b_V)
    if self.cfg.use_decoder_bias:
        torch.nn.init.zeros_(self.b_D)

init_lorsa_with_mhsa

init_lorsa_with_mhsa(
    mhsa: Attention | GroupedQueryAttention,
)

Initialize Lorsa with Original Multi Head Sparse Attention

Source code in src/lm_saes/lorsa.py
@torch.no_grad()
def init_lorsa_with_mhsa(self, mhsa: Attention | GroupedQueryAttention):
    """Initialize Lorsa with Original Multi Head Sparse Attention"""
    assert self.cfg.n_qk_heads % mhsa.W_Q.size(0) == 0
    assert self.cfg.d_qk_head == mhsa.W_Q.size(2)
    assert self.dataset_average_activation_norm is not None
    input_norm_factor = math.sqrt(self.cfg.d_model) / self.dataset_average_activation_norm[self.cfg.hook_point_in]
    qk_exp_factor = self.cfg.n_qk_heads // mhsa.W_Q.size(0)
    if self.device_mesh is not None:
        model_parallel_rank = self.device_mesh.get_local_rank(mesh_dim="model")
        model_parallel_size = mesh_dim_size(self.device_mesh, "model")
        lorsa_qk_start_idx = model_parallel_rank * self.cfg.n_qk_heads // model_parallel_size
        lorsa_qk_end_idx = lorsa_qk_start_idx + self.cfg.n_qk_heads // model_parallel_size
        lorsa_qk_indices = torch.arange(lorsa_qk_start_idx, lorsa_qk_end_idx)
        W_Q_local = mhsa.W_Q[lorsa_qk_indices // qk_exp_factor] / input_norm_factor
        W_K_local = mhsa.W_K[lorsa_qk_indices // qk_exp_factor] / input_norm_factor
        W_Q = DTensor.from_local(
            W_Q_local,
            device_mesh=self.device_mesh,
            placements=self.dim_maps()["W_Q"].placements(self.device_mesh),
        )
        W_K = DTensor.from_local(
            W_K_local,
            device_mesh=self.device_mesh,
            placements=self.dim_maps()["W_K"].placements(self.device_mesh),
        )
        self.W_Q.copy_(W_Q)
        self.W_K.copy_(W_K)
        if self.cfg.use_post_qk_ln and self.cfg.normalization_type == "RMS":
            assert FORKED_TL, "Post-QK layer normalization requires the forked TransformerLens (lmsaes)."
            ln_q_w_local = mhsa.ln_q.w[lorsa_qk_indices // qk_exp_factor]  # type: ignore[attr-defined]
            if mhsa.cfg.n_key_value_heads is not None:
                ln_k_w_local = torch.repeat_interleave(
                    mhsa.ln_k.w,  # type: ignore[attr-defined]
                    mhsa.cfg.n_heads // mhsa.cfg.n_key_value_heads,
                    dim=0,
                )[lorsa_qk_indices // qk_exp_factor]
            else:
                ln_k_w_local = mhsa.ln_k.w[lorsa_qk_indices // qk_exp_factor]  # type: ignore[attr-defined]
            ln_q_w = DTensor.from_local(
                ln_q_w_local,
                device_mesh=self.device_mesh,
                placements=self.ln_q.dim_maps()["w"].placements(self.device_mesh),
            )
            ln_k_w = DTensor.from_local(
                ln_k_w_local,
                device_mesh=self.device_mesh,
                placements=self.ln_k.dim_maps()["w"].placements(self.device_mesh),
            )
            self.ln_q.w.copy_(ln_q_w)
            self.ln_k.w.copy_(ln_k_w)
    else:
        self.W_Q = nn.Parameter(
            torch.repeat_interleave(mhsa.W_Q, qk_exp_factor, dim=0).to(self.cfg.dtype) / input_norm_factor
        )
        self.W_K = nn.Parameter(
            torch.repeat_interleave(mhsa.W_K, qk_exp_factor, dim=0).to(self.cfg.dtype) / input_norm_factor
        )
        if self.cfg.use_post_qk_ln and self.cfg.normalization_type == "RMS":
            assert FORKED_TL, "Post-QK layer normalization requires the forked TransformerLens (lmsaes)."
            self.ln_q.w = nn.Parameter(
                torch.repeat_interleave(mhsa.ln_q.w, qk_exp_factor, dim=0).to(self.cfg.dtype)  # type: ignore[attr-defined]
            )
            self.ln_k.w = nn.Parameter(
                torch.repeat_interleave(mhsa.ln_k.w, self.ln_k.w.size(0) // mhsa.ln_k.w.size(0), dim=0).to(  # type: ignore[attr-defined]
                    self.cfg.dtype
                )
            )

init_W_D_with_active_subspace_per_head

init_W_D_with_active_subspace_per_head(
    batch: dict[str, Tensor],
    mhsa: Attention | GroupedQueryAttention,
)

Initialize W_D with the active subspace for each head.

Source code in src/lm_saes/lorsa.py
@torch.no_grad()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def init_W_D_with_active_subspace_per_head(
    self, batch: dict[str, torch.Tensor], mhsa: Attention | GroupedQueryAttention
):
    """
    Initialize W_D with the active subspace for each head.
    """
    x = self.prepare_input(batch)[0]
    if isinstance(x, DTensor):
        x = x.to_local()

    captured_z = None

    def capture_hook(tensor, hook):
        nonlocal captured_z
        captured_z = tensor.clone().detach()
        return tensor

    mhsa.hook_z.add_hook(capture_hook)
    _ = mhsa.forward(
        query_input=x,
        key_input=x,
        value_input=x,
    )
    output_per_head = torch.einsum("b s n h, n h d -> b s n d", captured_z, mhsa.W_O)
    n_ov_per_orig_head = self.cfg.n_ov_heads // mhsa.cfg.n_heads
    if self.device_mesh is not None:
        assert isinstance(self.W_O, DTensor)
        assert isinstance(self.W_V, DTensor)
        model_parallel_rank = self.device_mesh.get_local_rank(mesh_dim="model")
        model_parallel_size = mesh_dim_size(self.device_mesh, "model")
        orig_start_idx = model_parallel_rank * mhsa.cfg.n_heads // model_parallel_size
        orig_end_idx = orig_start_idx + mhsa.cfg.n_heads // model_parallel_size
        W_O_local = torch.empty_like(self.W_O.to_local())
        W_V_local = torch.empty_like(self.W_V.to_local())
        for orig_head_index in range(orig_start_idx, orig_end_idx):
            output = output_per_head[:, :, orig_head_index, :]
            output_flattened = output.flatten(0, 1)
            demeaned_output = output_flattened - output_flattened.mean(dim=0)
            U, S, V = torch.svd(demeaned_output.T.to(torch.float32))
            proj_weight = U[:, : self.cfg.d_qk_head]
            start_idx = (orig_head_index - orig_start_idx) * n_ov_per_orig_head
            end_idx = min(start_idx + n_ov_per_orig_head, W_O_local.size(0))
            W_O_local[start_idx:end_idx] = (
                self.W_O.to_local()[start_idx:end_idx, : self.cfg.d_qk_head] @ proj_weight.T
            )
            W_V_local[start_idx:end_idx] = (
                W_O_local[start_idx:end_idx] @ (mhsa.W_V[orig_head_index] @ mhsa.W_O[orig_head_index]).T
            )
        W_V_local = W_V_local / W_V_local.norm(dim=1, keepdim=True)
        W_O_local = W_O_local / W_O_local.norm(dim=1, keepdim=True)
        torch.distributed.broadcast(tensor=W_O_local, group=self.device_mesh.get_group("data"), group_src=0)
        torch.distributed.broadcast(tensor=W_V_local, group=self.device_mesh.get_group("data"), group_src=0)
        W_O_global = DTensor.from_local(
            W_O_local, device_mesh=self.device_mesh, placements=self.dim_maps()["W_O"].placements(self.device_mesh)
        )
        W_V_global = DTensor.from_local(
            W_V_local, device_mesh=self.device_mesh, placements=self.dim_maps()["W_V"].placements(self.device_mesh)
        )
        self.W_O.copy_(W_O_global)
        self.W_V.copy_(W_V_global)
    else:
        for orig_head_index in range(mhsa.cfg.n_heads):
            output = output_per_head[:, :, orig_head_index, :]
            output_flattened = output.flatten(0, 1)
            demeaned_output = output_flattened - output_flattened.mean(dim=0)
            U, S, V = torch.svd(demeaned_output.T.to(torch.float32))
            proj_weight = U[:, : self.cfg.d_qk_head]
            self.W_O[orig_head_index * n_ov_per_orig_head : (orig_head_index + 1) * n_ov_per_orig_head] = (
                self.W_O[
                    orig_head_index * n_ov_per_orig_head : (orig_head_index + 1) * n_ov_per_orig_head,
                    : self.cfg.d_qk_head,
                ]
                @ proj_weight.T
            )
            self.W_V[orig_head_index * n_ov_per_orig_head : (orig_head_index + 1) * n_ov_per_orig_head] = (
                self.W_O[orig_head_index * n_ov_per_orig_head : (orig_head_index + 1) * n_ov_per_orig_head]
                @ (mhsa.W_V[orig_head_index] @ mhsa.W_O[orig_head_index]).T
            )
        self.W_V.copy_(self.W_V / self.W_V.norm(dim=1, keepdim=True))
        self.W_O.copy_(self.W_O / self.W_O.norm(dim=1, keepdim=True))

init_W_V_with_active_subspace_per_head

init_W_V_with_active_subspace_per_head(
    batch: dict[str, Tensor],
    mhsa: Attention | GroupedQueryAttention,
)

Initialize W_D with the active subspace for each head.

Source code in src/lm_saes/lorsa.py
@torch.no_grad()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def init_W_V_with_active_subspace_per_head(
    self, batch: dict[str, torch.Tensor], mhsa: Attention | GroupedQueryAttention
):
    """
    Initialize W_D with the active subspace for each head.
    """
    x = self.prepare_input(batch)[0]
    if isinstance(x, DTensor):
        x = x.to_local()

    v_per_head = (
        x.reshape(-1, self.cfg.d_model) @ mhsa.W_V.permute(1, 0, 2).reshape(mhsa.cfg.d_model, mhsa.cfg.d_model)
    ).reshape(-1, mhsa.cfg.n_heads, mhsa.cfg.d_head)
    captured_v = torch.einsum("bnh,nhd->bnd", v_per_head, mhsa.W_V.permute(0, 2, 1))

    n_ov_per_orig_head = self.cfg.n_ov_heads // mhsa.cfg.n_heads
    if self.device_mesh is not None:
        assert isinstance(self.W_O, DTensor)
        assert isinstance(self.W_V, DTensor)
        model_parallel_rank = self.device_mesh.get_local_rank(mesh_dim="model")
        model_parallel_size = mesh_dim_size(self.device_mesh, "model")
        orig_start_idx = model_parallel_rank * mhsa.cfg.n_heads // model_parallel_size
        orig_end_idx = orig_start_idx + mhsa.cfg.n_heads // model_parallel_size
        W_O_local = torch.empty_like(self.W_O.to_local())
        W_V_local = torch.empty_like(self.W_V.to_local())
        for orig_head_index in range(orig_start_idx, orig_end_idx):
            v = captured_v[:, orig_head_index]
            demeaned_v = v - v.mean(dim=0)
            U, S, V = torch.svd(demeaned_v.T.to(torch.float32))
            proj_weight = U[:, : self.cfg.d_qk_head]
            start_idx = (orig_head_index - orig_start_idx) * n_ov_per_orig_head
            end_idx = min(start_idx + n_ov_per_orig_head, W_O_local.size(0))
            W_V_local[start_idx:end_idx] = (
                self.W_V.to_local()[start_idx:end_idx, : self.cfg.d_qk_head] @ proj_weight.T
            )
            W_O_local[start_idx:end_idx] = (
                W_V_local[start_idx:end_idx] @ mhsa.W_V[orig_head_index] @ mhsa.W_O[orig_head_index]
            )
        W_V_local = W_V_local / W_V_local.norm(dim=1, keepdim=True)
        W_O_local = W_O_local / W_O_local.norm(dim=1, keepdim=True)
        torch.distributed.broadcast(tensor=W_O_local, group=self.device_mesh.get_group("data"), group_src=0)
        torch.distributed.broadcast(tensor=W_V_local, group=self.device_mesh.get_group("data"), group_src=0)
        W_O_global = DTensor.from_local(
            W_O_local, device_mesh=self.device_mesh, placements=self.dim_maps()["W_O"].placements(self.device_mesh)
        )
        W_V_global = DTensor.from_local(
            W_V_local, device_mesh=self.device_mesh, placements=self.dim_maps()["W_V"].placements(self.device_mesh)
        )
        self.W_O.copy_(W_O_global)
        self.W_V.copy_(W_V_global)
    else:
        for orig_head_index in range(mhsa.cfg.n_heads):
            v = captured_v[:, orig_head_index]
            demeaned_v = v - v.mean(dim=0)
            U, S, V = torch.svd(demeaned_v.T.to(torch.float32))
            proj_weight = U[:, : self.cfg.d_qk_head]
            self.W_V[orig_head_index * n_ov_per_orig_head : (orig_head_index + 1) * n_ov_per_orig_head] = (
                self.W_V[
                    orig_head_index * n_ov_per_orig_head : (orig_head_index + 1) * n_ov_per_orig_head,
                    : self.cfg.d_qk_head,
                ]
                @ proj_weight.T
            )
            self.W_O[orig_head_index * n_ov_per_orig_head : (orig_head_index + 1) * n_ov_per_orig_head] = (
                self.W_V[orig_head_index * n_ov_per_orig_head : (orig_head_index + 1) * n_ov_per_orig_head]
                @ mhsa.W_V[orig_head_index]
                @ mhsa.W_O[orig_head_index]
            )
        self.W_V.copy_(self.W_V / self.W_V.norm(dim=1, keepdim=True))
        self.W_O.copy_(self.W_O / self.W_O.norm(dim=1, keepdim=True))

encoder_norm

encoder_norm(keepdim: bool = False) -> Tensor

Norm of encoder (Q/K weights).

Source code in src/lm_saes/lorsa.py
@override
def encoder_norm(self, keepdim: bool = False) -> torch.Tensor:
    """Norm of encoder (Q/K weights)."""
    if not isinstance(self.W_V, DTensor):
        return torch.norm(self.W_V, p=2, dim=1, keepdim=keepdim).to(self.cfg.device)
    else:
        assert self.device_mesh is not None
        return DTensor.from_local(
            torch.norm(self.W_V.to_local(), p=2, dim=1, keepdim=keepdim),
            device_mesh=self.device_mesh,
            placements=self.dim_maps()["W_V"].placements(self.device_mesh),
        )

decoder_norm

decoder_norm(keepdim: bool = False) -> Tensor

Norm of decoder (O weights).

Source code in src/lm_saes/lorsa.py
@override
def decoder_norm(self, keepdim: bool = False) -> torch.Tensor:
    """Norm of decoder (O weights)."""
    if not isinstance(self.W_O, DTensor):
        return torch.norm(self.W_O, p=2, dim=1, keepdim=keepdim).to(self.cfg.device)
    else:
        assert self.device_mesh is not None
        return DTensor.from_local(
            torch.norm(self.W_O.to_local(), p=2, dim=1, keepdim=keepdim),
            device_mesh=self.device_mesh,
            placements=self.dim_maps()["W_O"].placements(self.device_mesh),
        )

decoder_bias_norm

decoder_bias_norm() -> Tensor

Norm of decoder bias.

Source code in src/lm_saes/lorsa.py
@override
def decoder_bias_norm(self) -> torch.Tensor:
    """Norm of decoder bias."""
    if not self.cfg.use_decoder_bias:
        raise ValueError("Decoder bias not used")
    return torch.norm(self.b_D, p=2, dim=0, keepdim=True)

transform_to_unit_decoder_norm

transform_to_unit_decoder_norm()

Transform to unit decoder norm.

Source code in src/lm_saes/lorsa.py
@override
@torch.no_grad()
def transform_to_unit_decoder_norm(self):
    """Transform to unit decoder norm."""
    norm = self.decoder_norm(keepdim=True)
    self.W_O /= norm
    self.W_V *= norm
    self.b_V *= norm.squeeze()

standardize_parameters_of_dataset_norm

standardize_parameters_of_dataset_norm()

Standardize parameters for dataset norm.

Source code in src/lm_saes/lorsa.py
@override
@torch.no_grad()
def standardize_parameters_of_dataset_norm(self):
    """Standardize parameters for dataset norm."""
    assert self.cfg.norm_activation == "dataset-wise"
    assert self.dataset_average_activation_norm is not None

    hook_point_in = self.cfg.hook_point_in
    hook_point_out = self.cfg.hook_point_out

    input_norm_factor = math.sqrt(self.cfg.d_model) / self.dataset_average_activation_norm[hook_point_in]
    output_norm_factor = math.sqrt(self.cfg.d_model) / self.dataset_average_activation_norm[hook_point_out]

    self.W_Q.data *= input_norm_factor
    self.W_K.data *= input_norm_factor

    self.W_V.data *= input_norm_factor

    self.W_O.data = self.W_O.data / output_norm_factor
    self.b_D.data = self.b_D.data / output_norm_factor

    self.cfg.norm_activation = "inference"

compute_hidden_pre

compute_hidden_pre(
    x: Float[Tensor, "batch seq_len d_model"],
) -> Float[Tensor, "batch seq_len d_sae"]

Compute the hidden pre-activations.

Source code in src/lm_saes/lorsa.py
def compute_hidden_pre(
    self, x: Float[torch.Tensor, "batch seq_len d_model"]
) -> Float[torch.Tensor, "batch seq_len d_sae"]:
    """Compute the hidden pre-activations."""
    q, k, v = self._compute_qkv(x)
    query = q.permute(0, 2, 1, 3)
    key = k.permute(0, 2, 1, 3)
    value = v.reshape(*k.shape[:3], -1).permute(0, 2, 1, 3)
    with sdpa_kernel(
        backends=[
            SDPBackend.FLASH_ATTENTION,
            SDPBackend.CUDNN_ATTENTION,
            SDPBackend.EFFICIENT_ATTENTION,
            SDPBackend.MATH,
        ]
    ):
        z = F.scaled_dot_product_attention(
            query, key, value, scale=1 / self.attn_scale, is_causal=True, enable_gqa=True
        )
    return z.permute(0, 2, 1, 3).reshape(*v.shape)

encode

encode(
    x: Float[Tensor, "batch seq_len d_model"],
    return_hidden_pre: Literal[False] = False,
    **kwargs,
) -> Float[Tensor, "batch seq_len d_sae"]
encode(
    x: Float[Tensor, "batch seq_len d_model"],
    return_hidden_pre: Literal[True],
    **kwargs,
) -> Tuple[
    Float[Tensor, "batch seq_len d_sae"],
    Float[Tensor, "batch seq_len d_sae"],
]
encode(
    x: Float[Tensor, "batch seq_len d_model"],
    return_hidden_pre: bool = False,
    return_attention_pattern: bool = False,
    return_attention_score: bool = False,
    **kwargs,
) -> Union[
    Float[Tensor, "batch seq_len d_sae"],
    Tuple[
        Float[Tensor, "batch seq_len d_sae"],
        Float[Tensor, "batch seq_len d_sae"],
    ],
    Tuple[
        Float[Tensor, "batch seq_len d_sae"],
        Float[Tensor, "batch seq_len d_sae"],
        Float[Tensor, "batch n_qk_heads q_pos k_pos"],
    ],
]

Encode to sparse head activations.

Source code in src/lm_saes/lorsa.py
@override
def encode(
    self,
    x: Float[torch.Tensor, "batch seq_len d_model"],
    return_hidden_pre: bool = False,
    return_attention_pattern: bool = False,
    return_attention_score: bool = False,
    **kwargs,
) -> Union[
    Float[torch.Tensor, "batch seq_len d_sae"],
    Tuple[
        Float[torch.Tensor, "batch seq_len d_sae"],
        Float[torch.Tensor, "batch seq_len d_sae"],
    ],
    Tuple[
        Float[torch.Tensor, "batch seq_len d_sae"],
        Float[torch.Tensor, "batch seq_len d_sae"],
        Float[torch.Tensor, "batch n_qk_heads q_pos k_pos"],
    ],
]:
    """Encode to sparse head activations."""
    # Compute Q, K, V
    q, k, v = self._compute_qkv(x)

    pattern: Optional[torch.Tensor] = None
    scores: Optional[torch.Tensor] = None

    if not (return_attention_pattern or return_attention_score):
        query = q.permute(0, 2, 1, 3)
        key = k.permute(0, 2, 1, 3)
        value = v.reshape(*k.shape[:3], -1).permute(0, 2, 1, 3)
        with sdpa_kernel(
            backends=[
                SDPBackend.FLASH_ATTENTION,
                SDPBackend.CUDNN_ATTENTION,
                SDPBackend.EFFICIENT_ATTENTION,
                SDPBackend.MATH,
            ]
        ):
            z = F.scaled_dot_product_attention(
                query, key, value, scale=1 / self.attn_scale, is_causal=True, enable_gqa=True
            )
        hidden_pre = z.permute(0, 2, 1, 3).reshape(*v.shape)
    else:
        # Attention pattern
        # n_qk_heads batch q_pos k_pos
        q = q.permute(2, 0, 1, 3)  # (n_qk_heads, batch, seq_len, d_qk_head)
        k = k.permute(2, 0, 3, 1)  # (n_qk_heads, batch, d_qk_head, seq_len)
        scores = torch.einsum("nbqd,nbdk->nbqk", q, k) / self.attn_scale
        scores = self._apply_causal_mask(scores)
        pattern = F.softmax(scores, dim=-1)

        # Head outputs
        hidden_pre = self._compute_head_outputs(pattern, v)

    # Scale feature activations by decoder norm if configured
    if self.cfg.sparsity_include_decoder_norm:
        hidden_pre = hidden_pre * self.decoder_norm()

    feature_acts = self.activation_function(hidden_pre)

    if self.cfg.sparsity_include_decoder_norm:
        feature_acts = feature_acts / self.decoder_norm()
        hidden_pre = hidden_pre / self.decoder_norm()

    return_values: list[torch.Tensor] = [feature_acts]
    if return_hidden_pre:
        return_values.append(hidden_pre)
    if return_attention_pattern and pattern is not None:
        return_values.append(pattern.permute(1, 0, 2, 3))
    if return_attention_score and scores is not None:
        return_values.append(scores.permute(1, 0, 2, 3))
    return tuple(return_values) if len(return_values) > 1 else return_values[0]  # type: ignore[return-value]

decode

decode(feature_acts, **kwargs)

Decode head activations to output.

Source code in src/lm_saes/lorsa.py
@override
def decode(self, feature_acts, **kwargs):
    """Decode head activations to output."""
    if feature_acts.layout == torch.sparse_coo:
        return (
            torch.sparse.mm(
                feature_acts.to(torch.float32),
                self.W_O.to(torch.float32),
            ).to(self.cfg.dtype)
            + self.b_D
        )
    out = torch.einsum("bps,sd->bpd", feature_acts, self.W_O)
    if self.cfg.use_decoder_bias:
        out = out + self.b_D
    if isinstance(out, DTensor):
        out = DimMap({"data": 0}).redistribute(out)
    return out

set_decoder_to_fixed_norm

set_decoder_to_fixed_norm(value: float, force_exact: bool)

Set decoder weights to a fixed norm.

Source code in src/lm_saes/lorsa.py
@override
@torch.no_grad()
def set_decoder_to_fixed_norm(self, value: float, force_exact: bool):
    """Set decoder weights to a fixed norm."""
    if force_exact:
        self.W_O.mul_(value / self.decoder_norm(keepdim=True).mean())
    else:
        self.W_O.mul_(value / torch.clamp(self.decoder_norm(keepdim=True).mean(), min=value))

set_encoder_to_fixed_norm

set_encoder_to_fixed_norm(value: float)

Set encoder weights to fixed norm.

Source code in src/lm_saes/lorsa.py
@override
@torch.no_grad()
def set_encoder_to_fixed_norm(self, value: float):
    """Set encoder weights to fixed norm."""
    raise NotImplementedError("set_encoder_to_fixed_norm does not make sense for lorsa")

dim_maps

dim_maps() -> dict[str, DimMap]

Return a dictionary mapping parameter names to dimension maps.

Returns:

Type Description
dict[str, DimMap]

A dictionary mapping parameter names to DimMap objects.

Source code in src/lm_saes/lorsa.py
@override
def dim_maps(self) -> dict[str, DimMap]:
    """Return a dictionary mapping parameter names to dimension maps.

    Returns:
        A dictionary mapping parameter names to DimMap objects.
    """
    base_maps = super().dim_maps()
    return {
        **base_maps,
        "W_Q": DimMap({"model": 0}),
        "W_K": DimMap({"model": 0}),
        "W_V": DimMap({"model": 0}),
        "W_O": DimMap({"model": 0}),
        "b_Q": DimMap({"model": 0}),
        "b_K": DimMap({"model": 0}),
        "b_V": DimMap({"model": 0}),
        "b_D": DimMap({}),
    }

prepare_input

prepare_input(
    batch: dict[str, Tensor], **kwargs
) -> tuple[Tensor, dict[str, Any], dict[str, Any]]

Prepare input tensor.

Source code in src/lm_saes/lorsa.py
@override
def prepare_input(
    self, batch: dict[str, torch.Tensor], **kwargs
) -> tuple[torch.Tensor, dict[str, Any], dict[str, Any]]:
    """Prepare input tensor."""
    x = batch[self.cfg.hook_point_in]
    return x, {}, {}

prepare_label

prepare_label(batch: dict[str, Tensor], **kwargs)

Prepare label tensor.

Source code in src/lm_saes/lorsa.py
@override
def prepare_label(self, batch: dict[str, torch.Tensor], **kwargs):
    """Prepare label tensor."""
    label = batch[self.cfg.hook_point_out]
    return label

MOLTConfig pydantic-model

Bases: BaseSAEConfig

Configuration for Mixture of Linear Transforms (MOLT).

MOLT is a more efficient alternative to transcoders that sparsely replaces MLP computation in transformers. It converts dense MLP layers into sparse, interpretable linear transforms.

Config:

  • arbitrary_types_allowed: True

Fields:

hook_point_in pydantic-field

hook_point_in: str

Hook point to capture input activations from.

hook_point_out pydantic-field

hook_point_out: str

Hook point to output activations to.

rank_counts pydantic-field

rank_counts: dict[int, int]

Dictionary mapping rank values to their integer counts. Example: {4: 128, 8: 256, 16: 128} means 128 transforms of rank 4, 256 transforms of rank 8, and 128 transforms of rank 16.

d_sae property

d_sae: int

Calculate d_sae based on total rank counts.

available_ranks property

available_ranks: list[int]

Get sorted list of available ranks.

num_rank_types property

num_rank_types: int

Number of different rank types.

generate_rank_assignments

generate_rank_assignments() -> list[int]

Generate rank assignment for each of the d_sae linear transforms.

Returns:

Type Description
list[int]

List of rank assignments for each transform.

list[int]

For example: [1, 1, 1, 1, 2, 2, 4].

Source code in src/lm_saes/molt.py
def generate_rank_assignments(self) -> list[int]:
    """Generate rank assignment for each of the d_sae linear transforms.

    Returns:
        List of rank assignments for each transform.
        For example: [1, 1, 1, 1, 2, 2, 4].
    """
    assignments = []
    for rank in sorted(self.rank_counts.keys()):
        assignments.extend([rank] * self.rank_counts[rank])
    return assignments

get_local_rank_assignments

get_local_rank_assignments(
    model_parallel_size: int,
) -> list[int]

Get rank assignments for a specific local device in distributed running.

Each device gets all rank groups, with each group evenly divided across devices. This ensures consistent encoder/decoder sharding without feature_acts redistribution.

Parameters:

Name Type Description Default
model_parallel_size int

Number of model parallel devices for training and inference.

required

Returns:

Type Description
list[int]

List of rank assignments for this local device

list[int]

For example:

list[int]

global_rank_assignments = [1, 1, 2, 2], model_parallel_size = 2 -> local_rank_assignments = [1, 2]

Source code in src/lm_saes/molt.py
def get_local_rank_assignments(self, model_parallel_size: int) -> list[int]:
    """Get rank assignments for a specific local device in distributed running.

    Each device gets all rank groups, with each group evenly divided across devices.
    This ensures consistent encoder/decoder sharding without feature_acts redistribution.

    Args:
        model_parallel_size: Number of model parallel devices for training and inference.

    Returns:
        List of rank assignments for this local device
        For example:
        global_rank_assignments = [1, 1, 2, 2], model_parallel_size = 2 -> local_rank_assignments = [1, 2]
    """
    local_assignments = []
    for rank in sorted(self.rank_counts.keys()):
        global_count = self.rank_counts[rank]

        # Verify even division
        assert global_count % model_parallel_size == 0, (
            f"Transform rank {rank} global count {global_count} not divisible by "
            f"model_parallel_size {model_parallel_size}"
        )

        local_count = global_count // model_parallel_size
        local_assignments.extend([rank] * local_count)

    return local_assignments

MixtureOfLinearTransform

MixtureOfLinearTransform(
    cfg: MOLTConfig, device_mesh: DeviceMesh | None = None
)

Bases: AbstractSparseAutoEncoder

Mixture of Linear Transforms (MOLT) model.

MOLT is a sparse autoencoder variant that uses d_sae linear transforms, each with its own rank for UtVt decomposition.

Mathematical Formulation: - Encoder: ϕ(et · x - bt) where ϕ is the activation function - Decoder: Σᵢ fᵢ · (Uᵢ @ Vᵢ @ x) where fᵢ are feature activations - Decoder norm: ||UᵢVᵢ||_F for each transform i

The rank of each transform is determined by the rank_counts configuration, allowing for adaptive model capacity allocation.

Source code in src/lm_saes/molt.py
def __init__(self, cfg: MOLTConfig, device_mesh: DeviceMesh | None = None) -> None:
    super().__init__(cfg, device_mesh=device_mesh)
    self.cfg = cfg

    # Generate rank assignment for each linear transform
    if device_mesh is not None:
        # In distributed training/inference, get local rank assignments
        # Use model dimension for tensor parallelism
        mesh_dim_names = device_mesh.mesh_dim_names
        if mesh_dim_names is None:
            model_dim_index = 0
        else:
            model_dim_index = mesh_dim_names.index("model") if "model" in mesh_dim_names else 0
        local_rank = device_mesh.get_local_rank(
            mesh_dim=model_dim_index
        )  # this rank stands for device rank of this process
        model_parallel_size = device_mesh.size(mesh_dim=model_dim_index)

        self.rank_assignments = cfg.get_local_rank_assignments(model_parallel_size)

        for k, v in cfg.rank_counts.items():
            logger.info(
                f"Rank {k} has {v} global transforms, device rank {local_rank} has {self.rank_assignments.count(k)} transforms"
            )
    else:
        # Non-distributed case
        self.rank_assignments = cfg.generate_rank_assignments()

    # Encoder parameters (standard SAE encoder)
    if device_mesh is None:
        self.W_E = nn.Parameter(torch.empty(cfg.d_model, cfg.d_sae, device=cfg.device, dtype=cfg.dtype))
        self.b_E = nn.Parameter(torch.empty(cfg.d_sae, device=cfg.device, dtype=cfg.dtype))

        # Decoder parameters: d_sae linear transforms, each with UtVt decomposition
        # Group by rank for efficient parameter storage
        self.U_matrices = nn.ParameterDict()
        self.V_matrices = nn.ParameterDict()

        for rank in cfg.available_ranks:
            count = sum(1 for r in self.rank_assignments if r == rank)
            # Always create parameters for all rank types for consistency
            # In non-distributed case, we can skip empty tensors
            if count > 0:
                self.U_matrices[str(rank)] = nn.Parameter(
                    torch.empty(count, cfg.d_model, rank, device=cfg.device, dtype=cfg.dtype)
                )
                self.V_matrices[str(rank)] = nn.Parameter(
                    torch.empty(count, rank, cfg.d_model, device=cfg.device, dtype=cfg.dtype)
                )

        if cfg.use_decoder_bias:
            self.b_D = nn.Parameter(torch.empty(cfg.d_model, device=cfg.device, dtype=cfg.dtype))
    else:
        # Distributed initialization
        w_e_placements = self.dim_maps()["W_E"].placements(device_mesh)
        b_e_placements = self.dim_maps()["b_E"].placements(device_mesh)
        self.W_E = nn.Parameter(
            torch.distributed.tensor.empty(
                cfg.d_model,
                cfg.d_sae,
                dtype=cfg.dtype,
                device_mesh=device_mesh,
                placements=w_e_placements,
            )
        )

        self.b_E = nn.Parameter(
            torch.distributed.tensor.empty(
                cfg.d_sae,
                dtype=cfg.dtype,
                device_mesh=device_mesh,
                placements=b_e_placements,
            )
        )

        # Decoder parameters: d_sae linear transforms, each with UtVt decomposition
        # Group by rank for efficient parameter storage
        self.U_matrices = nn.ParameterDict()
        self.V_matrices = nn.ParameterDict()

        for rank in cfg.available_ranks:
            local_count = sum(1 for r in self.rank_assignments if r == rank)
            assert local_count > 0, f"Rank {rank} has local_count=0, sharding logic error"

            # Create DTensor with GLOBAL shape
            self.U_matrices[str(rank)] = nn.Parameter(
                torch.distributed.tensor.empty(
                    self.cfg.rank_counts[rank],  # GLOBAL count
                    cfg.d_model,
                    rank,
                    dtype=cfg.dtype,
                    device_mesh=device_mesh,
                    placements=self.dim_maps()["U_matrices"].placements(device_mesh),
                )
            )

            self.V_matrices[str(rank)] = nn.Parameter(
                torch.distributed.tensor.empty(
                    self.cfg.rank_counts[rank],  # GLOBAL count
                    rank,
                    cfg.d_model,
                    dtype=cfg.dtype,
                    device_mesh=device_mesh,
                    placements=self.dim_maps()["V_matrices"].placements(device_mesh),
                )
            )

        if cfg.use_decoder_bias:
            self.b_D = nn.Parameter(
                torch.distributed.tensor.empty(
                    cfg.d_model,
                    dtype=cfg.dtype,
                    device_mesh=device_mesh,
                    placements=self.dim_maps()["b_D"].placements(device_mesh),
                )
            )

dim_maps

dim_maps() -> dict[str, DimMap]

Return dimension maps for distributed training.

Encoder and decoder use consistent sharding: - W_E sharded along d_sae (output) dimension - U/V matrices sharded along transform count (first) dimension This ensures feature_acts from encoder can directly feed decoder without redistribution.

Source code in src/lm_saes/molt.py
def dim_maps(self) -> dict[str, DimMap]:
    """Return dimension maps for distributed training.

    Encoder and decoder use consistent sharding:
    - W_E sharded along d_sae (output) dimension
    - U/V matrices sharded along transform count (first) dimension
    This ensures feature_acts from encoder can directly feed decoder without redistribution.
    """
    base_maps = super().dim_maps()

    molt_maps = {
        "W_E": DimMap({"model": 1}),  # Shard along d_sae dimension
        "b_E": DimMap({"model": 0}),  # Shard along d_sae dimension
        # U and V matrices sharded along transform count dimension
        # This matches the W_E sharding pattern for feature_acts compatibility
        "U_matrices": DimMap({"model": 0}),  # Shard along transform count
        "V_matrices": DimMap({"model": 0}),  # Shard along transform count
        "b_D": DimMap({}),  # Replicate decoder bias
    }

    return base_maps | molt_maps

encoder_norm

encoder_norm(keepdim: bool = False) -> Tensor

Compute the norm of the encoder weight.

Source code in src/lm_saes/molt.py
@override
@timer.time("encoder_norm")
def encoder_norm(self, keepdim: bool = False) -> torch.Tensor:
    """Compute the norm of the encoder weight."""
    if not isinstance(self.W_E, DTensor):
        return torch.norm(self.W_E, p=2, dim=0, keepdim=keepdim).to(self.cfg.device)
    else:
        assert self.device_mesh is not None
        return DTensor.from_local(
            torch.norm(self.W_E.to_local(), p=2, dim=0, keepdim=keepdim),
            device_mesh=self.device_mesh,
            placements=DimMap({"model": 1 if keepdim else 0}).placements(self.device_mesh),
        )

decoder_norm

decoder_norm(keepdim: bool = False) -> Tensor

Compute the Frobenius norm of each linear transform's UtVt decomposition.

Source code in src/lm_saes/molt.py
@override
@timer.time("decoder_norm")
def decoder_norm(self, keepdim: bool = False) -> torch.Tensor:
    """Compute the Frobenius norm of each linear transform's UtVt decomposition."""
    # Pre-compute norms for all rank groups and concatenate
    norm_list = []

    for rank in self.cfg.available_ranks:
        rank_str = str(rank)
        if rank_str in self.U_matrices:
            U = self.U_matrices[rank_str]  # (count, d_model, rank)
            V = self.V_matrices[rank_str]  # (count, rank, d_model)

            assert isinstance(U, DTensor) == isinstance(V, DTensor), "U and V must have the same type"
            # Handle DTensor case - work with local shards
            if isinstance(U, DTensor) and isinstance(V, DTensor):
                U_local = U.to_local()
                V_local = V.to_local()

                # Compute ||U_i @ V_i||_F for each transform (local shard)
                UV_local = torch.bmm(U_local, V_local)  # (local_count, d_model, d_model)
                UV_norms_local = torch.norm(UV_local.view(UV_local.shape[0], -1), p="fro", dim=1)  # (local_count,)

                # Convert back to DTensor with proper placement
                assert self.device_mesh is not None
                UV_norms = DTensor.from_local(
                    UV_norms_local,
                    device_mesh=self.device_mesh,
                    placements=self.dim_maps()["U_matrices"].placements(self.device_mesh)[
                        0:1
                    ],  # Only keep first dimension placement
                )
                norm_list.append(UV_norms)
            else:
                # Non-distributed case
                UV = torch.bmm(U, V)  # (count, d_model, d_model)
                UV_norms = torch.norm(UV.view(UV.shape[0], -1), p="fro", dim=1)  # (count,)
                norm_list.append(UV_norms)

    if not norm_list:
        if self.device_mesh is not None:
            # Create replicated DTensor for zero norms
            norms = DTensor.from_local(
                torch.zeros(self.cfg.d_sae, device=self.cfg.device, dtype=self.cfg.dtype),
                device_mesh=self.device_mesh,
                placements=self.dim_maps()["b_E"].placements(self.device_mesh),  # Same as b_E sharding
            )
        else:
            norms = torch.zeros(self.cfg.d_sae, device=self.cfg.device, dtype=self.cfg.dtype)
    else:
        # Concatenate all norms in correct order
        if isinstance(norm_list[0], DTensor):
            # CRITICAL FIX: Avoid full_tensor() to prevent numerical errors
            # Instead, directly concatenate the DTensors which preserves numerical precision
            assert self.device_mesh is not None

            # Convert each DTensor norm to local tensor and concatenate locally
            local_norms = [norm.to_local() for norm in norm_list]

            # Concatenate local norms and convert back to DTensor
            norms_local = torch.cat(local_norms, dim=0)
            norms = DTensor.from_local(
                norms_local,
                device_mesh=self.device_mesh,
                placements=self.dim_maps()["b_E"].placements(self.device_mesh),  # Same as b_E (d_sae dimension)
            )
        else:
            norms = torch.cat(norm_list, dim=0)  # (d_sae,)

    if keepdim:
        return norms.unsqueeze(-1)
    else:
        return norms

set_encoder_to_fixed_norm

set_encoder_to_fixed_norm(value: float) -> None

Set encoder weights to a fixed norm.

Source code in src/lm_saes/molt.py
@torch.no_grad()
@timer.time("set_encoder_to_fixed_norm")
def set_encoder_to_fixed_norm(self, value: float) -> None:
    """Set encoder weights to a fixed norm."""
    self.W_E.mul_(value / self.encoder_norm(keepdim=True))

decode

decode(
    feature_acts: Union[
        Float[Tensor, "batch d_sae"],
        Float[Tensor, "batch seq_len d_sae"],
    ],
    **kwargs,
) -> Union[
    Float[Tensor, "batch d_model"],
    Float[Tensor, "batch seq_len d_model"],
]

Decode feature activations back to model space using MOLT transforms.

Parameters:

Name Type Description Default
feature_acts Union[Float[Tensor, 'batch d_sae'], Float[Tensor, 'batch seq_len d_sae']]

Feature activations from encode()

required
**kwargs

Must contain 'original_x' - the original input tensor

{}

Returns:

Type Description
Union[Float[Tensor, 'batch d_model'], Float[Tensor, 'batch seq_len d_model']]

Reconstructed tensor in model space

Source code in src/lm_saes/molt.py
@override
@timer.time("decode")
def decode(
    self,
    feature_acts: Union[
        Float[torch.Tensor, "batch d_sae"],
        Float[torch.Tensor, "batch seq_len d_sae"],
    ],
    **kwargs,
) -> Union[
    Float[torch.Tensor, "batch d_model"],
    Float[torch.Tensor, "batch seq_len d_model"],
]:
    """Decode feature activations back to model space using MOLT transforms.

    Args:
        feature_acts: Feature activations from encode()
        **kwargs: Must contain 'original_x' - the original input tensor

    Returns:
        Reconstructed tensor in model space
    """
    assert "original_x" in kwargs, "MOLT decode requires 'original_x' in kwargs"

    x = kwargs["original_x"]

    # Choose decoding strategy based on distributed setup
    is_distributed = any(
        isinstance(self.U_matrices[str(rank)], DTensor)
        for rank in self.cfg.available_ranks
        if str(rank) in self.U_matrices
    )

    if is_distributed:
        reconstruction = self._decode_distributed(feature_acts, x)
    else:
        reconstruction = self._decode_single_gpu(feature_acts, x)

    return reconstruction

compute_training_metrics

compute_training_metrics(
    *, l0: Tensor, feature_acts: Tensor, **kwargs
) -> dict[str, float]

Compute per-rank group training metrics for MOLT.

Source code in src/lm_saes/molt.py
@override
@torch.no_grad()
def compute_training_metrics(
    self,
    *,
    l0: torch.Tensor,
    feature_acts: torch.Tensor,
    **kwargs,
) -> dict[str, float]:
    """Compute per-rank group training metrics for MOLT."""
    metrics = {}
    feature_idx = 0
    total_rank_sum = 0.0

    for rank in self.cfg.available_ranks:
        rank_str = str(rank)
        if rank_str in self.U_matrices:
            # Extract features for this rank group
            end_idx = (
                feature_idx + self.cfg.rank_counts[rank]
            )  # rank_counts[rank] is the GLOBAL count of this rank group
            rank_features = feature_acts[..., feature_idx:end_idx]

            # Count active transforms (l0) for this rank group
            rank_l0 = (rank_features > 0).float().sum(-1)
            rank_l0_mean = item(rank_l0.mean())

            # Record metrics
            metrics[f"molt_metrics/l0_rank{rank}"] = rank_l0_mean
            metrics[f"molt_metrics/l0_rank{rank}_ratio"] = rank_l0_mean / self.cfg.rank_counts[rank]
            total_rank_sum += rank_l0_mean * rank

            feature_idx += self.cfg.rank_counts[rank]

    # Record total rank sum
    metrics["molt_metrics/total_rank_sum"] = total_rank_sum
    return metrics

forward

forward(
    x: Union[
        Float[Tensor, "batch d_model"],
        Float[Tensor, "batch seq_len d_model"],
    ],
    encoder_kwargs: dict[str, Any] = {},
    decoder_kwargs: dict[str, Any] = {},
) -> Union[
    Float[Tensor, "batch d_model"],
    Float[Tensor, "batch seq_len d_model"],
]

Forward pass through the autoencoder. Ensure that the input activations are normalized by calling normalize_activations before calling this method.

Source code in src/lm_saes/molt.py
@override
@timer.time("forward")
def forward(
    self,
    x: Union[
        Float[torch.Tensor, "batch d_model"],
        Float[torch.Tensor, "batch seq_len d_model"],
    ],
    encoder_kwargs: dict[str, Any] = {},
    decoder_kwargs: dict[str, Any] = {},
) -> Union[
    Float[torch.Tensor, "batch d_model"],
    Float[torch.Tensor, "batch seq_len d_model"],
]:
    """Forward pass through the autoencoder.
    Ensure that the input activations are normalized by calling `normalize_activations` before calling this method.
    """
    feature_acts = self.encode(x, **encoder_kwargs)
    reconstructed = self.decode(feature_acts, original_x=x)
    return reconstructed