data-archetype's picture
Upload folder using huggingface_hub
b32916f verified
raw
history blame
4.56 kB
"""Capacitor encoder: patchify -> FCDMBlocks -> diagonal Gaussian posterior.
No input RMSNorm (use_other_outer_rms_norms=False during training).
Post-bottleneck RMSNorm (affine=False) on the mean branch.
Encoder outputs posterior mode by default: alpha * RMSNorm(mean).
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from .fcdm_block import FCDMBlock
from .norms import ChannelWiseRMSNorm
from .straight_through_encoder import Patchify
@dataclass(frozen=True)
class EncoderPosterior:
"""VP-parameterized diagonal Gaussian posterior.
mean: Clean signal branch mu [B, bottleneck_dim, h, w]
logsnr: Per-element log signal-to-noise ratio [B, bottleneck_dim, h, w]
"""
mean: Tensor
logsnr: Tensor
@property
def alpha(self) -> Tensor:
"""VP signal coefficient: sqrt(sigmoid(logsnr))."""
return torch.sigmoid(self.logsnr).sqrt()
@property
def sigma(self) -> Tensor:
"""VP noise coefficient: sqrt(sigmoid(-logsnr))."""
return torch.sigmoid(-self.logsnr).sqrt()
def mode(self) -> Tensor:
"""Posterior mode in token space: alpha * mean."""
return self.alpha.to(dtype=self.mean.dtype) * self.mean
def sample(self, *, generator: torch.Generator | None = None) -> Tensor:
"""Sample from posterior: alpha * mean + sigma * eps."""
eps = torch.randn_like(self.mean, generator=generator) # type: ignore[call-overload]
alpha = self.alpha.to(dtype=self.mean.dtype)
sigma = self.sigma.to(dtype=self.mean.dtype)
return alpha * self.mean + sigma * eps
class Encoder(nn.Module):
"""Encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w].
With diagonal_gaussian posterior, the to_bottleneck projection outputs
2 * bottleneck_dim channels, split into mean and logsnr. The default
encode() returns the posterior mode: alpha * RMSNorm(mean).
"""
def __init__(
self,
in_channels: int,
patch_size: int,
model_dim: int,
depth: int,
bottleneck_dim: int,
mlp_ratio: float,
depthwise_kernel_size: int,
bottleneck_posterior_kind: str = "diagonal_gaussian",
bottleneck_norm_mode: str = "disabled",
) -> None:
super().__init__()
self.bottleneck_dim = int(bottleneck_dim)
self.bottleneck_posterior_kind = bottleneck_posterior_kind
self.bottleneck_norm_mode = bottleneck_norm_mode
self.patchify = Patchify(in_channels, patch_size, model_dim)
self.blocks = nn.ModuleList(
[
FCDMBlock(
model_dim,
mlp_ratio,
depthwise_kernel_size=depthwise_kernel_size,
use_external_adaln=False,
)
for _ in range(depth)
]
)
out_dim = (
2 * bottleneck_dim
if bottleneck_posterior_kind == "diagonal_gaussian"
else bottleneck_dim
)
self.to_bottleneck = nn.Conv2d(model_dim, out_dim, kernel_size=1, bias=True)
if bottleneck_norm_mode == "channel_wise":
self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False)
else:
self.norm_out = nn.Identity()
def encode_posterior(self, images: Tensor) -> EncoderPosterior:
"""Encode images and return the full posterior (mean + logsnr).
Only valid when bottleneck_posterior_kind == "diagonal_gaussian".
"""
z = self.patchify(images)
for block in self.blocks:
z = block(z)
projection = self.to_bottleneck(z)
mean, logsnr = projection.chunk(2, dim=1)
mean = self.norm_out(mean)
return EncoderPosterior(mean=mean, logsnr=logsnr)
def forward(self, images: Tensor) -> Tensor:
"""Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w].
Returns posterior mode (alpha * mean) for diagonal_gaussian,
or deterministic latents otherwise.
"""
z = self.patchify(images)
for block in self.blocks:
z = block(z)
projection = self.to_bottleneck(z)
if self.bottleneck_posterior_kind == "diagonal_gaussian":
mean, logsnr = projection.chunk(2, dim=1)
mean = self.norm_out(mean)
alpha = torch.sigmoid(logsnr).sqrt().to(dtype=mean.dtype)
return alpha * mean
z = self.norm_out(projection)
return z