Upload folder using huggingface_hub
Browse files- capacitor_diffae/model.py +51 -0
capacitor_diffae/model.py
CHANGED
|
@@ -65,10 +65,22 @@ class CapacitorDiffAE(nn.Module):
|
|
| 65 |
recon = model.reconstruct(images)
|
| 66 |
"""
|
| 67 |
|
|
|
|
|
|
|
| 68 |
def __init__(self, config: CapacitorDiffAEConfig) -> None:
|
| 69 |
super().__init__()
|
| 70 |
self.config = config
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
self.encoder = Encoder(
|
| 73 |
in_channels=config.in_channels,
|
| 74 |
patch_size=config.patch_size,
|
|
@@ -154,6 +166,45 @@ class CapacitorDiffAE(nn.Module):
|
|
| 154 |
model.eval()
|
| 155 |
return model
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
def encode(self, images: Tensor) -> Tensor:
|
| 158 |
"""Encode images to latents (posterior mode).
|
| 159 |
|
|
|
|
| 65 |
recon = model.reconstruct(images)
|
| 66 |
"""
|
| 67 |
|
| 68 |
+
_LATENT_NORM_EPS: float = 1e-4
|
| 69 |
+
|
| 70 |
def __init__(self, config: CapacitorDiffAEConfig) -> None:
|
| 71 |
super().__init__()
|
| 72 |
self.config = config
|
| 73 |
|
| 74 |
+
# Latent running stats for whitening/dewhitening
|
| 75 |
+
self.register_buffer(
|
| 76 |
+
"latent_norm_running_mean",
|
| 77 |
+
torch.zeros((config.bottleneck_dim,), dtype=torch.float32),
|
| 78 |
+
)
|
| 79 |
+
self.register_buffer(
|
| 80 |
+
"latent_norm_running_var",
|
| 81 |
+
torch.ones((config.bottleneck_dim,), dtype=torch.float32),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
self.encoder = Encoder(
|
| 85 |
in_channels=config.in_channels,
|
| 86 |
patch_size=config.patch_size,
|
|
|
|
| 166 |
model.eval()
|
| 167 |
return model
|
| 168 |
|
| 169 |
+
def _latent_norm_stats(self) -> tuple[Tensor, Tensor]:
|
| 170 |
+
"""Return (mean, std) tensors for latent whitening, shaped [1,C,1,1]."""
|
| 171 |
+
mean = self.latent_norm_running_mean.view(1, -1, 1, 1)
|
| 172 |
+
var = self.latent_norm_running_var.view(1, -1, 1, 1)
|
| 173 |
+
std = torch.sqrt(var.to(torch.float32) + self._LATENT_NORM_EPS)
|
| 174 |
+
return mean.to(torch.float32), std
|
| 175 |
+
|
| 176 |
+
def whiten(self, latents: Tensor) -> Tensor:
|
| 177 |
+
"""Whiten encoder latents using per-channel running stats.
|
| 178 |
+
|
| 179 |
+
Use this before passing latents to a downstream latent-space
|
| 180 |
+
diffusion model. The whitened latents have approximately zero mean
|
| 181 |
+
and unit variance per channel.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
latents: [B, bottleneck_dim, h, w] raw encoder output.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
Whitened latents [B, bottleneck_dim, h, w] in float32.
|
| 188 |
+
"""
|
| 189 |
+
z = latents.to(torch.float32)
|
| 190 |
+
mean, std = self._latent_norm_stats()
|
| 191 |
+
return (z - mean.to(device=z.device)) / std.to(device=z.device)
|
| 192 |
+
|
| 193 |
+
def dewhiten(self, latents: Tensor) -> Tensor:
|
| 194 |
+
"""Undo whitening to recover raw encoder latent scale.
|
| 195 |
+
|
| 196 |
+
Use this before passing whitened latents back to ``decode()``.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
latents: [B, bottleneck_dim, h, w] whitened latents.
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
Dewhitened latents [B, bottleneck_dim, h, w] in float32.
|
| 203 |
+
"""
|
| 204 |
+
z = latents.to(torch.float32)
|
| 205 |
+
mean, std = self._latent_norm_stats()
|
| 206 |
+
return z * std.to(device=z.device) + mean.to(device=z.device)
|
| 207 |
+
|
| 208 |
def encode(self, images: Tensor) -> Tensor:
|
| 209 |
"""Encode images to latents (posterior mode).
|
| 210 |
|