셀프 어텐션 동작 원리
트랜스포머(transformer)의 핵심 구성요소는 셀프 어텐션(self attention)입니다. 이 글에서는 셀프 어텐션의 내부 동작 원리에 대해 살펴보겠습니다.
Table of contents
모델 입력과 출력
셀프 어텐션을 이해하려면 먼저 입력부터 살펴봐야 합니다. 그림1은 트랜스포머 모델의 전체 구조를, 그림2는 그림1에서 인코더 입력만을 떼어서 나타낸 그림입니다. 그림2와 같이 모델 입력을 만드는 계층(layer)을 입력층(input layer)이라고 합니다.
그림1 Transformer 전체 구조
그림2 인코더 입력
그림2에서 확인할 수 있듯이 인코더 입력은 소스 시퀀스의 입력 임베딩(input embedding)에 위치 정보(positional encoding)을 더해서 만듭니다. 한국어에서 영어로 기계 번역을 수행하는 트랜스포머 모델을 구축하다고 가정해 봅시다. 이때 인코더 입력은 소스 언어 문장의 토큰 인덱스(index) 시퀀스가 됩니다.*
* 우리는 전처리 과정에서 입력 문장을 토큰화한 뒤 이를 인덱스로 변환한 적이 있는데요. 토큰화 및 인덱싱과 관련해서는 이 글을 참고하면 좋을 것 같습니다.
예를 들어 소스 언어의 토큰 시퀀스가 어제
, 카페
, 갔었어
라면 인코더 입력층의 직접적인 입력값은 이들 토큰들에 대응하는 인덱스 시퀀스가 되며 인코더 입력은 그림3과 같은 방식으로 만들어집니다. 다음 그림은 이해를 돕고자 토큰 인덱스(어제
의 고유 ID) 대신 토큰(어제
)으로 표기했습니다.
그림3 인코더 입력 예시
그림3의 왼편 행렬(matrix)은 소스 언어의 각 어휘에 대응하는 단어 수준 임베딩인데요. 단어 수준 임베딩 행렬에서 현재 입력의 각 토큰 인덱스에 대응하는 벡터를 참조(lookup)해 가져온 것이 그림2의 입력 임베딩(input embedding)입니다. 단어 수준 임베딩은 트랜스포머의 다른 요소들처럼 소스 언어를 타겟 언어로 번역하는 태스크를 수행하는 과정에서 같이 업데이트(학습)됩니다.
입력 임베딩에 더하는 위치 정보는 해당 토큰이 문장 내에서 몇 번째 위치인지 정보를 나타냅니다. 그림3 예시에서는 어제
가 첫번째, 카페
가 두번째, 갔었어
가 세번째입니다.
트랜스포머 모델은 이같은 방식으로 소스 언어의 토큰 시퀀스를 이에 대응하는 벡터 시퀀스로 변환해 인코더 입력을 만듭니다. 디코더 입력 역시 만드는 방식이 거의 같습니다.
그림4는 그림1에서 인코더와 디코더 블록만을 떼어 그린 그림인데요. 인코더 입력층(그림2)에서 만들어진 벡터 시퀀스가 최초 인코더 블록의 입력이 되며, 그 출력 벡터 시퀀스가 두 번째 인코더 블록의 입력이 됩니다. 다음 인코더 블록의 입력은 이전 블록의 출력입니다. 이를 $N$번 반복합니다.
그림4 인코더-디코더
그림5는 트랜스포머의 전체 구조에서 모델의 출력층(output layer)만을 떼어낸 것입니다. 이 출력층의 입력은 디코더 마지막 블록의 출력 벡터 시퀀스입니다. 출력층의 출력은 타깃 언어의 어휘 수만큼의 차원을 갖는 확률 벡터가 됩니다. 소스 언어의 어휘가 총 3만개라고 가정하면 이 벡터의 차원수는 3만이 되며 3만 개 요솟값을 모두 더하면 그 합은 1이 됩니다. 이 벡터는 디코더에 입력된 타깃 시퀀스의 다음 토큰 확률 분포를 가리킵니다.
그림5 디코더 출력
트랜스포머의 학습(train)은 인코더와 디코더 입력이 주어졌을 때 모델 최종 출력에서 정답에 해당하는 단어의 확률 값을 높이는 방식으로 수행됩니다.
셀프 어텐션 내부 동작
그러면 트랜스포머 모델 핵심인 셀프 어텐션 기법이 내부에서 어떻게 동작하는지 살펴보겠습니다. 셀프 어텐션은 트랜스포머의 인코더와 디코더 블록 모두에서 수행되는데요. 이 글에서는 인코더의 셀프 어텐션을 살펴보겠습니다.
(1) 쿼리, 키, 밸류 만들기
그림4를 보면 인코더에서 수행되는 셀프 어텐션의 입력은 이전 인코더 블록의 출력 벡터 시퀀스입니다. 그림3의 단어 임베딩 차원수($d$)가 4이고, 인코더에 입력된 단어 갯수가 3일 경우 셀프 어텐션 입력은 수식1의 $\mathbf{X}$과 같은 형태가 됩니다. 4차원짜리 단어 임베딩이 3개 모였음을 확인할 수 있습니다. 수식1의 $\mathbf{X}$의 요소값이 모두 정수(integer)인데요. 이는 예시일 뿐 실제 계산에서는 거의 대부분이 실수(real number)입니다.
수식1 입력 벡터 시퀀스 X
$ \mathbf{X}=\begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 2 & 0 & 2 \\ 1 & 1 & 1 & 1 \end{bmatrix} $
셀프 어텐션은 쿼리(query), 키(key), 밸류(value) 3개 요소 사이의 문맥적 관계성을 추출하는 과정입니다. 다음 수식처럼 입력 벡터 시퀀스($\mathbf{X}$)에 쿼리, 키, 밸류를 만들어주는 행렬($\mathbf{W}$)을 각각 곱합니다. 입력 벡터 시퀀스가 3개라면 수식2를 적용하면 쿼리, 키, 밸류는 각각 3개씩 총 9개의 벡터가 나옵니다.
참고로 수식2에서 $\times$ 기호는 행렬 곱셈(matrix multiplication)을 가리키는 연산자인데요. 해당 기호를 생략하는 경우도 있습니다. 행렬 곱셈이 익숙하지 않은 분들은 이 글을 참고하시면 좋겠습니다.
수식2 쿼리, 키, 밸류 만들기
\[\mathbf{Q}=\mathbf{X} \times { \mathbf{W} }_{ \text{Q} } \\ \mathbf{K}=\mathbf{X} \times { \mathbf{W} }_{ \text{K} } \\ \mathbf{V}=\mathbf{X} \times { \mathbf{W} }_{ \text{V} } \\\]수식3은 수식1의 입력 벡터 시퀀스 가운데 첫번째 입력 벡터($\mathbf{X}_{1}$)로 쿼리를 만드는 예시입니다.
- 수식3 좌변의 첫번째가 바로 $\mathbf{X}_{1}$입니다.
- 그리고 좌변 두번째가 수식2의 $\mathbf{W}_{\text{Q}}$에 대응합니다.
수식3 ‘쿼리’ 만들기 (1)
$ \begin{bmatrix} 1 & 0 & 1 & 0 \end{bmatrix}\times \begin{bmatrix} 1 & 0 & 1 \\ 1 & 0 & 0 \\ 0 & 0 & 1 \\ 0 & 1 & 1 \end{bmatrix}=\begin{bmatrix} 1 & 0 & 2 \end{bmatrix} $
수식1의 입력 벡터 시퀀스 가운데 두번째 입력 벡터($\mathbf{X}_2$)로 쿼리를 만드는 식은 수식4, 세번째($\mathbf{X}_3$)로 쿼리를 만드는 과정은 수식5와 같습니다. 이때 쿼리 만드는 방식은 이전과 같습니다.
수식4 ‘쿼리’ 만들기 (2)
$ \begin{bmatrix} 0 & 2 & 0 & 2 \end{bmatrix}\times \begin{bmatrix} 1 & 0 & 1 \\ 1 & 0 & 0 \\ 0 & 0 & 1 \\ 0 & 1 & 1 \end{bmatrix}=\begin{bmatrix} 2 & 2 & 2 \end{bmatrix} $
수식5 ‘쿼리’ 만들기 (3)
$ \begin{bmatrix} 1 & 1 & 1 & 1 \end{bmatrix}\times \begin{bmatrix} 1 & 0 & 1 \\ 1 & 0 & 0 \\ 0 & 0 & 1 \\ 0 & 1 & 1 \end{bmatrix}=\begin{bmatrix} 2 & 1 & 3 \end{bmatrix} $
수식6은 입력 벡터 시퀀스 $\mathbf{X}$를 한꺼번에 쿼리 벡터 시퀀스로 변환하는 식입니다. 입력 벡터 시퀀스에서 하나씩 떼어서 쿼리로 바꾸는 수식3, 수식4, 수식5와 비교했을 때 그 결과가 완전히 같음을 확인할 수 있습니다. 실제 쿼리 벡터 구축은 수식6과 같은 방식으로 이뤄집니다.
수식6 ‘쿼리’ 만들기 (4)
$ \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 2 & 0 & 2 \\ 1 & 1 & 1 & 1 \end{bmatrix}\times \begin{bmatrix} 1 & 0 & 1 \\ 1 & 0 & 0 \\ 0 & 0 & 1 \\ 0 & 1 & 1 \end{bmatrix}=\begin{bmatrix} 1 & 0 & 2 \\ 2 & 2 & 2 \\ 2 & 1 & 3 \end{bmatrix} $
수식7은 입력 벡터 시퀀스 $\mathbf{X}$를 통째로 한꺼번에 키 벡터 시퀀스로 변환하는 식입니다. 수식7 좌변에서 입력 벡터 시퀀스에 곱해지는 행렬은 수식2의 ${\mathbf{W}}_{\text{Q}}$에 대응합니다.
수식7 ‘키’ 만들기
$ \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 2 & 0 & 2 \\ 1 & 1 & 1 & 1 \end{bmatrix}\times \begin{bmatrix} 0 & 0 & 1 \\ 1 & 1 & 0 \\ 0 & 1 & 0 \\ 1 & 1 & 0 \end{bmatrix}=\begin{bmatrix} 0 & 1 & 1 \\ 4 & 4 & 0 \\ 2 & 3 & 1 \end{bmatrix} $
수식8은 입력 벡터 시퀀스 $\mathbf{X}$를 통째로 한꺼번에 밸류 벡터 시퀀스로 변환하는 걸 나타냅니다. 수식8 좌변에서 입력 벡터 시퀀스에 곱해지는 행렬은 수식2의 ${\mathbf{W}}_{\text{V}}$에 대응합니다.
수식8 ‘밸류’ 만들기
$ \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 2 & 0 & 2 \\ 1 & 1 & 1 & 1 \end{bmatrix}\times \begin{bmatrix} 0 & 2 & 0 \\ 0 & 3 & 0 \\ 1 & 0 & 3 \\ 1 & 1 & 0 \end{bmatrix}=\begin{bmatrix} 1 & 2 & 3 \\ 2 & 8 & 0 \\ 2 & 6 & 3 \end{bmatrix} $
다음 세 가지 행렬은 태스크(예: 기계 번역)를 가장 잘 수행하는 방향으로 학습 과정에서 업데이트됩니다.
- ${\mathbf{W}}_{\text{Q}}$
- ${\mathbf{W}}_{\text{K}}$
- ${\mathbf{W}}_{\text{V}}$
(2) 첫 번째 쿼리의 셀프 어텐션 출력값 계산하기
이제 셀프 어텐션을 계산하기 위한 준비가 모두 끝났습니다! 수식9는 셀프 어텐션의 정의입니다.
수식9 셀프 어텐션
\[\text{Attention} (\mathbf{Q},\mathbf{K},\mathbf{V})= \text{softmax} (\frac { \mathbf{Q} { \mathbf{K} }^{ \top } }{ \sqrt { { d }_{ \text{K} } } } ) \mathbf{V}\]수식9를 말로 풀면 이렇습니다. 쿼리와 키를 행렬곱한 뒤 해당 행렬의 모든 요소값을 키 차원수의 제곱근 값으로 나눠주고, 이 행렬을 행(row) 단위로 소프트맥스(softmax)*를 취해 스코어 행렬을 만들어줍니다. 이 스코어 행렬에 밸류를 행렬곱해 줘서 셀프 어텐션 계산을 마칩니다.
* 소프트맥스(softmax)란 입력 벡터의 모든 요솟값 범위를 0 이상, 1 이하로 하고 총합을 1이 되게끔 하는 함수입니다. 어떤 입력이든 소프트맥스 함수를 적용하면 해당 값이 확률로 변환됩니다.
수식6의 쿼리 벡터 세 개 가운데 첫번째 쿼리만 가지고 수식9에 정의된 셀프 어텐션 계산을 수행해보겠습니다(수식10~수식12). 수식10은 첫번째 쿼리 벡터와 모든 키 벡터들에 전치(transpose)를 취한 행렬($\mathbf{K}^{\top}$)을 행렬곱한 결과입니다. 여기에서 전치란 원래 행렬의 행(row)과 열(column)을 교환해 주는 걸 뜻합니다.
수식10 첫번째 쿼리 벡터에 관한 셀프 어텐션 계산 (1)
$ \begin{bmatrix} 1 & 0 & 2 \end{bmatrix}\times \begin{bmatrix} 0 & 4 & 2 \\ 1 & 4 & 3 \\ 1 & 0 & 1 \end{bmatrix}=\begin{bmatrix} 2 & 4 & 4 \end{bmatrix} $
수식10 우변에 있는 벡터의 첫번째 요소값(2)은 첫번째 쿼리 벡터와 첫번째 키 벡터 사이의 문맥적 관계성이 녹아든 결과입니다. 두번째 요소값(4)은 첫번째 쿼리 벡터와 두번째 쿼리 벡터 사이의 문맥적 관계성, 세번째 요소값(4)은 첫번째 쿼리와 세번째 쿼리 벡터 사이의 문맥적 관계성이 포함돼 있습니다.
수식11은 수식10의 결과에 키 벡터의 차원수($d_{\text{K}}=3$)의 제곱근으로 나눈 후 소프트맥스를 취해 만든 벡터입니다.
수식11 첫번째 쿼리 벡터에 관한 셀프 어텐션 계산 (2)
\[\text{softmax} ([ \frac{2}{\sqrt{3}}, \frac{4}{\sqrt{3}}, \frac{4}{\sqrt{3}} ]) = [ 0.13613, 0.43194, 0.43194 ]\]첫번째 쿼리 벡터에 대한 셀프 어텐션 계산의 마지막은 수식12와 같습니다. 수식11의 소프트맥스 벡터와 수식8의 밸류 벡터들을 행렬곱해서 계산을 수행합니다. 이는 소프트맥스 벡터의 각 요솟값에 대응하는 밸류 벡터들을 가중합(weighted sum)한 결과와 같습니다. 다시 말해 수식12는 $0.13613 * [1, 2, 3] + 0.43194 * [2, 8, 0] + 0.43194 * [2, 6, 3]$과 동일한 결과라는 이야기입니다.
수식12 첫번째 쿼리 벡터에 관한 셀프 어텐션 계산 (3)
$ \begin{bmatrix} 0.13163 & 0.43194 & 0.43194 \end{bmatrix}\times \begin{bmatrix} 1 & 2 & 3 \\ 2 & 8 & 0 \\ 2 & 6 & 3 \end{bmatrix}=\begin{bmatrix} 1.8639 & 6.3194 & 1.7042 \end{bmatrix} $
(3) 두 번째 쿼리의 셀프 어텐션 출력값 계산하기
이번에는 수식6의 두번째 쿼리 벡터에 대해 셀프 어텐션 계산을 해보겠습니다. 수식13은 두번째 쿼리 벡터와 모든 키 벡터들에 전치(transpose)를 취한 행렬($\mathbf{K}^{\top}$)을 행렬곱한 결과입니다.
수식13 우변에 있는 벡터의 첫번째 요소값(4)은 두번째 쿼리 벡터와 첫번째 키 벡터 사이의 문맥적 관계성이 녹아든 결과입니다. 두번째 요소값(16)은 두번째 쿼리 벡터와 두번째 쿼리 벡터 사이, 세번째 요소값(12)은 두번째 쿼리와 세번째 쿼리 벡터 사이의 문맥적 관계성이 포함돼 있습니다.
수식13 두번째 쿼리 벡터에 관한 셀프 어텐션 계산 (1)
$ \begin{bmatrix} 2 & 2 & 2 \end{bmatrix}\times \begin{bmatrix} 0 & 4 & 2 \\ 1 & 4 & 3 \\ 1 & 0 & 1 \end{bmatrix}=\begin{bmatrix} 4 & 16 & 12 \end{bmatrix} $
수식14는 수식13의 결과에 키 벡터의 차원수($d_{\text{K}}=3$)의 제곱근으로 나눠준 뒤 소프트맥스를 취해 만든 벡터입니다.
수식14 두번째 쿼리 벡터에 관한 셀프 어텐션 계산 (2)
\[\text{softmax} ([ \frac{4}{\sqrt{3}}, \frac{16}{\sqrt{3}}, \frac{12}{\sqrt{3}} ])= [ 0.00089, 0.90884, 0.09027 ]\]두번째 쿼리 벡터에 대한 셀프 어텐션 계산의 마지막은 수식15와 같습니다. 수식14의 소프트맥스 벡터와 수식8의 밸류 벡터들을 행렬곱해서 계산을 수행합니다. 이는 소프트맥스 벡터의 각 요소값에 대응하는 밸류 벡터들을 가중합한 결과와 동치입니다.
수식15 두번째 쿼리 벡터에 관한 셀프 어텐션 계산 (3)
$ \begin{bmatrix} 0.00089 & 0.90884 & 0.09027 \end{bmatrix}\times \begin{bmatrix} 1 & 2 & 3 \\ 2 & 8 & 0 \\ 2 & 6 & 3 \end{bmatrix}=\begin{bmatrix} 1.9991 & 7.8141 & 0.2735 \end{bmatrix} $
(4) 세 번째 쿼리의 셀프 어텐션 출력값 계산하기
수식6의 마지막 세번째 쿼리 벡터에 대해 셀프 어텐션 계산을 해보겠습니다. 수식16은 세번째 쿼리 벡터와 모든 키 벡터들에 전치(transpose)를 취한 행렬($\mathbf{K}^{\top}$)을 행렬곱한 결과입니다.
수식16 우변에 있는 벡터의 첫번째 요소값(4)은 세번째 쿼리 벡터와 첫번째 키 벡터 사이의 문맥적 관계성이 녹아든 결과입니다. 두번째 요소값(12)은 세번째 쿼리 벡터와 두번째 쿼리 벡터 사이, 세번째 요소값(10)은 세번째 쿼리와 세번째 쿼리 벡터 사이의 문맥적 관계성이 포함돼 있습니다.
수식16 세번째 쿼리 벡터에 관한 셀프 어텐션 계산 (1)
$ \begin{bmatrix} 2 & 1 & 3 \end{bmatrix}\times \begin{bmatrix} 0 & 4 & 2 \\ 1 & 4 & 3 \\ 1 & 0 & 1 \end{bmatrix}=\begin{bmatrix} 4 & 12 & 10 \end{bmatrix} $
수식17은 수식16의 결과에 키 벡터의 차원수($d_{\text{K}}=3$)의 제곱근으로 나눠준 뒤 소프트맥스를 취해 만든 벡터입니다.
수식17 세번째 쿼리 벡터에 관한 셀프 어텐션 계산 (2)
\[\text{softmax} ([ \frac{4}{\sqrt{3}}, \frac{12}{\sqrt{3}}, \frac{10}{\sqrt{3}} ])= [ 0.00744, 0.75471, 0.23785 ]\]세번째 쿼리 벡터에 대한 셀프 어텐션 계산의 마지막은 수식18과 같습니다. 수식17의 소프트맥스 벡터와 수식8의 밸류 벡터들을 행렬곱해서 계산을 수행합니다. 이는 소프트맥스 벡터의 각 요소값에 대응하는 밸류 벡터들을 가중합한 결과와 동치입니다.
수식18 세번째 쿼리 벡터에 관한 셀프 어텐션 계산 (3)
$ \begin{bmatrix} 0.00744 & 0.75471 & 0.23785 \end{bmatrix}\times \begin{bmatrix} 1 & 2 & 3 \\ 2 & 8 & 0 \\ 2 & 6 & 3 \end{bmatrix}=\begin{bmatrix} 1.9926 & 7.4796 & 0.7359 \end{bmatrix} $
지금까지는 손 계산으로 셀프 어텐션을 살펴봤는데요. 파이토치를 활용해 코드로도 확인해 보겠습니다. 우선 입력 벡터 시퀀스 $\mathbf{X}$와 쿼리, 키, 밸류 구축에 필요한 행렬들을 앞선 예시 그대로 정의합니다. 코드1과 같습니다.
코드1 변수 정의
import torch
x = torch.tensor([
[1.0, 0.0, 1.0, 0.0],
[0.0, 2.0, 0.0, 2.0],
[1.0, 1.0, 1.0, 1.0],
])
w_query = torch.tensor([
[1.0, 0.0, 1.0],
[1.0, 0.0, 0.0],
[0.0, 0.0, 1.0],
[0.0, 1.0, 1.0]
])
w_key = torch.tensor([
[0.0, 0.0, 1.0],
[1.0, 1.0, 0.0],
[0.0, 1.0, 0.0],
[1.0, 1.0, 0.0]
])
w_value = torch.tensor([
[0.0, 2.0, 0.0],
[0.0, 3.0, 0.0],
[1.0, 0.0, 3.0],
[1.0, 1.0, 0.0]
])
코드2는 수식2를 계산해 입력 벡터 시퀀스로 쿼리, 키, 밸류 벡터들을 만드는 파트입니다. torch.matmul
는 행렬곱을 수행하는 함수입니다.
코드2 쿼리, 키, 밸류 만들기
keys = torch.matmul(x, w_key)
querys = torch.matmul(x, w_query)
values = torch.matmul(x, w_value)
코드3은 코드2에서 만든 쿼리와 키 벡터들을 행렬곱해서 어텐션 스코어를 만드는 과정입니다. keys.T
는 키 벡터들을 전치한 행렬입니다.
코드3 어텐션 스코어 만들기
attn_scores = torch.matmul(querys, keys.T)
코드3을 수행한 결과는 다음과 같습니다. 정확히 수식10, 수식13, 수식16과 같습니다. 이들은 쿼리 벡터를 하나씩 떼어서 계산을 수행한 것인데요. 코드3처럼 쿼리 벡터들을 한꺼번에 모아서 키 벡터들과 행렬곱을 수행하여도 같은 결과를 낼 수 있음을 확인할 수 있습니다.
>>> attn_scores
tensor([[ 2., 4., 4.],
[ 4., 16., 12.],
[ 4., 12., 10.]])
코드4는 코드3의 결과에 키 벡터의 차원수의 제곱근으로 나눠준 뒤 소프트맥스를 취하는 과정입니다.
코드4 소프트맥스 확률값 만들기
import numpy as np
from torch.nn.functional import softmax
key_dim_sqrt = np.sqrt(keys.shape[-1])
attn_scores_softmax = softmax(attn_scores / key_dim_sqrt, dim=-1)
코드4를 수행한 결과는 다음과 같습니다. 정확히 수식11, 수식14, 수식17과 같습니다.
>>> attn_scores_softmax
tensor([[1.3613e-01, 4.3194e-01, 4.3194e-01],
[8.9045e-04, 9.0884e-01, 9.0267e-02],
[7.4449e-03, 7.5471e-01, 2.3785e-01]])
코드5는 코드4에서 구한 소프트맥스 확률과 밸류 벡터들을 가중합하는 과정을 수행합니다.
코드5 소프트맥스 확률과 밸류를 가중합하기
weighted_values = torch.matmul(attn_probs, values)
코드5의 수행 결과는 다음과 같습니다. 정확히 수식12, 수식15, 수식18과 같습니다.
>>> weighted_values
tensor([[1.8639, 6.3194, 1.7042],
[1.9991, 7.8141, 0.2735],
[1.9926, 7.4796, 0.7359]])
셀프 어텐션의 학습 대상은 쿼리, 키, 밸류를 만드는 가중치 행렬입니다. 코드 예시에서는 w_query
, w_key
, w_value
입니다. 이들은 태스크(예: 기계 번역)를 가장 잘 수행하는 방향으로 학습 과정에서 업데이트됩니다.
멀티 헤드 어텐션
멀티-헤드 어텐션(Multi-Head Attention)은 셀프 어텐션(self attention)을 여러 번 수행한 걸 가리킵니다. 여러 헤드가 독자적으로 셀프 어텐션을 계산한다는 이야기입니다. 비유하자면 같은 문서(입력)를 두고 독자(헤드) 여러 명이 함께 읽는 구조라 할 수 있겠습니다.
그림9는 입력 단어 수는 2개, 밸류의 차원수는 3, 헤드는 8개인 멀티-헤드 어텐션을 나타낸 그림입니다. 개별 헤드의 셀프 어텐션 수행 결과는 ‘입력 단어 수 $\times$ 밸류 차원수’, 즉 $2 \times 3$ 크기를 갖는 행렬입니다. 8개 헤드의 셀프 어텐션 수행 결과를 다음 그림의 ①처럼 이어 붙이면 $2 \times 24$의 행렬이 됩니다.
그림9 멀티-헤드 어텐션(Multi-Head Attention)
멀티-헤드 어텐션은 개별 헤드의 셀프 어텐션 수행 결과를 이어붙인 행렬(①)에 $\mathbf{W}^O$를 행렬곱해서 마무리됩니다. $\mathbf{W}^O$의 크기는 ‘셀프 어텐션 수행 결과 행렬의 열(column)의 수 $\times$ 목표 차원수’가 됩니다. 만일 멀티-헤드 어텐션 수행 결과를 그림9와 같이 4차원으로 설정해 두고 싶다면 $\mathbf{W}^O$는 $24 \times 4$ 크기의 행렬이 되어야 합니다.
멀티-헤드 어텐션의 최종 수행 결과는 ‘입력 단어 수 $\times$ 목표 차원수’입니다. 그림9에서는 입력 단어 두 개 각각에 대해 3차원짜리 벡터가 멀티-헤드 어텐션의 최종 결과물로 도출되었습니다. 멀티 헤드 어텐션은 인코더, 디코더 블록 모두에 적용됩니다. 앞으로 특별한 언급이 없다면 셀프 어텐션은 멀티 헤드 어텐션인 것으로 이해하면 좋겠습니다.
인코더에서 수행하는 셀프 어텐션
이번엔 트랜스포머 인코더에서 수행하는 계산 과정을 셀프 어텐션을 중심으로 살펴보겠습니다. 그림10은 트랜스포머 인코더 블록을 나타낸 그림인데요. 기억을 떠올리고자 다시 가져와 봤습니다. 인코더 블록의 입력은 이전 블록의 단어 벡터 시퀀스, 출력은 이번 블록 수행 결과로 도출된 단어 벡터 시퀀스입니다.
그림10 트랜스포머 인코더 블록
인코더에서 수행되는 셀프 어텐션은 쿼리, 키, 밸류가 모두 소스 시퀀스와 관련된 정보입니다. 트랜스포머의 학습 과제가 한국어에서 영어로 번역하는 태스크라면, 인코더의 쿼리, 키, 밸류는 모두 한국어가 된다는 이야기입니다.
그림11은 쿼리가 어제
인 경우의 셀프 어텐션을 나타냈습니다. 잘 학습된 트랜스포머라면 쿼리, 키로부터 계산한 소프트맥스 확률(코드5의 attn_scores_softmax
에 대응) 가운데 과거 시제에 해당하는 갔었어
, 많더라
등의 단어가 높은 값을 지닐 겁니다. 이 확률값들과 밸류 벡터를 가중합해서 셀프 어텐션 계산을 마칩니다.
그림11 쿼리가 ‘어제’일 때 셀프 어텐션
그림12는 쿼리가 카페
인 경우의 셀프 어텐션을 나타냈습니다. 잘 학습된 트랜스포머라면 쿼리, 키로부터 계산한 소프트맥스 확률 가운데 장소를 지칭하는 대명사 거기
가 높은 값을 지닐 겁니다. 이 확률값들과 밸류 벡터를 가중합해서 셀프 어텐션 계산을 마칩니다.
그림12 쿼리가 ‘카페’일 때 셀프 어텐션
이같은 계산을 갔었어
, 거기
, 사람
, 많더라
에 대해서도 수행합니다. 결국 인코더에서 수행하는 셀프 어텐션은 소스 시퀀스 내의 모든 단어 쌍(pair) 사이의 관계를 고려하게 됩니다.
디코더에서 수행하는 셀프 어텐션
그림13은 인코더와 디코더 블록을 나타낸 그림입니다. 그림13에서도 확인할 수 있듯 디코더 입력은 ① 인코더 마지막 블록에서 나온 소스 단어 벡터 시퀀스 ② 이전 디코더 블록의 수행 결과로 도출된 타깃 단어 벡터 시퀀스입니다.
그림13 인코더-디코더
그러면 디코더에서 수행되는 셀프 어텐션을 순서대로 살펴보겠습니다. 우선 마스크 멀티 헤드 어텐션(Masked Multi-Head Attention)입니다. 이 모듈에서는 타깃 언어의 단어 벡터 시퀀스를 계산 대상으로 합니다. 한국어를 영어로 번역하는 태스크를 수행하는 트랜스포머 모델이라면 여기서 계산되는 대상은 영어 단어 시퀀스가 됩니다.
이 파트에서는 입력 시퀀스가 타깃 언어(영어)로 바뀌었을 뿐 인코더 쪽 셀프 어텐션과 크게 다를 바가 없습니다. 그림14는 쿼리가 cafe
인 경우의 마스크 멀티 헤드 어텐션을 나타낸 것입니다. 학습이 잘 되었다면 쿼리, 키로부터 계산한 소프트맥스 확률(코드5의 attn_scores_softmax
에 대응) 가운데 장소를 지칭하는 대명사 There
가 높은 값을 지닐 겁니다. 이 확률값들과 밸류 벡터를 가중합해서 셀프 어텐션 계산을 마칩니다.
그림14 타깃 문장의 셀프 어텐션
그 다음은 멀티 헤드 어텐션입니다. 인코더와 디코더 쪽 정보를 모두 활용합니다. 인코더에서 넘어온 정보는 소스 언어의 문장(어제 카페 갔었어 거기 사람 많더라
)의 단어 벡터 시퀀스입니다. 디코더 정보는 타깃 언어 문장(<s> I went to the cafe yesterday There ...
)의 단어 벡터 시퀀스입니다. 전자를 키, 후자를 쿼리로 삼아 셀프 어텐션 계산을 수행합니다.
그림15는 쿼리 단어가 cafe
인 멀티 헤드 어텐션 계산을 나타낸 것입니다. 학습이 잘 되었다면 쿼리(타깃 언어 문장), 키(소스 언어 문장)로부터 계산한 소프트맥스 확률(코드5의 attn_scores_softmax
에 대응) 가운데 쿼리에 대응하는 해당 장소를 지칭하는 단어 카페
가 높은 값을 지닐 겁니다. 이 확률값들과 밸류 벡터를 가중합해서 셀프 어텐션 계산을 마칩니다.
그림15 소스-타깃 문장 간 셀프 어텐션
그런데 학습 과정에서는 약간의 트릭을 씁니다. 트랜스포머 모델의 최종 출력은 타겟 시퀀스 각각에 대한 확률 분포인데요. 모델이 한국어를 영어로 번역하는 태스크를 수행하고 있다면 영어 문장의 다음 단어가 어떤 것이 적절할지에 관한 확률이 됩니다.
예컨대 인코더에 어제 카페 갔었어 거기 사람 많더라
가, 디코더에 <s>
가 입력된 상황이라면 트랜스포머 모델은 다음 영어 단어 I
를 맞추도록 학습됩니다. 하지만 학습 과정에서 모델에 이번에 맞춰야할 정답인 I
를 알려주게 되면 학습하는 의미가 사라집니다.
따라서 정답을 포함한 미래 정보를 셀프 어텐션 계산에서 제외하게 됩니다. 이 때문에 디코더 블록의 첫번째 어텐션을 마스크 멀티-헤드 어텐션(Masked Multi-Head Attention)이라고 부릅니다. 그림16과 같습니다. 마스킹은 확률(코드5의 attn_scores_softmax
에 대응)이 0이 되도록 하여, 밸류와의 가중합에서 해당 단어 정보들이 무시되게끔 하는 방식으로 수행됩니다.
그림16 학습 시 디코더에서 수행되는 셀프 어텐션 (1)
그림16처럼 셀프 어텐션을 수행하면 디코더 마지막 블록 출력 벡터 가운데 <s>
에 해당하는 벡터에는 소스 문장 전체의 문맥적 관계성이 함축되어 있습니다. 트랜스포머 모델은 이 <s>
벡터를 가지고 I
를 맞추도록 학습합니다. 다시 말해 정답 I
에 관한 확률은 높이고 다른 단어들의 확률은 낮아지도록 합니다. 그림17과 같습니다.
그림17 Model Update (1)
그림18은 인코더에 어제 카페 갔었어 거기 사람 많더라
가, 디코더에 <s> I
가 입력된 상황입니다. 따라서 이때의 마스크 멀티-헤드 어텐션은 정답 단어 went
이후의 모든 타겟 언어 단어들을 모델이 보지 못하도록 하는 방식으로 수행됩니다.
그림18 학습 시 디코더에서 수행되는 셀프 어텐션 (2)
디코더 마지막 블록의 I
벡터에는 소스 문장(어제
… 갔더라
)과 <s> I
사이의 문맥적 관계성이 녹아 있습니다. 트랜스포머 모델은 이 I
벡터를 가지고 went
를 맞히도록 학습합니다. 다시 말해 정답 went
에 관한 확률은 높이고 다른 단어들의 확률은 낮아지도록 합니다. 그림19와 같습니다.
그림19 Model Update (2)
그림20은 인코더에 어제 카페 갔었어 거기 사람 많더라
가, 디코더에 <s> I went
가 입력된 상황입니다. 따라서 이때의 마스크 멀티-헤드 어텐션은 정답 단어 to
이후의 모든 타깃 언어 단어들을 모델이 보지 못하도록 하는 방식으로 수행됩니다.
그림20 Masked Attention (3)
디코더 마지막 블록의 went
벡터에는 소스 문장과 <s> I went
사이의 문맥적 관계성이 녹아 있습니다. 트랜스포머 모델은 이 went
에 해당하는 벡터를 가지고 to
를 맞추도록 학습합니다. 다시 말해 정답 to
에 관한 확률은 높이고 다른 단어들의 확률은 낮아지도록 합니다. 그림21과 같습니다.
그림21 Model Update (3)
트랜스포머 모델은 이런 방식으로 말뭉치 전체를 훑어가면서 반복 학습합니다. 학습을 마친 모델은 다음처럼 기계 번역을 수행(인퍼런스)합니다.
- 소스 언어(한국어) 문장을 인코더에 입력해 인코더 마지막 블록의 단어 벡터 시퀀스를 추출합니다.
- 인코더에서 넘어온 소스 언어 문장 정보와 디코더에 타깃 문장 시작을 알리는 스페셜 토큰
<s>
를 넣어서, 타깃 언어(영어)의 첫 번째 토큰을 생성합니다. - 인코더 쪽에서 넘어온 소스 언어 문장 정보와 이전에 생성된 타깃 언어 토큰 시퀀스를 디코더에 넣어서 만든 정보로 타깃 언어의 다음 토큰을 생성합니다.
- 생성된 문장 길이가 충분하거나 문장 끝을 알리는 스페셜 토큰
</s>
가 나올 때까지 3을 반복합니다.
한편 </s>
는 보통 타깃 언어 문장 맨 마지막에 붙여서 학습합니다. 이 토큰이 나타났다는 것은 모델이 타깃 문장 생성을 마쳤다는 의미입니다.