Training
Training infrastructure: trainer, optimizer configs, initialization, and logging.
TrainerConfig
pydantic-model
Bases: BaseConfig
Config:
arbitrary_types_allowed:True
Fields:
-
l1_coefficient(float | None) -
l1_coefficient_warmup_steps(int | float) -
lp_coefficient(float | None) -
auxk_coefficient(float | None) -
amp_dtype(dtype | None) -
sparsity_loss_type(Literal['power', 'tanh', 'tanh-quad', None]) -
tanh_stretch_coefficient(float) -
frequency_scale(float) -
p(int) -
initial_k(int | float | None) -
k_warmup_steps(int | float) -
k_cold_booting_steps(int | float) -
k_schedule_type(Literal['linear', 'exponential']) -
k_exponential_factor(float) -
k_aux(int) -
dead_threshold(float) -
skip_metrics_calculation(bool) -
gradient_accumulation_steps(int) -
lr(float | dict[str, float]) -
betas(Tuple[float, float]) -
optimizer_class(Literal['adam', 'sparseadam']) -
optimizer_foreach(bool) -
lr_scheduler_name(Literal['constant', 'constantwithwarmup', 'linearwarmupdecay', 'cosineannealing', 'cosineannealingwarmup', 'exponentialwarmup']) -
lr_end_ratio(float) -
lr_warm_up_steps(int | float) -
lr_cool_down_steps(int | float) -
jumprelu_lr_factor(float) -
clip_grad_norm(float) -
feature_sampling_window(int) -
total_training_tokens(int) -
log_frequency(int) -
eval_frequency(int) -
n_checkpoints(int) -
check_point_save_mode(Literal['log', 'linear']) -
from_pretrained_path(str | None) -
exp_result_path(str)
l1_coefficient
pydantic-field
Coefficient for the L1 sparsity loss. This loss is used to penalize the sparsity of the feature activations.
l1_coefficient_warmup_steps
pydantic-field
Steps (int) or fraction of total steps (float) to warm up the sparsity coefficient from 0.
lp_coefficient
pydantic-field
Coefficient for the Lp sparsity loss. This loss is used to . To use the JumpReLU \(L^p\) penalty, set lp_coefficient to a positive value.
Trainer
Trainer(cfg: TrainerConfig)
Source code in src/lm_saes/trainer.py
save_checkpoint
Save a complete checkpoint including model, optimizer, scheduler, and trainer state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae
|
AbstractSparseAutoEncoder
|
The sparse autoencoder model to save |
required |
checkpoint_path
|
Path | str
|
Path where to save the checkpoint (without extension) |
required |
Source code in src/lm_saes/trainer.py
from_checkpoint
classmethod
from_checkpoint(
sae: AbstractSparseAutoEncoder, checkpoint_path: str
) -> Trainer
Load a complete checkpoint including model, optimizer, scheduler, and trainer state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device_mesh
|
The device mesh to load the model into |
required | |
checkpoint_path
|
str
|
Path where the checkpoint was saved (without extension) |
required |
Returns:
| Name | Type | Description |
|---|---|---|
Trainer |
Trainer
|
A new trainer instance with loaded state |
Source code in src/lm_saes/trainer.py
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 | |
update_dead_statistics
update_dead_statistics(
feature_acts: Tensor,
mask: Tensor | None,
specs: tuple[str, ...],
) -> Tensor
Update the dead latents tracking based on current feature activations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature_acts
|
Tensor
|
Feature activations tensor of shape (batch, d_sae) or (batch, seq_len, d_sae) |
required |
Returns:
| Name | Type | Description |
|---|---|---|
is_dead |
Tensor
|
Boolean tensor indicating which features are dead. |
Source code in src/lm_saes/trainer.py
WandbConfig
pydantic-model
Bases: BaseConfig
Fields:
-
wandb_project(str) -
exp_name(str | None) -
wandb_entity(str | None) -
wandb_run_id(str | None) -
wandb_resume(Literal['allow', 'must', 'never', 'auto'])
InitializerConfig
pydantic-model
Bases: BaseConfig
Fields:
-
bias_init_method(Literal['all_zero', 'geometric_median']) -
decoder_uniform_bound(float) -
encoder_uniform_bound(float) -
init_encoder_with_decoder_transpose(bool) -
init_encoder_with_decoder_transpose_factor(float) -
init_log_jumprelu_threshold_value(float | None) -
grid_search_init_norm(bool) -
initialize_W_D_with_active_subspace(bool) -
d_active_subspace(int | None) -
initialize_lorsa_with_mhsa(bool | None) -
initialize_tc_with_mlp(bool | None) -
model_layer(int | None) -
init_encoder_bias_with_mean_hidden_pre(bool)
Initializer
Initializer(cfg: InitializerConfig)
Source code in src/lm_saes/initializer.py
initialize_parameters
Initialize the parameters of the SAE. Only used when the state is "training" to initialize sae.
Source code in src/lm_saes/initializer.py
initialization_search
initialization_search(
sae: AbstractSparseAutoEncoder,
activation_batch: Dict[str, Tensor],
wandb_logger: Run | None = None,
)
This function is used to search for the best initialization norm for the SAE decoder.
Source code in src/lm_saes/initializer.py
initialize_sae_from_config
initialize_sae_from_config(
cfg: BaseSAEConfig,
activation_stream: Iterable[dict[str, Tensor]]
| None = None,
activation_norm: dict[str, float] | None = None,
device_mesh: DeviceMesh | None = None,
wandb_logger: Run | None = None,
model: LanguageModel | None = None,
)
Initialize the SAE from the SAE config. Args: cfg (SAEConfig): The SAE config. activation_iter (Iterable[dict[str, Tensor]] | None): The activation iterator. activation_norm (dict[str, float] | None): The activation normalization. Used for dataset-wise normalization when self.cfg.norm_activation is "dataset-wise". device_mesh (DeviceMesh | None): The device mesh.
Source code in src/lm_saes/initializer.py
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 | |