Abstract
- temporal dimension에 multihead pooling attention을 추가한 ViT
- computational complexity 감소
- temporal dimension을 더 aware하는 ViT
Motivation
일반적으로 CNN에서 발전된 multiscale feature를 분석하는 방식과 ViT를 연결하는 multiscale feature hierarchies를 가진 trasnformer model을 만든다.
Fig. 1을 보면, 일반적인 ViT와 다르게 MViT는 channel-resolution 'scale' stage가 존재한다. hierarchical하게 존재하는 stages에서, channel은 증가하고 spatial resolution은 감소한다. 결과적으로 일반적인 hierarchical feature처럼, early layer는 high-resolution low-level feature를 가지고, deeper layer는 coarse, complex high-level feature를 가진다.
이 디자인의 장점은, video multiscale model에서 강력한 implicit temporal bias를 가진다는 것이다. 이는 일반적인 ViT를 shuffled frame video에 train시켜도 performance가 decay되지 않아 temporal feature를 거의 사용하지 않는다는 것을 보이는 것과 반대로, MViT를 같은 방식으로 train하면 performance가 크게 저하되어 temporal information을 사용한다는 것을 나타낸다.
Methods
Multi Head Pooling Attention
MHPA(Multi Head Pooling Attention)을 제안한다. 원래 MHA는 dimension이 고정되어 있지만, MHPA는 latent tensor를 pooling해서 resolution을 줄인다(See Fig. 3).
원래 MHA에서는 다음과 같이 QKV triple을 계산한다:
$$ \hat Q = XW_Q, \hat K = XW_K, \hat V = XW_V $$
각 weight matrix $W$는 $D×D$로 정의된다 .이를 줄이기 위해 pooling operator $\mathcal P$를 사용한다.
- Pooling Operator
pooling operator $\mathcal P (\cdot; Θ)$를 정의한다. 이는 각 dimension에 대해 kernel computation을 진행한다. 즉, $Θ:=(k,s,p)$에서 각각은 $k_T×k_H×k_W$와 같이 3 dim의 크기로 정의된 parameter이다. 각각은 pooling kernel, stride, padding을 의미한다.
이 경우 input tensor $L=T×H×W$에 대해 $\tilde L$은 pooling operation에 의해 다음과 같이 계산된다:
$$ \tilde L = \lfloor \frac{L+2p-k}{s} \rfloor + 1 $$
이후 $D$ dim으로 flatten되어 $ℝ^{\tilde L ×D}, \tilde L = \tilde T × \tilde H × \tilde W$가 된다.
여기서 $k$는 $p$와 함께 shape preserving 하게 사용되어 reduction의 factor는 stride $s_Ts_Hs_W$에만 dependent하게 된다.
→ 복잡하게 써 있지만 그냥 conv layer에서의 pooling과 같다. 다만 한 dim이 time이라는 점과 pooling을 하는 kernel이 temporal dimension까지 포함하는 tensor라는 점이 특이하고 재미있다. 즉, temporal wise로도 pooling을 해서 서로 다른 time에서의 matrix에 대해서도 정보가 intergrate되는 효과가 있는 것이다.
- Pooling Attention
pooling 된 QKV triple을 이용해서 attention을 수행한다. 이는 일반 attention과 같다.
$$PA(\cdot) = \text{Softmax}(\mathcal P (Q;Θ_Q)\mathcal P (K;Θ_K)^T/\sqrt d)\mathcal P (V;Θ_V)$$
normalizing factor $\sqrt d$는 row wise matrix의 inner product이다.
위에서 $s_K \equiv s_V$이므로 결과적으로 $PA$의 output은 $ℝ^{\tilde L_Q × D}$가 되고, 이는 stride factor $s^Q_Ts^Q_Hs^Q_W$로 reduce된 sequence length를 의미한다.
- Computational Analysis
attention 과정에서 pooling으로 인해 줄어드는 computation reduction factor는 다음과 같이 계산된다:
$$f_j=s^j_T\cdot s^j_H\cdot s^j_W, \forall j\in \{Q,K,V\}$$
그러나 MHPA의 complexity도 고려해야 하는데, 이는 input tensor가 $D×T×H×W$인 경우 다음과 같다:
$$O(THWD/h(D+THW/f_qf_K))$$
다만 이에 비례하여 memory도 증가하므로 channel $D$와 sequence length $THW/f_Qf_K$는 design choice가 된다. 자세한 내용은 paper를 참조하면 좋겠다.
Multiscale Transformer Networks
MViT는 channel은 점점 늘리고, spatiotemporal resolution(T, H, W)는 줄이는 방식으로 동작한다. 이 때문에 초반에는 fine resolution과 coarse channel을 가지고 있다가 뒷 layer에서는 coarse resolution과 fine channel을 갖게 된다. 이는 Tab. 1의 original ViT와 Tab. 2의 MViT를 비교하며 쉽게 확인할 수 있다.
Experiments
full length video에서 $T$ frame을 temporal stride $τ$로 추출한 경우 $T×τ$로 denote되었다.
Ablation on Kinetics
- Frame Shuffling
Table 9에서 input frame을 random shuffle했을 경우 top-1 accuracy가 ViT가 떨어지는 크기에 비해 크게 떨어지는 것을 볼 수 있다. 이는 기존 ViT가 temporal information을 반영하지 않는 bag-of-frames 형태로 inference함을 의미한다. 반면 MViT에서는 temporal한 방향으로도 modeling이 이루어진다고 생각할 수 있다.
- Separate Space & Time Embedding in MViT
Tab 11.은 positional embedding에 따른 param 개수와 accuracy를 보여준다.
References
[1] Fan, H., Xiong, B., Mangalam, K., Li, Y., Yan, Z., Malik, J., & Feichtenhofer, C. (2021). Multiscale vision transformers. In Proceedings of the IEEE/CVF international conference on computer vision (pp. 6824-6835).
Footnotes
'DL·ML > Paper' 카테고리의 다른 글
VideoChat2 (CVPR 2024, MLLM) (0) | 2024.05.28 |
---|---|
UMT(ICCV 2023 Oral, Video Foundation Model) (0) | 2024.05.28 |
U-Net (0) | 2024.04.15 |
GLA-GCN(ICCV 2023, 3D HPE) (0) | 2024.04.03 |
AGCN (CVPR 2019, action recognition) (0) | 2024.04.02 |