Skip to content

Activation

Activation extraction, caching, and processing.

ActivationFactoryConfig pydantic-model

Bases: BaseConfig

Config:

  • arbitrary_types_allowed: True

Fields:

sources pydantic-field

List of sources to use for activations. Can be a dataset or a path to activations.

target pydantic-field

The target to produce.

hook_points pydantic-field

hook_points: list[str]

The hook points to capture activations from.

batch_size pydantic-field

batch_size: int

The batch size to use for outputting activations.

num_workers pydantic-field

num_workers: int = 4

The number of workers to use for loading the dataset.

context_size pydantic-field

context_size: int | None = None

The context size to use for generating activations. All tokens will be padded or truncated to this size. If None, will not pad or truncate tokens. This may lead to some error when re-batching activations of different context sizes.

model_batch_size pydantic-field

model_batch_size: int = 1

The batch size to use for model forward pass when generating activations.

override_dtype pydantic-field

override_dtype: Optional[dtype] = None

The dtype to use for outputting activations. If None, will not override the dtype.

buffer_size pydantic-field

buffer_size: int | None = None

Buffer size for online shuffling. If None, no shuffling will be performed.

buffer_shuffle pydantic-field

buffer_shuffle: BufferShuffleConfig | None = None

" Manual seed and device of generator for generating randomperm in buffer.

ignore_token_ids pydantic-field

ignore_token_ids: list[int] | None = None

Tokens to ignore in the activations.

ActivationFactory

ActivationFactory(
    cfg: ActivationFactoryConfig,
    before_aggregation_interceptor: Callable[
        [dict[str, Any], int], dict[str, Any]
    ]
    | None = None,
    device_mesh: Optional[Any] = None,
)

Factory class for generating activation data from different sources.

This class handles loading data from datasets or activation files, processing it through a pipeline of processors, and aggregating the results based on configured weights.

The overall pipeline is like a tree, where multiple chains collect data from different sources, and then aggregated together, which in detail is: 1. Pre-aggregation processors: Process data from each source through a series of processors. 2. Aggregator: Aggregate the processed data streams. 3. Post-aggregation processor: Process the aggregated data through a final processor.

Initialize the factory with the given configuration.

Parameters:

Name Type Description Default
cfg ActivationFactoryConfig

Configuration object specifying data sources, processing pipeline and output format

required
Source code in src/lm_saes/activation/factory.py
def __init__(
    self,
    cfg: ActivationFactoryConfig,
    before_aggregation_interceptor: Callable[[dict[str, Any], int], dict[str, Any]] | None = None,
    device_mesh: Optional[Any] = None,
):
    """Initialize the factory with the given configuration.

    Args:
        cfg: Configuration object specifying data sources, processing pipeline and output format
    """
    self.cfg = cfg
    self.device_mesh = device_mesh

    self.pre_aggregation_processors = self.build_pre_aggregation_processors()
    self.post_aggregation_processor = self.build_post_aggregation_processor()
    self.aggregator = self.build_aggregator()
    self.before_aggregation_interceptor = before_aggregation_interceptor

build_pre_aggregation_processors

build_pre_aggregation_processors()

Build processors that run before aggregation for each data source.

Returns:

Type Description

List of callables that process data from each source

Source code in src/lm_saes/activation/factory.py
def build_pre_aggregation_processors(self):
    """Build processors that run before aggregation for each data source.

    Returns:
        List of callables that process data from each source
    """
    # Split sources by type
    dataset_sources = [source for source in self.cfg.sources if isinstance(source, ActivationFactoryDatasetSource)]
    activations_sources = [
        source for source in self.cfg.sources if isinstance(source, ActivationFactoryActivationsSource)
    ]

    pre_aggregation_processors = [
        self._build_pre_aggregation_dataset_source_processors(source, i) for i, source in enumerate(dataset_sources)
    ] + [
        self._build_pre_aggregation_activations_source_processors(source, i + len(dataset_sources))
        for i, source in enumerate(activations_sources)
    ]
    return pre_aggregation_processors

build_post_aggregation_processor

build_post_aggregation_processor()

Build processor that runs after aggregation.

Parameters:

Name Type Description Default
cfg

Factory configuration object

required

Returns:

Type Description

Callable that processes aggregated data

Source code in src/lm_saes/activation/factory.py
def build_post_aggregation_processor(self):
    """Build processor that runs after aggregation.

    Args:
        cfg: Factory configuration object

    Returns:
        Callable that processes aggregated data
    """

    def build_batchler():
        """Create batchler for batched activations."""
        assert self.cfg.batch_size is not None, "Batch size must be provided for outputting batched activations"
        return ActivationBatchler(
            batch_size=self.cfg.batch_size,
            buffer_size=self.cfg.buffer_size,
            buffer_shuffle_config=self.cfg.buffer_shuffle,
            device_mesh=self.device_mesh,
        )

    def build_override_dtype_processor():
        """Create processor that overrides the dtype of the activations."""
        assert self.cfg.override_dtype is not None, (
            "Override dtype must be provided for outputting activations with different dtype"
        )
        return OverrideDtypeProcessor(dtype=self.cfg.override_dtype)

    processors = []
    if self.cfg.batch_size is not None:
        processors.append(build_batchler())
    if self.cfg.override_dtype is not None:
        processors.append(build_override_dtype_processor())

    def process_activations(activations: Iterable[dict[str, Any]], **kwargs: Any):
        """Process aggregated activations through post-processors.

        Args:
            activations: Stream of aggregated activation data
            **kwargs: Additional arguments passed to processors

        Returns:
            Processed activation stream
        """
        for processor in processors:
            activations = processor.process(activations, **kwargs)
        return activations

    return process_activations

build_aggregator

build_aggregator()

Build function to aggregate data from multiple sources.

Returns:

Type Description

Callable that aggregates data streams. Currently is a simple weighted random sampler.

Source code in src/lm_saes/activation/factory.py
def build_aggregator(self):
    """Build function to aggregate data from multiple sources.

    Returns:
        Callable that aggregates data streams. Currently is a simple weighted random sampler.
    """
    source_sample_weights = np.array([source.sample_weights for source in self.cfg.sources])

    def aggregate(activations: list[Iterable[dict[str, Any]]], **kwargs: Any) -> Iterable[dict[str, Any]]:
        """Aggregate multiple activation streams by sampling based on weights.

        Args:
            activations: List of activation streams from different sources
            **kwargs: Additional arguments (unused)

        Yields:
            Sampled activation data with source info
        """
        ran_out_of_samples = np.zeros(len(self.cfg.sources), dtype=bool)
        activations: list[Iterator[dict[str, Any]]] = [iter(activation) for activation in activations]
        # Mask out sources run out of samples
        weights = source_sample_weights[~ran_out_of_samples]
        weights = weights / weights.sum()

        while not all(ran_out_of_samples):
            sampled_sources = np.random.choice(len(activations), replace=True, p=weights)
            try:
                result = next(activations[sampled_sources])
            except StopIteration:
                ran_out_of_samples[sampled_sources] = True
                continue
            yield result

    return aggregate

process

process(**kwargs: Any)

Process data through the full pipeline.

Parameters:

Name Type Description Default
**kwargs Any

Arguments passed to processors (must include required args)

{}

Returns:

Type Description

Iterable of processed activation data

Source code in src/lm_saes/activation/factory.py
def process(self, **kwargs: Any):
    """Process data through the full pipeline.

    Args:
        **kwargs: Arguments passed to processors (must include required args)

    Returns:
        Iterable of processed activation data
    """
    streams = [processor(**kwargs) for processor in self.pre_aggregation_processors]
    stream = self.aggregator(streams)
    return self.post_aggregation_processor(stream, **kwargs)

ActivationFactoryTarget

Bases: Enum

TOKENS class-attribute instance-attribute

TOKENS = 'tokens'

Output non-padded and non-truncated tokens.

ACTIVATIONS_2D class-attribute instance-attribute

ACTIVATIONS_2D = 'activations-2d'

Output activations in (batch_size, seq_len, d_model) shape. Tokens are padded and truncated to the same length.

ACTIVATIONS_1D class-attribute instance-attribute

ACTIVATIONS_1D = 'activations-1d'

Output activations in (n_filtered_tokens, d_model) shape. Tokens are filtered in this stage.

ActivationFactoryDatasetSource pydantic-model

Bases: ActivationFactorySource

Fields:

is_dataset_tokenized pydantic-field

is_dataset_tokenized: bool = False

Whether the dataset is tokenized. Non-tokenized datasets should have records with fields text, images, etc. Tokenized datasets should have records with fields tokens, which could contain either padded or non-padded tokens.

prepend_bos pydantic-field

prepend_bos: bool = True

Whether to prepend the BOS token to each record when tokenizing.

ActivationFactoryActivationsSource pydantic-model

Bases: ActivationFactorySource

Config:

  • arbitrary_types_allowed: True

Fields:

path pydantic-field

path: str | dict[str, str]

The path to the cached activations.

device pydantic-field

device: str = 'cpu'

The device to load the activations on.

dtype pydantic-field

dtype: Optional[dtype] = None

We might want to convert presaved bf16 activations to fp32

num_workers pydantic-field

num_workers: int = 4

The number of workers to use for loading the activations.

prefetch pydantic-field

prefetch: int | None = 8

The number of chunks to prefetch.

BufferShuffleConfig pydantic-model

Bases: BaseConfig

Fields:

perm_seed pydantic-field

perm_seed: int = 42

Perm seed for aligned permutation for generating activations. If None, will not use manual seed for Generator.

generator_device pydantic-field

generator_device: str | None = None

The device to be assigned for the torch.Generator. If 'None', generator will be initialized on cpu as pytorch default.

ActivationWriterConfig pydantic-model

Bases: BaseConfig

Fields:

hook_points pydantic-field

hook_points: list[str]

The hook points to capture activations from.

total_generating_tokens pydantic-field

total_generating_tokens: int | None = None

The total number of tokens to generate. If None, will write all activations to disk.

n_samples_per_chunk pydantic-field

n_samples_per_chunk: int | None = None

The number of samples to write to disk per chunk. If None, will not further batch the activations.

cache_dir pydantic-field

cache_dir: str = 'activations'

The directory to save the activations.

num_workers pydantic-field

num_workers: int | None = None

The number of workers to use for writing the activations. If None, will not use multi-threaded writing.

ActivationWriter

ActivationWriter(
    cfg: ActivationWriterConfig,
    executor: Optional[ThreadPoolExecutor] = None,
)

Writes activations to disk in a format compatible with CachedActivationLoader.

Parameters:

Name Type Description Default
cfg ActivationWriterConfig

Configuration for writing activations

required
executor Optional[ThreadPoolExecutor]

Optional ThreadPoolExecutor for parallel writing. If None, a new executor will be created with max_workers=2.

None
Source code in src/lm_saes/activation/writer.py
def __init__(
    self,
    cfg: ActivationWriterConfig,
    executor: Optional[ThreadPoolExecutor] = None,
):
    self.cache_dir = Path(cfg.cache_dir)
    self.cfg = cfg
    if cfg.num_workers is None:
        self.executor = None
    else:
        self.executor = executor or ThreadPoolExecutor(max_workers=cfg.num_workers)
    self._owned_executor = cfg.num_workers is not None and executor is None

    # Create directories for each hook point
    for hook_point in self.cfg.hook_points:
        hook_dir = self.cache_dir / hook_point
        hook_dir.mkdir(parents=True, exist_ok=True)

process

process(
    data: Iterable[dict[str, Any]],
    *,
    device_mesh: Optional[DeviceMesh] = None,
    start_shard: int = 0,
) -> None

Write activation data to disk in chunks.

Processes a stream of activation dictionaries, accumulating samples until reaching the configured chunk size, then writes each chunk to disk. Files are organized by hook point with names following the pattern 'chunk-{N}.pt'.

Parameters:

Name Type Description Default
data Iterable[dict[str, Any]]

Stream of activation dictionaries containing: - Activations for each hook point - Original tokens - Meta information

required
device_mesh Optional[DeviceMesh]

The device mesh to use for distributed writing. If None, will write to disk on the current rank.

None
start_shard int

The shard to start writing from.

0
Source code in src/lm_saes/activation/writer.py
def process(
    self,
    data: Iterable[dict[str, Any]],
    *,
    device_mesh: Optional[DeviceMesh] = None,
    start_shard: int = 0,
) -> None:
    """Write activation data to disk in chunks.

    Processes a stream of activation dictionaries, accumulating samples until reaching
    the configured chunk size, then writes each chunk to disk. Files are organized by
    hook point with names following the pattern 'chunk-{N}.pt'.

    Args:
        data: Stream of activation dictionaries containing:
            - Activations for each hook point
            - Original tokens
            - Meta information
        device_mesh: The device mesh to use for distributed writing. If None, will write to disk on the current rank.
        start_shard: The shard to start writing from.
    """
    total = (
        self.cfg.total_generating_tokens // device_mesh.get_group("data").size()
        if device_mesh is not None and self.cfg.total_generating_tokens is not None
        else self.cfg.total_generating_tokens
    )
    pbar = tqdm(desc="Writing activations to disk", total=total)
    n_tokens_written = 0

    futures = set() if self.cfg.num_workers is not None else None

    if self.cfg.n_samples_per_chunk is not None:

        def collate_batch(batch: Sequence[dict[str, Any]]) -> dict[str, Any]:
            # Assert that all samples have the same keys
            assert all(k in d for k in batch[0] for d in batch), (
                f"All samples must have the same keys: {batch[0].keys()}"
            )
            return {
                k: torch.stack([d[k] for d in batch])
                if isinstance(batch[0][k], torch.Tensor)
                else [d[k] for d in batch]
                for k in batch[0].keys()
            }

        data = map(collate_batch, more_itertools.batched(data, self.cfg.n_samples_per_chunk))

    for chunk_id, chunk in enumerate(data):
        assert all(k in chunk for k in self.cfg.hook_points), (
            f"All samples must have the hook points: {self.cfg.hook_points}"
        )

        chunk_name = (
            f"chunk-{chunk_id:08d}"
            if device_mesh is None
            else f"shard-{device_mesh.get_group('data').rank() + start_shard}-chunk-{chunk_id:08d}"
        )

        # Submit writing tasks for each hook point
        with timer.time("write_chunk"):
            for hook_point in self.cfg.hook_points:
                chunk_data = {"activation": chunk[hook_point]} | {
                    k: v for k, v in chunk.items() if k not in ["meta", *self.cfg.hook_points]
                }
                if futures is None:
                    self._write_chunk(
                        hook_point, chunk_data, chunk_name, chunk["meta"] if "meta" in chunk else None
                    )
                else:
                    assert self.executor is not None, "Executor is not initialized"
                    future = self.executor.submit(
                        self._write_chunk,
                        hook_point,
                        chunk_data,
                        chunk_name,
                        chunk["meta"] if "meta" in chunk else None,
                    )
                    futures.add(future)

            if futures is not None:
                assert self.cfg.num_workers is not None, "num_workers must be set to use parallel writing"
                # Wait for some futures to complete if we have too many pending
                while len(futures) >= self.cfg.num_workers * 2:
                    done, futures = wait(futures, return_when="FIRST_COMPLETED")
                    for future in done:
                        future.result()  # Raise any exceptions that occurred

        if timer.enabled:
            logger.info(f"\nTimer Summary:\n{timer.summary()}\n")

        n_tokens_written += chunk["tokens"].numel()
        pbar.update(chunk["tokens"].numel())

        if total is not None and n_tokens_written >= total:
            break

    if futures is not None:
        # Wait for remaining futures to complete
        for future in as_completed(futures):
            future.result()

    pbar.close()

__del__

__del__() -> None

Cleanup the executor if we own it.

Source code in src/lm_saes/activation/writer.py
def __del__(self) -> None:
    """Cleanup the executor if we own it."""
    if self._owned_executor and self.executor is not None:
        self.executor.shutdown(wait=True)