๋ฐ์ํ
1. LSTM Model์ ๋ฌด์์ผ๊น?
LSTM์ Long Short-Term Memory์ ์ฝ์์ ๋๋ค. RNN - Recurrent Neural Network (์ํ ์ ๊ฒฝ๋ง)์ ๋ฌธ์ ์ธ Long-Term Dependency (์ฅ๊ธฐ ์์กด์ฑ) ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์ ์๋ ๋ชจ๋ธ์ ๋๋ค.
- ๊ธฐ์กด์ RNN(์ํ ์ ๊ฒฝ๋ง)๋ชจ๋ธ์ ์๊ฐ & ๊ณต๊ฐ์ ํจํด์ ํ์ตํ๊ณ ์์ธกํ๋๋ฐ ์ ์ฉํฉ๋๋ค. ๊ทธ๋์ ์์ฐจ์ ์ธ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋๋ฐ์๋ ๊ฐ์ ์ด ์๋ ๋ชจ๋ธ์ ๋๋ค.
- ๋ค๋ง Long-Term Dependency(์ฅ๊ธฐ ์์กด์ฑ) ๋ฌธ์ ๊ฐ ์์ด์ ๊ธด Sequence์ ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๋๋ฐ ์ด๋ ค์์ด ์์ต๋๋ค.
- Long-Term Dependency(์ฅ๊ธฐ ์์กด์ฑ)์ ๋ํ ์ค๋ช ์ ์๋์ ๊ธ์ ์ ํ์์ผ๋๊น ์ฐธ๊ณ ํด์ฃผ์ธ์.
- ๊ทธ๋ฆฌ๊ณ LSTM ๋ชจ๋ธ์ ๊ธฐ์กด์ RNN ๋ชจ๋ธ๊ณผ๋ ๋ค๋ฅธ์ ์ Gradient Flow์ Weight(๊ฐ์ค์น)๊ฐ ๊ณฑํด์ง์ง ์๋๋ก ๋ชจ๋ธ์ ๊ตฌ์กฐ๋ฅผ ๋ณ๊ฒฝํ์ต๋๋ค.
- ์ด์ ๋ ๊ธฐ๋ณธ์ ์ธ RNN์ ๋ชจ๋ธ์ ํ์ต ๋ฌธ์ ์ธ Gradient Vanishing(๊ทธ๋๋์ธํธ ์์ค)๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด์ ์ ๋๋ค.
- ๊ธฐ๋ณธ RNN Model์ Backpropagation(์ญ์ ํ) ๊ณผ์ ์์ Gradient(๊ธฐ์ธ๊ธฐ)๊ฐ ์๊ฐ์ ๊ฑฐ์ฌ๋ฌ ์ฌ๋ผ๊ฐ๋ฉด์ w๊ฐ ๊ณฑํด์ง๋๋ค.
- ์ด ๋ฌธ์ ๋ก ์ธํด์ ๋ง์ฝ Time-step์ด ๊ธธ์ด์ง๋ฉด, Gradient(๊ธฐ์ธ๊ธฐ)๊ฐ 0์ ๊ฐ๊น์์ง๋๋ค.
- ๊ฒฐ๊ตญ, Network๊ฐ ์ด์ Time-step์ ์ ๋ณด๋ฅผ ์ ๊ธฐ์ต ๋ชปํ๋ค๋ Gradient loss(๊ธฐ์ธ๊ธฐ ์์ค) ๋ฌธ์ ๊ฐ ๋ฐ์ํฉ๋๋ค.
- ๊ทธ๋์ Gradient loss(๊ธฐ์ธ๊ธฐ ์์ค) ๋ฌธ์ ๋ฅผ ํด๊ฒฐ ํ๊ธฐ ์ํด์ "Uninterrupted Gradient flow" ์ฆ, Gradient(๊ธฐ์ธ๊ธฐ)๊ฐ ์ ๊ฒฝ๋ง์ ํตํด ์ํํ๊ฒ ํ๋ฅผ์ ์๊ฒ ํด์ค์ผ ํฉ๋๋ค.
Tip. Gradient๋ฅผ ์ํํ ํ๋ฅด๊ฒ ํ๋ ๋ฐฉ์์ด ResNet์ Residual Connection๊ณผ ์ ์ฌํฉ๋๋ค.
* Residual Connection: ResNet์ ํต์ฌ ๊ฐ๋ ์ค ํ๋์ด์ ์ ๊ฒฝ๋ง์ ํ์ต์ ๋๋ ์ฐ๊ฒฐ ๋ฐฉ์์ ๋๋ค.
- ๊ทธ๋์ผ Gradient loss(๊ธฐ์ธ๊ธฐ ์์ค) ๋ฌธ์ ์์ด Backpropagation(์ญ์ ํ)๊ณผ์ ์ด ์ ์ด๋ฃจ์ด ์ง์ ์๊ธฐ ๋๋ฌธ์ ๋๋ค.
2. LSTM ๋ชจ๋ธ์ ๊ตฌ์กฐ
LSTM ๋ชจ๋ธ์ ์ด๋ ํ ๊ตฌ์กฐ๋ก ์ด๋ฃจ์ด์ ธ ์๊ณ , ์ด๋ ํ ๋ฐฉ์์ผ๋ก ์๋ํ๋์ง ํ๋ฒ ์์๋ณด๊ฒ ์ต๋๋ค.
- LSTM ๋ชจ๋ธ์ ์ฅ๊ธฐ ๊ธฐ์ต๊ณผ ๋จ๊ธฐ ๊ธฐ์ต์ด ์๋ก์ด Event์ ํฉ์ณ์ ์ ๊ฐฑ์ ๋๋ ๋ฐฉ์์ผ๋ก ์งํ๋ฉ๋๋ค.
- ์ฅ๊ธฐ ๊ธฐ์ต์ ์ค๋ ์ง์๋๋๋ก, ๋จ๊ธฐ ๊ธฐ์ต์ ์ต๊ทผ ์ฌ๊ฑด์ ์ค์ฌ์ผ๋ก ๊ธฐ์ตํ๋๋ก ๊ธฐ์ต์ด ํ์ฑ๋๋ ๊ณผ์ ์ด ๋ถ๋ฆฌ๋ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง๋๋ค.
- ๊ทธ๋ฆฌ๊ณ ์ฅ๊ธฐ ๊ธฐ์ต์ "Cell State", ๋จ๊ธฐ ๊ธฐ์ต์ "Hidden state" ์ํ๋ฅผ ๊ฐ์ง๋๋ค. Cell State & Hidden state์ ๋ํ ์ค๋ช ์ ์๋์์ ํ๋๋ก ํ๊ฒ ์ต๋๋ค.
์ด๋ฒ์๋ ํ๋ฒ LSTM ๋ชจ๋ธ์ ์์ธํ ๊ตฌ์กฐ์ ๋ํ์ฌ ํ๋ฒ ๋ด๋ณด๊ฒ ์ต๋๋ค.
๊ทผ๋ฐ, ๋ชจ๋ธ์ ๊ตฌ์กฐ๋ฅผ ๋ณด๋ฉด, "Gate"๋ผ๋ ๊ฐ๋ ์ด ์์ต๋๋ค. ํ๋ฒ ์ค๋ช ์ ํด๋ณด๋ฉด
LSTM ๋ชจ๋ธ์ Long-Term Dependency(์ฅ๊ธฐ ์์กด์ฑ)์ ํด๊ฒฐํ๊ธฐ ์ํด์ 'Cell State (์ ์ํ)' ๋ผ๋ ๊ฐ๋ ์ด ์กด์ฌํฉ๋๋ค.
- "Cell State (์ ์ํ)"๋ LSTM Model์์ ๊ฐ Cell์์ ์ ์ง๋๋ 'Memory' ๋ผ๋ ๊ฐ๋ ์ผ๋ก ์กด์ฌํฉ๋๋ค.
- ์ด๋ฅผ ํตํด์ LSTM Model์ ์ ๋ณด๋ฅผ ๊ณ์ ๊ฐ์ง๊ณ ์๊ฑฐ๋ ํ์์๋ ์ ๋ณด๋ฅผ ๋ฒ๋ฆด ์ ์์ต๋๋ค. ์ด๋ "Gate(๊ฒ์ดํธ)" ๋ผ๋ ๊ตฌ์กฐ๋ก ์กฐ์ ๋ฉ๋๋ค.
- LSTM ๋ชจ๋ธ์ 4์ข ๋ฅ์ Gate๋ฅผ ๊ฐ์ง๊ณ ์์ต๋๋ค. ๊ฐ "Gate"๋ ์ด๋ ํ ์ญํ ์ ํ๋์ง ํ๋ฒ ์ค๋ช ํด ๋ณด๊ฒ ์ต๋๋ค.
- ๋ง๊ฐ ๊ฒ์ดํธ (Forget Gate)
- ๊ธฐ์ต ๊ฒ์ดํธ (Remember Gate)
- ์ ๋ ฅ ๊ฒ์ดํธ (Input Gate)
- ์ถ๋ ฅ ๊ฒ์ดํธ (Output Gate)
Forget Gate (๋ง๊ฐ ๊ฒ์ดํธ)
Forget Gate (๋ง๊ฐ ๊ฒ์ดํธ)๋ ์ฅ๊ธฐ ๊ธฐ์ต์ค Cell State (์ ์ํ) ์์ ์ด๋ ํ ์ ๋ณด๋ฅผ ์ ๊ฑฐํ ์ง๋ฅผ, ์ด๋ ํ ์ ๋ณด๋ฅผ ๊ธฐ์ตํ ์ง๋ฅผ ์ ํํฉ๋๋ค.
- ๊ณผ๊ฑฐ์ ์ ๋ณด์ค ํ์ํ์ง ์์ ๋ถ๋ถ์ Cell State (์ ์ํ)์์ ์ ๊ฑฐํ์ฌ, ์ ๊ฒฝ๋ง์ด ํ์ํ ์ ๋ณด๋ง์ ์ ์ง ํ ์ ์๋๋ก ๋๋ ์ญํ ์ ํฉ๋๋ค.
Forget Gate(๋ง๊ฐ ๊ฒ์ดํธ)์ ์๋๋ฐฉ์์ ์๋์ ๊ฐ์ต๋๋ค.
- ํ์ฌ์ input๊ฐ๊ณผ ์ด์ ์ hidden state(์จ๊ฒจ์ง ์ํ)๊ฐ Forget Gate(๋ง๊ฐ ๊ฒ์ดํธ)๋ก ์ ๋ฌ๋ฉ๋๋ค.
- ํ์ฌ์ input๊ฐ, hidden state(์จ๊ฒจ์ง ์ํ)๋ Weight(๊ฐ์ค์น)์ Bias(ํธํฅ)์ผ๋ก ์กฐ์ ๋ํ, Sigmoid ํจ์๋ฅผ ํต๊ณผํฉ๋๋ค. ์ด Sigmoid ํจ์๋ 0~1 ์ฌ์ด์ ๊ฐ์ ์ถ๋ ฅํ๋ฉฐ, ์ถ๋ ฅ๊ฐ์ด ์ด๋ค ์ ๋ณด๋ฅผ 'Forget' ์ฆ, '๋ง๊ฐ' ํ ์ง๋ฅผ ๊ฒฐ์ ํฉ๋๋ค.
- Sigmoid ํจ์์ ์ถ๋ ฅ๊ฐ์ด 1์ ๊ฐ๊น์ฐ๋ฉด, 'Forget' ํ์ง ์์ต๋๋ค. ์ฆ, '๋ง๊ฐ'ํ์ง ์๋๋ค๋ ์๋ฏธ์ ๋๋ค. ๋ง์ฝ 0์ ๊ฐ๊น์ฐ๋ฉด ๊ทธ ์ ๋ณด๋ฅผ ๋ง๊ฐํ๋ค๋ ์๋ฏธ์ ๋๋ค.
- ์ด ์ถ๋ ฅ ๊ฐ์ด "Cell State(์ ์ํ)"์ ์์๋ณ๋ก ๊ณฑํด์ ธ์ "Cell State(์ ์ํ)"์์ ํน์ ์ ๋ณด๋ฅผ Forget(๋ง๊ฐ)ํ๊ฒ ๋ฉ๋๋ค.
Remember Gate (๊ธฐ์ต ๊ฒ์ดํธ)
Remember Gate (๊ธฐ์ต ๊ฒ์ดํธ)๋ ์ ํ๋ ์๋ก์ด ๊ธฐ์ต์ผ๋ก ์ฅ๊ธฐ ๊ธฐ์ต์ ๊ฐฑ์ ํฉ๋๋ค. ๋ค๋ฅธ ๊ฒ์ดํธ์ ๋ฌ๋ฆฌ ๋ณ๋์ Gate ์ฐ์ฐ ์์ด ๋ํ๊ธฐ๋ก๋ง ๊ตฌ๋ณ๋ฉ๋๋ค.
- ์๋ก์ด ์ ๋ณด๋ฅผ Cell State(์ ์ํ)์ ์ถ๊ฐํ๋ ๋ฐฉ์์ผ๋ก ๊ฐฑ์ ํฉ๋๋ค. ์๋ก์ด input์ ๊ธฐ์ตํ๊ณ ์ด์ ์ ์ํ๋ฅผ ์ ๋ฐ์ดํธ ํ๋ ๋ฐฉ๋ฒ์ ๊ฒฐ์ ํ๋๋ฐ ๋์์ ์ฃผ๋ ์ญํ ์ ํฉ๋๋ค.
Remember Gate(๊ธฐ์ต ๊ฒ์ดํธ)์ ์๋๋ฐฉ์์ ์๋์ ๊ฐ์ต๋๋ค.
- ํ์ฌ input๊ฐ๊ณผ ์ด์ hidden state(์จ๊ฒจ์ง ์ํ)๊ฐ Remember Gate(๊ธฐ์ต ๊ฒ์ดํธ)๋ก ์ ๋ฌ๋ฉ๋๋ค.
- ์ด๋ค์ ๊ฐ๊ฐ Weight(๊ฐ์ค์น)์ Bias(ํธํฅ)์ผ๋ก ์กฐ์ ๋ ํ ํ๋๋ Sigmoid ํจ์๋ฅผ ํต๊ณผํ๊ณ , ๋ค๋ฅธ ํ๋๋ tan(ํ์ ํธ) ํจ์๋ฅผ ํต๊ณผํฉ๋๋ค.
- Sigmoid ํจ์๋ 0 ~ 1 ์ฌ์ด์ ๊ฐ์ ์ถ๋ ฅํ์ฌ ์๋ก์ด ์ ๋ณด์ ์ค์์ฑ์ ๊ฒฐ์ ํ๊ณ , tan(ํ์ ํธ) ํจ์๋ -1์์ 1 ์ฌ์ด์ ๊ฐ์ ์ถ๋ ฅํ์ฌ ์๋ก์ด ๊ฐ๋ค์ ์์ฑํฉ๋๋ค.
- Sigmoid ํจ์์ ๊ฐ, tan(ํ์ ํธ)ํจ์์ ๊ฐ ๋ค์ด ๊ณฑํด์ ธ์ ์ ์ํ์ ์ถ๊ฐ๋ ์๋ก์ด ๊ฐ์ ์์ฑํฉ๋๋ค.
- ์๋ก์ด ๊ฐ์ด ๋ง๊ฐ ๊ฒ์ดํธ์์ ์์ฑ๋ ์ ๋ฐ์ดํธ๋ Cell State(์ ์ํ)์ ๋ํด์ ธ์ ์ต์ข Cell State(์ ์ํ)๋ฅผ ํ์ฑํฉ๋๋ค.
Input Gate (์ ๋ ฅ ๊ฒ์ดํธ)
Input Gate(์ ๋ ฅ ๊ฒ์ดํธ)๋ Cell State(์ ์ํ)์ ์ด๋ค ์๋ก์ด ์ ๋ณด๋ฅผ ์ถ๊ฐํ ์ง๋ฅผ ๊ฒฐ์ ํ๋ ์ญํ ์ ํฉ๋๋ค.
- ๋จ๊ธฐ ๊ธฐ์ต๊ณผ ์๋ก์ด Event๊ฐ ํฉ์ณ์ง ์๋ก์ด ๊ธฐ์ต์์ Predict(์์ธก)์ ํ์ํ ๋ถ๋ถ์ ์ ํํฉ๋๋ค.
- ์ด๋ถ๋ถ์ ์ค๋ช ํด๋ณด๋ฉด, Cell State(์ ์ํ - ๊ธฐ์ต)์ hidden state(๋จ๊ธฐ ๊ธฐ์ต)์ ์ ์งํ๋ฉด์ ์๋ก์ด ์ ๋ ฅ(Event)๊ฐ ๋ค์ด์ฌ ๋ ๋ง๋ค ์ด๋ฅผ ์ ๋ฐ์ดํธ ํฉ๋๋ค.
Input Gate(์ ๋ ฅ ๊ฒ์ดํธ)์ ์๋๋ฐฉ์์ ์๋์ ๊ฐ์ต๋๋ค.
- ํ์ฌ ์ ๋ ฅ๊ณผ ์ด์ hidden state(์จ๊ฒจ์ง ์ํ)๊ฐ Input Gate(์ ๋ ฅ ๊ฒ์ดํธ) ๋ก ์ ๋ฌ๋ฉ๋๋ค.
- ์ด๋ค์ ๊ฐ๊ฐ Weight(๊ฐ์ค์น)์ Bias(ํธํฅ)์ผ๋ก ์กฐ์ ๋ ํ, ํ๋๋ Sigmoid ํจ์๋ฅผ ํต๊ณผํ๊ณ , ๋ค๋ฅธ ํ๋๋ tan(ํ์ ํธ)ํจ์๋ฅผ ํต๊ณผํฉ๋๋ค.
- Sigmoid ํจ์๋ 0 ~ 1 ์ฌ์ด์ ๊ฐ์ ์ถ๋ ฅํ์ฌ ์ด๋ค ์ ๋ณด๋ฅผ Cell State(์ ์ํ)์ ์ถ๊ฐํ ์ง ๊ฒฐ์ ํ๊ณ , tan(ํ์ ํธ)ํจ์๋ -1์์ 1 ์ฌ์ด์ ๊ฐ์ ์ถ๋ ฅํ์ฌ ์๋ก์ด ํ๋ณด ๊ฐ๋ค์ ์์ฑํฉ๋๋ค.
- Sigmoid ํจ์์ ๊ฐ, tan(ํ์ ํธ)ํจ์์ ๊ฐ ๋ค์ด ๊ณฑํด์ ธ์ ์ ์ํ์ ์ถ๊ฐ๋ ์๋ก์ด ๊ฐ์ ์์ฑํฉ๋๋ค.
- ์ด ์๋ก ์์ฑ๋ ๊ฐ์ Forget Gate(๋ง๊ฐ ๊ฒ์ดํธ)์์ ์ ๋ฐ์ดํธ๋ Cell State(์ ์ํ) ์ ๋ํด์ ธ์ ์ต์ข ์ ์ธ Cell State(์ ์ํ) ๋ฅผ ํ์ฑํ๊ฒ ๋ฉ๋๋ค.
Output Gate (์ถ๋ ฅ ๊ฒ์ดํธ)
Output Gate (์ถ๋ ฅ ๊ฒ์ดํธ)๋ Cell State(์ ์ํ)์์ ์ด๋ค ์ ๋ณด๋ฅผ ์ต์ข ์ถ๋ ฅ์ผ๋ก ๋ณด๋ผ์ง ๊ฒฐ์ ํ๋ ์ญํ ์ ํฉ๋๋ค.
LSTM Model์ด ํ์ํ ์ ๋ณด๋ง์ ์ ํํ์ฌ ์ต์ข ์ถ๋ ฅ์ผ๋ก ์ ๋ฌํ๊ฒ ๋์ต๋๋ค.
- ์ฌ๊ฑด, ๋จ๊ธฐ, ์ฅ๊ธฐ ๊ธฐ์ต์ด ์ฐํฉ๋์ด ์๋ ๊ฐฑ์ ๋ ์ฅ๊ธฐ ๊ธฐ์ต์์ ์์ธก์ ํ์ํ ๋ถ๋ถ์ ์ ํํฉ๋๋ค.
Output Gate(์ถ๋ ฅ ๊ฒ์ดํธ)์ ์๋๋ฐฉ์์ ์๋์ ๊ฐ์ต๋๋ค.
- ํ์ฌ ์ ๋ ฅ๊ณผ ์ด์ hidden state(์จ๊ฒจ์ง ์ํ)๊ฐ Output Gate(์ถ๋ ฅ ๊ฒ์ดํธ)๋ก ์ ๋ฌ๋ฉ๋๋ค.
- ์ด๋ค์ ๊ฐ๊ฐ Weight(๊ฐ์ค์น)์ Bias(ํธํฅ)์ผ๋ก ์กฐ์ ๋ ํ, ํ๋๋ Sigmoid ํจ์๋ฅผ ํต๊ณผํฉ๋๋ค.
- Sigmoid ํจ์๋ 0 ~ 1 ์ฌ์ด์ ๊ฐ์ ์ถ๋ ฅํ์ฌ ์ด๋ค Cell State(์ ์ํ)์ ๋ถ๋ถ์ ์ถ๋ ฅํ ์ง ๊ฒฐ์ ํฉ๋๋ค.
- ๋์์, ํ์ฌ Cell State(์ ์ํ)๋ tan(ํ์ ํธ)ํจ์๋ฅผ ํต๊ณผํ์ฌ -1์์ 1 ์ฌ์ด์ ๊ฐ์ ์ถ๋ ฅํฉ๋๋ค.
- Sigmoid ํจ์์ ๊ฐ, tan(ํ์ ํธ)ํจ์์ ๊ฐ๋ค์ด ๊ณฑํด์ ธ์ ์ต์ข ์ถ๋ ฅ์ ์์ฑํฉ๋๋ค.
- Output Gate(์ถ๋ ฅ ๊ฒ์ดํธ)๋ Cell State(์ ์ํ)์ ์ ๋ณด ์ค ์ด๋ค ๋ถ๋ถ์ด ๋ค์ hidden state(์จ๊ฒจ์ง ์ํ)๋ก ์ ๋ฌ๋ ์ง, ๋๋ ์ ๊ฒฝ๋ง์ ์ต์ข Output(์ถ๋ ฅ)์ผ๋ก ์ฌ์ฉ๋ ์ง๋ฅผ ๊ฒฐ์ ํฉ๋๋ค.
3. LSTM Model์ ๋ค์ด๊ฐ๋ณด๊ธฐ
์ฌ๊ฑด์ ๋จ๊ธฐ ๊ธฐ์ต, ์ฅ๊ธฐ ๊ธฐ์ต์ด ์ด๋์ ๋ ์์ธก์ ๊ด์ฌ ํ๋์ง๋ LSTM์ Gate ๊ตฌ์กฐ๋ก ์กฐ์ ๋ฉ๋๋ค.
- ๊ทธ๋ฆฌ๊ณ LSTM ๋ชจ๋ธ์ ์ฅ๊ธฐ๊ธฐ์ต์ ์ค๋ ๊ธฐ์ตํ ์ ์๊ณ , ์ด๋ ๋ถ๋ถ์ ๊ธฐ์ตํ ์ง๋ฅผ ์ ํ ํ ์ ์์ต๋๋ค.
๊ทธ๋ฌ๋ฉด ์ด์ LSTM์ ๋ชจ๋ธ์ ์์ธํ๊ฒ ํ๋ฒ ๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
- ๊ทธ๋ฆผ์ ๋ฐํ์ฌ ์ค๋ช ์ ํด๋ณด๋ฉด Gate ์ข ๋ฅ ft(๋ง๊ฐ ๊ฒ์ดํธ), it(์ ๋ ฅ ๊ฒ์ดํธ), ot(์ถ๋ ฅ ๊ฒ์ดํธ)๋ ์์์ ์ญํ ๊ณผ ์๋ ๋ฐฉ์์ ๋ฐํ์ฌ ์ค๋ช ํ์ผ๋ ํจ์ค ํ๊ฒ ์ต๋๋ค. ์์ ์ ์ ๋ด์ฉ์ ์ฐธ๊ณ ํด์ฃผ๋ฉด์ ํ๋ฒ ๋ด์ฃผ์ธ์!
์์ธํ ๊ตฌ์กฐ ์ค๋ช
- ์ฌ๊ธฐ์ "Ct"๋ Cell State(์
์ํ)๋ฅผ ๋ํ๋ด๋ฉฐ *Internal state๋ก ์ฅ๊ธฐ ๊ธฐ์ต์ด ๋๋ค๊ณ ๋์ด์์ต๋๋ค.
- *Internal state๋ Model์ด ์๊ฐ์ ๋ฐ๋ฅธ input์ Sequence๋ฅผ ์ฒ๋ฆฌํ๋ฉด์ ์ ์งํ๋ "State(์ํ)"๋ฅผ ์๋ฏธํฉ๋๋ค.
- RNN ์์๋ Internal stater๊ฐ ๊ฐ time-step์์์ hidden state(์จ๊ฒจ์ง ์ํ)๋ก ํํ ๋๋ฉฐ, ์ด๋ ์ด์ time-step์ hidden state(์จ๊ฒจ์ง ์ํ)์ ํ์ฌ์ time-step์ ์ ๋ ฅ์ ๊ธฐ๋ฐํฉ๋๋ค.
- ๊ทธ๋ฆฌ๊ณ Cell State(์
์ํ)๋ Sigmoid ํจ์(์์ Network)์ ๊ธฐ๋ฐํ๋ฉฐ, Cell State(์
์ํ) ๊ฐ์ *Linear Interaction ์ํ๋ก ๊ตฌ์ฑํ๋ฉฐ *Gradient Flow ์ง๋ฆ๊ธธ์ ์์ฑํ ๊ฒ์ด ํต์ฌ idea์
๋๋ค.
- *Linear Interaction: input & output ์ฌ์ด์ ์ํธ์์ฉ์ด ์ ํ์ ์ด ๊ด๊ณ๋ฅผ ๊ฐ์ง๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค. ์ฆ, input์ ํฌ๊ธฐ๊ฐ ๋ณํ๋ฉด output๋ ๊ทธ์ ๋น๋กํ์ฌ ๋ณํฉ๋๋ค.
- Gradient Flow: ์ ๊ฒฝ๋ง ํ์ต ๊ณผ์ ์์ ์ค์ฐจ๋ฅผ Backpropagation(์ญ์ ํ)ํ์ฌ Weight(๊ฐ์ค์น)๋ฅผ ์ ๋ฐ์ดํธ ํ๋ ๋งค์ปค๋์ฆ ์ ๋๋ค.
- Cell State(์ ์ํ)๋ Forget Gate(๋ง๊ฐ ๊ฒ์ดํธ)์ ์ ์ Cell State(์ ์ํ)๋ฅผ ๊ณฑํ ๊ฐ๊ณผ, Input(์ ๋ ฅ ๊ฒ์ดํธ)์ Cell์์ ํ์ฑ๋ ์๋ก์ด ๊ธฐ์ต์ ๊ณฑํ ๊ฐ์ ๋ํด์ Cell State(์ ์ํ)๋ฅผ ์ ๋ฐ์ดํธ ํฉ๋๋ค.
- ์ด๋ Cell State(์
์ํ)์ ๊ฐ์ด 1์ฉ ์ฆ๊ฐ or ๊ฐ์ ํ๋๋ฐ, Element(์์)๋ณ๋ก Integer Counter ํฉ๋๋ค.
- ๋ฌด์จ ๋ง์ด๋๋ฉด, LSTM Model์ Sequence์ ๊ฐ ์์๋ฅผ ์ฒ๋ฆฌํ๋ฉด์ ์ ๋ณด๋ฅผ ์๊ฑฐ๋ ์ ๊ฑฐํ๋ ๋์์ ์ํํ๋ค๋ ์๋ฏธ์ ๋๋ค.
- ๊ทธ๋ฆฌ๊ณ "ht", Hidden state(์จ๊ฒจ์ง ์ํ)๋ Cell์ Output๊ฐ์ผ๋ก ๋จ๊ธฐ ๊ธฐ์ต์ ์๋ฏธํฉ๋๋ค.
- Hidden state(์จ๊ฒจ์ง ์ํ), "ht"์ ๊ฐ์ Output Gate(์ถ๋ ฅ ๊ฒ์ดํธ)๋ "Ct", Cell State(์ ์ํ)๊ฐ์ ํ์ดํผ๋ณผ๋ฆญํ์ ํธ(tanh)ํจ์์ ๋ฃ์ ๊ฐ๊ณผ ๊ณฑํด์ ์ถ๋ ฅ๊ฐ์ ์ป์ต๋๋ค.
- ์ด๋ Counter ๊ฐ์ -1 ~ 1 ์ฌ์ด์ ๋ฒ์๋ก squashing ํฉ๋๋ค. "squashing" ํ๋ค๋ ๋ง์ Cell State(์ ์ํ)์ ๊ฐ์ -1 ~ 1 ์ฌ์ด์ ๋ฒ์๋ก ๋ณํํ๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.
๊ทธ๋ฆฌ๊ณ LSTM Model์ Gradient Flow์ Weight(๊ฐ์ค์น)๊ฐ ๊ณฑํด์ง์ง ์๋๋ก ๊ตฌ์กฐ๋ฅผ ๋ณ๊ฒฝํ์ต๋๋ค.
- ์ค๋ช ์ ํด๋ณด์๋ฉด, "Ct" -> ํ์ฌ์ ์ ์ํ ์์ "Ct-1" -> ์ด์ ์ ์ ์ํ ์ฌ์ด Gradient(๊ธฐ์ธ๊ธฐ) ์ฐ์ฐ์์ W(๊ฐ์ค์น)์ ์์ ํ ์ฌ๋ผ์ง๋๋ค.
- ์ด์ ๋ "ft" -> Forget Gate(๋ง๊ฐ ๊ฒ์ดํธ)์์ Element(์์) ๊ณฑ์ด ์๊ธฐ ๋๋ฌธ์ Local Gradient๋ Forget Gate(๋ง๊ฐ ๊ฒ์ดํธ)๊ฐ ๋ฉ๋๋ค.
- ์ด๊ฑฐ์ ๋ฐํ์ฌ ์ค๋ช ์ ํด๋ณด๋ฉด, Forget Gate(๋ง๊ฐ ๊ฒ์ดํธ)์ ์ถ๋ ฅ๊ฐ์ธ 0~1 ์ฌ์ด์ ๊ฐ (Cell state์ ๊ฐ ์์์ ๊ณฑํด์ ธ์ ์ ๋ฐ์ดํธ) ์ด Gradient(๊ธฐ์ธ๊ธฐ)์ ํ๋ฆ์ ๊ฒฐ์ ํ๊ธฐ ๋๋ฌธ์, Local Gradient๋ Forget Gate(๋ง๊ฐ ๊ฒ์ดํธ)์ ๊ฐ์ด ๋ฉ๋๋ค.
๊ทธ๋ฌ๋ฉด ์ด๋ฒ์๋ LSTM์์ Gradient Flow๊ฐ ์ข์ ์ด์ ๋ฅผ ์ค๋ช ํด ๋ณด๊ฒ ์ต๋๋ค.
- "ft" -> Forget Gate(๋ง๊ฐ ๊ฒ์ดํธ) ๊ฐ์ ๋งค๋ฒ ๋ด๋๊ธฐ ๋๋ฌธ์ ๊ฐ์ ๊ฐ์ด ๋ฐ๋ณตํด์ ๊ณฑํด์ง์ง ์์ต๋๋ค.
- Forget Gate(๋ง๊ฐ ๊ฒ์ดํธ) ๊ฐ์ด (0 ~ 1) ์ฌ์ด ๋ฒ์์์ ์กด์ฌํ๊ธฐ ๋๋ฌธ์ Gradient Exploding(๊ธฐ์ธ๊ธฐ ํญํ)์ ์ผ์ด๋์ง ์์ต๋๋ค.
- Final hidden state(์ต์ข ์จ๊ฒจ์ง ์ํ)์์ Fist Cell State(์ฒซ๋ฒ์ฌ ์ ์ํ)๊น์ง backward path(์ญ์ ํ ๋จ๊ณ)์๋ tanh(ํ์ดํผ๋ณผ๋ฆญํ์ ํธ)ํจ์๋ฅผ ํ๋ฒ๋ง ๋ํ๋ ๋๋ค. (์ฆ, ๋ฐ๋ณต์ ์ธ tanh(ํ์ดํผ๋ณผ๋ฆญํ์ ํธ)ํจ์์ ๊ณฑ์ ์ด ์ฌ๋ผ์ง๋๋ค)
+ Forget Gate(๋ง๊ฐ ๊ฒ์ดํธ)๊ฐ์ด 1๋ณด๋ค ์๊ธฐ ๋๋ฌธ์ Gradient Vanishing (๊ธฐ์ธ๊ธฐ ์์ค)์ด ์ผ์ด๋ ์ ์์ง๋ง, ์ด๋ฅผ ๋ฐฉ์ง ํ๊ธฐ ์ํด์ Forget Gate(๋ง๊ฐ ๊ฒ์ดํธ)์ bias(ํธํฅ)์ 1๋ก ์ด๊ธฐํ ํฉ๋๋ค.
4. LSTM Model ์ฝ๋ ์์
ํ๋ฒ LSTM Model์ ์ฝ๋ ์์๋ฅผ ํ๋ฒ ๋ณด๊ฒ ์ต๋๋ค.
from keras.models import Sequential
from keras.layers import LSTM, Dense
# model ์ด๊ธฐํ
model = Sequential()
# LSTM layer add
# ์
๋ ฅ ์ฐจ์์ ํน์ฑ์ ์์ ๋ฐ๋ผ ๋ณ๊ฒฝํด์ผ ํฉ๋๋ค. 100 - hidden unit ๊ฐ์
model.add(LSTM(100, input_shape=(timesteps, input_dim)))
# fully connected layer add, ์ถ๋ ฅ Node ๊ฐ์๊ฐ 1๊ฐ์ด๋ฏ๋ก 1์ด๋ผ๊ณ ํจ
model.add(Dense(1, activation='sigmoid'))
# model compile
# ์ด์ง ๋ถ๋ฅ ๋ฌธ์ ์ด๋ฏ๋ก ์์ค ํจ์๋ก 'binary_crossentropy'๋ฅผ ์ฌ์ฉํฉ๋๋ค.
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
- LSTM layer๋ฅผ ์ ์ธํด์ค๋, "model.add(LSTM(100, input_shape=(timesteps, input_dim)))" ์ด๋ฐ ํ์์ผ๋ก ์ ์ธ์ ํด์ค์ผ ํฉ๋๋ค.
- ์ฌ๊ธฐ์๋ 100์ hidden Unit์ ๊ฐ์์ ๋๋ค.
- ์ฌ๊ธฐ์ "timesteps"์ Sequence์ Length(๊ธธ์ด), input_dim์ ๊ฐ Sequence ์์์ ํน์ฑ ๊ฐ์๋ฅผ ๋ํ๋ ๋๋ค.
- ๊ทธ๋ฆฌ๊ณ ํ๋์ ์ถ๋ ฅ ๋
ธ๋๋ฅผ ๊ฐ์ง Dense Layer(์์ ์ฐ๊ฒฐ ๋ ์ด์ด)๋ฅผ ์ฌ์ฉํฉ๋๋ค.
- "model.add(Dense(1, activation='sigmoid'))" -> ์ถ๋ ฅ Node์ ๊ฐ์๊ฐ ํ๋์ด๋ฏ๋ก, activation ์ผ์ชฝ ์์ ์ถ๋ ฅ๋ ธ๋์ ๊ฐ์๋ฅผ ์ ๋ ํ๋ผ๋ฏธํฐ ๊ณต๊ฐ์ 1์ ์ ์ด์ฃผ์์ต๋๋ค.
- ๋ํ ์ ๊ฐ ๊ฐ๋จํ ๋ง๋ ๋ชจ๋ธ์ Binary(์ด์ง)๋ถ๋ฅ ๋ฌธ์ ์ ์ ํฉํ๊ฒ ๋ง๋ค์์ต๋๋ค.
- ๊ทธ๋์ ์ถ๋ ฅ Dense Layer์ 0~1์ฌ์ด์ ํ๋ฅ ๊ฐ์ ์ถ๋ ฅํด์ฃผ๋ "Sigmoid ํจ์", Compile ์ฝ๋์์๋ Binary(์ด์ง)๋ถ๋ฅ๋ฅผ ์ํด์ "binary_crossentropy"๋ฅผ ์ฌ์ฉํ์ต๋๋ค.
- ๋ง์ฝ์ ๋ค์ค Class๋ฅผ ๋ถ๋ฅํ๋ ๋ฌธ์ ๋ผ๋ฉด ํ์ฑํ ํจ์ -> "activation"์ "softmax"ํจ์๋ฅผ ์ฌ์ฉํด์ผ ํ๊ณ
- ๋ชจ๋ธ Compile ์ฝ๋์์ ์์ค ํจ์ -> "loss" ๋ฅผ 'categorical_crossentropy'๋ก ์ค์ ํด์ฃผ๋ฉด ๋ฉ๋๋ค.
- ๊ณต์๋ฌธ์์ ์ด๋ ํ parameter๊ฐ ๋ค์ด๊ฐ์ ์๋์ง ๋์์์ผ๋๊น ์ฐธ๊ณ ํด์ฃผ์ธ์.
- ๋ค์์ GRU Model์ ๋ฐํ์ฌ ์ค๋ช ํ๋ ๊ธ๋ก ๋์์ค๊ฒ ์ต๋๋ค.
+ Softmax ํจ์: Vector๋ฅผ input์ผ๋ก ๋ฐ์์ ๊ฐ ์์์ ๊ฐ์ 0~1 ์ฌ์ด์ ๊ฐ์ผ๋ก ๋ฐํํ์ฌ, ์ด ๊ฐ๋ค์ ํฉ์ด 1์ด ๋๋๋ก ๋ง๋๋ ํจ์์ ๋๋ค. ์ฃผ๋ก ํ๋ฅ ๋ถํฌ๋ฅผ ๋ํ๋ด๋๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.
๋ฐ์ํ
'๐ NLP (์์ฐ์ด์ฒ๋ฆฌ) > ๐ Natural Language Processing' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[NLP] Word2Vec, CBOW, Skip-Gram - ๊ฐ๋ & Model (0) | 2024.02.03 |
---|---|
[NLP] GRU Model - LSTM Model์ ๊ฐ๋ณ๊ฒ ๋ง๋ ๋ชจ๋ธ (0) | 2024.01.30 |
[NLP] Vanilla RNN Model, Long-Term Dependency - ์ฅ๊ธฐ ์์กด์ฑ ๋ฌธ์ (0) | 2024.01.23 |
[NLP] RNN (Recurrent Netural Network) - ์ํ์ ๊ฒฝ๋ง (0) | 2024.01.22 |
[NLP] Seq2Seq, Encoder & Decoder (0) | 2024.01.19 |