Skip to content

Runners

High-level runner functions and their settings for common workflows.

PretrainedSAE pydantic-model

Bases: BaseModelConfig

Fields:

  • device (str)
  • dtype (dtype)
  • pretrained_name_or_path (str)
  • fold_activation_scale (bool)
  • strict_loading (bool)

TrainSAESettings

Bases: BaseSettings

Settings for training a Sparse Autoencoder (SAE).

sae instance-attribute

Configuration for the SAE model architecture and parameters, or the path to a pretrained SAE.

sae_name instance-attribute

sae_name: str

Name of the SAE model. Use as identifier for the SAE model in the database.

sae_series instance-attribute

sae_series: str

Series of the SAE model. Use as identifier for the SAE model in the database.

initializer class-attribute instance-attribute

initializer: InitializerConfig | None = None

Configuration for model initialization. Should be None for a pretrained SAE.

trainer instance-attribute

trainer: TrainerConfig

Configuration for training process

activation_factory instance-attribute

activation_factory: ActivationFactoryConfig

Configuration for generating activations

wandb class-attribute instance-attribute

wandb: Optional[WandbConfig] = None

Configuration for Weights & Biases logging

eval class-attribute instance-attribute

eval: bool = False

Whether to run in evaluation mode

data_parallel_size class-attribute instance-attribute

data_parallel_size: int = 1

Size of data parallel mesh

model_parallel_size class-attribute instance-attribute

model_parallel_size: int = 1

Size of model parallel (tensor parallel) mesh

mongo class-attribute instance-attribute

mongo: Optional[MongoDBConfig] = None

Configuration for MongoDB

model class-attribute instance-attribute

model: Optional[LanguageModelConfig] = None

Configuration for the language model. Required if using dataset sources.

model_name class-attribute instance-attribute

model_name: Optional[str] = None

Name of the tokenizer to load. Mixcoder requires a tokenizer to get the modality indices.

datasets class-attribute instance-attribute

datasets: Optional[dict[str, Optional[DatasetConfig]]] = (
    None
)

Name to dataset config mapping. Required if using dataset sources.

device_type class-attribute instance-attribute

device_type: str = 'cuda'

Device type to use for distributed training ('cuda' or 'cpu')

train_sae

train_sae(settings: TrainSAESettings) -> None

Train a SAE model.

Parameters:

Name Type Description Default
settings TrainSAESettings

Configuration settings for SAE training

required
Source code in src/lm_saes/runners/train.py
def train_sae(settings: TrainSAESettings) -> None:
    """Train a SAE model.

    Args:
        settings: Configuration settings for SAE training
    """
    # Set up logging
    setup_logging(level="INFO")

    device_mesh = (
        init_device_mesh(
            device_type=settings.device_type,
            mesh_shape=(settings.data_parallel_size, settings.model_parallel_size),
            mesh_dim_names=("data", "model"),
        )
        if settings.model_parallel_size > 1 or settings.data_parallel_size > 1
        else None
    )

    logger.info(f"Device mesh initialized: {device_mesh}")

    mongo_client = MongoClient(settings.mongo) if settings.mongo is not None else None
    if mongo_client:
        logger.info("MongoDB client initialized")

    # Load configurations
    model_cfg = load_config(
        config=settings.model,
        name=settings.model_name,
        mongo_client=mongo_client,
        config_type="model",
        required=False,
    )

    dataset_cfgs = (
        {
            dataset_name: load_config(
                config=dataset_cfg,
                name=dataset_name,
                mongo_client=mongo_client,
                config_type="dataset",
            )
            for dataset_name, dataset_cfg in settings.datasets.items()
        }
        if settings.datasets is not None
        else None
    )

    # Load model and datasets
    logger.info("Loading model and datasets")
    model = load_model(model_cfg) if model_cfg is not None else None
    datasets = (
        {
            dataset_name: load_dataset(dataset_cfg, device_mesh=device_mesh)
            for dataset_name, dataset_cfg in dataset_cfgs.items()
        }
        if dataset_cfgs is not None
        else None
    )

    activation_factory = ActivationFactory(settings.activation_factory, device_mesh=device_mesh)

    logger.info("Processing activations stream")
    activations_stream = activation_factory.process(
        model=model,
        model_name=settings.model_name,
        datasets=datasets,
    )

    logger.info("Initializing SAE")

    wandb_logger = (
        wandb.init(
            project=settings.wandb.wandb_project,
            config=settings.model_dump(),
            name=settings.wandb.exp_name,
            entity=settings.wandb.wandb_entity,
            settings=wandb.Settings(x_disable_stats=True),
            mode=os.getenv("WANDB_MODE", "online"),  # type: ignore
            resume=settings.wandb.wandb_resume,
            id=settings.wandb.wandb_run_id,
        )
        if settings.wandb is not None and (device_mesh is None or mesh_rank(device_mesh) == 0)
        else None
    )

    assert settings.initializer is None or not isinstance(settings.initializer, str), (
        "Cannot use an initializer for a pretrained SAE"
    )
    if isinstance(settings.sae, PretrainedSAE):
        sae = AbstractSparseAutoEncoder.from_pretrained(
            settings.sae.pretrained_name_or_path,
            device_mesh=device_mesh,
            fold_activation_scale=settings.sae.fold_activation_scale,
            strict_loading=settings.sae.strict_loading,
            device=settings.sae.device,
            dtype=settings.sae.dtype,
        )
    elif settings.initializer is not None:
        initializer = Initializer(settings.initializer)
        sae = initializer.initialize_sae_from_config(
            settings.sae,
            activation_stream=activations_stream,
            device_mesh=device_mesh,
            wandb_logger=wandb_logger,
            model=model,
        )
    else:
        sae = AbstractSparseAutoEncoder.from_config(settings.sae, device_mesh=device_mesh)

    if settings.trainer.from_pretrained_path is not None:
        trainer = Trainer.from_checkpoint(
            sae,
            settings.trainer.from_pretrained_path,
        )
        trainer.wandb_logger = wandb_logger
    else:
        trainer = Trainer(settings.trainer)

    logger.info(f"SAE initialized: {type(sae).__name__}")

    if wandb_logger is not None:
        logger.info("WandB logger initialized")

    # TODO: implement eval_fn
    eval_fn = (lambda x: None) if settings.eval else None

    logger.info("Starting training")

    sae.cfg.save_hyperparameters(settings.trainer.exp_result_path)
    end_of_stream = trainer.fit(
        sae=sae, activation_stream=activations_stream, eval_fn=eval_fn, wandb_logger=wandb_logger
    )
    logger.info("Training completed, saving model")
    if end_of_stream:
        trainer.save_checkpoint(
            sae=sae,
            checkpoint_path=settings.trainer.exp_result_path,
        )
    else:
        sae.save_pretrained(
            save_path=settings.trainer.exp_result_path,
        )
        if is_primary_rank(device_mesh) and mongo_client is not None:
            assert settings.sae_name is not None and settings.sae_series is not None, (
                "sae_name and sae_series must be provided when saving to MongoDB"
            )
            mongo_client.create_sae(
                name=settings.sae_name,
                series=settings.sae_series,
                path=str(Path(settings.trainer.exp_result_path).absolute()),
                cfg=sae.cfg,
            )

    if wandb_logger is not None:
        wandb_logger.finish()
        logger.info("WandB session closed")

    logger.info("SAE training completed successfully")

TrainCLTSettings

Bases: BaseSettings

Settings for training a Cross Layer Transcoder (CLT). CLT works with multiple layers and their corresponding hook points.

sae instance-attribute

Configuration for the CLT model architecture and parameters, or the path to a pretrained CLT.

sae_name instance-attribute

sae_name: str

Name of the SAE model. Use as identifier for the SAE model in the database.

sae_series instance-attribute

sae_series: str

Series of the SAE model. Use as identifier for the SAE model in the database.

initializer class-attribute instance-attribute

initializer: InitializerConfig | None = None

Configuration for model initialization

trainer instance-attribute

trainer: TrainerConfig

Configuration for training process

activation_factory instance-attribute

activation_factory: ActivationFactoryConfig

Configuration for generating activations

wandb class-attribute instance-attribute

wandb: Optional[WandbConfig] = None

Configuration for Weights & Biases logging

eval class-attribute instance-attribute

eval: bool = False

Whether to run in evaluation mode

data_parallel_size class-attribute instance-attribute

data_parallel_size: int = 1

Size of data parallel mesh

model_parallel_size class-attribute instance-attribute

model_parallel_size: int = 1

Size of model parallel (tensor parallel) mesh

mongo class-attribute instance-attribute

mongo: Optional[MongoDBConfig] = None

Configuration for MongoDB

model class-attribute instance-attribute

model: Optional[LanguageModelConfig] = None

Configuration for the language model. Required if using dataset sources.

model_name class-attribute instance-attribute

model_name: Optional[str] = None

Name of the tokenizer to load. CLT requires a tokenizer to get the modality indices.

datasets class-attribute instance-attribute

datasets: Optional[dict[str, Optional[DatasetConfig]]] = (
    None
)

Name to dataset config mapping. Required if using dataset sources.

device_type class-attribute instance-attribute

device_type: str = 'cuda'

Device type to use for distributed training ('cuda' or 'cpu')

train_clt

train_clt(settings: TrainCLTSettings) -> None

Train a Cross Layer Transcoder (CLT) model.

Parameters:

Name Type Description Default
settings TrainCLTSettings

Configuration settings for CLT training

required
Source code in src/lm_saes/runners/train.py
def train_clt(settings: TrainCLTSettings) -> None:
    """Train a Cross Layer Transcoder (CLT) model.

    Args:
        settings: Configuration settings for CLT training
    """
    # Set up logging
    setup_logging(level="INFO")

    device_mesh = (
        init_device_mesh(
            device_type=settings.device_type,
            mesh_shape=(settings.data_parallel_size, settings.model_parallel_size),
            mesh_dim_names=("data", "model"),
        )
        if settings.model_parallel_size > 1 or settings.data_parallel_size > 1
        else None
    )

    logger.info(f"Device mesh initialized: {device_mesh}")

    mongo_client = MongoClient(settings.mongo) if settings.mongo is not None else None
    if mongo_client:
        logger.info("MongoDB client initialized")

    # Load configurations
    model_cfg = load_config(
        config=settings.model,
        name=settings.model_name,
        mongo_client=mongo_client,
        config_type="model",
        required=False,
    )

    dataset_cfgs = (
        {
            dataset_name: load_config(
                config=dataset_cfg,
                name=dataset_name,
                mongo_client=mongo_client,
                config_type="dataset",
            )
            for dataset_name, dataset_cfg in settings.datasets.items()
        }
        if settings.datasets is not None
        else None
    )

    # Load model and datasets
    logger.info("Loading model and datasets")
    model = load_model(model_cfg) if model_cfg is not None else None
    datasets = (
        {
            dataset_name: load_dataset(dataset_cfg, device_mesh=device_mesh)
            for dataset_name, dataset_cfg in dataset_cfgs.items()
        }
        if dataset_cfgs is not None
        else None
    )

    activation_factory = ActivationFactory(settings.activation_factory, device_mesh=device_mesh)

    logger.info("Processing activations stream")
    activations_stream = activation_factory.process(
        model=model,
        model_name=settings.model_name,
        datasets=datasets,
    )

    wandb_logger = (
        wandb.init(
            project=settings.wandb.wandb_project,
            config=settings.model_dump(),
            name=settings.wandb.exp_name,
            entity=settings.wandb.wandb_entity,
            settings=wandb.Settings(x_disable_stats=True),
            mode=os.getenv("WANDB_MODE", "online"),  # type: ignore
            resume=settings.wandb.wandb_resume,
            id=settings.wandb.wandb_run_id,
        )
        if settings.wandb is not None and (device_mesh is None or mesh_rank(device_mesh) == 0)
        else None
    )

    logger.info("Initializing CLT")
    assert settings.initializer is None or not isinstance(settings.initializer, str), (
        "Cannot use an initializer for a pretrained CLT"
    )
    if isinstance(settings.sae, PretrainedSAE):
        sae = AbstractSparseAutoEncoder.from_pretrained(
            settings.sae.pretrained_name_or_path,
            device_mesh=device_mesh,
            fold_activation_scale=settings.sae.fold_activation_scale,
            strict_loading=settings.sae.strict_loading,
            device=settings.sae.device,
            dtype=settings.sae.dtype,
        )
    elif settings.initializer is not None:
        initializer = Initializer(settings.initializer)
        sae = initializer.initialize_sae_from_config(
            settings.sae,
            activation_stream=activations_stream,
            device_mesh=device_mesh,
            wandb_logger=wandb_logger,
            model=model,
        )
    else:
        sae = AbstractSparseAutoEncoder.from_config(settings.sae, device_mesh=device_mesh)

    n_params = sum(p.numel() for p in sae.parameters())
    logger.info(f"CLT initialized with {n_params / 1e9:.2f}B parameters")

    if wandb_logger is not None:
        logger.info("WandB logger initialized")

    # TODO: implement eval_fn
    eval_fn = (lambda x: None) if settings.eval else None

    logger.info("Starting CLT training")
    if settings.trainer.from_pretrained_path is not None:
        trainer = Trainer.from_checkpoint(
            sae,
            settings.trainer.from_pretrained_path,
        )
        trainer.wandb_logger = wandb_logger
    else:
        trainer = Trainer(settings.trainer)
    sae.cfg.save_hyperparameters(settings.trainer.exp_result_path)
    end_of_stream = trainer.fit(
        sae=sae, activation_stream=activations_stream, eval_fn=eval_fn, wandb_logger=wandb_logger
    )

    logger.info("Training completed, saving CLT model")
    if end_of_stream:
        trainer.save_checkpoint(
            sae=sae,
            checkpoint_path=settings.trainer.exp_result_path,
        )
    else:
        sae.save_pretrained(
            save_path=settings.trainer.exp_result_path,
        )
        if is_primary_rank(device_mesh) and mongo_client is not None:
            assert settings.sae_name is not None and settings.sae_series is not None, (
                "sae_name and sae_series must be provided when saving to MongoDB"
            )
            mongo_client.create_sae(
                name=settings.sae_name,
                series=settings.sae_series,
                path=str(Path(settings.trainer.exp_result_path).absolute()),
                cfg=sae.cfg,
            )

    if wandb_logger is not None:
        wandb_logger.finish()
        logger.info("WandB session closed")

    logger.info("CLT training completed successfully")

TrainCrossCoderSettings

Bases: BaseSettings

Settings for training a CrossCoder. The main difference to TrainSAESettings is that the activation factory is a list of ActivationFactoryConfig, one for each head.

sae instance-attribute

Configuration for the CrossCoder model architecture and parameters, or the path to a pretrained CrossCoder.

sae_name instance-attribute

sae_name: str

Name of the SAE model. Use as identifier for the SAE model in the database.

sae_series instance-attribute

sae_series: str

Series of the SAE model. Use as identifier for the SAE model in the database.

initializer class-attribute instance-attribute

initializer: InitializerConfig | None = None

Configuration for model initialization

trainer instance-attribute

trainer: TrainerConfig

Configuration for training process

activation_factories instance-attribute

activation_factories: list[ActivationFactoryConfig]

Configuration for generating activations

wandb class-attribute instance-attribute

wandb: Optional[WandbConfig] = None

Configuration for Weights & Biases logging

eval class-attribute instance-attribute

eval: bool = False

Whether to run in evaluation mode

data_parallel_size class-attribute instance-attribute

data_parallel_size: int = 1

Size of data parallel mesh

model_parallel_size class-attribute instance-attribute

model_parallel_size: int = 1

Size of model parallel (tensor parallel) mesh

mongo class-attribute instance-attribute

mongo: Optional[MongoDBConfig] = None

Configuration for MongoDB

model class-attribute instance-attribute

model: Optional[LanguageModelConfig] = None

Configuration for the language model. Required if using dataset sources.

model_name class-attribute instance-attribute

model_name: Optional[str] = None

Name of the tokenizer to load. Mixcoder requires a tokenizer to get the modality indices.

datasets class-attribute instance-attribute

datasets: Optional[dict[str, Optional[DatasetConfig]]] = (
    None
)

Name to dataset config mapping. Required if using dataset sources.

device_type class-attribute instance-attribute

device_type: str = 'cuda'

Device type to use for distributed training ('cuda' or 'cpu')

train_crosscoder

train_crosscoder(settings: TrainCrossCoderSettings) -> None

Train a CrossCoder.

Parameters:

Name Type Description Default
settings TrainCrossCoderSettings

Configuration settings for SAE training

required
Source code in src/lm_saes/runners/train.py
def train_crosscoder(settings: TrainCrossCoderSettings) -> None:
    """Train a CrossCoder.

    Args:
        settings: Configuration settings for SAE training
    """
    # Set up logging
    setup_logging(level="INFO")

    assert isinstance(settings.sae, CrossCoderConfig), "CrossCoderConfig is required for training a CrossCoder"
    assert all(
        len(activation_factory.hook_points) == len(settings.activation_factories[0].hook_points)
        for activation_factory in settings.activation_factories
    ), "Number of hook points of activation factories must be the same"
    assert (
        len(settings.activation_factories) * len(settings.activation_factories[0].hook_points) == settings.sae.n_heads
    ), "Total number of hook points must match the number of heads in the CrossCoder"
    head_parallel_size = len(settings.activation_factories)

    device_mesh = init_device_mesh(
        device_type=settings.device_type,
        mesh_shape=(settings.data_parallel_size, head_parallel_size, settings.model_parallel_size),
        mesh_dim_names=("data", "head", "model"),
    )

    logger.info(
        f"Device mesh initialized with {settings.sae.n_heads} heads, {head_parallel_size} head parallel size, {settings.data_parallel_size} data parallel size, {settings.model_parallel_size} model parallel size"
    )

    mongo_client = MongoClient(settings.mongo) if settings.mongo is not None else None
    if mongo_client:
        logger.info("MongoDB client initialized")

    # Load configurations
    model_cfg = load_config(
        config=settings.model,
        name=settings.model_name,
        mongo_client=mongo_client,
        config_type="model",
        required=False,
    )

    dataset_cfgs = (
        {
            dataset_name: load_config(
                config=dataset_cfg,
                name=dataset_name,
                mongo_client=mongo_client,
                config_type="dataset",
            )
            for dataset_name, dataset_cfg in settings.datasets.items()
        }
        if settings.datasets is not None
        else None
    )

    # Load model and datasets
    logger.info("Loading model and datasets")
    model = load_model(model_cfg) if model_cfg is not None else None
    datasets = (
        {
            dataset_name: load_dataset(dataset_cfg, device_mesh=device_mesh)
            for dataset_name, dataset_cfg in dataset_cfgs.items()
        }
        if dataset_cfgs is not None
        else None
    )

    activation_factory_mesh = device_mesh[
        "data", "model"
    ]  # Remove the head dimension, since each activation factory should only be responsible for a subset of the heads.

    logger.info("Setting up activation factory for CrossCoder")
    activation_factory = ActivationFactory(
        settings.activation_factories[device_mesh.get_local_rank("head")], device_mesh=activation_factory_mesh
    )

    logger.info("Processing activations stream")
    activations_stream = activation_factory.process(
        model=model,
        model_name=settings.model_name,
        datasets=datasets,
    )

    wandb_logger = (
        wandb.init(
            project=settings.wandb.wandb_project,
            config=settings.model_dump(),
            name=settings.wandb.exp_name,
            entity=settings.wandb.wandb_entity,
            settings=wandb.Settings(x_disable_stats=True),
            mode=os.getenv("WANDB_MODE", "online"),  # type: ignore
            resume=settings.wandb.wandb_resume,
            id=settings.wandb.wandb_run_id,
        )
        if settings.wandb is not None and (device_mesh is None or mesh_rank(device_mesh) == 0)
        else None
    )

    if wandb_logger is not None:
        logger.info("WandB logger initialized")

    logger.info("Initializing CrossCoder")
    assert settings.initializer is None or not isinstance(settings.initializer, str), (
        "Cannot use an initializer for a pretrained CrossCoder"
    )
    if isinstance(settings.sae, PretrainedSAE):
        sae = AbstractSparseAutoEncoder.from_pretrained(
            settings.sae.pretrained_name_or_path,
            device_mesh=device_mesh,
            fold_activation_scale=settings.sae.fold_activation_scale,
            strict_loading=settings.sae.strict_loading,
            device=settings.sae.device,
            dtype=settings.sae.dtype,
        )
    elif settings.initializer is not None:
        initializer = Initializer(settings.initializer)
        sae = initializer.initialize_sae_from_config(
            settings.sae,
            activation_stream=activations_stream,
            device_mesh=device_mesh,
            wandb_logger=wandb_logger,
            model=model,
        )
    else:
        sae = AbstractSparseAutoEncoder.from_config(settings.sae, device_mesh=device_mesh)

    logger.info("CrossCoder initialized")

    # TODO: implement eval_fn
    eval_fn = (lambda x: None) if settings.eval else None

    logger.info("Starting CrossCoder training")
    if settings.trainer.from_pretrained_path is not None:
        trainer = Trainer.from_checkpoint(
            sae,
            settings.trainer.from_pretrained_path,
        )
        trainer.wandb_logger = wandb_logger
    else:
        trainer = Trainer(settings.trainer)

    sae.cfg.save_hyperparameters(settings.trainer.exp_result_path)
    end_of_stream = trainer.fit(
        sae=sae, activation_stream=activations_stream, eval_fn=eval_fn, wandb_logger=wandb_logger
    )

    logger.info("Training completed, saving CrossCoder")
    if end_of_stream:
        trainer.save_checkpoint(
            sae=sae,
            checkpoint_path=settings.trainer.exp_result_path,
        )
    else:
        sae.save_pretrained(
            save_path=settings.trainer.exp_result_path,
        )
        if is_primary_rank(device_mesh) and mongo_client is not None:
            assert settings.sae_name is not None and settings.sae_series is not None, (
                "sae_name and sae_series must be provided when saving to MongoDB"
            )
            mongo_client.create_sae(
                name=settings.sae_name,
                series=settings.sae_series,
                path=str(Path(settings.trainer.exp_result_path).absolute()),
                cfg=settings.sae,
            )

    if wandb_logger is not None:
        wandb_logger.finish()
        logger.info("WandB session closed")

    logger.info("CrossCoder training completed successfully")

TrainLorsaSettings

Bases: BaseSettings

Settings for training a Lorsa (Low-Rank Sparse Autoencoder) model.

sae instance-attribute

Configuration for the Lorsa model architecture and parameters, or the path to a pretrained Lorsa.

sae_name instance-attribute

sae_name: str

Name of the Lorsa model. Use as identifier for the Lorsa model in the database.

sae_series instance-attribute

sae_series: str

Series of the Lorsa model. Use as identifier for the Lorsa model in the database.

initializer class-attribute instance-attribute

initializer: InitializerConfig | None = None

Configuration for model initialization

trainer instance-attribute

trainer: TrainerConfig

Configuration for training process

activation_factory instance-attribute

activation_factory: ActivationFactoryConfig

Configuration for generating activations

wandb class-attribute instance-attribute

wandb: Optional[WandbConfig] = None

Configuration for Weights & Biases logging

eval class-attribute instance-attribute

eval: bool = False

Whether to run in evaluation mode

model_parallel_size class-attribute instance-attribute

model_parallel_size: int = 1

Size of model parallel (tensor parallel) mesh

data_parallel_size class-attribute instance-attribute

data_parallel_size: int = 1

Size of data parallel mesh

mongo class-attribute instance-attribute

mongo: Optional[MongoDBConfig] = None

Configuration for MongoDB

model class-attribute instance-attribute

model: Optional[LanguageModelConfig] = None

Configuration for the language model. Required if using dataset sources.

model_name class-attribute instance-attribute

model_name: Optional[str] = None

Name of the tokenizer to load. LORSA may require a tokenizer to get the modality indices.

datasets class-attribute instance-attribute

datasets: Optional[dict[str, Optional[DatasetConfig]]] = (
    None
)

Name to dataset config mapping. Required if using dataset sources.

device_type class-attribute instance-attribute

device_type: str = 'cuda'

Device type to use for distributed training ('cuda' or 'cpu')

train_lorsa

train_lorsa(settings: TrainLorsaSettings) -> None

Train a LORSA (Low-Rank Sparse Autoencoder) model.

Parameters:

Name Type Description Default
settings TrainLorsaSettings

Configuration settings for LORSA training

required
Source code in src/lm_saes/runners/train.py
def train_lorsa(settings: TrainLorsaSettings) -> None:
    """Train a LORSA (Low-Rank Sparse Autoencoder) model.

    Args:
        settings: Configuration settings for LORSA training
    """
    # Set up logging
    setup_logging(level="INFO")

    device_mesh = (
        init_device_mesh(
            device_type=settings.device_type,
            mesh_shape=(settings.data_parallel_size, settings.model_parallel_size),
            mesh_dim_names=("data", "model"),
        )
        if settings.model_parallel_size > 1 or settings.data_parallel_size > 1
        else None
    )

    logger.info(f"Device mesh initialized: {device_mesh}")

    mongo_client = MongoClient(settings.mongo) if settings.mongo is not None else None
    if mongo_client:
        logger.info("MongoDB client initialized")

    # Load configurations
    model_cfg = load_config(
        config=settings.model,
        name=settings.model_name,
        mongo_client=mongo_client,
        config_type="model",
        required=False,
    )

    dataset_cfgs = (
        {
            dataset_name: load_config(
                config=dataset_cfg,
                name=dataset_name,
                mongo_client=mongo_client,
                config_type="dataset",
            )
            for dataset_name, dataset_cfg in settings.datasets.items()
        }
        if settings.datasets is not None
        else None
    )

    # Load model and datasets
    logger.info("Loading model and datasets")
    model = load_model(model_cfg) if model_cfg is not None else None
    datasets = (
        {
            dataset_name: load_dataset(dataset_cfg, device_mesh=device_mesh)
            for dataset_name, dataset_cfg in dataset_cfgs.items()
        }
        if dataset_cfgs is not None
        else None
    )

    activation_factory = ActivationFactory(settings.activation_factory, device_mesh=device_mesh)

    logger.info("Processing activations stream")
    activations_stream = activation_factory.process(
        model=model,
        model_name=settings.model_name,
        datasets=datasets,
    )

    logger.info("Initializing lorsa")

    wandb_logger = (
        wandb.init(
            project=settings.wandb.wandb_project,
            config=settings.model_dump(),
            name=settings.wandb.exp_name,
            entity=settings.wandb.wandb_entity,
            settings=wandb.Settings(x_disable_stats=True),
            mode=os.getenv("WANDB_MODE", "online"),  # type: ignore
            resume=settings.wandb.wandb_resume,
            id=settings.wandb.wandb_run_id,
        )
        if settings.wandb is not None and (device_mesh is None or device_mesh.get_rank() == 0)
        else None
    )

    assert settings.initializer is None or not isinstance(settings.initializer, str), (
        "Cannot use an initializer for a pretrained Lorsa"
    )
    if isinstance(settings.sae, PretrainedSAE):
        sae = AbstractSparseAutoEncoder.from_pretrained(
            settings.sae.pretrained_name_or_path,
            device_mesh=device_mesh,
            fold_activation_scale=settings.sae.fold_activation_scale,
            strict_loading=settings.sae.strict_loading,
            device=settings.sae.device,
            dtype=settings.sae.dtype,
        )
    elif settings.initializer is not None:
        initializer = Initializer(settings.initializer)
        sae = initializer.initialize_sae_from_config(
            settings.sae,
            activation_stream=activations_stream,
            device_mesh=device_mesh,
            wandb_logger=wandb_logger,
            model=model,
        )
    else:
        sae = AbstractSparseAutoEncoder.from_config(settings.sae, device_mesh=device_mesh)

    n_params = sum(p.numel() for p in sae.parameters())
    logger.info(f"lorsa initialized with {n_params / 1e9:.2f}B parameters")

    if wandb_logger is not None:
        logger.info("WandB logger initialized")

    # TODO: implement eval_fn
    eval_fn = (lambda x: None) if settings.eval else None

    logger.info("Starting LORSA training")
    if settings.trainer.from_pretrained_path is not None:
        trainer = Trainer.from_checkpoint(
            sae,
            settings.trainer.from_pretrained_path,
        )
        trainer.wandb_logger = wandb_logger
    else:
        trainer = Trainer(settings.trainer)

    sae.cfg.save_hyperparameters(settings.trainer.exp_result_path)
    end_of_stream = trainer.fit(
        sae=sae, activation_stream=activations_stream, eval_fn=eval_fn, wandb_logger=wandb_logger
    )

    logger.info("Training completed, saving LORSA model")
    if end_of_stream:
        trainer.save_checkpoint(
            sae=sae,
            checkpoint_path=settings.trainer.exp_result_path,
        )
    else:
        sae.save_pretrained(
            save_path=settings.trainer.exp_result_path,
        )
        if is_primary_rank(device_mesh) and mongo_client is not None:
            assert settings.sae_name is not None and settings.sae_series is not None, (
                "sae_name and sae_series must be provided when saving to MongoDB"
            )
            mongo_client.create_sae(
                name=settings.sae_name,
                series=settings.sae_series,
                path=str(Path(settings.trainer.exp_result_path).absolute()),
                cfg=sae.cfg,
            )

    if wandb_logger is not None:
        wandb_logger.finish()
        logger.info("WandB session closed")

    logger.info("LORSA training completed successfully")

TrainMOLTSettings

Bases: BaseSettings

Settings for training a Mixture of Linear Transforms (MOLT). MOLT is a more efficient alternative to transcoders that sparsely replaces MLP computation in transformers.

sae instance-attribute

Configuration for the MOLT model architecture and parameters

sae_name instance-attribute

sae_name: str

Name of the SAE model. Use as identifier for the SAE model in the database.

sae_series instance-attribute

sae_series: str

Series of the SAE model. Use as identifier for the SAE model in the database.

initializer class-attribute instance-attribute

initializer: InitializerConfig | None = None

Configuration for model initialization. Should be None for a pretrained MOLT.

trainer instance-attribute

trainer: TrainerConfig

Configuration for training process

activation_factory instance-attribute

activation_factory: ActivationFactoryConfig

Configuration for generating activations

wandb class-attribute instance-attribute

wandb: Optional[WandbConfig] = None

Configuration for Weights & Biases logging

eval class-attribute instance-attribute

eval: bool = False

Whether to run in evaluation mode

data_parallel_size class-attribute instance-attribute

data_parallel_size: int = 1

Size of data parallel mesh

model_parallel_size class-attribute instance-attribute

model_parallel_size: int = 1

Size of model parallel (tensor parallel) mesh

mongo class-attribute instance-attribute

mongo: Optional[MongoDBConfig] = None

Configuration for MongoDB

model class-attribute instance-attribute

model: Optional[LanguageModelConfig] = None

Configuration for the language model. Required if using dataset sources.

model_name class-attribute instance-attribute

model_name: Optional[str] = None

Name of the tokenizer to load. MOLT requires a tokenizer to get the modality indices.

datasets class-attribute instance-attribute

datasets: Optional[dict[str, Optional[DatasetConfig]]] = (
    None
)

Name to dataset config mapping. Required if using dataset sources.

device_type class-attribute instance-attribute

device_type: str = 'cuda'

Device type to use for distributed training ('cuda' or 'cpu')

train_molt

train_molt(settings: TrainMOLTSettings) -> None

Train a Mixture of Linear Transforms (MOLT) model.

Parameters:

Name Type Description Default
settings TrainMOLTSettings

Configuration settings for MOLT training

required
Source code in src/lm_saes/runners/train.py
def train_molt(settings: TrainMOLTSettings) -> None:
    """Train a Mixture of Linear Transforms (MOLT) model.

    Args:
        settings: Configuration settings for MOLT training
    """
    # Set up logging
    setup_logging(level="INFO")

    device_mesh = (
        init_device_mesh(
            device_type=settings.device_type,
            mesh_shape=(settings.model_parallel_size, settings.data_parallel_size),  # TODO: check the order
            mesh_dim_names=("model", "data"),
        )
        if settings.model_parallel_size > 1 or settings.data_parallel_size > 1
        else None
    )

    logger.info(f"Device mesh initialized: {device_mesh}")

    mongo_client = MongoClient(settings.mongo) if settings.mongo is not None else None
    if mongo_client:
        logger.info("MongoDB client initialized")

    # Load configurations
    model_cfg = load_config(
        config=settings.model,
        name=settings.model_name,
        mongo_client=mongo_client,
        config_type="model",
        required=False,
    )

    dataset_cfgs = (
        {
            dataset_name: load_config(
                config=dataset_cfg,
                name=dataset_name,
                mongo_client=mongo_client,
                config_type="dataset",
            )
            for dataset_name, dataset_cfg in settings.datasets.items()
        }
        if settings.datasets is not None
        else None
    )

    # Load model and datasets
    logger.info("Loading model and datasets")
    model = load_model(model_cfg) if model_cfg is not None else None
    datasets = (
        {
            dataset_name: load_dataset(dataset_cfg, device_mesh=device_mesh)
            for dataset_name, dataset_cfg in dataset_cfgs.items()
        }
        if dataset_cfgs is not None
        else None
    )

    activation_factory = ActivationFactory(settings.activation_factory, device_mesh=device_mesh)

    logger.info("Processing activations stream")
    activations_stream = activation_factory.process(
        model=model,
        model_name=settings.model_name,
        datasets=datasets,
    )

    wandb_logger = (
        wandb.init(
            project=settings.wandb.wandb_project,
            config=settings.model_dump(),
            name=settings.wandb.exp_name,
            entity=settings.wandb.wandb_entity,
            settings=wandb.Settings(x_disable_stats=True),
            mode=os.getenv("WANDB_MODE", "online"),  # type: ignore
            resume=settings.wandb.wandb_resume,
            id=settings.wandb.wandb_run_id,
        )
        if settings.wandb is not None and (device_mesh is None or mesh_rank(device_mesh) == 0)
        else None
    )

    logger.info("Initializing MOLT")

    assert settings.initializer is None or not isinstance(settings.initializer, str), (
        "Cannot use an initializer for a pretrained MOLT"
    )
    if isinstance(settings.sae, PretrainedSAE):
        sae = AbstractSparseAutoEncoder.from_pretrained(
            settings.sae.pretrained_name_or_path,
            device_mesh=device_mesh,
            fold_activation_scale=settings.sae.fold_activation_scale,
            strict_loading=settings.sae.strict_loading,
            device=settings.sae.device,
            dtype=settings.sae.dtype,
        )
    elif settings.initializer is not None:
        initializer = Initializer(settings.initializer)
        sae = initializer.initialize_sae_from_config(
            settings.sae,
            activation_stream=activations_stream,
            device_mesh=device_mesh,
            wandb_logger=wandb_logger,
            model=model,
        )
    else:
        sae = AbstractSparseAutoEncoder.from_config(settings.sae, device_mesh=device_mesh)

    logger.info(f"MOLT initialized: {type(sae).__name__}")

    if wandb_logger is not None:
        logger.info("WandB logger initialized")

    # TODO: implement eval_fn
    eval_fn = (lambda x: None) if settings.eval else None

    logger.info("Starting MOLT training")
    if settings.trainer.from_pretrained_path is not None:
        trainer = Trainer.from_checkpoint(
            sae,
            settings.trainer.from_pretrained_path,
        )
        trainer.wandb_logger = wandb_logger
    else:
        trainer = Trainer(settings.trainer)

    sae.cfg.save_hyperparameters(settings.trainer.exp_result_path)
    end_of_stream = trainer.fit(
        sae=sae, activation_stream=activations_stream, eval_fn=eval_fn, wandb_logger=wandb_logger
    )

    logger.info("Training completed, saving MOLT model")
    if end_of_stream:
        trainer.save_checkpoint(
            sae=sae,
            checkpoint_path=settings.trainer.exp_result_path,
        )
    else:
        sae.save_pretrained(
            save_path=settings.trainer.exp_result_path,
        )
        if is_primary_rank(device_mesh) and mongo_client is not None:
            assert settings.sae_name is not None and settings.sae_series is not None, (
                "sae_name and sae_series must be provided when saving to MongoDB"
            )
            mongo_client.create_sae(
                name=settings.sae_name,
                series=settings.sae_series,
                path=str(Path(settings.trainer.exp_result_path).absolute()),
                cfg=sae.cfg,
            )

    if wandb_logger is not None:
        wandb_logger.finish()
        logger.info("WandB session closed")

    logger.info("MOLT training completed successfully")

EvaluateSAESettings

Bases: BaseSettings

Settings for evaluating a Sparse Autoencoder.

sae instance-attribute

Path to a pretrained SAE model

sae_name instance-attribute

sae_name: str

Name of the SAE model. Use as identifier for the SAE model in the database.

sae_series instance-attribute

sae_series: str

Series of the SAE model. Use as identifier for the SAE model in the database.

activation_factory instance-attribute

activation_factory: ActivationFactoryConfig

Configuration for generating activations

model class-attribute instance-attribute

model: Optional[LanguageModelConfig] = None

Configuration for the language model. Required if using dataset sources.

eval instance-attribute

eval: EvalConfig

Configuration for evaluation

model_parallel_size class-attribute instance-attribute

model_parallel_size: int = 1

Size of model parallel (tensor parallel) mesh

fold_activation_scale class-attribute instance-attribute

fold_activation_scale: bool = False

Whether to fold the activation scale.

wandb class-attribute instance-attribute

wandb: Optional[WandbConfig] = None

Configuration for Weights & Biases logging

device_type class-attribute instance-attribute

device_type: str = 'cuda'

Device type to use for distributed training ('cuda' or 'cpu')

evaluate_sae

evaluate_sae(settings: EvaluateSAESettings) -> None

Evaluate a SAE model.

Parameters:

Name Type Description Default
settings EvaluateSAESettings

Configuration settings for SAE evaluation

required
Source code in src/lm_saes/runners/eval.py
def evaluate_sae(settings: EvaluateSAESettings) -> None:
    """Evaluate a SAE model.

    Args:
        settings: Configuration settings for SAE evaluation
    """
    # Set up logging
    setup_logging(level="INFO")

    device_mesh = (
        init_device_mesh(
            device_type=settings.device_type,
            mesh_shape=(settings.model_parallel_size,),
            mesh_dim_names=("model",),
        )
        if settings.model_parallel_size > 1
        else None
    )

    logger.info(f"Device mesh initialized: {device_mesh}")

    activation_factory = ActivationFactory(settings.activation_factory)

    logger.info("Loading SAE model")

    sae = AbstractSparseAutoEncoder.from_pretrained(
        settings.sae.pretrained_name_or_path,
        device_mesh=device_mesh,
        fold_activation_scale=settings.fold_activation_scale,
        device=settings.sae.device,
        dtype=settings.sae.dtype,
        strict_loading=settings.sae.strict_loading,
    )

    logger.info(f"SAE model loaded: {type(sae).__name__}")

    wandb_logger = (
        wandb.init(
            project=settings.wandb.wandb_project,
            config=settings.model_dump(),
            name=settings.wandb.exp_name,
            entity=settings.wandb.wandb_entity,
            settings=wandb.Settings(x_disable_stats=True),
            mode=os.getenv("WANDB_MODE", "online"),  # type: ignore
        )
        if settings.wandb is not None and (device_mesh is None or mesh_rank(device_mesh) == 0)
        else None
    )

    if wandb_logger is not None:
        logger.info("WandB logger initialized")

    logger.info("Processing activations for evaluation")
    activations = activation_factory.process()
    evaluator = Evaluator(settings.eval)
    evaluator.evaluate(sae, activations, wandb_logger)
    logger.info("Evaluation completed")

EvaluateCrossCoderSettings

Bases: BaseSettings

Settings for evaluating a CrossCoder model.

sae instance-attribute

Path to a pretrained CrossCoder model

sae_name instance-attribute

sae_name: str

Name of the SAE model. Use as identifier for the SAE model in the database.

sae_series instance-attribute

sae_series: str

Series of the SAE model. Use as identifier for the SAE model in the database.

activation_factories instance-attribute

activation_factories: list[ActivationFactoryConfig]

Configuration for generating activations

eval instance-attribute

eval: EvalConfig

Configuration for evaluation

wandb class-attribute instance-attribute

wandb: Optional[WandbConfig] = None

Configuration for Weights & Biases logging

device_type class-attribute instance-attribute

device_type: str = 'cuda'

Device type to use for distributed training ('cuda' or 'cpu')

evaluate_crosscoder

evaluate_crosscoder(
    settings: EvaluateCrossCoderSettings,
) -> None

Evaluate a CrossCoder model. The key difference to evaluate_sae is that the activation factories are a list of ActivationFactoryConfig, one for each head; and the evaluating contains a device mesh transformation from head parallelism to model (feature) parallelism.

Parameters:

Name Type Description Default
settings EvaluateCrossCoderSettings

Configuration settings for CrossCoder evaluation

required
Source code in src/lm_saes/runners/eval.py
@torch.no_grad()
def evaluate_crosscoder(settings: EvaluateCrossCoderSettings) -> None:
    """Evaluate a CrossCoder model. The key difference to evaluate_sae is that the activation factories are a list of ActivationFactoryConfig, one for each head; and the evaluating contains a device mesh transformation from head parallelism to model (feature) parallelism.

    Args:
        settings: Configuration settings for CrossCoder evaluation
    """
    # Set up logging
    setup_logging(level="INFO")

    parallel_size = len(settings.activation_factories)

    logger.info(f"Analyzing CrossCoder with {parallel_size} parallel size")

    device_mesh = init_device_mesh(
        device_type=settings.device_type,
        mesh_shape=(parallel_size,),
        mesh_dim_names=("head",),
    )

    logger.info("Device meshes initialized for CrossCoder analysis")

    logger.info("Setting up activation factory for CrossCoder head")
    activation_factory = ActivationFactory(settings.activation_factories[device_mesh.get_local_rank("head")])

    logger.info("Loading CrossCoder model")
    sae = CrossCoder.from_pretrained(
        settings.sae.pretrained_name_or_path,
        device_mesh=device_mesh,
        device=settings.sae.device,
        dtype=settings.sae.dtype,
        fold_activation_scale=settings.sae.fold_activation_scale,
        strict_loading=settings.sae.strict_loading,
    )

    assert len(settings.activation_factories) * len(settings.activation_factories[0].hook_points) == sae.cfg.n_heads, (
        "Total number of hook points must match the number of heads in the CrossCoder"
    )

    wandb_logger = (
        wandb.init(
            project=settings.wandb.wandb_project,
            config=settings.model_dump(),
            name=settings.wandb.exp_name,
            entity=settings.wandb.wandb_entity,
            settings=wandb.Settings(x_disable_stats=True),
            mode=os.getenv("WANDB_MODE", "online"),  # type: ignore
        )
        if settings.wandb is not None and (device_mesh is None or mesh_rank(device_mesh) == 0)
        else None
    )

    if wandb_logger is not None:
        logger.info("WandB logger initialized")

    logger.info("Processing activations for CrossCoder evaluation")
    activations = activation_factory.process()
    evaluator = Evaluator(settings.eval)
    evaluator.evaluate(sae, activations, wandb_logger)

    logger.info("CrossCoder evaluation completed successfully")

AnalyzeSAESettings

Bases: BaseSettings

Settings for analyzing a Sparse Autoencoder.

sae instance-attribute

Configuration for the SAE model architecture and parameters

sae_name instance-attribute

sae_name: str

Name of the SAE model. Use as identifier for the SAE model in the database.

sae_series instance-attribute

sae_series: str

Series of the SAE model. Use as identifier for the SAE model in the database.

activation_factory instance-attribute

activation_factory: ActivationFactoryConfig

Configuration for generating activations

model class-attribute instance-attribute

model: Optional[LanguageModelConfig] = None

Configuration for the language model. Required if using dataset sources.

model_name class-attribute instance-attribute

model_name: Optional[str] = None

Name of the tokenizer to load. LORSA may require a tokenizer to get the modality indices.

datasets class-attribute instance-attribute

datasets: Optional[dict[str, Optional[DatasetConfig]]] = (
    None
)

Name to dataset config mapping. Required if using dataset sources.

analyzer instance-attribute

Configuration for feature analysis

feature_analysis_name class-attribute instance-attribute

feature_analysis_name: str = 'default'

Name of the feature analysis.

mongo class-attribute instance-attribute

mongo: MongoDBConfig | None = None

Configuration for the MongoDB database.

output_dir class-attribute instance-attribute

output_dir: Optional[Path] = None

Directory to save analysis results. Only used if MongoDB client is not provided.

model_parallel_size class-attribute instance-attribute

model_parallel_size: int = 1

Size of model parallel (tensor parallel) mesh

device_type class-attribute instance-attribute

device_type: str = 'cuda'

Device type to use for distributed training ('cuda' or 'cpu')

analyze_sae

analyze_sae(settings: AnalyzeSAESettings) -> None

Analyze a SAE model.

Parameters:

Name Type Description Default
settings AnalyzeSAESettings

Configuration settings for SAE analysis

required
Source code in src/lm_saes/runners/analyze.py
@torch.no_grad()
def analyze_sae(settings: AnalyzeSAESettings) -> None:
    """Analyze a SAE model.

    Args:
        settings: Configuration settings for SAE analysis
    """
    # Set up logging
    setup_logging(level="INFO")

    device_mesh = (
        init_device_mesh(
            device_type=settings.device_type,
            mesh_shape=(settings.model_parallel_size,),
            mesh_dim_names=("model",),
        )
        if settings.model_parallel_size > 1
        else None
    )

    logger.info(f"Device mesh initialized: {device_mesh}")

    mongo_client = None
    if settings.mongo is not None:
        mongo_client = MongoClient(settings.mongo)
        logger.info("MongoDB client initialized")
    else:
        assert settings.output_dir is not None, "Output directory must be provided if MongoDB client is not provided"
        logger.info(f"Analysis results will be saved to {settings.output_dir}")

    # Load configurations
    model_cfg = load_config(
        config=settings.model,
        name=settings.model_name,
        mongo_client=mongo_client,
        config_type="model",
        required=False,
    )

    dataset_cfgs = (
        {
            dataset_name: load_config(
                config=dataset_cfg,
                name=dataset_name,
                mongo_client=mongo_client,
                config_type="dataset",
            )
            for dataset_name, dataset_cfg in settings.datasets.items()
        }
        if settings.datasets is not None
        else None
    )

    model = load_model(model_cfg) if model_cfg is not None else None
    datasets = (
        {
            dataset_name: load_dataset(dataset_cfg, device_mesh=device_mesh)
            for dataset_name, dataset_cfg in dataset_cfgs.items()
        }
        if dataset_cfgs is not None
        else None
    )

    activation_factory = ActivationFactory(settings.activation_factory, device_mesh=device_mesh)

    sae = AbstractSparseAutoEncoder.from_pretrained(
        settings.sae.pretrained_name_or_path,
        device_mesh=device_mesh,
        device=settings.sae.device,
        dtype=settings.sae.dtype,
        fold_activation_scale=settings.sae.fold_activation_scale,
        strict_loading=settings.sae.strict_loading,
    )

    logger.info(f"SAE model loaded: {type(sae).__name__}")

    analyzer = FeatureAnalyzer(settings.analyzer)
    logger.info("Feature analyzer initialized")

    logger.info("Processing activations for analysis")

    with torch.amp.autocast(device_type=settings.device_type, dtype=settings.amp_dtype):
        result = analyzer.analyze_chunk(
            activation_factory,
            sae=sae,
            device_mesh=device_mesh,
            activation_factory_process_kwargs={
                "model": model,
                "model_name": settings.model_name,
                "datasets": datasets,
            },
        )

    logger.info("Analysis completed, saving results")
    start_idx = 0 if device_mesh is None else device_mesh.get_local_rank("model") * len(result)
    if mongo_client is not None:
        logger.info("Saving results to MongoDB")
        mongo_client.add_feature_analysis(
            name=settings.feature_analysis_name,
            sae_name=settings.sae_name,
            sae_series=settings.sae_series,
            analysis=result,
            start_idx=start_idx,
        )
        logger.info("Results saved to MongoDB")
    else:
        assert settings.output_dir is not None, "Output directory must be set when MongoDB is not used"
        logger.info(f"Saving results to output directory: {settings.output_dir}")
        pickle_path = save_analysis_to_file(
            output_dir=settings.output_dir,
            analysis_name=settings.feature_analysis_name,
            sae_name=settings.sae_name,
            sae_series=settings.sae_series,
            analysis=result,
            start_idx=start_idx,
            device_mesh=device_mesh,
        )
        logger.info(f"Results saved to: {pickle_path}")

    logger.info("SAE analysis completed successfully")

AnalyzeCrossCoderSettings

Bases: BaseSettings

Settings for analyzing a CrossCoder model.

sae instance-attribute

Configuration for the CrossCoder model architecture and parameters

sae_name instance-attribute

sae_name: str

Name of the SAE model. Use as identifier for the SAE model in the database.

sae_series instance-attribute

sae_series: str

Series of the SAE model. Use as identifier for the SAE model in the database.

activation_factories instance-attribute

activation_factories: list[ActivationFactoryConfig]

Configuration for generating activations

analyzer instance-attribute

Configuration for feature analysis

amp_dtype class-attribute instance-attribute

amp_dtype: dtype = bfloat16

The dtype to use for outputting activations. If None, will not override the dtype.

feature_analysis_name class-attribute instance-attribute

feature_analysis_name: str = 'default'

Name of the feature analysis.

mongo class-attribute instance-attribute

mongo: MongoDBConfig | None = None

Configuration for the MongoDB database.

output_dir class-attribute instance-attribute

output_dir: Optional[Path] = None

Directory to save analysis results. Only used if MongoDB client is not provided.

device_type class-attribute instance-attribute

device_type: str = 'cuda'

Device type to use for distributed training ('cuda' or 'cpu')

analyze_crosscoder

analyze_crosscoder(
    settings: AnalyzeCrossCoderSettings,
) -> None

Analyze a CrossCoder model. The key difference to analyze_sae is that the activation factories are a list of ActivationFactoryConfig, one for each head; and the analyzing contains a device mesh transformation from head parallelism to model (feature) parallelism.

Parameters:

Name Type Description Default
settings AnalyzeCrossCoderSettings

Configuration settings for CrossCoder analysis

required
Source code in src/lm_saes/runners/analyze.py
@torch.no_grad()
def analyze_crosscoder(settings: AnalyzeCrossCoderSettings) -> None:
    """Analyze a CrossCoder model. The key difference to analyze_sae is that the activation factories are a list of ActivationFactoryConfig, one for each head; and the analyzing contains a device mesh transformation from head parallelism to model (feature) parallelism.

    Args:
        settings: Configuration settings for CrossCoder analysis
    """
    # Set up logging
    setup_logging(level="INFO")

    parallel_size = len(settings.activation_factories)

    logger.info(f"Analyzing CrossCoder with {parallel_size} parallel size")

    crosscoder_device_mesh = init_device_mesh(
        device_type=settings.device_type,
        mesh_shape=(parallel_size,),
        mesh_dim_names=("head",),
    )

    device_mesh = init_device_mesh(
        device_type=settings.device_type,
        mesh_shape=(parallel_size,),
        mesh_dim_names=("model",),
    )

    logger.info("Device meshes initialized for CrossCoder analysis")

    mongo_client = None
    if settings.mongo is not None:
        mongo_client = MongoClient(settings.mongo)
        logger.info("MongoDB client initialized")
    else:
        assert settings.output_dir is not None, "Output directory must be provided if MongoDB client is not provided"
        logger.info(f"Analysis results will be saved to: {settings.output_dir}")

    logger.info("Setting up activation factory for CrossCoder head")
    activation_factory = ActivationFactory(settings.activation_factories[crosscoder_device_mesh.get_local_rank("head")])

    logger.info("Loading CrossCoder model")
    sae = CrossCoder.from_pretrained(
        settings.sae.pretrained_name_or_path,
        device_mesh=crosscoder_device_mesh,
        device=settings.sae.device,
        dtype=settings.sae.dtype,
        fold_activation_scale=settings.sae.fold_activation_scale,
        strict_loading=settings.sae.strict_loading,
    )

    assert len(settings.activation_factories) * len(settings.activation_factories[0].hook_points) == sae.cfg.n_heads, (
        "Total number of hook points must match the number of heads in the CrossCoder"
    )

    logger.info("Feature analyzer initialized")
    analyzer = FeatureAnalyzer(settings.analyzer)

    logger.info("Processing activations for CrossCoder analysis")

    with torch.amp.autocast(device_type=settings.device_type, dtype=settings.amp_dtype):
        result = analyzer.analyze_chunk(
            activation_factory,
            sae=sae,
            device_mesh=device_mesh,
        )

    logger.info("CrossCoder analysis completed, saving results to MongoDB")
    start_idx = 0 if device_mesh is None else device_mesh.get_local_rank("model") * len(result)
    if mongo_client is not None:
        mongo_client.add_feature_analysis(
            name=settings.feature_analysis_name,
            sae_name=settings.sae_name,
            sae_series=settings.sae_series,
            analysis=result,
            start_idx=start_idx,
        )
    else:
        assert settings.output_dir is not None, "Output directory must be set when MongoDB is not used"
        logger.info(f"Saving results to output directory: {settings.output_dir}")
        pickle_path = save_analysis_to_file(
            output_dir=settings.output_dir,
            analysis_name=settings.feature_analysis_name,
            sae_name=settings.sae_name,
            sae_series=settings.sae_series,
            analysis=result,
            start_idx=start_idx,
            device_mesh=device_mesh,
        )
        logger.info(f"Results saved to: {pickle_path}")

    logger.info("CrossCoder analysis completed successfully")

GenerateActivationsSettings

Bases: BaseSettings

Settings for activation generation.

model class-attribute instance-attribute

model: Optional[LanguageModelConfig] = None

Configuration for loading the language model. If None, will read from the database.

model_name instance-attribute

model_name: str

Name of the model to load. Use as identifier for the model in the database.

dataset class-attribute instance-attribute

dataset: Optional[DatasetConfig] = None

Configuration for loading the dataset. If None, will read from the database.

dataset_name instance-attribute

dataset_name: str

Name of the dataset. Use as identifier for the dataset in the database.

hook_points instance-attribute

hook_points: list[str]

List of model hook points to capture activations from

output_dir instance-attribute

output_dir: str

Directory to save activation files

target class-attribute instance-attribute

Target type for activation generation

model_batch_size class-attribute instance-attribute

model_batch_size: int = 1

Batch size for model forward

batch_size instance-attribute

batch_size: int

Size of the batch for activation generation

buffer_size class-attribute instance-attribute

buffer_size: Optional[int] = None

Size of the buffer for activation generation

buffer_shuffle class-attribute instance-attribute

buffer_shuffle: Optional[BufferShuffleConfig] = None

"Manual seed and device of generator for generating randomperm in buffer

total_tokens class-attribute instance-attribute

total_tokens: Optional[int] = None

Optional total number of tokens to generate

context_size class-attribute instance-attribute

context_size: int = 128

Context window size for tokenization

n_samples_per_chunk class-attribute instance-attribute

n_samples_per_chunk: Optional[int] = None

Number of samples per saved chunk

num_workers class-attribute instance-attribute

num_workers: Optional[int] = None

Number of workers for parallel writing

format class-attribute instance-attribute

format: Literal['pt', 'safetensors'] = 'safetensors'

Format to save activations in ('pt' or 'safetensors')

n_shards class-attribute instance-attribute

n_shards: Optional[int] = None

Number of shards to split the dataset into. If None, the dataset is split to the world size. Must be larger than the world size.

start_shard class-attribute instance-attribute

start_shard: int = 0

The shard to start writing from

mongo class-attribute instance-attribute

mongo: Optional[MongoDBConfig] = None

Configuration for the MongoDB database. If None, will not use the database.

ignore_token_ids class-attribute instance-attribute

ignore_token_ids: Optional[list[int]] = None

Tokens to ignore in the activations.

device_type class-attribute instance-attribute

device_type: str = 'cuda'

Device type to use for distributed training ('cuda' or 'cpu')

override_dtype class-attribute instance-attribute

override_dtype: dtype | None = None

Dtype to override the activations to. If None, will not override the dtype.

model_post_init

model_post_init(__context: dict) -> None

Validate configuration after initialization.

Source code in src/lm_saes/runners/generate.py
def model_post_init(self, __context: dict) -> None:
    """Validate configuration after initialization."""
    if self.mongo is not None:
        assert self.model is not None, "Database not provided. Must manually provide model config."
        assert self.dataset is not None, "Database not provided. Must manually provide dataset config."

generate_activations

generate_activations(
    settings: GenerateActivationsSettings,
) -> None

Generate and save model activations from a dataset.

Parameters:

Name Type Description Default
settings GenerateActivationsSettings

Configuration settings for activation generation

required
Source code in src/lm_saes/runners/generate.py
def generate_activations(settings: GenerateActivationsSettings) -> None:
    """Generate and save model activations from a dataset.

    Args:
        settings: Configuration settings for activation generation
    """
    # Set up logging
    setup_logging(level="INFO")

    # Initialize device mesh
    device_mesh = (
        init_device_mesh(
            device_type=settings.device_type,
            mesh_shape=(int(os.environ.get("WORLD_SIZE", 1)), 1),
            mesh_dim_names=("data", "model"),
        )
        if os.environ.get("WORLD_SIZE") is not None
        else None
    )

    logger.info(f"Device mesh initialized: {device_mesh}")

    mongo_client = MongoClient(settings.mongo) if settings.mongo is not None else None
    if mongo_client:
        logger.info("MongoDB client initialized")

    # Load configurations
    logger.info("Loading model and dataset configurations")
    model_cfg = load_config(
        config=settings.model, name=settings.model_name, mongo_client=mongo_client, config_type="model"
    )

    dataset_cfg = load_config(
        config=settings.dataset, name=settings.dataset_name, mongo_client=mongo_client, config_type="dataset"
    )

    # Load model and dataset
    logger.info("Loading model and dataset")
    model = load_model(model_cfg)
    dataset, metadata = load_dataset(
        dataset_cfg,
        device_mesh=device_mesh,
        n_shards=settings.n_shards,
        start_shard=settings.start_shard,
    )

    logger.info(f"Model loaded: {settings.model_name}")
    logger.info(f"Dataset loaded: {settings.dataset_name}")

    # Configure activation generation
    logger.info("Configuring activation factory")
    factory_cfg = ActivationFactoryConfig(
        sources=[ActivationFactoryDatasetSource(name=settings.dataset_name)],
        target=settings.target,
        hook_points=settings.hook_points,
        context_size=settings.context_size,
        model_batch_size=settings.model_batch_size,
        batch_size=settings.batch_size,
        buffer_size=settings.buffer_size,
        buffer_shuffle=settings.buffer_shuffle,
        ignore_token_ids=settings.ignore_token_ids,
        override_dtype=settings.override_dtype,
    )

    # Configure activation writer
    logger.info("Configuring activation writer")
    writer_cfg = ActivationWriterConfig(
        hook_points=settings.hook_points,
        total_generating_tokens=settings.total_tokens,
        n_samples_per_chunk=settings.n_samples_per_chunk,
        cache_dir=settings.output_dir,
        format=settings.format,
        num_workers=settings.num_workers,
    )

    # Create factory and writer
    factory = ActivationFactory(factory_cfg)
    writer = ActivationWriter(writer_cfg)

    logger.info("Starting activation generation and writing")
    # Generate and write activations
    activations = factory.process(
        model=model, model_name=settings.model_name, datasets={settings.dataset_name: (dataset, metadata)}
    )
    writer.process(activations, device_mesh=device_mesh, start_shard=settings.start_shard)

    logger.info("Activation generation completed successfully")

AutoInterpSettings

Bases: BaseSettings

Settings for automatic interpretation of SAE features.

sae_name instance-attribute

sae_name: str

Name of the SAE model to interpret. Use as identifier for the SAE model in the database.

sae_series instance-attribute

sae_series: str

Series of the SAE model to interpret. Use as identifier for the SAE model in the database.

model instance-attribute

Configuration for the language model used to generate activations.

model_name instance-attribute

model_name: str

Name of the model to load.

auto_interp instance-attribute

auto_interp: AutoInterpConfig

Configuration for the auto-interpretation process.

mongo instance-attribute

Configuration for the MongoDB database.

features class-attribute instance-attribute

features: Optional[list[int]] = None

List of specific feature indices to interpret. If None, will interpret all features.

analysis_name class-attribute instance-attribute

analysis_name: str = 'default'

Name of the analysis to use for interpretation.

max_workers class-attribute instance-attribute

max_workers: int = 10

Maximum number of workers to use for interpretation.

auto_interp

auto_interp(settings: AutoInterpSettings)

Synchronous wrapper for interpret_feature.

Parameters:

Name Type Description Default
settings AutoInterpSettings

Configuration for feature interpretation

required
Source code in src/lm_saes/runners/autointerp.py
def auto_interp(settings: AutoInterpSettings):
    """Synchronous wrapper for interpret_feature.

    Args:
        settings: Configuration for feature interpretation
    """
    asyncio.run(interpret_feature(settings))

SweepSAESettings

Bases: BaseSettings

Settings for sweeping a Sparse Autoencoder (SAE).

items instance-attribute

items: list[SweepingItem]

List of sweeping items

activation_factory instance-attribute

activation_factory: ActivationFactoryConfig

Configuration for generating activations

eval class-attribute instance-attribute

eval: bool = False

Whether to run in evaluation mode

data_parallel_size class-attribute instance-attribute

data_parallel_size: int = 1

Size of data parallel mesh

model_parallel_size class-attribute instance-attribute

model_parallel_size: int = 1

Size of model parallel (tensor parallel) mesh

mongo class-attribute instance-attribute

mongo: Optional[MongoDBConfig] = None

Configuration for MongoDB

model class-attribute instance-attribute

model: Optional[LanguageModelConfig] = None

Configuration for the language model. Required if using dataset sources.

model_name class-attribute instance-attribute

model_name: Optional[str] = None

Name of the tokenizer to load. Mixcoder requires a tokenizer to get the modality indices.

datasets class-attribute instance-attribute

datasets: Optional[dict[str, Optional[DatasetConfig]]] = (
    None
)

Name to dataset config mapping. Required if using dataset sources.

device_type class-attribute instance-attribute

device_type: str = 'cuda'

Device type to use for distributed training ('cuda' or 'cpu')

SweepingItem pydantic-model

Bases: BaseModel

A single item in a sweeping configuration.

Fields:

sae pydantic-field

Configuration for the SAE model architecture and parameters, or the path to a pretrained SAE.

sae_name pydantic-field

sae_name: str

Name of the SAE model. Use as identifier for the SAE model in the database.

sae_series pydantic-field

sae_series: str

Series of the SAE model. Use as identifier for the SAE model in the database.

initializer pydantic-field

initializer: InitializerConfig | None = None

Configuration for model initialization. Should be None for a pretrained SAE.

trainer pydantic-field

trainer: TrainerConfig

Configuration for training process

wandb pydantic-field

wandb: Optional[WandbConfig] = None

Configuration for Weights & Biases logging

sweep_sae

sweep_sae(settings: SweepSAESettings) -> None

Sweep experiments for training SAE models.

Parameters:

Name Type Description Default
settings SweepSAESettings

Configuration settings for SAE sweeping

required
Source code in src/lm_saes/runners/train.py
def sweep_sae(settings: SweepSAESettings) -> None:
    """Sweep experiments for training SAE models.

    Args:
        settings: Configuration settings for SAE sweeping
    """
    # Set up logging
    setup_logging(level="INFO")

    n_sweeps = len(settings.items)

    device_mesh = init_device_mesh(
        device_type=settings.device_type,
        mesh_shape=(n_sweeps, settings.data_parallel_size, settings.model_parallel_size),
        mesh_dim_names=("sweep", "data", "model"),
    )

    logger.info(f"Device mesh initialized for sweep with {n_sweeps} configurations")

    mongo_client = MongoClient(settings.mongo) if settings.mongo is not None else None

    logger.info("Loading configurations on rank 0")
    # Load configurations
    model_cfg = load_config(
        config=settings.model,
        name=settings.model_name,
        mongo_client=mongo_client,
        config_type="model",
        required=False,
    )

    dataset_cfgs = (
        {
            dataset_name: load_config(
                config=dataset_cfg,
                name=dataset_name,
                mongo_client=mongo_client,
                config_type="dataset",
            )
            for dataset_name, dataset_cfg in settings.datasets.items()
        }
        if settings.datasets is not None
        else None
    )

    # Load model and datasets
    model = load_model(model_cfg) if model_cfg is not None else None
    datasets = (
        {
            dataset_name: load_dataset(dataset_cfg, device_mesh=device_mesh)
            for dataset_name, dataset_cfg in dataset_cfgs.items()
        }
        if dataset_cfgs is not None
        else None
    )

    activation_factory = ActivationFactory(settings.activation_factory, device_mesh=device_mesh)

    logger.info("Processing activations stream")
    activations_stream = activation_factory.process(
        model=model,
        model_name=settings.model_name,
        datasets=datasets,
    )

    sae_device_mesh = device_mesh["data", "model"]
    logger.info(f"Created 2D sub-mesh for SAE: {sae_device_mesh}")

    item = settings.items[device_mesh.get_local_rank("sweep")]
    logger.info(f"Processing sweep item: {item.sae_name}/{item.sae_series}")

    def convert_activations_to_2d_mesh(stream_3d, mesh_2d):
        from torch.distributed.tensor import DTensor

        for batch in stream_3d:
            converted_batch = {}
            for key, value in batch.items():
                if isinstance(value, torch.Tensor):
                    assert isinstance(value, DTensor), "value must be a DTensor"
                    local_tensor = value.to_local()
                    from lm_saes.utils.distributed import DimMap

                    converted_value = DTensor.from_local(
                        local_tensor,
                        device_mesh=mesh_2d,
                        placements=DimMap({"data": 0}).placements(mesh_2d),
                    )
                    converted_batch[key] = converted_value
                else:
                    converted_batch[key] = value
            yield converted_batch

    activations_stream = convert_activations_to_2d_mesh(activations_stream, sae_device_mesh)

    logger.info("Initializing SAE on 2D sub-mesh")

    assert item.initializer is None or not isinstance(item.initializer, str), (
        "Cannot use an initializer for a pretrained SAE"
    )
    if isinstance(item.sae, PretrainedSAE):
        sae = AbstractSparseAutoEncoder.from_pretrained(
            item.sae.pretrained_name_or_path,
            device_mesh=sae_device_mesh,
            fold_activation_scale=item.sae.fold_activation_scale,
            strict_loading=item.sae.strict_loading,
            device=item.sae.device,
            dtype=item.sae.dtype,
        )
    elif item.initializer is not None:
        initializer = Initializer(item.initializer)
        sae = initializer.initialize_sae_from_config(
            item.sae,
            activation_stream=activations_stream,
            device_mesh=sae_device_mesh,
            model=model,
        )
    else:
        sae = AbstractSparseAutoEncoder.from_config(item.sae, device_mesh=sae_device_mesh)

    wandb_logger = (
        wandb.init(
            project=item.wandb.wandb_project,
            config=item.model_dump(),
            name=item.wandb.exp_name,
            entity=item.wandb.wandb_entity,
            settings=wandb.Settings(x_disable_stats=True),
            mode=os.getenv("WANDB_MODE", "online"),  # type: ignore
        )
        if item.wandb is not None and is_primary_rank(device_mesh)
        else None
    )
    # TODO: implement eval_fn
    eval_fn = (lambda x: None) if settings.eval else None

    logger.info("Starting training for sweep item")
    trainer = Trainer(item.trainer)
    sae.cfg.save_hyperparameters(item.trainer.exp_result_path)
    trainer.fit(sae=sae, activation_stream=activations_stream, eval_fn=eval_fn, wandb_logger=wandb_logger)

    logger.info("Training completed, saving sweep item")
    sae.save_pretrained(
        save_path=item.trainer.exp_result_path,
    )
    if is_primary_rank(device_mesh) and mongo_client is not None:
        assert item.sae_name is not None and item.sae_series is not None, (
            "sae_name and sae_series must be provided when saving to MongoDB"
        )
        mongo_client.create_sae(
            name=item.sae_name,
            series=item.sae_series,
            path=str(Path(item.trainer.exp_result_path).absolute()),
            cfg=sae.cfg,
        )

    if wandb_logger is not None:
        wandb_logger.finish()
        logger.info("WandB session closed for sweep item")

    logger.info(f"Sweep item completed: {item.sae_name}/{item.sae_series}")

DirectLogitAttributeSettings

Bases: BaseSettings

Settings for analyzing a CrossCoder model.

sae instance-attribute

Configuration for the SAE model architecture and parameters

sae_name instance-attribute

sae_name: str

Name of the SAE model. Use as identifier for the SAE model in the database.

layer_idx class-attribute instance-attribute

layer_idx: Optional[int | None] = None

The index of layer to DLA.

sae_series instance-attribute

sae_series: str

Series of the SAE model. Use as identifier for the SAE model in the database.

model class-attribute instance-attribute

model: Optional[LanguageModelConfig] = None

Configuration for the language model.

model_name instance-attribute

model_name: str

Name of the language model.

direct_logit_attributor instance-attribute

direct_logit_attributor: DirectLogitAttributorConfig

Configuration for the direct logit attributor.

mongo class-attribute instance-attribute

mongo: MongoDBConfig | None = None

Configuration for the MongoDB database.

analysis_file class-attribute instance-attribute

analysis_file: Path | None = None

The analysis results file to be updated. Only used if MongoDB client is not provided.

device_type class-attribute instance-attribute

device_type: str = 'cuda'

Device type to use for distributed training ('cuda' or 'cpu')

direct_logit_attribute

direct_logit_attribute(
    settings: DirectLogitAttributeSettings,
) -> None

Direct logit attribute a SAE model.

Parameters:

Name Type Description Default
settings DirectLogitAttributeSettings

Configuration settings for DirectLogitAttributor

required
Source code in src/lm_saes/runners/analyze.py
@torch.no_grad()
def direct_logit_attribute(settings: DirectLogitAttributeSettings) -> None:
    """Direct logit attribute a SAE model.

    Args:
        settings: Configuration settings for DirectLogitAttributor
    """
    # Set up logging
    setup_logging(level="INFO")

    # device_mesh = (
    #     init_device_mesh(
    #         device_type=settings.device_type,
    #         mesh_shape=(settings.head_parallel_size, settings.data_parallel_size, settings.model_parallel_size),
    #         mesh_dim_names=("head", "data", "model"),
    #     )
    #     if settings.head_parallel_size > 1 or settings.data_parallel_size > 1 or settings.model_parallel_size > 1
    #     else None
    # )

    mongo_client = None
    if settings.mongo is not None:
        mongo_client = MongoClient(settings.mongo)
        logger.info("MongoDB client initialized")
    else:
        assert settings.analysis_file is not None, (
            "Analysis directory must be provided if MongoDB client is not provided"
        )
        # the analysis directory should contain the analysis results to be updated
        logger.info(f"Analysis results to be updated: {settings.analysis_file}")

    logger.info("Loading SAE model")
    sae = AbstractSparseAutoEncoder.from_pretrained(
        settings.sae.pretrained_name_or_path,
        device=settings.sae.device,
        dtype=settings.sae.dtype,
        fold_activation_scale=settings.sae.fold_activation_scale,
        strict_loading=settings.sae.strict_loading,
    )

    # Load configurations
    model_cfg = load_config(
        config=settings.model,
        name=settings.model_name,
        mongo_client=mongo_client,
        config_type="model",
        required=True,
    )
    model_cfg.device = settings.device_type
    model_cfg.dtype = sae.cfg.dtype

    model = load_model(model_cfg)
    assert isinstance(model, TransformerLensLanguageModel), (
        "DirectLogitAttributor only supports TransformerLensLanguageModel as the model backend"
    )

    logger.info("Direct logit attribution")
    direct_logit_attributor = DirectLogitAttributor(settings.direct_logit_attributor)
    results = direct_logit_attributor.direct_logit_attribute(sae, model, settings.layer_idx)

    # if is_master():
    if mongo_client is not None:
        logger.info("Direct logit attribution completed, saving results to MongoDB")
        mongo_client.update_features(
            sae_name=settings.sae_name,
            sae_series=settings.sae_series,
            update_data=[{"logits": result} for result in results],
            start_idx=0,
        )
    else:
        assert settings.analysis_file is not None, "analysis_file must be set when MongoDB is not used"
        logger.info(f"Loading analysis results from: {settings.analysis_file}")

        # Load existing analysis results
        with open(settings.analysis_file, "rb") as f:
            analysis_data = pickle.load(f)

        # Update each feature with logits
        assert len(analysis_data) == len(results), (
            f"Number of features in analysis file ({len(analysis_data)}) does not match "
            f"number of results from direct logit attribution ({len(results)})"
        )

        for i, result in enumerate(results):
            assert analysis_data[i]["feature_idx"] == i, "Feature index mismatch"
            analysis_data[i]["logits"] = result

        # Save updated analysis back to file
        logger.info(f"Saving updated analysis results to: {settings.analysis_file}")
        with open(settings.analysis_file, "wb") as f:
            pickle.dump(analysis_data, f)

        logger.info(f"Updated {len(results)} features with logit attributions")

    logger.info("Direct logit attribution completed successfully")

CheckActivationConsistencySettings

Bases: BaseSettings

Settings for checking activation consistency. It will check if the activations are consistent across different hook points by comparing their token ids.

paths instance-attribute

paths: dict[str, Path]

Paths to the activations to check.

device class-attribute instance-attribute

device: str = 'cuda'

Device to use for checking activation consistency

num_workers class-attribute instance-attribute

num_workers: int = 0

Number of workers to use for checking activation consistency

prefetch_factor class-attribute instance-attribute

prefetch_factor: int | None = None

Number of samples loaded in advance by each worker

check_activation_consistency

check_activation_consistency(
    settings: CheckActivationConsistencySettings,
) -> None

Check activation consistency.

Source code in src/lm_saes/runners/generate.py
def check_activation_consistency(settings: CheckActivationConsistencySettings) -> None:
    """Check activation consistency."""

    loader = CachedActivationLoader(
        cache_dirs=settings.paths,
        device=settings.device,
        num_workers=settings.num_workers,
        prefetch_factor=settings.prefetch_factor,
    )

    activations = loader.process()
    for activation in activations:
        pass