๋ฐ์ํ
1. Attention
Attention์ CS ๋ฐ ML์์ ์ค์ํ ๊ฐ๋ ์ค ํ๋๋ก ์ฌ๊ฒจ์ง๋๋ค. Attention์ ๋งค์ปค๋์ฆ์ ์ฃผ๋ก Sequence Data๋ฅผ ์ฒ๋ฆฌํ๊ฑฐ๋ ์์ฑํ๋ ๋ชจ๋ธ์์ ์ฌ์ฉ๋ฉ๋๋ค. -> Sequence ์ ๋ ฅ์ ์ํํ๋ ๋จธ์ ๋ฌ๋ ํ์ต ๋ฐฉ๋ฒ์ ์ผ์ข
- Attention์ ๊ฐ๋ ์ Decoder์์ ์ถ๋ ฅ์ ์์ธกํ๋ ๋งค์์ (time step)๋ง๋ค, Encoder์์์ ์ ์ฒด์ ์ ๋ ฅ ๋ฌธ์ฅ์ ๋ค์ ํ๋ฒ ์ฐธ๊ณ ํ๊ฒ ํ๋ ๊ธฐ๋ฒ์ ๋๋ค.
- ๋จ, ์ ์ฒด ์
๋ ฅ ๋ฌธ์ฅ์ ์ ๋ถ ๋ค ์ข
์ผํ ๋น์จ๋ก ์ฐธ๊ณ ํ๋ ๊ฒ์ด ์๋๋ผ, ํด๋น ์์ ์์ ์์ธกํด์ผ ํ ์์์ ์ฐ๊ด์ด ์๋ ์
๋ ฅ ์์ ๋ถ๋ถ์ Attention(์ง์ค)ํด์ ๋ณด๊ฒ ํฉ๋๋ค.
- ์ด ๋ฐฉ๋ฒ์ด ๋ฌธ๋งฅ์ ํ์ ํ๋ ํต์ฌ์ ๋ฐฉ๋ฒ์ด๋ฉฐ, ์ด๋ฌํ ๋ฐฉ์์ DL(๋ฅ๋ฌ๋)๋ชจ๋ธ์ ์ ์ฉํ๊ฒ์ด 'Attention" ๋ฉ์ปค๋์ฆ์ ๋๋ค.
- ์ฆ, ์ด๋ง์ Task ์ํ์ ํฐ ์ํฅ์ ์ฃผ๋ Element(์์)์ Weight(๊ฐ์ค์น)๋ฅผ ํฌ๊ฒ ์ฃผ๊ณ , ๊ทธ๋ ์ง ์์ Weight(๊ฐ์ค์น)๋ ๋ฎ๊ฒ ์ค๋๋ค.
Attention์ด ๋ฑ์ฅํ ์ด์ ?
- Attention ๊ธฐ๋ฒ์ด ๋ฑ์ฅํ ์ด์ ์ ๋ํด์ ํ๋ฒ ์์๋ณด๋ฉด ์์ ๋ดค๋ ๊ธ์ค์ Seq2Seq Model์ ๋ํ ๊ธ์ด ์์์ต๋๋ค.
- ๊ฐ๋ ์ ๋ฐํ์ฌ ์ค๋ช ์ ํด๋ณด๋ฉด, Seq2Seq๋ Encoder์์ Input Sequence(์ ๋ ฅ ์ํ์ค)๋ฅผ ํ๋์ ๊ณ ์ ๋ ํฌ๊ธฐ์ Vector๋ก ์์ถํฉ๋๋ค. ์ด๋ฅผ Context Vector๋ผ๊ณ ํ๊ณ , Decoder๋ ์ด Context Vector๋ฅผ ํ์ฉํด์ Output Sequence(์ถ๋ ฅ ์ํ์ค)๋ฅผ ๋ง๋ญ๋๋ค.
- ๊ทผ๋ฐ, Seq2Seq ๋ชจ๋ธ์ RNN ๋ชจ๋ธ์ ๊ธฐ๋ฐํ๊ณ ์์ต๋๋ค. ๊ทธ๋ฌ๋ฉด ๋ฌธ์ ๊ฐ ๋ฐ์ํ๋๋ฐ
- ํ๋๋ Context Vector์ ํฌ๊ธฐ๊ฐ ๊ณ ์ ๋์ด ์์ผ๋ฏ๋ก ๊ณ ์ ๋ ํฌ๊ธฐ์ Vector์ Input์ผ๋ก ๋ค์ด์ค๋ ์ ๋ณด๋ฅผ ์์ถํ๋ ค๊ณ ํ๋๊น ์ ๋ณด์์ค์ด ๋ฐ์ํ ์๋ ์๋ค๋์ .
- ๋ค๋ฅธ ํ๋๋ RNN์ด๋ RNN๊ธฐ๋ฐ ๋ชจ๋ธ๋ค์ ๊ณ ์ง์ ์ธ ๋ฌธ์ ๋ Gradient Vanishing(๊ธฐ์ธ๊ธฐ ์์ค)๋ฌธ์ ๊ฐ ์กด์ฌํ๊ธฐ ๋๋ฌธ์ ๋๋ค.
- ๊ทธ๋์ Seq2Seq ๋ชจ๋ธ์ Input Sequence(์ ๋ ฅ ์ํ์ค)๊ฐ ๊ธธ์ด์ง๋ฉด Output Sequence(์ถ๋ ฅ ์ํ์ค)์ ์ ํ๋๊ฐ ๋จ์ด์ง๋๋ค.
- ๊ทธ๋์ ์ ํ๋๋ฅผ ๋จ์ด์ง๋๊ฒ์ ์ด๋์ ๋ ๋ฐฉ์งํด์ฃผ๊ธฐ ์ํด์ ๋ฑ์ฅํ ๋ฐฉ๋ฒ์ด Attention ์ ๋๋ค.
Seq2seq, Encoder, Decoder ๊ฐ๋ ์ ์ ๋ฆฌํด ๋์ ๊ธ์ ๋งํฌํด ๋๊ฒ ์ต๋๋ค. ์ฐธ๊ณ ํด์ฃผ์ธ์!
2. Attention Function
- Attention ํจ์์ ๋ฐํ์ฌ ์์๋ณด๊ฒ ์ต๋๋ค. Attention์ ํจ์๋ก ํํํ๋ฉด ๋ค์๊ณผ ๊ฐ์ด ํํ๋ฉ๋๋ค.
Attention(Q, K, V) = Attention Value
Q = Query : t ์์ ์ Decoder Cell ์์์ Hidden state(์๋ ์ํ)
K = Keys : ๋ชจ๋ ์์ ์ Encoder Cell ์
์ Hidden state(์๋ ์ํ)๋ค
V = Values : ๋ชจ๋ ์์ ์ Encdoer Cell ์
์ Hidden state(์๋ ์ํ)๋ค
- Attention function์ ์ฃผ์ด์ง 'Query(์ฟผ๋ฆฌ)'์ ๋ํด์ ๋ชจ๋ 'Key(ํค)'์์ ์ ์ฌ๋๋ฅผ ๊ฐ๊ฐ ๊ตฌํฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ ๊ตฌํ ์ด ์ ์ฌ๋๋ฅผ ํค์ ๋งตํ ๋์ด์๋ ๊ฐ๊ฐ์ 'Value(๊ฐ)'์ ๋ฐ์ํฉ๋๋ค.
- ๊ทธ๋ฆฌ๊ณ ์ ์ฌ๋๊ฐ ๋ฐ์๋ 'Value(๊ฐ)'๋ค์ ๋ํํ Returnํฉ๋๋ค. ์ด Returnํ ๊ฐ์ 'Attention Value(์ดํ ์ ๊ฐ)'์ด๋ผ๊ณ ํฉ๋๋ค. ํ๋ฒ Attention์ ์์ ๋ค์ ๋ณด๊ฒ ์ต๋๋ค. Scaled Dot-Product Attention, Self-Attention, Multi-Head Attention3๊ฐ๋ฅผ ํ๋ฒ ๋ณด๊ฒ ์ต๋๋ค.
3. Dot-Product Attention
Attention์ ์ข ๋ฅ๊ฐ ์๋๋ฐ, ๊ทธ ์ค์ ํ๋์ธ Dot-Product Attention์ ๋ํด ์์๋ณด๊ฒ ์ต๋๋ค.
- Decoder์ 3๋ฒ์งธ LSTM Cell์์ ์ถ๋ ฅ ๋จ์ด๋ฅผ ์์ธกํ ๋, Attention ๋งค์ปค๋์ฆ์ ์ฌ์ฉํฉ๋๋ค.
- ์์ Decoder์ ์ฒซ, ๋๋ฒ์งธ ์ ์ ์ด๋ฏธ Attention ๋งค์ปค๋์ฆ์ ์ด์ฉํด์ ๋จ์ด๋ฅผ ์์ธกํ๋ ๊ณผ์ ์ ๊ฑฐ์ณค์ต๋๋ค.
- Decoder์ 3๋ฒ์งธ LSTM Cell์ ์ถ๋ ฅ ๋จ์ด๋ฅผ ์์ธกํ๊ธฐ ์ํ์ฌ Encoder์ ๋ชจ๋ ๋จ์ด๋ค์ ์ ๋ณด๋ฅผ ์ฐธ๊ณ ํ๋ ค๊ณ ํฉ๋๋ค.
- Encoder์ Softmax ํจ์๋ฅผ ํตํด ๋์จ ๊ฒฐ๊ณผ๊ฐ์ I, am, a, student ๋จ์ด ๊ฐ๊ฐ์ด ์ถ๋ ฅ์ด ์ผ๋ง๋ ๋์์ด ๋๋์ง ์ ๋๋ฅผ ์์นํ ํ ๊ฐ์ ๋๋ค.
- ์ฌ๊ธฐ์๋ Softmax ํจ์์์ ์๋ ์ง์ฌ๊ฐํ์ ํฌ๊ธฐ๊ฐ ํด์๋ก ๊ฒฐ๊ณผ๊ฐ์ด ๋จ์ด ๊ฐ๊ฐ์ ์ถ๋ ฅ์ ๋์์ด ๋๋ ์ ๋๋ฅผ ์์นํ ํ ํฌ๊ธฐ ์ ๋๋ค.
- ๊ฐ ์ ๋ ฅ ๋จ์ด๊ฐ Decoder ์์ธก์ ๋์์ด ๋๋ ์ ๋๋ฅผ ์์นํ ํ๋ฉด ์ด๋ฅผ ํ๋์ ์ ๋ณด๋ก ํฉ์นํ Decoder๋ก ๋ณด๋ ๋๋ค (๋งจ ์์ ์ด๋ก์ ์ผ๊ฐํ) -> ๊ฒฐ๋ก ์ Decoder๋ ์ถ๋ ฅ ๋จ์ด๋ฅผ ๋ ์ ํํ๊ฒ ์์ธกํ ํ๋ฅ ์ด ๋์์ง๋๋ค. ๊ตฌ์ ์ ์ผ๋ก ์ค๋ช ํด ๋ณด๊ฒ ์ต๋๋ค.
1. Attention Score ๊ตฌํ๊ธฐ
- Encoder์ Time Step(Encoder์ ์์ )์ ๊ฐ๊ฐ 1๋ถํฐ N๊น์ง ๋ผ๊ณ ํ์๋ Hidden State(์๋ ์ํ)๋ฅผ ๊ฐ๊ฐ h1, h2, ~ hn์ด๋ผ๊ณ ํ๊ณ , Decoder์ Time Step(ํ์ฌ ์์ ) t์์์ Decoder์ Hidden State(์๋ ์ํ)๋ฅผ st๋ผ๊ณ ํฉ๋๋ค.
- ์ฌ๊ธฐ์๋ Encoder์ Hidden State(์๋ ์ํ), Decoder์ Hidden State(์๋ ์ํ)์ ์ฐจ์์ด ๊ฐ๋ค๊ณ ๊ฐ์ ํด๋ณด๊ฒ ์ต๋๋ค.
- ๊ทธ๋ฌ๋ฉด Encoder, Decoder์ Hidden state(์๋ ์ํ)์ Dimension(์ฐจ์)์ ๋์ผํ๊ฒ 4๋ก ๊ฐ์ต๋๋ค.
- Attention ๋ฉ์ปค๋์ฆ์์๋ ์ถ๋ ฅ๋จ์ด๋ฅผ ์์ธกํ ๋ Attention Value(์ดํ ์ ๊ฐ)์ ํ์๋ก ํฉ๋๋ค. t๋ฒ์งธ ๋จ์ด๋ฅผ ์์ธกํ๊ธฐ ์ํ Decoder์ Cell์ 2๊ฐ์ input๊ฐ์ ํ์๋ก ํ๋๋ฐ, ์ด์ ์์ ์ธ t-1์ ์๋ ์ํ์ ์ด์ ์์ t-1์ ๋์จ ์ถ๋ ฅ ๋จ์ด์ ๋๋ค.
- ๊ทธ๋ฆฌ๊ณ Attention Value(์ดํ ์ ๊ฐ)์ ์ข ํฉํ์ฌ Decoder Cell์ ํ์ฌ ์์ (t)์ Output์ ์์ฑํฉ๋๋ค.
+ Hidden state(์๋ ์ํ): RNN(์ํ์ ๊ฒฝ๋ง), LSTM(์ฅ, ๋จ๊ธฐ ๋ฉ๋ชจ๋ฆฌ)์์ ์๊ฐ์ ์์กด์ฑ์ ํ์ต ๋ฐ ๋ชจ๋ธ์ ์ด์ ์ ๋ณด ์ ์ง์ ์ฌ์ฉ๋๋ ๋ด๋ถ ์ํ์ ๋๋ค.
Input Sequence๋ฅผ ํ๋ฒ์ ํ๋์ฉ ์ฒ๋ฆฌํ๋ฉด์ ๋ด๋ถ ์ํ๋ฅผ ์ ๋ฐ์ดํธ ํ๋๋ฐ, ์ด๋ ํ์ฌ ์์ ์ input ๋ฐ ์ด์ ์์ ์ hidden state(์๋ ์ํ)์ ์ํฅ์ ๋ฐ์ผ๋ฉฐ, ์ด์ ์ ๋ณด ์ ์ง ๋ฐ ํ์ฌ ์ ๋ ฅ์ ๋ฐ๋ผ ๋ณํํฉ๋๋ค.
ํ๋ฒ ํ์ฌ ์์ ์ ๋จ์ด์ธ t๋ฒ์งธ ๋จ์ด๋ฅผ ์์ธกํ๊ธฐ ์ํ Attention ๊ฐ์ 'at'๋ผ๊ณ ์ ์ํด๋ณด๊ฒ ์ต๋๋ค.
- ํ์ฌ ์์ ์ ๋จ์ด์ธ t๋ฒ์งธ ๋จ์ด๋ฅผ ์์ธกํ๊ธฐ ์ํ Attention ๊ฐ์ธ 'at' ๊ฐ์ ๊ตฌํ๋ ค๋ฉด Attention Score(์ดํ ์ ์ค์ฝ์ด)๋ฅผ ๊ตฌํด์ผ ํฉ๋๋ค.
- ์ฌ๊ธฐ์ Attention Score(์ดํ ์ ์ค์ฝ์ด)๋ ํ์ฌ Decoder์ ์์ t์์ ๋จ์ด๋ฅผ ์์ธกํ๊ธฐ ์ํด์ Encoder์ ๋ชจ๋ Hidden state(์๋ ์ํ) ๊ฐ๊ฐ์ด Decoder์ ํ์์ ์ Hidden state(์๋ ์ํ)๊ฐ์ธ 'st'์ ์ผ๋ง๋ ์ ์ฌํ์ง๋ฅผ ํ๋จํ๋ ๊ฐ์ ๋๋ค.
- Dot-Product Attention์์๋ Score๊ฐ์ ๊ตฌํ๊ธฐ ์ํ์ฌ Decoder์ ํ์์ ์ Hidden state(์๋ ์ํ)๊ฐ์ธ 'St'๋ฅผ Transpose(์ ์น)ํ๊ณ ๊ฐ Hidden state(์๋ ์ํ)์ Dot Product(๋ด์ )์ ์ํํฉ๋๋ค.
- ์๋ฅผ ๋ค์ด์ 'St'์ Encoder์ i๋ฒ์งธ Hidden state์ Attention Score์ ๊ณ์ฐ ๋ฐฉ๋ฒ์ ์๋์ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ต๋๋ค.
- Attention Score Function(์ดํ ์ ์ค์ฝ์ด ํจ์)๋ฅผ ์ ์ํด๋ณด๋ฉด ์๋ ์ผ์ชฝ์ ์์๊ณผ ๊ฐ์ต๋๋ค.
- ๊ทธ๋ฆฌ๊ณ Decoder์ ํ์์ ์ Hidden state(์๋ ์ํ)๊ฐ์ธ 'st'์ Encoder์ ๋ชจ๋ Hidden State(์๋ ์ํ)์ Attention Score์ ๋ชจ์๊ฐ์ 'et'๋ผ๊ณ ํ๊ณ , 'et'์ ์์์ ์๋์ ์ค๋ฅธ์ชฝ ์์๊ณผ ๊ฐ์ต๋๋ค.
2. Attention Distribution(์ดํ ์ ๋ถํฌ) ๊ตฌํ๊ธฐ by Softmax ํจ์
- 'et'์ Softmax Function(Softmax ํจ์)๋ฅผ ์ ์ฉํ์ฌ, ๋ชจ๋ ๊ฐ์ ํฉ์น๋ฉด 1์ด ๋๋ ํ๋ฅ ๋ถํฌ๋ฅผ ์ป์ด๋ ๋๋ค.
- ์ด๋ฅผ Attention Distribution(์ดํ ์ ๋ถํฌ)๋ผ๊ณ ํ๋ฉฐ, ๊ฐ๊ฐ์ ๊ฐ์ Attention Weight(์ดํ ์ ๊ฐ์ค์น)๋ผ๊ณ ํฉ๋๋ค.
- ์๋ฅผ ๋ค์ด์ Softmax Function(Softmax ํจ์)๋ฅผ ์ ์ฉํ์ฌ ์ป์ ์ถ๋ ฅ๊ฐ๋ค์ Attention ๊ฐ์ค์น๋ค์ ํฉ์ 1์ ๋๋ค.
- ์ด ๊ทธ๋ฆผ์์๋ ๊ฐ Encoder์ Hidden State(์๋ ์ํ)์์์ Attention Weight(์ดํ ์ ๊ฐ์ค์น)์ ํฌ๊ธฐ๋ฅผ ์ง์ฌ๊ฐํ์ ํฌ๊ธฐ๋ก ์๊ฐํ ํ์ฌ, Attention Weight(์ดํ ์ ๊ฐ์ค์น)๊ฐ ํด์๋ก ์ง์ฌ๊ฐํ์ด ํฝ๋๋ค.
- Decoder์ ์์ ์ธ 't'์์ Attention Weight(์ดํ ์ ๊ฐ์ค์น)์ ๋ชจ์๊ฐ์ธ Attention Distribution(์ดํ ์ ๋ถํฌ)๋ฅผ 'at'๋ผ๊ณ ํ ๋, 'at'๋ฅผ ์์ผ๋ก ์ ์ํ๋ฉด ์๋์ ์๊ณผ ๊ฐ์ต๋๋ค.
3. ๊ฐ Encoder์ Attention ๊ฐ์ค์น & Hidden state๋ฅผ ํฉ์ณ ์ดํ ์ ๊ฐ ๊ตฌํ๊ธฐ
- Attention์ ์ต์ข ๊ฒฐ๊ณผ ๊ฐ์ ์ป๊ธฐ ์ํด์๋ ๊ฐ Encoder์ Hidden state(์๋์ํ), Attention Weight(์ดํ ์ ๊ฐ์ค์น)๊ฐ๋ค์ ๊ณฑํ๊ณ , ๋ชจ๋ ๋ํฉ๋๋ค.
- ์ด๋ฌํ ๊ณผ์ ์ Weight Sum(๊ฐ์คํฉ)์ด๋ผ๊ณ ํ๋ฉฐ, ์๋์ ์์ Attention์ ์ต์ข ๊ฒฐ๊ณผ์ธ Attention Value(์ดํ ์ ๊ฐ)์ ๋ํ ์์ ๋ณด์ฌ์ค๋๋ค.
- ์ด Attention Value(์ดํ ์ ๊ฐ) 'at'๋ ์ข ์ข Encoder์ ๋ฌธ๋งฅ์ ํฌํจํ๋ฉฐ, Context Vector(์ปจํ ์คํธ ๋ฒกํฐ)๋ผ๊ณ ๋ ํฉ๋๋ค.
- Seq2seq์ Encoder์ ๋ง์ง๋ง Hidden state(์๋์ํ)๋ฅผ Context Vector(์ปจํ ์คํธ ๋ฒกํฐ)๋ก ๋ถ๋ฅด๋๊ฒ๊ณผ ๋ค๋ฆ ๋๋ค.
4. Attention Value & Decoder์ t์์ ์ Hidden state๋ฅผ ์ฐ๊ฒฐ (Attention ๋งค์ปค๋์ฆ์ ํต์ฌ)
- Attention Value(์ดํ ์ ๊ฐ)'at'๊ฐ ๊ตฌํด์ง๋ฉด Attention ๋ฉ์ปค๋์ฆ์ 'at'๋ฅผ 'st'์ ํฉ์ณ์(Concentrate) ํ๋์ Vector๋ก ๋ง๋ญ๋๋ค.
- ์ด ์์ ์ ๊ฒฐ๊ณผ๋ฅผ 'vt'๋ผ๊ณ ์ ์ํ๊ณ , ์ด 'vt'๋ฅผ y^ ์์ธก ์ฐ์ฐ์ input์ผ๋ก ์ฌ์ฉํ์ฌ Encoder์์ ์ป์ ์ ๋ณด๋ฅผ ํ์ฉํ์ฌ y^๋ฅผ ๋ ์ ํํ๊ฒ ์์ธกํ๋๋ฐ ๋์์ ์ค๋๋ค.
5. Output Layer ์ฐ์ฐ์ ์ ๋ ฅ๊ฐ ๊ณ์ฐ ๋ฐ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉํ๊ธฐ
- 'vt'๋ฅผ ๋ฐ๋ก Output Layer(์ถ๋ ฅ์ธต)์ผ๋ก ๋ณด๋ด๊ธฐ ์ ์ ์ถ๋ ฅ์ธต ์ฐ์ฐ์ ํ๋ ๊ณผ์ ์ ๋๋ค.
- Weight(๊ฐ์ค์น) ํ๋ ฌ์ ๊ณฑํ๊ณ , tanh(ํ์ดํผ๋ณผ๋ฆญํ์ ํธ) ํจ์๋ฅผ ์ง๋๋๋ก ํด์ Output Layer ์ฐ์ฐ์ ํ๊ธฐ ์ํ Vector์ธ 'st'๋ฅผ ์ป์ต๋๋ค.
- Seq2seq๋ Output Layer(์ถ๋ ฅ์ธต)์ Input(์ ๋ ฅ)์ด t์์ ์ Hidden State(์๋ ์ํ)์ธ 'st'์์ง๋ง, Attention ๋งค์ปค๋์ฆ ์์๋ Output Layer(์ถ๋ ฅ์ธต)์ Input(์ ๋ ฅ)์ด 'st'๊ฐ ๋ฉ๋๋ค.
- ์์ผ๋ก๋ ์๋์ ์๊ณผ ๊ฐ์ต๋๋ค. 'Wc'๋ ํ์ต ๊ฐ๋ฅํ Weight(๊ฐ์ค์น) ํ๋ ฌ, 'bc'๋ bias(ํธํฅ)์ ๋๋ค.
- 'st'๋ฅผ Output Layer(์ถ๋ ฅ์ธต)์ Input(์ ๋ ฅ)์ผ๋ก ์ฌ์ฉํ์ฌ Predict Vector(์์ธก ๋ฒกํฐ)๋ฅผ ์ป์ต๋๋ค.
Dot-Product Attention Model Example Code
TF(Tensorflow)๋ก ๊ตฌํํ ์ฝ๋์ ๋๋ค.
import tensorflow as tf
def scaled_dot_product_attention(q, k, v, mask):
# Query์ Key ๊ฐ์ dot product ๊ณ์ฐ
matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
# dot product๋ฅผ ์ ๊ทํํ๊ธฐ ์ํด ์ค์ผ์ผ๋ง
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
# ๋ง์ฝ mask๊ฐ ์ ๊ณต๋์๋ค๋ฉด, ์ ํจํ ๋ถ๋ถ์ ๋ํด์๋ง -1e9๋ก ๋ง์คํน
if mask is not None:
scaled_attention_logits += (mask * -1e9)
# ์ดํ
์
๊ฐ์ค์น ๊ณ์ฐ (์ํํธ๋งฅ์ค ํจ์ ์ ์ฉ)
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
# ์ดํ
์
๊ฐ์ค ํ๊ท ์ ์ฌ์ฉํ์ฌ ๊ฐ์ ๊ณ์ฐ
output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output, attention_weights
- ์ด ์ฝ๋๋ Dot-Product Attention์ ๊ตฌํํ ์ฝ๋์ ๋๋ค.
- ์์๋ ์ฃผ์ด์ง Query(์ฟผ๋ฆฌ, 'q'), Key(ํค, 'k') ๊ฐ์ Product๋ฅผ Keyํ๋ ฌ์ ์ ์นํ์ฌ Dot Product๋ฅผ ๊ณ์ฐํฉ๋๋ค.
- Dot Product๋ฅผ Key์ ์ฐจ์์๋ก ๋๋ ์ฃผ์ด ์ ๊ทํ ํฉ๋๋ค.
- Scaling์ ํตํด Gradient Vanishing / Exploding (๊ธฐ์ธ๊ธฐ ์์ค, ํญํ)์ ๋ฐฉ์งํฉ๋๋ค.
- Masking์ ์ ํจํ ๋ถ๋ถ์๋ง 'Attention'์ ์ ์ฉํฉ๋๋ค. -> ์ ํ์ ์ผ๋ก ์ดํ ์ ๋ง์คํฌ๋ฅผ ์ ์ฉํ๋ฉฐ, ์ ํจํ์ง ์์ ๋ถ๋ถ์๋ -1e9๋ฅผ ๋ํ์ฌ ๋ง์คํน ํฉ๋๋ค.
- Softmax ํจ์๋ฅผ ์ ์ฉํ์ฌ Attention Weight(์ดํ ์ ๊ฐ์ค์น)๋ฅผ ๊ณ์ฐํฉ๋๋ค. ์ฌ๊ธฐ์ Weight(๊ฐ์ค์น)๋ Query(์ฟผ๋ฆฌ), Key(ํค)๊ฐ์ ์ค์์ฑ์ ๋ํ๋ ๋๋ค.
- Attention Weight(๊ฐ์ค์น)๋ฅผ ์ฌ์ฉํ์ฌ Value(๋ฐธ๋ฅ)ํ๋ ฌ๊ณผ ํ๊ท ๊ฐ์ ๊ณ์ฐํ๊ณ 'output'์ ์ถ๋ ฅํฉ๋๋ค.
- 'output'์ Query์ ๋ํ Attention ๊ฐ์ค ํ๊ท ๊ฐ์ ๋๋ค. ๊ทธ๋ฆฌ๊ณ 'output'๊ณผ Attention Weight(๊ฐ์ค์น)๋ฅผ ๋ฐํํฉ๋๋ค.
4. Self-Attention
Self-Attention ๊ธฐ๋ฒ์ Attention ๊ธฐ๋ฒ์ ๋ง ๊ทธ๋๋ก ์๊ธฐ ์์ ์๊ฒ ์ํํ๋ Attention ๊ธฐ๋ฒ์ ๋๋ค.
- Input Sequence ๋ด์ ๊ฐ ์์๊ฐ์ ์๋์ ์ธ ์ค์๋๋ฅผ ๊ณ์ฐํ๋ ๋งค์ปค๋์ฆ์ด๋ฉฐ, Sequence์ ๋ค์ํ ์์น๊ฐ์ ์๋ก ๋ค๋ฅธ ๊ด๊ณ๋ฅผ ํ์ตํ ์ ์์ต๋๋ค.
- Sequence ์์ ๊ฐ์ด๋ฐ Task ์ํ์ ์ค์ํ Element(์์)์ ์ง์คํ๊ณ ๊ทธ๋ ์ง ์์ Element(์์)๋ ๋ฌด์ํฉ๋๋ค.
- ์ด๋ฌ๋ฉด Task๊ฐ ์ํํ๋ ์ฑ๋ฅ์ด ์์นํ ๋ฟ๋๋ฌ Decoding ํ ๋ Source Sequence ๊ฐ์ด๋ฐ ์ค์ํ Element(์์)๋ค๋ง ์ถ๋ฆฝ๋๋ค.
- ๋ฌธ๋งฅ์ ๋ฐ๋ผ ์ง์คํ ๋จ์ด๋ฅผ ๊ฒฐ์ ํ๋ ๋ฐฉ์์ ์๋ฏธ -> ์ค์ํ ๋จ์ด์๋ง ์ง์ค์ ํ๊ณ ๋๋จธ์ง๋ ๊ทธ๋ฅ ์ฝ์ต๋๋ค.
- ์ด ๋ฐฉ๋ฒ์ด ๋ฌธ๋งฅ์ ํ์ ํ๋ ํต์ฌ์ด๋ฉฐ, ์ด๋ฌํ ๋ฐฉ์์ Deep Learning ๋ชจ๋ธ์ ์ ์ฉํ๊ฒ์ด 'Attention' ๋งค์ปค๋์ฆ์ด๋ฉฐ, ์ด ๋งค์ปค๋์ฆ์ ์๊ธฐ ์์ ์๊ฒ ์ ์ฉํ๊ฒ์ด 'Self-Attention' ์ ๋๋ค. ํ๋ฒ ๋งค์ปค๋์ฆ์ ์์๋ฅผ ๋ค์ด์ ์ค๋ช ํด ๋ณด๊ฒ ์ต๋๋ค.
Self-Attention์ ๊ณ์ฐ ์์
Self-Attention์ Query(์ฟผ๋ฆฌ), Key(ํค), Value(๋ฐธ๋ฅ) 3๊ฐ์ง ์์๊ฐ ์๋ก ์ํฅ์ ์ฃผ๊ณ ๋ฐ๋ ๊ตฌ์กฐ์ ๋๋ค.
- ๋ฌธ์ฅ๋ด ๊ฐ ๋จ์ด๊ฐ Vector(๋ฒกํฐ) ํํ๋ก input(์ ๋ ฅ)์ ๋ฐ์ต๋๋ค.
* Vector: ์ซ์์ ๋์ด ์ ๋
- ๊ฐ ๋จ์ด์ Vector๋ 3๊ฐ์ง ๊ณผ์ ์ ๊ฑฐ์ณ์ ๋ฐํ์ด ๋ฉ๋๋ค.
- Query(์ฟผ๋ฆฌ) - ๋ด๊ฐ ์ฐพ๊ณ ์ ๋ณด๋ฅผ ์์ฒญํ๋๊ฒ ์ ๋๋ค.
- Key(ํค) - ๋ด๊ฐ ์ฐพ๋ ์ ๋ณด๊ฐ ์๋ ์ฐพ์๋ณด๋ ๊ณผ์ ์ ๋๋ค.
- Value(๋ฐธ๋ฅ) - ์ฐพ์์ ์ ๊ณต๋ ์ ๋ณด๊ฐ ๊ฐ์น ์๋์ง ํ๋จํ๋ ๊ณผ์ ์ ๋๋ค.
- ์์ ๊ทธ๋ฆผ์ ๋ณด๋ฉด ์ ๋ ฅ๋๋ ๋ฌธ์ฅ "์ด์ ์นดํ ๊ฐ์์ด ๊ฑฐ๊ธฐ ์ฌ๋ ๋ง๋๋ผ" ์ด 6๊ฐ ๋จ์ด๋ก ๊ตฌ์ฑ๋์ด ์๋ค๋ฉด?
- ์ฌ๊ธฐ์์ Self-Attention ๊ณ์ฐ ๋์์ Query(์ฟผ๋ฆฌ) Vector 6๊ฐ, Key(ํค) Vector 6๊ฐ, Value(๋ฐธ๋ฅ) Vector 6๊ฐ๋ฑ ๋ชจ๋ 18๊ฐ๊ฐ ๋ฉ๋๋ค.
- ์์ ํ๋ ๋ ์ธ๋ถ์ ์ผ๋ก ๋ํ๋ธ๊ฒ์ ๋๋ค. Self-Attention์ Query ๋จ์ด ๊ฐ๊ฐ์ ๋ํด ๋ชจ๋ Key ๋จ์ด์ ์ผ๋ง๋ ์ ๊ธฐ์ ์ธ ๊ด๊ณ๋ฅผ ๋งบ๊ณ ์๋์ง์ ํ๋ฅ ๊ฐ๋ค์ ํฉ์ด 1์ธ ํ๋ฅ ๊ฐ์ผ๋ก ๋ํ๋ ๋๋ค.
- ์ด๊ฒ์ ๋ณด๋ฉด Self-Attention ๋ชจ๋์ Value(๋ฐธ๋ฅ) Vector๋ค์ Weighted Sum(๊ฐ์คํฉ)ํ๋ ๋ฐฉ์์ผ๋ก ๊ณ์ฐ์ ๋ง๋ฌด๋ฆฌ ํฉ๋๋ค.
- ํ๋ฅ ๊ฐ์ด ๊ฐ์ฅ ๋์ ํค ๋จ์ด๊ฐ ์ฟผ๋ฆฌ ๋จ์ด์ ๊ฐ์ฅ ๊ด๋ จ์ด ๋์ ๋จ์ด๋ผ๊ณ ํ ์ ์์ต๋๋ค.
- ์ฌ๊ธฐ์๋ '์นดํ'์ ๋ํด์๋ง ๊ณ์ฐ ์์๋ฅผ ๋ค์์ง๋ง, ์ด๋ฌํ ๋ฐฉ์์ผ๋ก ๋๋จธ์ง ๋จ์ด๋ค์ค Self-Attention์ ๊ฐ๊ฐ ์ํํฉ๋๋ค.
Self-Attention์ ๋์ ๋ฐฉ์
Self-Attention์์ ๊ฐ์ฅ ์ค์ํ ๊ฐ๋ ์ Query(์ฟผ๋ฆฌ), Key(ํค), Value(๋ฐธ๋ฅ)์ ์์๊ฐ์ด ๋์ผํ๋ค๋ ์ ์ ๋๋ค.
- ๊ทธ๋ ๋ค๊ณ Query(์ฟผ๋ฆฌ), Key(ํค), Value(๋ฐธ๋ฅ)๊ฐ ๋์ผ ํ๋ค๋ ๋ง์ด ์ด๋๋๋ค.
- ๊ทธ๋ฆผ์ ๋ณด๋ฉด ๊ฐ์ค์น Weight W๊ฐ์ ์ํด์ ์ต์ข ์ ์ธ Query(์ฟผ๋ฆฌ), Key(ํค), Value(๋ฐธ๋ฅ)๊ฐ์ ์๋ก ๋ฌ๋ผ์ง๋๋ค.
Attention์ ๊ตฌํ๋ ๊ณต์์ ๋ณด๊ฒ ์ต๋๋ค.
- ์ผ๋จ Query(์ฟผ๋ฆฌ)๋ Key(ํค)๋ฅผ ๋ด์ ํด์ค๋๋ค. ์ด๋ ๊ฒ ๋ด์ ์ ํด์ฃผ๋ ์ด์ ๋ ๋ ์ฌ์ด์ ์ฐ๊ด์ฑ์ ๊ณ์ฐํ๊ธฐ ์ํด์์ ๋๋ค.
- ์ด ๋ด์ ๋ ๊ฐ์ "Attention Score"๋ผ๊ณ ํฉ๋๋ค. Dot-Product Attention ๋ถ๋ถ์์ ์์ธํ ์ค๋ช ํ์ง๋ง ์ด๋ฒ์๋ ๊ฐ๋จํ๊ฒ ์ค๋ช ํด ๋ณด๊ฒ ์ต๋๋ค.
- ๋ง์ฝ, Query(์ฟผ๋ฆฌ)๋ Key(ํค)์ Dimension(์ฐจ์)์ด ์ปค์ง๋ฉด, ๋ด์ ๊ฐ์ Attention Score๊ฐ ์ปค์ง๊ฒ ๋์ด์ ๋ชจ๋ธ์ด ํ์ตํ๋๋ฐ ์ด๋ ค์์ด ์๊น๋๋ค.
- ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด์ ์ฐจ์ d_k์ ๋ฃจํธ๋งํผ ๋๋์ด์ฃผ๋ Scaling ์์ ์ ์งํํฉ๋๋ค. ์ด๊ณผ์ ์ "Scaled Dot-Product Attention" ์ด๋ผ๊ณ ํฉ๋๋ค.
- ๊ทธ๋ฆฌ๊ณ Scaled Dot-Product Attention"์ ์งํํ ๊ฐ์ ์ ๊ทํ๋ฅผ ์งํํด์ฃผ๊ธฐ ์ํด์ Softmax ํจ์๋ฅผ ๊ฑฐ์ณ์ ๋ณด์ ์ ์ํด ๊ณ์ฐ๋ scoreํ๋ ฌ, valueํ๋ ฌ์ ๋ด์ ํฉ๋๋ค. ๊ทธ๋ฌ๋ฉด ์ต์ข ์ ์ผ๋ก Attention ํ๋ ฌ์ ์ป์ ์ ์๊ฒ ๋ฉ๋๋ค. ์์ ๋ฌธ์ฅ์ ๋ค์ด์ ์ค๋ช ํด ๋ณด๊ฒ ์ต๋๋ค.
"I am a student"๋ผ๋ ๋ฌธ์ฅ์ผ๋ก ์์๋ฅผ ๋ค์ด๋ณด๊ฒ ์ต๋๋ค.
- Self-Attention์ Query(์ฟผ๋ฆฌ), Key(ํค), Value(๋ฐธ๋ฅ) 3๊ฐ ์์ ์ฌ์ด์ ๋ฌธ๋งฅ์ ๊ด๊ณ์ฑ์ ์ถ์ถํฉ๋๋ค.
Q = X * Wq, K = X * Wk, W = X * Wv
- ์์ ์์์ฒ๋ผ Input Vector Sequence(์ ๋ ฅ ๋ฒกํฐ ์ํ์ค) X์ Query(์ฟผ๋ฆฌ), Key(ํค), Value(๋ฐธ๋ฅ)๋ฅผ ๋ง๋ค์ด์ฃผ๋ ํ๋ ฌ(W)๋ฅผ ๊ฐ๊ฐ ๊ณฑํด์ค๋๋ค.
- Input Vector Sequence(์ ๋ ฅ ๋ฒกํฐ ์ํ์ค)๊ฐ 4๊ฐ์ด๋ฉด ์ผ์ชฝ์ ์๋ ํ๋ ฌ์ ์ ์ฉํ๋ฉด Query(์ฟผ๋ฆฌ), Key(ํค), Value(๋ฐธ๋ฅ) ๊ฐ๊ฐ 4๊ฐ์ฉ, ์ด 12๊ฐ๊ฐ ๋์ต๋๋ค.
* Word Embedding: ๋จ์ด๋ฅผ Vector๋ก ๋ณํํด์ Dense(๋ฐ์ง)ํ Vector๊ณต๊ฐ์ Mapping ํ์ฌ ์ค์ Vector๋ก ํํํฉ๋๋ค.
- ๊ฐ ๋จ์ด์ ๋ฐํ์ฌ Word Embedding(๋จ์ด ์๋ฒ ๋ฉ)์ ํฉ๋๋ค. ๋จ์ด 'i'์ Embedding์ด [1,1,1,1]์ด๋ผ๊ณ ํ์ ๋, ์ฒ์ 'i'์ ์ฒ์ Query(์ฟผ๋ฆฌ), Key(ํค), Value(๋ฐธ๋ฅ)๋ฅผ ๊ฐ๊ฐ 'Q_i, original', 'K_i, original', 'V_i, original', ๋ผ๊ณ ํฉ๋๋ค.
- Embedding์ ๊ฐ์ด ๋ค [1,1,1,1]๋ก ๊ฐ์ ์ด์ ๋ Self-Attention ๋งค์ปค๋์ฆ์์๋ ๊ฐ์์ผ ํ๊ธฐ ๋๋ฌธ์ ๋ชจ๋ [1,1,1,1]์ด๋ผ๊ณ ๋์ผํฉ๋๋ค.
- ํ์ต๋ Weight(๊ฐ์ค์น)๊ฐ์ด 'WQ', 'WK', 'WV'๋ผ๊ณ ํ ๋ Original ๊ฐ๋ค๊ณผ ์ ๊ณฑ์ ํด์ฃผ๋ฉด ์ต์ข
์ ์ผ๋ก 'Q', 'K', 'V'๊ฐ์ด ๋์ถ๋ฉ๋๋ค.
- 'Q', 'K', 'V'๊ฐ์ ์ด์ฉํด์ ์์์ ์ค์ ํ ๋ณด์ ๋ 'Attention Score'๋ฅผ ๊ณฑํด์ฃผ๋ฉด ์๋์ ์ผ์ชฝ์๊ณผ ๊ฐ์ด 1.5๋ผ๋ ๊ฐ์ด ๋์ต๋๋ค.
- ํ๋ ฌ 'Q', 'K'๋ ์๋ก ์ ๊ณฑ ๊ณ์ฐํด์ฃผ๊ณ , ์ฌ๊ธฐ์ ํ๋ ฌ 'Q', 'K', 'V'์ Dimension(์ฐจ์)์ 4์ด๋ฏ๋ก ๋ฃจํธ 4๋ก ๋๋์ด์ค๋๋ค.
- 'i' ๋ฟ๋ง ์๋๋ผ ๋ชจ๋ ๋จ์ด๊ฐ์ 'Self-Attention'์ ํด์ฃผ๋ฉด ์์ ์ค๋ฅธ์ชฝ ํ๋ ฌ๊ณผ ๊ฐ์ ๊ฒฐ๊ณผ๊ฐ ๋์ต๋๋ค.
- ๊ฐ์ด๋ฐ ๋ ธ๋ฝ์ ๋ถ๋ถ์ ์๊ธฐ ์์ ์ ๋ํ 'Attention'์ด๋ฏ๋ก ๋น์ฐํ ๊ฐ์ด ์ ์ผ ํฌ๊ณ , ์์ชฝ ์ด๋ก์ ๋ถ๋ถ์ ๋ณด๋ฉด ์ ์๊ฐ ๋์ต๋๋ค.
- ์์ ๊ทธ๋ฆผ์ ๋จ์ด ํ๋ํ๋์ 'Attention'์ ๊ตฌํ๋ ๊ณผ์ ์ ๋์ํ ํ ๊ทธ๋ฆผ์ ๋๋ค.
- ์ค์ ๋ก๋ ์ฌ๋ฌ ๋จ์ด๋ฅผ ์์ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ํด์ ๊ณ์ฐ์ ํฉ๋๋ค.
- ๋ณ๋ ฌ์ฒ๋ฆฌ๋ฅผ ํ๋ฉด ์ฐ์ฐ์๋๊ฐ ๋นจ๋ฆฌ์ง๋ ์ด์ ์ด ์๊ธฐ ๋๋ฌธ์ ๋๋ค. ๊ฐ๋จํ 'Self-Attention' ๊ณผ์ ์ ์์ฝํด์ ์ค๋ช ํด๋ณด๊ฒ ์ต๋๋ค.
Self-Attention ๊ณผ์ Summary
1. ์ํ๋ ๋ฌธ์ฅ์ ์๋ฒ ๋ฉํ๊ณ ํ์ต์ ํตํด ๊ฐ Query, Key, Value์ ๋ง๋ weight๋ค์ ๊ตฌํด์ค.
2. ๊ฐ ๋จ์ด์ ์๋ฒ ๋ฉ์ Query, Key, Value(Query = Key = Value)์ weight๋ฅผ ์ ๊ณฑ(๋ด์ )ํด ์ต์ข Q, K, V๋ฅผ ๊ตฌํจ.
3. Attention score ๊ณต์์ ํตํด ๊ฐ ๋จ์ด๋ณ Self Attention value๋ฅผ ๋์ถ
4. Self Attention value์ ๋ด๋ถ๋ฅผ ๋น๊ตํ๋ฉด์ ์๊ด๊ด๊ณ๊ฐ ๋์ ๋จ์ด๋ค์ ๋์ถ
Self-Attention Example Code
PyTorch ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ๋ถ๋ฌ์์ ํ๋ฒ ์์ ์ฝ๋๋ฅผ ๋๋ ค๋ณด๊ฒ ์ต๋๋ค.
import torch
1. ๋ณ์์ ์
- PyTorch ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ํ์ฉํด์ ์ฝ๋๋ก ๋ณด๊ฒ ์ต๋๋ค. ์ฐ์ Input Vector Sequence X์ Query(์ฟผ๋ฆฌ), Key(ํค), Value(๋ฐธ๋ฅ)๊ตฌ์ถ์ ํ์ํ ํ๋ ฌ๋ค์ ์ ์ํฉ๋๋ค.
x = torch.tensor([
[1.0, 0.0, 1.0, 0.0],
[0.0, 2.0, 0.0, 2.0],
[1.0, 1.0, 1.0, 1.0],
])
w_query = torch.tensor([
[1.0, 0.0, 1.0],
[1.0, 0.0, 0.0],
[0.0, 0.0, 1.0],
[0.0, 1.0, 1.0]
])
w_key = torch.tensor([
[0.0, 0.0, 1.0],
[1.0, 1.0, 0.0],
[0.0, 1.0, 0.0],
[1.0, 1.0, 0.0]
])
w_value = torch.tensor([
[0.0, 2.0, 0.0],
[0.0, 3.0, 0.0],
[1.0, 0.0, 3.0],
[1.0, 1.0, 0.0]
])
2. Query(์ฟผ๋ฆฌ), Key(ํค), Value(๋ฐธ๋ฅ) ๋ง๋ค๊ธฐ
- Vector Sequence๋ก Query(์ฟผ๋ฆฌ), Key(ํค), Value(๋ฐธ๋ฅ)๋ฅผ ๋ง๋ญ๋๋ค.
- ์ฌ๊ธฐ์ torch.matmul์ ํ๋ ฌ๊ณฑ์ ์ํํ๋ ํจ์์ ๋๋ค.
keys = torch.matmul(x, w_key)
querys = torch.matmul(x, w_query)
values = torch.matmul(x, w_value)
3. Attention Score ๋ง๋ค๊ธฐ
- ์์์ ๋ง๋ Query(์ฟผ๋ฆฌ), Key(ํค) Vector๋ค์ ํ๋ ฌ๊ณฑํด์ Attention Score๋ฅผ ๋ง๋๋ ๊ณผ์ ์ ๋๋ค.
- ์ฌ๊ธฐ์ keys.T๋ Key(ํค) Vector๋ค์ Transpose(์ ์น)ํ ํ๋ ฌ์ ๋๋ค.
attn_scores = torch.matmul(querys, keys.T)
attn_scores
tensor([[ 2., 4., 4.],
[ 4., 16., 12.],
[ 4., 12., 10.]])
4. Softmax ํ๋ฅ ๊ฐ ๋ง๋ค๊ธฐ
- Key(ํค) Vector์ ์ฐจ์ฐ๋์์ ์ ๊ณฑ๊ทผ์ผ๋ก ๋๋ ์ค๋ค Softmax๋ฅผ ์ทจํ๋ ๊ณผ์ ์ ๋๋ค.
import numpy as np
from torch.nn.functional import softmax
key_dim_sqrt = np.sqrt(keys.shape[-1])
attn_scores_softmax = softmax(attn_scores / key_dim_sqrt, dim=-1)
- keys๋ Attention ๋ฉ์ปค๋์ฆ์์ ์ฌ์ฉ๋๋ ํค(key) ๊ฐ๋ค์ ๋ํ๋ด๋ Tensor์ ๋๋ค.
- keys.shape[-1]๋ keys Tensor์ ๋ง์ง๋ง Dimension(์ฐจ์)์ ํฌ๊ธฐ, ์ฆ ํค์ ์ฐจ์(dimension) ์๋ฅผ ๋ํ๋ ๋๋ค.
- key_dim_sqrt๋ ํค ์ฐจ์์ ์ ๊ณฑ๊ทผ์ ๋ํ ๊ฐ์ผ๋ก, ์ดํ ์ ์ค์ฝ์ด๋ฅผ ์ ๊ทํํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.
- ์ ๊ทํ๋ ์ดํ ์ ์ค์ฝ์ด๋ฅผ ํค ์ฐจ์์ ์ ๊ณฑ๊ทผ์ผ๋ก ๋๋์ด์ฃผ๋ ๊ณผ์ ์ ๋๋ค.
- ์ด๋ ๊ฒ ํจ์ผ๋ก์จ ์ดํ ์ ์ ํฌ๊ธฐ์ ๋ฐ๋ผ ์ ๊ทํ๊ฐ ์ด๋ฃจ์ด์ง๋ฉฐ, ์๋ ด์ ๋๋ชจํ๊ณ ์์ ์ฑ์ ์ ๊ณตํฉ๋๋ค.
attn_scores_softmax
tensor([[1.3613e-01, 4.3194e-01, 4.3194e-01],
[8.9045e-04, 9.0884e-01, 9.0267e-02],
[7.4449e-03, 7.5471e-01, 2.3785e-01]])
5. Softmax ํ๋ฅ , Value๋ฅผ ๊ฐ์คํฉํ๊ธฐ
- Softmax ํ๋ฅ ๊ณผ Value Vector๋ค์ ๊ฐ์คํฉ ํ๋ ๊ณผ์ ์ ๋๋ค.
- Self-Attention์ ํ์ต ๋์์ Query(์ฟผ๋ฆฌ), Key(ํค) Value(๋ฐธ๋ฅ)๋ฅผ ๋ง๋๋ Weight(๊ฐ์ค์น) ํ๋ ฌ์ ๋๋ค.
- ์ฝ๋์์๋ 'w_query', 'w_key', 'w_value' ์ ๋๋ค.
- ์ด๋ค์ Task(์: ๊ธฐ๊ณ๋ฒ์ญ)๋ฅผ ๊ฐ์ฅ ์ ์ํํ๋ ๋ฐฉํฅ์ผ๋ก ํ์ต ๊ณผ์ ์์ ์ ๋ฐ์ดํธ ๋ฉ๋๋ค.
weighted_values = torch.matmul(attn_scores_softmax, values)
weighted_values
tensor([[1.8639, 6.3194, 1.7042],
[1.9991, 7.8141, 0.2735],
[1.9926, 7.4796, 0.7359]])
Self-Attention Model Example Code
PyTorch๋ฅผ ์ฌ์ฉํ์ฌ Self-Attention์ผ๋ก ๊ตฌํํ์์ต๋๋ค.
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size # ์๋ฒ ๋ฉ ์ฐจ์
self.heads = heads # ํค๋ ์
self.head_dim = embed_size // heads # ํค๋ ๋น ์๋ฒ ๋ฉ ์ฐจ์
assert (
self.head_dim * heads == embed_size
),
# ๊ฐ, ํค, ์ฟผ๋ฆฌ๋ฅผ ์ํ ์ ํ ๋ ์ด์ด
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
# ์ดํ
์
์ดํ ์ต์ข
์ถ๋ ฅ์ ์ํ ์ ํ ๋ ์ด์ด
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, query, mask):
# ํ๋ จ ์์ ์ ํ์ธ
N = query.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
# ์๋ฒ ๋ฉ์ ํค๋ ๋น ๋ค๋ฅธ ๋ถ๋ถ์ผ๋ก ๋๋๊ธฐ
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = query.reshape(N, query_len, self.heads, self.head_dim)
# ๊ฐ, ํค, ์ฟผ๋ฆฌ์ ๋ํ ์ ํ ํ๋ก์ ์
values = self.values(values)
keys = self.keys(keys)
queries = self.queries(queries)
# ์๋์ง(์ดํ
์
์ค์ฝ์ด) ๊ณ์ฐ์ ์ํด Einsum ์ฌ์ฉ
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
# ๋ง์คํฌ ์ ์ฉ ์ฌ๋ถ ํ์ธ
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
# ์ํํธ๋งฅ์ค๋ฅผ ์ ์ฉํ์ฌ ์ดํ
์
์ค์ฝ์ด ์ ๊ทํ
attention = torch.nn.functional.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
# ์ดํ
์
์ค์ฝ์ด๋ฅผ ์ฌ์ฉํ์ฌ ๊ฐ์ ๊ฐ์ค ํ๊ท ๊ณ์ฐ
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
N, query_len, self.heads * self.head_dim
)
# ์ต์ข
์ถ๋ ฅ์ ์ํ ์ ํ ๋ ์ด์ด ์ ์ฉ
out = self.fc_out(out)
return out
- ์ด ์ฝ๋๋ Self-Attention ๋ชจ๋ธ์ ๊ตฌํํ ์ฝ๋์ ๋๋ค.
- ์์๋ ๊ฐ๊ฐ์ Query, Key, Value์ ๋ํด ์ ํ ๋ ์ด์ด๋ฅผ ํตํด ์ ์ ํ ์ฐจ์์ผ๋ก ๋งคํํฉ๋๋ค.
- ๊ฐ Head์ ๋ํ ๋ถ๋ถ์ผ๋ก Embedding์ ๋๋์ด์ ๊ด๋ฆฌํฉ๋๋ค.
- Attention Score ๊ณ์ฐ์ ์ํด์ Einsum์ ์ฌ์ฉํฉ๋๋ค. - Query(์ฟผ๋ฆฌ), Key(ํค)๊ฐ์ Attention Score ๊ณ์ฐ.
- ์ ํ์ ์ผ๋ก ์ ์ฉ๋ 'Attention Score'๋ฅผ ์ฌ์ฉํ์ฌ ์ ํจํ์ง ์์ ๋ถ๋ถ์ masking ํฉ๋๋ค.
- Attention Score๋ฅผ Softmax ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ์ ๊ทํํ๊ณ , Value(๋ฐธ๋ฅ)์ด ๊ฐ์ค ํ๊ท ์ ๊ณ์ฐํฉ๋๋ค.
- ๊ทธ๋ฆฌ๊ณ ๊ณ์ฐ๋ ๊ฒฐ๊ณผ์ ๋ํด์ ์ ํ ๋ ์ด์ด๋ฅผ ์ ์ฉํ์ฌ ์ต์ข ์ถ๋ ฅ์ ์ป์ด์ ๋ฐํํฉ๋๋ค.
5. Multi-Head Attention
Multi-Head Attention์ ์ฌ๋ฌ๊ฐ์ 'Attention Head'๋ฅผ ๊ฐ์ง๊ณ ์์ต๋๋ค. ์ฌ๊ธฐ์ ๊ฐ 'Head Attention'์์ ๋์จ ๊ฐ์ ์ฐ๊ฒฐํด ์ฌ์ฉํ๋ ๋ฐฉ์์ ๋๋ค.
- ํ๋ฒ์ 'Attention'์ ์ด์ฉํ์ฌ ํ์ต์ ์ํค๋๊ฒ ๋ณด๋ค๋ 'Attention'์ ๋ณ๋ ฌ๋ก ์ฌ๋ฌ๊ฐ ์ฌ์ฉํ๋ ๋ฐฉ์์ ๋๋ค.
- ์์๋ ์๋์ Query(์ฟผ๋ฆฌ), Key(ํค) Value(๋ฐธ๋ฅ) ํ๋ ฌ ๊ฐ์ Head์ ๋งํผ ๋ถํ ํฉ๋๋ค.
- ๋ถํ ํ ํ๋ ฌ ๊ฐ์ ํตํด, ๊ฐ 'Attention' Value๊ฐ๋ค์ ๋์ถํฉ๋๋ค.
- ๋์ถ๋ 'Attention value'๊ฐ๋ค์ concatenate(์์ ํฉ์น๊ธฐ)ํ์ฌ์ ์ต์ข 'Attention Value'๋ฅผ ๋์ถํฉ๋๋ค.
- [4x4] ํฌ๊ธฐ์ ๋ฌธ์ฅ Embedding Vector์ [4x8]์ Query(์ฟผ๋ฆฌ), Key(ํค) Value(๋ฐธ๋ฅ)๊ฐ ์์ ๋, ์ผ๋ฐ์ ์ธ ํ ๋ฒ์ ๊ณ์ฐํ๋ Attention ๋ฉ์ปค๋์ฆ์ [4x4]*[4x8]=[4x8]์ 'Attention Value'๊ฐ ํ ๋ฒ์ ๋์ถ๋ฉ๋๋ค.
- 'Multi-Head Attention' ๋งค์ปค๋์ฆ์ผ๋ก ๋ณด๋ฉด ์ฌ๊ธฐ์ Head๋ 4๊ฐ ์ ๋๋ค. 'I, am, a, student'
- Head๊ฐ 4๊ฐ ์ด๋ฏ๋ก ๊ฐ ์ฐ์ฐ๊ณผ์ ์ด 1/4๋งํผ ํ์ํฉ๋๋ค.
- ์์ ๊ทธ๋ฆผ์ผ๋ก ๋ณด๋ฉด ํฌ๊ธฐ๊ฐ [4x8]์ด์๋, Query(์ฟผ๋ฆฌ), Key(ํค) Value(๋ฐธ๋ฅ)๋ฅผ 4๋ฑ๋ถ ํ์ฌ [4x2]๋ก ๋ง๋ญ๋๋ค. ๊ทธ๋ฌ๋ฉด ์ฌ๊ธฐ์์ 'Attention Value'๋ [4x2]๊ฐ ๋ฉ๋๋ค.
- ์ด 'Attention Value'๋ค์ ๋ง์ง๋ง์ผ๋ก Concatenate(ํฉ์ณ์ค๋ค)ํฉ์ณ์ฃผ๋ฉด, ํฌ๊ธฐ๊ฐ [4x8]๊ฐ ๋์ด ์ผ๋ฐ์ ์ธ Attention ๋งค์ปค๋์ฆ์ ๊ฒฐ๊ณผ๊ฐ๊ณผ ๋์ผํ๊ฒ ๋ฉ๋๋ค. ์์๋ฅผ ํ๋ฒ ๋ค์ด๋ณด๊ฒ ์ต๋๋ค.
Summary: Query(์ฟผ๋ฆฌ), Key(ํค), Value(๋ฐธ๋ฅ)๊ฐ์ ํ ๋ฒ์ ๊ณ์ฐํ์ง ์๊ณ head ์๋งํผ ๋๋ ๊ณ์ฐ ํ ๋์ค์ Attention Value๋ค์ ํฉ์น๋ ๋ฉ์ปค๋์ฆ. ํ๋ง๋๋ก ๋ถํ ๊ณ์ฐ ํ ํฉ์ฐํ๋ ๋ฐฉ์.
Multi-Head Attention Example
- ์ ๋ ฅ ๋จ์ด ์๋ 2๊ฐ, ๋ฐธ๋ฅ์ ์ฐจ์์๋ 3, ํค๋๋ 8๊ฐ์ธ ๋ฉํฐ-ํค๋ ์ดํ ์ ์ ๋ํ๋ธ ๊ทธ๋ฆผ ์ ๋๋ค.
- ๊ฐ๋ณ ํค๋์ ์ ํ ์ดํ ์ ์ํ ๊ฒฐ๊ณผ๋ ‘์ ๋ ฅ ๋จ์ด ์ ×, ๋ฐธ๋ฅ ์ฐจ์์’, ์ฆ 2×3 ํฌ๊ธฐ๋ฅผ ๊ฐ๋ ํ๋ ฌ์ ๋๋ค.
- 8๊ฐ ํค๋์ ์ ํ ์ดํ ์ ์ํ ๊ฒฐ๊ณผ๋ฅผ ๋ค์ ๊ทธ๋ฆผ์ โ ์ฒ๋ผ ์ด์ด ๋ถ์ด๋ฉด 2×24 ์ ํ๋ ฌ์ด ๋ฉ๋๋ค.
- Multi-Head Attention์ ์ต์ข ์ํ ๊ฒฐ๊ณผ๋ '์ ๋ ฅ ๋จ์ด ์' x '๋ชฉํ ์ฐจ์ ์' ์ด๋ฉฐ, Encoder, Decoder Block ๋ชจ๋์ ์ ์ฉ๋ฉ๋๋ค.
Multi-Head Attention์ ๊ฐ๋ณ Head์ Self-Attention ์ํ ๊ฒฐ๊ณผ๋ฅผ ์ด์ด ๋ถ์ธ ํ๋ ฌ (โ )์ W0๋ฅผ ํ๋ ฌ๊ณฑํด์ ๋ง๋ฌด๋ฆฌ ๋๋ค.
→ ์ ํ ์ดํ ์ ์ํ ๊ฒฐ๊ณผ ํ๋ ฌ์ ์ด(column)์ ์ × ๋ชฉํ ์ฐจ์์
Multi-Head Attention Model Example Code
Multi-Head Attention Model ์์ ์ฝ๋์ ๋๋ค. TF(Tensorflow)๋ก ๊ตฌํํ์ต๋๋ค.
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, **kargs):
super(MultiHeadAttention, self).__init__()
self.num_heads = kargs['num_heads']
self.d_model = kargs['d_model']
assert self.d_model % self.num_heads == 0
self.depth = self.d_model // self.num_heads
# Query, Key, Value์ ๋ํ ๊ฐ์ค์น ํ๋ ฌ ์ด๊ธฐํ
self.wq = tf.keras.layers.Dense(kargs['d_model']) # Query์ ๋ํ ๊ฐ์ค์น
self.wk = tf.keras.layers.Dense(kargs['d_model']) # Key์ ๋ํ ๊ฐ์ค์น
self.wv = tf.keras.layers.Dense(kargs['d_model']) # Value์ ๋ํ ๊ฐ์ค์น
# ์ต์ข
์ถ๋ ฅ์ ์ํ ๊ฐ์ค์น ํ๋ ฌ ์ด๊ธฐํ
self.dense = tf.keras.layers.Dense(kargs['d_model'])
def split_heads(self, x, batch_size):
# ๋ง์ง๋ง ์ฐจ์์ (num_heads, depth)๋ก ๋๋๊ณ ๊ฒฐ๊ณผ๋ฅผ (batch_size, num_heads, seq_len, depth)๋ก ๋ณํ
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask):
batch_size = tf.shape(q)[0]
# Query, Key, Value์ ๋ํ ๊ฐ์ค์น ํ๋ ฌ ์ ์ฉ
q = self.wq(q) # Query์ ๋ํ ๊ฐ์ค์น ์ ์ฉ
k = self.wk(k) # Key์ ๋ํ ๊ฐ์ค์น ์ ์ฉ
v = self.wv(v) # Value์ ๋ํ ๊ฐ์ค์น ์ ์ฉ
# ๊ฐ ํค๋๋ก ์๋ฒ ๋ฉ์ ๋๋๊ธฐ
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
# Scaled Dot-Product Attention ์ํ
# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights = scaled_dot_product_attention(
q, k, v, mask)
# ์ดํ
์
ํค๋๋ฅผ ๋ค์ ํฉ์น๊ณ ์ต์ข
์ถ๋ ฅ์ผ๋ก ๋ณํ
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
concat_attention = tf.reshape(scaled_attention,
(batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
# ์ต์ข
์ถ๋ ฅ ๊ณ์ฐ
output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
return output, attention_weights
- ์ด ์ฝ๋๋ Multi-Head Attention Model์ ๊ตฌํํ ์ฝ๋์ ๋๋ค.
- ํด๋์ค๋ฅผ ์ด๊ธฐํํ๋ฉด์ ํ์ํ ํ๋ผ๋ฏธํฐ๋ค์ ์ค์ ํ๊ณ , ๊ฐ๊ฐ์ Query, Key, Value, ์ต์ข ์ถ๋ ฅ์ ์ฌ์ฉ๋ ๊ฐ์ค์น๋ค์ ์ ์ํฉ๋๋ค.
- ์ ๋ ฅ ํ ์์ ๋ง์ง๋ง ์ฐจ์์ ์ฌ๋ฌ Heads๋ก ๋๋๊ณ , ๊ฒฐ๊ณผ๋ฅผ ์ ์ ํ๊ฒ ๋ณํํ์ฌ Heads์ ์ฐจ์์ ์์ผ๋ก ๊ฐ์ ธ์ต๋๋ค.
- Query, Key, Value์ ๊ฐ๊ฐ์ ๊ฐ์ค์น๋ฅผ ์ ์ฉํ๊ณ , Heads๋ก ๋ถํ ํฉ๋๋ค.
- Scaled Dot-Product Attention์ ํธ์ถํ๊ณ , ๊ฒฐ๊ณผ๋ฅผ ์ฒ๋ฆฌํฉ๋๋ค.
- ์ต์ข ์ถ๋ ฅ์ ๊ณ์ฐํ๊ณ ๋ฐํํฉ๋๋ค.
- Ps. ์ด๋ ๊ฒ Attention์ ๋ฐํด์ ์ค๋ช ํ ๊ธ์ ์ผ๋๋ฐ, ๋ด์ฉ๋ ๋ ์ธ๊ฒ ์์ง๋ง ๋๋ฌด ๊ธธ์ด์ง๊ฑฐ ๊ฐ์์ ์ฌ๊ธฐ๊น์ง๋ง ์ฐ๊ณ ๋ค์ ๊ธ๋ก ๋๊ธฐ๊ฒ ์ต๋๋ค.
- ๋ค์๊ธ์๋ ํฉ์ฑ๊ณฑ, ์ํ ์ ๊ฒฝ๋ง์ด๋ ๋น๊ตํ๊ฑฐ๋, Encoder, Decoder์์ ์ํํ๋ Self-Attention์ ๋ฐํ์ฌ ์์๋ณด๊ฒ ์ต๋๋ค.
๋ฐ์ํ
'๐ NLP (์์ฐ์ด์ฒ๋ฆฌ) > ๐ Natural Language Processing' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[NLP] Transformer Model - ํธ๋์คํฌ๋จธ ๋ชจ๋ธ ์์๋ณด๊ธฐ (0) | 2024.03.07 |
---|---|
[NLP] ํฉ์ฑ๊ณฑ, ์ํ์ ๊ฒฝ๋ง, Encoder, Decoder์์ ์ํํ๋ Self-Attention (0) | 2024.03.01 |
[NLP] Word Embedding - ์๋ ์๋ฒ ๋ฉ (0) | 2024.02.12 |
[NLP] Word2Vec, CBOW, Skip-Gram - ๊ฐ๋ & Model (0) | 2024.02.03 |
[NLP] GRU Model - LSTM Model์ ๊ฐ๋ณ๊ฒ ๋ง๋ ๋ชจ๋ธ (0) | 2024.01.30 |