MLP-Mixer

PAPER: MLP-Mixer: An all-MLP Architecture for Vision

Motivations

  Modern deep vision architectures consist of layers that mix features (i) at a given spatial location, (ii) between different spatial locations, or both at once. In CNNs, (ii) is implemented with $N\times N$ convolutions (for $N>1$) and pooling. Neurons in deeper layers have a larger receptive field due to downsampling. Especially, $1\times 1$ convolutions perform (i) and larger kernels perform both (i) and (ii). In Vision Transformers and other attention-based architectures, self-attention layers allow both (i) and (ii) and the FFN perform (i).

  We can summarize two types of mix features above as the per-location operations (channel-mixing) and cross-location operations (token-mixing). The idea behind the MLP-Mixer architecture is separately extract two types of features across channels and tokens using MLPs.

MLP-Mixer

  The figure below illustrates the architecture of MLP-Mixer. MLP-Mixer contains two types of layers: one with MLPs applied independently to image patches (i.e. “mixing” the per-location features), and one with MLPs applied across patches (i.e. “mixing” spatial information). The computational complexity of the MLP-Mixer is linear in the number of input patches, unlike ViT whose complexity is quadratic. Unlike ViTs, Mixer does not use position embeddings because the token-mixing MLPs are sensitive to the order of the input tokens. The runnerable codes in PyTorch are as follows.

Imgur

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class MLPBlock(nn.Module):
    def __init__(self, input_dim, mlp_dim) :
        super().__init__()
        self.fc1 = nn.Linear(input_dim, mlp_dim)
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(mlp_dim, input_dim)
    
    def forward(self,x):
        # [B, #tokens, D] or [B, D, #tokens]
        return self.fc2(self.gelu(self.fc1(x)))


class MixerBlock(nn.Module):
    def __init__(self, tokens, channels, tokens_hidden_dim, channels_hidden_dim):
        super().__init__()
        self.token_mixing = MLPBlock(tokens, tokens_hidden_dim)
        self.channel_mixing = MLPBlock(channels, channels_hidden_dim)
        self.norm = nn.LayerNorm(channels)

    def forward(self, x):
        # token-mixing [B, D, #tokens]
        y = self.norm(x).transpose(1, 2)
        y = self.token_mixing(y)

        # channel-mixing [B, #tokens, D]
        y = y.transpose(1, 2) + x
        res = y
        y = self.norm(y)
        y = res + self.channel_mixing(y)

        return y


class MLPMixer(nn.Module):
    def __init__(self, num_classes, num_blocks, patch_size, tokens_hidden_dim, channels_hidden_dim, tokens, channels):
        super().__init__()
        self.embed = nn.Conv2d(3, channels, kernel_size=patch_size, stride=patch_size) 
        self.mlp_blocks = nn.ModuleList([
            MixerBlock(tokens, channels, tokens_hidden_dim, channels_hidden_dim) for _ in range(num_blocks)
        ])
        self.norm = nn.LayerNorm(channels)
        self.fc = nn.Linear(channels, num_classes)

    def forward(self,x):
        # [B, C, H, W] -> [B, D, patch_h, patch_w] -> [B, #tokens, D]
        y = self.embed(x)
        B, D, _, _ = y.shape
        y = y.view(B, D, -1).transpose(1, 2)

        for block in self.mlp_blocks:
            y = block(y)
        
        y = self.norm(y)
        y = torch.mean(y, dim=1, keepdim=False) # [B, D]
        probs = self.fc(y) # [B, #class]

        return probs

  More experiments and results could be seen in original paper. As an alternative of attention-based architectures, MLP-Mixer has a simpler but more efficient structure. Attention may be not all you need.