File size: 4,560 Bytes
b32916f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """Capacitor encoder: patchify -> FCDMBlocks -> diagonal Gaussian posterior.
No input RMSNorm (use_other_outer_rms_norms=False during training).
Post-bottleneck RMSNorm (affine=False) on the mean branch.
Encoder outputs posterior mode by default: alpha * RMSNorm(mean).
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from .fcdm_block import FCDMBlock
from .norms import ChannelWiseRMSNorm
from .straight_through_encoder import Patchify
@dataclass(frozen=True)
class EncoderPosterior:
"""VP-parameterized diagonal Gaussian posterior.
mean: Clean signal branch mu [B, bottleneck_dim, h, w]
logsnr: Per-element log signal-to-noise ratio [B, bottleneck_dim, h, w]
"""
mean: Tensor
logsnr: Tensor
@property
def alpha(self) -> Tensor:
"""VP signal coefficient: sqrt(sigmoid(logsnr))."""
return torch.sigmoid(self.logsnr).sqrt()
@property
def sigma(self) -> Tensor:
"""VP noise coefficient: sqrt(sigmoid(-logsnr))."""
return torch.sigmoid(-self.logsnr).sqrt()
def mode(self) -> Tensor:
"""Posterior mode in token space: alpha * mean."""
return self.alpha.to(dtype=self.mean.dtype) * self.mean
def sample(self, *, generator: torch.Generator | None = None) -> Tensor:
"""Sample from posterior: alpha * mean + sigma * eps."""
eps = torch.randn_like(self.mean, generator=generator) # type: ignore[call-overload]
alpha = self.alpha.to(dtype=self.mean.dtype)
sigma = self.sigma.to(dtype=self.mean.dtype)
return alpha * self.mean + sigma * eps
class Encoder(nn.Module):
"""Encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w].
With diagonal_gaussian posterior, the to_bottleneck projection outputs
2 * bottleneck_dim channels, split into mean and logsnr. The default
encode() returns the posterior mode: alpha * RMSNorm(mean).
"""
def __init__(
self,
in_channels: int,
patch_size: int,
model_dim: int,
depth: int,
bottleneck_dim: int,
mlp_ratio: float,
depthwise_kernel_size: int,
bottleneck_posterior_kind: str = "diagonal_gaussian",
bottleneck_norm_mode: str = "disabled",
) -> None:
super().__init__()
self.bottleneck_dim = int(bottleneck_dim)
self.bottleneck_posterior_kind = bottleneck_posterior_kind
self.bottleneck_norm_mode = bottleneck_norm_mode
self.patchify = Patchify(in_channels, patch_size, model_dim)
self.blocks = nn.ModuleList(
[
FCDMBlock(
model_dim,
mlp_ratio,
depthwise_kernel_size=depthwise_kernel_size,
use_external_adaln=False,
)
for _ in range(depth)
]
)
out_dim = (
2 * bottleneck_dim
if bottleneck_posterior_kind == "diagonal_gaussian"
else bottleneck_dim
)
self.to_bottleneck = nn.Conv2d(model_dim, out_dim, kernel_size=1, bias=True)
if bottleneck_norm_mode == "channel_wise":
self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False)
else:
self.norm_out = nn.Identity()
def encode_posterior(self, images: Tensor) -> EncoderPosterior:
"""Encode images and return the full posterior (mean + logsnr).
Only valid when bottleneck_posterior_kind == "diagonal_gaussian".
"""
z = self.patchify(images)
for block in self.blocks:
z = block(z)
projection = self.to_bottleneck(z)
mean, logsnr = projection.chunk(2, dim=1)
mean = self.norm_out(mean)
return EncoderPosterior(mean=mean, logsnr=logsnr)
def forward(self, images: Tensor) -> Tensor:
"""Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w].
Returns posterior mode (alpha * mean) for diagonal_gaussian,
or deterministic latents otherwise.
"""
z = self.patchify(images)
for block in self.blocks:
z = block(z)
projection = self.to_bottleneck(z)
if self.bottleneck_posterior_kind == "diagonal_gaussian":
mean, logsnr = projection.chunk(2, dim=1)
mean = self.norm_out(mean)
alpha = torch.sigmoid(logsnr).sqrt().to(dtype=mean.dtype)
return alpha * mean
z = self.norm_out(projection)
return z
|