A A
[PyTorch] Torch.Autograd๋ฅผ ์ด์šฉํ•œ ์ž๋™ ๋ฏธ๋ถ„
โš ๏ธ ๋ณธ ๋‚ด์šฉ์€ PyTorch Korea์˜ ๊ณต์‹ ๋ฌธ์„œ์— ๊ธฐ๋ฐ˜ํ•˜์—ฌ ๊ณต๋ถ€ํ•œ ๋‚ด์šฉ์„ ์ ์€๊ฒƒ์ด๋‹ˆ ์–‘ํ•ด๋ฐ”๋ž๋‹ˆ๋‹ค!
 

torch.autograd๋ฅผ ์‚ฌ์šฉํ•œ ์ž๋™ ๋ฏธ๋ถ„

ํŒŒ์ดํ† ์น˜(PyTorch) ๊ธฐ๋ณธ ์ตํžˆ๊ธฐ|| ๋น ๋ฅธ ์‹œ์ž‘|| ํ…์„œ(Tensor)|| Dataset๊ณผ Dataloader|| ๋ณ€ํ˜•(Transform)|| ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ ๊ตฌ์„ฑํ•˜๊ธฐ|| Autograd|| ์ตœ์ ํ™”(Optimization)|| ๋ชจ๋ธ ์ €์žฅํ•˜๊ณ  ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ์‹ ๊ฒฝ๋ง์„ ํ•™์Šตํ•  ๋•Œ ๊ฐ€์žฅ

tutorials.pytorch.kr

 


torch.autograd๋ฅผ ์‚ฌ์šฉํ•œ ์ž๋™ ๋ฏธ๋ถ„

์ผ๋ฐ˜์ ์œผ๋กœ Neural Network(์‹ ๊ฒฝ๋ง)์„ ํ•™์Šตํ•  ๋–„ ๊ฐ€์žฅ ์ž์ฃผ ์‚ฌ์šฉ๋˜๋Š” ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ Backpropagation(์—ญ์ „ํŒŒ)์ž…๋‹ˆ๋‹ค.

  • ์ด ์•Œ๊ณ ๋ฆฌ์ฆ˜์—์„œ, ๋งค๊ฐœ๋ณ€์ˆ˜(๋ชจ๋ธ ๊ฐ€์ค‘์น˜)๋Š” ์ฃผ์–ด์ง„ ๋งค๊ฐœ๋ณ€์ˆ˜์— ๋Œ€ํ•œ ์†์‹ค ํ•จ์ˆ˜์˜ ๋ณ€ํ™”๋„(gradient)์— ๋”ฐ๋ผ ์กฐ์ •๋ฉ๋‹ˆ๋‹ค.
  • ์ด๋Ÿฌํ•œ ๋ณ€ํ™”๋„๋ฅผ ๊ณ„์‚ฐํ•˜๊ธฐ ์œ„ํ•ด PyTorch์—๋Š” 'torch.autograd' ๋ผ๊ณ  ๋ถˆ๋ฆฌ๋Š” ์ž๋™ ๋ฏธ๋ถ„ ์—”์ง„์ด ๋‚ด์žฅ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์ด๋Š” ๋ชจ๋“  ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„์— ๋Œ€ํ•œ ๋ณ€ํ™”๋„์˜ ์ž๋™ ๊ณ„์‚ฐ์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค.
  • ์•„๋ž˜๋Š” ์ž…๋ ฅ ํ…์„œ x, ๋งค๊ฐœ๋ณ€์ˆ˜ w์™€ b, ๊ทธ๋ฆฌ๊ณ  ์†์‹ค ํ•จ์ˆ˜๊ฐ€ ์žˆ๋Š” ๊ฐ€์žฅ ๊ฐ„๋‹จํ•œ ๋‹จ์ผ ๊ณ„์ธต ์‹ ๊ฒฝ๋ง์„ ์ •์˜ํ•˜๊ณ ,  Backpropagation(์—ญ์ „ํŒŒ)๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š” ์˜ˆ์ œ์ž…๋‹ˆ๋‹ค.
import torch

# ์ž…๋ ฅ ๋ฐ ์ถœ๋ ฅ ํ…์„œ ์ •์˜
x = torch.ones(5)  # input tensor
y = torch.zeros(3)  # expected output

# ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋ฐ ํŽธํ–ฅ ์ •์˜
w = torch.randn(5, 3, requires_grad=True)  # requires_grad=True: ์ž๋™ ๋ฏธ๋ถ„์„ ์œ„ํ•ด ํ•„์š”
b = torch.randn(3, requires_grad=True)

# ์„ ํ˜• ๋ณ€ํ™˜ ๋ฐ ์†์‹ค ํ•จ์ˆ˜ ๊ณ„์‚ฐ
z = torch.matmul(x, w) + b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

print(f"Loss: {loss}")

# ์—ญ์ „ํŒŒ ์ˆ˜ํ–‰
loss.backward()

# ๋ณ€ํ™”๋„ ์ถœ๋ ฅ
print(f"w.grad: {w.grad}")
print(f"b.grad: {b.grad}")
Loss: 0.5469902753829956
w.grad: tensor([[0.0285, 0.2291, 0.1074],
        [0.0285, 0.2291, 0.1074],
        [0.0285, 0.2291, 0.1074],
        [0.0285, 0.2291, 0.1074],
        [0.0285, 0.2291, 0.1074]])
b.grad: tensor([0.0285, 0.2291, 0.1074])

 

  • x = torch.ones(5): ์ž…๋ ฅ ํ…์„œ x๋Š” ๋ชจ๋“  ์š”์†Œ๊ฐ€ 1์ธ ํฌ๊ธฐ 5์˜ ํ…์„œ์ž…๋‹ˆ๋‹ค.
  • y = torch.zeros(3): ๊ธฐ๋Œ€ ์ถœ๋ ฅ ํ…์„œ y๋Š” ๋ชจ๋“  ์š”์†Œ๊ฐ€ 0์ธ ํฌ๊ธฐ 3์˜ ํ…์„œ์ž…๋‹ˆ๋‹ค.
  • w = torch.randn(5, 3, requires_grad=True): ๊ฐ€์ค‘์น˜ ํ…์„œ w๋Š” 5x3 ํฌ๊ธฐ์˜ ๋žœ๋ค ํ…์„œ์ด๋ฉฐ, requires_grad=True๋กœ ์„ค์ •๋˜์–ด ๋ณ€ํ™”๋„๋ฅผ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • b = torch.randn(3, requires_grad=True): ํŽธํ–ฅ ํ…์„œ b๋Š” ํฌ๊ธฐ 3์˜ ๋žœ๋ค ํ…์„œ์ด๋ฉฐ, requires_grad=True๋กœ ์„ค์ •๋˜์–ด ๋ณ€ํ™”๋„๋ฅผ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • z = torch.matmul(x, w) + b: ์„ ํ˜• ๋ณ€ํ™˜์„ ์ˆ˜ํ–‰ํ•˜์—ฌ ์ถœ๋ ฅ ํ…์„œ z๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
  • loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y): ์†์‹ค ํ•จ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ๋Š” binary_cross_entropy_with_logits ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
  • ์—ญ์ „ํŒŒ ์ˆ˜ํ–‰: loss.backward(): ์—ญ์ „ํŒŒ๋ฅผ ์ˆ˜ํ–‰ํ•˜์—ฌ ์†์‹ค ํ•จ์ˆ˜์˜ ๋ณ€ํ™”๋„๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
  • ๋ณ€ํ™”๋„ ์ถœ๋ ฅ: w.grad: ๊ฐ€์ค‘์น˜ w์— ๋Œ€ํ•œ ๋ณ€ํ™”๋„๋ฅผ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค, b.grad: ํŽธํ–ฅ b์— ๋Œ€ํ•œ ๋ณ€ํ™”๋„๋ฅผ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค.

 


Tensor, Function๊ณผ ์—ฐ์‚ฐ๊ทธ๋ž˜ํ”„(Computational graph)

์ด ์ฝ”๋“œ๋Š” ๋‹ค์Œ์˜ ์—ฐ์‚ฐ ๊ทธ๋ž˜ํ”„๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค:

  • ์ด ์‹ ๊ฒฝ๋ง์—์„œ, w์™€ b๋Š” ์ตœ์ ํ™”๋ฅผ ํ•ด์•ผ ํ•˜๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜์ž…๋‹ˆ๋‹ค.
  • ๋”ฐ๋ผ์„œ ์ด๋Ÿฌํ•œ ๋ณ€์ˆ˜๋“ค์— ๋Œ€ํ•œ ์†์‹ค ํ•จ์ˆ˜์˜ ๋ณ€ํ™”๋„๋ฅผ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
  • ์ด๋ฅผ ์œ„ํ•ด์„œ ํ•ด๋‹น ํ…์„œ์— requires_grad ์†์„ฑ์„ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
requires_grad ์˜ ๊ฐ’์€ ํ…์„œ๋ฅผ ์ƒ์„ฑํ•  ๋•Œ ์„ค์ •ํ•˜๊ฑฐ๋‚˜, ๋‚˜์ค‘์—  x.requires_grad_(True) ๋ฉ”์†Œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋‚˜์ค‘์— ์„ค์ •ํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์—ฐ์‚ฐ ๊ทธ๋ž˜ํ”„๋ฅผ ๊ตฌ์„ฑํ•˜๊ธฐ ์œ„ํ•ด ํ…์„œ์— ์ ์šฉํ•˜๋Š” ํ•จ์ˆ˜๋Š” ์‚ฌ์‹ค Function ํด๋ž˜์Šค์˜ ๊ฐ์ฒด์ž…๋‹ˆ๋‹ค.
  • ์ด ๊ฐ์ฒด๋Š” ์ˆœ์ „ํŒŒ ๋ฐฉํ–ฅ์œผ๋กœ ํ•จ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” ๋ฐฉ๋ฒ•๊ณผ, ์—ญ๋ฐฉํ–ฅ ์ „ํŒŒ ๋‹จ๊ณ„์—์„œ ๋„ํ•จ์ˆ˜(derivative)๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์•Œ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์—ญ๋ฐฉํ–ฅ ์ „ํŒŒ ํ•จ์ˆ˜์— ๋Œ€ํ•œ ์ฐธ์กฐ(reference)๋Š” ํ…์„œ์˜ grad_fn ์†์„ฑ์— ์ €์žฅ๋ฉ๋‹ˆ๋‹ค.
  • Function์— ๋Œ€ํ•œ ์ž์„ธํ•œ ์ •๋ณด๋Š” ์•„๋ž˜์— ๊ณต์‹๋ฌธ์„œ ์—์„œ ์ฐพ์•„๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

  • PyTorch AutoGrad ๊ณต์‹๋ฌธ์„œ
 

Automatic differentiation package - torch.autograd — PyTorch 2.4 documentation

Automatic differentiation package - torch.autograd torch.autograd provides classes and functions implementing automatic differentiation of arbitrary scalar valued functions. It requires minimal changes to the existing code - you only need to declare Tensor

pytorch.org

print(f"Gradient function for z = {z.grad_fn}")
print(f"Gradient function for loss = {loss.grad_fn}")
Gradient function for z = <AddBackward0 object at 0x781041f2d420>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward0 object at 0x7810196afe20>

๋ณ€ํ™”๋„(Gradient) ๊ณ„์‚ฐํ•˜๊ธฐ

  • ์‹ ๊ฒฝ๋ง์—์„œ ๋งค๊ฐœ๋ณ€์ˆ˜์˜ ๊ฐ€์ค‘์น˜๋ฅผ ์ตœ์ ํ™”ํ•˜๋ ค๋ฉด ๋งค๊ฐœ๋ณ€์ˆ˜์— ๋Œ€ํ•œ ์†์‹คํ•จ์ˆ˜์˜ ๋„ํ•จ์ˆ˜(derivative)๋ฅผ ๊ณ„์‚ฐํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
  • ์ฆ‰, x์™€ y์˜ ์ผ๋ถ€ ๊ณ ์ •๊ฐ’์—์„œ ∂๐‘™๐‘œ๐‘ ๐‘ /∂๐‘ค ์™€ ∂๐‘™๐‘œ๐‘ ๐‘ /∂๐‘ ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
  • ์ด๋Ÿฌํ•œ ๋„ํ•จ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•˜๊ธฐ ์œ„ํ•ด, loss.backward()๋ฅผ ํ˜ธ์ถœํ•œ ๋‹ค์Œ w.grad ์™€ b.grad ์—์„œ ๊ฐ’์„ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
loss.backward()
print(w.grad)
print(b.grad)
tensor([[0.0556, 0.1949, 0.1666],
        [0.0556, 0.1949, 0.1666],
        [0.0556, 0.1949, 0.1666],
        [0.0556, 0.1949, 0.1666],
        [0.0556, 0.1949, 0.1666]])
tensor([0.0556, 0.1949, 0.1666])
 ์—ฐ์‚ฐ ๊ทธ๋ž˜ํ”„์˜ ์žŽ(leaf) ๋…ธ๋“œ๋“ค ์ค‘ `requires_grad` ์†์„ฑ์ด `True`๋กœ ์„ค์ •๋œ ๋…ธ๋“œ๋“ค์˜ `grad` ์†์„ฑ๋งŒ ๊ตฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
๊ทธ๋ž˜ํ”„์˜ ๋‹ค๋ฅธ ๋ชจ๋“  ๋…ธ๋“œ์—์„œ๋Š” ๋ณ€ํ™”๋„๊ฐ€ ์œ ํšจํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
์„ฑ๋Šฅ ์ƒ์˜ ์ด์œ ๋กœ, ์ฃผ์–ด์ง„ ๊ทธ๋ž˜ํ”„์—์„œ์˜ `backward` ๋ฅผ ์‚ฌ์šฉํ•œ ๋ณ€ํ™”๋„ ๊ณ„์‚ฐ์€ ํ•œ ๋ฒˆ๋งŒ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
๋งŒ์•ฝ ๋™์ผํ•œ ๊ทธ๋ž˜ํ”„์—์„œ ์—ฌ๋Ÿฌ๋ฒˆ์˜ `backward` ํ˜ธ์ถœ์ด ํ•„์š”ํ•˜๋ฉด, `backward` ํ˜ธ์ถœ ์‹œ์— `retrain_graph=True`๋ฅผ ์ „๋‹ฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

๋ณ€ํ™”๋„(Gradient) ์ถ”์  ๋ฉˆ์ถ”๊ธฐ

  • ๊ธฐ๋ณธ์ ์œผ๋กœ, requires_grad=True์ธ ๋ชจ๋“  ํ…์„œ๋“ค์€ ์—ฐ์‚ฐ ๊ธฐ๋ก์„ ์ถ”์ ํ•˜๊ณ  ๋ณ€ํ™”๋„ ๊ณ„์‚ฐ์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค.
  • ๊ทธ๋Ÿฌ๋‚˜ ๋ชจ๋ธ์„ ํ•™์Šตํ•œ ๋’ค ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ๋‹จ์ˆœํžˆ ์ ์šฉํ•˜๊ธฐ๋งŒ ํ•˜๋Š” ๊ฒฝ์šฐ์™€ ๊ฐ™์ด ์ˆœ์ „ํŒŒ ์—ฐ์‚ฐ๋งŒ ํ•„์š”ํ•œ ๊ฒฝ์šฐ์—๋Š”, ์ด๋Ÿฌํ•œ ์ถ”์ ์ด๋‚˜ ์ง€์›์ด ํ•„์š” ์—†์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์—ฐ์‚ฐ ์ฝ”๋“œ๋ฅผ torch.no_grad() ๋ธ”๋ก์œผ๋กœ ๋‘˜๋Ÿฌ์‹ธ์„œ ์—ฐ์‚ฐ ์ถ”์ ์„ ๋ฉˆ์ถœ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
z = torch.matmul(x, w)+b
print(z.requires_grad)

with torch.no_grad():
    z = torch.matmul(x, w)+b
print(z.requires_grad)
True
False
  • ๋™์ผํ•œ ๊ฒฐ๊ณผ๋ฅผ ์–ป๋Š” ๋‹ค๋ฅธ ๋ฐฉ๋ฒ•์€ ํ…์„œ์— detach() ๋ฉ”์†Œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค:
z = torch.matmul(x, w)+b
z_det = z.detach()
print(z_det.requires_grad)
False

 

๋˜ํ•œ ๋ณ€ํ™”๋„ ์ถ”์ ์„ ๋ฉˆ์ถฐ์•ผ ํ•˜๋Š” ์ด์œ ๋“ค์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค:

  • ์‹ ๊ฒฝ๋ง์˜ ์ผ๋ถ€ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๊ณ ์ •๋œ ๋งค๊ฐœ๋ณ€์ˆ˜(frozen parameter) ๋กœ ํ‘œ์‹œํ•ฉ๋‹ˆ๋‹ค.
  • ๋ณ€ํ™”๋„๋ฅผ ์ถ”์ ํ•˜์ง€ ์•Š๋Š” ํ…์„œ์˜ ์—ฐ์‚ฐ์ด ๋” ํšจ์œจ์ ์ด๊ธฐ ๋•Œ๋ฌธ์—, ์ˆœ์ „ํŒŒ ๋‹จ๊ณ„๋งŒ ์ˆ˜ํ–‰ํ•  ๋•Œ ์—ฐ์‚ฐ ์†๋„๊ฐ€ ํ–ฅ์ƒ๋ฉ๋‹ˆ๋‹ค.

์—ฐ์‚ฐ ๊ทธ๋ž˜ํ”„์— ๋Œ€ํ•œ ์ถ”๊ฐ€ ์ •๋ณด

  • ๊ฐœ๋…์ ์œผ๋กœ, autograd๋Š” ๋ฐ์ดํ„ฐ(ํ…์„œ)์˜ ๋ฐ ์‹คํ–‰๋œ ๋ชจ๋“  ์—ฐ์‚ฐ๋“ค(๋ฐ ์—ฐ์‚ฐ ๊ฒฐ๊ณผ๊ฐ€ ์ƒˆ๋กœ์šด ํ…์„œ์ธ ๊ฒฝ์šฐ๋„ ํฌํ•จํ•˜์—ฌ)์˜ ๊ธฐ๋ก์„ Function ๊ฐ์ฒด๋กœ ๊ตฌ์„ฑ๋œ ๋ฐฉํ–ฅ์„ฑ ๋น„์ˆœํ™˜ ๊ทธ๋ž˜ํ”„(DAG; Directed Acyclic Graph)์— ์ €์žฅ(keep)ํ•ฉ๋‹ˆ๋‹ค.
  • torch.autograd Function ๊ณต์‹๋ฌธ์„œ
 

Automatic differentiation package - torch.autograd — PyTorch 2.4 documentation

Automatic differentiation package - torch.autograd torch.autograd provides classes and functions implementing automatic differentiation of arbitrary scalar valued functions. It requires minimal changes to the existing code - you only need to declare Tensor

pytorch.org

 

  • ์ด ๋ฐฉํ–ฅ์„ฑ ๋น„์ˆœํ™˜ ๊ทธ๋ž˜ํ”„(DAG)์˜ ์žŽ(leave)์€ ์ž…๋ ฅ ํ…์„œ์ด๊ณ , ๋ฟŒ๋ฆฌ(root)๋Š” ๊ฒฐ๊ณผ ํ…์„œ์ž…๋‹ˆ๋‹ค.
  • ์ด ๊ทธ๋ž˜ํ”„๋ฅผ ๋ฟŒ๋ฆฌ์—์„œ๋ถ€ํ„ฐ ์žŽ๊นŒ์ง€ ์ถ”์ ํ•˜๋ฉด ์—ฐ์‡„ ๋ฒ•์น™(chain rule)์— ๋”ฐ๋ผ ๋ณ€ํ™”๋„๋ฅผ ์ž๋™์œผ๋กœ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์ˆœ์ „ํŒŒ ๋‹จ๊ณ„์—์„œ, autograd๋Š” ๋‹ค์Œ ๋‘ ๊ฐ€์ง€ ์ž‘์—…์„ ๋™์‹œ์— ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค:
  • ์š”์ฒญ๋œ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•˜์—ฌ ๊ฒฐ๊ณผ ํ…์„œ๋ฅผ ๊ณ„์‚ฐํ•˜๊ณ ,
  • DAG์— ์—ฐ์‚ฐ์˜ ๋ณ€ํ™”๋„ ๊ธฐ๋Šฅ(gradient function) ๋ฅผ ์œ ์ง€(maintain)ํ•ฉ๋‹ˆ๋‹ค.

์—ญ์ „ํŒŒ ๋‹จ๊ณ„๋Š” DAG ๋ฟŒ๋ฆฌ(root)์—์„œ .backward()๊ฐ€ ํ˜ธ์ถœ๋  ๋•Œ ์‹œ์ž‘๋ฉ๋‹ˆ๋‹ค. autograd๋Š” ์ด๋•Œ

  • ๊ฐ .grad_fn ์œผ๋กœ๋ถ€ํ„ฐ ๋ณ€ํ™”๋„๋ฅผ ๊ณ„์‚ฐํ•˜๊ณ ,
  • ๊ฐ ํ…์„œ์˜ .grad ์†์„ฑ์— ๊ณ„์‚ฐ ๊ฒฐ๊ณผ๋ฅผ ์Œ“๊ณ (accumulate),
  • ์—ฐ์‡„ ๋ฒ•์น™์„ ์‚ฌ์šฉํ•˜์—ฌ, ๋ชจ๋“  ์žŽ(leaf) ํ…์„œ๋“ค๊นŒ์ง€ ์ „ํŒŒ(propagate)ํ•ฉ๋‹ˆ๋‹ค.
PyTorch์—์„œ DAG๋“ค์€ ๋™์ (dynamic)์ž…๋‹ˆ๋‹ค.
์ฃผ๋ชฉํ•ด์•ผ ํ•  ์ค‘์š”ํ•œ ์ ์€ ๊ทธ๋ž˜ํ”„๊ฐ€ ์ฒ˜์Œ๋ถ€ํ„ฐ(from scratch) ๋‹ค์‹œ ์ƒ์„ฑ๋œ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.
๋งค๋ฒˆ `bachward()` ๊ฐ€ ํ˜ธ์ถœ๋˜๊ณ  ๋‚˜๋ฉด, autograd๋Š” ์ƒˆ๋กœ์šด ๊ทธ๋ž˜ํ”„๋ฅผ ์ฑ„์šฐ๊ธฐ(populate) ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค.
์ด๋Ÿฌํ•œ์  ๋•๋ถ„์— ๋ชจ๋ธ์—์„œ ํ๋ฆ„ ์ œ์–ด(control flow) ๊ตฌ๋ฌธ๋“ค์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.
๋งค๋ฒˆ ๋ฐ˜๋ณต(iteration)ํ•  ๋•Œ๋งˆ๋‹ค ํ•„์š”ํ•˜๋ฉด ๋ชจ์–‘(shape)์ด๋‚˜ ํฌ๊ธฐ(size), ์—ฐ์‚ฐ(operation)์„ ๋ฐ”๊ฟ€ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์„ ํƒ์ ์œผ๋กœ ์ฝ๊ธฐ(Optional Reading): ํ…์„œ ๋ณ€ํ™”๋„์™€ ์•ผ์ฝ”๋น„์•ˆ ๊ณฑ (Jacobian Product)

๋Œ€๋ถ€๋ถ„์˜ ๊ฒฝ์šฐ, ์Šค์นผ๋ผ ์†์‹ค ํ•จ์ˆ˜๋ฅผ ๊ฐ€์ง€๊ณ  ์ผ๋ถ€ ๋งค๊ฐœ๋ณ€์ˆ˜์™€ ๊ด€๋ จํ•œ ๋ณ€ํ™”๋„๋ฅผ ๊ณ„์‚ฐํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
๊ทธ๋Ÿฌ๋‚˜ ์ถœ๋ ฅ ํ•จ์ˆ˜๊ฐ€ ์ž„์˜์˜ ํ…์„œ์ธ ๊ฒฝ์šฐ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.
์ด๋Ÿด ๋•Œ, PyTorch๋Š” ์‹ค์ œ ๋ณ€ํ™”๋„๊ฐ€ ์•„๋‹Œ ์•ผ์ฝ”๋น„์•ˆ ๊ณฑ(Jacobian product)์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
  • ๐‘ฅโƒ— =โŸจ๐‘ฅ1,…,๐‘ฅ๐‘›โŸฉ ์ด๊ณ , ๐‘ฆโƒ— =โŸจ๐‘ฆ1,…,๐‘ฆ๐‘šโŸฉ ์ผ ๋•Œ
  • ๋ฒกํ„ฐ ํ•จ์ˆ˜ ๐‘ฆโƒ— =๐‘“(๐‘ฅโƒ— ) ์—์„œ ๐‘ฅโƒ— ์— ๋Œ€ํ•œ ๐‘ฆโƒ—  ์˜ ๋ณ€ํ™”๋„๋Š” ์•ผ์ฝ”๋น„์•ˆ ํ–‰๋ ฌ(Jacobian matrix)๋กœ ์ฃผ์–ด์ง‘๋‹ˆ๋‹ค.

  • ์•ผ์ฝ”๋น„์•ˆ ํ–‰๋ ฌ ์ž์ฒด๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” ๋Œ€์‹ , PyTorch๋Š” ์ฃผ์–ด์ง„ ์ž…๋ ฅ ๋ฒกํ„ฐ ๐‘ฃ=(๐‘ฃ1…๐‘ฃ๐‘š)์— ๋Œ€ํ•œ ์•ผ์ฝ”๋น„์•ˆ ๊ณฑ(Jacobian Product) ๐‘ฃ๐‘‡⋅๐ฝ ์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
  • ์ด ๊ณผ์ •์€ ๐‘ฃ๋ฅผ ์ธ์ž๋กœ backward๋ฅผ ํ˜ธ์ถœํ•˜๋ฉด ์ด๋ค„์ง‘๋‹ˆ๋‹ค.
  • ๐‘ฃ์˜ ํฌ๊ธฐ๋Š” ๊ณฑ(product)์„ ๊ณ„์‚ฐํ•˜๋ ค๊ณ  ํ•˜๋Š” ์›๋ž˜ ํ…์„œ์˜ ํฌ๊ธฐ์™€ ๊ฐ™์•„์•ผ ํ•ฉ๋‹ˆ๋‹ค.
inp = torch.eye(4, 5, requires_grad=True)
out = (inp+1).pow(2).t()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"First call\n{inp.grad}")
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")
inp.grad.zero_()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")
First call
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])

Second call
tensor([[8., 4., 4., 4., 4.],
        [4., 8., 4., 4., 4.],
        [4., 4., 8., 4., 4.],
        [4., 4., 4., 8., 4.]])

Call after zeroing gradients
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])
  • ๋™์ผํ•œ ์ธ์ž๋กœ backward๋ฅผ ๋‘์ฐจ๋ก€ ํ˜ธ์ถœํ•˜๋ฉด ๋ณ€ํ™”๋„ ๊ฐ’์ด ๋‹ฌ๋ผ์ง‘๋‹ˆ๋‹ค.
  • ์ด๋Š” ์—ญ๋ฐฉํ–ฅ ์ „ํŒŒ๋ฅผ ์ˆ˜ํ–‰ํ•  ๋•Œ, PyTorch๊ฐ€ ๋ณ€ํ™”๋„๋ฅผ ๋ˆ„์ (accumulate)ํ•ด์ฃผ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.
  • ์ฆ‰, ๊ณ„์‚ฐ๋œ ๋ณ€ํ™”๋„์˜ ๊ฐ’์ด ์—ฐ์‚ฐ ๊ทธ๋ž˜ํ”„์˜ ๋ชจ๋“  ์žŽ(leaf) ๋…ธ๋“œ์˜ grad ์†์„ฑ์— ์ถ”๊ฐ€๋ฉ๋‹ˆ๋‹ค.
  • ๋”ฐ๋ผ์„œ ์ œ๋Œ€๋กœ ๋œ ๋ณ€ํ™”๋„๋ฅผ ๊ณ„์‚ฐํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” grad ์†์„ฑ์„ ๋จผ์ € 0์œผ๋กœ ๋งŒ๋“ค์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
  • ์‹ค์ œ ํ•™์Šต ๊ณผ์ •์—์„œ๋Š” ์˜ตํ‹ฐ๋งˆ์ด์ €(optimizer)๊ฐ€ ์ด ๊ณผ์ •์„ ๋„์™€์ค๋‹ˆ๋‹ค.
์ด์ „์—๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜ ์—†์ด `backward()' ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœํ–ˆ์Šต๋‹ˆ๋‹ค.
์ด๋Š” ๋ณธ์งˆ์ ์œผ๋กœ `backward(torch.tensor(1.0))`์„ ํ˜ธ์ถœํ•˜๋Š” ๊ฒƒ๊ณผ ๋™์ผํ•˜๋ฉฐ, ์‹ ๊ฒฝ๋ง ํ›ˆ๋ จ ์ค‘์˜ ์†์‹ค๊ณผ ๊ฐ™์€ ์Šค์นผ๋ผ-๊ฐ’ ํ•จ์ˆ˜์˜ ๋ณ€ํ™”๋„๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” ์œ ์šฉํ•œ ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.

๋” ์ž์„ธํ•œ ๋‚ด์šฉ์„ ๋ณด๊ณ  ์‹ถ์œผ์‹œ๋ฉด ์•„๋ž˜ ๋งํฌ์— ๋“ค์–ด๊ฐ€์„œ ๊ณต์‹ ๋ฌธ์„œ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”!

  • Autograd ๊ด€๋ จ ๊ณต์‹๋ฌธ์„œ.
 

Autograd mechanics — PyTorch 2.4 documentation

Autograd mechanics This note will present an overview of how autograd works and records the operations. It’s not strictly necessary to understand all this, but we recommend getting familiar with it, as it will help you write more efficient, cleaner progr

pytorch.org