"""Capacitor encoder: patchify -> FCDMBlocks -> diagonal Gaussian posterior. No input RMSNorm (use_other_outer_rms_norms=False during training). Post-bottleneck RMSNorm (affine=False) on the mean branch. Encoder outputs posterior mode by default: alpha * RMSNorm(mean). """ from __future__ import annotations from dataclasses import dataclass import torch from torch import Tensor, nn from .fcdm_block import FCDMBlock from .norms import ChannelWiseRMSNorm from .straight_through_encoder import Patchify @dataclass(frozen=True) class EncoderPosterior: """VP-parameterized diagonal Gaussian posterior. mean: Clean signal branch mu [B, bottleneck_dim, h, w] logsnr: Per-element log signal-to-noise ratio [B, bottleneck_dim, h, w] """ mean: Tensor logsnr: Tensor @property def alpha(self) -> Tensor: """VP signal coefficient: sqrt(sigmoid(logsnr)).""" return torch.sigmoid(self.logsnr).sqrt() @property def sigma(self) -> Tensor: """VP noise coefficient: sqrt(sigmoid(-logsnr)).""" return torch.sigmoid(-self.logsnr).sqrt() def mode(self) -> Tensor: """Posterior mode in token space: alpha * mean.""" return self.alpha.to(dtype=self.mean.dtype) * self.mean def sample(self, *, generator: torch.Generator | None = None) -> Tensor: """Sample from posterior: alpha * mean + sigma * eps.""" eps = torch.randn_like(self.mean, generator=generator) # type: ignore[call-overload] alpha = self.alpha.to(dtype=self.mean.dtype) sigma = self.sigma.to(dtype=self.mean.dtype) return alpha * self.mean + sigma * eps class Encoder(nn.Module): """Encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w]. With diagonal_gaussian posterior, the to_bottleneck projection outputs 2 * bottleneck_dim channels, split into mean and logsnr. The default encode() returns the posterior mode: alpha * RMSNorm(mean). """ def __init__( self, in_channels: int, patch_size: int, model_dim: int, depth: int, bottleneck_dim: int, mlp_ratio: float, depthwise_kernel_size: int, bottleneck_posterior_kind: str = "diagonal_gaussian", bottleneck_norm_mode: str = "disabled", ) -> None: super().__init__() self.bottleneck_dim = int(bottleneck_dim) self.bottleneck_posterior_kind = bottleneck_posterior_kind self.bottleneck_norm_mode = bottleneck_norm_mode self.patchify = Patchify(in_channels, patch_size, model_dim) self.blocks = nn.ModuleList( [ FCDMBlock( model_dim, mlp_ratio, depthwise_kernel_size=depthwise_kernel_size, use_external_adaln=False, ) for _ in range(depth) ] ) out_dim = ( 2 * bottleneck_dim if bottleneck_posterior_kind == "diagonal_gaussian" else bottleneck_dim ) self.to_bottleneck = nn.Conv2d(model_dim, out_dim, kernel_size=1, bias=True) if bottleneck_norm_mode == "channel_wise": self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False) else: self.norm_out = nn.Identity() def encode_posterior(self, images: Tensor) -> EncoderPosterior: """Encode images and return the full posterior (mean + logsnr). Only valid when bottleneck_posterior_kind == "diagonal_gaussian". """ z = self.patchify(images) for block in self.blocks: z = block(z) projection = self.to_bottleneck(z) mean, logsnr = projection.chunk(2, dim=1) mean = self.norm_out(mean) return EncoderPosterior(mean=mean, logsnr=logsnr) def forward(self, images: Tensor) -> Tensor: """Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w]. Returns posterior mode (alpha * mean) for diagonal_gaussian, or deterministic latents otherwise. """ z = self.patchify(images) for block in self.blocks: z = block(z) projection = self.to_bottleneck(z) if self.bottleneck_posterior_kind == "diagonal_gaussian": mean, logsnr = projection.chunk(2, dim=1) mean = self.norm_out(mean) alpha = torch.sigmoid(logsnr).sqrt().to(dtype=mean.dtype) return alpha * mean z = self.norm_out(projection) return z