גרפי חישוב

חשבון דיפרנציאלי ואינטגרלי רב־משתני מיסודות ראשונים

מבנה נתונים אחד מארגן את כל מה שלמדנו בשני השיעורים האחרונים: גרף החישוב. כל פעולה אריתמטית במודל (חיבור, כפל, matmul, הפעלה) הופכת לצומת בגרף מכוון. הגרף הזה הוא הדרך שבה PyTorch, JAX ו־TensorFlow מחשבים גרדיאנטים באופן אוטומטי.

האימון מריץ את הגרף בשני מעברים. המעבר קדימה זורם משמאל לימין, מחשב ושומר את הערך של כל צומת. המעבר אחורה זורם מימין לשמאל, ומשתמש בכלל השרשרת כדי לדחוף את הגרדיאנט מההפסד בחזרה אל כל קלט, צומת אחר צומת.

הרעיון שמאפשר את ההתרחבות הזו: כל צומת צריך לדעת רק את הנגזרת המקומית של עצמו. כדי לשלוח את הגרדיאנט אחורה דרך צומת, מכפילים את הגרדיאנט הנכנס (מלמעלה) ביעקוביאן המקומי של הצומת (האופן שבו הפלט שלו תלוי בקלטים שלו). אף צומת אינו זקוק לתמונה הגלובלית; כללים מקומיים, כשהם משורשרים יחד, מייצרים את הגרדיאנט הכולל המדויק.

איפה זה ב־MLגרף חישוב הוא autograd. כשכותבים מודל ב־PyTorch, כל פעולה רושמת בשקט צומת; קריאה ל־loss.backward() עוברת על הגרף בסדר הפוך, מכפילה את היעקוביאנים המקומיים לפי כלל השרשרת, ומפקידה ∂loss/∂w על כל פרמטר. לעולם אין צורך לכתוב נגזרת ביד, והנוחות הזו, חישוב מדויק וזול של נגזרות, היא חלק גדול מהסיבה שלמידה עמוקה מודרנית מעשית בכלל.
▶ גרפי חישוב
← כלל השרשרת: צורה מטריציתנקודות קריטיות ב־Rⁿ →