for textmining

Gated Recurrent Unit (GRU)

|

이번 글에서는 Recursive Neural Networks(RNN)의 대표적인 셀 가운데 하나인 Gated Recurrent Unit(GRU)에 대해 살펴보도록 하겠습니다. 이 글은 기본적으로 미국 스탠포드 대학의 CS231n, CS224d 강좌를 참고로 하되 고려대학교 데이터사이언스 연구실의 김해동 석사과정이 만든 자료를 정리하였음을 먼저 밝힙니다. 김해동 석사과정은 GRU의 순전파(forward propagation)역전파(backward propagation) 과정을 알기 쉽게 설명하였습니다. 기본적인 RNN 구조와 Long-Short Term Memory(LSTM) 셀에 대해 궁금하신 분은 이곳을 참고하시기 바랍니다. 그럼 시작하겠습니다.

GRU 개요

GRU는 LSTM의 장점을 유지하면서도 계산복잡성을 확 낮춘 셀 구조입니다. GRU도 Gradient Vanishing/Explosion 문제를 극복했다는 점에서 LSTM과 유사하지만 게이트 일부를 생략한 형태입니다. GRU는 크게 update gatereset gate 두 가지로 나뉩니다.

두 게이트 모두 현 시점의 입력값($x_t$)과 직전 시점 은닉층값($h_{t-1}$)을 반영해 구합니다. 활성함수($σ$)는 시그모이드를 씁니다.

$W, U$는 각각 입력값과 은닉층값과 선형결합하는 파라메터이고요. 위 첨자 $z, r$은 각각 update gate, reset gate에 속한다는 뜻입니다.

자 이제부터는 기억(memory)에 관련된 과정입니다. 우선 현 시점($t$)에서 기억해둘 만한 정보를 아래와 같이 정의합니다.

위 식을 해석하면 이렇습니다. 현 시점 정보($Wx_t$)와 과거 정보($Uh_{t-1}$)를 반영하되, 과거 정보를 얼마나 반영할지는 reset gate 값에 의존한다는 이야기입니다.

reset gate의 활성함수는 시그모이드이므로 0~1 사이의 범위를 갖습니다. $r_t$값이 0이라면 과거 정보를 모두 잊고, 1이라면 과거 정보를 모두 기억합니다. $r_t$값에 상관없이 현재 정보는 반영됩니다.

위 식의 활성함수는 하이퍼볼릭탄젠트입니다. -1~1 사이의 범위를 갖습니다. 위 식엔 update, reset gate와 달리 $W,U$에 위 첨자가 없는데, 아예 다른 파라메터라는 점에 주의해야 합니다. 한편 ⊙는 Hadamard product(요소별 곱셈)를 뜻합니다.

다음 상태(state)로 업데이트하는 식은 아래와 같습니다.

위 식을 해석하면 이렇습니다. $h_{t-1}$는 과거 정보, $h_t$ 틸다는 현재 정보입니다. 이를 얼마나 조합할지 결정하는 건 update gate의 $z_t$입니다.

update gate의 활성함수는 시그모이드이므로 0~1 사이의 범위를 갖습니다. $z_t$가 0이라면 과거 정보를 모두 잊고, 현재 정보만을 기억합니다. $z_1$이 1이라면 과거 정보를 모두 기억하지만, 현재 정보는 모두 무시합니다.

GRU 셀을 그림(출처)으로 나타내면 아래와 같습니다.

GRU의 순전파

이제부터 본격적으로 GRU의 순전파와 역전파에 대해 설명할 예정인데요. 기본적인 방식에 관해서는 이 글을 참고하시면 좋을 것 같습니다.

GRU의 순전파를 계산그래프로 나타내면 아래 그림과 같습니다. 그림으로 보면 복잡해보이지만 지금까지 설명드린 수식과 완전히 동일합니다. 다만 $h_t$에 $W_{out}$을 곱해 $y_t$를 만드는 과정은 엄밀히 말해 GRU 셀 내부 작동은 아니지만 이해를 돕기 위하여 셀 내부에 그렸습니다.

어쨌든 $t$ 시점의 GRU 셀의 입력은 $x_t$, $h_{t-1}$, 출력은 $h_t$입니다. GRU 셀과 연결된 Softmax-with-Loss 계층은 $y_t$를 입력으로 받아 $t$ 시점의 Loss $L_t$를 출력합니다.

우리가 업데이트해야 할 파라메터는 $U^{z}$, $U^r$, $U$, $W^z$, $W^r$, $W$입니다.

GRU의 역전파

$t$ 시점의 GRU 셀이 Softmax-with-Loss 노드로부터 최초로 받는 그래디언트는 $∂L_t/∂y_t$입니다. 아래 그림에서는 이를 편의상 $dy_t$라고 적었습니다. 이후 모든 표기는 이 방식을 따랐습니다. $h_t$와 $y_t$는 곱셈 노드로 연결돼 있기 때문에 $dW_{out}$과 $dh_t$는 흘러들어온 그래디언트 $dy_t$에 순전파 때 입력 신호들을 서로 바꾼 값을 각각 곱한 값입니다. $dh_t$ 틸다는 흘러들어온 그래디언트 $dh_t$에 Hadamard product의 로컬 그래디언트를 곱해 구합니다.

우선 위쪽부터 살펴보겠습니다. $dr_t$ 틸다는 흘러들어온 그래디언트 $dh_t$ 틸다에 하이퍼볼릭탄젠트의 로컬 그래디언트를 곱해 구합니다. $dr_t$는 흘러들어온 그래디언트 $dr_t$ 틸다에 Hadamard product의 로컬 그래디언트를 곱해 구합니다. $dinput_r$은 흘러들어온 그래디언트 $dr_t$에 시그모이드 함수의 로컬 그래디언트를 곱해 구합니다. $dUh_{t-1}$은 흘러들어온 그래디언트 $dr_t$ 틸다에 Hadamard product의 로컬 그래디언트를 곱해 구합니다.

이제 아래쪽을 살펴보겠습니다. $dz_t$는 흘러들어온 그래디언트 $dh_t$에 Hadamard product의 로컬 그래디언트를 곱해 구합니다. $dinput_z$는 흘러들어온 그래디언트 $dz_t$에 시그모이드 함수의 로컬 그래디언트를 곱해 구합니다. $dWx_t$는 흘러들어온 그래디언트 $dh_t$ 틸다에 하이퍼볼릭탄젠트의 로컬 그래디언트를 곱해 구합니다. $input_r$과 $input_r$은 덧셈 노드로 연결돼 있으므로 흘러들어온 그래디언트가 그대로 전파됩니다.

이상을 종합하면 우리가 구해야 하는 $dW_x$와 $dU_h$는 아래와 같습니다. $dh_t$는 $∂L_t/∂h_t$를 의미하기 때문에 $cvx$를 기준으로 위 그림 위쪽은 모두 $∂h_t/∂W_x$를 계산하는 과정이라고 이해하면 될 것 같습니다. 마찬가지로 $cvx$를 기준으로 위 그림 위쪽은 모두 $∂h_t/∂U_h$를 구하는 과정으로 보면 됩니다.

the backpropagation through time (BPTT)

지금까지 설명해드린 그림들은 $t$ 시점의 셀 하나를 보여드렸습니다. 그런데 셀을 다양하게 구성해 RNN 네트워크를 구성할 수 있습니다. GRU 역전파시 그래디언트는 아래 그림의 파란색 셀 개수만큼 재귀적으로 합쳐져 전파됩니다.

다시 $t$ 시점의 GRU 셀로 돌아가보겠습니다. $δ_t$는 흘러들어온 그래디언트 $dh_t(=∂L_t/∂h_t)$에 Hadamard product의 로컬 그래디언트 $∂h_t/∂h_{t-1}$를 곱해 구합니다. 이를 수식으로 정리하면 아래와 같습니다.

여기서 $h_{t-1}$은 $t-1$ 시점의 GRU 셀의 출력 결과라는 점에 주목해야 합니다. 다시 말해 $δ_t$는 적색 화살표를 따라 $t-1$ 시점의 GRU 셀로 흘러들어간다는 이야기입니다.

이번엔 $t-1$ 시점의 GRU 셀을 살펴 보겠습니다. $δ_t$는 $t-1$ 시점의 Loss인 $L_{t-1}$로부터 전파되는 그래디언트 $dh_{t-1}$과 합쳐져 역전파되는 걸 확인할 수 있습니다. 이렇듯 GRU 셀은 직전 시점의 정보를 받아 다음 스텝으로 순전파하기 때문에 역전파시엔 그래디언트가 재귀적으로 합쳐져 흘러 들어갑니다.

수식으로 이해하는 GRU 역전파

시퀀스 길이가 3인 GRU 구조를 도식적으로 나타내면 그림과 같습니다. 우리는 파라메터 $W_x$와 $U_h$를 업데이트하는 데 관심이 있다는 걸 상기하고 아래 설명을 천천히 따라가 봅시다.

$δ_3$은 아래 식과 같습니다. 흘러들어온 그래디언트 $dh_3(=∂L_3/∂h_3)$에 Hadamard product의 로컬 그래디언트 $∂h_3/∂h_2$를 곱해 구합니다.

$δ_2$는 흘러들어온 그래디언트 $δ_3+dh_2(=∂L_2/∂h_{2})$에 로컬 그래디언트 $δh_2/δh_1$를 곱해 구합니다.

자 이제 $∂L_1/∂W_x$를 구할 준비가 다 되었습니다. 아래 식과 같습니다. 자세히 살펴보시면 두번째 시점과 세번째 시점에서 전파된 그래디언트가 모두 반영이 되고 있는 점을 확인할 수 있습니다.

같은 방식으로 $∂L_2/∂W_x, ∂L_3/∂W_x$을 구해 모두 더하면 아래와 같습니다.

이를 일반화하여 식을 정리하면 아래와 같습니다. 같은 방식으로 $U_h$에 대한 그래디언트도 구할 수 있습니다.

Comments