[DL] RNN(Recursive Neural Network)의 이해

2021. 12. 26. 16:00·DL·ML
목차
  1. RNN(Recurrent Neural Network)
  2. Sequential data
  3. Recurrent Neural Network
  4. Backpropagation
  5. Vanishing gradient and Exploding Gradient problem
  6. References

RNN(Recurrent Neural Network)

Sequential data

일반적으로 deep neural network에서는 input과 output이 하나의 vector인 one-to-one으로 이루어진다. 하지만 어떤 경우에는 sequential data가 input이나 output에 포함되는 경우가 있다. 예를 들어 stochastic process(time series)[각주:1]나 ordered data structures가 있다. speech recognition이 time series에 해당하는 주요 예시가 될 것이고, machine translation이 ordered data structure의 예가 될 것이다. 이런 경우에 one-to-many나 many-to-many, many-to-one의 형태로 input과 output이 구성될 수 있다. 

 

Figure 1 : 이미지를 glimpse의 series로 classify한다.

요즈음에는 non-sequence data의 경우에도 sequential하게 processing하는 경우도 있다. 위의 Figure은 MNIST dataset을 visual attention의 sequence로 처리하는 과정의 일부이다.

 

즉, observation sequence x={x1,x2,...,xT}와 corresponding label y={y1,y2,...,yT}에 대해서, mapping transformation f:x↦y을 학습하는 것이다. 이 경우에, sequential data를 model하는 방법으로서 RNN이 사용된다.


Recurrent Neural Network

Figure 2 : RNN example, 좌측은 recursive description이고, 오른쪽은 unfolded한 것이다.

RNN은 이전의 입력 데이터들에 대한 정보를 hidden state h라는 모델 내에 숨겨진 unit에 저장한다. 그리고 시간 t에서 새로운 input xt가 들어올 때, xt 뿐 아니라 이전 state의 ht−1도 함께 포함하여 ht와 output zt를 계산하게 된다. 이를 수식으로 표현하면 다음과 같다 : 

ht=f(ht−1,xt)

따라서, 현재 ht는 현재의 input xt에 대한 정보만 저장하고 있는 것이 아니라 지금까지 입력된 모든 x에 대한 정보를 포함하고 있다. 

 

이때 hidden state에 대한 non-linear transformation은 tanh를 사용하고, output으로는 softmax를 사용한다.

ht=tanh(Whhht−1+Wxhxt+bh)

zt=softmax(Whzht+bz)

 

한 가지 기억해야 할 점은, 각 time에서 사용하는 parameter W는 모든 sequence에 대해서 동일하게 사용된다. 또한 여기에서는 Whh,Wxh등으로 나누어 표기하고 있지만, 실제로는 하나의 행렬 W의 부분을 나누어 사용하게 된다. 즉, update해야 할 paramter는 W하나 뿐이다.


Backpropagation

sequence labeling에서, Loss function은 negative log likelihood를 사용한다. 

L(x,y)=−∑tytlog⁡zt

여기서 zt는 output vector z에서 정답 레이블의 값을 나타낸다. y_t는 정답 레이블일 때만 1이고 나머지는 0이므로, 정답 레이블일 경우에만 loss function에 포함되고, 그 경우 z_t가 높을 경우 log의 값이 커지고, negative log likelihood, 즉 loss function의 값이 줄어들게 된다. 

 

이때 softmax에 들어가기 전 값을 αt=Whzht+bz라고 하면 위의 값은 다음과 같이 나타낼 수 있다.

 

L(x,y)=−∑tytlog⁡(softmax(αt))

 

이를 α에 대해서 미분해 보자. 먼저 정답 레이블 j일 때, 

p(y^h|ht;Θ)=exp⁡(αj(Θ)∑kexp⁡(αk(Θ))

이므로 αj(Θ)의 partial derivative를 구하면

∂yjlog⁡p(y^j|ht;Θ)∂αj=∂yjlog⁡(exp⁡(αj)∑kexp⁡(αk))∂αj=yj∑kexp⁡(αk)(−(exp⁡(αj))2(∑kexp⁡(αk))2+exp⁡(αj)∑kexp⁡(αk))=yj(∑kexp⁡(αk)−exp⁡(αj)∑kexp⁡(αk))=yj(1−pj)

 

이다.

 

비슷한 방법으로 ∀k≠j인 경우에 그 prediction p(y^k) 에 대해서도, αj(Θ)에 대한 미분을 구하면 다음과 같다.

 

∂yklog⁡p(y^k|yt;Θ)∂αj=ykp(y^k−exp⁡(αk(Θ))exp⁡(αj(Θ))[∑sexp⁡(αs(Θ))]2=−ykp(y^j)

 

전체 vector y의 αj(Θ)에 대한 그래디언트는 위 두 값의 합이다. 따라서 다음과 같이 계산된다.

 

 

∂p(y^)∂αj=∑j∂yjlog⁡p(y^j|hi;Θ)∂αj=∂log⁡p(y^j)|hi;Θ)∂αj+∑k≠j∂log⁡p(y^k|hi;Θ)∂αj=yj−yjp(y^j)−∑k≠jykp(y^j)=yj−p(y^j)(yj+∑k≠jyk)=yj−p(y^j)

 

이제 W와 b에 대한 derivative를 계산하면 된다. 먼저 Whz에 대한 loss의 propagation는 다음과 같다 :

∂LWhz=∑t∂L∂zt∂zt∂Whz

 

Whh와 Wxh에 대해서는 이전 hidden state에 있는 dependent들을 모두 생각해야 한다. 따라서 t+1에서의 output만 고려했을 떄, propagation 값은 다음과 같다 : 

∂L(t+1)∂Whh=∑k=1t∂L(t+1)∂zt+1∂zt+1∂hh+1∂ht+1∂hk∂ht∂Whh

 

만약 모든 time에 대해서 backpropagation 값을 고려한다면 다음과 같다 : 

∂L∂Whh=∑t∑k=1t∂L(t+1)∂zt+1∂zt+1∂hh+1∂ht+1∂hk∂ht∂Whh

 

같은 방법으로 이전 hidden state에 dependency가 있는 Wxh에 대한 backpropagation 값도 계산할 수 있다 : 

∂L∂Wxh=∑t∑k=1t∂L(t+1)∂zt+1∂zt+1∂hh+1∂ht+1∂hk∂ht∂Wxh

 


Vanishing gradient and Exploding Gradient problem

위의 gradient 계산 부분을 보면 이전 hidden state에 dependency가 있는 경우에 ∂ht+1∂hk이 포함되어 있는 것을 볼 수 있다.

 

이는 chain rule에 의해서 다음과 같이 계산된다 : 

∂ht+1∂ht∂ht∂ht−1⋯∂hk+1∂hk

 

이는 풀어서 쓰면 다음과 같다 : 

∂tanh⁡(Wht+Wx+b)∂ht∂tanh⁡(Wht−1+Wx+b)∂ht−1⋯∂tanh⁡(Whk+Wx+b)∂hk

 

이때 

∂tanh⁡(Wht+Wx+b)∂ht=(1−tanh2⁡(Wht+Wx+b))W

 

문제가 되는 것은 W term인데, 긴 sequence에 대해 backpropagate를 할 때에는 이 동일한 matrix를 계속해서 곱하게 되기 때문이다. 같은 값을 곱할 때에는 정확히 같은 값이 되거나, 0에 수렴하게 되거나(vanishing gradient), 발산하는 경우(exploding gradient)가 있다. 같은 값이 되는 경우는 practical하게는 거의 나타나지 않으므로 긴 sequence의 backpropagation에서는 항상 두 문제 중 하나가 발생한다고 보아야 한다.

 

이 심각한 문제를 해결하기 위해 여러 가지 방법이 제안되었는데, Gradient clipping도 그 중 하나이다. gradient clipping은 값의 L2 norm이 너무 커지면 나누어 주는 heuristic이다. 

 

결과적으로는 LSTM이 가장 효과적인 대안으로 제시되었고, 오늘날에는 vanilla RNN은 거의 사용하지 않는다. 


References

Gang Chen. (2016). A Gentle Tutorial of Recurrent Neural Network with Error Backpropagation. Department of Computer Science and Engineering, SUNY at Buffalo. 

Fei-Fei Li, J. Johnson, S. Yeung. (2017). Convolutional Neural Networks for Visual Recognition. Dept. of Comp. Sci., Stanford University, Palo Alto, CA, USA, spring 2017 [Online]. 

 

Footnotes

  1. time series는 index가 integer t로 이루어진 특수한 경우를 가리킨다. reference :  https://stats.stackexchange.com/questions/126791/is-a-time-series-the-same-as-a-stochastic-process [본문으로]

'DL·ML' 카테고리의 다른 글

[DL] Hierarchical Softmax  (0) 2022.07.01
[ML] Reduced Error Pruning  (0) 2022.03.18
Xavier initializer  (0) 2021.12.22
[paper review] ImageNet Classification with Deep Convolutional Neural Network (AlexNet)  (0) 2021.11.29
경사하강법(Gradient Descent)  (0) 2021.06.21
  1. RNN(Recurrent Neural Network)
  2. Sequential data
  3. Recurrent Neural Network
  4. Backpropagation
  5. Vanishing gradient and Exploding Gradient problem
  6. References
'DL·ML' Other articles in this category
  • [DL] Hierarchical Softmax
  • [ML] Reduced Error Pruning
  • Xavier initializer
  • [paper review] ImageNet Classification with Deep Convolutional Neural Network (AlexNet)
Jordano
Jordano
JordanoJordano 님의 블로그입니다.
  • Jordano
    Jordano
    Jordano
  • Total
    Today
    Yesterday
    • All categories
      • Introduction
      • Theatre⋅Play
      • Thinking
        • iDeAs
        • Philosophy
      • History
        • Cuba
        • China
      • CS
        • HTML·CSS·JavaScript
        • Dart·Flutter
        • C, C++
        • Python
        • PS
        • Algorithm
        • Network
        • OS
        • etc
      • DL·ML
        • Paper
        • Study
        • Project
      • Mathematics
        • Information Theory
        • Linear Algebra
        • Statistics
        • etc
      • etc
        • Paper
      • Private
      • Travel
  • Blog Menu

    • 홈
    • 태그
    • 방명록
  • Link

  • hELLO· Designed By정상우.v4.10.3
Jordano
[DL] RNN(Recursive Neural Network)의 이해
상단으로

티스토리툴바

단축키

내 블로그

내 블로그 - 관리자 홈 전환
Q
Q
새 글 쓰기
W
W

블로그 게시글

글 수정 (권한 있는 경우)
E
E
댓글 영역으로 이동
C
C

모든 영역

이 페이지의 URL 복사
S
S
맨 위로 이동
T
T
티스토리 홈 이동
H
H
단축키 안내
Shift + /
⇧ + /

* 단축키는 한글/영문 대소문자로 이용 가능하며, 티스토리 기본 도메인에서만 동작합니다.