Style Guide for Language-Model-SAEs
Language-Model-SAEs basically takes advantage of Python and TypeScript (React), respectively for the core library & backend, and the frontend visualization. This style guide is a list of common dos and don'ts.
Python Style Guide
The Python style guide mainly follows the best practices listed in Google Python Style Guide, but also contains instructions on writing tensor computation and distributed program.
Lint and Format
Language-Model-SAEs uses ruff as the Python linter and formatter. ruff is a tool for detecting stylistic inconsistencies and potential bugs in Python source code. The formatter ensures consistent formatting throughout the codebase, including indentation, line width, trailing commas, and string quote style. The linter checks code quality, catching issues like unused variables and non-standard naming conventions. Make sure ruff is happy before committing, by running:
These commands will check the formatting and linting issues in the Python codes based on the rules defined in pyproject.toml. It will also fix all formatting problems and some fixable linting problems. You should manually check the remaining linting problems (if exists) and fix them.
We also have a pre-commit hook configured in .pre-commit-config.yaml. Install the pre-commit hook by running
This should automatically run the above ruff formatter and linter checks before committing.
Imports and Exports
- Use
import xfor importing packages and modules. - Use
from x import ywherexis the package prefix andyis the module name with no prefix. - Use
from x import y as zin any of the following circumstances:- Two modules named
yare to be imported. yconflicts with a top-level name defined in the current module.yconflicts with a common parameter name that is part of the public API (e.g.,features).yis an inconveniently long name.yis too generic in the context of your code (e.g.,from storage.file_system import options as fs_options).
- Two modules named
- Use
import y as zonly when z is a standard abbreviation (e.g.,import numpy as np). - Always use complete absolute path to import first-party modules.
- For any functions or classes intended to be exposed to users, add them to
__all__in__init__.py.
Exceptions
WIP
Mutable States
Immutability produces code that's easier to reason about, easier to test, and easier to verify for correctness. Immutable data doesn't change, so we can safely reuse it without worrying that results will differ between calls. Immutable data can also be safely passed between threads without race conditions or other concurrency issues. We should avoid mutable states whenever possible.
Mutable Global States
Avoid mutable global states in the core library, as they significantly compromise the purity of core functionalities. Some global caches are permitted in the visualization server.
Mutable Local States
While a purely functional style (which avoids mutable states entirely) is preferred, some mutability is pragmatic or even necessary. Completely eliminating mutability may lead to overly complicated program structures and decreased readability. Thus, the preference for immutability follows a best effort principle. Below are some cases where mutable states are acceptable, though immutable alternatives should be considered first:
-
Use list/dictionary/set comprehension to create container types without resorting to procedural loops,
map,filter, etc. The comprehension approach removes temporary mutable states and keeps codes concise.Loops
-
The principle of avoiding "empty first, then fill" applies to other cases as well:
Concatenate Strings
Accumulate Strings
Build Dictionary
-
Avoid in-place modifications: create new containers instead of modifying old.
Create New Dictionary on Modification
-
When mutable state is inevitable, limit its scope and preserve the purity of the outer function. Ensure that only a minimal portion of the code has access to the mutable state.
Localized Mutable State
def compute_statistics(data: list[float]) -> dict[str, float]: """Pure function that returns statistics without side effects.""" # Mutable state is confined within this function stats = {} total = 0.0 for value in data: total += value stats["mean"] = total / len(data) stats["sum"] = total return stats # Return new object, no external mutationLeaked Mutable State
# Global mutable state accumulated_stats = {} def compute_statistics(data: list[float]) -> None: """Impure function that mutates global state.""" total = 0.0 for value in data: total += value # Mutates external state - breaks purity accumulated_stats["mean"] = total / len(data) accumulated_stats["sum"] = total
The function remains referentially transparent: given the same input, it always produces the same output without observable side effects. Internal mutability for performance is acceptable as long as it doesn't leak outside the function boundary.
Tensor Computation
PyTorch provides a wide range of tensor operations. However, most can be decomposed into basic operators. As a library for interpretability, we encourage using einops for better readability, since it explicitly specifies input shapes, output shapes, and semantic dimensions instead of numeric indices.
-
Use
einops.einsumto perform tensor products, including batch matrix multiplication and more complex operations. Use full names for dimensions, e.g.,batchinstead ofb. Matrix multiplication can be written asx @ yfor simplicity. -
Use
einops.rearrangeto perform reshape, transpose, permute, squeeze, and unsqueeze operations while explicitly showing how dimensions change.Rearrange
-
Use
einops.reduceto perform reductions over specific dimensions. Only use.mean()or.sum()for overall reductions that produce a scalar output. -
Use
einops.repeatto broadcast or tile tensors along specific dimensions instead of manual reshaping and expanding.Repeat
-
Avoid using complicated operators and modules from
torch.nnandtorch.nn.functional, unless there're significant performance gaps.
Distributed Programming
The distributed support in Language-Model-SAEs relies on DeviceMesh and DTensor. The design of DeviceMesh and DTensor is heavily inspired by JAX. DeviceMesh allows users to easily manage multi-dimensional parallelism by creating a "mesh" that controls all devices and specifies how different parallelism strategies are distributed across them. Built on DeviceMesh, DTensor provides a global view of how tensors are distributed across devices, following the SPMD (Single Program, Multiple Data) programming model. With DTensor, users can (ideally) work as if they have infinite logical device memory to accommodate large tensors and perform operations on them. DTensor automatically splits the data and computation across physical devices based on the DeviceMesh it operates on and the sharding strategy it uses.
Below list some rules to better leverage DTensor for distributed programming in Language-Model-SAEs:
-
Avoid hardcoding
DTensorplacements. Use DimMap (which is designed to be similar to PartitionSpec in JAX) to dynamically generate placements based on currentDeviceMesh. This allows absence of some specific dimensions inDeviceMesh.DimMap-generated Placements
-
Avoid set
Type Annotation
All codes should be annotated with type hints. Language-Model-SAEs relys on basedpyright to perform static type checking. Below list some extra rules:
-
Type hints of generic types should follow PEP 585. Use built-in types
list,dict,set, etc. rather than types from thetypingmodule. -
Type hints of union types should follow PEP 604 syntax. Use
X | Yrather thanUnion[X, Y], andX | Nonerather thanOptional[X]. -
Tensors with known shapes should be annotated with
jaxtyping.Tensors with Shapes Annotated
Some of the type hints in the current codebase may not follow the above rules since it's heavy work to fix them all. We expect new codes to follow these rules.
TypeScript Style Guide
TBD