Activation
Activation extraction, caching, and processing.
ActivationFactoryConfig
pydantic-model
Bases: BaseConfig
Config:
arbitrary_types_allowed:True
Fields:
-
sources(list[ActivationFactoryDatasetSource | ActivationFactoryActivationsSource]) -
target(ActivationFactoryTarget) -
hook_points(list[str]) -
batch_size(int) -
num_workers(int) -
context_size(int | None) -
model_batch_size(int) -
override_dtype(Optional[dtype]) -
buffer_size(int | None) -
buffer_shuffle(BufferShuffleConfig | None) -
ignore_token_ids(list[int] | None)
sources
pydantic-field
sources: list[
ActivationFactoryDatasetSource
| ActivationFactoryActivationsSource
]
List of sources to use for activations. Can be a dataset or a path to activations.
num_workers
pydantic-field
The number of workers to use for loading the dataset.
context_size
pydantic-field
The context size to use for generating activations. All tokens will be padded or truncated to this size. If None, will not pad or truncate tokens. This may lead to some error when re-batching activations of different context sizes.
model_batch_size
pydantic-field
The batch size to use for model forward pass when generating activations.
override_dtype
pydantic-field
The dtype to use for outputting activations. If None, will not override the dtype.
buffer_size
pydantic-field
Buffer size for online shuffling. If None, no shuffling will be performed.
buffer_shuffle
pydantic-field
buffer_shuffle: BufferShuffleConfig | None = None
" Manual seed and device of generator for generating randomperm in buffer.
ActivationFactory
ActivationFactory(
cfg: ActivationFactoryConfig,
before_aggregation_interceptor: Callable[
[dict[str, Any], int], dict[str, Any]
]
| None = None,
device_mesh: Optional[Any] = None,
)
Factory class for generating activation data from different sources.
This class handles loading data from datasets or activation files, processing it through a pipeline of processors, and aggregating the results based on configured weights.
The overall pipeline is like a tree, where multiple chains collect data from different sources, and then aggregated together, which in detail is: 1. Pre-aggregation processors: Process data from each source through a series of processors. 2. Aggregator: Aggregate the processed data streams. 3. Post-aggregation processor: Process the aggregated data through a final processor.
Initialize the factory with the given configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg
|
ActivationFactoryConfig
|
Configuration object specifying data sources, processing pipeline and output format |
required |
Source code in src/lm_saes/activation/factory.py
build_pre_aggregation_processors
Build processors that run before aggregation for each data source.
Returns:
| Type | Description |
|---|---|
|
List of callables that process data from each source |
Source code in src/lm_saes/activation/factory.py
build_post_aggregation_processor
Build processor that runs after aggregation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg
|
Factory configuration object |
required |
Returns:
| Type | Description |
|---|---|
|
Callable that processes aggregated data |
Source code in src/lm_saes/activation/factory.py
build_aggregator
Build function to aggregate data from multiple sources.
Returns:
| Type | Description |
|---|---|
|
Callable that aggregates data streams. Currently is a simple weighted random sampler. |
Source code in src/lm_saes/activation/factory.py
process
Process data through the full pipeline.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs
|
Any
|
Arguments passed to processors (must include required args) |
{}
|
Returns:
| Type | Description |
|---|---|
|
Iterable of processed activation data |
Source code in src/lm_saes/activation/factory.py
ActivationFactoryTarget
Bases: Enum
TOKENS
class-attribute
instance-attribute
Output non-padded and non-truncated tokens.
ACTIVATIONS_2D
class-attribute
instance-attribute
Output activations in (batch_size, seq_len, d_model) shape. Tokens are padded and truncated to the same length.
ActivationFactoryDatasetSource
pydantic-model
Bases: ActivationFactorySource
Fields:
-
name(str) -
sample_weights(float) -
type(str) -
is_dataset_tokenized(bool) -
prepend_bos(bool)
is_dataset_tokenized
pydantic-field
Whether the dataset is tokenized. Non-tokenized datasets should have records with fields text, images, etc. Tokenized datasets should have records with fields tokens, which could contain either padded or non-padded tokens.
ActivationFactoryActivationsSource
pydantic-model
Bases: ActivationFactorySource
Config:
arbitrary_types_allowed:True
Fields:
-
name(str) -
sample_weights(float) -
type(str) -
path(str | dict[str, str]) -
device(str) -
dtype(Optional[dtype]) -
num_workers(int) -
prefetch(int | None)
dtype
pydantic-field
We might want to convert presaved bf16 activations to fp32
num_workers
pydantic-field
The number of workers to use for loading the activations.
BufferShuffleConfig
pydantic-model
Bases: BaseConfig
Fields:
-
perm_seed(int) -
generator_device(str | None)
perm_seed
pydantic-field
Perm seed for aligned permutation for generating activations. If None, will not use manual seed for Generator.
ActivationWriterConfig
pydantic-model
Bases: BaseConfig
Fields:
-
hook_points(list[str]) -
total_generating_tokens(int | None) -
n_samples_per_chunk(int | None) -
cache_dir(str) -
format(Literal['pt', 'safetensors']) -
num_workers(int | None)
total_generating_tokens
pydantic-field
The total number of tokens to generate. If None, will write all activations to disk.
n_samples_per_chunk
pydantic-field
The number of samples to write to disk per chunk. If None, will not further batch the activations.
ActivationWriter
ActivationWriter(
cfg: ActivationWriterConfig,
executor: Optional[ThreadPoolExecutor] = None,
)
Writes activations to disk in a format compatible with CachedActivationLoader.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg
|
ActivationWriterConfig
|
Configuration for writing activations |
required |
executor
|
Optional[ThreadPoolExecutor]
|
Optional ThreadPoolExecutor for parallel writing. If None, a new executor will be created with max_workers=2. |
None
|
Source code in src/lm_saes/activation/writer.py
process
process(
data: Iterable[dict[str, Any]],
*,
device_mesh: Optional[DeviceMesh] = None,
start_shard: int = 0,
) -> None
Write activation data to disk in chunks.
Processes a stream of activation dictionaries, accumulating samples until reaching the configured chunk size, then writes each chunk to disk. Files are organized by hook point with names following the pattern 'chunk-{N}.pt'.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Iterable[dict[str, Any]]
|
Stream of activation dictionaries containing: - Activations for each hook point - Original tokens - Meta information |
required |
device_mesh
|
Optional[DeviceMesh]
|
The device mesh to use for distributed writing. If None, will write to disk on the current rank. |
None
|
start_shard
|
int
|
The shard to start writing from. |
0
|
Source code in src/lm_saes/activation/writer.py
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | |