for textmining

효과적인 RNN 학습

|

이번 글에서는 Recurrent Neural Network(RNN)을 효과적으로 학습시키는 전략들에 대해 살펴보도록 하겠습니다. 이 글은 Oxford Deep NLP 2017 course을 기본으로 하되 저희 연구실 장명준 석사과정이 만든 자료를 정리했음을 먼저 밝힙니다. 그럼 시작하겠습니다.

그래디언트 문제

RNN 역전파시 그래디언트가 너무 작아지거나, 반대로 너무 커져서 학습이 제대로 이뤄지지 않는 경우가 많습니다. 이를 각각 gradient vanishing, gradient exploding 문제라고 합니다. 이 문제를 수식으로 살펴보겠습니다. vanilla RNN 셀 $t$번째 시점의 히든스테이트 $h_t$는 다음과 같이 정의됩니다. 단 여기에서 하이퍼볼릭탄젠트 안의 식들을 모두 합쳐 $z_t$이라 두겠습니다. ($x_t$ : $t$번째 시점의 입력, $W_{hh}, W_{xh}, b_h$ : 학습 파라메터)

\[\begin{align*} { h }_{ t }=&\tanh { \left( { W }_{ hh }{ h }_{ t-1 }+{ W }_{ xh }{ x }_{ t }+{ b }_{ h } \right) }\\=&\tanh({z}_{t}) \end{align*}\]

다음 그림과 같은 RNN 구조에서 네번째 시점의 손실 $cost_4$에 대한 $h_1$의 그래디언트는 체인룰(chain rule)에 의해 아래와 같이 계산할 수 있습니다. 바꿔 말해 적색 화살표에 해당하는 그래디언트를 순차적으로 곱한 값이라고 보면 될 것 같습니다.

\[\frac { \partial { cost }_{ 4 } }{ \partial { h }_{ 1 } } =\frac { \partial { cost }_{ 4 } }{ \partial { \hat { p } }_{ 4 } } \frac { \partial { \hat { p } }_{ 4 } }{ \partial { h }_{ 4 } } \frac { \partial { h }_{ 4 } }{ \partial { h }_{ 3 } } \frac { \partial { h }_{ 3 } }{ \partial { h }_{ 2 } } \frac { \partial { h }_{ 2 } }{ \partial { h }_{ 1 } }\]

위 식을 일반화하여 $n$번째 시점의 손실 $cost_n$에 대한 $h_1$의 그래디언트는 다음과 같이 표시할 수 있습니다.

\[\frac { \partial { cost }_{ n } }{ \partial { h }_{ 1 } } =\frac { \partial { cost }_{ n } }{ \partial { \hat { p } }_{ n } } \frac { \partial { \hat { p } }_{ n } }{ \partial { h }_{ n } } \left( \prod _{ t=2 }^{ n }{ \frac { \partial { h }_{ t } }{ \partial { h }_{ t-1 } } } \right)\]

우리의 관심은 괄호 안입니다. 체인룰에 의해 그래디언트가 곱해지면서 이 값이 커지는지 작아지는지가 관심인 것입니다. $h_t=\tanh(z_t)$이므로 $t$번째 히든스테이트에 대한 $t-1$번째 히든스테이트의 그래디언트는 체인룰에 의해 다음과 같이 표시할 수 있습니다.

\[\frac { \partial { h }_{ t } }{ \partial { h }_{ t-1 } } =\frac { \partial { h }_{ t } }{ \partial { z }_{ t } } \frac { \partial { z }_{ t } }{ \partial { h }_{ t-1 } }\]

$z_t=W_{hh}h_{t-1}+W_{xh}x_t+b_h$이므로 $∂z_t/∂h_{t-1}$은 $W_{hh}$입니다. norm의 성질에 의해 다음 부등식이 성립합니다.

\[\left\| \frac { \partial { h }_{ t } }{ \partial { h }_{ t-1 } } \right\| \le \left\| \frac { \partial { h }_{ t } }{ \partial { z }_{ t } } \right\| \left\| \frac { \partial { z }_{ t } }{ \partial { h }_{ t-1 } } \right\| =\left\| \frac { \partial { h }_{ t } }{ \partial { z }_{ t } } \right\| \left\| { W }_{ hh } \right\|\]

norm의 성질에 의해 다음이 성립한다고 합니다(증명). 다시 말해 $W_{hh}$의 L2 norm은 $W_{hh}$의 가장 큰 고유값(eigenvalue)라는 뜻입니다.

\[{ \left\| { W }_{ hh } \right\| }_{ 2 }=\sqrt { { \lambda }_{ max } }\]

다시 우리의 관심인 문제로 돌아오겠습니다. RNN 역전파시 체인룰에 의해 $∂h_t/∂h_{t-1}$을 지속적으로 곱해주어야 합니다. 그런데 우리가 살펴봤듯이 $∂h_t/∂h_{t-1}$의 L2 norm은 절대적으로 $W_{hh}$의 L2 norm 크기에 달려 있습니다. 다시 말해 $W_{hh}$의 가장 큰 고유값이 1보다 크다면 앞쪽 스텝으로 올수록 그래디언트가 매우 커질 것이며, 1보다 작다면 반대로 매우 작아질 것입니다.

이와 별개로 $z_t$를 계산한 이후 취해지는 비선형함수 하이퍼볼릭탄젠트의 문제도 있습니다. 하이퍼볼릭탄젠트를 1차 미분한 그래프는 다음과 같습니다. 입력값이 -5보다 작거나 5보다 크면 그래디언트가 0으로 작아지는 걸 확인할 수 있습니다. RNN 역전파 과정에서 하이퍼볼릭탄젠트의 그래디언트도 계속 곱해지기 때문에 하이퍼볼릭탄젠트도 그래디언트 배니싱 문제에 일부 영향을 끼칠 수 있습니다.

셀 구조를 바꾸기 : LSTM

그래디언트 문제를 해결하기 위한 방법 가운데 가장 각광받는 것은 셀 구조를 아예 바꿔보는 겁니다. 계속 곱해서 문제가 생겼으니 이번에는 더하는 걸로 바꿔 봅시다. 이를 cell state라 하겠습니다.

\[{ c }_{ t }={ c }_{ t-1 }+\tanh { \left( V\left[ { x }_{ t-1 };{ h }_{ t-1 } \right] +{ b }_{ c } \right) }\]

하이퍼볼릭탄젠트 안에 있는 내용은 vanilla RNN 셀과 본질적으로 다르지 않습니다. $c_t$를 $c_{t-1}$로 미분하는 경우 우변에 $c_{t-1}$가 더해졌기 때문에 기본적으로 1을 확보할 수 있습니다. 그래디언트가 0으로 죽는 걸 막고자 함입니다.

그런데 각 스텝마다 그래디언트가 지속적으로 커진다면 되레 그래디언트 익스플로딩 문제가 발생할 수 있습니다. 밸런스를 맞춰주어야 합니다. cell state를 아래와 같이 바꿔보겠습니다.

\[\begin{align*} { c }_{ t }={ f }_{ t }\odot { c }_{ t-1 }+&{ i }_{ t }\odot \tanh { \left( V\left[ { x }_{ t-1 };{ h }_{ t-1 } \right] +{ b }_{ c } \right) } \\ \\ where\quad { i }_{ t }=&\sigma \left( { W }_{ i }\left[ { x }_{ t-1 };{ h }_{ t-1 } \right] +{ b }_{ i } \right) \\ { f }_{ t }=&\sigma \left( { W }_{ f }\left[ { x }_{ t-1 };{ h }_{ t-1 } \right] +{ b }_{ f } \right) \end{align*}\]

$i_t$와 $f_t$는 시그모이드가 취해진 값으로 0~1사이의 값을 가집니다. 각각 직전 시점의 정보, 현 시점의 정보를 얼마나 반영할지를 결정합니다. 물론 다음 히든스테이트를 만들 때도 gate를 둘 수 있습니다.

\[{ h }_{ t }={ o }_{ t }\odot \tanh { \left( { W }_{ h }{ c }_{ t }+{ b }_{ h } \right) } \\ where\quad { o }_{ t }=\sigma \left( { W }_{ o }\left[ { x }_{ t-1 };{ h }_{ t-1 } \right] +{ b }_{ o } \right)\]

Skip connection

셀 구조 말고 skip connection 기법을 사용할 수도 있습니다. 그래디언트가 흐를 수 있는 지름길(shortcut)을 만들어주는 것입니다. skip connection은 원래 Convolutional Neural Network에서 자주 사용된 기법인데 RNN 구조에서도 다음과 같이 적용할 수 있다고 합니다.

과적합, 드롭아웃

일반적이지는 않지만, RNN에서도 드롭아웃을 적용해 과적합을 피할 수 있습니다. 그런데 히든스테이트와 히든스테이트 사이에 드롭아웃을 적용하는 것은 그리 좋은 선택이 아니라고 합니다. 히든스테이드와 히든스테이트를 연결하는 가중치 $W_{hh}$가 모든 step에서 공유되기 때문입니다. 아래 그림을 보겠습니다.

위 예시에서 $h_t$를 만들 때 1, 3, 5번째 노드를 드롭아웃할 경우 $W_{hh}$에서 해당 행벡터가 업데이트되지 않습니다. 하지만 다음 step에서 얼마든지 업데이트될 가능성이 있습니다. 과적합을 피하기 위해서 학습을 덜한다는 본래 취지에 맞지 않는다는 이야기입니다.

굳이 과적합을 적용할 경우 다음과 같이 할 수 있다고 합니다. 배치 단위로 파라메터를 업데이트하는 경우가 많은데, 1회 iteration마다 드롭아웃을 하는 노드를 모든 step에 대해 통일시키는 겁니다.

vocabulary 문제

예측해야 할 단어(범주) 수가 너무 많아서 소프트맥스 확률 계산시 과부하가 걸리는 경우가 많습니다. 이를 위해 일부 단어만 뽑아 소프트맥스 확률을 구하고 여기서 구한 손실(loss)만큼만 파라메터를 업데이트하는 방법도 있습니다. 이와 관련 자세한 내용은 이곳을 참고하시면 좋을 것 같습니다.



Comments