Sinkhorn Distance

PAPER: Sinkhorn Distances: Lightspeed Computation of Optimal Transport

Optimal Transport

  Suppose $(X,r)$ and $(Y,c)$ are metric space with probabilities measures, which have the same total mass $\int_Xr\text{d}x=\int_Yc\text{d}y$. Generally $r$ and $c$ represent marginal probability distributions, hence their values sum to one. Let $U(r,c)$ be the set of positive transport matrices of $r\in\mathbb{R}^n$ and $c\in\mathbb{R}^m$

\[U(r,c)=\{P\in\mathbb{R}^{n\times m}_+\mid P\mathbf{1}_m=r,P^\top\mathbf{1}_n=c\}\]

where the columns of $P$ sum to $r$ and the rows sum to $c$. $U(r,c)$ contains all possible joint probabilities of $(X,Y)$ where $p(X=i,Y=j)=p_{ij}$. Thus, we can define the entropy $h$ and the KL-divergence of $P,Q\in U(r,c)$ as

\[h(r)=\sum_{i}r_i\log r_i,\quad h(P)=-\sum_{ij}p_{ij}\log p_{ij},\quad KL(P\Vert Q)=\sum_{ij}p_{ij}\log\frac{p_{ij}}{q_{ij}}\]

  Given a $n\times m$ cost matrix $M$, the cost of mapping $r$ to $c$ using a transport matrix $P$ can be quantified as $\langle P,M\rangle$. The optimal transport between $r$ and $c$ is denoted as

\[d_M(r,c)=\min_{P\in U(r,c)}\langle P,M\rangle\]

where our goal is to find the optimal $P$ with the lowest cost. The optimum $d_M{r,c}$ is the Wasserstein distance between $r$ and $c$ whenever $M$ is itself a metric matrix. It can be solved relatively using linear programming in $O(d^3\log d)$.

Sinkhorn Distance

  Lemma 1: The independence table $rc^\top$ has entropy $h(rc^\top)=h(r)+h(c)$.

\[\begin{aligned} h(rc^\top) &= -\sum_{ij}r_ic_j\log(r_ic_j) \\ &= -\sum_{ij}r_ic_j\log r_i-\sum_{ij} r_ic_j\log c_j \\ &= \sum_jc_jh(r) + \sum_ir_ih(c) \\ &= h(r) + h(c) \end{aligned}\]

  Lemma 2: $\forall r,c\in\Sigma_d,\forall P\in U(r,c),h(P)\leq h(r)+h(c)$.

  Consider a Lagrange multiplier for $h(P)$, we have

\[L(P,\alpha,\beta)=-\sum_{ij}p_{ij}\log p_{ij}+\alpha^\top(P\mathbf{1}_d-r)+\beta^\top(P^\top\mathbf{1}_d-c)\] \[\frac{\partial L}{\partial p_{ij}}=-1-\log p_{ij}+\alpha_i+\beta_j=0 \\ \frac{\partial L}{\partial \alpha_i}=\sum_jp_{ij}-r_i=0 \\ \frac{\partial L}{\partial \beta_j}=\sum_ip_{ij}-c_j=0\]

  By solving the above equations, we can get $p_{ij}=e^{\alpha_i+\beta_j-1}$ and

\[r_i=\sum_jp_{ij}=e^{\alpha_i-1}\sum_je^{\beta_j} \\ c_j=\sum_ip_{ij}=e^{\beta_j-1}\sum_ie^{\alpha_i}\] \[\sum_ie^{\alpha_i}=\sum_i\frac{r_ie}{\sum_je^{\beta_j}}=\frac{e}{\sum_je^{\beta_j}}\]

  Combined with the above equations, we have $p_{ij}=r_ic_j$, which means $P=rc^\top$ when $h(P)=h(r)+h(c)$. According to lemma 1 and 2, we can introduce the convex set

\[U_\alpha(r,c)=\{P\in U(r,c)\mid KL(P\Vert rc^\top)\leq\alpha\}=\{P\in U(r,c)\mid h(P)\geq h(r)+h(c)-\alpha\}\]

where $P$ is constrained to have sufficient entropy w.r.t. $\alpha$. As we know that the uniform distribution has maximum entropy, the entropic constraint in optimal transport can smooth the transport matrix $P$. Thus, the Sinkhorn distance with entropic constraints is defined as

\[d_{M,\alpha}(r,c)=\min_{P\in U_\alpha(r,c)}\langle P,M\rangle\]

where $\alpha$ controls the entropy of $P$. When $\alpha$ is large enough, the Sinkhorn distance coincides with the classic OT distance. When $\alpha=0$, the Sinkhorn distance $d_{M,0}=r^\top Mc$ (hint: $P=rc^\top$) if $M$ is a Euclidean distance matrix. Consider a Lagrange multiplier for the entropy constraint of Sinkhorn distance

\[d^\lambda_M(r,c)=\langle P^\lambda,M\rangle\quad P^\lambda=\arg\min_{P\in U(r,c)}\langle P,M\rangle-\frac{1}{\lambda}h(P)\]

where each $\alpha$ corresponds a $\lambda\in[0,+\infty)$ such that $d_{M,\alpha}(r,c)=d_M^\lambda(r,c)$. Figure below summarizes the relationships among $d_M$, $d_{M,\alpha}$ and $d_M^\lambda$. Since the entropy of $P^\lambda$ decreases monotonically with $\lambda$, computing $d_{M,\alpha}$ can be carried out by computing $d_M^\lambda$ with increasing values of $\lambda$ until $h(P^\lambda)$ reaches $h(rc^\top)-\alpha$.

Imgur

  Lemma 3: For $\lambda>0$, the solution $P^\lambda$ is unique and has the form $P^\lambda=\text{diag}(u)K\text{diag}(v)$ where $u$ and $v$ are two non-negative vectors of $\mathbb{R}^d$ uniquely defined up to a multiplicative factor and $K=e^{-\lambda M}$.

  The fact that $P^\lambda$ can be written as a rescaled version of $K$ is a well-known fact in transport theory. Let $L(P,\alpha,\beta)$ be the Lagrangian of $d_M^\lambda(r,c)$, we have

\[L(P,\alpha,\beta)=\sum_{ij}(\frac{1}{\lambda}p_{ij}\log p_{ij}+p_{ij}m_{ij})+\alpha^\top(P\mathbf{1}_d-r)+\beta^\top(P^\top\mathbf{1}_d-c)\] \[\frac{\partial P}{\partial p_{ij}}=\frac{1}{\lambda}\log p_{ij}+\frac{1}{\lambda}+m_{ij}+\alpha_i+\beta_j=0\] \[p_{ij}=e^{-\lambda(m_{ij}+\alpha_i+\beta_j)-1}\]

  Since $K$ is strictly positive, Sinkhorn’s theorem states that there exists a unique matrix of the form $\text{diag}(u)K\text{diag}(v)$ that belongs to $U(r,c)$ where $u,v\geq\mathbf{0}_d$. $P^\lambda$ is thus necessarily that matrix, and can be computed with rescaling iteration.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def compute_optimal_transport(M, r, c, lam, epsilon=1e-5):
    """
    Computes the optimal transport matrix and Sinkhorn distance using the
    Sinkhorn-Knopp algorithm

    Inputs:
        - M : cost matrix (n x m)
        - r : vector of marginals (n, )
        - c : vector of marginals (m, )
        - lam : strength of the entropic regularization
        - epsilon : convergence parameter

    Output:
        - P : optimal transport matrix (n x m)
        - dist : Sinkhorn distance
    """
    n, m = M.shape
    P = np.exp(-lam * M) # K
    # Avoiding poor math condition
    P /= P.sum()
    u = np.zeros(n)
    # Normalize this matrix so that P.sum(1) == r, P.sum(0) == c
    while np.max(np.abs(u - P.sum(1))) > epsilon:
        # Shape (n, )
        u = P.sum(1)
        P *= (r / u).reshape((-1, 1))
        P *= (c / P.sum(0)).reshape((1, -1))
        
    return P, np.sum(P * M)

  In practice, we generally assume that $r$ and $c$ follow the uniform distribution, and compute the transport matrix $P$ and Sinkhorn distance mainly via the cost matrix $M$. Here we provide an example of transport from $x$ to $y$.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
base = np.linspace(-4, 4, 128)
x = np.sin(4 * base)[:, np.newaxis]
y = np.cos(4 * base)[:, np.newaxis]
M = pairwise_distances(x, y, metric='euclidean')

n, m = M.shape
# Uniform distribution
r = np.ones(n) / n
c = np.ones(m) / m
P, d = compute_optimal_transport(M, r, c, lam=500, epsilon=1e-6)
# Normalize, so each row sums to 1 (i.e. probability)
P /= r.reshape((-1, 1))

plt.subplot(1, 2, 1)
plt.title('Measure')
sns.heatmap(M, square=False, cmap='Reds')
plt.subplot(1, 2, 2)
plt.title('Transport')
sns.heatmap(P, square=False, cmap='Reds')
plt.show()

Imgur

1
2
3
4
5
plt.plot(x1, label='curve 1')
plt.plot(x2, label='curve 2')
plt.plot(P.T @ x1, label='1 -> 2')
plt.legend()
plt.show()

Imgur