A A
[NLP] GRU Model - LSTM Model์„ ๊ฐ€๋ณ๊ฒŒ ๋งŒ๋“  ๋ชจ๋ธ

1. GRU Model์€ ๋ฌด์—‡์ผ๊นŒ?

GRU (Gated Recurrent Unit)๋Š” ์ˆœํ™˜ ์‹ ๊ฒฝ๋ง(RNN)์˜ ํ•œ ์ข…๋ฅ˜๋กœ, ์•ž์—์„œ ์„ค๋ช…ํ•œ LSTM(Long Short-Term Memory)๋ชจ๋ธ์˜ ๋‹จ์ˆœํ™”๋œ ํ˜•ํƒœ๋กœ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • GRU Model์€ LSTM Model๊ณผ ๋น„์Šทํ•œ ๋ฐฉ์‹์œผ๋กœ ์ž‘๋™ํ•˜์ง€๋งŒ, ๋” ๊ฐ„๋‹จํ•œ ๊ตฌ์กฐ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
  • LSTM Model์˜ ์žฅ์ ์„ ์œ ์ง€ํ•˜๋˜, Gate(๊ฒŒ์ดํŠธ)์˜ ๊ตฌ์กฐ๋ฅผ ๋‹จ์ˆœํ•˜๊ฒŒ ๋งŒ๋“  ๋ชจ๋ธ์ด GRU Model ์ž…๋‹ˆ๋‹ค.
  • ๋˜ํ•œ GRU, LSTM Model์€ ๋‘˜๋‹ค Long-Term Dependency(์žฅ๊ธฐ ์˜์กด์„ฑ) ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•˜์—ฌ ๋งŒ๋“ค์–ด ์กŒ์Šต๋‹ˆ๋‹ค.

GRU Model - LSTM ์„ ๊ฐ€๋ณ๊ฒŒ(๊ฒฝ๋Ÿ‰ํ™”) ๋งŒ๋“  ๋ชจ๋ธ

  • LSTM Model์„ ์„ค๋ช…ํ•œ ๊ธ€์—์„œ ์„ค๋ช…ํ–ˆ์ง€๋งŒ LSTM Model์€ "Cell State(์…€ ์ƒํƒœ)"์™€ "Hidden state(์ˆจ๊ฒจ์ง„ ์ƒํƒœ)"๋ฅผ ๊ฐ๊ฐ ๋‹ค๋ฃน๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  4๊ฐœ์˜ "Gate" ์ž…๋ ฅ, ๋ง๊ฐ, ์ถœ๋ ฅ, ๊ธฐ์–ต ๊ฒŒ์ดํŠธ๋ฅผ ๊ฐ€์ง‘๋‹ˆ๋‹ค.
+ ์ œ๊ฐ€ ์ฐพ์•„๋ณด๋ฉด์„œ ์•ˆ ๋‚ด์šฉ์ธ๋ฐ, ๊ธฐ์–ต ๊ฒŒ์ดํŠธ๋Š” ๋ฌธํ—Œ์— ๋”ฐ๋ผ ๊ฒŒ์ดํŠธ๋ผ๊ณ  ์–ธ๊ธ‰ ๋˜๊ธฐ๋„ ํ•˜๊ณ  ์•ˆ๋œ๋‹ค๊ณ  ํ•˜๊ธฐ๋„ ํ•ฉ๋‹ˆ๋‹ค.
์ž…๋ ฅ ๊ฒŒ์ดํŠธ์—์„œ ํ†ต๊ณผํ•œ ์ •๋ณด & ๋ง๊ฐ ๊ฒŒ์ดํŠธ์—์„œ ์ œ๊ฑฐ๋˜์ง€ ์•Š์€ ์ •๋ณด๋ฅผ ํ•ฉ์ณ Cell State(์…€ ์ƒํƒœ)๋กœ ์—…๋ฐ์ดํŠธ ํ•˜๋Š” ๊ณผ์ •์„ "๊ธฐ์–ต" ๋‹จ๊ณ„๋ผ๊ณ  ํ•˜๋Š”๋ฐ ์ด ๊ณผ์ •์„ "๊ธฐ์–ต ๊ฒŒ์ดํŠธ"๋ผ๊ณ  ํ•œ๋‹ค๊ณ  ํ•˜๋„ค์š”.
  • ๊ทผ๋ฐ GRU Model์€ LSTM Model๊ณผ ๋‹ค๋ฅธ์ ์ด ํ•˜๋‚˜์˜ "Hidden state(์ˆจ๊ฒจ์ง„ ์ƒํƒœ)" ๋งŒ ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.
  • ๊ทธ๋ฆฌ๊ณ  LSTM์€ 4๊ฐœ์˜ ๊ฒŒ์ดํŠธ๋ฅผ ๊ฐ€์ง„๊ฒƒ๊ณผ ๋‹ฌ๋ฆฌ 2๊ฐœ์˜ ๊ฒŒ์ดํŠธ "Update Gate(์—…๋ฐ์ดํŠธ ๊ฒŒ์ดํŠธ)", "Reset Gate(๋ฆฌ์…‹ ๊ฒŒ์ดํŠธ)"๋งŒ ๊ฐ€์ง‘๋‹ˆ๋‹ค.

2. Update Gate(์—…๋ฐ์ดํŠธ ๊ฒŒ์ดํŠธ), Reset Gate(๋ฆฌ์…‹ ๊ฒŒ์ดํŠธ)

๊ทธ๋Ÿฌ๋ฉด ํ•œ๋ฒˆ "Update Gate(์—…๋ฐ์ดํŠธ ๊ฒŒ์ดํŠธ)", "Reset Gate(๋ฆฌ์…‹ ๊ฒŒ์ดํŠธ)"์— ๋ฐํ•˜์—ฌ ์„ค๋ช…ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

Update Gate(์—…๋ฐ์ดํŠธ ๊ฒŒ์ดํŠธ)

  • ์ด ๊ฒŒ์ดํŠธ๋Š” ์ƒˆ๋กœ์šด input(์ž…๋ ฅ)๊ฐ’์„ ์–ผ๋งˆ๋‚˜ "Hidden state(์ˆจ๊ฒจ์ง„ ์ƒํƒœ)"์— ๋ฐ˜์˜ํ• ์ง€๋ฅผ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
  • ๋˜ํ•œ "Update Gate(์—…๋ฐ์ดํŠธ ๊ฒŒ์ดํŠธ)"๋Š” LSTM Model์˜ "Input Gate(์ž…๋ ฅ ๊ฒŒ์ดํŠธ)" ์™€ "Forget Gate(๋ง๊ฐ ๊ฒŒ์ดํŠธ)"์˜ ์—ญํ• ์„ ํ•ฉ์นœ ๊ฒƒ์œผ๋กœ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Reset Gate(๋ฆฌ์…‹ ๊ฒŒ์ดํŠธ)

  • ์ด ๊ฒŒ์ดํŠธ๋Š” ๊ณผ๊ฑฐ์˜ "Hidden state(์ˆจ๊ฒจ์ง„ ์ƒํƒœ)"๋ฅผ ์–ผ๋งˆ๋‚˜ '์žŠ์–ด๋ฒ„๋ฆด(Forget)'์ง€๋ฅผ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
  • ์ด๊ฑด LSTM Model์˜ "Cell State(์…€ ์ƒํƒœ)"๋ฅผ ์–ด๋–ป๊ฒŒ ์—…๋ฐ์ดํŠธ๋ฅผ ํ• ์ง€ ๊ฒฐ์ •ํ•˜๋Š” ์—ญํ• ์„ ํ•ฉ๋‹ˆ๋‹ค.

3. GRU Model ์•ˆ์— ๋“ค์–ด๊ฐ€๋ณด๊ธฐ

LSTM์˜ ๋ชจ๋ธ๊ณผ ๋‹ค๋ฅธ์ ์€ "Cell State(์…€ ์ƒํƒœ)"๋ฅผ ์—†์• ๊ณ , "Hidden state(์ˆจ๊ฒจ์ง„ ์ƒํƒœ, ๊ทธ๋ฆผ์—์„  h)"๊ฐ€ ๊ทธ ์—ญํ• ์„ ํ•˜๋„๋ก ์ˆ˜์ •ํ•œ๊ฒƒ ์ž…๋‹ˆ๋‹ค.

 

GRU Model - LSTM ์„ ๊ฐ€๋ณ๊ฒŒ(๊ฒฝ๋Ÿ‰ํ™”) ๋งŒ๋“  ๋ชจ๋ธ

  • ์„ค๋ช…ํ•ด๋ณด์ž๋ฉด ์—ฌ๊ธฐ์„œ "ht"๋Š” t ์‹œ์ ์—์„œ์˜ Hidden state(์ˆจ๊ฒจ์ง„ ์ƒํƒœ)๋ฅผ ์˜๋ฏธํ•˜๊ณ  ์ฃผ๋กœ ์ˆœํ™˜์‹ ๊ฒฝ๋ง(LSTM, GRU)๋ชจ๋ธ์—์„œ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
    • Hidden state(์ˆจ๊ฒจ์ง„ ์ƒํƒœ)๋Š” Sequence์˜ ๊ฐ Element(์š”์†Œ)๋ฅผ ์ฒ˜๋ฆฌํ• ๋•Œ ๋งˆ๋‹ค ์—…๋ฐ์ดํŠธ ๋˜๋Š” ์‹ ๊ฒฝ๋ง์˜ "๊ธฐ์–ต"์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.
    • ๋˜ํ•œ ์ด์ „ Element์˜ ์ •๋ณด๋“ค์„ ํฌํ•จํ•˜๊ณ  ์žˆ์œผ๋ฉฐ, ์ด ์ •๋ณด๋“ค์€ ๋‹ค์Œ Element๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š”๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
    • ๋˜ํ•œ "ht"๋Š” t -1 ์‹œ์ ์˜ Hidden state & "ht -1"์˜ t ์‹œ์ ์— input(์ž…๋ ฅ)๊ฐ’์— ๊ธฐ๋ฐ˜ํ•˜์—ฌ ๊ณ„์‚ฐ๋ฉ๋‹ˆ๋‹ค.

 

  • "rt"๋Š” Reset Gate(๋ฆฌ์…‹ ๊ฒŒ์ดํŠธ)๋ฅผ ๋‚˜ํƒ€๋‚ด๋ฉฐ ๋‹ค์Œ ์ƒํƒœ "ht", Hidden state(์ˆจ๊ฒจ์ง„ ์ƒํƒœ)๋ฅผ ๊ณ„์‚ฐ ํ•˜๊ธฐ ์œ„ํ•ด ์ด์ „ ์ƒํƒœ"ht-1"๊ฐ’์—์„œ ์‚ฌ์šฉํ•  ๋ถ€๋ถ„์„ ์ œ์–ด ํ•ฉ๋‹ˆ๋‹ค.
    • ์ด๋Š” ํ˜„์žฌ input(์ž…๋ ฅ)๊ฐ’๊ณผ ์ด์ „ Hidden state(์ˆจ๊ฒจ์ง„ ์ƒํƒœ)์— ๊ธฐ๋ฐ˜ํ•˜์—ฌ ๊ณ„์‚ฐ๋˜๋ฉฐ, 0~ 1 ์‚ฌ์ด์˜ ๊ฐ’์„ ๊ฐ€์ง‘๋‹ˆ๋‹ค.
    • Reset Gate(๋ฆฌ์…‹ ๊ฒŒ์ดํŠธ)์˜ ๊ฐ’์ด 0์— ๊ฐ€๊นŒ์›Œ์ง€๋ฉด ์ด์ „ Hidden state(์ˆจ๊ฒจ์ง„ ์ƒํƒœ)์˜ ์ •๋ณด๋Š” "Forget" ๋ฉ๋‹ˆ๋‹ค.
    • Reset Gate(๋ฆฌ์…‹ ๊ฒŒ์ดํŠธ)์˜ ๊ฐ’์ด 1์— ๊ฐ€๊นŒ์›Œ์ง€๋ฉด ์ด์ „ Hidden state(์ˆจ๊ฒจ์ง„ ์ƒํƒœ)์˜ ์ •๋ณด๋Š” ์œ ์ง€๋ฉ๋‹ˆ๋‹ค.

 

  • "zt"๋Š” Update Gate(์—…๋ฐ์ดํŠธ ๊ฒŒ์ดํŠธ)๋ฅผ ๋‚˜ํƒ€๋‚ด๋ฉฐ ์ด์ „์ƒํƒœ "ht-1"๊ฐ’์„ ์œ ์ง€ํ• ์ง€, ์•„๋‹ˆ๋ฉด ์ƒˆ๋กœ์šด ์ƒํƒœ "ht"๋กœ ๋Œ€์ฒดํ• ์ง€๋ฅผ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค.
    • ์ด๋Š” ํ˜„์žฌ input(์ž…๋ ฅ)๊ฐ’๊ณผ ์ด์ „ Hidden state(์ˆจ๊ฒจ์ง„ ์ƒํƒœ)์— ๊ธฐ๋ฐ˜ํ•˜์—ฌ ๊ณ„์‚ฐ๋˜๋ฉฐ, 0~ 1 ์‚ฌ์ด์˜ ๊ฐ’์„ ๊ฐ€์ง‘๋‹ˆ๋‹ค.
    • Update Gate(์—…๋ฐ์ดํŠธ ๊ฒŒ์ดํŠธ)์˜ ๊ฐ’์ด 1์— ๊ฐ€๊นŒ์›Œ์ง€๋ฉด ์ด์ „ Hidden state(์ˆจ๊ฒจ์ง„ ์ƒํƒœ)์˜ ์ •๋ณด๊ฐ€ ์œ ์ง€๋ฉ๋‹ˆ๋‹ค.
    • Update Gate(์—…๋ฐ์ดํŠธ ๊ฒŒ์ดํŠธ)์˜ ๊ฐ’์ด 0์— ๊ฐ€๊นŒ์›Œ์ง€๋ฉด ์ƒˆ๋กœ์šด Hidden state(์ˆจ๊ฒจ์ง„ ์ƒํƒœ)์˜ ์ •๋ณด๋„ ๋Œ€์ฒด๋ฉ๋‹ˆ๋‹ค.

4. GRU Model ์ฝ”๋“œ ์˜ˆ์‹œ

ํ•œ๋ฒˆ GRU Model์˜ ์ฝ”๋“œ ์˜ˆ์‹œ๋ฅผ ํ•œ๋ฒˆ ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
from keras.models import Sequential
from keras.layers import GRU, Dense

# ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
model = Sequential()

# GRU layer add
# ์ž…๋ ฅ ์ฐจ์›์€ ํŠน์„ฑ์˜ ์ˆ˜์— ๋”ฐ๋ผ ๋ณ€๊ฒฝํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. 100 - hidden unit ๊ฐœ์ˆ˜
model.add(GRU(100, input_shape=(timesteps, input_dim)))

# fully connected layer add, ์ถœ๋ ฅ Node ๊ฐœ์ˆ˜๊ฐ€ 1๊ฐœ ์ด๋ฏ€๋กœ 1์ด๋ผ๊ณ  ํ•จ
model.add(Dense(1, activation='sigmoid'))

# model compile
# ์ด์ง„ ๋ถ„๋ฅ˜ ๋ฌธ์ œ์ด๋ฏ€๋กœ ์†์‹ค ํ•จ์ˆ˜๋กœ 'binary_crossentropy'๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
  • GRU layer๋ฅผ ์„ ์–ธํ•ด์ค„๋•Œ, "model.add(GRU(100, input_shape=(timesteps, input_dim)))" ์ด๋Ÿฐ ํ˜•์‹์œผ๋กœ ์„ ์–ธ์„ ํ•ด์ค˜์•ผ ํ•ฉ๋‹ˆ๋‹ค.
    • ์—ฌ๊ธฐ์„œ๋Š” 100์€ hidden Unit์˜ ๊ฐœ์ˆ˜์ž…๋‹ˆ๋‹ค.
    • ์—ฌ๊ธฐ์„œ "timesteps"์€ Sequence์˜ Length(๊ธธ์ด), input_dim์€ ๊ฐ Sequence ์š”์†Œ์˜ ํŠน์„ฑ ๊ฐœ์ˆ˜๋ฅผ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.

 

  • ๊ทธ๋ฆฌ๊ณ  ํ•˜๋‚˜์˜ ์ถœ๋ ฅ ๋…ธ๋“œ๋ฅผ ๊ฐ€์ง„ Dense Layer(์™„์ „ ์—ฐ๊ฒฐ ๋ ˆ์ด์–ด)๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
    • "model.add(Dense(1, activation='sigmoid'))" -> ์ถœ๋ ฅ Node์˜ ๊ฐœ์ˆ˜๊ฐ€ ํ•˜๋‚˜์ด๋ฏ€๋กœ, activation ์™ผ์ชฝ ์˜†์— ์ถœ๋ ฅ๋…ธ๋“œ์˜ ๊ฐœ์ˆ˜๋ฅผ ์ ๋Š” ํŒŒ๋ผ๋ฏธํ„ฐ ๊ณต๊ฐ„์— 1์„ ์ ์–ด์ฃผ์—ˆ์Šต๋‹ˆ๋‹ค.

 

  • ๋˜ํ•œ ์ œ๊ฐ€ ๊ฐ„๋‹จํžˆ ๋งŒ๋“  ๋ชจ๋ธ์€ Binary(์ด์ง„)๋ถ„๋ฅ˜ ๋ฌธ์ œ์— ์ ํ•ฉํ•˜๊ฒŒ ๋งŒ๋“ค์—ˆ์Šต๋‹ˆ๋‹ค.
    • ๊ทธ๋ž˜์„œ ์ถœ๋ ฅ Dense Layer์— 0~1์‚ฌ์ด์˜ ํ™•๋ฅ ๊ฐ’์„ ์ถœ๋ ฅํ•ด์ฃผ๋Š” "Sigmoid ํ•จ์ˆ˜", Compile ์ฝ”๋“œ์—์„œ๋Š” Binary(์ด์ง„)๋ถ„๋ฅ˜๋ฅผ ์œ„ํ•ด์„œ "binary_crossentropy"๋ฅผ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค.

 

  • ๋งŒ์•ฝ์— ๋‹ค์ค‘ Class๋ฅผ ๋ถ„๋ฅ˜ํ•˜๋Š” ๋ฌธ์ œ๋ผ๋ฉด ํ™œ์„ฑํ™” ํ•จ์ˆ˜ -> "activation"์„ "softmax"ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ํ•˜๊ณ 
  • ๋ชจ๋ธ Compile ์ฝ”๋“œ์—์„œ ์†์‹ค ํ•จ์ˆ˜ -> "loss" ๋ฅผ 'categorical_crossentropy'๋กœ ์„ค์ •ํ•ด์ฃผ๋ฉด ๋ฉ๋‹ˆ๋‹ค.
+ Softmax ํ•จ์ˆ˜: Vector๋ฅผ input์œผ๋กœ ๋ฐ›์•„์„œ ๊ฐ ์›์†Œ์˜ ๊ฐ’์„ 0~1 ์‚ฌ์ด์˜ ๊ฐ’์œผ๋กœ ๋ฐ˜ํ™˜ํ•˜์—ฌ, ์ด ๊ฐ’๋“ค์˜ ํ•ฉ์ด 1์ด ๋˜๋„๋ก ๋งŒ๋“œ๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค. ์ฃผ๋กœ ํ™•๋ฅ  ๋ถ„ํฌ๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š”๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
  • ๊ณต์‹๋ฌธ์„œ์— ์–ด๋– ํ•œ parameter๊ฐ€ ๋“ค์–ด๊ฐˆ์ˆ˜ ์žˆ๋Š”์ง€ ๋‚˜์™€์žˆ์œผ๋‹ˆ๊นŒ ์ฐธ๊ณ ํ•ด์ฃผ์„ธ์š”.
 

tf.keras.layers.GRU  |  TensorFlow v2.15.0.post1

Gated Recurrent Unit - Cho et al. 2014.

www.tensorflow.org