zerotohero/notebooks/01a_mircograd.ipynb

1976 lines
158 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "672d74cf-2fad-4493-9e3d-7ac2286dbfee",
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"id": "e7295cd8-1f33-40d2-aca5-e226e5b945ba",
"metadata": {},
"source": [
"Let's define a function f"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "07a0fdd9-e9c4-4a28-9201-99429a638573",
"metadata": {},
"outputs": [],
"source": [
"def f(x):\n",
" return 3*x**2 - 4*x + 5"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "67093e12-40d4-4abd-a81d-2f7f63881e19",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"20.0"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f(3.0)"
]
},
{
"cell_type": "markdown",
"id": "ec1297be-8d68-4d59-af98-39e45eb9857b",
"metadata": {},
"source": [
"We can also plot it for a range of values"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0a3f756f-886b-455d-b1f2-53060088dce9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x229ffcad010>]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"xs = np.arange(-5, 5, 0.25)\n",
"ys = f(xs)\n",
"plt.plot(xs, ys)"
]
},
{
"cell_type": "markdown",
"id": "5c9cb820-159e-420d-b0d0-a8706e19f9d5",
"metadata": {},
"source": [
"Now, what's a derivate?\n",
"> It is sensitivity of the function to the change of the output with respect to the input.\n",
"\n",
"In simpler terms, (f(x+h) - f(x))/h, where h tends to zero. This gives us the slope."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "7866b622-e10c-4003-844f-9044e7bdc080",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3.0000002482211127e-05"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"h = 0.00001\n",
"x = 2/3\n",
"(f(x+h) - f(x))/h"
]
},
{
"cell_type": "markdown",
"id": "4f027e6d-d889-4827-867e-335a96c20045",
"metadata": {},
"source": [
"Let's define a function with mupltiple inputs"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a3bcd24a-4a48-4d85-8d49-3fb751589b40",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.0\n"
]
}
],
"source": [
"a = 2.0\n",
"b = -3.0\n",
"c = 10.0\n",
"\n",
"d = a*b + c\n",
"print(d)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "78dcb378-190f-4421-8d95-af79cdf05b18",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"d1 4.0\n",
"d2 3.999699999999999\n",
"slope -3.000000000010772\n"
]
}
],
"source": [
"h = 0.0001\n",
"a = 2.0\n",
"b = -3.0\n",
"c = 10.0\n",
"\n",
"d1 = a*b + c\n",
"a += h\n",
"d2 = a*b + c\n",
"\n",
"print('d1', d1)\n",
"print('d2', d2)\n",
"print('slope', (d2 - d1)/h) # By the good old derivation (wrt to a), we know that this will be b"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9fb1208f-1a1e-4fdb-a290-a7d0e8a5050a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"d1 4.0\n",
"d2 4.0002\n",
"slope 2.0000000000042206\n"
]
}
],
"source": [
"h = 0.0001\n",
"a = 2.0\n",
"b = -3.0\n",
"c = 10.0\n",
"\n",
"d1 = a*b + c\n",
"b += h\n",
"d2 = a*b + c\n",
"\n",
"print('d1', d1)\n",
"print('d2', d2)\n",
"print('slope', (d2 - d1)/h) # By the good old derivation (wrt to b), we know that this will be a"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e7bfb7ef-4c16-44c6-afce-ac20ab858b4f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"d1 4.0\n",
"d2 4.0001\n",
"slope 0.9999999999976694\n"
]
}
],
"source": [
"h = 0.0001\n",
"a = 2.0\n",
"b = -3.0\n",
"c = 10.0\n",
"\n",
"d1 = a*b + c\n",
"c += h\n",
"d2 = a*b + c\n",
"\n",
"print('d1', d1)\n",
"print('d2', d2)\n",
"print('slope', (d2 - d1)/h) # By the good old derivation (wrt to c), we know that this will be 1"
]
},
{
"cell_type": "markdown",
"id": "ad332637-3dbe-4630-bb5a-d5a76e935a9b",
"metadata": {},
"source": [
"The NN will be mathematically very large expressions. We now start by building the data structures for this. Let's start by making the `Value` object from mircograd"
]
},
{
"cell_type": "code",
"execution_count": 131,
"id": "e11f7575-00be-44a4-9bde-ae9288b1922a",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"Value(data=4.0)"
]
},
"execution_count": 131,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class Value:\n",
"\n",
" def __init__(self, data, _children=(), _op='', label=''): # Define a empty tuple `children` to keep the pointers of other Value objects\n",
" # Define a empty set `op` to keep track of what created that Value object\n",
" self.data = data\n",
" self.grad = 0.0 # this is the derivative of Value wrt to its nodes; initalized as 0 (we assume at the beginning that every Value doesn't impact the output)\n",
" self._backward = lambda: None # for a leaf node, this should be nothing.\n",
" self._prev = set(_children) # This will be a empty set when we define a new Value object (a, b, c)\n",
" self._op = _op\n",
" self.label = label\n",
"\n",
" def __repr__(self):\n",
" return f\"Value(data={self.data})\"\n",
"\n",
" def __add__(self, other): # a.__add__(b)\n",
" out = Value(self.data + other.data, (self, other), '+') # Since self.data and other.data is python floating point number,\n",
" # the addition is according to whatever is defined in the python kernel\n",
" def _backward():\n",
" self.grad += 1.0 * out.grad\n",
" other.grad += 1.0 * out.grad\n",
" out._backward = _backward\n",
" \n",
" return out\n",
" \n",
" def __mul__(self, other): # a.__mul__(b) # You can't name it mult or multi, because we are defining magic methods for the Value object\n",
" out = Value(self.data * other.data, (self, other), '*')\n",
"\n",
" def _backward():\n",
" self.grad += other.data * out.grad\n",
" other.grad += self.data * out.grad\n",
" out._backward = _backward\n",
" \n",
" return out\n",
"\n",
" def tanh(self):\n",
" x = self.data\n",
" t = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)\n",
" out = Value(t, (self, ), 'tanh')\n",
"\n",
" def _backward():\n",
" self.grad += (1 - t**2) * out.grad\n",
" out._backward = _backward\n",
" \n",
" return out\n",
"\n",
" def backward(self):\n",
" \n",
" topo = []\n",
" visited = set()\n",
" def build_topo(v):\n",
" if v not in visited:\n",
" visited.add(v)\n",
" for child in v._prev:\n",
" build_topo(child)\n",
" topo.append(v)\n",
" build_topo(self)\n",
" \n",
" self.grad = 1.0\n",
" \n",
" for node in reversed(topo):\n",
" node._backward()\n",
"\n",
"a = Value(2.0, label='a')\n",
"b = Value(-3.0, label='b')\n",
"c = Value(10.0, label='c')\n",
"e = a*b; e.label = 'e'\n",
"d = e + c; d.label = 'd' #(a.__mul__(b)).__add__(c); you can also call this manually.\n",
"f = Value(-2.0, label='f')\n",
"L = d * f; L.label = 'L'\n",
"d"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "aa1c0bd9-5c01-4ff6-b397-bf8b37ca0a70",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{Value(data=-6.0), Value(data=10.0)}"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"d._prev # This will gives us the Value objects that made d (the children of d, ik sounds wrong but it is what it is)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "7d88f630-07a0-4e12-a0e6-177094ed4999",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"set()"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c._prev # Has no children"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "a24046f3-53a4-4bea-ae0b-756bb9a80dd3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'+'"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"d._op"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "cdf24235-ee48-4675-bb65-bd78d1ae201e",
"metadata": {},
"outputs": [],
"source": [
"# imported from mircograd's codebase\n",
"\n",
"from graphviz import Digraph\n",
"\n",
"def trace(root):\n",
" # builds a set of all nodes and edges in a graph\n",
" nodes, edges = set(), set()\n",
" def build(v):\n",
" if v not in nodes:\n",
" nodes.add(v)\n",
" for child in v._prev:\n",
" edges.add((child, v))\n",
" build(child)\n",
" build(root)\n",
" return nodes, edges\n",
"\n",
"def draw_dot(root):\n",
" dot = Digraph(format='svg', graph_attr={'rankdir': 'LR'}) # LR = left to right\n",
" \n",
" nodes, edges = trace(root)\n",
" for n in nodes:\n",
" uid = str(id(n))\n",
" # for any value in the graph, create a rectangular ('record') node for it\n",
" dot.node(name = uid, label = \"{%s | data %.4f | grad %.4f}\" % (n.label, n.data, n.grad), shape='record')\n",
" if n._op:\n",
" # if this value is a result of some operation, create an op node for it\n",
" dot.node(name = uid + n._op, label = n._op)\n",
" # and connect this node to it\n",
" dot.edge(uid + n._op, uid)\n",
"\n",
" for n1, n2 in edges:\n",
" # connect n1 to the op node of n2\n",
" dot.edge(str(id(n1)), str(id(n2)) + n2._op)\n",
"\n",
" return dot"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "68a93a67-c9f0-462b-b61d-3442e3e8523a",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 13.1.0 (20250701.0955)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"822pt\" height=\"127pt\"\n",
" viewBox=\"0.00 0.00 822.00 127.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 123)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-123 817.75,-123 817.75,4 -4,4\"/>\n",
"<!-- 2377653933648 -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>2377653933648</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"315.38,-27.5 315.38,-63.5 502.88,-63.5 502.88,-27.5 315.38,-27.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"326.38\" y=\"-40.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">e</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"337.38,-28 337.38,-63.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"379.5\" y=\"-40.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data &#45;6.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"421.62,-28 421.62,-63.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"462.25\" y=\"-40.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 2377653934672+ -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>2377653934672+</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"567\" cy=\"-72.5\" rx=\"27\" ry=\"18\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"567\" y=\"-67.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">+</text>\n",
"</g>\n",
"<!-- 2377653933648&#45;&gt;2377653934672+ -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>2377653933648&#45;&gt;2377653934672+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M503.05,-61.6C512.25,-63.19 521.16,-64.73 529.24,-66.13\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"528.48,-69.55 538.93,-67.81 529.68,-62.66 528.48,-69.55\"/>\n",
"</g>\n",
"<!-- 2377653933648* -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>2377653933648*</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"251.25\" cy=\"-45.5\" rx=\"27\" ry=\"18\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"251.25\" y=\"-40.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">*</text>\n",
"</g>\n",
"<!-- 2377653933648*&#45;&gt;2377653933648 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>2377653933648*&#45;&gt;2377653933648</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M278.69,-45.5C286.1,-45.5 294.64,-45.5 303.71,-45.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"303.51,-49 313.51,-45.5 303.51,-42 303.51,-49\"/>\n",
"</g>\n",
"<!-- 2377653934672 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>2377653934672</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"630,-54.5 630,-90.5 813.75,-90.5 813.75,-54.5 630,-54.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"641.38\" y=\"-67.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">d</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"652.75,-55 652.75,-90.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"692.62\" y=\"-67.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 4.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"732.5,-55 732.5,-90.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"773.12\" y=\"-67.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 2377653934672+&#45;&gt;2377653934672 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>2377653934672+&#45;&gt;2377653934672</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M594.28,-72.5C601.42,-72.5 609.61,-72.5 618.32,-72.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"618.06,-76 628.06,-72.5 618.06,-69 618.06,-76\"/>\n",
"</g>\n",
"<!-- 2377653955280 -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>2377653955280</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"2.62,-55.5 2.62,-91.5 185.62,-91.5 185.62,-55.5 2.62,-55.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"13.62\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">a</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"24.62,-56 24.62,-91.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"64.5\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 2.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"104.38,-56 104.38,-91.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"145\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 2377653955280&#45;&gt;2377653933648* -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>2377653955280&#45;&gt;2377653933648*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M185.81,-57.13C195.66,-55.35 205.21,-53.63 213.82,-52.08\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"214.19,-55.57 223.41,-50.34 212.95,-48.68 214.19,-55.57\"/>\n",
"</g>\n",
"<!-- 2379410649872 -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>2379410649872</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"0,-0.5 0,-36.5 188.25,-36.5 188.25,-0.5 0,-0.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"11.38\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">b</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"22.75,-1 22.75,-36.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"64.88\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data &#45;3.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"107,-1 107,-36.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"147.62\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 2379410649872&#45;&gt;2377653933648* -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>2379410649872&#45;&gt;2377653933648*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M188.49,-34.75C197.27,-36.28 205.77,-37.76 213.51,-39.1\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"212.66,-42.51 223.11,-40.78 213.86,-35.61 212.66,-42.51\"/>\n",
"</g>\n",
"<!-- 2379407120272 -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>2379407120272</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"314.25,-82.5 314.25,-118.5 504,-118.5 504,-82.5 314.25,-82.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"325.25\" y=\"-95.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">c</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"336.25,-83 336.25,-118.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"379.5\" y=\"-95.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 10.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"422.75,-83 422.75,-118.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"463.38\" y=\"-95.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 2379407120272&#45;&gt;2377653934672+ -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>2379407120272&#45;&gt;2377653934672+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M504.39,-83.57C513.16,-81.99 521.65,-80.47 529.37,-79.08\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"529.73,-82.57 538.95,-77.36 528.49,-75.68 529.73,-82.57\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x229ffd18690>"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"draw_dot(d)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "c580a0dd-f363-4fd8-b221-4c80f78cc1c1",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 13.1.0 (20250701.0955)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"1151pt\" height=\"154pt\"\n",
" viewBox=\"0.00 0.00 1151.00 154.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 150)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-150 1147,-150 1147,4 -4,4\"/>\n",
"<!-- 2377653933648 -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>2377653933648</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"319.88,-27.5 319.88,-63.5 511.88,-63.5 511.88,-27.5 319.88,-27.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"330.88\" y=\"-40.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">e</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"341.88,-28 341.88,-63.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"384\" y=\"-40.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data &#45;6.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"426.12,-28 426.12,-63.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"469\" y=\"-40.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad &#45;2.0000</text>\n",
"</g>\n",
"<!-- 2377653934672+ -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>2377653934672+</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"576\" cy=\"-72.5\" rx=\"27\" ry=\"18\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"576\" y=\"-67.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">+</text>\n",
"</g>\n",
"<!-- 2377653933648&#45;&gt;2377653934672+ -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>2377653933648&#45;&gt;2377653934672+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M512.05,-61.75C521.28,-63.33 530.2,-64.85 538.29,-66.23\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"537.53,-69.65 547.98,-67.89 538.71,-62.75 537.53,-69.65\"/>\n",
"</g>\n",
"<!-- 2377653933648* -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>2377653933648*</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"255.75\" cy=\"-45.5\" rx=\"27\" ry=\"18\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"255.75\" y=\"-40.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">*</text>\n",
"</g>\n",
"<!-- 2377653933648*&#45;&gt;2377653933648 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>2377653933648*&#45;&gt;2377653933648</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M283.22,-45.5C290.59,-45.5 299.09,-45.5 308.13,-45.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"307.89,-49 317.89,-45.5 307.89,-42 307.89,-49\"/>\n",
"</g>\n",
"<!-- 2377653997136 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>2377653997136</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"953.25,-81.5 953.25,-117.5 1143,-117.5 1143,-81.5 953.25,-81.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"965.38\" y=\"-94.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">L</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"977.5,-82 977.5,-117.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"1019.62\" y=\"-94.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data &#45;8.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"1061.75,-82 1061.75,-117.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"1102.38\" y=\"-94.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 1.0000</text>\n",
"</g>\n",
"<!-- 2377653997136* -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>2377653997136*</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"890.25\" cy=\"-99.5\" rx=\"27\" ry=\"18\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"890.25\" y=\"-94.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">*</text>\n",
"</g>\n",
"<!-- 2377653997136*&#45;&gt;2377653997136 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>2377653997136*&#45;&gt;2377653997136</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M917.69,-99.5C924.84,-99.5 933.03,-99.5 941.74,-99.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"941.51,-103 951.51,-99.5 941.51,-96 941.51,-103\"/>\n",
"</g>\n",
"<!-- 2377653934672 -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>2377653934672</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"639,-54.5 639,-90.5 827.25,-90.5 827.25,-54.5 639,-54.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"650.38\" y=\"-67.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">d</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"661.75,-55 661.75,-90.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"701.62\" y=\"-67.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 4.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"741.5,-55 741.5,-90.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"784.38\" y=\"-67.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad &#45;2.0000</text>\n",
"</g>\n",
"<!-- 2377653934672&#45;&gt;2377653997136* -->\n",
"<g id=\"edge8\" class=\"edge\">\n",
"<title>2377653934672&#45;&gt;2377653997136*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M827.49,-88.75C836.27,-90.28 844.77,-91.76 852.51,-93.1\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"851.66,-96.51 862.11,-94.78 852.86,-89.61 851.66,-96.51\"/>\n",
"</g>\n",
"<!-- 2377653934672+&#45;&gt;2377653934672 -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>2377653934672+&#45;&gt;2377653934672</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M603.31,-72.5C610.49,-72.5 618.72,-72.5 627.47,-72.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"627.29,-76 637.29,-72.5 627.29,-69 627.29,-76\"/>\n",
"</g>\n",
"<!-- 2377653955280 -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>2377653955280</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"4.88,-55.5 4.88,-91.5 187.88,-91.5 187.88,-55.5 4.88,-55.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"15.88\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">a</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"26.88,-56 26.88,-91.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"66.75\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 2.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"106.62,-56 106.62,-91.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"147.25\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 6.0000</text>\n",
"</g>\n",
"<!-- 2377653955280&#45;&gt;2377653933648* -->\n",
"<g id=\"edge9\" class=\"edge\">\n",
"<title>2377653955280&#45;&gt;2377653933648*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M188.01,-57.37C198.67,-55.48 209.04,-53.63 218.3,-51.99\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"218.65,-55.48 227.88,-50.28 217.42,-48.59 218.65,-55.48\"/>\n",
"</g>\n",
"<!-- 2377653995280 -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>2377653995280</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"640.12,-109.5 640.12,-145.5 826.12,-145.5 826.12,-109.5 640.12,-109.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"650.38\" y=\"-122.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">f</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"660.62,-110 660.62,-145.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"702.75\" y=\"-122.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data &#45;2.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"744.88,-110 744.88,-145.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"785.5\" y=\"-122.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 4.0000</text>\n",
"</g>\n",
"<!-- 2377653995280&#45;&gt;2377653997136* -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>2377653995280&#45;&gt;2377653997136*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M826.6,-110.81C835.76,-109.16 844.62,-107.56 852.67,-106.1\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"853.09,-109.58 862.31,-104.36 851.84,-102.7 853.09,-109.58\"/>\n",
"</g>\n",
"<!-- 2379410649872 -->\n",
"<g id=\"node9\" class=\"node\">\n",
"<title>2379410649872</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"0,-0.5 0,-36.5 192.75,-36.5 192.75,-0.5 0,-0.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"11.38\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">b</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"22.75,-1 22.75,-36.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"64.88\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data &#45;3.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"107,-1 107,-36.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"149.88\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad &#45;4.0000</text>\n",
"</g>\n",
"<!-- 2379410649872&#45;&gt;2377653933648* -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>2379410649872&#45;&gt;2377653933648*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M193,-34.91C201.81,-36.42 210.32,-37.88 218.06,-39.2\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"217.21,-42.61 227.66,-40.85 218.4,-35.71 217.21,-42.61\"/>\n",
"</g>\n",
"<!-- 2379407120272 -->\n",
"<g id=\"node10\" class=\"node\">\n",
"<title>2379407120272</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"318.75,-82.5 318.75,-118.5 513,-118.5 513,-82.5 318.75,-82.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"329.75\" y=\"-95.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">c</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"340.75,-83 340.75,-118.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"384\" y=\"-95.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 10.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"427.25,-83 427.25,-118.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"470.12\" y=\"-95.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad &#45;2.0000</text>\n",
"</g>\n",
"<!-- 2379407120272&#45;&gt;2377653934672+ -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>2379407120272&#45;&gt;2377653934672+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M513.41,-83.41C522.21,-81.85 530.7,-80.34 538.42,-78.98\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"538.77,-82.47 548,-77.28 537.55,-75.58 538.77,-82.47\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x229ffc31710>"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"draw_dot(L)"
]
},
{
"cell_type": "markdown",
"id": "9886d9da-645c-489e-af3a-60d48facf40b",
"metadata": {},
"source": [
"The value of the forward pass is 8. We would now like to run backpropagation. We will start at the end and reverse and calculate the gradient along all these intermediate values. For every single value here, we are going to calculate the derivative of value (L) wrt every node. \n",
"\n",
"In a NN, we will be very interested in this derivative of the loss function wrt the weights of the NN. There will be data and weights, but the data is fixed, so our focus will be the only the derivatives wrt the weights."
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "f1dd3efe-8498-4c25-9fb6-2f7e71447f6b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"a.grad 6.000000000000227\n",
"b.grad -3.9999999999995595\n",
"c.grad -1.9999999999988916\n",
"e.grad -2.000000000000668\n",
"f.grad 3.9999999999995595\n",
"d.grad -2.000000000000668\n",
"L.grad 1.000000000000334\n"
]
}
],
"source": [
"def local_staging(): # just for our local testing\n",
" h = 0.001\n",
" \n",
" a = Value(2.0, label='a')\n",
" b = Value(-3.0, label='b')\n",
" c = Value(10.0, label='c')\n",
" e = a*b; e.label = 'e'\n",
" d = e + c; d.label = 'd'\n",
" f = Value(-2.0, label='f')\n",
" L = d * f; L.label = 'L'\n",
" L1 = L.data\n",
" \n",
" a = Value(2.0 + h, label='a') # changing a by h\n",
" b = Value(-3.0, label='b')\n",
" c = Value(10.0, label='c')\n",
" e = a*b; e.label = 'e'\n",
" d = e + c; d.label = 'd'\n",
" f = Value(-2.0, label='f')\n",
" L = d * f; L.label = 'L'\n",
" L2 = L.data\n",
"\n",
" print('a.grad',(L2 - L1)/h)\n",
"\n",
" a = Value(2.0, label='a') \n",
" b = Value(-3.0 + h, label='b') # changing b by h\n",
" c = Value(10.0, label='c')\n",
" e = a*b; e.label = 'e'\n",
" d = e + c; d.label = 'd'\n",
" f = Value(-2.0, label='f')\n",
" L = d * f; L.label = 'L'\n",
" L2 = L.data\n",
"\n",
" print('b.grad',(L2 - L1)/h) \n",
"\n",
" a = Value(2.0, label='a') \n",
" b = Value(-3.0, label='b')\n",
" c = Value(10.0 + h, label='c') # changing c by h\n",
" e = a*b; e.label = 'e'\n",
" d = e + c; d.label = 'd'\n",
" f = Value(-2.0, label='f')\n",
" L = d * f; L.label = 'L'\n",
" L2 = L.data\n",
"\n",
" print('c.grad',(L2 - L1)/h) \n",
"\n",
" a = Value(2.0, label='a') \n",
" b = Value(-3.0, label='b')\n",
" c = Value(10.0, label='c')\n",
" e = a*b; e.label = 'e'\n",
" e.data += h # changing e by h\n",
" d = e + c; d.label = 'd'\n",
" f = Value(-2.0, label='f')\n",
" L = d * f; L.label = 'L'\n",
" L2 = L.data\n",
"\n",
" print('e.grad',(L2 - L1)/h) \n",
" \n",
" a = Value(2.0, label='a')\n",
" b = Value(-3.0, label='b')\n",
" c = Value(10.0, label='c')\n",
" e = a*b; e.label = 'e'\n",
" d = e + c; d.label = 'd'\n",
" f = Value(-2.0 + h, label='f') # changing f by h\n",
" L = d * f; L.label = 'L'\n",
" L2 = L.data\n",
"\n",
" print('f.grad',(L2 - L1)/h) # should be d\n",
"\n",
" a = Value(2.0, label='a')\n",
" b = Value(-3.0, label='b')\n",
" c = Value(10.0, label='c')\n",
" e = a*b; e.label = 'e'\n",
" d = e + c; d.label = 'd'\n",
" d.data += h # changing d by h\n",
" f = Value(-2.0, label='f') \n",
" L = d * f; L.label = 'L'\n",
" L2 = L.data \n",
"\n",
" print('d.grad',(L2 - L1)/h) # should be f\n",
"\n",
" a = Value(2.0, label='a')\n",
" b = Value(-3.0, label='b')\n",
" c = Value(10.0, label='c')\n",
" e = a*b; e.label = 'e'\n",
" d = e + c; d.label = 'd'\n",
" f = Value(-2.0, label='f')\n",
" L = d * f; L.label = 'L'\n",
" L2 = L.data + h # changing L by h\n",
"\n",
" print('L.grad',(L2 - L1)/h) # should be 1\n",
"\n",
"local_staging()"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "78e46a2f-7ba5-458f-beb3-133ebca6a839",
"metadata": {},
"outputs": [],
"source": [
"L.grad = 1.0"
]
},
{
"cell_type": "markdown",
"id": "90580c47-bf74-4599-81c9-8bccb8f3c69d",
"metadata": {},
"source": [
"L = d * f\n",
"\n",
"dL/dd = ?\n",
"By calculus, we know that this will be f"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "bbd2d2c2-f136-4244-a901-ff7c0165efe2",
"metadata": {},
"outputs": [],
"source": [
"f.grad = 4.0\n",
"d.grad = -2.0"
]
},
{
"attachments": {
"8211c8be-62e8-4102-aaca-a5bd17838677.png": {
"image/png": ""
}
},
"cell_type": "markdown",
"id": "57e6a7b9-30af-4273-95f1-906bc3da4756",
"metadata": {},
"source": [
"to find dL/dc, let's start with dd/dc \n",
"\n",
"d = c + e;\n",
"so, dd/dc = 1 (and dd/de = 1)\n",
"\n",
"This is a local derivative to study how c & e affects d. But we have to do this for L, and in a NN, there will be 1000s of such local derivatives.\n",
"\n",
"So, let's use chain rule\n",
"\n",
"![image.png](attachment:8211c8be-62e8-4102-aaca-a5bd17838677.png)\n",
"\n",
"> Intuitively, the chain rule states that knowing the instantaneous rate of change of z relative to y and that of y relative to x allows one to calculate the instantaneous rate of change of z relative to x as the product of the two rates of change.\n",
"\n",
"so,\n",
"\n",
"dL/dc = dL/dd * dd/dc = -2.0 * 1 = -2.0"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "416bb5f7-8e8a-49d6-b6b8-784e0bf26447",
"metadata": {},
"outputs": [],
"source": [
"c.grad = -2.0\n",
"e.grad = -2.0"
]
},
{
"cell_type": "markdown",
"id": "f5989699-809e-4b6b-bba6-7af1d836fe95",
"metadata": {},
"source": [
"Now, to find dL/da and dL/db,\n",
"\n",
"dL/da = (dL/dd * dd/de) * de/da = (dL/de) * de/da\n",
"\n",
"dL/db = (dL/dd * dd/de) * de/db = (dL/de) * de/db\n",
"\n",
"de/da = d(a*b)/da = b\n",
"\n",
"de/db = d(a*b)/db = a\n",
"\n",
"dL/da = -2.0 * -3.0 = 6.0\n",
"\n",
"dL/db = -2.0 * 2.0 = -4.0"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "9c66472a-bcd8-42c7-93c0-2c44c8941469",
"metadata": {},
"outputs": [],
"source": [
"a.grad = 6.0\n",
"b.grad = -4.0"
]
},
{
"cell_type": "markdown",
"id": "bbab9a31-8408-475a-98e3-58796dbe78cd",
"metadata": {},
"source": [
"This was manual backpropagation! Let's nudge our inputs towards their gradients to increase L (single, manual optimization step) "
]
},
{
"cell_type": "code",
"execution_count": 62,
"id": "15926195-7075-48e8-bf02-d8b06ad904a4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-7.713030143999999\n"
]
}
],
"source": [
"a.data += 0.001 * a.grad # 0.001 is the step size\n",
"b.data += 0.001 * b.grad\n",
"c.data += 0.001 * c.grad\n",
"f.data += 0.001 * f.grad\n",
"\n",
"e = a * b\n",
"d = e + c\n",
"L = d * f\n",
"\n",
"print(L.data)"
]
},
{
"attachments": {
"34f0ebc5-cb10-4db9-b7a9-9f501a3732da.webp": {
"image/webp": ""
}
},
"cell_type": "markdown",
"id": "4eff7563-3a93-493d-8788-bb055a23d07a",
"metadata": {},
"source": [
"Let's look at another backpropagation using a neuron\n",
"\n",
"![neuron_model.webp](attachment:34f0ebc5-cb10-4db9-b7a9-9f501a3732da.webp)\n",
"\n",
"activation function is usually sigmoid or tanh [[more here]](https://en.wikipedia.org/wiki/Activation_function)"
]
},
{
"cell_type": "code",
"execution_count": 123,
"id": "29f8a9d0-16a6-49e3-b1db-52084db2ce24",
"metadata": {},
"outputs": [],
"source": [
"# inputs x1,x2\n",
"x1 = Value(2.0, label='x1')\n",
"x2 = Value(0.0, label='x2')\n",
"# weights w1,w2\n",
"w1 = Value(-3.0, label='w1')\n",
"w2 = Value(1.0, label='w2')\n",
"# bias of the neuron\n",
"b = Value(6.8813735870195432, label='b')\n",
"# x1*w1 + x2*w2 + b\n",
"x1w1 = x1*w1; x1w1.label = 'x1*w1'\n",
"x2w2 = x2*w2; x2w2.label = 'x2*w2'\n",
"x1w1x2w2 = x1w1 + x2w2; x1w1x2w2.label = 'x1*w1 + x2*w2'\n",
"n = x1w1x2w2 + b; n.label = 'n'\n",
"o = n.tanh(); o.label = 'output'"
]
},
{
"cell_type": "code",
"execution_count": 126,
"id": "1a1bdc0d-23e1-44bc-8b8e-1e88cdb64b2c",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 13.1.0 (20250701.0955)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"1594pt\" height=\"210pt\"\n",
" viewBox=\"0.00 0.00 1594.00 210.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 206)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-206 1590.25,-206 1590.25,4 -4,4\"/>\n",
"<!-- 2379410984016 -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>2379410984016</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"326.25,-55.5 326.25,-91.5 540,-91.5 540,-55.5 326.25,-55.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"352.62\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x2*w2</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"379,-56 379,-91.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"418.88\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 0.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"458.75,-56 458.75,-91.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"499.38\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 0.5000</text>\n",
"</g>\n",
"<!-- 2379410985680+ -->\n",
"<g id=\"node9\" class=\"node\">\n",
"<title>2379410985680+</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"605.25\" cy=\"-100.5\" rx=\"27\" ry=\"18\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"605.25\" y=\"-95.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">+</text>\n",
"</g>\n",
"<!-- 2379410984016&#45;&gt;2379410985680+ -->\n",
"<g id=\"edge14\" class=\"edge\">\n",
"<title>2379410984016&#45;&gt;2379410985680+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M540.42,-90.37C549.89,-91.87 559,-93.32 567.21,-94.62\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"566.63,-98.07 577.05,-96.18 567.72,-91.16 566.63,-98.07\"/>\n",
"</g>\n",
"<!-- 2379410984016* -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>2379410984016*</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"261\" cy=\"-73.5\" rx=\"27\" ry=\"18\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"261\" y=\"-68.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">*</text>\n",
"</g>\n",
"<!-- 2379410984016*&#45;&gt;2379410984016 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>2379410984016*&#45;&gt;2379410984016</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M288.21,-73.5C296,-73.5 305.08,-73.5 314.82,-73.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"314.55,-77 324.55,-73.5 314.55,-70 314.55,-77\"/>\n",
"</g>\n",
"<!-- 2379410987088 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>2379410987088</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"3.75,-55.5 3.75,-91.5 194.25,-91.5 194.25,-55.5 3.75,-55.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"18.5\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x2</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"33.25,-56 33.25,-91.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"73.12\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 0.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"113,-56 113,-91.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"153.62\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 0.5000</text>\n",
"</g>\n",
"<!-- 2379410987088&#45;&gt;2379410984016* -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>2379410987088&#45;&gt;2379410984016*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M194.46,-73.5C204.21,-73.5 213.66,-73.5 222.21,-73.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"222.12,-77 232.12,-73.5 222.12,-70 222.12,-77\"/>\n",
"</g>\n",
"<!-- 2379410981968 -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>2379410981968</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"324,-110.5 324,-146.5 542.25,-146.5 542.25,-110.5 324,-110.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"350.38\" y=\"-123.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x1*w1</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"376.75,-111 376.75,-146.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"418.88\" y=\"-123.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data &#45;6.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"461,-111 461,-146.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"501.62\" y=\"-123.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 0.5000</text>\n",
"</g>\n",
"<!-- 2379410981968&#45;&gt;2379410985680+ -->\n",
"<g id=\"edge13\" class=\"edge\">\n",
"<title>2379410981968&#45;&gt;2379410985680+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M542.35,-110.69C551.12,-109.24 559.54,-107.86 567.19,-106.6\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"567.72,-110.06 577.02,-104.98 566.58,-103.15 567.72,-110.06\"/>\n",
"</g>\n",
"<!-- 2379410981968* -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>2379410981968*</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"261\" cy=\"-128.5\" rx=\"27\" ry=\"18\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"261\" y=\"-123.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">*</text>\n",
"</g>\n",
"<!-- 2379410981968*&#45;&gt;2379410981968 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>2379410981968*&#45;&gt;2379410981968</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M288.21,-128.5C295.29,-128.5 303.43,-128.5 312.17,-128.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"312.01,-132 322.01,-128.5 312.01,-125 312.01,-132\"/>\n",
"</g>\n",
"<!-- 2379410982992 -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>2379410982992</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"0,-165.5 0,-201.5 198,-201.5 198,-165.5 0,-165.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"16.25\" y=\"-178.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">w1</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"32.5,-166 32.5,-201.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"74.62\" y=\"-178.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data &#45;3.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"116.75,-166 116.75,-201.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"157.38\" y=\"-178.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 1.0000</text>\n",
"</g>\n",
"<!-- 2379410982992&#45;&gt;2379410981968* -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>2379410982992&#45;&gt;2379410981968*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M172.12,-165.01C180.9,-162.35 189.7,-159.5 198,-156.5 208.09,-152.86 218.82,-148.27 228.47,-143.88\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"229.86,-147.09 237.45,-139.69 226.9,-140.75 229.86,-147.09\"/>\n",
"</g>\n",
"<!-- 2379410988688 -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>2379410988688</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"1.5,-110.5 1.5,-146.5 196.5,-146.5 196.5,-110.5 1.5,-110.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"16.25\" y=\"-123.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x1</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"31,-111 31,-146.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"70.88\" y=\"-123.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 2.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"110.75,-111 110.75,-146.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"153.62\" y=\"-123.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad &#45;1.5000</text>\n",
"</g>\n",
"<!-- 2379410988688&#45;&gt;2379410981968* -->\n",
"<g id=\"edge12\" class=\"edge\">\n",
"<title>2379410988688&#45;&gt;2379410981968*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M196.76,-128.5C205.77,-128.5 214.47,-128.5 222.4,-128.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"222.25,-132 232.25,-128.5 222.25,-125 222.25,-132\"/>\n",
"</g>\n",
"<!-- 2379410985680 -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>2379410985680</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"668.25,-82.5 668.25,-118.5 939,-118.5 939,-82.5 668.25,-82.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"720.88\" y=\"-95.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x1*w1 + x2*w2</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"773.5,-83 773.5,-118.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"815.62\" y=\"-95.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data &#45;6.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"857.75,-83 857.75,-118.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"898.38\" y=\"-95.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 0.5000</text>\n",
"</g>\n",
"<!-- 2379410981264+ -->\n",
"<g id=\"node13\" class=\"node\">\n",
"<title>2379410981264+</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"1002\" cy=\"-127.5\" rx=\"27\" ry=\"18\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"1002\" y=\"-122.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">+</text>\n",
"</g>\n",
"<!-- 2379410985680&#45;&gt;2379410981264+ -->\n",
"<g id=\"edge10\" class=\"edge\">\n",
"<title>2379410985680&#45;&gt;2379410981264+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M939.14,-118.99C947.95,-120.21 956.31,-121.36 963.87,-122.4\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"963.16,-125.83 973.55,-123.73 964.12,-118.9 963.16,-125.83\"/>\n",
"</g>\n",
"<!-- 2379410985680+&#45;&gt;2379410985680 -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>2379410985680+&#45;&gt;2379410985680</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M632.73,-100.5C639.73,-100.5 647.79,-100.5 656.52,-100.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"656.39,-104 666.39,-100.5 656.39,-97 656.39,-104\"/>\n",
"</g>\n",
"<!-- 2377655405264 -->\n",
"<g id=\"node10\" class=\"node\">\n",
"<title>2377655405264</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"1374.75,-109.5 1374.75,-145.5 1586.25,-145.5 1586.25,-109.5 1374.75,-109.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"1400\" y=\"-122.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">output</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"1425.25,-110 1425.25,-145.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"1465.12\" y=\"-122.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 0.7071</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"1505,-110 1505,-145.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"1545.62\" y=\"-122.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 1.0000</text>\n",
"</g>\n",
"<!-- 2377655405264tanh -->\n",
"<g id=\"node11\" class=\"node\">\n",
"<title>2377655405264tanh</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"1311.75\" cy=\"-127.5\" rx=\"27\" ry=\"18\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"1311.75\" y=\"-122.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">tanh</text>\n",
"</g>\n",
"<!-- 2377655405264tanh&#45;&gt;2377655405264 -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>2377655405264tanh&#45;&gt;2377655405264</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M1339.17,-127.5C1346.27,-127.5 1354.43,-127.5 1363.17,-127.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"1362.99,-131 1372.99,-127.5 1362.99,-124 1362.99,-131\"/>\n",
"</g>\n",
"<!-- 2379410981264 -->\n",
"<g id=\"node12\" class=\"node\">\n",
"<title>2379410981264</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"1065,-109.5 1065,-145.5 1248.75,-145.5 1248.75,-109.5 1065,-109.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"1076.38\" y=\"-122.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">n</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"1087.75,-110 1087.75,-145.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"1127.62\" y=\"-122.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 0.8814</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"1167.5,-110 1167.5,-145.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"1208.12\" y=\"-122.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 0.5000</text>\n",
"</g>\n",
"<!-- 2379410981264&#45;&gt;2377655405264tanh -->\n",
"<g id=\"edge11\" class=\"edge\">\n",
"<title>2379410981264&#45;&gt;2377655405264tanh</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M1249.01,-127.5C1257.39,-127.5 1265.52,-127.5 1272.98,-127.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"1272.95,-131 1282.95,-127.5 1272.95,-124 1272.95,-131\"/>\n",
"</g>\n",
"<!-- 2379410981264+&#45;&gt;2379410981264 -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>2379410981264+&#45;&gt;2379410981264</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M1029.28,-127.5C1036.42,-127.5 1044.61,-127.5 1053.32,-127.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"1053.06,-131 1063.06,-127.5 1053.06,-124 1053.06,-131\"/>\n",
"</g>\n",
"<!-- 2379410985424 -->\n",
"<g id=\"node14\" class=\"node\">\n",
"<title>2379410985424</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"2.25,-0.5 2.25,-36.5 195.75,-36.5 195.75,-0.5 2.25,-0.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"18.5\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">w2</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"34.75,-1 34.75,-36.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"74.62\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 1.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"114.5,-1 114.5,-36.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"155.12\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 2379410985424&#45;&gt;2379410984016* -->\n",
"<g id=\"edge8\" class=\"edge\">\n",
"<title>2379410985424&#45;&gt;2379410984016*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M168.91,-36.94C178.74,-39.93 188.67,-43.15 198,-46.5 207.96,-50.07 218.58,-54.47 228.18,-58.68\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"226.54,-61.78 237.1,-62.67 229.4,-55.39 226.54,-61.78\"/>\n",
"</g>\n",
"<!-- 2379410985936 -->\n",
"<g id=\"node15\" class=\"node\">\n",
"<title>2379410985936</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"711.75,-137.5 711.75,-173.5 895.5,-173.5 895.5,-137.5 711.75,-137.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"723.12\" y=\"-150.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">b</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"734.5,-138 734.5,-173.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"774.38\" y=\"-150.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 6.8814</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"814.25,-138 814.25,-173.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"854.88\" y=\"-150.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 0.5000</text>\n",
"</g>\n",
"<!-- 2379410985936&#45;&gt;2379410981264+ -->\n",
"<g id=\"edge9\" class=\"edge\">\n",
"<title>2379410985936&#45;&gt;2379410981264+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M895.91,-142.48C919.97,-139.05 944.58,-135.55 963.93,-132.79\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"964.34,-136.26 973.75,-131.39 963.35,-129.33 964.34,-136.26\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x22997449150>"
]
},
"execution_count": 126,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"draw_dot(o)"
]
},
{
"cell_type": "code",
"execution_count": 73,
"id": "2e570d89-e528-4d41-a691-da62fce06b2e",
"metadata": {},
"outputs": [],
"source": [
"o.grad = 1.0"
]
},
{
"cell_type": "code",
"execution_count": 80,
"id": "35b5853e-489a-43d3-8c41-560f4a876131",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.4999999999999999\n"
]
}
],
"source": [
"# o = tanh(n)\n",
"# do/dn = 1 - tanh(n)**2 = 1 - o**2\n",
"print(1 - o.data**2)\n",
"n.grad = 0.5"
]
},
{
"cell_type": "code",
"execution_count": 78,
"id": "aa8737b2-6e53-4aa5-bb2c-2221e1dce011",
"metadata": {},
"outputs": [],
"source": [
"# a plus is just a distributor of gradient, so n.grad will flow to b and x1*w1 + x2*w2, and x1w1x2w2 will flow to x1w1 & x2w2\n",
"x1w1x2w2.grad = 0.5\n",
"b.grad = 0.5\n",
"x1w1.grad = 0.5\n",
"x2w2.grad = 0.5"
]
},
{
"cell_type": "code",
"execution_count": 82,
"id": "ff460f77-a807-4fce-93b7-079662202a70",
"metadata": {},
"outputs": [],
"source": [
"# a multi is front grad * the other child\n",
"x1.grad = w1.data * x1w1.grad\n",
"w1.grad = x1.data * x1w1.grad\n",
"x2.grad = w2.data * x2w2.grad\n",
"w2.grad = x2.data * x2w2.grad"
]
},
{
"cell_type": "markdown",
"id": "d6965ba7-fefa-4c10-aae7-49ee2e9da76d",
"metadata": {},
"source": [
"Doing this manually is ridiculous. Let's implement this backprop in our Value object."
]
},
{
"cell_type": "code",
"execution_count": 101,
"id": "32f79032-f121-4b86-b77d-dacdbdd01ee4",
"metadata": {},
"outputs": [],
"source": [
"o.grad = 1.0 # we need to set the base case for o"
]
},
{
"cell_type": "code",
"execution_count": 102,
"id": "66038692-d288-4052-9861-4aeeb8a84836",
"metadata": {},
"outputs": [],
"source": [
"o._backward()"
]
},
{
"cell_type": "code",
"execution_count": 108,
"id": "ac1b36fa-1eae-473a-9182-2e18c0d18c1b",
"metadata": {},
"outputs": [],
"source": [
"n._backward()\n",
"x1w1x2w2._backward()\n",
"x1w1._backward()\n",
"x2w2._backward()"
]
},
{
"cell_type": "markdown",
"id": "aeb2bc8d-76cb-4a0a-8fec-848a5ea6600e",
"metadata": {},
"source": [
"Now, we get rid of us doing _backward manually. For this, we use topological sort [[more here](https://en.wikipedia.org/wiki/Topological_sorting)]"
]
},
{
"cell_type": "code",
"execution_count": 110,
"id": "4adbf2b0-5130-4908-8536-6da47081037a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Value(data=6.881373587019543),\n",
" Value(data=-3.0),\n",
" Value(data=2.0),\n",
" Value(data=-6.0),\n",
" Value(data=0.0),\n",
" Value(data=1.0),\n",
" Value(data=0.0),\n",
" Value(data=-6.0),\n",
" Value(data=0.8813735870195432),\n",
" Value(data=0.7071067811865476)]"
]
},
"execution_count": 110,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# imported from mircograd's codebase\n",
"\n",
"topo = []\n",
"visited = set()\n",
"def build_topo(v):\n",
" if v not in visited:\n",
" visited.add(v)\n",
" for child in v._prev:\n",
" build_topo(child)\n",
" topo.append(v)\n",
"build_topo(o)\n",
"topo"
]
},
{
"cell_type": "code",
"execution_count": 120,
"id": "d9626671-9a47-41ce-a7d3-3517a3f61888",
"metadata": {},
"outputs": [],
"source": [
"o.grad = 1.0\n",
"\n",
"topo = []\n",
"visited = set()\n",
"def build_topo(v):\n",
" if v not in visited:\n",
" visited.add(v)\n",
" for child in v._prev:\n",
" build_topo(child)\n",
" topo.append(v)\n",
"build_topo(o)\n",
"\n",
"for node in reversed(topo):\n",
" node._backward()"
]
},
{
"cell_type": "markdown",
"id": "174e74ee-c432-41e6-acd4-a1e8c3ccbf7b",
"metadata": {},
"source": [
"Now let's put this inside our Value object"
]
},
{
"cell_type": "code",
"execution_count": 125,
"id": "ebcd762e-0b83-4303-a6aa-c60f2b4e24c7",
"metadata": {},
"outputs": [],
"source": [
"o.backward()"
]
},
{
"cell_type": "markdown",
"id": "3507e421-b28b-486b-8b6b-4126bafe1964",
"metadata": {},
"source": [
"We have a bug"
]
},
{
"cell_type": "code",
"execution_count": 127,
"id": "82af7380-6803-4d7d-8088-8ee42edd0cdd",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 13.1.0 (20250701.0955)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"501pt\" height=\"45pt\"\n",
" viewBox=\"0.00 0.00 501.00 45.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 41)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-41 496.75,-41 496.75,4 -4,4\"/>\n",
"<!-- 2377655700112 -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>2377655700112</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"309,-0.5 309,-36.5 492.75,-36.5 492.75,-0.5 309,-0.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"320.38\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">b</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"331.75,-1 331.75,-36.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"371.62\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 6.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"411.5,-1 411.5,-36.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"452.12\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 1.0000</text>\n",
"</g>\n",
"<!-- 2377655700112+ -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>2377655700112+</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"246\" cy=\"-18.5\" rx=\"27\" ry=\"18\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"246\" y=\"-13.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">+</text>\n",
"</g>\n",
"<!-- 2377655700112+&#45;&gt;2377655700112 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>2377655700112+&#45;&gt;2377655700112</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M273.28,-18.5C280.42,-18.5 288.61,-18.5 297.32,-18.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"297.06,-22 307.06,-18.5 297.06,-15 297.06,-22\"/>\n",
"</g>\n",
"<!-- 2377654771536 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>2377654771536</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"0,-0.5 0,-36.5 183,-36.5 183,-0.5 0,-0.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"11\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">a</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"22,-1 22,-36.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"61.88\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 3.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"101.75,-1 101.75,-36.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"142.38\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 1.0000</text>\n",
"</g>\n",
"<!-- 2377654771536&#45;&gt;2377655700112+ -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>2377654771536&#45;&gt;2377655700112+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M183.41,-18.5C191.77,-18.5 199.88,-18.5 207.32,-18.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"207.26,-22 217.26,-18.5 207.26,-15 207.26,-22\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x229fff13f10>"
]
},
"execution_count": 127,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = Value(3.0, label='a')\n",
"b = a + a ; b.label = 'b'\n",
"b.backward()\n",
"draw_dot(b)"
]
},
{
"cell_type": "markdown",
"id": "88b4e727-0ab8-448e-8076-45fcb46cbe0a",
"metadata": {},
"source": [
"The grad of a should be 2, NOT 1!"
]
},
{
"cell_type": "code",
"execution_count": 129,
"id": "8d38e9af-126f-4574-9fd7-5ea1e9285d23",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 13.1.0 (20250701.0955)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"823pt\" height=\"100pt\"\n",
" viewBox=\"0.00 0.00 823.00 100.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 96)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-96 818.5,-96 818.5,4 -4,4\"/>\n",
"<!-- 2377654777360 -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>2377654777360</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"628.5,-27.5 628.5,-63.5 814.5,-63.5 814.5,-27.5 628.5,-27.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"638.75\" y=\"-40.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">f</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"649,-28 649,-63.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"691.12\" y=\"-40.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data &#45;6.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"733.25,-28 733.25,-63.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"773.88\" y=\"-40.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 1.0000</text>\n",
"</g>\n",
"<!-- 2377654777360* -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>2377654777360*</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"565.5\" cy=\"-45.5\" rx=\"27\" ry=\"18\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"565.5\" y=\"-40.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">*</text>\n",
"</g>\n",
"<!-- 2377654777360*&#45;&gt;2377654777360 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>2377654777360*&#45;&gt;2377654777360</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M592.97,-45.5C600.11,-45.5 608.28,-45.5 616.96,-45.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"616.69,-49 626.69,-45.5 616.69,-42 616.69,-49\"/>\n",
"</g>\n",
"<!-- 2377654827536 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>2377654827536</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"0,-55.5 0,-91.5 188.25,-91.5 188.25,-55.5 0,-55.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"11.38\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">b</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"22.75,-56 22.75,-91.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"62.62\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 3.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"102.5,-56 102.5,-91.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"145.38\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad &#45;2.0000</text>\n",
"</g>\n",
"<!-- 2377654825616* -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>2377654825616*</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"251.25\" cy=\"-73.5\" rx=\"27\" ry=\"18\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"251.25\" y=\"-68.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">*</text>\n",
"</g>\n",
"<!-- 2377654827536&#45;&gt;2377654825616* -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>2377654827536&#45;&gt;2377654825616*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M188.49,-73.5C196.9,-73.5 205.05,-73.5 212.52,-73.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"212.49,-77 222.49,-73.5 212.49,-70 212.49,-77\"/>\n",
"</g>\n",
"<!-- 2377654692304+ -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>2377654692304+</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"251.25\" cy=\"-18.5\" rx=\"27\" ry=\"18\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"251.25\" y=\"-13.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">+</text>\n",
"</g>\n",
"<!-- 2377654827536&#45;&gt;2377654692304+ -->\n",
"<g id=\"edge8\" class=\"edge\">\n",
"<title>2377654827536&#45;&gt;2377654692304+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M163.14,-55C171.65,-52.33 180.19,-49.47 188.25,-46.5 198.23,-42.82 208.87,-38.26 218.46,-33.9\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"219.79,-37.14 227.38,-29.75 216.84,-30.79 219.79,-37.14\"/>\n",
"</g>\n",
"<!-- 2377653712016 -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>2377653712016</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"0.38,-0.5 0.38,-36.5 187.88,-36.5 187.88,-0.5 0.38,-0.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"11.38\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">a</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"22.38,-1 22.38,-36.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"64.5\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data &#45;2.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"106.62,-1 106.62,-36.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"147.25\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 3.0000</text>\n",
"</g>\n",
"<!-- 2377653712016&#45;&gt;2377654825616* -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>2377653712016&#45;&gt;2377654825616*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M160.08,-36.95C169.59,-39.96 179.21,-43.19 188.25,-46.5 198.1,-50.11 208.63,-54.49 218.16,-58.66\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"216.47,-61.74 227.03,-62.62 219.32,-55.35 216.47,-61.74\"/>\n",
"</g>\n",
"<!-- 2377653712016&#45;&gt;2377654692304+ -->\n",
"<g id=\"edge9\" class=\"edge\">\n",
"<title>2377653712016&#45;&gt;2377654692304+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M188.05,-18.5C196.6,-18.5 204.9,-18.5 212.49,-18.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"212.28,-22 222.28,-18.5 212.28,-15 212.28,-22\"/>\n",
"</g>\n",
"<!-- 2377654825616 -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>2377654825616</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"314.25,-55.5 314.25,-91.5 502.5,-91.5 502.5,-55.5 314.25,-55.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"325.62\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">d</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"337,-56 337,-91.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"379.12\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data &#45;6.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"421.25,-56 421.25,-91.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"461.88\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 1.0000</text>\n",
"</g>\n",
"<!-- 2377654825616&#45;&gt;2377654777360* -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>2377654825616&#45;&gt;2377654777360*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M502.74,-56.65C511.61,-55.05 520.18,-53.5 527.98,-52.09\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"528.45,-55.56 537.67,-50.34 527.2,-48.68 528.45,-55.56\"/>\n",
"</g>\n",
"<!-- 2377654825616*&#45;&gt;2377654825616 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>2377654825616*&#45;&gt;2377654825616</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M278.56,-73.5C285.74,-73.5 293.97,-73.5 302.72,-73.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"302.54,-77 312.54,-73.5 302.54,-70 302.54,-77\"/>\n",
"</g>\n",
"<!-- 2377654692304 -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>2377654692304</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"314.62,-0.5 314.62,-36.5 502.12,-36.5 502.12,-0.5 314.62,-0.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"325.62\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">e</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"336.62,-1 336.62,-36.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"376.5\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 1.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"416.38,-1 416.38,-36.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"459.25\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad &#45;6.0000</text>\n",
"</g>\n",
"<!-- 2377654692304&#45;&gt;2377654777360* -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>2377654692304&#45;&gt;2377654777360*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M502.3,-34.67C511.21,-36.22 519.84,-37.73 527.69,-39.09\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"527.01,-42.53 537.46,-40.79 528.21,-35.63 527.01,-42.53\"/>\n",
"</g>\n",
"<!-- 2377654692304+&#45;&gt;2377654692304 -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>2377654692304+&#45;&gt;2377654692304</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M278.56,-18.5C285.8,-18.5 294.11,-18.5 302.95,-18.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"302.87,-22 312.87,-18.5 302.87,-15 302.87,-22\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x22997449b10>"
]
},
"execution_count": 129,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = Value(-2.0, label='a')\n",
"b = Value(3.0, label='b')\n",
"d = a * b ; d.label = 'd'\n",
"e = a + b ; e.label = 'e'\n",
"f = d * e ; f.label = 'f'\n",
"\n",
"f.backward()\n",
"\n",
"draw_dot(f)"
]
},
{
"cell_type": "markdown",
"id": "6bd645d9-55c0-43fc-8d42-cc047042fa5e",
"metadata": {},
"source": [
"For a and b, there should be two leaves with grad -6.0 from e, and two leaves with grad 3.0 and -2.0. But the backprop only shows grad for one of the operations."
]
},
{
"cell_type": "markdown",
"id": "a9931101-1a6a-4ea3-8c89-51d2918b508e",
"metadata": {},
"source": [
"If you see the chain rule, for multivariate solutions, the gradients add up. Fixing this in our Value object."
]
},
{
"cell_type": "code",
"execution_count": 133,
"id": "042f67cb-6a18-484b-809d-43d14ac67b57",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 13.1.0 (20250701.0955)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"826pt\" height=\"100pt\"\n",
" viewBox=\"0.00 0.00 826.00 100.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 96)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-96 822.25,-96 822.25,4 -4,4\"/>\n",
"<!-- 2379410983504 -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>2379410983504</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"0,-55.5 0,-91.5 192,-91.5 192,-55.5 0,-55.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"11\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">a</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"22,-56 22,-91.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"64.12\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data &#45;2.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"106.25,-56 106.25,-91.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"149.12\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad &#45;3.0000</text>\n",
"</g>\n",
"<!-- 2379410986704* -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>2379410986704*</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"255\" cy=\"-73.5\" rx=\"27\" ry=\"18\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"255\" y=\"-68.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">*</text>\n",
"</g>\n",
"<!-- 2379410983504&#45;&gt;2379410986704* -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>2379410983504&#45;&gt;2379410986704*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M192.4,-73.5C200.73,-73.5 208.79,-73.5 216.18,-73.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"216.05,-77 226.05,-73.5 216.05,-70 216.05,-77\"/>\n",
"</g>\n",
"<!-- 2377654486480+ -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>2377654486480+</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"255\" cy=\"-18.5\" rx=\"27\" ry=\"18\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"255\" y=\"-13.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">+</text>\n",
"</g>\n",
"<!-- 2379410983504&#45;&gt;2377654486480+ -->\n",
"<g id=\"edge8\" class=\"edge\">\n",
"<title>2379410983504&#45;&gt;2377654486480+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M166.39,-55.06C175.08,-52.38 183.79,-49.5 192,-46.5 201.99,-42.85 212.63,-38.3 222.22,-33.93\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"223.55,-37.17 231.15,-29.78 220.6,-30.83 223.55,-37.17\"/>\n",
"</g>\n",
"<!-- 2379410986704 -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>2379410986704</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"318,-55.5 318,-91.5 506.25,-91.5 506.25,-55.5 318,-55.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"329.38\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">d</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"340.75,-56 340.75,-91.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"382.88\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data &#45;6.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"425,-56 425,-91.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"465.62\" y=\"-68.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 1.0000</text>\n",
"</g>\n",
"<!-- 2377654489040* -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>2377654489040*</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"569.25\" cy=\"-45.5\" rx=\"27\" ry=\"18\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"569.25\" y=\"-40.45\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">*</text>\n",
"</g>\n",
"<!-- 2379410986704&#45;&gt;2377654489040* -->\n",
"<g id=\"edge9\" class=\"edge\">\n",
"<title>2379410986704&#45;&gt;2377654489040*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M506.49,-56.65C515.36,-55.05 523.93,-53.5 531.73,-52.09\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"532.2,-55.56 541.42,-50.34 530.95,-48.68 532.2,-55.56\"/>\n",
"</g>\n",
"<!-- 2379410986704*&#45;&gt;2379410986704 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>2379410986704*&#45;&gt;2379410986704</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M282.31,-73.5C289.49,-73.5 297.72,-73.5 306.47,-73.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"306.29,-77 316.29,-73.5 306.29,-70 306.29,-77\"/>\n",
"</g>\n",
"<!-- 2379410909584 -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>2379410909584</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"1.88,-0.5 1.88,-36.5 190.12,-36.5 190.12,-0.5 1.88,-0.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"13.25\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">b</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"24.62,-1 24.62,-36.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"64.5\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 3.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"104.38,-1 104.38,-36.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"147.25\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad &#45;8.0000</text>\n",
"</g>\n",
"<!-- 2379410909584&#45;&gt;2379410986704* -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>2379410909584&#45;&gt;2379410986704*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M163.53,-36.96C173.14,-39.96 182.87,-43.18 192,-46.5 201.86,-50.08 212.39,-54.46 221.92,-58.63\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"220.23,-61.71 230.79,-62.59 223.09,-55.32 220.23,-61.71\"/>\n",
"</g>\n",
"<!-- 2379410909584&#45;&gt;2377654486480+ -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>2379410909584&#45;&gt;2377654486480+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M190.59,-18.5C199.58,-18.5 208.29,-18.5 216.23,-18.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"216.11,-22 226.11,-18.5 216.11,-15 216.11,-22\"/>\n",
"</g>\n",
"<!-- 2377654486480 -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>2377654486480</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"318.38,-0.5 318.38,-36.5 505.88,-36.5 505.88,-0.5 318.38,-0.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"329.38\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">e</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"340.38,-1 340.38,-36.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"380.25\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data 1.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"420.12,-1 420.12,-36.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"463\" y=\"-13.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad &#45;6.0000</text>\n",
"</g>\n",
"<!-- 2377654486480&#45;&gt;2377654489040* -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>2377654486480&#45;&gt;2377654489040*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M506.05,-34.67C514.96,-36.22 523.59,-37.73 531.44,-39.09\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"530.76,-42.53 541.21,-40.79 531.96,-35.63 530.76,-42.53\"/>\n",
"</g>\n",
"<!-- 2377654486480+&#45;&gt;2377654486480 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>2377654486480+&#45;&gt;2377654486480</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M282.31,-18.5C289.55,-18.5 297.86,-18.5 306.7,-18.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"306.62,-22 316.62,-18.5 306.62,-15 306.62,-22\"/>\n",
"</g>\n",
"<!-- 2377654489040 -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>2377654489040</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"632.25,-27.5 632.25,-63.5 818.25,-63.5 818.25,-27.5 632.25,-27.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"642.5\" y=\"-40.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">f</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"652.75,-28 652.75,-63.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"694.88\" y=\"-40.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">data &#45;6.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"737,-28 737,-63.5\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"777.62\" y=\"-40.7\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">grad 1.0000</text>\n",
"</g>\n",
"<!-- 2377654489040*&#45;&gt;2377654489040 -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>2377654489040*&#45;&gt;2377654489040</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M596.72,-45.5C603.86,-45.5 612.03,-45.5 620.71,-45.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"620.44,-49 630.44,-45.5 620.44,-42 620.44,-49\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x2299737d310>"
]
},
"execution_count": 133,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = Value(-2.0, label='a')\n",
"b = Value(3.0, label='b')\n",
"d = a * b ; d.label = 'd'\n",
"e = a + b ; e.label = 'e'\n",
"f = d * e ; f.label = 'f'\n",
"\n",
"f.backward() # Fixed\n",
"\n",
"draw_dot(f)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}