data-archetype commited on
Commit
518072c
·
verified ·
1 Parent(s): 1e745d3

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. capacitor_diffae/model.py +51 -0
capacitor_diffae/model.py CHANGED
@@ -65,10 +65,22 @@ class CapacitorDiffAE(nn.Module):
65
  recon = model.reconstruct(images)
66
  """
67
 
 
 
68
  def __init__(self, config: CapacitorDiffAEConfig) -> None:
69
  super().__init__()
70
  self.config = config
71
 
 
 
 
 
 
 
 
 
 
 
72
  self.encoder = Encoder(
73
  in_channels=config.in_channels,
74
  patch_size=config.patch_size,
@@ -154,6 +166,45 @@ class CapacitorDiffAE(nn.Module):
154
  model.eval()
155
  return model
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def encode(self, images: Tensor) -> Tensor:
158
  """Encode images to latents (posterior mode).
159
 
 
65
  recon = model.reconstruct(images)
66
  """
67
 
68
+ _LATENT_NORM_EPS: float = 1e-4
69
+
70
  def __init__(self, config: CapacitorDiffAEConfig) -> None:
71
  super().__init__()
72
  self.config = config
73
 
74
+ # Latent running stats for whitening/dewhitening
75
+ self.register_buffer(
76
+ "latent_norm_running_mean",
77
+ torch.zeros((config.bottleneck_dim,), dtype=torch.float32),
78
+ )
79
+ self.register_buffer(
80
+ "latent_norm_running_var",
81
+ torch.ones((config.bottleneck_dim,), dtype=torch.float32),
82
+ )
83
+
84
  self.encoder = Encoder(
85
  in_channels=config.in_channels,
86
  patch_size=config.patch_size,
 
166
  model.eval()
167
  return model
168
 
169
+ def _latent_norm_stats(self) -> tuple[Tensor, Tensor]:
170
+ """Return (mean, std) tensors for latent whitening, shaped [1,C,1,1]."""
171
+ mean = self.latent_norm_running_mean.view(1, -1, 1, 1)
172
+ var = self.latent_norm_running_var.view(1, -1, 1, 1)
173
+ std = torch.sqrt(var.to(torch.float32) + self._LATENT_NORM_EPS)
174
+ return mean.to(torch.float32), std
175
+
176
+ def whiten(self, latents: Tensor) -> Tensor:
177
+ """Whiten encoder latents using per-channel running stats.
178
+
179
+ Use this before passing latents to a downstream latent-space
180
+ diffusion model. The whitened latents have approximately zero mean
181
+ and unit variance per channel.
182
+
183
+ Args:
184
+ latents: [B, bottleneck_dim, h, w] raw encoder output.
185
+
186
+ Returns:
187
+ Whitened latents [B, bottleneck_dim, h, w] in float32.
188
+ """
189
+ z = latents.to(torch.float32)
190
+ mean, std = self._latent_norm_stats()
191
+ return (z - mean.to(device=z.device)) / std.to(device=z.device)
192
+
193
+ def dewhiten(self, latents: Tensor) -> Tensor:
194
+ """Undo whitening to recover raw encoder latent scale.
195
+
196
+ Use this before passing whitened latents back to ``decode()``.
197
+
198
+ Args:
199
+ latents: [B, bottleneck_dim, h, w] whitened latents.
200
+
201
+ Returns:
202
+ Dewhitened latents [B, bottleneck_dim, h, w] in float32.
203
+ """
204
+ z = latents.to(torch.float32)
205
+ mean, std = self._latent_norm_stats()
206
+ return z * std.to(device=z.device) + mean.to(device=z.device)
207
+
208
  def encode(self, images: Tensor) -> Tensor:
209
  """Encode images to latents (posterior mode).
210