200行の心臓部。Value クラスは「1つの数値 data と、それが結果にどれだけ効くか grad をセットで持つ箱」。計算するたびに、どの演算をしたか(_children)と各入力への効き(_local_grads)を記録しておく。
会計で言えば、差異がどの科目から来たかを自動で遡る原因分解エンジン。順方向で金額(data)を積み上げ、backward を呼ぶと逆順に各段の局所勾配を掛けながら遡って、すべての入力の grad(感応度)を埋める。
この段は4つの用語に分かれる。加算・乗算の勾配、ReLU、連鎖律、トポロジカルソート。下のリンクから1演算ずつ詳しく見られる。
この段が、Karpathy のオリジナル200行のどこに当たるか。
class Value:
def __init__(self, data, children=(), local_grads=()):
self.data = data # 順方向で決まる数値(金額)
self.grad = 0 # 損失に対する効き(感応度)
self._children = children
self._local_grads = local_grads
def __add__(self, other): # 局所勾配 (1, 1)
return Value(self.data + other.data, (self, other), (1, 1))
def __mul__(self, other): # 局所勾配 (b, a)
return Value(self.data * other.data, (self, other), (other.data, self.data))
def relu(self): # 局所勾配 1.0(正) / 0.0(負)
return Value(max(0, self.data), (self,), (float(self.data > 0),))
def backward(self):
topo, visited = [], set()
def build_topo(v):
if v not in visited:
visited.add(v)
for child in v._children:
build_topo(child)
topo.append(v)
build_topo(self)
self.grad = 1
for v in reversed(topo): # 逆順に遡る
for child, local_grad in zip(v._children, v._local_grads):
child.grad += local_grad * v.grad # 連鎖律出典: karpathy / microgpt.py (本体は原文ベース、抜粋・コメントは日本語に補足)
触って動かせるデモ+簿記アナロジー+該当コード付き。