FNet

PAPER: FNet: Mixing Tokens with Fourier Transforms

FNet

  Vanilla attention mechanism aims to connect each token in the input through a relevance weighted basis of every other token, which requires large usage of computation and memory. Synthesizer and other relevant researches have challenged the necessity of attention sublayer based on dot product. FNet also proposes an alternative for attention, which directly utilizes non-parameters Fourier Transform to capture token-wise interaction. Here are the illustration and the code of FNet.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class FNetBlock(nn.Module):
    def __init__(self, d_model, d_ff, dropout):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        self.norm_1 = nn.LayerNorm(d_model)
        self.norm_2 = nn.LayerNorm(d_model)

    def fourier_transform(self, x):
        return torch.fft.fft(x, dim=-1).real

    def forward(self, x):
        residual = x
        x = self.fourier_transform(x)
        x = self.norm_1(x + residual)
        residual = x
        x = self.mlp(x)
        out = self.norm_2(x + residual)

        return out


class FNet(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1, e_layer=3):
        super().__init__()
        self.encoder = nn.ModuleList([
            FNetBlock(d_model, d_ff, dropout) for _ in range(e_layer)
        ])

    def forward(self, x):
        for layer in self.encoder:
            x = layer(x)

        return x

  Notably, modified FNet (DCT, extra learnable parameters) degraded accuracy and reduced training stability. More experiments and precise explanation should be taken on existing token-mixing approaches. More experiments and results of FNet could be seen at the original paper.