[논문리뷰] Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
소개
오늘 리뷰하는 논문은 Swin Transformer(Swin)입니다. 이 논문은 Vision Transformer(ViT)의 후속작이라고 보시면 될 것 같습니다. 그렇기 때문에 ViT 기반으로 모델이 동작하는 부분이 대다수이기 때문에 꼭 ViT 논문을 보고 오시는 것을 추천드립니다.
위에서 언급한 것처럼, Swin은 ViT 기반으로 만들어진 백본입니다. 그렇기 때문에 컴퓨터 비전에 CNN의 구조가 아닌 Transformer 구조를 따릅니다. 기존 ViT와 달리 본 논문에서 제안한 기술은 제목에서 볼 수 있듯이 2가지로 구성됩니다.
- 첫 번째 제안 기술(계층적 구조): 기존 CNN의 계층적 구조는 다들 아시죠? 일반적인 CNN 즉, 백본(Backbone)이라고 불리는 구조는 아래 그림과 같습니다.
다 그렇다고 할 수는 없지만, 일반적으로 동일한 Kernel size를 가져가고 중간에 Maxpooling으로 이미지의 해상도를 줄여가며 수용 필드를 늘려가는 구조입니다. 이렇게 함으로써 더 많은 Representations을 학습할 수 있고, 더 나은 정확도를 가져왔었죠. 이 Pooling 구조를 Transformer에도 도입한 것이 Swin의 첫 번째 제안된 기술입니다. 이렇게 함으로써 다양한 Representations을 학습할 수 있을뿐더러, 줄어든 해상도에서 Attention 연산을 진행하기 때문에 속도에도 이점이 있습니다. - 두 번째 제안 기술(Shift windows): Shift랑 Windows 둘 다 처음 보는 기술 같습니다. 먼저, Windows 방식이 무엇이냐 하면 아래 그림과 같습니다.
본 논문은 ViT의 후속작이라고 했습니다. 그렇다면 기존 기술보다 나은 점이 당연 있어야겠죠? ViT는 아신다고 생각하고 설명을 진행하겠습니다. 이미지가 (b)의 그림같이 패치로 나누어져 이 패치들이 각각의 Query, Key로 작동하여 기존 Transformer 같이 연산이 진행됩니다. 하지만, 여기서 문제점이 무엇이냐 하면 일반적으로 문장에서 단어의 길이는 어느 정도 되죠? 많아야 50 정도 될 것입니다. 그럼 이 단어들이 Query, Key로 각각 연산되겠지요. 하지만 이미지의 경우를 볼까요? 이미지는 보통 224 * 224 해상도가 되는데 이 모든 픽셀이 Query, Key로 작동한다면 시간이 엄청 걸리겠죠. ViT의 경우 이미지 패치로 작동되지만 그래도 시간이 오래 걸리는 것에는 변함이 없습니다.
여기서 제안한 기술은 고정된 Window(M * M의 크기)를 이미지 패치에 적용하여, Attention의 연산이 Window 안에서만 연산되게 하는 방식입니다. 전체가 아닌, 위 빨간색 테두리 안에서만 연산을 진행하자는 것입니다. 이렇게 하면 당연히 전체 이미지 패치에 대해 연산하는 ViT에 비해서는 당연히 속도가 올라가겠지요.
하지만 마냥 속도만 올라가면 될까요?
위 그림은 각 Window의 인덱스를 매겨보았습니다. [0-15]의 인덱스를 가지는 Windows가 존재하는데, 각 Windows가 독립적으로 Attention 연산을 진행하면 위 그림처럼 서로 다른 인덱스를 가지는 Window 끼리는 Attention 연산이 진행되지 않습니다. 그럼 장기의존성을 가져 전역적으로 정보를 집계하는 Transformer의 구조를 깨버리게 되는 것이고, 당연히 성능 또한 저하됩니다.
그래서 이 부분을 Shift Windows 방식으로 각각의 윈도끼리의 Attention 연산을 진행하는 것입니다. 모든 Windows끼리는 아니더라도 일부 Shift Windows 방식을 진행하여 연산 속도와 정확도 둘 다 월등한 성능을 기록하였습니다. 이 방식은 이따가 자세히 설명하겠습니다. 그냥 이런 것이 있구나 정도로만 이해해 주시면 되겠습니다. - 이러한 방식으로 Swin Transformer는 아주 범용적인 백본이 되었습니다. 이미지 분류뿐만 아니라, Object Detection, Semantic Degmentation 모두 월등한 성능을 기록했다고 하네요.
대략적인 부분은 살펴보았고 이제 자세한 구조를 보겠습니다.
Swin Transformer
전반적인 구조는 아래와 같습니다.
아주 복잡해 보이지만, 기존 ViT와 다른 부분은 따로 표시해 두었습니다. 파란색으로 표시한 것은 첫 번째 제안기술인 계층적 구조를 되게 만든 Patch Merging이라는 것입니다. 그리고 주황색으로 표시한 것은 두 번째 제안 기술인 Shift windows를 기존 Transformer 블록에 적용한 것입니다. 자잘한 부분은 제외하고 색으로 표시하지 않은 모든 구조는 ViT와 동일합니다. 이제 네트워크가 어떤 식으로 흘러가는지 확인해 보겠습니다.
네트워크 흐름
- "Input": 먼저 H x W x 3의 해상도를 가지는 어떤 이미지가 입력으로 들어가고, 이미지가 겹치지 않게 각각의 이미지 패치를 나눕니다(Patch Partition).
- "Stage 1": Transformer 학습을 위해 사용자가 정의한 C차원으로 매핑해줍니다(Linear Embedding, 출력은 \( \frac {H} {4} * \frac {W} {4} * C \) 차원). 여기서 2개로 구성된 Swin Transformer Block으로 입력되어 동일한 차원인 \( \frac {H} {4} * \frac {W} {4} * C \)가 출력됩니다.
- "Stage 2": Patch Merging으로 \( \frac {H} {4} * \frac {W} {4} * C \)의 해상도가 \( \frac {H} {8} * \frac {W} {8} * 2C \)의 해상도로 줄어듭니다. \( \frac {H} {8} * \frac {W} {8} * 2C \)의 해상도가 2개로 구성된 Swin Transformer Block으로 입력되어 동일한 차원인 \( \frac {H} {8} * \frac {W} {8} * 2C \)를 출력합니다.
- "Stage 3, Stage 4": Stage 2와 차원과 Swin Transformer Block의 개수만 다르고 나머지는 동일합니다
Patch Merging
그럼 해상도를 어떻게 줄일까요? Stage 2를 예시로 들면 아래와 같은 그림이 나옵니다.
- Stage 1의 출력인 \( \frac {H} {4} * \frac {W} {4} * C \)의 차원을 2 * 2 그룹들로 나눕니다.
- 나눠진 하나의 그룹은 \( \frac{H} {8} * \frac{W} {8} * C \)의 차원을 가지고, 4개의 그룹들을 채널을 기준으로 병합합니다(Concat).
- 병합된 \( \frac{H} {8} * \frac{W} {8} * 4C \)의 차원 축소를 위해 절반인 \( 2C \)의 차원으로 축소합니다.
- 위 과정들은 모든 Stage에서 동일하게 작용합니다.
이러한 계층적 구조는, 일반적인 Representations보다 더 계층적인 Representations을 학습할 수 있고, 앞선 전술한 것처럼 줄어든 차원만큼 연산속도에도 이점이 있습니다.
Swin Tranfomer Block
Window
앞에서 Swin은 Window로 쪼개는 방식으로 ViT보다 연산에 이점이 있다고 하였습니다. 얼마나 이점이 있는지 확인해 보겠습니다.
Ω 기호는 연산에 얼마나 시간이 걸리는지 측정한 기준입니다. 먼저 (1)은 ViT, (2)는 Swin입니다. 차이는 (1)의 \(2(hw)^2 \)와 (2)의 \(2M^2 hwC \)입니다. 여기서 \(M \)은 윈도의 크기인데 보통은 7로 고정합니다.
기존 방법은 해상도에 따라 2차원적으로 계산량이 증가합니다. 다시 말해, 해상도가 올라가면 계산량이 기하급수적으로 증가한다는 의미입니다. 반면, Swin의 경우 윈도우의 크기는 보통 고정되어 있으니 상수처럼 취급하고, hw의 크기에서만 선형적으로 계산량이 증가하기 때문에 계산적인 부분에서는 상당한 이점이 있습니다!
(a)는 ViT, (b)는 Swin의 Transformer가 동작하는 방식입니다. 빨간색 부분이 연산 속도에 연관을 주는 부분이라고 생각하시면 됩니다. 그리고 \( A * B \) 행렬과 \( B * C \) 행렬의 곱의 시간복잡도는 \( Ω (A * B * C) \)입니다.
- (a) ViT
- 먼저, \( hw * c \)의 Patch Embedding이 \( C * C \)의 차원을 가지는 각각의 Query, Key, Value의 Linear layer로 입력됩니다 ( \( [hw * C] \) x \( [C * C] \) x 3 = \( \Omega \left(3hw * C^2 \right) \)).
- 각각의 Query, Key, Value는 \( hw * C \)의 차원을 가집니다. 다음으로 Key와 Query의 곱으로 유사도를 구합니다 (\( [hw * C] \) x \( [C * hw] \) = \( \Omega \left(hw^2 * C \right) \)).
- 본 논문에서 softmax연산은 시간 복잡도에 생략한다고 합니다. 유사도를 구한 \( [hw * hw]\)차원과 \( [hw * C] \)차원을 가지는 Value를 곱해주면 (\( [hw * hw] \) x \( [hw * C] \) = \( \Omega \left(hw^2 * C \right) \)).
- 최종적으로 나온 \( hw * C \) 차원의 결과에 \( C * C \) 차원의 Linear layer를 통과하면 (\( [hw * C] \) x \( [C * C] \) = \( \Omega \left(hw * C^2 \right) \)).
- \( \Omega \left( 3hw * C^2 + hw^2 * C + hw^2 * C + hw * C ^ 2 \right) \) = \( \Omega \left(4hwC^2 + 2hw^2C \right) \)
- 먼저, \( hw * c \)의 Patch Embedding이 \( C * C \)의 차원을 가지는 각각의 Query, Key, Value의 Linear layer로 입력됩니다 ( \( [hw * C] \) x \( [C * C] \) x 3 = \( \Omega \left(3hw * C^2 \right) \)).
- (b) Swin
- 위의 1번과 동일합니다.
- 각각의 Query, Key, Value는 \( hw * C \)의 차원을 가집니다.
- Query, Key, Value는 M * M의 window size만큼 분할되고 각각 W-Query, W-Key, W-Value은 \( M^2 * C \)의 차원을 가집니다.
- 각 Window에 대한 W-Query와 W-Key의 곱으로 유사도를 구합니다(\( [M^2 * C] \) x \( [C * M^2] \) = \( \Omega \left( M^4 * C \right) \)).
- 각 Window에 대한 유사도 \( [M^2 * M^2] \)차원과 \( [M^2 * C] \)차원을 가지는 각 Window에 대한 Value를 곱해주면 (\( \Omega \left(M^4 * C \right) \)).
- 최종적으로 하나의 Window에 대한 시간 복잡도는 \( \Omega \left(M^4 * C + M^4 * C \right) \) = \( \Omega \left(2M^4 * C \right) \) 입니다.
- Window의 개수는 \( hw / M^2 \)개이므로, 전체에 대한 시간 복잡도는 \( \left( hw / M^2 \right) \) x \( \Omega \left(2M^4 * C \right) \) = \( \Omega \left( 2M^2hwC \right) \) 입니다.
- Query, Key, Value는 M * M의 window size만큼 분할되고 각각 W-Query, W-Key, W-Value은 \( M^2 * C \)의 차원을 가집니다.
- 위의 4번과 동일합니다.
- 전체 시간복잡도를 구하면 \( \Omega \left( 3hw * C^2 + 2M^{2}hwC + hw * C^2 \right) \) = \( \Omega \left( 4hwC^2 + 2M^2hwC \right) \)
- 위의 1번과 동일합니다.
Shift window
Window로 쪼개서 연산하는 방식은 연산속도에 큰 이점이 있지만, 앞에서 말한 것과 같이 Window 끼리의 상호작용이 부족합니다. 그렇기 때문에 본 논문에서는 연산 속도에 이점은 가져가 돼, Window 끼리의 상호작용을 충분히 할 수 있는 navie 한 Shifted window 방식은 아래와 같습니다.
왼쪽은 기존 Window 방식입니다. 이 \( \lceil \frac {h} {M} * \frac {w} {M} \rceil \) 개의 Window들에 대해 각각 독립적으로 self-attention을 시행하고, \( \lceil \left( \frac {h} {M} + 1 \right) \rceil \) * \( \lceil \left( \frac {w} {M} + 1 \right) \rceil \) 개의 추가적인 Windows로 나누어 각 독립적인 self-attention을 추가적으로 시행합니다. 이 Windows를 나누는 방식은 아래와 같습니다.
window size인 M이 4, 가로 세로는 각각 8이라고 했을 때, 2 * 2 파티션이 3 * 3개의 파티션으로 분할됩니다. 분할되는 기준은 중앙에 하나의 Window를 배치하고(위에선 Window 5), 연이어서 상, 하, 좌, 우, 대각에 배치하여 파티션을 나눕니다.
이러한 방식도 2 * 2의 windows가 3 * 3으로 늘어남에 따라 2.25배 정도의 계산량을 요구한다고, 본 논문에서는 이러한 추가적인 계산을 거의 요구하지 않는 cyclic-shifting 방식을 제안하였습니다.
기존 window 파티션에서 좌상단 부분을, 우하단으로 cyclic-shifting 시키는 것입니다. 이 상태에서 self-attention을 진행하면 각 window에 독립적으로 수행된 self-attention이 다른 window에도 적용될 수 있는 것입니다. 위 그림에서 같은 색으로 칠해진 부분은 기존 2 * 2 windows에서 self-attention이 수행된 것이므로, 중복 attention 연산을 제한하기 위해 masked self-attention을 진행합니다.
위와 같이 mask multihead self-attention을 진행하면 중복 연산은 제한하고, naive shifted window 방식과 같이 기존 2 * 2 windows에서 3 * 3 windows로 늘린 만큼의 계산량을 요구하는 것이 아닌, 기존과 같은 2 * 2 windows만큼의 계산량에 연산을 진행함에 따라 더 효과적인 방법입니다. 최종적으로 cyclic shift에서 연산 진행 후, 원래대로 reverse cyclic shift로 돌려놓습니다.
위 방법을 진행하는 전체 Swin Transformer block은 아래와 같이 계산됩니다.
기존 W-MSA의 방식과 cyclic shift 하여 연산하는 SW-MSA 방식으로 두 가지 MSA가 연산된 것을 볼 수 있습니다. MLP와 LayerNorm, residual connctection은 기존과 동일합니다.
Relative position bias
본 논문에서는 이 Swin Transformer의 성능을 더 이끌 방법으로 Relative position bias를 각 헤드에 더해주었습니다. 기존에는 Absolute bias를 더해주었는데, 이 부분은 오히려 성능 저하를 나타냈다고 합니다.
Relative position bias를 적용하는 방식은
- 먼저, 현재 위치를 기준으로 상대적인 거리를 계산하는 방식임. 예컨대 한 축을 기준으로 0 -> 1 = -1, 0 -> 2는 -2의 거리를 나타냅니다([0 <- 1 = 1, 0 <- 2 = 2], 역은 서로 반대 부호).
- 한 Window의 패치의 수는 \( M^2 \) 이므로, \( B \in \mathbb {R}^{M^2 * M^2} \)는 [-M + 1, M - 1]의 범위를 가지고 아래는 M이 2일 때 예시입니다.
- 왼쪽은 x축이 변할 때의 거리 변화가 나타내지는 그림이고, 오른쪽은 y축이 변할 때의 거리 변화가 나타내집니다(각 색깔이 변할 때 거리가 측정됨)
- 이 두 축에 대한 결과를 Index로 나타내기 위해 M-1씩 더해주고, 결합해 줍니다.
- 이제 이 Index들을 더 작은 사이즈 bias matrix인 \( \hat {B} \in \mathbb {R}^{\left( 2M - 1\right) * \left( 2M -1 \right)} \)로 접근하여 Relative position bias값을 구합니다. 이 bias matrix는 위치 정보를 학습할 수 있게 파라미터화 한 것입니다.
- 최종적으로 구한 Relative Position을 더해줍니다.
실험
이미지 분류에서, 왼쪽 결과를 보면 비슷한 파라미터를 가진 모델 중, FLOPs가 가장 낮으며 정확도는 가장 높게 나왔습니다. 이것은 속도와 정확도 간에 더 나은 Trade-off를 가졌다고 할 수 있고, 모든 구조가 Transformer로 구성되었기 때문에 더 나은 향상을 위한 잠재력이 남아있다고 합니다.
또한, 오른쪽 결과는 더 큰 데이터 세트에 대해 사전 학습 시킨 것인데, 역시 Transformer의 특성에 맞게 많은 데이터에 대한 사전 학습을 시키니 성능이 CNN 모델과 기존 모델의 성능 향상을 볼 수 있습니다.
이미지 분류뿐만 아니라 Object Detection, Semantic Segmentation에서도 CNN을 백본으로 사용했을 때 보다 더 나은 결과를 보여줍니다.
왼쪽은 Object Detection, 오른쪽은 Semantic Segmentation의 결과입니다.
다음은 분류, Object Detection, Semantic Segmentaiton에서 Shifted windows 방식과 Relative position bias를 제거하였을 때 성능에 얼마나 영향을 미치는 가에 대한 결과입니다.
모든 Task에서 확연한 성능 차이를 보였네요. 마지막으로 속도를 보면
맨 아래 cyclic 방식이 모든 스테이지에서 가장 빠른 성능을 보였습니다.
오늘은 요즘 범용적으로 사용하고 있는 백본 모델인 Swin Transformer 논문을 한 번 리뷰해보았습니다.
자세한 내용은 본 논문을 참고하시고, 언제나 게시글에 대한 피드백은 환영입니다.
긴 글 읽어주셔서 감사합니다.