Infrastructure
Language model backend, dataset, and database configuration.
LanguageModelConfig
pydantic-model
Bases: BaseModelConfig
Fields:
-
device(str) -
dtype(dtype) -
model_name(str) -
model_from_pretrained_path(str | None) -
use_flash_attn(bool) -
cache_dir(str | None) -
local_files_only(bool) -
max_length(int) -
backend(Literal['huggingface', 'transformer_lens', 'auto']) -
load_ckpt(bool) -
tokenizer_only(bool) -
prepend_bos(bool) -
bos_token_id(int | None) -
eos_token_id(int | None) -
pad_token_id(int | None)
model_from_pretrained_path
pydantic-field
The path to the pretrained model. If None, will use the model from HuggingFace.
cache_dir
pydantic-field
The directory of the HuggingFace cache. Should have the same effect as HF_HOME.
local_files_only
pydantic-field
Whether to only load the model from the local files. Should have the same effect as HF_HUB_OFFLINE=1.
backend
pydantic-field
The backend to use for the language model.
bos_token_id
pydantic-field
The ID of the BOS token. If None, will use the default BOS token.
eos_token_id
pydantic-field
The ID of the EOS token. If None, will use the default EOS token.
pad_token_id
pydantic-field
The ID of the padding token. If None, will use the default padding token.
from_pretrained_sae
staticmethod
Load the LanguageModelConfig from a pretrained SAE name or path. Config is read from
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pretrained_name_or_path
|
str
|
The path to the pretrained SAE. |
required |
**kwargs
|
Additional keyword arguments to pass to the LanguageModelConfig constructor. |
{}
|
Source code in src/lm_saes/backend/language_model.py
TransformerLensLanguageModel
TransformerLensLanguageModel(
cfg: LanguageModelConfig,
device_mesh: DeviceMesh | None = None,
)
Bases: LanguageModel
Source code in src/lm_saes/backend/language_model.py
run_with_cache_until
Run with activation caching, stopping at a given hook for efficiency.
Source code in src/lm_saes/backend/language_model.py
DatasetConfig
pydantic-model
Bases: BaseConfig
Fields:
-
dataset_name_or_path(str) -
cache_dir(str | None) -
is_dataset_on_disk(bool)
MongoDBConfig
pydantic-model
Bases: BaseModel
Fields:
-
mongo_uri(str) -
mongo_db(str)
MongoClient
MongoClient(cfg: MongoDBConfig)
Source code in src/lm_saes/database.py
enable_gridfs
disable_gridfs
is_gridfs_enabled
update_sae
Update an SAE and all its references.
If the name is updated, all references in other collections are also updated within a transaction.
Source code in src/lm_saes/database.py
update_sae_set
Update an SAE set and all its references.
If the name is updated, all references in other collections are also updated within a transaction.
Source code in src/lm_saes/database.py
get_random_alive_feature
get_random_alive_feature(
sae_name: str,
sae_series: str,
name: str | None = None,
metric_filters: Optional[
dict[str, dict[str, float]]
] = None,
) -> Optional[FeatureRecord]
Get a random feature that has non-zero activation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae_name
|
str
|
Name of the SAE model |
required |
sae_series
|
str
|
Series of the SAE model |
required |
name
|
str | None
|
Name of the analysis |
None
|
metric_filters
|
Optional[dict[str, dict[str, float]]]
|
Optional dict of metric filters in the format {"metric_name": {"\(gte": value, "\)lte": value}} |
None
|
Returns:
| Type | Description |
|---|---|
Optional[FeatureRecord]
|
A random feature record with non-zero activation, or None if no such feature exists |
Source code in src/lm_saes/database.py
update_feature
update_feature(
sae_name: str,
feature_index: int,
update_data: dict,
sae_series: str | None = None,
)
Update a feature with additional data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae_name
|
str
|
Name of the SAE |
required |
feature_index
|
int
|
Index of the feature to update |
required |
update_data
|
dict
|
Dictionary with data to update |
required |
sae_series
|
str | None
|
Optional series of the SAE |
None
|
Returns:
| Type | Description |
|---|---|
|
Result of the update operation |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the feature doesn't exist |
Source code in src/lm_saes/database.py
add_bookmark
add_bookmark(
sae_name: str,
sae_series: str,
feature_index: int,
tags: Optional[list[str]] = None,
notes: Optional[str] = None,
) -> bool
Add a bookmark for a feature.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae_name
|
str
|
Name of the SAE |
required |
sae_series
|
str
|
Series of the SAE |
required |
feature_index
|
int
|
Index of the feature to bookmark |
required |
tags
|
Optional[list[str]]
|
Optional list of tags for the bookmark |
None
|
notes
|
Optional[str]
|
Optional notes for the bookmark |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
bool |
bool
|
True if bookmark was added, False if it already exists |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the feature doesn't exist |
Source code in src/lm_saes/database.py
remove_bookmark
Remove a bookmark for a feature.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae_name
|
str
|
Name of the SAE |
required |
sae_series
|
str
|
Series of the SAE |
required |
feature_index
|
int
|
Index of the feature to remove bookmark from |
required |
Returns:
| Name | Type | Description |
|---|---|---|
bool |
bool
|
True if bookmark was removed, False if it didn't exist |
Source code in src/lm_saes/database.py
is_bookmarked
Check if a feature is bookmarked.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae_name
|
str
|
Name of the SAE |
required |
sae_series
|
str
|
Series of the SAE |
required |
feature_index
|
int
|
Index of the feature |
required |
Returns:
| Name | Type | Description |
|---|---|---|
bool |
bool
|
True if the feature is bookmarked, False otherwise |
Source code in src/lm_saes/database.py
get_bookmark
Get a specific bookmark.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae_name
|
str
|
Name of the SAE |
required |
sae_series
|
str
|
Series of the SAE |
required |
feature_index
|
int
|
Index of the feature |
required |
Returns:
| Name | Type | Description |
|---|---|---|
BookmarkRecord |
Optional[BookmarkRecord]
|
The bookmark record if it exists, None otherwise |
Source code in src/lm_saes/database.py
list_bookmarks
list_bookmarks(
sae_name: Optional[str] = None,
sae_series: Optional[str] = None,
tags: Optional[list[str]] = None,
limit: Optional[int] = None,
skip: int = 0,
) -> list[BookmarkRecord]
List bookmarks with optional filtering.
Source code in src/lm_saes/database.py
update_bookmark
update_bookmark(
sae_name: str,
sae_series: str,
feature_index: int,
tags: Optional[list[str]] = None,
notes: Optional[str] = None,
) -> bool
Update an existing bookmark.
Source code in src/lm_saes/database.py
get_bookmark_count
Get the total count of bookmarks with optional filtering.
Source code in src/lm_saes/database.py
get_available_metrics
Get available metrics for an SAE by checking the first feature.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae_name
|
str
|
Name of the SAE model |
required |
sae_series
|
str
|
Series of the SAE model |
required |
Returns:
| Type | Description |
|---|---|
list[str]
|
List of available metric names |
Source code in src/lm_saes/database.py
count_features_with_filters
count_features_with_filters(
sae_name: str,
sae_series: str,
name: str | None = None,
metric_filters: Optional[
dict[str, dict[str, float]]
] = None,
) -> int
Count features that match the given filters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae_name
|
str
|
Name of the SAE model |
required |
sae_series
|
str
|
Series of the SAE model |
required |
name
|
str | None
|
Name of the analysis |
None
|
metric_filters
|
Optional[dict[str, dict[str, float]]]
|
Optional dict of metric filters in the format {"metric_name": {"\(gte": value, "\)lte": value}} |
None
|
Returns:
| Type | Description |
|---|---|
int
|
Number of features matching the filters |
Source code in src/lm_saes/database.py
create_circuit
create_circuit(
sae_set_name: str,
sae_series: str,
prompt: str,
input: CircuitInput,
config: CircuitConfig,
name: Optional[str] = None,
group: Optional[str] = None,
parent_id: Optional[str] = None,
clt_names: Optional[list[str]] = None,
lorsa_names: Optional[list[str]] = None,
use_lorsa: bool = True,
) -> str
Create a new circuit graph record with pending status.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae_set_name
|
str
|
Name of the SAE set used. |
required |
sae_series
|
str
|
Series of the SAE. |
required |
prompt
|
str
|
The prompt used for generation. |
required |
input
|
CircuitInput
|
The circuit input configuration. |
required |
config
|
CircuitConfig
|
The circuit configuration. |
required |
name
|
Optional[str]
|
Optional custom name for the circuit. |
None
|
group
|
Optional[str]
|
Optional group name. |
None
|
parent_id
|
Optional[str]
|
Optional parent circuit ID. |
None
|
clt_names
|
Optional[list[str]]
|
Names of CLT SAEs used. |
None
|
lorsa_names
|
Optional[list[str]]
|
Names of LORSA SAEs used. |
None
|
use_lorsa
|
bool
|
Whether LORSA was used. |
True
|
Returns:
| Type | Description |
|---|---|
str
|
The ID of the created circuit. |
Source code in src/lm_saes/database.py
get_circuit
Get a circuit by its ID.
Source code in src/lm_saes/database.py
list_circuits
list_circuits(
sae_series: Optional[str] = None,
group: Optional[str] = None,
limit: Optional[int] = None,
skip: int = 0,
) -> list[dict[str, Any]]
List circuits with optional filtering.
Note: raw_graph_id is excluded from the listing for efficiency.
Source code in src/lm_saes/database.py
update_circuits_group
Update the group for multiple circuits.
Source code in src/lm_saes/database.py
update_circuit
Update a circuit by its ID.
Source code in src/lm_saes/database.py
delete_circuit
Delete a circuit by its ID.
Also deletes the associated raw graph from GridFS if it exists.
Source code in src/lm_saes/database.py
update_circuit_progress
update_circuit_progress(
circuit_id: str,
progress: float,
progress_phase: Optional[str] = None,
) -> bool
Update the progress of a circuit generation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
circuit_id
|
str
|
The circuit ID. |
required |
progress
|
float
|
Progress percentage (0-100). |
required |
progress_phase
|
Optional[str]
|
Optional phase description. |
None
|
Returns:
| Type | Description |
|---|---|
bool
|
True if update was successful. |
Source code in src/lm_saes/database.py
update_circuit_status
Update the status of a circuit.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
circuit_id
|
str
|
The circuit ID. |
required |
status
|
str
|
New status (pending, running, completed, failed). |
required |
error_message
|
Optional[str]
|
Optional error message for failed status. |
None
|
Returns:
| Type | Description |
|---|---|
bool
|
True if update was successful. |
Source code in src/lm_saes/database.py
store_raw_graph
Store raw graph data to GridFS and update circuit record.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
circuit_id
|
str
|
The circuit ID. |
required |
graph_data
|
dict[str, Any]
|
The raw graph data dictionary with numpy arrays. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if storage was successful. |
Source code in src/lm_saes/database.py
load_raw_graph
Load raw graph data from GridFS.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
circuit_id
|
str
|
The circuit ID. |
required |
Returns:
| Type | Description |
|---|---|
Optional[dict[str, Any]]
|
The raw graph data dictionary with numpy arrays, or None if not found. |
Source code in src/lm_saes/database.py
get_circuit_status
Get just the status information for a circuit.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
circuit_id
|
str
|
The circuit ID. |
required |
Returns:
| Type | Description |
|---|---|
Optional[dict[str, Any]]
|
Dict with status, progress, progress_phase, and error_message. |