Infrastructure
Language model backend, dataset, and database configuration.
LanguageModelConfig
pydantic-model
Bases: BaseModelConfig
Fields:
-
device(str) -
dtype(dtype) -
model_name(str) -
model_from_pretrained_path(str | None) -
use_flash_attn(bool) -
cache_dir(str | None) -
local_files_only(bool) -
max_length(int) -
backend(Literal['huggingface', 'transformer_lens', 'auto']) -
load_ckpt(bool) -
tokenizer_only(bool) -
prepend_bos(bool) -
bos_token_id(int | None) -
eos_token_id(int | None) -
pad_token_id(int | None)
model_from_pretrained_path
pydantic-field
The path to the pretrained model. If None, will use the model from HuggingFace.
cache_dir
pydantic-field
The directory of the HuggingFace cache. Should have the same effect as HF_HOME.
local_files_only
pydantic-field
Whether to only load the model from the local files. Should have the same effect as HF_HUB_OFFLINE=1.
backend
pydantic-field
The backend to use for the language model.
bos_token_id
pydantic-field
The ID of the BOS token. If None, will use the default BOS token.
eos_token_id
pydantic-field
The ID of the EOS token. If None, will use the default EOS token.
pad_token_id
pydantic-field
The ID of the padding token. If None, will use the default padding token.
from_pretrained_sae
staticmethod
Load the LanguageModelConfig from a pretrained SAE name or path. Config is read from
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pretrained_name_or_path
|
str
|
The path to the pretrained SAE. |
required |
**kwargs
|
Any
|
Additional keyword arguments to pass to the LanguageModelConfig constructor. |
{}
|
Source code in src/lm_saes/backend/language_model.py
TransformerLensLanguageModel
TransformerLensLanguageModel(
cfg: LanguageModelConfig,
device_mesh: DeviceMesh | None = None,
)
Bases: LanguageModel
Source code in src/lm_saes/backend/language_model.py
run_with_cache_until
Run with activation caching, stopping at a given hook for efficiency.
Source code in src/lm_saes/backend/language_model.py
HuggingFaceLanguageModel
HuggingFaceLanguageModel(cfg: LanguageModelConfig)
Bases: LanguageModel
Source code in src/lm_saes/backend/language_model.py
DatasetConfig
pydantic-model
Bases: BaseConfig
Fields:
-
dataset_name_or_path(str) -
cache_dir(str | None) -
is_dataset_on_disk(bool)
dataset_name_or_path
pydantic-field
The name or path to the dataset. Should be a valid dataset name or path for datasets.load_dataset or datasets.load_from_disk, depending on is_dataset_on_disk.
cache_dir
pydantic-field
The directory to cache the dataset. Will be passed to datasets.load_dataset.
MongoDBConfig
pydantic-model
Bases: BaseModel
Fields:
-
mongo_uri(str) -
mongo_db(str)
MongoClient
MongoClient(cfg: MongoDBConfig)
Source code in src/lm_saes/database.py
enable_gridfs
disable_gridfs
is_gridfs_enabled
update_sae
Update an SAE and all its references.
If the name is updated, all references in other collections are also updated within a transaction.
Source code in src/lm_saes/database.py
update_sae_set
Update an SAE set and all its references.
If the name is updated, all references in other collections are also updated within a transaction.
Source code in src/lm_saes/database.py
get_random_alive_feature
get_random_alive_feature(
sae_name: str,
sae_series: str,
name: str | None = None,
metric_filters: dict[str, dict[str, float]]
| None = None,
) -> FeatureRecord | None
Get a random feature that has non-zero activation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae_name
|
str
|
Name of the SAE model |
required |
sae_series
|
str
|
Series of the SAE model |
required |
name
|
str | None
|
Name of the analysis |
None
|
metric_filters
|
dict[str, dict[str, float]] | None
|
Optional dict of metric filters in the format {"metric_name": {"\(gte": value, "\)lte": value}} |
None
|
Returns:
| Type | Description |
|---|---|
FeatureRecord | None
|
A random feature record with non-zero activation, or None if no such feature exists |
Source code in src/lm_saes/database.py
update_feature
update_feature(
sae_name: str,
feature_index: int,
update_data: dict,
sae_series: str | None = None,
) -> UpdateResult
Update a feature with additional data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae_name
|
str
|
Name of the SAE |
required |
feature_index
|
int
|
Index of the feature to update |
required |
update_data
|
dict
|
Dictionary with data to update |
required |
sae_series
|
str | None
|
Optional series of the SAE |
None
|
Returns:
| Type | Description |
|---|---|
UpdateResult
|
Result of the update operation. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the feature doesn't exist |
Source code in src/lm_saes/database.py
add_bookmark
add_bookmark(
sae_name: str,
sae_series: str,
feature_index: int,
tags: list[str] | None = None,
notes: str | None = None,
) -> bool
Add a bookmark for a feature.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae_name
|
str
|
Name of the SAE |
required |
sae_series
|
str
|
Series of the SAE |
required |
feature_index
|
int
|
Index of the feature to bookmark |
required |
tags
|
list[str] | None
|
Optional list of tags for the bookmark |
None
|
notes
|
str | None
|
Optional notes for the bookmark |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
bool |
bool
|
True if bookmark was added, False if it already exists |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the feature doesn't exist |
Source code in src/lm_saes/database.py
remove_bookmark
Remove a bookmark for a feature.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae_name
|
str
|
Name of the SAE |
required |
sae_series
|
str
|
Series of the SAE |
required |
feature_index
|
int
|
Index of the feature to remove bookmark from |
required |
Returns:
| Name | Type | Description |
|---|---|---|
bool |
bool
|
True if bookmark was removed, False if it didn't exist |
Source code in src/lm_saes/database.py
is_bookmarked
Check if a feature is bookmarked.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae_name
|
str
|
Name of the SAE |
required |
sae_series
|
str
|
Series of the SAE |
required |
feature_index
|
int
|
Index of the feature |
required |
Returns:
| Name | Type | Description |
|---|---|---|
bool |
bool
|
True if the feature is bookmarked, False otherwise |
Source code in src/lm_saes/database.py
get_bookmark
Get a specific bookmark.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae_name
|
str
|
Name of the SAE |
required |
sae_series
|
str
|
Series of the SAE |
required |
feature_index
|
int
|
Index of the feature |
required |
Returns:
| Name | Type | Description |
|---|---|---|
BookmarkRecord |
BookmarkRecord | None
|
The bookmark record if it exists, None otherwise |
Source code in src/lm_saes/database.py
list_bookmarks
list_bookmarks(
sae_name: str | None = None,
sae_series: str | None = None,
tags: list[str] | None = None,
limit: int | None = None,
skip: int = 0,
) -> list[BookmarkRecord]
List bookmarks with optional filtering.
Source code in src/lm_saes/database.py
update_bookmark
update_bookmark(
sae_name: str,
sae_series: str,
feature_index: int,
tags: list[str] | None = None,
notes: str | None = None,
) -> bool
Update an existing bookmark.
Source code in src/lm_saes/database.py
get_bookmark_count
Get the total count of bookmarks with optional filtering.
Source code in src/lm_saes/database.py
get_available_metrics
Get available metrics for an SAE by checking the first feature.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae_name
|
str
|
Name of the SAE model |
required |
sae_series
|
str
|
Series of the SAE model |
required |
Returns:
| Type | Description |
|---|---|
list[str]
|
List of available metric names |
Source code in src/lm_saes/database.py
count_features_with_filters
count_features_with_filters(
sae_name: str,
sae_series: str,
name: str | None = None,
metric_filters: dict[str, dict[str, float]]
| None = None,
) -> int
Count features that match the given filters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sae_name
|
str
|
Name of the SAE model |
required |
sae_series
|
str
|
Series of the SAE model |
required |
name
|
str | None
|
Name of the analysis |
None
|
metric_filters
|
dict[str, dict[str, float]] | None
|
Optional dict of metric filters in the format {"metric_name": {"\(gte": value, "\)lte": value}} |
None
|
Returns:
| Type | Description |
|---|---|
int
|
Number of features matching the filters |
Source code in src/lm_saes/database.py
create_circuit
create_circuit(
sae_set_name: str,
sae_series: str,
prompt: str,
input: CircuitInput,
config: CircuitConfig,
name: str | None = None,
group: str | None = None,
parent_id: str | None = None,
) -> str
Create a new circuit record with pending status.
Source code in src/lm_saes/database.py
get_circuit
Get a circuit by its ID.
Source code in src/lm_saes/database.py
list_circuits
list_circuits(
sae_series: str | None = None,
group: str | None = None,
limit: int | None = None,
skip: int = 0,
) -> list[dict[str, Any]]
List circuits with optional filtering.
Note: raw_graph_id is excluded from the listing for efficiency.
Source code in src/lm_saes/database.py
update_circuits_group
Update the group for multiple circuits.
Source code in src/lm_saes/database.py
update_circuit
Update a circuit by its ID.
delete_circuit
Delete a circuit by its ID.
Also deletes the associated attribution from GridFS if it exists.
Source code in src/lm_saes/database.py
update_circuit_progress
update_circuit_progress(
circuit_id: str,
progress: float,
progress_phase: str | None = None,
) -> bool
Update the progress of a circuit generation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
circuit_id
|
str
|
The circuit ID. |
required |
progress
|
float
|
Progress percentage (0-100). |
required |
progress_phase
|
str | None
|
Optional phase description. |
None
|
Returns:
| Type | Description |
|---|---|
bool
|
True if update was successful. |
Source code in src/lm_saes/database.py
update_circuit_status
Update the status of a circuit.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
circuit_id
|
str
|
The circuit ID. |
required |
status
|
str
|
New status (pending, running, completed, failed). |
required |
error_message
|
str | None
|
Optional error message for failed status. |
None
|
Returns:
| Type | Description |
|---|---|
bool
|
True if update was successful. |
Source code in src/lm_saes/database.py
store_attribution
Store attribution data to GridFS and update circuit record.
Source code in src/lm_saes/database.py
load_attribution
Load attribution data from GridFS.
Source code in src/lm_saes/database.py
get_circuit_status
Get just the status information for a circuit.