Skip to content

Training

Training infrastructure: trainer, optimizer configs, initialization, and logging.

TrainerConfig pydantic-model

Bases: BaseConfig

Config:

  • arbitrary_types_allowed: True

Fields:

  • l1_coefficient (float | None)
  • l1_coefficient_warmup_steps (int | float)
  • lp_coefficient (float | None)
  • auxk_coefficient (float | None)
  • amp_dtype (dtype | None)
  • sparsity_loss_type (Literal['power', 'tanh', 'tanh-quad', None])
  • tanh_stretch_coefficient (float)
  • frequency_scale (float)
  • p (int)
  • initial_k (int | float | None)
  • k_warmup_steps (int | float)
  • k_cold_booting_steps (int | float)
  • k_schedule_type (Literal['linear', 'exponential'])
  • k_exponential_factor (float)
  • k_aux (int)
  • dead_threshold (float)
  • skip_metrics_calculation (bool)
  • gradient_accumulation_steps (int)
  • lr (float | dict[str, float])
  • betas (Tuple[float, float])
  • optimizer_class (Literal['adam', 'sparseadam'])
  • optimizer_foreach (bool)
  • lr_scheduler_name (Literal['constant', 'constantwithwarmup', 'linearwarmupdecay', 'cosineannealing', 'cosineannealingwarmup', 'exponentialwarmup'])
  • lr_end_ratio (float)
  • lr_warm_up_steps (int | float)
  • lr_cool_down_steps (int | float)
  • jumprelu_lr_factor (float)
  • clip_grad_norm (float)
  • feature_sampling_window (int)
  • total_training_tokens (int)
  • log_frequency (int)
  • eval_frequency (int)
  • n_checkpoints (int)
  • check_point_save_mode (Literal['log', 'linear'])
  • from_pretrained_path (str | None)
  • exp_result_path (str)

l1_coefficient pydantic-field

l1_coefficient: float | None = 8e-05

Coefficient for the L1 sparsity loss. This loss is used to penalize the sparsity of the feature activations.

l1_coefficient_warmup_steps pydantic-field

l1_coefficient_warmup_steps: int | float = 0.1

Steps (int) or fraction of total steps (float) to warm up the sparsity coefficient from 0.

lp_coefficient pydantic-field

lp_coefficient: float | None = None

Coefficient for the Lp sparsity loss. This loss is used to . To use the JumpReLU \(L^p\) penalty, set lp_coefficient to a positive value.

auxk_coefficient pydantic-field

auxk_coefficient: float | None = None

Coefficient for the Aux-K auxiliary loss. This loss is used to revive dead latents during training. To use the Aux-K loss, set auxk_coefficient to a positive value.

Trainer

Trainer(cfg: TrainerConfig)
Source code in src/lm_saes/trainer.py
def __init__(self, cfg: TrainerConfig):
    self.cfg = cfg
    self.checkpoint_thresholds: list[int] = []
    self.total_training_steps: int = 0
    self.lr_warm_up_steps: int = 0
    self.lr_cool_down_steps: int = 0
    self.k_warmup_steps: int = 0
    self.k_cold_booting_steps: int = 0
    self.l1_coefficient_warmup_steps: int = 0
    self.cur_step: int = 0
    self.cur_tokens: int = 0
    self.optimizer: Optimizer | None = None
    self.scheduler: lr_scheduler.LRScheduler | None = None
    self.wandb_logger: Run | None = None
    self.metrics: list[Metric] = []
    # Dead statistics for auxk loss
    self.tokens_since_last_activation: Tensor | None = None
    self.is_dead: Tensor | None = None

save_checkpoint

save_checkpoint(
    sae: AbstractSparseAutoEncoder,
    checkpoint_path: Path | str,
) -> None

Save a complete checkpoint including model, optimizer, scheduler, and trainer state.

Parameters:

Name Type Description Default
sae AbstractSparseAutoEncoder

The sparse autoencoder model to save

required
checkpoint_path Path | str

Path where to save the checkpoint (without extension)

required
Source code in src/lm_saes/trainer.py
def save_checkpoint(self, sae: AbstractSparseAutoEncoder, checkpoint_path: Path | str) -> None:
    """
    Save a complete checkpoint including model, optimizer, scheduler, and
    trainer state.

    Args:
        sae: The sparse autoencoder model to save
        checkpoint_path: Path where to save the checkpoint (without extension)
    """

    # Create checkpoint directory if it doesn't exist
    checkpoint_dir = Path(checkpoint_path) / "checkpoints" / f"step_{self.cur_step}"

    if checkpoint_dir and not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir, exist_ok=True)

    sae.cfg.save_hyperparameters(checkpoint_dir)
    # Save model state
    if sae.device_mesh is None:
        sae.save_checkpoint(checkpoint_dir / "sae_weights.safetensors")
    else:
        sae.save_checkpoint(checkpoint_dir / "sae_weights.dcp")

    if is_primary_rank(sae.device_mesh):
        # Prepare trainer state
        trainer_state = {
            "cur_step": self.cur_step,
            "cur_tokens": self.cur_tokens,
            "total_training_steps": self.total_training_steps,
            "lr_warm_up_steps": self.lr_warm_up_steps,
            "lr_cool_down_steps": self.lr_cool_down_steps,
            "k_warmup_steps": self.k_warmup_steps,
            "k_cold_booting_steps": self.k_cold_booting_steps,
            "l1_coefficient_warmup_steps": self.l1_coefficient_warmup_steps,
            "checkpoint_thresholds": self.checkpoint_thresholds,
            "cfg": self.cfg,
        }
        # Save trainer state
        trainer_path = checkpoint_dir / "trainer.pt"
        torch.save(trainer_state, trainer_path)
        if self.wandb_logger is not None:
            with open(checkpoint_dir / "wandb_run_id.json", "w") as f:
                json.dump({"wandb_run_id": self.wandb_logger.id}, f)
    # Save optimizer state - handle distributed tensors
    if self.optimizer is not None:
        if sae.device_mesh is None:
            if is_primary_rank(sae.device_mesh):
                optimizer_path = checkpoint_dir / "optimizer.pt"
                optimizer_state = self.optimizer.state_dict()
                torch.save(optimizer_state, optimizer_path)
        else:
            optimizer_path = checkpoint_dir / "optimizer.dcp"
            optimizer_state = self.optimizer.state_dict()
            fs_writer = FileSystemWriter(optimizer_path)
            dcp.save(optimizer_state, storage_writer=fs_writer)

    # Save scheduler state - handle distributed tensors
    if self.scheduler is not None:
        if sae.device_mesh is None:
            if is_primary_rank(sae.device_mesh):
                scheduler_path = checkpoint_dir / "scheduler.pt"
                scheduler_state = self.scheduler.state_dict()
                torch.save(scheduler_state, scheduler_path)
        else:
            scheduler_path = checkpoint_dir / "scheduler.dcp"
            scheduler_state = self.scheduler.state_dict()
            fs_writer = FileSystemWriter(scheduler_path)
            dcp.save(scheduler_state, storage_writer=fs_writer)

    logger.info(f"Checkpoint saved to {checkpoint_path}")

from_checkpoint classmethod

from_checkpoint(
    sae: AbstractSparseAutoEncoder, checkpoint_path: str
) -> Trainer

Load a complete checkpoint including model, optimizer, scheduler, and trainer state.

Parameters:

Name Type Description Default
device_mesh

The device mesh to load the model into

required
checkpoint_path str

Path where the checkpoint was saved (without extension)

required

Returns:

Name Type Description
Trainer Trainer

A new trainer instance with loaded state

Source code in src/lm_saes/trainer.py
@classmethod
def from_checkpoint(
    cls,
    sae: AbstractSparseAutoEncoder,
    checkpoint_path: str,
) -> "Trainer":
    """
    Load a complete checkpoint including model, optimizer, scheduler, and
    trainer state.

    Args:
        device_mesh: The device mesh to load the model into
        checkpoint_path: Path where the checkpoint was saved (without extension)

    Returns:
        Trainer: A new trainer instance with loaded state
    """
    # Load trainer state first to get the config
    checkpoint_dir = Path(checkpoint_path)
    trainer_path = checkpoint_dir / "trainer.pt"
    if os.path.exists(trainer_path):
        trainer_state = torch.load(trainer_path, map_location="cpu", weights_only=False)
        cfg = trainer_state.get("cfg")
        if cfg is None:
            raise ValueError("Checkpoint does not contain trainer config")

        # Create trainer instance with loaded config
        trainer = cls(cfg)
        trainer.cfg.from_pretrained_path = checkpoint_path

        # Restore trainer state variables
        trainer.cur_step = trainer_state["cur_step"]
        trainer.cur_tokens = trainer_state["cur_tokens"]
        trainer.total_training_steps = trainer_state["total_training_steps"]
        trainer.lr_warm_up_steps = trainer_state["lr_warm_up_steps"]
        trainer.lr_cool_down_steps = trainer_state["lr_cool_down_steps"]
        trainer.k_warmup_steps = trainer_state["k_warmup_steps"]
        trainer.k_cold_booting_steps = trainer_state["k_cold_booting_steps"]
        trainer.l1_coefficient_warmup_steps = trainer_state["l1_coefficient_warmup_steps"]
        trainer.checkpoint_thresholds = trainer_state["checkpoint_thresholds"]

        logger.info(f"Loaded trainer state from step {trainer.cur_step}")
    else:
        raise ValueError(f"Trainer checkpoint not found at {trainer_path}")

    trainer._initialize_optimizer(sae)
    assert trainer.optimizer is not None and trainer.scheduler is not None, (
        "Optimizer and scheduler should be already initialized"
    )

    # Load optimizer state
    if sae.device_mesh is None:
        optimizer_path = checkpoint_dir / "optimizer.pt"
        optimizer_state = torch.load(optimizer_path, map_location="cpu")
        trainer.optimizer.load_state_dict(optimizer_state)
        logger.info("Loaded optimizer state")
    else:
        optimizer_path = checkpoint_dir / "optimizer.dcp"
        fs_reader = FileSystemReader(str(optimizer_path))
        optimizer_state = trainer.optimizer.state_dict()
        dcp.load(optimizer_state, storage_reader=fs_reader)
        trainer.optimizer.load_state_dict(optimizer_state)
        logger.info("Loaded optimizer state")
        logger.info(f"trainer.optimizer.state_dict(): {trainer.optimizer.state_dict()}")

    # Load scheduler state
    if sae.device_mesh is None:
        scheduler_path = checkpoint_dir / "scheduler.pt"
        scheduler_state = torch.load(scheduler_path, map_location="cpu")
        trainer.scheduler.load_state_dict(scheduler_state)
        logger.info("Loaded scheduler state")
    else:
        scheduler_path = checkpoint_dir / "scheduler.dcp"
        fs_reader = FileSystemReader(str(scheduler_path))
        scheduler_state = trainer.scheduler.state_dict()
        dcp.load(scheduler_state, storage_reader=fs_reader)
        trainer.scheduler.load_state_dict(scheduler_state)
        logger.info("Loaded scheduler state")
        logger.info(f"trainer.scheduler.state_dict(): {trainer.scheduler.state_dict()}")

    logger.info(f"Checkpoint loaded from {checkpoint_path}")
    return trainer

update_dead_statistics

update_dead_statistics(
    feature_acts: Tensor,
    mask: Tensor | None,
    specs: tuple[str, ...],
) -> Tensor

Update the dead latents tracking based on current feature activations.

Parameters:

Name Type Description Default
feature_acts Tensor

Feature activations tensor of shape (batch, d_sae) or (batch, seq_len, d_sae)

required

Returns:

Name Type Description
is_dead Tensor

Boolean tensor indicating which features are dead.

Source code in src/lm_saes/trainer.py
@torch.no_grad()
def update_dead_statistics(self, feature_acts: Tensor, mask: Tensor | None, specs: tuple[str, ...]) -> Tensor:
    """Update the dead latents tracking based on current feature activations.

    Args:
        feature_acts: Feature activations tensor of shape (batch, d_sae) or (batch, seq_len, d_sae)

    Returns:
        is_dead: Boolean tensor indicating which features are dead.
    """
    assert self.tokens_since_last_activation is not None, (
        "tokens_since_last_activation must be initialized before calling update_dead_statistics"
    )
    assert self.is_dead is not None, "is_dead must be initialized before calling update_dead_statistics"

    valid_tokens = mask.sum() if mask is not None else feature_acts[..., 0].numel()

    feature_acts_sum, _ = apply_token_mask(feature_acts, specs, mask, "sum")
    activated = feature_acts_sum.gt(0)

    self.tokens_since_last_activation = torch.where(
        activated,
        torch.zeros_like(self.tokens_since_last_activation),
        self.tokens_since_last_activation + valid_tokens,
    )
    self.is_dead = self.tokens_since_last_activation >= self.cfg.dead_threshold
    return self.is_dead

WandbConfig pydantic-model

Bases: BaseConfig

Fields:

  • wandb_project (str)
  • exp_name (str | None)
  • wandb_entity (str | None)
  • wandb_run_id (str | None)
  • wandb_resume (Literal['allow', 'must', 'never', 'auto'])

InitializerConfig pydantic-model

Bases: BaseConfig

Fields:

  • bias_init_method (Literal['all_zero', 'geometric_median'])
  • decoder_uniform_bound (float)
  • encoder_uniform_bound (float)
  • init_encoder_with_decoder_transpose (bool)
  • init_encoder_with_decoder_transpose_factor (float)
  • init_log_jumprelu_threshold_value (float | None)
  • grid_search_init_norm (bool)
  • initialize_W_D_with_active_subspace (bool)
  • d_active_subspace (int | None)
  • initialize_lorsa_with_mhsa (bool | None)
  • initialize_tc_with_mlp (bool | None)
  • model_layer (int | None)
  • init_encoder_bias_with_mean_hidden_pre (bool)

Initializer

Initializer(cfg: InitializerConfig)
Source code in src/lm_saes/initializer.py
def __init__(self, cfg: InitializerConfig):
    self.cfg = cfg

initialize_parameters

initialize_parameters(sae: AbstractSparseAutoEncoder)

Initialize the parameters of the SAE. Only used when the state is "training" to initialize sae.

Source code in src/lm_saes/initializer.py
@torch.no_grad()
def initialize_parameters(self, sae: AbstractSparseAutoEncoder):
    """Initialize the parameters of the SAE.
    Only used when the state is "training" to initialize sae.
    """

    sae.init_parameters(
        encoder_uniform_bound=self.cfg.encoder_uniform_bound,
        decoder_uniform_bound=self.cfg.decoder_uniform_bound,
        init_log_jumprelu_threshold_value=self.cfg.init_log_jumprelu_threshold_value,
    )

    if self.cfg.init_encoder_with_decoder_transpose:
        sae.init_encoder_with_decoder_transpose(self.cfg.init_encoder_with_decoder_transpose_factor)

    return sae
initialization_search(
    sae: AbstractSparseAutoEncoder,
    activation_batch: Dict[str, Tensor],
    wandb_logger: Run | None = None,
)

This function is used to search for the best initialization norm for the SAE decoder.

Source code in src/lm_saes/initializer.py
@torch.no_grad()
def initialization_search(
    self,
    sae: AbstractSparseAutoEncoder,
    activation_batch: Dict[str, Tensor],
    wandb_logger: Run | None = None,
):
    """
    This function is used to search for the best initialization norm for the SAE decoder.
    """
    batch = sae.normalize_activations(activation_batch)

    if self.cfg.bias_init_method == "geometric_median":
        assert sae.b_D is not None, "Decoder bias should exist if use_decoder_bias is True"
        if isinstance(sae, CrossLayerTranscoder):
            for i in range(sae.cfg.n_layers):
                hook_point_out = sae.cfg.hook_points_out[i]
                normalized_mean_activation = batch[hook_point_out].mean(0)
                sae.b_D[i].copy_(normalized_mean_activation)
        elif (
            isinstance(sae, MixtureOfLinearTransform)
            or isinstance(sae, LowRankSparseAttention)
            or isinstance(sae, SparseAutoEncoder)
        ):
            label = sae.prepare_label(batch)
            normalized_mean_activation = label.mean(dim=list(range((batch[sae.cfg.hook_point_out].ndim - 1))))
            sae.b_D.copy_(normalized_mean_activation)
        else:
            raise ValueError(
                f"Bias initialization method {self.cfg.bias_init_method} is not supported for {sae.cfg.sae_type}"
            )

    if self.cfg.init_encoder_bias_with_mean_hidden_pre:
        sae.init_encoder_bias_with_mean_hidden_pre(batch)

    @torch.autocast(device_type=sae.cfg.device, dtype=sae.cfg.dtype)
    def grid_search_best_init_norm(search_range: List[float]) -> float:
        losses: Dict[float, float] = {}

        for norm in search_range:
            sae.set_decoder_to_fixed_norm(norm, force_exact=True)
            if self.cfg.init_encoder_with_decoder_transpose:
                sae.init_encoder_with_decoder_transpose(self.cfg.init_encoder_with_decoder_transpose_factor)
            if self.cfg.init_encoder_bias_with_mean_hidden_pre:
                sae.init_encoder_bias_with_mean_hidden_pre(batch)
            mse = item(sae.compute_loss(batch)["l_rec"].mean())
            losses[norm] = mse
        best_norm = min(losses, key=losses.get)  # type: ignore
        return best_norm

    if self.cfg.grid_search_init_norm:
        best_norm_coarse = grid_search_best_init_norm(torch.linspace(0.1, 5.0, 50).numpy().tolist())
        best_norm_fine_grained = grid_search_best_init_norm(
            torch.linspace(best_norm_coarse - 0.09, best_norm_coarse + 0.1, 20).numpy().tolist()
        )

        logger.info(f"The best (i.e. lowest MSE) initialized norm is {best_norm_fine_grained}")
        if wandb_logger is not None:
            wandb_logger.log({"best_norm_fine_grained": best_norm_fine_grained})

        sae.set_decoder_to_fixed_norm(best_norm_fine_grained, force_exact=True)

    if self.cfg.init_encoder_with_decoder_transpose:
        sae.init_encoder_with_decoder_transpose(self.cfg.init_encoder_with_decoder_transpose_factor)
    if self.cfg.init_encoder_bias_with_mean_hidden_pre:
        sae.init_encoder_bias_with_mean_hidden_pre(batch)

    return sae

initialize_sae_from_config

initialize_sae_from_config(
    cfg: BaseSAEConfig,
    activation_stream: Iterable[dict[str, Tensor]]
    | None = None,
    activation_norm: dict[str, float] | None = None,
    device_mesh: DeviceMesh | None = None,
    wandb_logger: Run | None = None,
    model: LanguageModel | None = None,
)

Initialize the SAE from the SAE config. Args: cfg (SAEConfig): The SAE config. activation_iter (Iterable[dict[str, Tensor]] | None): The activation iterator. activation_norm (dict[str, float] | None): The activation normalization. Used for dataset-wise normalization when self.cfg.norm_activation is "dataset-wise". device_mesh (DeviceMesh | None): The device mesh.

Source code in src/lm_saes/initializer.py
def initialize_sae_from_config(
    self,
    cfg: BaseSAEConfig,
    activation_stream: Iterable[dict[str, Tensor]] | None = None,
    activation_norm: dict[str, float] | None = None,
    device_mesh: DeviceMesh | None = None,
    wandb_logger: Run | None = None,
    model: LanguageModel | None = None,
):
    """
    Initialize the SAE from the SAE config.
    Args:
        cfg (SAEConfig): The SAE config.
        activation_iter (Iterable[dict[str, Tensor]] | None): The activation iterator.
        activation_norm (dict[str, float] | None): The activation normalization. Used for dataset-wise normalization when self.cfg.norm_activation is "dataset-wise".
        device_mesh (DeviceMesh | None): The device mesh.
    """
    sae: AbstractSparseAutoEncoder = AbstractSparseAutoEncoder.from_config(
        cfg,
        device_mesh=device_mesh,
    )

    sae = self.initialize_parameters(sae)
    if sae.cfg.norm_activation == "dataset-wise":
        if activation_norm is None:
            assert activation_stream is not None, (
                "Activation iterator must be provided for dataset-wise normalization"
            )

            activation_norm = calculate_activation_norm(
                activation_stream, cfg.associated_hook_points, device_mesh=device_mesh
            )
        sae.set_dataset_average_activation_norm(activation_norm)

    if isinstance(sae, LowRankSparseAttention) and self.cfg.initialize_lorsa_with_mhsa:
        assert sae.cfg.norm_activation == "dataset-wise", (
            "Norm activation must be dataset-wise for Lorsa if use initialize_lorsa_with_mhsa"
        )
        assert isinstance(model, TransformerLensLanguageModel) and model.model is not None, (
            "Only support TransformerLens backend for initializing Lorsa with Original Multi Head Sparse Attention"
        )
        assert self.cfg.model_layer is not None, (
            "Model layer must be provided for initializing Lorsa with Original Multi Head Sparse Attention"
        )
        assert isinstance(model.model, HookedTransformer), "Model must be a TransformerLens model"
        assert isinstance(model.model.blocks[self.cfg.model_layer], TransformerBlock), (
            "Block must be a TransformerBlock"
        )
        assert isinstance(model.model.blocks[self.cfg.model_layer].attn, Attention | GroupedQueryAttention), (
            "Attention must be an Attention or GroupedQueryAttention"
        )
        sae.init_lorsa_with_mhsa(
            cast(
                Attention | GroupedQueryAttention,
                model.model.blocks[self.cfg.model_layer].attn,
            )
        )

    assert activation_stream is not None, "Activation iterator must be provided for initialization search"
    activation_batch = next(iter(activation_stream))  # type: ignore

    if (
        isinstance(sae, SparseAutoEncoder)
        and sae.cfg.hook_point_in != sae.cfg.hook_point_out
        and self.cfg.initialize_tc_with_mlp
    ):
        batch = sae.normalize_activations(activation_batch)
        assert sae.cfg.norm_activation == "dataset-wise"
        assert isinstance(model, TransformerLensLanguageModel) and model.model is not None
        assert self.cfg.model_layer is not None
        assert isinstance(model.model, HookedTransformer), "Model must be a TransformerLens model"
        assert isinstance(model.model.blocks[self.cfg.model_layer], TransformerBlock), (
            "Block must be a TransformerBlock"
        )
        assert isinstance(model.model.blocks[self.cfg.model_layer].mlp, CanBeUsedAsMLP)
        sae.init_tc_with_mlp(
            batch=batch,
            mlp=cast(CanBeUsedAsMLP, model.model.blocks[self.cfg.model_layer].mlp),
        )

    if self.cfg.initialize_W_D_with_active_subspace:
        batch = sae.normalize_activations(activation_batch)
        if isinstance(sae, LowRankSparseAttention):
            assert sae.cfg.norm_activation == "dataset-wise", (
                "Norm activation must be dataset-wise for Lorsa if use initialize_W_D_with_active_subspace"
            )
            assert isinstance(model, TransformerLensLanguageModel) and model.model is not None, (
                "Only support TransformerLens backend for initializing Lorsa decoder weight with active subspace"
            )
            assert self.cfg.model_layer is not None, (
                "Model layer must be provided for initializing Lorsa decoder weight with active subspace"
            )
            sae.init_W_V_with_active_subspace_per_head(
                batch=batch,
                mhsa=cast(
                    Attention | GroupedQueryAttention,
                    model.model.blocks[self.cfg.model_layer].attn,
                ),
            )
        else:
            assert self.cfg.d_active_subspace is not None, (
                "d_active_subspace must be provided for initializing other SAEs with active subspace"
            )
            sae.init_W_D_with_active_subspace(batch=batch, d_active_subspace=self.cfg.d_active_subspace)

    sae = self.initialization_search(sae, activation_batch, wandb_logger=wandb_logger)

    return sae