File size: 4,560 Bytes
b32916f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""Capacitor encoder: patchify -> FCDMBlocks -> diagonal Gaussian posterior.

No input RMSNorm (use_other_outer_rms_norms=False during training).
Post-bottleneck RMSNorm (affine=False) on the mean branch.
Encoder outputs posterior mode by default: alpha * RMSNorm(mean).
"""

from __future__ import annotations

from dataclasses import dataclass

import torch
from torch import Tensor, nn

from .fcdm_block import FCDMBlock
from .norms import ChannelWiseRMSNorm
from .straight_through_encoder import Patchify


@dataclass(frozen=True)
class EncoderPosterior:
    """VP-parameterized diagonal Gaussian posterior.

    mean: Clean signal branch mu [B, bottleneck_dim, h, w]
    logsnr: Per-element log signal-to-noise ratio [B, bottleneck_dim, h, w]
    """

    mean: Tensor
    logsnr: Tensor

    @property
    def alpha(self) -> Tensor:
        """VP signal coefficient: sqrt(sigmoid(logsnr))."""
        return torch.sigmoid(self.logsnr).sqrt()

    @property
    def sigma(self) -> Tensor:
        """VP noise coefficient: sqrt(sigmoid(-logsnr))."""
        return torch.sigmoid(-self.logsnr).sqrt()

    def mode(self) -> Tensor:
        """Posterior mode in token space: alpha * mean."""
        return self.alpha.to(dtype=self.mean.dtype) * self.mean

    def sample(self, *, generator: torch.Generator | None = None) -> Tensor:
        """Sample from posterior: alpha * mean + sigma * eps."""
        eps = torch.randn_like(self.mean, generator=generator)  # type: ignore[call-overload]
        alpha = self.alpha.to(dtype=self.mean.dtype)
        sigma = self.sigma.to(dtype=self.mean.dtype)
        return alpha * self.mean + sigma * eps


class Encoder(nn.Module):
    """Encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w].

    With diagonal_gaussian posterior, the to_bottleneck projection outputs
    2 * bottleneck_dim channels, split into mean and logsnr. The default
    encode() returns the posterior mode: alpha * RMSNorm(mean).
    """

    def __init__(
        self,
        in_channels: int,
        patch_size: int,
        model_dim: int,
        depth: int,
        bottleneck_dim: int,
        mlp_ratio: float,
        depthwise_kernel_size: int,
        bottleneck_posterior_kind: str = "diagonal_gaussian",
        bottleneck_norm_mode: str = "disabled",
    ) -> None:
        super().__init__()
        self.bottleneck_dim = int(bottleneck_dim)
        self.bottleneck_posterior_kind = bottleneck_posterior_kind
        self.bottleneck_norm_mode = bottleneck_norm_mode
        self.patchify = Patchify(in_channels, patch_size, model_dim)
        self.blocks = nn.ModuleList(
            [
                FCDMBlock(
                    model_dim,
                    mlp_ratio,
                    depthwise_kernel_size=depthwise_kernel_size,
                    use_external_adaln=False,
                )
                for _ in range(depth)
            ]
        )
        out_dim = (
            2 * bottleneck_dim
            if bottleneck_posterior_kind == "diagonal_gaussian"
            else bottleneck_dim
        )
        self.to_bottleneck = nn.Conv2d(model_dim, out_dim, kernel_size=1, bias=True)
        if bottleneck_norm_mode == "channel_wise":
            self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False)
        else:
            self.norm_out = nn.Identity()

    def encode_posterior(self, images: Tensor) -> EncoderPosterior:
        """Encode images and return the full posterior (mean + logsnr).

        Only valid when bottleneck_posterior_kind == "diagonal_gaussian".
        """
        z = self.patchify(images)
        for block in self.blocks:
            z = block(z)
        projection = self.to_bottleneck(z)
        mean, logsnr = projection.chunk(2, dim=1)
        mean = self.norm_out(mean)
        return EncoderPosterior(mean=mean, logsnr=logsnr)

    def forward(self, images: Tensor) -> Tensor:
        """Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w].

        Returns posterior mode (alpha * mean) for diagonal_gaussian,
        or deterministic latents otherwise.
        """
        z = self.patchify(images)
        for block in self.blocks:
            z = block(z)
        projection = self.to_bottleneck(z)
        if self.bottleneck_posterior_kind == "diagonal_gaussian":
            mean, logsnr = projection.chunk(2, dim=1)
            mean = self.norm_out(mean)
            alpha = torch.sigmoid(logsnr).sqrt().to(dtype=mean.dtype)
            return alpha * mean
        z = self.norm_out(projection)
        return z