Skip to content

Analysis

Post-training feature analysis and interpretability tools.

FeatureAnalyzerConfig pydantic-model

Bases: BaseConfig

Config:

  • arbitrary_types_allowed: True

Fields:

total_analyzing_tokens pydantic-field

total_analyzing_tokens: int

Total number of tokens to analyze

ignore_token_ids pydantic-field

ignore_token_ids: list[int] | None = None

Tokens to ignore in the activations.

subsamples pydantic-field

subsamples: dict[str, dict[str, int | float]]

Dictionary mapping subsample names to their parameters: - proportion: Proportion of max activation to consider - n_samples: Number of samples to keep - max_length: Maximum length of the sample

clt_layer pydantic-field

clt_layer: int | None = None

Layer to analyze for CLT. Provided iff analyzing CLT.

FeatureAnalyzer

FeatureAnalyzer(cfg: FeatureAnalyzerConfig)

Analyzes feature activations from a sparse autoencoder.

This class processes activation data from a sparse autoencoder to: 1. Track activation statistics like frequency and magnitude 2. Sample and store representative activations 3. Organize results by feature for analysis

Initialize the feature analyzer.

Parameters:

Name Type Description Default
cfg FeatureAnalyzerConfig

Analysis configuration specifying parameters like sample sizes and thresholds

required
Source code in src/lm_saes/analysis/feature_analyzer.py
def __init__(
    self,
    cfg: FeatureAnalyzerConfig,
):
    """Initialize the feature analyzer.

    Args:
        cfg: Analysis configuration specifying parameters like sample sizes and thresholds
    """
    self.cfg = cfg

compute_ignore_token_masks

compute_ignore_token_masks(
    tokens: Tensor,
    ignore_token_ids: Optional[list[int]] = None,
) -> Tensor

Compute ignore token masks for the given tokens.

Parameters:

Name Type Description Default
tokens Tensor

The tokens to compute the ignore token masks for

required
ignore_token_ids Optional[list[int]]

The token IDs to ignore

None
Source code in src/lm_saes/analysis/feature_analyzer.py
def compute_ignore_token_masks(
    self, tokens: torch.Tensor, ignore_token_ids: Optional[list[int]] = None
) -> torch.Tensor:
    """Compute ignore token masks for the given tokens.

    Args:
        tokens: The tokens to compute the ignore token masks for
        ignore_token_ids: The token IDs to ignore
    """
    if ignore_token_ids is None:
        warnings.warn(
            "ignore_token_ids are not provided. No tokens (including pad tokens) will be filtered out. If this is intentional, set ignore_token_ids explicitly to an empty list to avoid this warning.",
            UserWarning,
            stacklevel=2,
        )
        ignore_token_ids = []
    mask = torch.ones_like(tokens, dtype=torch.bool)
    for token_id in ignore_token_ids:
        mask &= tokens != token_id
    return mask

get_post_analysis_func

get_post_analysis_func(sae_type: str)

Get the post-analysis processor for the given SAE type.

Parameters:

Name Type Description Default
sae_type str

The SAE type identifier

required

Returns:

Type Description

The post-analysis processor instance

Source code in src/lm_saes/analysis/feature_analyzer.py
def get_post_analysis_func(self, sae_type: str):
    """Get the post-analysis processor for the given SAE type.

    Args:
        sae_type: The SAE type identifier

    Returns:
        The post-analysis processor instance
    """
    try:
        return get_post_analysis_processor(sae_type)
    except KeyError:
        # Fallback to generic processor if no specific processor is registered
        return get_post_analysis_processor("generic")

analyze_chunk

analyze_chunk(
    activation_factory: ActivationFactory,
    sae: AbstractSparseAutoEncoder,
    device_mesh: DeviceMesh | None = None,
    activation_factory_process_kwargs: dict[str, Any] = {},
) -> list[dict[str, Any]]

Analyze feature activations for a chunk of the SAE.

Processes activation data to: 1. Track activation statistics 2. Sample representative activations 3. Organize results by feature

Parameters:

Name Type Description Default
activation_factory ActivationFactory

The activation factory to use

required
sae AbstractSparseAutoEncoder

The sparse autoencoder model

required
device_mesh DeviceMesh | None

The device mesh to use

None
activation_factory_process_kwargs dict[str, Any]

Keyword arguments to pass to the activation factory's process method

{}

Returns:

Type Description
list[dict[str, Any]]

List of dictionaries containing per-feature analysis results:

list[dict[str, Any]]
  • Activation counts and maximums
list[dict[str, Any]]
  • Sampled activations with metadata
Source code in src/lm_saes/analysis/feature_analyzer.py
@torch.no_grad()
def analyze_chunk(
    self,
    activation_factory: ActivationFactory,
    sae: AbstractSparseAutoEncoder,
    device_mesh: DeviceMesh | None = None,
    activation_factory_process_kwargs: dict[str, Any] = {},
) -> list[dict[str, Any]]:
    """Analyze feature activations for a chunk of the SAE.

    Processes activation data to:
    1. Track activation statistics
    2. Sample representative activations
    3. Organize results by feature

    Args:
        activation_factory: The activation factory to use
        sae: The sparse autoencoder model
        device_mesh: The device mesh to use
        activation_factory_process_kwargs: Keyword arguments to pass to the activation factory's process method

    Returns:
        List of dictionaries containing per-feature analysis results:
        - Activation counts and maximums
        - Sampled activations with metadata
    """
    activation_stream = activation_factory.process(**activation_factory_process_kwargs)
    n_tokens = n_analyzed_tokens = 0

    # Progress tracking
    pbar = tqdm(
        total=self.cfg.total_analyzing_tokens,
        desc="Analyzing SAE",
        smoothing=0.01,
        disable=not is_primary_rank(device_mesh),
    )

    if device_mesh is not None and device_mesh.mesh_dim_names is not None and "model" in device_mesh.mesh_dim_names:
        d_sae_local = sae.cfg.d_sae // device_mesh["model"].size()
    else:
        d_sae_local = sae.cfg.d_sae

    # Initialize tracking variables
    sample_result = {k: None for k in self.cfg.subsamples.keys()}
    if device_mesh is not None:
        act_times = torch.distributed.tensor.zeros(
            (sae.cfg.d_sae,),
            dtype=torch.long,
            device_mesh=device_mesh,
            placements=DimMap({"model": 0}).placements(device_mesh),
        )
        max_feature_acts = torch.distributed.tensor.zeros(
            (sae.cfg.d_sae,),
            dtype=sae.cfg.dtype,
            device_mesh=device_mesh,
            placements=DimMap({"model": 0}).placements(device_mesh),
        )
    else:
        act_times = torch.zeros((d_sae_local,), dtype=torch.long, device=sae.cfg.device)
        max_feature_acts = torch.zeros((d_sae_local,), dtype=sae.cfg.dtype, device=sae.cfg.device)
    mapper = KeyedDiscreteMapper()

    # TODO: Make a wrapper for CLT
    if isinstance(sae, CrossLayerTranscoder):
        sae.encode = partial(sae.encode_single_layer, layer=self.cfg.clt_layer)  # type: ignore
        sae.prepare_input = partial(sae.prepare_input_single_layer, layer=self.cfg.clt_layer)  # type: ignore
        sae.decoder_norm_per_feature = partial(sae.decoder_norm_per_feature, layer=self.cfg.clt_layer)  # type: ignore
        sae.keep_only_decoders_for_layer_from(self.cfg.clt_layer)  # type: ignore
        torch.cuda.empty_cache()

    # Process activation batches
    for batch in activation_stream:
        # Reshape meta to zip outer dimensions to inner
        meta = {k: [m[k] for m in batch["meta"]] for k in batch["meta"][0].keys()}

        # Get feature activations from SAE
        x, encoder_kwargs, _ = sae.prepare_input(batch)
        tokens = batch["tokens"]
        feature_acts: torch.Tensor = sae.encode(x, **encoder_kwargs)
        if isinstance(feature_acts, DTensor):
            assert device_mesh is not None, "Device mesh is required for DTensor feature activations"
            if device_mesh is not feature_acts.device_mesh:
                feature_acts = DTensor.from_local(
                    feature_acts.redistribute(
                        placements=DimMap({"head": -1, "model": -1}).placements(feature_acts.device_mesh)
                    ).to_local(),
                    device_mesh,
                    placements=DimMap({"model": -1}).placements(device_mesh),
                )
                # TODO: Remove this once redistributing across device meshes is supported
            feature_acts = feature_acts.redistribute(placements=DimMap({"model": -1}).placements(device_mesh))
            if not isinstance(tokens, DTensor):
                tokens = DTensor.from_local(tokens, device_mesh, placements=DimMap({}).placements(device_mesh))
        if isinstance(sae, CrossCoder):
            feature_acts = feature_acts.amax(dim=-2)
        assert feature_acts.shape == (tokens.shape[0], tokens.shape[1], sae.cfg.d_sae), (
            f"feature_acts.shape: {feature_acts.shape}, expected: {(tokens.shape[0], tokens.shape[1], sae.cfg.d_sae)}"
        )

        # Compute and apply ignore token masks
        if self.cfg.ignore_token_ids is None and batch.get("mask") is not None:
            ignore_token_masks = batch["mask"]
            if device_mesh is not None and not isinstance(ignore_token_masks, DTensor):
                ignore_token_masks = DTensor.from_local(
                    ignore_token_masks, device_mesh, placements=DimMap({}).placements(device_mesh)
                )
        else:
            ignore_token_masks = self.compute_ignore_token_masks(tokens, self.cfg.ignore_token_ids)
        feature_acts *= rearrange(ignore_token_masks, "batch_size n_ctx -> batch_size n_ctx 1")

        # Update activation statistics
        active_feature_count = feature_acts.gt(0.0).sum(dim=[0, 1])
        act_times += active_feature_count
        max_feature_acts = torch.max(max_feature_acts, feature_acts.max(dim=0).values.max(dim=0).values)

        # Apply discrete mapper encoding only to string metadata, keep others as-is
        discrete_meta = {}
        for k, v in meta.items():
            if all(isinstance(item, str) for item in v):
                # Apply discrete mapper encoding to string metadata
                discrete_meta[k] = torch.tensor(mapper.encode(k, v), device=sae.cfg.device, dtype=torch.int32)
            else:
                # Keep non-string metadata as-is (assuming they are already tensors or can be converted)
                discrete_meta[k] = torch.tensor(v, device=sae.cfg.device)
        if device_mesh is not None:
            discrete_meta = {
                k: DTensor.from_local(
                    local_tensor=repeat(v, "batch_size -> batch_size d_sae", d_sae=d_sae_local),
                    device_mesh=device_mesh,
                    placements=DimMap({"model": 1}).placements(device_mesh),
                )
                for k, v in discrete_meta.items()
            }
        else:
            discrete_meta = {
                k: repeat(v, "batch_size -> batch_size d_sae", d_sae=d_sae_local) for k, v in discrete_meta.items()
            }
        sample_result = self._process_batch(
            feature_acts, discrete_meta, sample_result, max_feature_acts, device_mesh
        )

        # Update progress
        n_tokens_current = tokens.numel()
        n_tokens += n_tokens_current
        n_analyzed_tokens += cast(int, item(ignore_token_masks.int().sum()))
        pbar.update(n_tokens_current)
        if n_tokens >= self.cfg.total_analyzing_tokens:
            break

    pbar.close()

    # Filter out None values and format final per-feature results
    sample_result = {k: v for k, v in sample_result.items() if v is not None}
    sample_result = {
        name: {k: to_local(v) for k, v in subsample.items()} for name, subsample in sample_result.items()
    }

    return self.get_post_analysis_func(sae.cfg.sae_type).process(
        sae=sae,
        act_times=to_local(act_times),
        n_analyzed_tokens=n_analyzed_tokens,
        max_feature_acts=to_local(max_feature_acts),
        sample_result=sample_result,
        mapper=mapper,
        device_mesh=device_mesh,
        activation_factory=activation_factory,
        activation_factory_process_kwargs=activation_factory_process_kwargs,
    )

DirectLogitAttributorConfig pydantic-model

Bases: BaseConfig

Fields:

top_k pydantic-field

top_k: int = 10

The number of top tokens to attribute to.

clt_layer pydantic-field

clt_layer: int | None = None

Layer to analyze for CLT. Provided iff analyzing CLT.

DirectLogitAttributor

DirectLogitAttributor(cfg: DirectLogitAttributorConfig)
Source code in src/lm_saes/analysis/direct_logit_attributor.py
def __init__(self, cfg: DirectLogitAttributorConfig):
    self.cfg = cfg

direct_logit_attribute

direct_logit_attribute(
    sae, model: LanguageModel, layer_idx: int | None = None
)

Compute direct logit attribution for the given SAE.

Parameters:

Name Type Description Default
sae

The SAE model to attribute.

required
model LanguageModel

The language model backend.

required
layer_idx int | None

The layer index (required for some SAE types like CrossLayerTranscoder).

None

Returns:

Type Description

A list of dictionaries containing top positive and negative logits for each feature.

Source code in src/lm_saes/analysis/direct_logit_attributor.py
@torch.no_grad()
def direct_logit_attribute(self, sae, model: LanguageModel, layer_idx: int | None = None):
    """Compute direct logit attribution for the given SAE.

    Args:
        sae: The SAE model to attribute.
        model: The language model backend.
        layer_idx: The layer index (required for some SAE types like CrossLayerTranscoder).

    Returns:
        A list of dictionaries containing top positive and negative logits for each feature.
    """
    assert isinstance(model, TransformerLensLanguageModel), (
        "DirectLogitAttributor only supports TransformerLensLanguageModel as the model backend"
    )
    hooked_model: HookedTransformer | None = model.model
    assert hooked_model is not None, "Model ckpt must be loaded for direct logit attribution"

    # Use singledispatch to compute logits and d_sae based on SAE type
    logits, d_sae = compute_logits_and_d_sae(sae, hooked_model, layer_idx)

    # Select the top k tokens
    top_k_logits, top_k_indices = torch.topk(logits, self.cfg.top_k, dim=-1)
    top_k_tokens = [hooked_model.to_str_tokens(top_k_indices[i]) for i in range(d_sae)]

    assert top_k_logits.shape == top_k_indices.shape == (d_sae, self.cfg.top_k), (
        f"Top k logits and indices should have shape (d_sae, top_k), but got {top_k_logits.shape} and {top_k_indices.shape}"
    )
    assert (len(top_k_tokens), len(top_k_tokens[0])) == (d_sae, self.cfg.top_k), (
        f"Top k tokens should have shape (d_sae, top_k), but got {len(top_k_tokens)} and {len(top_k_tokens[0])}"
    )

    # Select the bottom k tokens
    bottom_k_logits, bottom_k_indices = torch.topk(logits, self.cfg.top_k, dim=-1, largest=False)
    bottom_k_tokens = [hooked_model.to_str_tokens(bottom_k_indices[i]) for i in range(d_sae)]

    assert bottom_k_logits.shape == bottom_k_indices.shape == (d_sae, self.cfg.top_k), (
        f"Bottom k logits and indices should have shape (d_sae, top_k), but got {bottom_k_logits.shape} and {bottom_k_indices.shape}"
    )
    assert (len(bottom_k_tokens), len(bottom_k_tokens[0])) == (d_sae, self.cfg.top_k), (
        f"Bottom k tokens should have shape (d_sae, top_k), but got {len(bottom_k_tokens)} and {len(bottom_k_tokens[0])}"
    )

    result = [
        {
            "top_positive": [
                {"token": token, "logit": logit} for token, logit in zip(top_k_tokens[i], top_k_logits[i].tolist())
            ],
            "top_negative": [
                {"token": token, "logit": logit}
                for token, logit in zip(bottom_k_tokens[i], bottom_k_logits[i].tolist())
            ],
        }
        for i in range(d_sae)
    ]
    return result