RevIN

PAPER: Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift

Motivations

  Normally, we usually use moments of each order such as mean and variance to describe the distribution of the time series. In fact, time series data often suffers from a severe distribution shift. Some existing methods try to disentangle the time series into the composition of trend and seasonality to alleviate this problem, but it still affects the performance of time series forecasting.

Imgur

RevIN

  To eliminate the impact caused by the change of moments, RevIN directly removes non-stationary information from the input sequences, specifically, the mean and standard deviation of the instances. The figure and code below demonstrates the mechanism of RevIN.

Imgur

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class RevIN(nn.Module):
    def __init__(self, num_features: int, eps=1e-5, affine=True):
        """
        :param num_features: the number of features or channels
        :param eps: a value added for numerical stability
        :param affine: if True, RevIN has learnable affine parameters
        """
        super(RevIN, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        if self.affine:
            self._init_params()

    def forward(self, x, mode:str):
        if mode == 'norm':
            self._get_statistics(x)
            x = self._normalize(x)
        elif mode == 'denorm':
            x = self._denormalize(x)
        else: raise NotImplementedError
        return x

    def _init_params(self):
        # initialize RevIN params: (C,)
        self.affine_weight = nn.Parameter(torch.ones(self.num_features))
        self.affine_bias = nn.Parameter(torch.zeros(self.num_features))

    def _get_statistics(self, x):
        dim2reduce = tuple(range(1, x.ndim-1))
        self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
        self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()

    def _normalize(self, x):
        x = x - self.mean
        x = x / self.stdev
        if self.affine:
            x = x * self.affine_weight
            x = x + self.affine_bias
        return x

    def _denormalize(self, x):
        if self.affine:
            x = x - self.affine_bias
            x = x / (self.affine_weight + self.eps*self.eps)
        x = x * self.stdev
        x = x + self.mean
        return x

  We can directly apply RevIN to any existing time series model layers. RevIN removes mean and variance of time series to normalize different channels as follows. We can see that RevIN normalizes the input data while keeping the alignment of different channels of time series. In the meanwhile, the learnable affine parameters of RevIN help the model to learn the discrepancies (almost caused by trend) between different segmentations so as to mitigate distribution shift problem. Table below shows the performance of RevIN applied to existing time series forecasting models.

Imgur

Imgur