This repository was archived by the owner on Mar 23, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathsae.py
172 lines (141 loc) · 6.06 KB
/
sae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import logging
from pathlib import Path
import torch
from torch import nn
class TopKSparseAutoencoder(nn.Module):
def __init__(
self,
d_model: int,
n_latents: int,
k: int,
b_pre: torch.Tensor,
dtype: torch.dtype,
normalize_eps: float = 1e-6,
):
""""""
super().__init__()
self.d_model = d_model
self.n_latents = n_latents
self.k = k
self.dtype = dtype
self.normalize_eps = normalize_eps
self.h_bias = None
# Initialize training data mean (or median) as shared trainable pre-bias Parameter for encoder and decoder
self.b_pre = nn.Parameter(b_pre.to(dtype), requires_grad=True)
# Initialize encoder and decoder. The encoder has an additional bias term b_enc in addition to b_pre in the
# forward pass, whereas the decoder does not have a bias term.
self.encoder = nn.Linear(d_model, n_latents, bias=True, dtype=dtype)
self.decoder = nn.Linear(n_latents, d_model, bias=False, dtype=dtype)
# Use orthogonal initialization for encoder to ensure well-distributed, independent directions and copy
# the transposed encoder weights to decoder weights to ensure parallel initialization as per paper.
nn.init.orthogonal_(self.encoder.weight)
with torch.no_grad():
self.decoder.weight.copy_(self.encoder.weight.t())
self.normalize_decoder_weights()
def normalize_decoder_weights(self) -> None:
"""Normalize the decoder weights to unit norm for each latent (corresponding to decoder columns)."""
with torch.no_grad():
self.decoder.weight.div_(self.decoder.weight.norm(dim=1, keepdim=True))
def project_decoder_grads(self):
"""Project out gradient information parallel to dict vectors."""
with torch.no_grad():
# Compute dot product of decoder weights and their grads, then subtract the projection from the grads
# in place to save memory
proj = torch.sum(self.decoder.weight * self.decoder.weight.grad, dim=1, keepdim=True)
self.decoder.weight.grad.sub_(proj * self.decoder.weight)
def preprocess_input(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Preprocess input by converting to model dtype, centering and normalizing."""
x = x.to(self.dtype)
mean = x.mean(dim=-1, keepdim=True)
norm = x.std(dim=-1, keepdim=True) + self.normalize_eps
x = (x - mean) / norm
return x, mean, norm
@staticmethod
def postprocess_output(
reconstructed: torch.Tensor,
mean: torch.Tensor,
norm: torch.Tensor,
) -> torch.Tensor:
"""Postprocess output by denormalizing and adding back the input mean."""
return (reconstructed * norm) + mean
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
:param x: input tensor of shape (batch_size, seq_len, d_model)
:return: reconstructed tensor of shape (batch_size, seq_len, d_model)
"""
# Store original dtype and preprocess input
orig_dtype = x.dtype
x, mean, norm = self.preprocess_input(x)
# Reshape to flatten batch and sequence dimensions
batch_size, seq_len, d_model = x.shape
x = x.reshape(-1, d_model)
# Forward pass through model in normalized space
normalized_recon, h, _ = self.forward_1d_normalized(x)
# Reshape back to (batch_size, seq_len, d_model)
normalized_recon = normalized_recon.reshape(batch_size, seq_len, -1)
# Postprocess output and return
reconstructed = self.postprocess_output(normalized_recon, mean, norm).to(orig_dtype)
return reconstructed
def forward_1d_normalized(
self,
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
:param x: input tensor of shape (batch_size, d_model)
"""
# Subtract pre-bias and encode input
x = x - self.b_pre
h = self.encoder(x)
if self.h_bias is not None:
h = h + self.h_bias
# Reconstruct input and latent representation with default k sparsity
reconstructed, h_sparse = self.decode_latent(h=h, k=self.k)
return reconstructed, h, h_sparse
def decode_latent(self, h: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor]:
""""""
# Apply TopK activation, Relu to guarantee positive topk vals and then build sparse representation
topk_values, topk_indices = torch.topk(h, k=k, dim=-1)
topk_values = torch.relu(topk_values)
h_sparse = torch.zeros_like(h).scatter_(1, topk_indices, topk_values)
# Decode h_sparse and add pre-bias
reconstructed = self.decoder(h_sparse) + self.b_pre
return reconstructed, h_sparse
def set_latent_bias(self, h_bias: torch.Tensor) -> None:
""""""
assert h_bias.shape == (self.n_latents,), "h_bias shape must be of shape (n_latents,)"
self.h_bias = h_bias.to(self.dtype)
def unset_latent_bias(self) -> None:
""""""
self.h_bias = None
def load_sae_model(
model_path: Path,
sae_top_k: int,
sae_normalization_eps: float,
device: torch.device,
dtype: torch.dtype,
) -> TopKSparseAutoencoder:
""""""
logging.info(f"Loading TopK SAE model weights and config from: {model_path}")
state_dict = torch.load(
model_path,
map_location=torch.device("cpu"),
weights_only=True,
)
b_pre = state_dict["b_pre"]
d_model = b_pre.shape[0]
n_latents = state_dict["encoder.weight"].shape[0]
logging.info("Initializing TopK SAE model and loading state dict...")
model = TopKSparseAutoencoder(
d_model=d_model,
n_latents=n_latents,
k=sae_top_k,
b_pre=b_pre,
dtype=dtype,
normalize_eps=sae_normalization_eps,
)
model.load_state_dict(state_dict)
del state_dict
logging.info(f"Moving model to device {device} and setting to eval mode...")
model.to(device)
model.eval()
return model