A A
[NLP] Attention - ์–ดํ…์…˜

1. Attention

Attention์€ CS ๋ฐ ML์—์„œ ์ค‘์š”ํ•œ ๊ฐœ๋…์ค‘ ํ•˜๋‚˜๋กœ ์—ฌ๊ฒจ์ง‘๋‹ˆ๋‹ค. Attention์˜ ๋งค์ปค๋‹ˆ์ฆ˜์€ ์ฃผ๋กœ Sequence Data๋ฅผ ์ฒ˜๋ฆฌํ•˜๊ฑฐ๋‚˜ ์ƒ์„ฑํ•˜๋Š” ๋ชจ๋ธ์—์„œ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. -> Sequence ์ž…๋ ฅ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๋จธ์‹ ๋Ÿฌ๋‹ ํ•™์Šต ๋ฐฉ๋ฒ•์˜ ์ผ์ข…
  • Attention์˜ ๊ฐœ๋…์€ Decoder์—์„œ ์ถœ๋ ฅ์„ ์˜ˆ์ธกํ•˜๋Š” ๋งค์‹œ์ (time step)๋งˆ๋‹ค, Encoder์—์„œ์˜ ์ „์ฒด์˜ ์ž…๋ ฅ ๋ฌธ์žฅ์„ ๋‹ค์‹œ ํ•œ๋ฒˆ ์ฐธ๊ณ ํ•˜๊ฒŒ ํ•˜๋Š” ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค.
  • ๋‹จ, ์ „์ฒด ์ž…๋ ฅ ๋ฌธ์žฅ์„ ์ „๋ถ€ ๋‹ค ์ข…์ผํ•œ ๋น„์œจ๋กœ ์ฐธ๊ณ ํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ, ํ•ด๋‹น ์‹œ์ ์—์„œ ์˜ˆ์ธกํ•ด์•ผ ํ•  ์š”์†Œ์™€ ์—ฐ๊ด€์ด ์žˆ๋Š” ์ž…๋ ฅ ์š”์†Œ ๋ถ€๋ถ„์„ Attention(์ง‘์ค‘)ํ•ด์„œ ๋ณด๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค.
    • ์ด ๋ฐฉ๋ฒ•์ด ๋ฌธ๋งฅ์„ ํŒŒ์•…ํ•˜๋Š” ํ•ต์‹ฌ์˜ ๋ฐฉ๋ฒ•์ด๋ฉฐ, ์ด๋Ÿฌํ•œ ๋ฐฉ์‹์„ DL(๋”ฅ๋Ÿฌ๋‹)๋ชจ๋ธ์— ์ ์šฉํ•œ๊ฒƒ์ด 'Attention" ๋ฉ”์ปค๋‹ˆ์ฆ˜์ž…๋‹ˆ๋‹ค.
    • ์ฆ‰, ์ด๋ง์€ Task ์ˆ˜ํ–‰์— ํฐ ์˜ํ–ฅ์„ ์ฃผ๋Š” Element(์š”์†Œ)์— Weight(๊ฐ€์ค‘์น˜)๋ฅผ ํฌ๊ฒŒ ์ฃผ๊ณ , ๊ทธ๋ ‡์ง€ ์•Š์€ Weight(๊ฐ€์ค‘์น˜)๋Š” ๋‚ฎ๊ฒŒ ์ค๋‹ˆ๋‹ค.

Attention์ด ๋“ฑ์žฅํ•œ ์ด์œ ?

  • Attention ๊ธฐ๋ฒ•์ด ๋“ฑ์žฅํ•œ ์ด์œ ์— ๋Œ€ํ•ด์„œ ํ•œ๋ฒˆ ์•Œ์•„๋ณด๋ฉด ์•ž์„œ ๋ดค๋˜ ๊ธ€์ค‘์— Seq2Seq Model์— ๋Œ€ํ•œ ๊ธ€์ด ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค.
  • ๊ฐœ๋…์— ๋ฐํ•˜์—ฌ ์„ค๋ช…์„ ํ•ด๋ณด๋ฉด, Seq2Seq๋Š” Encoder์—์„œ Input Sequence(์ž…๋ ฅ ์‹œํ€€์Šค)๋ฅผ ํ•˜๋‚˜์˜ ๊ณ ์ •๋œ ํฌ๊ธฐ์˜ Vector๋กœ ์••์ถ•ํ•ฉ๋‹ˆ๋‹ค. ์ด๋ฅผ Context Vector๋ผ๊ณ  ํ•˜๊ณ , Decoder๋Š” ์ด Context Vector๋ฅผ ํ™œ์šฉํ•ด์„œ Output Sequence(์ถœ๋ ฅ ์‹œํ€€์Šค)๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
  • ๊ทผ๋ฐ, Seq2Seq ๋ชจ๋ธ์€ RNN ๋ชจ๋ธ์— ๊ธฐ๋ฐ˜ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•˜๋Š”๋ฐ
    • ํ•˜๋‚˜๋Š” Context Vector์˜ ํฌ๊ธฐ๊ฐ€ ๊ณ ์ •๋˜์–ด ์žˆ์œผ๋ฏ€๋กœ ๊ณ ์ •๋œ ํฌ๊ธฐ์˜ Vector์— Input์œผ๋กœ ๋“ค์–ด์˜ค๋Š” ์ •๋ณด๋ฅผ ์••์ถ•ํ•˜๋ ค๊ณ  ํ•˜๋‹ˆ๊นŒ ์ •๋ณด์†์‹ค์ด ๋ฐœ์ƒํ• ์ˆ˜๋„ ์žˆ๋‹ค๋Š”์ .
    • ๋‹ค๋ฅธ ํ•˜๋‚˜๋Š” RNN์ด๋‚˜ RNN๊ธฐ๋ฐ˜ ๋ชจ๋ธ๋“ค์˜ ๊ณ ์งˆ์ ์ธ ๋ฌธ์ œ๋Š” Gradient Vanishing(๊ธฐ์šธ๊ธฐ ์†Œ์‹ค)๋ฌธ์ œ๊ฐ€ ์กด์žฌํ•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.
  • ๊ทธ๋ž˜์„œ Seq2Seq ๋ชจ๋ธ์€ Input Sequence(์ž…๋ ฅ ์‹œํ€€์Šค)๊ฐ€ ๊ธธ์–ด์ง€๋ฉด Output Sequence(์ถœ๋ ฅ ์‹œํ€€์Šค)์˜ ์ •ํ™•๋„๊ฐ€ ๋–จ์–ด์ง‘๋‹ˆ๋‹ค.
  • ๊ทธ๋ž˜์„œ ์ •ํ™•๋„๋ฅผ ๋–จ์–ด์ง€๋Š”๊ฒƒ์„ ์–ด๋Š์ •๋„ ๋ฐฉ์ง€ํ•ด์ฃผ๊ธฐ ์œ„ํ•ด์„œ ๋“ฑ์žฅํ•œ ๋ฐฉ๋ฒ•์ด Attention ์ž…๋‹ˆ๋‹ค.
Seq2seq, Encoder, Decoder ๊ฐœ๋…์„ ์ •๋ฆฌํ•ด ๋†“์€ ๊ธ€์„ ๋งํฌํ•ด ๋†“๊ฒ ์Šต๋‹ˆ๋‹ค. ์ฐธ๊ณ ํ•ด์ฃผ์„ธ์š”!
 

[NLP] Seq2Seq, Encoder & Decoder

1..sequence-to-sequence ๐Ÿ’ก ํŠธ๋žœ์Šคํฌ๋จธ(Transformer) ๋ชจ๋ธ์€ ๊ธฐ๊ณ„ ๋ฒˆ์—ญ(machine translation) ๋“ฑ ์‹œํ€€์Šค-ํˆฌ-์‹œํ€€์Šค(sequence-to-sequence) ๊ณผ์ œ๋ฅผ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์œ„ํ•œ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. sequence: ๋‹จ์–ด ๊ฐ™์€ ๋ฌด์–ธ๊ฐ€์˜ ๋‚˜์—ด์„ ์˜๋ฏธํ•ฉ

daehyun-bigbread.tistory.com


2. Attention Function

  • Attention ํ•จ์ˆ˜์— ๋ฐํ•˜์—ฌ ์•Œ์•„๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. Attention์„ ํ•จ์ˆ˜๋กœ ํ‘œํ˜„ํ•˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ‘œํ˜„๋ฉ๋‹ˆ๋‹ค.

Attention ํ•จ์ˆ˜์˜ ๋Œ€๋žต์ ์ธ Flow. ์ถœ์ฒ˜: https://wikidocs.net/22893

Attention(Q, K, V) = Attention Value

Q = Query : t ์‹œ์ ์˜ Decoder Cell ์—์„œ์˜ Hidden state(์€๋‹‰ ์ƒํƒœ)
K = Keys : ๋ชจ๋“  ์‹œ์ ์˜ Encoder Cell ์…€์˜ Hidden state(์€๋‹‰ ์ƒํƒœ)๋“ค
V = Values : ๋ชจ๋“  ์‹œ์ ์˜ Encdoer Cell ์…€์˜ Hidden state(์€๋‹‰ ์ƒํƒœ)๋“ค
  • Attention function์€ ์ฃผ์–ด์ง„ 'Query(์ฟผ๋ฆฌ)'์— ๋Œ€ํ•ด์„œ ๋ชจ๋“  'Key(ํ‚ค)'์™€์˜ ์œ ์‚ฌ๋„๋ฅผ ๊ฐ๊ฐ ๊ตฌํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ๊ตฌํ•œ ์ด ์œ ์‚ฌ๋„๋ฅผ ํ‚ค์™€ ๋งตํ•‘ ๋˜์–ด์žˆ๋Š” ๊ฐ๊ฐ์˜ 'Value(๊ฐ’)'์— ๋ฐ˜์˜ํ•ฉ๋‹ˆ๋‹ค. 
  • ๊ทธ๋ฆฌ๊ณ  ์œ ์‚ฌ๋„๊ฐ€ ๋ฐ˜์˜๋œ 'Value(๊ฐ’)'๋“ค์„ ๋”ํ•œํ›„ Returnํ•ฉ๋‹ˆ๋‹ค. ์ด Returnํ•œ ๊ฐ’์€ 'Attention Value(์–ดํ…์…˜ ๊ฐ’)'์ด๋ผ๊ณ  ํ•ฉ๋‹ˆ๋‹ค. ํ•œ๋ฒˆ Attention์˜ ์˜ˆ์ œ๋“ค์„ ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. Scaled Dot-Product Attention, Self-Attention, Multi-Head Attention3๊ฐœ๋ฅผ ํ•œ๋ฒˆ ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

3. Dot-Product Attention

Attention์˜ ์ข…๋ฅ˜๊ฐ€ ์žˆ๋Š”๋ฐ, ๊ทธ ์ค‘์˜ ํ•˜๋‚˜์ธ Dot-Product Attention์— ๋Œ€ํ•ด ์•Œ์•„๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

LSTM Cell์—์„œ ์ถœ๋ ฅ ๋‹จ์–ด๋ฅผ ์˜ˆ์ธกํ• ๋•Œ Attention ๋งค์ปค๋‹ˆ์ฆ˜ ์‚ฌ์šฉ.  ์ถœ์ฒ˜: https://wikidocs.net/22893

  • Decoder์˜ 3๋ฒˆ์งธ LSTM Cell์—์„œ ์ถœ๋ ฅ ๋‹จ์–ด๋ฅผ ์˜ˆ์ธกํ• ๋•Œ, Attention ๋งค์ปค๋‹ˆ์ฆ˜์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
  • ์•ž์˜ Decoder์˜ ์ฒซ, ๋‘๋ฒˆ์งธ ์…€์€ ์ด๋ฏธ Attention ๋งค์ปค๋‹ˆ์ฆ˜์„ ์ด์šฉํ•ด์„œ ๋‹จ์–ด๋ฅผ ์˜ˆ์ธกํ•˜๋Š” ๊ณผ์ •์„ ๊ฑฐ์ณค์Šต๋‹ˆ๋‹ค.
  • Decoder์˜ 3๋ฒˆ์งธ LSTM Cell์€ ์ถœ๋ ฅ ๋‹จ์–ด๋ฅผ ์˜ˆ์ธกํ•˜๊ธฐ ์œ„ํ•˜์—ฌ Encoder์˜ ๋ชจ๋“  ๋‹จ์–ด๋“ค์˜ ์ •๋ณด๋ฅผ ์ฐธ๊ณ ํ•˜๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.
  • Encoder์˜ Softmax ํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด ๋‚˜์˜จ ๊ฒฐ๊ณผ๊ฐ’์€ I, am, a, student ๋‹จ์–ด ๊ฐ๊ฐ์ด ์ถœ๋ ฅ์ด ์–ผ๋งˆ๋‚˜ ๋„์›€์ด ๋˜๋Š”์ง€ ์ •๋„๋ฅผ ์ˆ˜์น˜ํ™” ํ•œ ๊ฐ’์ž…๋‹ˆ๋‹ค.
  • ์—ฌ๊ธฐ์„œ๋Š” Softmax ํ•จ์ˆ˜์œ„์˜ ์žˆ๋Š” ์ง์‚ฌ๊ฐํ˜•์˜ ํฌ๊ธฐ๊ฐ€ ํด์ˆ˜๋ก ๊ฒฐ๊ณผ๊ฐ’์ด ๋‹จ์–ด ๊ฐ๊ฐ์˜ ์ถœ๋ ฅ์— ๋„์›€์ด ๋˜๋Š” ์ •๋„๋ฅผ ์ˆ˜์น˜ํ™” ํ•œ ํฌ๊ธฐ ์ž…๋‹ˆ๋‹ค.
  • ๊ฐ ์ž…๋ ฅ ๋‹จ์–ด๊ฐ€ Decoder ์˜ˆ์ธก์— ๋„์›€์ด ๋˜๋Š” ์ •๋„๋ฅผ ์ˆ˜์น˜ํ™” ํ•˜๋ฉด ์ด๋ฅผ ํ•˜๋‚˜์˜ ์ •๋ณด๋กœ ํ•ฉ์นœํ›„ Decoder๋กœ ๋ณด๋ƒ…๋‹ˆ๋‹ค (๋งจ ์œ„์˜ ์ดˆ๋ก์ƒ‰ ์‚ผ๊ฐํ˜•) -> ๊ฒฐ๋ก ์€ Decoder๋Š” ์ถœ๋ ฅ ๋‹จ์–ด๋ฅผ ๋” ์ •ํ™•ํ•˜๊ฒŒ ์˜ˆ์ธกํ•  ํ™•๋ฅ ์ด ๋†’์•„์ง‘๋‹ˆ๋‹ค. ๊ตฌ์ œ์ ์œผ๋กœ ์„ค๋ช…ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

1. Attention Score ๊ตฌํ•˜๊ธฐ

์ถœ์ฒ˜: https://wikidocs.net/22893

  • Encoder์˜ Time Step(Encoder์˜ ์‹œ์ )์„ ๊ฐ๊ฐ 1๋ถ€ํ„ฐ N๊นŒ์ง€ ๋ผ๊ณ  ํ–ˆ์„๋•Œ Hidden State(์€๋‹‰ ์ƒํƒœ)๋ฅผ ๊ฐ๊ฐ h1, h2, ~ hn์ด๋ผ๊ณ  ํ•˜๊ณ , Decoder์˜ Time Step(ํ˜„์žฌ ์‹œ์ ) t์—์„œ์˜ Decoder์˜  Hidden State(์€๋‹‰ ์ƒํƒœ)๋ฅผ st๋ผ๊ณ  ํ•ฉ๋‹ˆ๋‹ค.
  • ์—ฌ๊ธฐ์„œ๋Š” Encoder์˜ Hidden State(์€๋‹‰ ์ƒํƒœ), Decoder์˜ Hidden State(์€๋‹‰ ์ƒํƒœ)์˜ ์ฐจ์›์ด ๊ฐ™๋‹ค๊ณ  ๊ฐ€์ •ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
  • ๊ทธ๋Ÿฌ๋ฉด Encoder, Decoder์˜ Hidden state(์€๋‹‰ ์ƒํƒœ)์˜ Dimension(์ฐจ์›)์€ ๋™์ผํ•˜๊ฒŒ 4๋กœ ๊ฐ™์Šต๋‹ˆ๋‹ค.
  • Attention ๋ฉ”์ปค๋‹ˆ์ฆ˜์—์„œ๋Š” ์ถœ๋ ฅ๋‹จ์–ด๋ฅผ ์˜ˆ์ธกํ• ๋•Œ Attention Value(์–ดํ…์…˜ ๊ฐ’)์„ ํ•„์š”๋กœ ํ•ฉ๋‹ˆ๋‹ค. t๋ฒˆ์งธ ๋‹จ์–ด๋ฅผ ์˜ˆ์ธกํ•˜๊ธฐ ์œ„ํ•œ Decoder์˜ Cell์€ 2๊ฐœ์˜ input๊ฐ’์„ ํ•„์š”๋กœ ํ•˜๋Š”๋ฐ, ์ด์ „์‹œ์ ์ธ t-1์˜ ์€๋‹‰ ์ƒํƒœ์™€ ์ด์ „ ์‹œ์  t-1์— ๋‚˜์˜จ ์ถœ๋ ฅ ๋‹จ์–ด์ž…๋‹ˆ๋‹ค. 
  • ๊ทธ๋ฆฌ๊ณ  Attention Value(์–ดํ…์…˜ ๊ฐ’)์„ ์ข…ํ•ฉํ•˜์—ฌ Decoder Cell์€ ํ˜„์žฌ ์‹œ์ (t)์˜ Output์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
+  Hidden state(์€๋‹‰ ์ƒํƒœ): RNN(์ˆœํ™˜์‹ ๊ฒฝ๋ง), LSTM(์žฅ, ๋‹จ๊ธฐ ๋ฉ”๋ชจ๋ฆฌ)์—์„œ ์‹œ๊ฐ„์  ์˜์กด์„ฑ์„ ํ•™์Šต ๋ฐ ๋ชจ๋ธ์˜ ์ด์ „์ •๋ณด ์œ ์ง€์— ์‚ฌ์šฉ๋˜๋Š” ๋‚ด๋ถ€ ์ƒํƒœ์ž…๋‹ˆ๋‹ค.
Input Sequence๋ฅผ ํ•œ๋ฒˆ์— ํ•˜๋‚˜์”ฉ ์ฒ˜๋ฆฌํ•˜๋ฉด์„œ ๋‚ด๋ถ€ ์ƒํƒœ๋ฅผ ์—…๋ฐ์ดํŠธ ํ•˜๋Š”๋ฐ, ์ด๋•Œ ํ˜„์žฌ ์‹œ์ ์˜ input ๋ฐ ์ด์ „์‹œ์ ์˜ hidden state(์€๋‹‰ ์ƒํƒœ)์— ์˜ํ–ฅ์„ ๋ฐ›์œผ๋ฉฐ, ์ด์ „์ •๋ณด ์œ ์ง€ ๋ฐ ํ˜„์žฌ ์ž…๋ ฅ์— ๋”ฐ๋ผ ๋ณ€ํ™”ํ•ฉ๋‹ˆ๋‹ค.

 

ํ•œ๋ฒˆ ํ˜„์žฌ ์‹œ์ ์˜ ๋‹จ์–ด์ธ t๋ฒˆ์งธ ๋‹จ์–ด๋ฅผ ์˜ˆ์ธกํ•˜๊ธฐ ์œ„ํ•œ Attention ๊ฐ’์„ 'at'๋ผ๊ณ  ์ •์˜ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
  • ํ˜„์žฌ ์‹œ์ ์˜ ๋‹จ์–ด์ธ t๋ฒˆ์งธ ๋‹จ์–ด๋ฅผ ์˜ˆ์ธกํ•˜๊ธฐ ์œ„ํ•œ Attention ๊ฐ’์ธ 'at' ๊ฐ’์„ ๊ตฌํ•˜๋ ค๋ฉด Attention Score(์–ดํ…์…˜ ์Šค์ฝ”์–ด)๋ฅผ ๊ตฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
  • ์—ฌ๊ธฐ์„œ Attention Score(์–ดํ…์…˜ ์Šค์ฝ”์–ด)๋Š” ํ˜„์žฌ Decoder์˜ ์‹œ์  t์—์„œ ๋‹จ์–ด๋ฅผ ์˜ˆ์ธกํ•˜๊ธฐ ์œ„ํ•ด์„œ Encoder์˜ ๋ชจ๋“  Hidden state(์€๋‹‰ ์ƒํƒœ) ๊ฐ๊ฐ์ด Decoder์˜ ํ˜„์‹œ์ ์˜ Hidden state(์€๋‹‰ ์ƒํƒœ)๊ฐ’์ธ 'st'์™€ ์–ผ๋งˆ๋‚˜ ์œ ์‚ฌํ•œ์ง€๋ฅผ ํŒ๋‹จํ•˜๋Š” ๊ฐ’์ž…๋‹ˆ๋‹ค.

 

  • Dot-Product Attention์—์„œ๋Š” Score๊ฐ’์„ ๊ตฌํ•˜๊ธฐ ์œ„ํ•˜์—ฌ Decoder์˜ ํ˜„์‹œ์ ์˜ Hidden state(์€๋‹‰ ์ƒํƒœ)๊ฐ’์ธ 'St'๋ฅผ Transpose(์ „์น˜)ํ•˜๊ณ  ๊ฐ Hidden state(์€๋‹‰ ์ƒํƒœ)์™€ Dot Product(๋‚ด์ )์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
  • ์˜ˆ๋ฅผ ๋“ค์–ด์„œ 'St'์™€ Encoder์˜ i๋ฒˆ์งธ Hidden state์˜ Attention Score์˜ ๊ณ„์‚ฐ ๋ฐฉ๋ฒ•์€ ์•„๋ž˜์˜ ๊ทธ๋ฆผ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

Attention Score์˜ ๊ณ„์‚ฐ ๋ฐฉ๋ฒ•  ์ถœ์ฒ˜: https://wikidocs.net/22893

  • Attention Score Function(์–ดํ…์…˜ ์Šค์ฝ”์–ด ํ•จ์ˆ˜)๋ฅผ ์ •์˜ํ•ด๋ณด๋ฉด ์•„๋ž˜ ์™ผ์ชฝ์˜ ์ˆ˜์‹๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.
  • ๊ทธ๋ฆฌ๊ณ  Decoder์˜ ํ˜„์‹œ์ ์˜ Hidden state(์€๋‹‰ ์ƒํƒœ)๊ฐ’์ธ 'st'์™€ Encoder์˜ ๋ชจ๋“  Hidden State(์€๋‹‰ ์ƒํƒœ)์˜ Attention Score์˜ ๋ชจ์Œ๊ฐ’์„ 'et'๋ผ๊ณ ํ•˜๊ณ , 'et'์˜ ์ˆ˜์‹์€ ์•„๋ž˜์˜ ์˜ค๋ฅธ์ชฝ ์ˆ˜์‹๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.


2. Attention Distribution(์–ดํ…์…˜ ๋ถ„ํฌ) ๊ตฌํ•˜๊ธฐ by Softmax ํ•จ์ˆ˜

์ถœ์ฒ˜: https://wikidocs.net/22893

  • 'et'์— Softmax Function(Softmax ํ•จ์ˆ˜)๋ฅผ ์ ์šฉํ•˜์—ฌ, ๋ชจ๋“  ๊ฐ’์„ ํ•ฉ์น˜๋ฉด 1์ด ๋˜๋Š” ํ™•๋ฅ  ๋ถ„ํฌ๋ฅผ ์–ป์–ด๋ƒ…๋‹ˆ๋‹ค.
  • ์ด๋ฅผ Attention Distribution(์–ดํ…์…˜ ๋ถ„ํฌ)๋ผ๊ณ  ํ•˜๋ฉฐ, ๊ฐ๊ฐ์˜ ๊ฐ’์€ Attention Weight(์–ดํ…์…˜ ๊ฐ€์ค‘์น˜)๋ผ๊ณ  ํ•ฉ๋‹ˆ๋‹ค.
  • ์˜ˆ๋ฅผ ๋“ค์–ด์„œ Softmax Function(Softmax ํ•จ์ˆ˜)๋ฅผ ์ ์šฉํ•˜์—ฌ ์–ป์€ ์ถœ๋ ฅ๊ฐ’๋“ค์˜ Attention ๊ฐ€์ค‘์น˜๋“ค์˜ ํ•ฉ์€ 1์ž…๋‹ˆ๋‹ค.
  • ์ด ๊ทธ๋ฆผ์—์„œ๋Š” ๊ฐ Encoder์˜ Hidden State(์€๋‹‰ ์ƒํƒœ)์—์„œ์˜ Attention Weight(์–ดํ…์…˜ ๊ฐ€์ค‘์น˜)์˜ ํฌ๊ธฐ๋ฅผ ์ง์‚ฌ๊ฐํ˜•์˜ ํฌ๊ธฐ๋กœ ์‹œ๊ฐํ™” ํ•˜์—ฌ, Attention Weight(์–ดํ…์…˜ ๊ฐ€์ค‘์น˜)๊ฐ€ ํด์ˆ˜๋ก ์ง์‚ฌ๊ฐํ˜•์ด ํฝ๋‹ˆ๋‹ค.
  • Decoder์˜ ์‹œ์ ์ธ 't'์—์„œ Attention Weight(์–ดํ…์…˜ ๊ฐ€์ค‘์น˜)์˜ ๋ชจ์Œ๊ฐ’์ธ Attention Distribution(์–ดํ…์…˜ ๋ถ„ํฌ)๋ฅผ 'at'๋ผ๊ณ  ํ•  ๋•Œ, 'at'๋ฅผ ์‹์œผ๋กœ ์ •์˜ํ•˜๋ฉด ์•„๋ž˜์˜ ์‹๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

'at'๋ฅผ ์‹์œผ๋กœ ์ •์˜.  ์ถœ์ฒ˜: https://wikidocs.net/22893

 


3. ๊ฐ Encoder์˜ Attention ๊ฐ€์ค‘์น˜ & Hidden state๋ฅผ ํ•ฉ์ณ ์–ดํ…์…˜ ๊ฐ’ ๊ตฌํ•˜๊ธฐ

  • Attention์˜ ์ตœ์ข… ๊ฒฐ๊ณผ ๊ฐ’์„ ์–ป๊ธฐ ์œ„ํ•ด์„œ๋Š” ๊ฐ Encoder์˜ Hidden state(์€๋‹‰์ƒํƒœ), Attention Weight(์–ดํ…์…˜ ๊ฐ€์ค‘์น˜)๊ฐ’๋“ค์„ ๊ณฑํ•˜๊ณ , ๋ชจ๋‘ ๋”ํ•ฉ๋‹ˆ๋‹ค.
  • ์ด๋Ÿฌํ•œ ๊ณผ์ •์„ Weight Sum(๊ฐ€์ค‘ํ•ฉ)์ด๋ผ๊ณ  ํ•˜๋ฉฐ, ์•„๋ž˜์˜ ์‹์€ Attention์˜ ์ตœ์ข…๊ฒฐ๊ณผ์ธ Attention Value(์–ดํ…์…˜ ๊ฐ’)์— ๋Œ€ํ•œ ์‹์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

Attention ํ•จ์ˆ˜์˜ ์ถœ๋ ฅ๊ฐ’ ์ธ Attention Value 'at'์— ๋Œ€ํ•œ ์‹

  • ์ด Attention Value(์–ดํ…์…˜ ๊ฐ’) 'at'๋Š” ์ข…์ข… Encoder์˜ ๋ฌธ๋งฅ์„ ํฌํ•จํ•˜๋ฉฐ, Context Vector(์ปจํ…์ŠคํŠธ ๋ฒกํ„ฐ)๋ผ๊ณ ๋„ ํ•ฉ๋‹ˆ๋‹ค.
  • Seq2seq์˜ Encoder์˜ ๋งˆ์ง€๋ง‰ Hidden state(์€๋‹‰์ƒํƒœ)๋ฅผ Context Vector(์ปจํ…์ŠคํŠธ ๋ฒกํ„ฐ)๋กœ ๋ถ€๋ฅด๋Š”๊ฒƒ๊ณผ ๋‹ค๋ฆ…๋‹ˆ๋‹ค.

4. Attention Value & Decoder์˜ t์‹œ์ ์˜ Hidden state๋ฅผ ์—ฐ๊ฒฐ (Attention ๋งค์ปค๋‹ˆ์ฆ˜์˜ ํ•ต์‹ฌ)

Attention ๋งค์ปค๋‹ˆ์ฆ˜์˜ ํ•ต์‹ฌ

  • Attention Value(์–ดํ…์…˜ ๊ฐ’)'at'๊ฐ€ ๊ตฌํ•ด์ง€๋ฉด Attention ๋ฉ”์ปค๋‹ˆ์ฆ˜์€ 'at'๋ฅผ 'st'์™€ ํ•ฉ์ณ์„œ(Concentrate) ํ•˜๋‚˜์˜ Vector๋กœ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
  • ์ด ์ž‘์—…์˜ ๊ฒฐ๊ณผ๋ฅผ 'vt'๋ผ๊ณ  ์ •์˜ํ•˜๊ณ , ์ด 'vt'๋ฅผ y^ ์˜ˆ์ธก ์—ฐ์‚ฐ์˜ input์œผ๋กœ ์‚ฌ์šฉํ•˜์—ฌ Encoder์—์„œ ์–ป์€ ์ •๋ณด๋ฅผ ํ™œ์šฉํ•˜์—ฌ y^๋ฅผ ๋” ์ •ํ™•ํ•˜๊ฒŒ ์˜ˆ์ธกํ•˜๋Š”๋ฐ ๋„์›€์„ ์ค๋‹ˆ๋‹ค.

5. Output Layer ์—ฐ์‚ฐ์˜ ์ž…๋ ฅ๊ฐ’ ๊ณ„์‚ฐ ๋ฐ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•˜๊ธฐ

์ถœ๋ ฅ์ธต ์—ฐ์‚ฐ์˜ ์ž…๋ ฅ์ด ๋˜๋Š” 'st' ๊ณ„์‚ฐ

  • 'vt'๋ฅผ ๋ฐ”๋กœ Output Layer(์ถœ๋ ฅ์ธต)์œผ๋กœ ๋ณด๋‚ด๊ธฐ ์ „์— ์ถœ๋ ฅ์ธต ์—ฐ์‚ฐ์„ ํ•˜๋Š” ๊ณผ์ •์ž…๋‹ˆ๋‹ค.
  • Weight(๊ฐ€์ค‘์น˜) ํ–‰๋ ฌ์— ๊ณฑํ•˜๊ณ , tanh(ํ•˜์ดํผ๋ณผ๋ฆญํƒ„์  ํŠธ) ํ•จ์ˆ˜๋ฅผ ์ง€๋‚˜๋„๋ก ํ•ด์„œ Output Layer ์—ฐ์‚ฐ์„ ํ•˜๊ธฐ ์œ„ํ•œ Vector์ธ 'st'๋ฅผ ์–ป์Šต๋‹ˆ๋‹ค.
  • Seq2seq๋Š” Output Layer(์ถœ๋ ฅ์ธต)์˜ Input(์ž…๋ ฅ)์ด t์‹œ์ ์˜ Hidden State(์€๋‹‰ ์ƒํƒœ)์ธ 'st'์˜€์ง€๋งŒ, Attention ๋งค์ปค๋‹ˆ์ฆ˜ ์—์„œ๋Š” Output Layer(์ถœ๋ ฅ์ธต)์˜ Input(์ž…๋ ฅ)์ด 'st'๊ฐ€ ๋ฉ๋‹ˆ๋‹ค.
  • ์‹์œผ๋กœ๋Š” ์•„๋ž˜์˜ ์‹๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค. 'Wc'๋Š” ํ•™์Šต ๊ฐ€๋Šฅํ•œ Weight(๊ฐ€์ค‘์น˜) ํ–‰๋ ฌ, 'bc'๋Š” bias(ํŽธํ–ฅ)์ž…๋‹ˆ๋‹ค. 

'bc' - bias(ํŽธํ–ฅ)์€ ์—ฌ๊ธฐ์„œ๋Š” ์ƒ๋žตํ–ˆ์Šต๋‹ˆ๋‹ค.

  • 'st'๋ฅผ Output Layer(์ถœ๋ ฅ์ธต)์˜ Input(์ž…๋ ฅ)์œผ๋กœ ์‚ฌ์šฉํ•˜์—ฌ Predict Vector(์˜ˆ์ธก ๋ฒกํ„ฐ)๋ฅผ ์–ป์Šต๋‹ˆ๋‹ค.

์˜ˆ์ธก ๋ฒกํ„ฐ๋ฅผ ์–ป๋Š” ์ˆ˜์‹


Dot-Product Attention Model Example Code

TF(Tensorflow)๋กœ ๊ตฌํ˜„ํ•œ ์ฝ”๋“œ์ž…๋‹ˆ๋‹ค.
import tensorflow as tf

def scaled_dot_product_attention(q, k, v, mask):
    # Query์™€ Key ๊ฐ„์˜ dot product ๊ณ„์‚ฐ
    matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

    # dot product๋ฅผ ์ •๊ทœํ™”ํ•˜๊ธฐ ์œ„ํ•ด ์Šค์ผ€์ผ๋ง
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

    # ๋งŒ์•ฝ mask๊ฐ€ ์ œ๊ณต๋˜์—ˆ๋‹ค๋ฉด, ์œ ํšจํ•œ ๋ถ€๋ถ„์— ๋Œ€ํ•ด์„œ๋งŒ -1e9๋กœ ๋งˆ์Šคํ‚น
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)

    # ์–ดํ…์…˜ ๊ฐ€์ค‘์น˜ ๊ณ„์‚ฐ (์†Œํ”„ํŠธ๋งฅ์Šค ํ•จ์ˆ˜ ์ ์šฉ)
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

    # ์–ดํ…์…˜ ๊ฐ€์ค‘ ํ‰๊ท ์„ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ’์„ ๊ณ„์‚ฐ
    output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

    return output, attention_weights

 

  • ์ด ์ฝ”๋“œ๋Š” Dot-Product Attention์„ ๊ตฌํ˜„ํ•œ ์ฝ”๋“œ์ž…๋‹ˆ๋‹ค.
  1. ์ˆœ์„œ๋Š” ์ฃผ์–ด์ง„ Query(์ฟผ๋ฆฌ, 'q'), Key(ํ‚ค, 'k') ๊ฐ„์˜ Product๋ฅผ Keyํ–‰๋ ฌ์„ ์ „์น˜ํ•˜์—ฌ Dot Product๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
  2. Dot Product๋ฅผ Key์˜ ์ฐจ์›์ˆ˜๋กœ ๋‚˜๋ˆ ์ฃผ์–ด ์ •๊ทœํ™” ํ•ฉ๋‹ˆ๋‹ค.
    • Scaling์„ ํ†ตํ•ด Gradient Vanishing / Exploding (๊ธฐ์šธ๊ธฐ ์†์‹ค, ํญํŒ”)์„ ๋ฐฉ์ง€ํ•ฉ๋‹ˆ๋‹ค.
  3. Masking์€ ์œ ํšจํ•œ ๋ถ€๋ถ„์—๋งŒ 'Attention'์„ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค. -> ์„ ํƒ์ ์œผ๋กœ ์–ดํ…์…˜ ๋งˆ์Šคํฌ๋ฅผ ์ ์šฉํ•˜๋ฉฐ, ์œ ํšจํ•˜์ง€ ์•Š์€ ๋ถ€๋ถ„์—๋Š” -1e9๋ฅผ ๋”ํ•˜์—ฌ ๋งˆ์Šคํ‚น ํ•ฉ๋‹ˆ๋‹ค.
  4. Softmax ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•˜์—ฌ Attention Weight(์–ดํ…์…˜ ๊ฐ€์ค‘์น˜)๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ Weight(๊ฐ€์ค‘์น˜)๋Š” Query(์ฟผ๋ฆฌ), Key(ํ‚ค)๊ฐ„์˜ ์ค‘์š”์„ฑ์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.
  5. Attention Weight(๊ฐ€์ค‘์น˜)๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ Value(๋ฐธ๋ฅ˜)ํ–‰๋ ฌ๊ณผ ํ‰๊ท ๊ฐ’์„ ๊ณ„์‚ฐํ•˜๊ณ  'output'์„ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค.
    • 'output'์€ Query์— ๋Œ€ํ•œ Attention ๊ฐ€์ค‘ ํ‰๊ท  ๊ฐ’์ž…๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  'output'๊ณผ Attention Weight(๊ฐ€์ค‘์น˜)๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

4. Self-Attention

Self-Attention ๊ธฐ๋ฒ•์€ Attention ๊ธฐ๋ฒ•์„ ๋ง ๊ทธ๋Œ€๋กœ ์ž๊ธฐ ์ž์‹ ์—๊ฒŒ ์ˆ˜ํ–‰ํ•˜๋Š” Attention ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค.
  • Input Sequence ๋‚ด์˜ ๊ฐ ์š”์†Œ๊ฐ„์˜ ์ƒ๋Œ€์ ์ธ ์ค‘์š”๋„๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” ๋งค์ปค๋‹ˆ์ฆ˜์ด๋ฉฐ, Sequence์˜ ๋‹ค์–‘ํ•œ ์œ„์น˜๊ฐ„์˜ ์„œ๋กœ ๋‹ค๋ฅธ ๊ด€๊ณ„๋ฅผ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • Sequence ์š”์†Œ ๊ฐ€์šด๋ฐ Task ์ˆ˜ํ–‰์— ์ค‘์š”ํ•œ Element(์š”์†Œ)์— ์ง‘์ค‘ํ•˜๊ณ  ๊ทธ๋ ‡์ง€ ์•Š์€ Element(์š”์†Œ)๋Š” ๋ฌด์‹œํ•ฉ๋‹ˆ๋‹ค.
    • ์ด๋Ÿฌ๋ฉด Task๊ฐ€ ์ˆ˜ํ–‰ํ•˜๋Š” ์„ฑ๋Šฅ์ด ์ƒ์Šนํ•  ๋ฟ๋”๋Ÿฌ Decoding ํ• ๋•Œ Source Sequence ๊ฐ€์šด๋ฐ ์ค‘์š”ํ•œ Element(์š”์†Œ)๋“ค๋งŒ ์ถ”๋ฆฝ๋‹ˆ๋‹ค.
  • ๋ฌธ๋งฅ์— ๋”ฐ๋ผ ์ง‘์ค‘ํ•  ๋‹จ์–ด๋ฅผ ๊ฒฐ์ •ํ•˜๋Š” ๋ฐฉ์‹์„ ์˜๋ฏธ -> ์ค‘์š”ํ•œ ๋‹จ์–ด์—๋งŒ ์ง‘์ค‘์„ ํ•˜๊ณ  ๋‚˜๋จธ์ง€๋Š” ๊ทธ๋ƒฅ ์ฝ์Šต๋‹ˆ๋‹ค.
    • ์ด ๋ฐฉ๋ฒ•์ด ๋ฌธ๋งฅ์„ ํŒŒ์•…ํ•˜๋Š” ํ•ต์‹ฌ์ด๋ฉฐ, ์ด๋Ÿฌํ•œ ๋ฐฉ์‹์„ Deep Learning ๋ชจ๋ธ์— ์ ์šฉํ•œ๊ฒƒ์ด 'Attention' ๋งค์ปค๋‹ˆ์ฆ˜์ด๋ฉฐ, ์ด ๋งค์ปค๋‹ˆ์ฆ˜์„ ์ž๊ธฐ ์ž์‹ ์—๊ฒŒ ์ ์šฉํ•œ๊ฒƒ์ด 'Self-Attention' ์ž…๋‹ˆ๋‹ค. ํ•œ๋ฒˆ ๋งค์ปค๋‹ˆ์ฆ˜์„ ์˜ˆ์‹œ๋ฅผ ๋“ค์–ด์„œ ์„ค๋ช…ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

Self-Attention์˜ ๊ณ„์‚ฐ ์˜ˆ์‹œ

Self-Attention์€ Query(์ฟผ๋ฆฌ), Key(ํ‚ค), Value(๋ฐธ๋ฅ˜) 3๊ฐ€์ง€ ์š”์†Œ๊ฐ€ ์„œ๋กœ ์˜ํ–ฅ์„ ์ฃผ๊ณ  ๋ฐ›๋Š” ๊ตฌ์กฐ์ž…๋‹ˆ๋‹ค.
  • ๋ฌธ์žฅ๋‚ด ๊ฐ ๋‹จ์–ด๊ฐ€ Vector(๋ฒกํ„ฐ) ํ˜•ํƒœ๋กœ input(์ž…๋ ฅ)์„ ๋ฐ›์Šต๋‹ˆ๋‹ค.

* Vector: ์ˆซ์ž์˜ ๋‚˜์—ด ์ •๋„

  • ๊ฐ ๋‹จ์–ด์˜ Vector๋Š” 3๊ฐ€์ง€ ๊ณผ์ •์„ ๊ฑฐ์ณ์„œ ๋ฐ˜ํ™˜์ด ๋ฉ๋‹ˆ๋‹ค.
  • Query(์ฟผ๋ฆฌ) - ๋‚ด๊ฐ€ ์ฐพ๊ณ  ์ •๋ณด๋ฅผ ์š”์ฒญํ•˜๋Š”๊ฒƒ ์ž…๋‹ˆ๋‹ค. 
  • Key(ํ‚ค) - ๋‚ด๊ฐ€ ์ฐพ๋Š” ์ •๋ณด๊ฐ€ ์žˆ๋Š” ์ฐพ์•„๋ณด๋Š” ๊ณผ์ •์ž…๋‹ˆ๋‹ค.
  • Value(๋ฐธ๋ฅ˜) - ์ฐพ์•„์„œ ์ œ๊ณต๋œ ์ •๋ณด๊ฐ€ ๊ฐ€์น˜ ์žˆ๋Š”์ง€ ํŒ๋‹จํ•˜๋Š” ๊ณผ์ •์ž…๋‹ˆ๋‹ค.

  • ์œ„์˜ ๊ทธ๋ฆผ์„ ๋ณด๋ฉด ์ž…๋ ฅ๋˜๋Š” ๋ฌธ์žฅ "์–ด์ œ ์นดํŽ˜ ๊ฐ”์—ˆ์–ด ๊ฑฐ๊ธฐ ์‚ฌ๋žŒ ๋งŽ๋”๋ผ" ์ด 6๊ฐœ ๋‹จ์–ด๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ๋‹ค๋ฉด?
  • ์—ฌ๊ธฐ์„œ์˜ Self-Attention ๊ณ„์‚ฐ ๋Œ€์ƒ์€ Query(์ฟผ๋ฆฌ) Vector 6๊ฐœ, Key(ํ‚ค) Vector 6๊ฐœ, Value(๋ฐธ๋ฅ˜) Vector 6๊ฐœ๋“ฑ ๋ชจ๋‘ 18๊ฐœ๊ฐ€ ๋ฉ๋‹ˆ๋‹ค.

  • ์œ„์˜ ํ‘œ๋Š” ๋” ์„ธ๋ถ€์ ์œผ๋กœ ๋‚˜ํƒ€๋‚ธ๊ฒƒ์ž…๋‹ˆ๋‹ค. Self-Attention์€ Query ๋‹จ์–ด ๊ฐ๊ฐ์— ๋Œ€ํ•ด ๋ชจ๋“  Key ๋‹จ์–ด์™€ ์–ผ๋งˆ๋‚˜ ์œ ๊ธฐ์ ์ธ ๊ด€๊ณ„๋ฅผ ๋งบ๊ณ  ์žˆ๋Š”์ง€์˜ ํ™•๋ฅ ๊ฐ’๋“ค์˜ ํ•ฉ์ด 1์ธ ํ™•๋ฅ ๊ฐ’์œผ๋กœ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.
  • ์ด๊ฒƒ์„ ๋ณด๋ฉด Self-Attention ๋ชจ๋“ˆ์€ Value(๋ฐธ๋ฅ˜) Vector๋“ค์„ Weighted Sum(๊ฐ€์ค‘ํ•ฉ)ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ๊ณ„์‚ฐ์„ ๋งˆ๋ฌด๋ฆฌ ํ•ฉ๋‹ˆ๋‹ค.
  • ํ™•๋ฅ ๊ฐ’์ด ๊ฐ€์žฅ ๋†’์€ ํ‚ค ๋‹จ์–ด๊ฐ€ ์ฟผ๋ฆฌ ๋‹จ์–ด์™€ ๊ฐ€์žฅ ๊ด€๋ จ์ด ๋†’์€ ๋‹จ์–ด๋ผ๊ณ  ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  • ์—ฌ๊ธฐ์„œ๋Š” '์นดํŽ˜'์— ๋Œ€ํ•ด์„œ๋งŒ ๊ณ„์‚ฐ ์˜ˆ์‹œ๋ฅผ ๋“ค์—ˆ์ง€๋งŒ, ์ด๋Ÿฌํ•œ ๋ฐฉ์‹์œผ๋กœ ๋‚˜๋จธ์ง€ ๋‹จ์–ด๋“ค์˜ค Self-Attention์„ ๊ฐ๊ฐ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

Self-Attention์˜ ๋™์ž‘ ๋ฐฉ์‹

Self-Attention์—์„œ ๊ฐ€์žฅ ์ค‘์š”ํ•œ ๊ฐœ๋…์€ Query(์ฟผ๋ฆฌ), Key(ํ‚ค), Value(๋ฐธ๋ฅ˜)์˜ ์‹œ์ž‘๊ฐ’์ด ๋™์ผํ•˜๋‹ค๋Š” ์ ์ž…๋‹ˆ๋‹ค.
  • ๊ทธ๋ ‡๋‹ค๊ณ  Query(์ฟผ๋ฆฌ), Key(ํ‚ค), Value(๋ฐธ๋ฅ˜)๊ฐ€ ๋™์ผ ํ•˜๋‹ค๋Š” ๋ง์ด ์ด๋‹™๋‹ˆ๋‹ค.
  • ๊ทธ๋ฆผ์„ ๋ณด๋ฉด ๊ฐ€์ค‘์น˜ Weight W๊ฐ’์— ์˜ํ•ด์„œ ์ตœ์ข…์ ์ธ Query(์ฟผ๋ฆฌ), Key(ํ‚ค), Value(๋ฐธ๋ฅ˜)๊ฐ’์€ ์„œ๋กœ ๋‹ฌ๋ผ์ง‘๋‹ˆ๋‹ค.

Attention ๊ตฌํ•˜๋Š” ๊ณต์‹

Attention์„ ๊ตฌํ•˜๋Š” ๊ณต์‹์„ ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
  • ์ผ๋‹จ Query(์ฟผ๋ฆฌ)๋ž‘ Key(ํ‚ค)๋ฅผ ๋‚ด์ ํ•ด์ค๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ๋‚ด์ ์„ ํ•ด์ฃผ๋Š” ์ด์œ ๋Š” ๋‘˜ ์‚ฌ์ด์˜ ์—ฐ๊ด€์„ฑ์„ ๊ณ„์‚ฐํ•˜๊ธฐ ์œ„ํ•ด์„œ์ž…๋‹ˆ๋‹ค.
  • ์ด ๋‚ด์ ๋œ ๊ฐ’์„ "Attention Score"๋ผ๊ณ  ํ•ฉ๋‹ˆ๋‹ค. Dot-Product Attention ๋ถ€๋ถ„์—์„œ ์ž์„ธํžˆ ์„ค๋ช… ํ–ˆ์ง€๋งŒ ์ด๋ฒˆ์—๋Š” ๊ฐ„๋‹จํ•˜๊ฒŒ ์„ค๋ช…ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
  • ๋งŒ์•ฝ, Query(์ฟผ๋ฆฌ)๋ž‘ Key(ํ‚ค)์˜ Dimension(์ฐจ์›)์ด ์ปค์ง€๋ฉด, ๋‚ด์  ๊ฐ’์˜ Attention Score๊ฐ€ ์ปค์ง€๊ฒŒ ๋˜์–ด์„œ ๋ชจ๋ธ์ด ํ•™์Šตํ•˜๋Š”๋ฐ ์–ด๋ ค์›€์ด ์ƒ๊น๋‹ˆ๋‹ค.
  • ์ด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด์„œ ์ฐจ์› d_k์˜ ๋ฃจํŠธ๋งŒํผ ๋‚˜๋ˆ„์–ด์ฃผ๋Š” Scaling ์ž‘์—…์„ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค. ์ด๊ณผ์ •์„ "Scaled Dot-Product Attention" ์ด๋ผ๊ณ  ํ•ฉ๋‹ˆ๋‹ค.
  • ๊ทธ๋ฆฌ๊ณ  Scaled Dot-Product Attention"์„ ์ง„ํ–‰ํ•œ ๊ฐ’์„ ์ •๊ทœํ™”๋ฅผ ์ง„ํ–‰ํ•ด์ฃผ๊ธฐ ์œ„ํ•ด์„œ Softmax ํ•จ์ˆ˜๋ฅผ ๊ฑฐ์ณ์„œ ๋ณด์ •์„ ์œ„ํ•ด ๊ณ„์‚ฐ๋œ scoreํ–‰๋ ฌ, valueํ–‰๋ ฌ์„ ๋‚ด์ ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด ์ตœ์ข…์ ์œผ๋กœ Attention ํ–‰๋ ฌ์„ ์–ป์„ ์ˆ˜ ์žˆ๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ์˜ˆ์‹œ ๋ฌธ์žฅ์„ ๋“ค์–ด์„œ ์„ค๋ช…ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

 

"I am a student"๋ผ๋Š” ๋ฌธ์žฅ์œผ๋กœ ์˜ˆ์‹œ๋ฅผ ๋“ค์–ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

  • Self-Attention์€ Query(์ฟผ๋ฆฌ), Key(ํ‚ค), Value(๋ฐธ๋ฅ˜) 3๊ฐœ ์š”์†Œ ์‚ฌ์ด์˜ ๋ฌธ๋งฅ์  ๊ด€๊ณ„์„ฑ์„ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.
Q = X * Wq, K = X * Wk, W = X * Wv
  • ์œ„์˜ ์ˆ˜์‹์ฒ˜๋Ÿผ Input Vector Sequence(์ž…๋ ฅ ๋ฒกํ„ฐ ์‹œํ€€์Šค) X์— Query(์ฟผ๋ฆฌ), Key(ํ‚ค), Value(๋ฐธ๋ฅ˜)๋ฅผ ๋งŒ๋“ค์–ด์ฃผ๋Š” ํ–‰๋ ฌ(W)๋ฅผ ๊ฐ๊ฐ ๊ณฑํ•ด์ค๋‹ˆ๋‹ค.
  • Input Vector Sequence(์ž…๋ ฅ ๋ฒกํ„ฐ ์‹œํ€€์Šค)๊ฐ€ 4๊ฐœ์ด๋ฉด ์™ผ์ชฝ์— ์žˆ๋Š” ํ–‰๋ ฌ์„ ์ ์šฉํ•˜๋ฉด Query(์ฟผ๋ฆฌ), Key(ํ‚ค), Value(๋ฐธ๋ฅ˜) ๊ฐ๊ฐ 4๊ฐœ์”ฉ, ์ด 12๊ฐœ๊ฐ€ ๋‚˜์˜ต๋‹ˆ๋‹ค.
* Word Embedding: ๋‹จ์–ด๋ฅผ Vector๋กœ ๋ณ€ํ™˜ํ•ด์„œ Dense(๋ฐ€์ง‘)ํ•œ Vector๊ณต๊ฐ„์— Mapping ํ•˜์—ฌ ์‹ค์ˆ˜ Vector๋กœ ํ‘œํ˜„ํ•ฉ๋‹ˆ๋‹ค.
  1. ๊ฐ ๋‹จ์–ด์— ๋ฐํ•˜์—ฌ Word Embedding(๋‹จ์–ด ์ž„๋ฒ ๋”ฉ)์„ ํ•ฉ๋‹ˆ๋‹ค. ๋‹จ์–ด 'i'์˜ Embedding์ด [1,1,1,1]์ด๋ผ๊ณ  ํ–ˆ์„ ๋•Œ, ์ฒ˜์Œ 'i'์˜ ์ฒ˜์Œ Query(์ฟผ๋ฆฌ), Key(ํ‚ค), Value(๋ฐธ๋ฅ˜)๋ฅผ ๊ฐ๊ฐ 'Q_i, original', 'K_i, original', 'V_i, original', ๋ผ๊ณ  ํ•ฉ๋‹ˆ๋‹ค.
    • Embedding์˜ ๊ฐ’์ด ๋‹ค [1,1,1,1]๋กœ ๊ฐ™์€ ์ด์œ ๋Š” Self-Attention ๋งค์ปค๋‹ˆ์ฆ˜์—์„œ๋Š” ๊ฐ™์•„์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋ชจ๋‘ [1,1,1,1]์ด๋ผ๊ณ  ๋™์ผํ•ฉ๋‹ˆ๋‹ค.
  2. ํ•™์Šต๋œ Weight(๊ฐ€์ค‘์น˜)๊ฐ’์ด 'WQ', 'WK', 'WV'๋ผ๊ณ  ํ• ๋•Œ Original ๊ฐ’๋“ค๊ณผ ์ ๊ณฑ์„ ํ•ด์ฃผ๋ฉด ์ตœ์ข…์ ์œผ๋กœ 'Q', 'K', 'V'๊ฐ’์ด ๋„์ถœ๋ฉ๋‹ˆ๋‹ค.
    • 'Q', 'K', 'V'๊ฐ’์„ ์ด์šฉํ•ด์„œ ์œ„์—์„œ ์„ค์ •ํ•œ ๋ณด์ •๋œ 'Attention Score'๋ฅผ ๊ณฑํ•ด์ฃผ๋ฉด ์•„๋ž˜์˜ ์™ผ์ชฝ์‹๊ณผ ๊ฐ™์ด 1.5๋ผ๋Š” ๊ฐ’์ด ๋‚˜์˜ต๋‹ˆ๋‹ค.
    • ํ–‰๋ ฌ 'Q', 'K'๋Š” ์„œ๋กœ ์ ๊ณฑ ๊ณ„์‚ฐํ•ด์ฃผ๊ณ , ์—ฌ๊ธฐ์„œ ํ–‰๋ ฌ 'Q', 'K', 'V'์˜ Dimension(์ฐจ์›)์€ 4์ด๋ฏ€๋กœ ๋ฃจํŠธ 4๋กœ ๋‚˜๋ˆ„์–ด์ค๋‹ˆ๋‹ค.

  • 'i' ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ๋ชจ๋“  ๋‹จ์–ด๊ฐ„์˜ 'Self-Attention'์„ ํ•ด์ฃผ๋ฉด ์œ„์˜ ์˜ค๋ฅธ์ชฝ ํ–‰๋ ฌ๊ณผ ๊ฐ™์€ ๊ฒฐ๊ณผ๊ฐ€ ๋‚˜์˜ต๋‹ˆ๋‹ค.
  • ๊ฐ€์šด๋ฐ ๋…ธ๋ฝ์ƒ‰ ๋ถ€๋ถ„์€ ์ž๊ธฐ ์ž์‹ ์— ๋Œ€ํ•œ 'Attention'์ด๋ฏ€๋กœ ๋‹น์—ฐํžˆ ๊ฐ’์ด ์ œ์ผ ํฌ๊ณ , ์–‘์ชฝ ์ดˆ๋ก์ƒ‰ ๋ถ€๋ถ„์„ ๋ณด๋ฉด ์ ์ˆ˜๊ฐ€ ๋†’์Šต๋‹ˆ๋‹ค.

Attention ๊ตฌํ•˜๋Š” ๊ณผ์ •์„ ๋„์‹ํ™”

  • ์œ„์˜ ๊ทธ๋ฆผ์€ ๋‹จ์–ด ํ•˜๋‚˜ํ•˜๋‚˜์˜ 'Attention'์„ ๊ตฌํ•˜๋Š” ๊ณผ์ •์„ ๋„์‹ํ™” ํ•œ ๊ทธ๋ฆผ์ž…๋‹ˆ๋‹ค.
  • ์‹ค์ œ๋กœ๋Š” ์—ฌ๋Ÿฌ ๋‹จ์–ด๋ฅผ ์œ„์˜ ๊ทธ๋ฆผ๊ณผ ๊ฐ™์ด ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ํ•ด์„œ ๊ณ„์‚ฐ์„ ํ•ฉ๋‹ˆ๋‹ค.
  • ๋ณ‘๋ ฌ์ฒ˜๋ฆฌ๋ฅผ ํ•˜๋ฉด ์—ฐ์‚ฐ์†๋„๊ฐ€ ๋นจ๋ฆฌ์ง€๋Š” ์ด์ ์ด ์žˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ๊ฐ„๋‹จํžˆ 'Self-Attention' ๊ณผ์ •์„ ์š”์•ฝํ•ด์„œ ์„ค๋ช…ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
Self-Attention ๊ณผ์ • Summary
1. ์›ํ•˜๋Š” ๋ฌธ์žฅ์„ ์ž„๋ฒ ๋”ฉํ•˜๊ณ  ํ•™์Šต์„ ํ†ตํ•ด ๊ฐ Query, Key, Value์— ๋งž๋Š” weight๋“ค์„ ๊ตฌํ•ด์คŒ.
2. ๊ฐ ๋‹จ์–ด์˜ ์ž„๋ฒ ๋”ฉ์˜ Query, Key, Value(Query = Key = Value)์™€ weight๋ฅผ ์ ๊ณฑ(๋‚ด์ )ํ•ด ์ตœ์ข… Q, K, V๋ฅผ ๊ตฌํ•จ.
3. Attention score ๊ณต์‹์„ ํ†ตํ•ด ๊ฐ ๋‹จ์–ด๋ณ„ Self Attention value๋ฅผ ๋„์ถœ
4. Self Attention value์˜ ๋‚ด๋ถ€๋ฅผ ๋น„๊ตํ•˜๋ฉด์„œ ์ƒ๊ด€๊ด€๊ณ„๊ฐ€ ๋†’์€ ๋‹จ์–ด๋“ค์„ ๋„์ถœ

Self-Attention Example Code

PyTorch ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ๋ถˆ๋Ÿฌ์™€์„œ ํ•œ๋ฒˆ ์˜ˆ์ œ ์ฝ”๋“œ๋ฅผ ๋Œ๋ ค๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
import torch

1. ๋ณ€์ˆ˜์ •์˜

  • PyTorch ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ํ™œ์šฉํ•ด์„œ ์ฝ”๋“œ๋กœ ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ์šฐ์„  Input Vector Sequence X์™€ Query(์ฟผ๋ฆฌ), Key(ํ‚ค), Value(๋ฐธ๋ฅ˜)๊ตฌ์ถ•์— ํ•„์š”ํ•œ ํ–‰๋ ฌ๋“ค์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. 
x = torch.tensor([
  [1.0, 0.0, 1.0, 0.0],
  [0.0, 2.0, 0.0, 2.0],
  [1.0, 1.0, 1.0, 1.0],
])

w_query = torch.tensor([
  [1.0, 0.0, 1.0],
  [1.0, 0.0, 0.0],
  [0.0, 0.0, 1.0],
  [0.0, 1.0, 1.0]
])

w_key = torch.tensor([
  [0.0, 0.0, 1.0],
  [1.0, 1.0, 0.0],
  [0.0, 1.0, 0.0],
  [1.0, 1.0, 0.0]
])

w_value = torch.tensor([
  [0.0, 2.0, 0.0],
  [0.0, 3.0, 0.0],
  [1.0, 0.0, 3.0],
  [1.0, 1.0, 0.0]
])

2. Query(์ฟผ๋ฆฌ), Key(ํ‚ค), Value(๋ฐธ๋ฅ˜) ๋งŒ๋“ค๊ธฐ

  • Vector Sequence๋กœ Query(์ฟผ๋ฆฌ), Key(ํ‚ค), Value(๋ฐธ๋ฅ˜)๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
  • ์—ฌ๊ธฐ์„œ torch.matmul์€ ํ–‰๋ ฌ๊ณฑ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค.
keys = torch.matmul(x, w_key)
querys = torch.matmul(x, w_query)
values = torch.matmul(x, w_value)

3. Attention Score ๋งŒ๋“ค๊ธฐ

  • ์œ„์—์„œ ๋งŒ๋“  Query(์ฟผ๋ฆฌ), Key(ํ‚ค) Vector๋“ค์„ ํ–‰๋ ฌ๊ณฑํ•ด์„œ Attention Score๋ฅผ ๋งŒ๋“œ๋Š” ๊ณผ์ •์ž…๋‹ˆ๋‹ค.
  • ์—ฌ๊ธฐ์„œ keys.T๋Š” Key(ํ‚ค) Vector๋“ค์„ Transpose(์ „์น˜)ํ•œ ํ–‰๋ ฌ์ž…๋‹ˆ๋‹ค.
attn_scores = torch.matmul(querys, keys.T)
attn_scores
tensor([[ 2.,  4.,  4.],
        [ 4., 16., 12.],
        [ 4., 12., 10.]])

4. Softmax ํ™•๋ฅ ๊ฐ’ ๋งŒ๋“ค๊ธฐ

  • Key(ํ‚ค) Vector์˜ ์ฐจ์šฐ๋„ˆ์ˆ˜์˜ ์ œ๊ณฑ๊ทผ์œผ๋กœ ๋‚˜๋ˆ ์ค€๋’ค Softmax๋ฅผ ์ทจํ•˜๋Š” ๊ณผ์ •์ž…๋‹ˆ๋‹ค.
import numpy as np
from torch.nn.functional import softmax
key_dim_sqrt = np.sqrt(keys.shape[-1])
attn_scores_softmax = softmax(attn_scores / key_dim_sqrt, dim=-1)
  • keys๋Š” Attention ๋ฉ”์ปค๋‹ˆ์ฆ˜์—์„œ ์‚ฌ์šฉ๋˜๋Š” ํ‚ค(key) ๊ฐ’๋“ค์„ ๋‚˜ํƒ€๋‚ด๋Š” Tensor์ž…๋‹ˆ๋‹ค.
  • keys.shape[-1]๋Š” keys Tensor์˜ ๋งˆ์ง€๋ง‰ Dimension(์ฐจ์›)์˜ ํฌ๊ธฐ, ์ฆ‰ ํ‚ค์˜ ์ฐจ์›(dimension) ์ˆ˜๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.
  • key_dim_sqrt๋Š” ํ‚ค ์ฐจ์›์˜ ์ œ๊ณฑ๊ทผ์— ๋Œ€ํ•œ ๊ฐ’์œผ๋กœ, ์–ดํ…์…˜ ์Šค์ฝ”์–ด๋ฅผ ์ •๊ทœํ™”ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
  • ์ •๊ทœํ™”๋Š” ์–ดํ…์…˜ ์Šค์ฝ”์–ด๋ฅผ ํ‚ค ์ฐจ์›์˜ ์ œ๊ณฑ๊ทผ์œผ๋กœ ๋‚˜๋ˆ„์–ด์ฃผ๋Š” ๊ณผ์ •์ž…๋‹ˆ๋‹ค.
  • ์ด๋ ‡๊ฒŒ ํ•จ์œผ๋กœ์จ ์–ดํ…์…˜์˜ ํฌ๊ธฐ์— ๋”ฐ๋ผ ์ •๊ทœํ™”๊ฐ€ ์ด๋ฃจ์–ด์ง€๋ฉฐ, ์ˆ˜๋ ด์„ ๋„๋ชจํ•˜๊ณ  ์•ˆ์ •์„ฑ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
attn_scores_softmax
tensor([[1.3613e-01, 4.3194e-01, 4.3194e-01],
        [8.9045e-04, 9.0884e-01, 9.0267e-02],
        [7.4449e-03, 7.5471e-01, 2.3785e-01]])

5. Softmax ํ™•๋ฅ , Value๋ฅผ ๊ฐ€์ค‘ํ•ฉํ•˜๊ธฐ

  • Softmax ํ™•๋ฅ ๊ณผ Value Vector๋“ค์„ ๊ฐ€์ค‘ํ•ฉ ํ•˜๋Š” ๊ณผ์ •์ž…๋‹ˆ๋‹ค.
  • Self-Attention์˜ ํ•™์Šต ๋Œ€์ƒ์€ Query(์ฟผ๋ฆฌ), Key(ํ‚ค) Value(๋ฐธ๋ฅ˜)๋ฅผ ๋งŒ๋“œ๋Š” Weight(๊ฐ€์ค‘์น˜) ํ–‰๋ ฌ์ž…๋‹ˆ๋‹ค.
  • ์ฝ”๋“œ์—์„œ๋Š” 'w_query', 'w_key', 'w_value' ์ž…๋‹ˆ๋‹ค.
  • ์ด๋“ค์€ Task(์˜ˆ: ๊ธฐ๊ณ„๋ฒˆ์—ญ)๋ฅผ ๊ฐ€์žฅ ์ž˜ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ฐฉํ–ฅ์œผ๋กœ ํ•™์Šต ๊ณผ์ •์—์„œ ์—…๋ฐ์ดํŠธ ๋ฉ๋‹ˆ๋‹ค.
weighted_values = torch.matmul(attn_scores_softmax, values)
weighted_values
tensor([[1.8639, 6.3194, 1.7042],
        [1.9991, 7.8141, 0.2735],
        [1.9926, 7.4796, 0.7359]])

Self-Attention Model Example Code

PyTorch๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ Self-Attention์œผ๋กœ ๊ตฌํ˜„ํ•˜์˜€์Šต๋‹ˆ๋‹ค.
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size  # ์ž„๋ฒ ๋”ฉ ์ฐจ์›
        self.heads = heads  # ํ—ค๋“œ ์ˆ˜
        self.head_dim = embed_size // heads  # ํ—ค๋“œ ๋‹น ์ž„๋ฒ ๋”ฉ ์ฐจ์›

        assert (
            self.head_dim * heads == embed_size
        ),

        # ๊ฐ’, ํ‚ค, ์ฟผ๋ฆฌ๋ฅผ ์œ„ํ•œ ์„ ํ˜• ๋ ˆ์ด์–ด
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)

        # ์–ดํ…์…˜ ์ดํ›„ ์ตœ์ข… ์ถœ๋ ฅ์„ ์œ„ํ•œ ์„ ํ˜• ๋ ˆ์ด์–ด
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        # ํ›ˆ๋ จ ์˜ˆ์ œ ์ˆ˜ ํ™•์ธ
        N = query.shape[0]

        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # ์ž„๋ฒ ๋”ฉ์„ ํ—ค๋“œ ๋‹น ๋‹ค๋ฅธ ๋ถ€๋ถ„์œผ๋กœ ๋‚˜๋ˆ„๊ธฐ
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        # ๊ฐ’, ํ‚ค, ์ฟผ๋ฆฌ์— ๋Œ€ํ•œ ์„ ํ˜• ํ”„๋กœ์ ์…˜
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # ์—๋„ˆ์ง€(์–ดํ…์…˜ ์Šค์ฝ”์–ด) ๊ณ„์‚ฐ์„ ์œ„ํ•ด Einsum ์‚ฌ์šฉ
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        # ๋งˆ์Šคํฌ ์ ์šฉ ์—ฌ๋ถ€ ํ™•์ธ
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # ์†Œํ”„ํŠธ๋งฅ์Šค๋ฅผ ์ ์šฉํ•˜์—ฌ ์–ดํ…์…˜ ์Šค์ฝ”์–ด ์ •๊ทœํ™”
        attention = torch.nn.functional.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        # ์–ดํ…์…˜ ์Šค์ฝ”์–ด๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ’์˜ ๊ฐ€์ค‘ ํ‰๊ท  ๊ณ„์‚ฐ
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        # ์ตœ์ข… ์ถœ๋ ฅ์„ ์œ„ํ•œ ์„ ํ˜• ๋ ˆ์ด์–ด ์ ์šฉ
        out = self.fc_out(out)
        return out

 

  • ์ด ์ฝ”๋“œ๋Š” Self-Attention ๋ชจ๋ธ์„ ๊ตฌํ˜„ํ•œ ์ฝ”๋“œ์ž…๋‹ˆ๋‹ค.
  1. ์ˆœ์„œ๋Š” ๊ฐ๊ฐ์˜ Query, Key, Value์— ๋Œ€ํ•ด ์„ ํ˜• ๋ ˆ์ด์–ด๋ฅผ ํ†ตํ•ด ์ ์ ˆํ•œ ์ฐจ์›์œผ๋กœ ๋งคํ•‘ํ•ฉ๋‹ˆ๋‹ค.
  2. ๊ฐ Head์— ๋Œ€ํ•œ ๋ถ€๋ถ„์œผ๋กœ Embedding์„ ๋‚˜๋ˆ„์–ด์„œ ๊ด€๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
  3. Attention Score ๊ณ„์‚ฐ์„ ์œ„ํ•ด์„œ Einsum์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. - Query(์ฟผ๋ฆฌ), Key(ํ‚ค)๊ฐ„์˜ Attention Score ๊ณ„์‚ฐ.
  4. ์„ ํƒ์ ์œผ๋กœ ์ ์šฉ๋œ 'Attention Score'๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์œ ํšจํ•˜์ง€ ์•Š์€ ๋ถ€๋ถ„์— masking ํ•ฉ๋‹ˆ๋‹ค.
  5. Attention Score๋ฅผ Softmax ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ •๊ทœํ™”ํ•˜๊ณ , Value(๋ฐธ๋ฅ˜)์ด ๊ฐ€์ค‘ ํ‰๊ท ์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
  6. ๊ทธ๋ฆฌ๊ณ  ๊ณ„์‚ฐ๋œ ๊ฒฐ๊ณผ์— ๋Œ€ํ•ด์„œ ์„ ํ˜• ๋ ˆ์ด์–ด๋ฅผ ์ ์šฉํ•˜์—ฌ ์ตœ์ข… ์ถœ๋ ฅ์„ ์–ป์–ด์„œ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

5. Multi-Head Attention

Multi-Head Attention์€ ์—ฌ๋Ÿฌ๊ฐœ์˜ 'Attention Head'๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ ๊ฐ 'Head Attention'์—์„œ ๋‚˜์˜จ ๊ฐ’์„ ์—ฐ๊ฒฐํ•ด ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ์‹์ž…๋‹ˆ๋‹ค.
  • ํ•œ๋ฒˆ์˜ 'Attention'์„ ์ด์šฉํ•˜์—ฌ ํ•™์Šต์„ ์‹œํ‚ค๋Š”๊ฒƒ ๋ณด๋‹ค๋Š” 'Attention'์„ ๋ณ‘๋ ฌ๋กœ ์—ฌ๋Ÿฌ๊ฐœ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ์‹์ž…๋‹ˆ๋‹ค.
  1. ์ˆœ์„œ๋Š” ์›๋ž˜์˜ Query(์ฟผ๋ฆฌ), Key(ํ‚ค) Value(๋ฐธ๋ฅ˜) ํ–‰๋ ฌ ๊ฐ’์„ Head์ˆ˜ ๋งŒํผ ๋ถ„ํ• ํ•ฉ๋‹ˆ๋‹ค.
  2. ๋ถ„ํ• ํ•œ ํ–‰๋ ฌ ๊ฐ’์„ ํ†ตํ•ด, ๊ฐ 'Attention' Value๊ฐ’๋“ค์„ ๋„์ถœํ•ฉ๋‹ˆ๋‹ค.
  3. ๋„์ถœ๋œ 'Attention value'๊ฐ’๋“ค์„ concatenate(์Œ“์•„ ํ•ฉ์น˜๊ธฐ)ํ•˜์—ฌ์„œ ์ตœ์ข… 'Attention Value'๋ฅผ ๋„์ถœํ•ฉ๋‹ˆ๋‹ค.

  • [4x4] ํฌ๊ธฐ์˜ ๋ฌธ์žฅ Embedding Vector์™€ [4x8]์˜ Query(์ฟผ๋ฆฌ), Key(ํ‚ค) Value(๋ฐธ๋ฅ˜)๊ฐ€ ์žˆ์„ ๋•Œ, ์ผ๋ฐ˜์ ์ธ ํ•œ ๋ฒˆ์— ๊ณ„์‚ฐํ•˜๋Š” Attention ๋ฉ”์ปค๋‹ˆ์ฆ˜์€ [4x4]*[4x8]=[4x8]์˜ 'Attention Value'๊ฐ€ ํ•œ ๋ฒˆ์— ๋„์ถœ๋ฉ๋‹ˆ๋‹ค.

  • 'Multi-Head Attention' ๋งค์ปค๋‹ˆ์ฆ˜์œผ๋กœ ๋ณด๋ฉด ์—ฌ๊ธฐ์„œ Head๋Š” 4๊ฐœ ์ž…๋‹ˆ๋‹ค. 'I, am, a, student'
  • Head๊ฐ€ 4๊ฐœ ์ด๋ฏ€๋กœ ๊ฐ ์—ฐ์‚ฐ๊ณผ์ •์ด 1/4๋งŒํผ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
  • ์œ„์˜ ๊ทธ๋ฆผ์œผ๋กœ ๋ณด๋ฉด ํฌ๊ธฐ๊ฐ€ [4x8]์ด์—ˆ๋˜, Query(์ฟผ๋ฆฌ), Key(ํ‚ค) Value(๋ฐธ๋ฅ˜)๋ฅผ 4๋“ฑ๋ถ„ ํ•˜์—ฌ [4x2]๋กœ ๋งŒ๋“ญ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด ์—ฌ๊ธฐ์„œ์˜ 'Attention Value'๋Š” [4x2]๊ฐ€ ๋ฉ๋‹ˆ๋‹ค.
  • ์ด 'Attention Value'๋“ค์„ ๋งˆ์ง€๋ง‰์œผ๋กœ Concatenate(ํ•ฉ์ณ์ค€๋‹ค)ํ•ฉ์ณ์ฃผ๋ฉด, ํฌ๊ธฐ๊ฐ€ [4x8]๊ฐ€ ๋˜์–ด ์ผ๋ฐ˜์ ์ธ Attention ๋งค์ปค๋‹ˆ์ฆ˜์˜ ๊ฒฐ๊ณผ๊ฐ’๊ณผ ๋™์ผํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ์˜ˆ์‹œ๋ฅผ ํ•œ๋ฒˆ ๋“ค์–ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
Summary: Query(์ฟผ๋ฆฌ), Key(ํ‚ค), Value(๋ฐธ๋ฅ˜)๊ฐ’์„ ํ•œ ๋ฒˆ์— ๊ณ„์‚ฐํ•˜์ง€ ์•Š๊ณ  head ์ˆ˜๋งŒํผ ๋‚˜๋ˆ  ๊ณ„์‚ฐ ํ›„ ๋‚˜์ค‘์— Attention Value๋“ค์„ ํ•ฉ์น˜๋Š” ๋ฉ”์ปค๋‹ˆ์ฆ˜. ํ•œ๋งˆ๋””๋กœ ๋ถ„ํ•  ๊ณ„์‚ฐ ํ›„ ํ•ฉ์‚ฐํ•˜๋Š” ๋ฐฉ์‹.

Multi-Head Attention Example

  • ์ž…๋ ฅ ๋‹จ์–ด ์ˆ˜๋Š” 2๊ฐœ, ๋ฐธ๋ฅ˜์˜ ์ฐจ์›์ˆ˜๋Š” 3, ํ—ค๋“œ๋Š” 8๊ฐœ์ธ ๋ฉ€ํ‹ฐ-ํ—ค๋“œ ์–ดํ…์…˜์„ ๋‚˜ํƒ€๋‚ธ ๊ทธ๋ฆผ ์ž…๋‹ˆ๋‹ค.
  • ๊ฐœ๋ณ„ ํ—ค๋“œ์˜ ์…€ํ”„ ์–ดํ…์…˜ ์ˆ˜ํ–‰ ๊ฒฐ๊ณผ๋Š” ‘์ž…๋ ฅ ๋‹จ์–ด ์ˆ˜ ×, ๋ฐธ๋ฅ˜ ์ฐจ์›์ˆ˜’, ์ฆ‰ 2×3 ํฌ๊ธฐ๋ฅผ ๊ฐ–๋Š” ํ–‰๋ ฌ์ž…๋‹ˆ๋‹ค.
  • 8๊ฐœ ํ—ค๋“œ์˜ ์…€ํ”„ ์–ดํ…์…˜ ์ˆ˜ํ–‰ ๊ฒฐ๊ณผ๋ฅผ ๋‹ค์Œ ๊ทธ๋ฆผ์˜ โ‘ ์ฒ˜๋Ÿผ ์ด์–ด ๋ถ™์ด๋ฉด 2×24 ์˜ ํ–‰๋ ฌ์ด ๋ฉ๋‹ˆ๋‹ค.
  • Multi-Head Attention์˜ ์ตœ์ข… ์ˆ˜ํ–‰ ๊ฒฐ๊ณผ๋Š” '์ž…๋ ฅ ๋‹จ์–ด ์ˆ˜' x '๋ชฉํ‘œ ์ฐจ์› ์ˆ˜' ์ด๋ฉฐ, Encoder, Decoder Block ๋ชจ๋‘์— ์ ์šฉ๋ฉ๋‹ˆ๋‹ค.
Multi-Head Attention์€ ๊ฐœ๋ณ„ Head์˜ Self-Attention ์ˆ˜ํ–‰ ๊ฒฐ๊ณผ๋ฅผ ์ด์–ด ๋ถ™์ธ ํ–‰๋ ฌ (โ‘ )์— W0๋ฅผ ํ–‰๋ ฌ๊ณฑํ•ด์„œ ๋งˆ๋ฌด๋ฆฌ ๋œ๋‹ค.
→ ์…€ํ”„ ์–ดํ…์…˜ ์ˆ˜ํ–‰ ๊ฒฐ๊ณผ ํ–‰๋ ฌ์˜ ์—ด(column)์˜ ์ˆ˜ × ๋ชฉํ‘œ ์ฐจ์›์ˆ˜

Multi-Head Attention Model Example Code

Multi-Head Attention Model ์˜ˆ์ œ ์ฝ”๋“œ์ž…๋‹ˆ๋‹ค. TF(Tensorflow)๋กœ ๊ตฌํ˜„ํ–ˆ์Šต๋‹ˆ๋‹ค.
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, **kargs):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = kargs['num_heads']
        self.d_model = kargs['d_model']

        assert self.d_model % self.num_heads == 0

        self.depth = self.d_model // self.num_heads

        # Query, Key, Value์— ๋Œ€ํ•œ ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ ์ดˆ๊ธฐํ™”
        self.wq = tf.keras.layers.Dense(kargs['d_model'])  # Query์— ๋Œ€ํ•œ ๊ฐ€์ค‘์น˜
        self.wk = tf.keras.layers.Dense(kargs['d_model'])  # Key์— ๋Œ€ํ•œ ๊ฐ€์ค‘์น˜
        self.wv = tf.keras.layers.Dense(kargs['d_model'])  # Value์— ๋Œ€ํ•œ ๊ฐ€์ค‘์น˜

        # ์ตœ์ข… ์ถœ๋ ฅ์„ ์œ„ํ•œ ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ ์ดˆ๊ธฐํ™”
        self.dense = tf.keras.layers.Dense(kargs['d_model'])

    def split_heads(self, x, batch_size):
        # ๋งˆ์ง€๋ง‰ ์ฐจ์›์„ (num_heads, depth)๋กœ ๋‚˜๋ˆ„๊ณ  ๊ฒฐ๊ณผ๋ฅผ (batch_size, num_heads, seq_len, depth)๋กœ ๋ณ€ํ™˜
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q, mask):
        batch_size = tf.shape(q)[0]

        # Query, Key, Value์— ๋Œ€ํ•œ ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ ์ ์šฉ
        q = self.wq(q)  # Query์— ๋Œ€ํ•œ ๊ฐ€์ค‘์น˜ ์ ์šฉ
        k = self.wk(k)  # Key์— ๋Œ€ํ•œ ๊ฐ€์ค‘์น˜ ์ ์šฉ
        v = self.wv(v)  # Value์— ๋Œ€ํ•œ ๊ฐ€์ค‘์น˜ ์ ์šฉ

        # ๊ฐ ํ—ค๋“œ๋กœ ์ž„๋ฒ ๋”ฉ์„ ๋‚˜๋ˆ„๊ธฐ
        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

        # Scaled Dot-Product Attention ์ˆ˜ํ–‰
        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention, attention_weights = scaled_dot_product_attention(
            q, k, v, mask)

        # ์–ดํ…์…˜ ํ—ค๋“œ๋ฅผ ๋‹ค์‹œ ํ•ฉ์น˜๊ณ  ์ตœ์ข… ์ถœ๋ ฅ์œผ๋กœ ๋ณ€ํ™˜
        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)
        concat_attention = tf.reshape(scaled_attention,
                                      (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

        # ์ตœ์ข… ์ถœ๋ ฅ ๊ณ„์‚ฐ
        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

        return output, attention_weights
  • ์ด ์ฝ”๋“œ๋Š” Multi-Head Attention Model์„ ๊ตฌํ˜„ํ•œ ์ฝ”๋“œ์ž…๋‹ˆ๋‹ค.
  1. ํด๋ž˜์Šค๋ฅผ ์ดˆ๊ธฐํ™”ํ•˜๋ฉด์„œ ํ•„์š”ํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ๋“ค์„ ์„ค์ •ํ•˜๊ณ , ๊ฐ๊ฐ์˜ Query, Key, Value, ์ตœ์ข… ์ถœ๋ ฅ์— ์‚ฌ์šฉ๋  ๊ฐ€์ค‘์น˜๋“ค์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
  2. ์ž…๋ ฅ ํ…์„œ์˜ ๋งˆ์ง€๋ง‰ ์ฐจ์›์„ ์—ฌ๋Ÿฌ Heads๋กœ ๋‚˜๋ˆ„๊ณ , ๊ฒฐ๊ณผ๋ฅผ ์ ์ ˆํ•˜๊ฒŒ ๋ณ€ํ™˜ํ•˜์—ฌ Heads์˜ ์ฐจ์›์„ ์•ž์œผ๋กœ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
  3. Query, Key, Value์— ๊ฐ๊ฐ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ์ ์šฉํ•˜๊ณ , Heads๋กœ ๋ถ„ํ• ํ•ฉ๋‹ˆ๋‹ค.
  4. Scaled Dot-Product Attention์„ ํ˜ธ์ถœํ•˜๊ณ , ๊ฒฐ๊ณผ๋ฅผ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
  5. ์ตœ์ข… ์ถœ๋ ฅ์„ ๊ณ„์‚ฐํ•˜๊ณ  ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

  • Ps. ์ด๋ ‡๊ฒŒ Attention์— ๋ฐํ•ด์„œ ์„ค๋ช…ํ•œ ๊ธ€์„ ์ผ๋Š”๋ฐ, ๋‚ด์šฉ๋„ ๋” ์“ธ๊ฒŒ ์žˆ์ง€๋งŒ ๋„ˆ๋ฌด ๊ธธ์–ด์งˆ๊ฑฐ ๊ฐ™์•„์„œ ์—ฌ๊ธฐ๊นŒ์ง€๋งŒ ์“ฐ๊ณ  ๋‹ค์Œ ๊ธ€๋กœ ๋„˜๊ธฐ๊ฒ ์Šต๋‹ˆ๋‹ค.
  • ๋‹ค์Œ๊ธ€์—๋Š” ํ•ฉ์„ฑ๊ณฑ, ์ˆœํ™˜ ์‹ ๊ฒฝ๋ง์ด๋ž‘ ๋น„๊ตํ•œ๊ฑฐ๋ž‘, Encoder, Decoder์—์„œ ์ˆ˜ํ–‰ํ•˜๋Š” Self-Attention์— ๋ฐํ•˜์—ฌ ์•Œ์•„๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.