Skip to content

Infrastructure

Language model backend, dataset, and database configuration.

LanguageModelConfig pydantic-model

Bases: BaseModelConfig

Fields:

model_name pydantic-field

model_name: str = 'gpt2'

The name of the model to use.

model_from_pretrained_path pydantic-field

model_from_pretrained_path: str | None = None

The path to the pretrained model. If None, will use the model from HuggingFace.

use_flash_attn pydantic-field

use_flash_attn: bool = False

Whether to use Flash Attention.

cache_dir pydantic-field

cache_dir: str | None = None

The directory of the HuggingFace cache. Should have the same effect as HF_HOME.

local_files_only pydantic-field

local_files_only: bool = False

Whether to only load the model from the local files. Should have the same effect as HF_HUB_OFFLINE=1.

max_length pydantic-field

max_length: int = 2048

The maximum length of the input.

backend pydantic-field

backend: Literal[
    "huggingface", "transformer_lens", "auto"
] = "auto"

The backend to use for the language model.

tokenizer_only pydantic-field

tokenizer_only: bool = False

Whether to only load the tokenizer.

prepend_bos pydantic-field

prepend_bos: bool = True

Whether to prepend the BOS token to the input.

bos_token_id pydantic-field

bos_token_id: int | None = None

The ID of the BOS token. If None, will use the default BOS token.

eos_token_id pydantic-field

eos_token_id: int | None = None

The ID of the EOS token. If None, will use the default EOS token.

pad_token_id pydantic-field

pad_token_id: int | None = None

The ID of the padding token. If None, will use the default padding token.

from_pretrained_sae staticmethod

from_pretrained_sae(pretrained_name_or_path: str, **kwargs)

Load the LanguageModelConfig from a pretrained SAE name or path. Config is read from /lm_config.json (for local storage), //lm_config.json (for HuggingFace Hub), or constructed from model name (for SAELens).

Parameters:

Name Type Description Default
pretrained_name_or_path str

The path to the pretrained SAE.

required
**kwargs

Additional keyword arguments to pass to the LanguageModelConfig constructor.

{}
Source code in src/lm_saes/backend/language_model.py
@staticmethod
def from_pretrained_sae(pretrained_name_or_path: str, **kwargs):
    """Load the LanguageModelConfig from a pretrained SAE name or path. Config is read from <pretrained_name_or_path>/lm_config.json (for local storage), <repo_id>/<name>/lm_config.json (for HuggingFace Hub), or constructed from model name (for SAELens).

    Args:
        pretrained_name_or_path (str): The path to the pretrained SAE.
        **kwargs: Additional keyword arguments to pass to the LanguageModelConfig constructor.
    """
    sae_type = auto_infer_pretrained_sae_type(pretrained_name_or_path.split(":")[0])
    if sae_type == PretrainedSAEType.LOCAL:
        path = os.path.join(os.path.dirname(pretrained_name_or_path), "lm_config.json")
    elif sae_type == PretrainedSAEType.HUGGINGFACE:
        repo_id, name = pretrained_name_or_path.split(":")
        path = hf_hub_download(repo_id=repo_id, filename=f"{name}/lm_config.json")
    elif sae_type == PretrainedSAEType.SAELENS:
        from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory

        repo_id, name = pretrained_name_or_path.split(":")
        lookups = get_pretrained_saes_directory()
        assert lookups.get(repo_id) is not None and lookups[repo_id].saes_map.get(name) is not None, (
            f"Pretrained SAE {pretrained_name_or_path} not found in SAELens. This might indicate bugs in `auto_infer_pretrained_sae_type`."
        )
        model_name = lookups[repo_id].model
        return LanguageModelConfig(model_name=model_name, **kwargs)
    else:
        raise ValueError(f"Unsupported pretrained type: {sae_type}")
    with open(os.path.join(path, "lm_config.json"), "r") as f:
        lm_config = json.load(f)
    return LanguageModelConfig.model_validate(lm_config, **kwargs)

TransformerLensLanguageModel

TransformerLensLanguageModel(
    cfg: LanguageModelConfig,
    device_mesh: DeviceMesh | None = None,
)

Bases: LanguageModel

Source code in src/lm_saes/backend/language_model.py
def __init__(self, cfg: LanguageModelConfig, device_mesh: DeviceMesh | None = None):
    self.cfg = cfg
    self.device_mesh = device_mesh
    if cfg.device == "cuda":
        self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
    elif cfg.device == "npu":
        self.device = torch.device(f"npu:{torch.npu.current_device()}")  # type: ignore[reportAttributeAccessIssue]
    else:
        self.device = torch.device(cfg.device)

    hf_model = (
        AutoModelForCausalLM.from_pretrained(
            (cfg.model_name if cfg.model_from_pretrained_path is None else cfg.model_from_pretrained_path),
            cache_dir=cfg.cache_dir,
            local_files_only=cfg.local_files_only,
            dtype=cfg.dtype,
            trust_remote_code=True,
        )
        if cfg.load_ckpt and not cfg.tokenizer_only
        else None
    )
    hf_tokenizer = AutoTokenizer.from_pretrained(
        (cfg.model_name if cfg.model_from_pretrained_path is None else cfg.model_from_pretrained_path),
        cache_dir=cfg.cache_dir,
        trust_remote_code=True,
        use_fast=True,
        add_bos_token=True,
        local_files_only=cfg.local_files_only,
    )
    self.tokenizer = set_tokens(
        hf_tokenizer,
        cfg.bos_token_id,
        cfg.eos_token_id,
        cfg.pad_token_id,
    )
    self.model = (
        HookedTransformer.from_pretrained_no_processing(
            cfg.model_name,
            use_flash_attn=cfg.use_flash_attn,
            device=self.device,
            cache_dir=cfg.cache_dir,
            hf_model=hf_model,
            hf_config=hf_model.config,
            tokenizer=hf_tokenizer,
            dtype=cfg.dtype,  # type: ignore ; issue with transformer_lens
        )
        if hf_model and not cfg.tokenizer_only
        else None
    )

run_with_cache_until

run_with_cache_until(*args, **kwargs) -> Any

Run with activation caching, stopping at a given hook for efficiency.

Source code in src/lm_saes/backend/language_model.py
def run_with_cache_until(self, *args, **kwargs) -> Any:
    """Run with activation caching, stopping at a given hook for efficiency."""
    assert self.model is not None, "model must be initialized"
    if self.device_mesh is None:
        return run_with_cache_until(self.model, *args, **kwargs)

    args = pytree.tree_map(self._to_tensor, args)
    kwargs = pytree.tree_map(self._to_tensor, kwargs)
    return pytree.tree_map(self._to_dtensor, run_with_cache_until(self.model, *args, **kwargs))

DatasetConfig pydantic-model

Bases: BaseConfig

Fields:

  • dataset_name_or_path (str)
  • cache_dir (str | None)
  • is_dataset_on_disk (bool)

MongoDBConfig pydantic-model

Bases: BaseModel

Fields:

  • mongo_uri (str)
  • mongo_db (str)

MongoClient

MongoClient(cfg: MongoDBConfig)
Source code in src/lm_saes/database.py
def __init__(self, cfg: MongoDBConfig):
    self.client: pymongo.MongoClient = pymongo.MongoClient(cfg.mongo_uri)
    self.db = self.client[cfg.mongo_db]
    self.fs: gridfs.GridFS | None = None
    self.feature_collection = self.db["features"]
    self.sae_collection = self.db["saes"]
    self.analysis_collection = self.db["analyses"]
    self.dataset_collection = self.db["datasets"]
    self.model_collection = self.db["models"]
    self.bookmark_collection = self.db["bookmarks"]
    self.sae_set_collection = self.db["sae_sets"]
    self.circuit_collection = self.db["circuits"]
    self.sae_collection.create_index([("name", pymongo.ASCENDING), ("series", pymongo.ASCENDING)], unique=True)
    self.sae_collection.create_index([("series", pymongo.ASCENDING)])
    self.analysis_collection.create_index(
        [("name", pymongo.ASCENDING), ("sae_name", pymongo.ASCENDING), ("sae_series", pymongo.ASCENDING)],
        unique=True,
    )
    self.feature_collection.create_index(
        [("sae_name", pymongo.ASCENDING), ("sae_series", pymongo.ASCENDING), ("index", pymongo.ASCENDING)],
        unique=True,
    )
    self.dataset_collection.create_index([("name", pymongo.ASCENDING)], unique=True)
    self.model_collection.create_index([("name", pymongo.ASCENDING)], unique=True)
    self.sae_set_collection.create_index([("name", pymongo.ASCENDING)], unique=True)
    self.sae_set_collection.create_index([("sae_series", pymongo.ASCENDING)])
    self.bookmark_collection.create_index(
        [("sae_name", pymongo.ASCENDING), ("sae_series", pymongo.ASCENDING), ("feature_index", pymongo.ASCENDING)],
        unique=True,
    )
    self.bookmark_collection.create_index([("created_at", pymongo.DESCENDING)])
    self.circuit_collection.create_index([("sae_series", pymongo.ASCENDING)])
    self.circuit_collection.create_index([("created_at", pymongo.DESCENDING)])

    # Initialize GridFS by default
    self._init_fs()

enable_gridfs

enable_gridfs() -> None

Enable GridFS for storing large binary data.

Source code in src/lm_saes/database.py
def enable_gridfs(self) -> None:
    """Enable GridFS for storing large binary data."""
    if self.fs is None:
        self._init_fs()

disable_gridfs

disable_gridfs() -> None

Disable GridFS usage.

Source code in src/lm_saes/database.py
def disable_gridfs(self) -> None:
    """Disable GridFS usage."""
    self.fs = None

is_gridfs_enabled

is_gridfs_enabled() -> bool

Check if GridFS is enabled.

Source code in src/lm_saes/database.py
def is_gridfs_enabled(self) -> bool:
    """Check if GridFS is enabled."""
    return self.fs is not None

update_sae

update_sae(
    name: str, series: str, update_data: dict[str, Any]
)

Update an SAE and all its references.

If the name is updated, all references in other collections are also updated within a transaction.

Source code in src/lm_saes/database.py
def update_sae(self, name: str, series: str, update_data: dict[str, Any]):
    """Update an SAE and all its references.

    If the name is updated, all references in other collections are also updated within a transaction.
    """
    new_name = update_data.get("name")
    if new_name and new_name != name:
        with self.client.start_session() as session:
            with session.start_transaction():
                self.feature_collection.update_many(
                    {"sae_name": name, "sae_series": series}, {"$set": {"sae_name": new_name}}, session=session
                )
                self.analysis_collection.update_many(
                    {"sae_name": name, "sae_series": series}, {"$set": {"sae_name": new_name}}, session=session
                )
                self.bookmark_collection.update_many(
                    {"sae_name": name, "sae_series": series}, {"$set": {"sae_name": new_name}}, session=session
                )
                self.sae_set_collection.update_many(
                    {"sae_names": name, "sae_series": series},
                    {"$set": {"sae_names.$[elem]": new_name}},
                    array_filters=[{"elem": name}],
                    session=session,
                )
                self.circuit_collection.update_many(
                    {"clt_names": name, "sae_series": series},
                    {"$set": {"clt_names.$[elem]": new_name}},
                    array_filters=[{"elem": name}],
                    session=session,
                )
                self.circuit_collection.update_many(
                    {"lorsa_names": name, "sae_series": series},
                    {"$set": {"lorsa_names.$[elem]": new_name}},
                    array_filters=[{"elem": name}],
                    session=session,
                )
                self.sae_collection.update_one(
                    {"name": name, "series": series}, {"$set": update_data}, session=session
                )
    else:
        self.sae_collection.update_one({"name": name, "series": series}, {"$set": update_data})

update_sae_set

update_sae_set(name: str, update_data: dict[str, Any])

Update an SAE set and all its references.

If the name is updated, all references in other collections are also updated within a transaction.

Source code in src/lm_saes/database.py
def update_sae_set(self, name: str, update_data: dict[str, Any]):
    """Update an SAE set and all its references.

    If the name is updated, all references in other collections are also updated within a transaction.
    """
    new_name = update_data.get("name")
    if new_name and new_name != name:
        with self.client.start_session() as session:
            with session.start_transaction():
                self.circuit_collection.update_many(
                    {"sae_set_name": name}, {"$set": {"sae_set_name": new_name}}, session=session
                )
                self.sae_set_collection.update_one({"name": name}, {"$set": update_data}, session=session)
    else:
        self.sae_set_collection.update_one({"name": name}, {"$set": update_data})

get_random_alive_feature

get_random_alive_feature(
    sae_name: str,
    sae_series: str,
    name: str | None = None,
    metric_filters: Optional[
        dict[str, dict[str, float]]
    ] = None,
) -> Optional[FeatureRecord]

Get a random feature that has non-zero activation.

Parameters:

Name Type Description Default
sae_name str

Name of the SAE model

required
sae_series str

Series of the SAE model

required
name str | None

Name of the analysis

None
metric_filters Optional[dict[str, dict[str, float]]]

Optional dict of metric filters in the format {"metric_name": {"\(gte": value, "\)lte": value}}

None

Returns:

Type Description
Optional[FeatureRecord]

A random feature record with non-zero activation, or None if no such feature exists

Source code in src/lm_saes/database.py
def get_random_alive_feature(
    self,
    sae_name: str,
    sae_series: str,
    name: str | None = None,
    metric_filters: Optional[dict[str, dict[str, float]]] = None,
) -> Optional[FeatureRecord]:
    """Get a random feature that has non-zero activation.

    Args:
        sae_name: Name of the SAE model
        sae_series: Series of the SAE model
        name: Name of the analysis
        metric_filters: Optional dict of metric filters in the format {"metric_name": {"$gte": value, "$lte": value}}

    Returns:
        A random feature record with non-zero activation, or None if no such feature exists
    """
    elem_match: dict[str, Any] = {"max_feature_acts": {"$gt": 0}}
    if name is not None:
        elem_match["name"] = name

    match_filter: dict[str, Any] = {
        "sae_name": sae_name,
        "sae_series": sae_series,
        "analyses": {"$elemMatch": elem_match},
    }

    # Add metric filters if provided
    if metric_filters:
        for metric_name, filters in metric_filters.items():
            match_filter[f"metric.{metric_name}"] = filters

    pipeline = [
        {"$match": match_filter},
        {"$sample": {"size": 1}},
    ]
    feature = next(self.feature_collection.aggregate(pipeline), None)
    if feature is None:
        return None

    # Convert GridFS references back to numpy arrays
    if self.is_gridfs_enabled():
        feature = self._from_gridfs(feature)

    return FeatureRecord.model_validate(feature)

update_feature

update_feature(
    sae_name: str,
    feature_index: int,
    update_data: dict,
    sae_series: str | None = None,
)

Update a feature with additional data.

Parameters:

Name Type Description Default
sae_name str

Name of the SAE

required
feature_index int

Index of the feature to update

required
update_data dict

Dictionary with data to update

required
sae_series str | None

Optional series of the SAE

None

Returns:

Type Description

Result of the update operation

Raises:

Type Description
ValueError

If the feature doesn't exist

Source code in src/lm_saes/database.py
def update_feature(self, sae_name: str, feature_index: int, update_data: dict, sae_series: str | None = None):
    """Update a feature with additional data.

    Args:
        sae_name: Name of the SAE
        feature_index: Index of the feature to update
        update_data: Dictionary with data to update
        sae_series: Optional series of the SAE

    Returns:
        Result of the update operation

    Raises:
        ValueError: If the feature doesn't exist
    """
    # Ensure we have a non-None sae_series
    if sae_series is None:
        raise ValueError("sae_series cannot be None")

    feature = self.get_feature(sae_name, sae_series, feature_index)
    if feature is None:
        raise ValueError(f"Feature {feature_index} not found for SAE {sae_name}/{sae_series}")

    # Initialize GridFS if not already done
    if not self.is_gridfs_enabled():
        self.enable_gridfs()

    # Convert numpy arrays to GridFS references
    processed_update_data = self._to_gridfs(update_data)

    result = self.feature_collection.update_one(
        {"sae_name": sae_name, "sae_series": sae_series, "index": feature_index}, {"$set": processed_update_data}
    )

    return result

add_bookmark

add_bookmark(
    sae_name: str,
    sae_series: str,
    feature_index: int,
    tags: Optional[list[str]] = None,
    notes: Optional[str] = None,
) -> bool

Add a bookmark for a feature.

Parameters:

Name Type Description Default
sae_name str

Name of the SAE

required
sae_series str

Series of the SAE

required
feature_index int

Index of the feature to bookmark

required
tags Optional[list[str]]

Optional list of tags for the bookmark

None
notes Optional[str]

Optional notes for the bookmark

None

Returns:

Name Type Description
bool bool

True if bookmark was added, False if it already exists

Raises:

Type Description
ValueError

If the feature doesn't exist

Source code in src/lm_saes/database.py
def add_bookmark(
    self,
    sae_name: str,
    sae_series: str,
    feature_index: int,
    tags: Optional[list[str]] = None,
    notes: Optional[str] = None,
) -> bool:
    """Add a bookmark for a feature.

    Args:
        sae_name: Name of the SAE
        sae_series: Series of the SAE
        feature_index: Index of the feature to bookmark
        tags: Optional list of tags for the bookmark
        notes: Optional notes for the bookmark

    Returns:
        bool: True if bookmark was added, False if it already exists

    Raises:
        ValueError: If the feature doesn't exist
    """
    # Check if feature exists
    feature = self.get_feature(sae_name, sae_series, feature_index)
    if feature is None:
        raise ValueError(f"Feature {feature_index} not found for SAE {sae_name}/{sae_series}")

    bookmark_data = {
        "sae_name": sae_name,
        "sae_series": sae_series,
        "feature_index": feature_index,
        "created_at": datetime.utcnow(),
        "tags": tags or [],
        "notes": notes,
    }

    try:
        result = self.bookmark_collection.insert_one(bookmark_data)
        return result.inserted_id is not None
    except pymongo.errors.DuplicateKeyError:
        return False

remove_bookmark

remove_bookmark(
    sae_name: str, sae_series: str, feature_index: int
) -> bool

Remove a bookmark for a feature.

Parameters:

Name Type Description Default
sae_name str

Name of the SAE

required
sae_series str

Series of the SAE

required
feature_index int

Index of the feature to remove bookmark from

required

Returns:

Name Type Description
bool bool

True if bookmark was removed, False if it didn't exist

Source code in src/lm_saes/database.py
def remove_bookmark(self, sae_name: str, sae_series: str, feature_index: int) -> bool:
    """Remove a bookmark for a feature.

    Args:
        sae_name: Name of the SAE
        sae_series: Series of the SAE
        feature_index: Index of the feature to remove bookmark from

    Returns:
        bool: True if bookmark was removed, False if it didn't exist
    """
    result = self.bookmark_collection.delete_one(
        {
            "sae_name": sae_name,
            "sae_series": sae_series,
            "feature_index": feature_index,
        }
    )
    return result.deleted_count > 0

is_bookmarked

is_bookmarked(
    sae_name: str, sae_series: str, feature_index: int
) -> bool

Check if a feature is bookmarked.

Parameters:

Name Type Description Default
sae_name str

Name of the SAE

required
sae_series str

Series of the SAE

required
feature_index int

Index of the feature

required

Returns:

Name Type Description
bool bool

True if the feature is bookmarked, False otherwise

Source code in src/lm_saes/database.py
def is_bookmarked(self, sae_name: str, sae_series: str, feature_index: int) -> bool:
    """Check if a feature is bookmarked.

    Args:
        sae_name: Name of the SAE
        sae_series: Series of the SAE
        feature_index: Index of the feature

    Returns:
        bool: True if the feature is bookmarked, False otherwise
    """
    bookmark = self.bookmark_collection.find_one(
        {
            "sae_name": sae_name,
            "sae_series": sae_series,
            "feature_index": feature_index,
        }
    )
    return bookmark is not None

get_bookmark

get_bookmark(
    sae_name: str, sae_series: str, feature_index: int
) -> Optional[BookmarkRecord]

Get a specific bookmark.

Parameters:

Name Type Description Default
sae_name str

Name of the SAE

required
sae_series str

Series of the SAE

required
feature_index int

Index of the feature

required

Returns:

Name Type Description
BookmarkRecord Optional[BookmarkRecord]

The bookmark record if it exists, None otherwise

Source code in src/lm_saes/database.py
def get_bookmark(self, sae_name: str, sae_series: str, feature_index: int) -> Optional[BookmarkRecord]:
    """Get a specific bookmark.

    Args:
        sae_name: Name of the SAE
        sae_series: Series of the SAE
        feature_index: Index of the feature

    Returns:
        BookmarkRecord: The bookmark record if it exists, None otherwise
    """
    bookmark = self.bookmark_collection.find_one(
        {
            "sae_name": sae_name,
            "sae_series": sae_series,
            "feature_index": feature_index,
        }
    )
    if bookmark is None:
        return None
    return BookmarkRecord.model_validate(bookmark)

list_bookmarks

list_bookmarks(
    sae_name: Optional[str] = None,
    sae_series: Optional[str] = None,
    tags: Optional[list[str]] = None,
    limit: Optional[int] = None,
    skip: int = 0,
) -> list[BookmarkRecord]

List bookmarks with optional filtering.

Source code in src/lm_saes/database.py
def list_bookmarks(
    self,
    sae_name: Optional[str] = None,
    sae_series: Optional[str] = None,
    tags: Optional[list[str]] = None,
    limit: Optional[int] = None,
    skip: int = 0,
) -> list[BookmarkRecord]:
    """List bookmarks with optional filtering."""
    query = {}

    if sae_name is not None:
        query["sae_name"] = sae_name
    if sae_series is not None:
        query["sae_series"] = sae_series
    if tags:
        query["tags"] = {"$in": tags}

    cursor = self.bookmark_collection.find(query).sort("created_at", pymongo.DESCENDING)

    if skip > 0:
        cursor = cursor.skip(skip)
    if limit is not None:
        cursor = cursor.limit(limit)

    return [BookmarkRecord.model_validate(bookmark) for bookmark in cursor]

update_bookmark

update_bookmark(
    sae_name: str,
    sae_series: str,
    feature_index: int,
    tags: Optional[list[str]] = None,
    notes: Optional[str] = None,
) -> bool

Update an existing bookmark.

Source code in src/lm_saes/database.py
def update_bookmark(
    self,
    sae_name: str,
    sae_series: str,
    feature_index: int,
    tags: Optional[list[str]] = None,
    notes: Optional[str] = None,
) -> bool:
    """Update an existing bookmark."""
    update_data = {}
    if tags is not None:
        update_data["tags"] = tags
    if notes is not None:
        update_data["notes"] = notes

    if not update_data:
        return True  # Nothing to update

    result = self.bookmark_collection.update_one(
        {
            "sae_name": sae_name,
            "sae_series": sae_series,
            "feature_index": feature_index,
        },
        {"$set": update_data},
    )
    return result.modified_count > 0

get_bookmark_count

get_bookmark_count(
    sae_name: Optional[str] = None,
    sae_series: Optional[str] = None,
) -> int

Get the total count of bookmarks with optional filtering.

Source code in src/lm_saes/database.py
def get_bookmark_count(self, sae_name: Optional[str] = None, sae_series: Optional[str] = None) -> int:
    """Get the total count of bookmarks with optional filtering."""
    query = {}
    if sae_name is not None:
        query["sae_name"] = sae_name
    if sae_series is not None:
        query["sae_series"] = sae_series

    return self.bookmark_collection.count_documents(query)

get_available_metrics

get_available_metrics(
    sae_name: str, sae_series: str
) -> list[str]

Get available metrics for an SAE by checking the first feature.

Parameters:

Name Type Description Default
sae_name str

Name of the SAE model

required
sae_series str

Series of the SAE model

required

Returns:

Type Description
list[str]

List of available metric names

Source code in src/lm_saes/database.py
def get_available_metrics(self, sae_name: str, sae_series: str) -> list[str]:
    """Get available metrics for an SAE by checking the first feature.

    Args:
        sae_name: Name of the SAE model
        sae_series: Series of the SAE model

    Returns:
        List of available metric names
    """
    # Use projection to avoid loading large arrays from analyses[0].samplings
    projection = {
        "metric": 1,
    }

    first_feature = self.feature_collection.find_one({"sae_name": sae_name, "sae_series": sae_series}, projection)

    if first_feature is None or first_feature.get("metric") is None:
        return []

    return list(first_feature["metric"].keys())

count_features_with_filters

count_features_with_filters(
    sae_name: str,
    sae_series: str,
    name: str | None = None,
    metric_filters: Optional[
        dict[str, dict[str, float]]
    ] = None,
) -> int

Count features that match the given filters.

Parameters:

Name Type Description Default
sae_name str

Name of the SAE model

required
sae_series str

Series of the SAE model

required
name str | None

Name of the analysis

None
metric_filters Optional[dict[str, dict[str, float]]]

Optional dict of metric filters in the format {"metric_name": {"\(gte": value, "\)lte": value}}

None

Returns:

Type Description
int

Number of features matching the filters

Source code in src/lm_saes/database.py
def count_features_with_filters(
    self,
    sae_name: str,
    sae_series: str,
    name: str | None = None,
    metric_filters: Optional[dict[str, dict[str, float]]] = None,
) -> int:
    """Count features that match the given filters.

    Args:
        sae_name: Name of the SAE model
        sae_series: Series of the SAE model
        name: Name of the analysis
        metric_filters: Optional dict of metric filters in the format {"metric_name": {"$gte": value, "$lte": value}}

    Returns:
        Number of features matching the filters
    """
    elem_match: dict[str, Any] = {"max_feature_acts": {"$gt": 0}}
    if name is not None:
        elem_match["name"] = name

    match_filter: dict[str, Any] = {
        "sae_name": sae_name,
        "sae_series": sae_series,
        "analyses": {"$elemMatch": elem_match},
    }

    # Add metric filters if provided
    if metric_filters:
        for metric_name, filters in metric_filters.items():
            match_filter[f"metric.{metric_name}"] = filters

    return self.feature_collection.count_documents(match_filter)

create_circuit

create_circuit(
    sae_set_name: str,
    sae_series: str,
    prompt: str,
    input: CircuitInput,
    config: CircuitConfig,
    name: Optional[str] = None,
    group: Optional[str] = None,
    parent_id: Optional[str] = None,
    clt_names: Optional[list[str]] = None,
    lorsa_names: Optional[list[str]] = None,
    use_lorsa: bool = True,
) -> str

Create a new circuit graph record with pending status.

Parameters:

Name Type Description Default
sae_set_name str

Name of the SAE set used.

required
sae_series str

Series of the SAE.

required
prompt str

The prompt used for generation.

required
input CircuitInput

The circuit input configuration.

required
config CircuitConfig

The circuit configuration.

required
name Optional[str]

Optional custom name for the circuit.

None
group Optional[str]

Optional group name.

None
parent_id Optional[str]

Optional parent circuit ID.

None
clt_names Optional[list[str]]

Names of CLT SAEs used.

None
lorsa_names Optional[list[str]]

Names of LORSA SAEs used.

None
use_lorsa bool

Whether LORSA was used.

True

Returns:

Type Description
str

The ID of the created circuit.

Source code in src/lm_saes/database.py
def create_circuit(
    self,
    sae_set_name: str,
    sae_series: str,
    prompt: str,
    input: CircuitInput,
    config: CircuitConfig,
    name: Optional[str] = None,
    group: Optional[str] = None,
    parent_id: Optional[str] = None,
    clt_names: Optional[list[str]] = None,
    lorsa_names: Optional[list[str]] = None,
    use_lorsa: bool = True,
) -> str:
    """Create a new circuit graph record with pending status.

    Args:
        sae_set_name: Name of the SAE set used.
        sae_series: Series of the SAE.
        prompt: The prompt used for generation.
        input: The circuit input configuration.
        config: The circuit configuration.
        name: Optional custom name for the circuit.
        group: Optional group name.
        parent_id: Optional parent circuit ID.
        clt_names: Names of CLT SAEs used.
        lorsa_names: Names of LORSA SAEs used.
        use_lorsa: Whether LORSA was used.

    Returns:
        The ID of the created circuit.
    """
    circuit_data = {
        "name": name,
        "group": group,
        "parent_id": parent_id,
        "sae_set_name": sae_set_name,
        "sae_series": sae_series,
        "prompt": prompt,
        "input": input.model_dump(),
        "config": config.model_dump(),
        "created_at": datetime.utcnow(),
        "status": CircuitStatus.PENDING,
        "progress": 0.0,
        "progress_phase": None,
        "error_message": None,
        "raw_graph_id": None,
        "clt_names": clt_names,
        "lorsa_names": lorsa_names,
        "use_lorsa": use_lorsa,
    }
    result = self.circuit_collection.insert_one(circuit_data)
    return str(result.inserted_id)

get_circuit

get_circuit(circuit_id: str) -> Optional[CircuitRecord]

Get a circuit by its ID.

Source code in src/lm_saes/database.py
def get_circuit(self, circuit_id: str) -> Optional[CircuitRecord]:
    """Get a circuit by its ID."""
    try:
        circuit = self.circuit_collection.find_one({"_id": ObjectId(circuit_id)})
    except Exception:
        return None
    if circuit is None:
        return None
    circuit["id"] = str(circuit.pop("_id"))
    return CircuitRecord.model_validate(circuit)

list_circuits

list_circuits(
    sae_series: Optional[str] = None,
    group: Optional[str] = None,
    limit: Optional[int] = None,
    skip: int = 0,
) -> list[dict[str, Any]]

List circuits with optional filtering.

Note: raw_graph_id is excluded from the listing for efficiency.

Source code in src/lm_saes/database.py
def list_circuits(
    self,
    sae_series: Optional[str] = None,
    group: Optional[str] = None,
    limit: Optional[int] = None,
    skip: int = 0,
) -> list[dict[str, Any]]:
    """List circuits with optional filtering.

    Note: raw_graph_id is excluded from the listing for efficiency.
    """
    query: dict[str, Any] = {}
    if sae_series is not None:
        query["sae_series"] = sae_series
    if group is not None:
        query["group"] = group

    # Exclude raw_graph_id from listing
    projection = {"raw_graph_id": 0}

    cursor = self.circuit_collection.find(query, projection=projection).sort("created_at", pymongo.DESCENDING)

    if skip > 0:
        cursor = cursor.skip(skip)
    if limit is not None:
        cursor = cursor.limit(limit)

    circuits = []
    for circuit in cursor:
        circuit["id"] = str(circuit.pop("_id"))
        circuits.append(circuit)
    return circuits

update_circuits_group

update_circuits_group(
    circuit_ids: list[str], group: Optional[str]
) -> int

Update the group for multiple circuits.

Source code in src/lm_saes/database.py
def update_circuits_group(self, circuit_ids: list[str], group: Optional[str]) -> int:
    """Update the group for multiple circuits."""
    try:
        object_ids = [ObjectId(cid) for cid in circuit_ids]
    except Exception:
        return 0
    result = self.circuit_collection.update_many({"_id": {"$in": object_ids}}, {"$set": {"group": group}})
    return result.modified_count

update_circuit

update_circuit(
    circuit_id: str, update_data: dict[str, Any]
) -> bool

Update a circuit by its ID.

Source code in src/lm_saes/database.py
def update_circuit(self, circuit_id: str, update_data: dict[str, Any]) -> bool:
    """Update a circuit by its ID."""
    try:
        result = self.circuit_collection.update_one({"_id": ObjectId(circuit_id)}, {"$set": update_data})
    except Exception:
        return False
    return result.modified_count > 0

delete_circuit

delete_circuit(circuit_id: str) -> bool

Delete a circuit by its ID.

Also deletes the associated raw graph from GridFS if it exists.

Source code in src/lm_saes/database.py
def delete_circuit(self, circuit_id: str) -> bool:
    """Delete a circuit by its ID.

    Also deletes the associated raw graph from GridFS if it exists.
    """
    try:
        # First get the circuit to find raw_graph_id
        circuit = self.circuit_collection.find_one({"_id": ObjectId(circuit_id)})
        if circuit and circuit.get("raw_graph_id") and self.fs:
            try:
                self.fs.delete(ObjectId(circuit["raw_graph_id"]))
            except Exception:
                pass  # Ignore errors when deleting GridFS file

        result = self.circuit_collection.delete_one({"_id": ObjectId(circuit_id)})
    except Exception:
        return False
    return result.deleted_count > 0

update_circuit_progress

update_circuit_progress(
    circuit_id: str,
    progress: float,
    progress_phase: Optional[str] = None,
) -> bool

Update the progress of a circuit generation.

Parameters:

Name Type Description Default
circuit_id str

The circuit ID.

required
progress float

Progress percentage (0-100).

required
progress_phase Optional[str]

Optional phase description.

None

Returns:

Type Description
bool

True if update was successful.

Source code in src/lm_saes/database.py
def update_circuit_progress(
    self,
    circuit_id: str,
    progress: float,
    progress_phase: Optional[str] = None,
) -> bool:
    """Update the progress of a circuit generation.

    Args:
        circuit_id: The circuit ID.
        progress: Progress percentage (0-100).
        progress_phase: Optional phase description.

    Returns:
        True if update was successful.
    """
    update_data: dict[str, Any] = {"progress": progress}
    if progress_phase is not None:
        update_data["progress_phase"] = progress_phase

    try:
        result = self.circuit_collection.update_one(
            {"_id": ObjectId(circuit_id)},
            {"$set": update_data},
        )
    except Exception:
        return False
    return result.modified_count > 0

update_circuit_status

update_circuit_status(
    circuit_id: str,
    status: str,
    error_message: Optional[str] = None,
) -> bool

Update the status of a circuit.

Parameters:

Name Type Description Default
circuit_id str

The circuit ID.

required
status str

New status (pending, running, completed, failed).

required
error_message Optional[str]

Optional error message for failed status.

None

Returns:

Type Description
bool

True if update was successful.

Source code in src/lm_saes/database.py
def update_circuit_status(
    self,
    circuit_id: str,
    status: str,
    error_message: Optional[str] = None,
) -> bool:
    """Update the status of a circuit.

    Args:
        circuit_id: The circuit ID.
        status: New status (pending, running, completed, failed).
        error_message: Optional error message for failed status.

    Returns:
        True if update was successful.
    """
    update_data: dict[str, Any] = {"status": status}
    if error_message is not None:
        update_data["error_message"] = error_message

    try:
        result = self.circuit_collection.update_one(
            {"_id": ObjectId(circuit_id)},
            {"$set": update_data},
        )
    except Exception:
        return False
    return result.modified_count > 0

store_raw_graph

store_raw_graph(
    circuit_id: str, graph_data: dict[str, Any]
) -> bool

Store raw graph data to GridFS and update circuit record.

Parameters:

Name Type Description Default
circuit_id str

The circuit ID.

required
graph_data dict[str, Any]

The raw graph data dictionary with numpy arrays.

required

Returns:

Type Description
bool

True if storage was successful.

Source code in src/lm_saes/database.py
def store_raw_graph(self, circuit_id: str, graph_data: dict[str, Any]) -> bool:
    """Store raw graph data to GridFS and update circuit record.

    Args:
        circuit_id: The circuit ID.
        graph_data: The raw graph data dictionary with numpy arrays.

    Returns:
        True if storage was successful.
    """
    if not self.is_gridfs_enabled():
        self.enable_gridfs()

    assert self.fs is not None

    try:
        # Convert numpy arrays to GridFS references
        processed_data = self._to_gridfs(graph_data)

        # Store in GridFS as a single document
        import pickle

        # Use pickle for complex data with GridFS references
        graph_bytes = pickle.dumps(processed_data)
        graph_id = self.fs.put(graph_bytes, filename=f"circuit_{circuit_id}_graph")

        # Update circuit record with graph ID
        result = self.circuit_collection.update_one(
            {"_id": ObjectId(circuit_id)},
            {"$set": {"raw_graph_id": str(graph_id)}},
        )
        return result.modified_count > 0
    except Exception:
        return False

load_raw_graph

load_raw_graph(circuit_id: str) -> Optional[dict[str, Any]]

Load raw graph data from GridFS.

Parameters:

Name Type Description Default
circuit_id str

The circuit ID.

required

Returns:

Type Description
Optional[dict[str, Any]]

The raw graph data dictionary with numpy arrays, or None if not found.

Source code in src/lm_saes/database.py
def load_raw_graph(self, circuit_id: str) -> Optional[dict[str, Any]]:
    """Load raw graph data from GridFS.

    Args:
        circuit_id: The circuit ID.

    Returns:
        The raw graph data dictionary with numpy arrays, or None if not found.
    """
    if not self.is_gridfs_enabled():
        self.enable_gridfs()

    assert self.fs is not None

    try:
        circuit = self.circuit_collection.find_one({"_id": ObjectId(circuit_id)})
        if circuit is None or circuit.get("raw_graph_id") is None:
            return None

        graph_id = ObjectId(circuit["raw_graph_id"])
        if not self.fs.exists(graph_id):
            return None

        import pickle

        graph_bytes = self.fs.get(graph_id).read()
        processed_data = pickle.loads(graph_bytes)

        # Convert GridFS references back to numpy arrays
        return self._from_gridfs(processed_data)
    except Exception:
        return None

get_circuit_status

get_circuit_status(
    circuit_id: str,
) -> Optional[dict[str, Any]]

Get just the status information for a circuit.

Parameters:

Name Type Description Default
circuit_id str

The circuit ID.

required

Returns:

Type Description
Optional[dict[str, Any]]

Dict with status, progress, progress_phase, and error_message.

Source code in src/lm_saes/database.py
def get_circuit_status(self, circuit_id: str) -> Optional[dict[str, Any]]:
    """Get just the status information for a circuit.

    Args:
        circuit_id: The circuit ID.

    Returns:
        Dict with status, progress, progress_phase, and error_message.
    """
    try:
        circuit = self.circuit_collection.find_one(
            {"_id": ObjectId(circuit_id)},
            projection={
                "status": 1,
                "progress": 1,
                "progress_phase": 1,
                "error_message": 1,
            },
        )
    except Exception:
        return None

    if circuit is None:
        return None

    return {
        "status": circuit.get("status", CircuitStatus.PENDING),
        "progress": circuit.get("progress", 0.0),
        "progress_phase": circuit.get("progress_phase"),
        "error_message": circuit.get("error_message"),
    }