Training
Training infrastructure: trainer, optimizer configs, initialization, and logging.
TrainerConfig
pydantic-model
Bases: BaseConfig
Config:
arbitrary_types_allowed:True
Fields:
-
l1_coefficient(float | None) -
l1_coefficient_warmup_steps(int | float) -
lp_coefficient(float | None) -
auxk_coefficient(float | None) -
amp_dtype(dtype | None) -
sparsity_loss_type(Literal['power', 'tanh', 'tanh-quad', None]) -
tanh_stretch_coefficient(float) -
frequency_scale(float) -
p(int) -
initial_k(int | float | None) -
k_warmup_steps(int | float) -
k_cold_booting_steps(int | float) -
k_schedule_type(Literal['linear', 'exponential']) -
k_exponential_factor(float) -
k_aux(int) -
dead_threshold(float) -
skip_metrics_calculation(bool) -
gradient_accumulation_steps(int) -
lr(float | dict[str, float]) -
betas(tuple[float, float]) -
optimizer_class(Literal['adam', 'sparseadam']) -
optimizer_foreach(bool) -
lr_scheduler_name(Literal['constant', 'constantwithwarmup', 'linearwarmupdecay', 'cosineannealing', 'cosineannealingwarmup', 'exponentialwarmup']) -
lr_end_ratio(float) -
lr_warm_up_steps(int | float) -
lr_cool_down_steps(int | float) -
jumprelu_lr_factor(float) -
clip_grad_norm(float) -
feature_sampling_window(int) -
total_training_tokens(int) -
log_frequency(int) -
eval_frequency(int) -
n_checkpoints(int) -
check_point_save_mode(Literal['log', 'linear']) -
from_pretrained_path(str | None) -
exp_result_path(str)
l1_coefficient
pydantic-field
Coefficient for the L1 sparsity loss. This loss is used to penalize the sparsity of the feature activations.
l1_coefficient_warmup_steps
pydantic-field
Steps (int) or fraction of total steps (float) to warm up the sparsity coefficient from 0.
lp_coefficient
pydantic-field
Coefficient for the Lp sparsity loss. This loss is used to . To use the JumpReLU \(L^p\) penalty, set lp_coefficient to a positive value.
Trainer
Trainer(cfg: TrainerConfig)
Source code in src/lm_saes/trainer.py
save_checkpoint
save_checkpoint(
sae: SparseDictionary, checkpoint_path: Path | str
) -> None
Save a complete checkpoint including model, optimizer, scheduler, and trainer state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae
|
SparseDictionary
|
The sparse autoencoder model to save |
required |
checkpoint_path
|
Path | str
|
Path where to save the checkpoint (without extension) |
required |
Source code in src/lm_saes/trainer.py
from_checkpoint
classmethod
from_checkpoint(
sae: SparseDictionary, checkpoint_path: str
) -> Trainer
Load a complete checkpoint including model, optimizer, scheduler, and trainer state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae
|
SparseDictionary
|
The SAE model instance. |
required |
checkpoint_path
|
str
|
Path where the checkpoint was saved (without extension). |
required |
Returns:
| Name | Type | Description |
|---|---|---|
Trainer |
Trainer
|
A new trainer instance with loaded state. |
Source code in src/lm_saes/trainer.py
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 | |
update_dead_statistics
update_dead_statistics(
feature_acts: Tensor,
mask: Tensor | None,
specs: tuple[str, ...],
) -> Tensor
Update the dead latents tracking based on current feature activations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature_acts
|
Tensor
|
Feature activations tensor of shape (batch, d_sae) or (batch, seq_len, d_sae) |
required |
Returns:
| Name | Type | Description |
|---|---|---|
is_dead |
Tensor
|
Boolean tensor indicating which features are dead. |
Source code in src/lm_saes/trainer.py
WandbConfig
pydantic-model
Bases: BaseConfig
Fields:
-
wandb_project(str) -
exp_name(str | None) -
wandb_entity(str | None) -
wandb_run_id(str | None) -
wandb_resume(Literal['allow', 'must', 'never', 'auto'])
InitializerConfig
pydantic-model
Bases: BaseConfig
Fields:
-
bias_init_method(Literal['all_zero', 'geometric_median']) -
decoder_uniform_bound(float) -
encoder_uniform_bound(float) -
init_encoder_with_decoder_transpose(bool) -
init_encoder_with_decoder_transpose_factor(float) -
init_log_jumprelu_threshold_value(float | None) -
grid_search_init_norm(bool) -
initialize_W_D_with_active_subspace(bool) -
d_active_subspace(int | None) -
initialize_lorsa_with_mhsa(bool | None) -
initialize_tc_with_mlp(bool | None) -
model_layer(int | None) -
init_encoder_bias_with_mean_hidden_pre(bool)
bias_init_method
pydantic-field
Method for initializing the decoder bias. "geometric_median" sets the bias to the geometric median of the activation distribution, which is more robust than "all_zero" for skewed activations.
decoder_uniform_bound
pydantic-field
Half-range of the uniform distribution used to initialize decoder weights, weights are sampled from U(-decoder_uniform_bound, decoder_uniform_bound).
encoder_uniform_bound
pydantic-field
Half-range of the uniform distribution used to initialize encoder weights, weights are sampled from U(-encoder_uniform_bound, encoder_uniform_bound).
init_encoder_with_decoder_transpose
pydantic-field
If True, the encoder weight matrix is initialized as the transpose of the decoder weight matrix (scaled by init_encoder_with_decoder_transpose_factor), providing a better starting point for SAE training.
init_encoder_with_decoder_transpose_factor
pydantic-field
Scaling factor applied to the transposed decoder weights when initializing the encoder.
init_log_jumprelu_threshold_value
pydantic-field
Initial value for the log-threshold parameter of JumpReLU activations. Only used when the SAE uses a JumpReLU activation function.
grid_search_init_norm
pydantic-field
Performs a coarse-then-fine grid search over decoder norms to find the value that minimizes the initial reconstruction loss, then sets the decoder to that norm.
initialize_W_D_with_active_subspace
pydantic-field
Initializes the decoder weight matrix within the active (high-variance) subspace of the input activations via SVD. Recommended for low-rank activations such as attention outputs to reduce dead features.
d_active_subspace
pydantic-field
Dimension of the active subspace used when initialize_W_D_with_active_subspace=True
initialize_lorsa_with_mhsa
pydantic-field
Initializes the Lorsa QK weights from the target model's attention (MHSA) weights at model_layer.
initialize_tc_with_mlp
pydantic-field
Initializes the transcoder decoder weights from the target model's MLP weights at model_layer.
model_layer
pydantic-field
Layer index of the target model from which to extract weights for initialize_lorsa_with_mhsa or initialize_tc_with_mlp.
Initializer
Initializer(cfg: InitializerConfig)
Source code in src/lm_saes/initializer.py
initialize_parameters
initialize_parameters(sae: SparseDictionary)
Initialize the parameters of the SAE. Only used when the state is "training" to initialize sae.
Source code in src/lm_saes/initializer.py
initialization_search
initialization_search(
sae: SparseDictionary,
activation_batch: dict[str, Tensor],
wandb_logger: Run | None = None,
)
This function is used to search for the best initialization norm for the SAE decoder.
Source code in src/lm_saes/initializer.py
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | |
initialize_sae_from_config
initialize_sae_from_config(
cfg: SparseDictionaryConfig,
activation_stream: Iterable[dict[str, Tensor]]
| None = None,
activation_norm: dict[str, float] | None = None,
device_mesh: DeviceMesh | None = None,
wandb_logger: Run | None = None,
model: LanguageModel | None = None,
)
Initialize the SAE from the SAE config. Args: cfg (SAEConfig): The SAE config. activation_iter (Iterable[dict[str, Tensor]] | None): The activation iterator. activation_norm (dict[str, float] | None): The activation normalization. Used for dataset-wise normalization when self.cfg.norm_activation is "dataset-wise". device_mesh (DeviceMesh | None): The device mesh.
Source code in src/lm_saes/initializer.py
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 | |