{ "cells": [ { "cell_type": "code", "execution_count": 142, "id": "672d74cf-2fad-4493-9e3d-7ac2286dbfee", "metadata": {}, "outputs": [], "source": [ "import math\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "id": "e7295cd8-1f33-40d2-aca5-e226e5b945ba", "metadata": {}, "source": [ "Let's define a function f" ] }, { "cell_type": "code", "execution_count": 143, "id": "07a0fdd9-e9c4-4a28-9201-99429a638573", "metadata": {}, "outputs": [], "source": [ "def f(x):\n", " return 3*x**2 - 4*x + 5" ] }, { "cell_type": "code", "execution_count": 144, "id": "67093e12-40d4-4abd-a81d-2f7f63881e19", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "20.0" ] }, "execution_count": 144, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f(3.0)" ] }, { "cell_type": "markdown", "id": "ec1297be-8d68-4d59-af98-39e45eb9857b", "metadata": {}, "source": [ "We can also plot it for a range of values" ] }, { "cell_type": "code", "execution_count": 145, "id": "0a3f756f-886b-455d-b1f2-53060088dce9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 145, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "xs = np.arange(-5, 5, 0.25)\n", "ys = f(xs)\n", "plt.plot(xs, ys)" ] }, { "cell_type": "markdown", "id": "5c9cb820-159e-420d-b0d0-a8706e19f9d5", "metadata": {}, "source": [ "Now, what's a derivate?\n", "> It is sensitivity of the function to the change of the output with respect to the input.\n", "\n", "In simpler terms, (f(x+h) - f(x))/h, where h tends to zero. This gives us the slope." ] }, { "cell_type": "code", "execution_count": 146, "id": "7866b622-e10c-4003-844f-9044e7bdc080", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3.0000002482211127e-05" ] }, "execution_count": 146, "metadata": {}, "output_type": "execute_result" } ], "source": [ "h = 0.00001\n", "x = 2/3\n", "(f(x+h) - f(x))/h" ] }, { "cell_type": "markdown", "id": "4f027e6d-d889-4827-867e-335a96c20045", "metadata": {}, "source": [ "Let's define a function with mupltiple inputs" ] }, { "cell_type": "code", "execution_count": 147, "id": "a3bcd24a-4a48-4d85-8d49-3fb751589b40", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4.0\n" ] } ], "source": [ "a = 2.0\n", "b = -3.0\n", "c = 10.0\n", "\n", "d = a*b + c\n", "print(d)" ] }, { "cell_type": "code", "execution_count": 148, "id": "78dcb378-190f-4421-8d95-af79cdf05b18", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "d1 4.0\n", "d2 3.999699999999999\n", "slope -3.000000000010772\n" ] } ], "source": [ "h = 0.0001\n", "a = 2.0\n", "b = -3.0\n", "c = 10.0\n", "\n", "d1 = a*b + c\n", "a += h\n", "d2 = a*b + c\n", "\n", "print('d1', d1)\n", "print('d2', d2)\n", "print('slope', (d2 - d1)/h) # By the good old derivation (wrt to a), we know that this will be b" ] }, { "cell_type": "code", "execution_count": 149, "id": "9fb1208f-1a1e-4fdb-a290-a7d0e8a5050a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "d1 4.0\n", "d2 4.0002\n", "slope 2.0000000000042206\n" ] } ], "source": [ "h = 0.0001\n", "a = 2.0\n", "b = -3.0\n", "c = 10.0\n", "\n", "d1 = a*b + c\n", "b += h\n", "d2 = a*b + c\n", "\n", "print('d1', d1)\n", "print('d2', d2)\n", "print('slope', (d2 - d1)/h) # By the good old derivation (wrt to b), we know that this will be a" ] }, { "cell_type": "code", "execution_count": 150, "id": "e7bfb7ef-4c16-44c6-afce-ac20ab858b4f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "d1 4.0\n", "d2 4.0001\n", "slope 0.9999999999976694\n" ] } ], "source": [ "h = 0.0001\n", "a = 2.0\n", "b = -3.0\n", "c = 10.0\n", "\n", "d1 = a*b + c\n", "c += h\n", "d2 = a*b + c\n", "\n", "print('d1', d1)\n", "print('d2', d2)\n", "print('slope', (d2 - d1)/h) # By the good old derivation (wrt to c), we know that this will be 1" ] }, { "cell_type": "markdown", "id": "ad332637-3dbe-4630-bb5a-d5a76e935a9b", "metadata": {}, "source": [ "The NN will be mathematically very large expressions. We now start by building the data structures for this. Let's start by making the `Value` object from mircograd" ] }, { "cell_type": "code", "execution_count": 167, "id": "e11f7575-00be-44a4-9bde-ae9288b1922a", "metadata": { "scrolled": true }, "outputs": [], "source": [ "class Value:\n", "\n", " def __init__(self, data, _children=(), _op='', label=''): # Define a empty tuple `children` to keep the pointers of other Value objects\n", " # Define a empty set `op` to keep track of what created that Value object\n", " self.data = data\n", " self.grad = 0.0 # this is the derivative of Value wrt to its nodes; initalized as 0 (we assume at the beginning that every Value doesn't impact the output)\n", " self._backward = lambda: None # for a leaf node, this should be nothing.\n", " self._prev = set(_children) # This will be a empty set when we define a new Value object (a, b, c)\n", " self._op = _op\n", " self.label = label\n", "\n", " def __repr__(self):\n", " return f\"Value(data={self.data})\"\n", "\n", " def __add__(self, other): # a.__add__(b)\n", " other = other if isinstance(other, Value) else Value(other) # if other is NOT a Value obj, just make one\n", " out = Value(self.data + other.data, (self, other), '+') # Since self.data and other.data is python floating point number,\n", " # the addition is according to whatever is defined in the python kernel\n", " def _backward():\n", " self.grad += 1.0 * out.grad\n", " other.grad += 1.0 * out.grad\n", " out._backward = _backward\n", " \n", " return out\n", "\n", " def __radd__(self, other): # other + self\n", " return self + other\n", " \n", " def __mul__(self, other): # a.__mul__(b) # You can't name it mult or multi, because we are defining magic methods for the Value object\n", " other = other if isinstance(other, Value) else Value(other) # if other is NOT a Value obj, just make one\n", " out = Value(self.data * other.data, (self, other), '*')\n", "\n", " def _backward():\n", " self.grad += other.data * out.grad\n", " other.grad += self.data * out.grad\n", " out._backward = _backward\n", " \n", " return out\n", "\n", " def __pow__(self, other):\n", " assert isinstance(other, (int, float)) # other is only int or float\n", " out = Value(self.data**other, (self, ), f'**{other}')\n", "\n", " def _backward():\n", " self.grad += other * self.data**(other-1) * out.grad\n", " out._backward = _backward\n", "\n", " return out\n", "\n", " def __rmul__(self, other): # other * self\n", " return self*other\n", "\n", " def __truediv__(self, other): # self / other\n", " return self * other**-1\n", "\n", " def __neg__(self): # -self\n", " return self * -1\n", "\n", " def __sub__(self, other): # self - other\n", " return self + (-other)\n", " \n", " def tanh(self):\n", " x = self.data\n", " t = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)\n", " out = Value(t, (self, ), 'tanh')\n", "\n", " def _backward():\n", " self.grad += (1 - t**2) * out.grad\n", " out._backward = _backward\n", " \n", " return out\n", "\n", " def exp(self, other):\n", " assert isinstance(other, (float, int))\n", " x = self.data\n", " out = Value(math.exp(x), (self, ), 'exp')\n", "\n", " def _backward():\n", " self.grad += out.data * out.grad\n", " out._backward = _backward\n", " \n", " return out\n", " \n", " def backward(self):\n", " \n", " topo = []\n", " visited = set()\n", " def build_topo(v):\n", " if v not in visited:\n", " visited.add(v)\n", " for child in v._prev:\n", " build_topo(child)\n", " topo.append(v)\n", " build_topo(self)\n", " \n", " self.grad = 1.0\n", " \n", " for node in reversed(topo):\n", " node._backward()" ] }, { "cell_type": "code", "execution_count": 152, "id": "cdf24235-ee48-4675-bb65-bd78d1ae201e", "metadata": {}, "outputs": [], "source": [ "# imported from mircograd's codebase\n", "\n", "from graphviz import Digraph\n", "\n", "def trace(root):\n", " # builds a set of all nodes and edges in a graph\n", " nodes, edges = set(), set()\n", " def build(v):\n", " if v not in nodes:\n", " nodes.add(v)\n", " for child in v._prev:\n", " edges.add((child, v))\n", " build(child)\n", " build(root)\n", " return nodes, edges\n", "\n", "def draw_dot(root):\n", " dot = Digraph(format='svg', graph_attr={'rankdir': 'LR'}) # LR = left to right\n", " \n", " nodes, edges = trace(root)\n", " for n in nodes:\n", " uid = str(id(n))\n", " # for any value in the graph, create a rectangular ('record') node for it\n", " dot.node(name = uid, label = \"{%s | data %.4f | grad %.4f}\" % (n.label, n.data, n.grad), shape='record')\n", " if n._op:\n", " # if this value is a result of some operation, create an op node for it\n", " dot.node(name = uid + n._op, label = n._op)\n", " # and connect this node to it\n", " dot.edge(uid + n._op, uid)\n", "\n", " for n1, n2 in edges:\n", " # connect n1 to the op node of n2\n", " dot.edge(str(id(n1)), str(id(n2)) + n2._op)\n", "\n", " return dot" ] }, { "attachments": { "34f0ebc5-cb10-4db9-b7a9-9f501a3732da.webp": { "image/webp": "" } }, "cell_type": "markdown", "id": "4eff7563-3a93-493d-8788-bb055a23d07a", "metadata": {}, "source": [ "Let's look at another backpropagation using a neuron\n", "\n", "![neuron_model.webp](attachment:34f0ebc5-cb10-4db9-b7a9-9f501a3732da.webp)\n", "\n", "activation function is usually sigmoid or tanh [[more here]](https://en.wikipedia.org/wiki/Activation_function)" ] }, { "cell_type": "code", "execution_count": 153, "id": "29f8a9d0-16a6-49e3-b1db-52084db2ce24", "metadata": {}, "outputs": [], "source": [ "# inputs x1,x2\n", "x1 = Value(2.0, label='x1')\n", "x2 = Value(0.0, label='x2')\n", "# weights w1,w2\n", "w1 = Value(-3.0, label='w1')\n", "w2 = Value(1.0, label='w2')\n", "# bias of the neuron\n", "b = Value(6.8813735870195432, label='b')\n", "# x1*w1 + x2*w2 + b\n", "x1w1 = x1*w1; x1w1.label = 'x1*w1'\n", "x2w2 = x2*w2; x2w2.label = 'x2*w2'\n", "x1w1x2w2 = x1w1 + x2w2; x1w1x2w2.label = 'x1*w1 + x2*w2'\n", "n = x1w1x2w2 + b; n.label = 'n'\n", "o = n.tanh(); o.label = 'output'\n", "o.backward()" ] }, { "cell_type": "code", "execution_count": 154, "id": "1a1bdc0d-23e1-44bc-8b8e-1e88cdb64b2c", "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "1315738898448\n", "\n", "x1*w1 + x2*w2\n", "\n", "data -6.0000\n", "\n", "grad 0.5000\n", "\n", "\n", "\n", "1315738897168+\n", "\n", "+\n", "\n", "\n", "\n", "1315738898448->1315738897168+\n", "\n", "\n", "\n", "\n", "\n", "1315738898448+\n", "\n", "+\n", "\n", "\n", "\n", "1315738898448+->1315738898448\n", "\n", "\n", "\n", "\n", "\n", "1315738898960\n", "\n", "w2\n", "\n", "data 1.0000\n", "\n", "grad 0.0000\n", "\n", "\n", "\n", "1315738896848*\n", "\n", "*\n", "\n", "\n", "\n", "1315738898960->1315738896848*\n", "\n", "\n", "\n", "\n", "\n", "1315738898512\n", "\n", "x2\n", "\n", "data 0.0000\n", "\n", "grad 0.5000\n", "\n", "\n", "\n", "1315738898512->1315738896848*\n", "\n", "\n", "\n", "\n", "\n", "1315738899088\n", "\n", "b\n", "\n", "data 6.8814\n", "\n", "grad 0.5000\n", "\n", "\n", "\n", "1315738899088->1315738897168+\n", "\n", "\n", "\n", "\n", "\n", "1315731070736\n", "\n", "output\n", "\n", "data 0.7071\n", "\n", "grad 1.0000\n", "\n", "\n", "\n", "1315731070736tanh\n", "\n", "tanh\n", "\n", "\n", "\n", "1315731070736tanh->1315731070736\n", "\n", "\n", "\n", "\n", "\n", "1315738897168\n", "\n", "n\n", "\n", "data 0.8814\n", "\n", "grad 0.5000\n", "\n", "\n", "\n", "1315738897168->1315731070736tanh\n", "\n", "\n", "\n", "\n", "\n", "1315738897168+->1315738897168\n", "\n", "\n", "\n", "\n", "\n", "1315738897680\n", "\n", "x1\n", "\n", "data 2.0000\n", "\n", "grad -1.5000\n", "\n", "\n", "\n", "1315738897872*\n", "\n", "*\n", "\n", "\n", "\n", "1315738897680->1315738897872*\n", "\n", "\n", "\n", "\n", "\n", "1315738909968\n", "\n", "w1\n", "\n", "data -3.0000\n", "\n", "grad 1.0000\n", "\n", "\n", "\n", "1315738909968->1315738897872*\n", "\n", "\n", "\n", "\n", "\n", "1315738896848\n", "\n", "x2*w2\n", "\n", "data 0.0000\n", "\n", "grad 0.5000\n", "\n", "\n", "\n", "1315738896848->1315738898448+\n", "\n", "\n", "\n", "\n", "\n", "1315738896848*->1315738896848\n", "\n", "\n", "\n", "\n", "\n", "1315738897872\n", "\n", "x1*w1\n", "\n", "data -6.0000\n", "\n", "grad 0.5000\n", "\n", "\n", "\n", "1315738897872->1315738898448+\n", "\n", "\n", "\n", "\n", "\n", "1315738897872*->1315738897872\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 154, "metadata": {}, "output_type": "execute_result" } ], "source": [ "draw_dot(o)" ] }, { "cell_type": "code", "execution_count": 155, "id": "7d33e845-75ed-4aec-9310-0f55e866e065", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Value(data=3.0)" ] }, "execution_count": 155, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = Value(2.0)\n", "a + 1" ] }, { "cell_type": "markdown", "id": "59924a5f-ae4c-4d0f-b45d-611b32ba6df7", "metadata": {}, "source": [ "Let's add this QoL improvement." ] }, { "cell_type": "code", "execution_count": 156, "id": "5ec7a25c-1540-47d3-aefc-eea6cfac8e8a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Value(data=4.0)\n", "Value(data=4.0)\n" ] } ], "source": [ "print(a * 2)\n", "print(2 * a)" ] }, { "cell_type": "code", "execution_count": 157, "id": "048d58a9-e70d-45b7-885f-e34317323283", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Value(data=54.598150033144236)" ] }, "execution_count": 157, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = Value(4.0)\n", "a.exp()" ] }, { "cell_type": "markdown", "id": "8830493e-47d3-44dd-957f-9c84a9fafa99", "metadata": {}, "source": [ "a / b\n", "a * 1/b\n", "a * b**-1\n", "\n", "so, let's implement n**k and have a spl case for division" ] }, { "cell_type": "code", "execution_count": 158, "id": "c75d090e-6f68-41d8-b874-fd72a5c2303a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Value(data=125.0)" ] }, "execution_count": 158, "metadata": {}, "output_type": "execute_result" } ], "source": [ "z = Value(5.0)\n", "z**3" ] }, { "cell_type": "code", "execution_count": 159, "id": "ff12ee1b-b022-4790-81dd-b88527719e29", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Value(data=0.5)" ] }, "execution_count": 159, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = Value(4.0)\n", "b = Value(8.0)\n", "a / b" ] }, { "cell_type": "code", "execution_count": 160, "id": "b313cee1-dd94-40bc-a414-ac6475443585", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Value(data=-4.0)" ] }, "execution_count": 160, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a = Value(4.0)\n", "b = Value(8.0)\n", "a - b" ] }, { "cell_type": "code", "execution_count": 161, "id": "7079d013-6e83-4b81-b1d1-d4860d0bf3e7", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "1315738821136\n", "\n", " \n", "\n", "data 4.8284\n", "\n", "grad 0.1464\n", "\n", "\n", "\n", "1315738812816*\n", "\n", "*\n", "\n", "\n", "\n", "1315738821136->1315738812816*\n", "\n", "\n", "\n", "\n", "\n", "1315738821136+\n", "\n", "+\n", "\n", "\n", "\n", "1315738821136+->1315738821136\n", "\n", "\n", "\n", "\n", "\n", "1315738816592\n", "\n", "x1\n", "\n", "data 2.0000\n", "\n", "grad -1.5000\n", "\n", "\n", "\n", "1315738816144*\n", "\n", "*\n", "\n", "\n", "\n", "1315738816592->1315738816144*\n", "\n", "\n", "\n", "\n", "\n", "1315738816912\n", "\n", "w2\n", "\n", "data 1.0000\n", "\n", "grad 0.0000\n", "\n", "\n", "\n", "1315738814800*\n", "\n", "*\n", "\n", "\n", "\n", "1315738816912->1315738814800*\n", "\n", "\n", "\n", "\n", "\n", "1315738812496\n", "\n", " \n", "\n", "data 6.8284\n", "\n", "grad -0.1036\n", "\n", "\n", "\n", "1315738827216**-1\n", "\n", "**-1\n", "\n", "\n", "\n", "1315738812496->1315738827216**-1\n", "\n", "\n", "\n", "\n", "\n", "1315738812496+\n", "\n", "+\n", "\n", "\n", "\n", "1315738812496+->1315738812496\n", "\n", "\n", "\n", "\n", "\n", "1315738825808\n", "\n", " \n", "\n", "data 1.0000\n", "\n", "grad -0.1036\n", "\n", "\n", "\n", "1315738825808->1315738812496+\n", "\n", "\n", "\n", "\n", "\n", "1315738816144\n", "\n", "x1*w1\n", "\n", "data -6.0000\n", "\n", "grad 0.5000\n", "\n", "\n", "\n", "1315738813264+\n", "\n", "+\n", "\n", "\n", "\n", "1315738816144->1315738813264+\n", "\n", "\n", "\n", "\n", "\n", "1315738816144*->1315738816144\n", "\n", "\n", "\n", "\n", "\n", "1315738819984\n", "\n", "x2\n", "\n", "data 0.0000\n", "\n", "grad 0.5000\n", "\n", "\n", "\n", "1315738819984->1315738814800*\n", "\n", "\n", "\n", "\n", "\n", "1315738827984\n", "\n", " \n", "\n", "data -1.0000\n", "\n", "grad 0.1464\n", "\n", "\n", "\n", "1315738827984->1315738821136+\n", "\n", "\n", "\n", "\n", "\n", "1315738861968\n", "\n", " \n", "\n", "data 5.8284\n", "\n", "grad 0.0429\n", "\n", "\n", "\n", "1315738861968->1315738821136+\n", "\n", "\n", "\n", "\n", "\n", "1315738861968->1315738812496+\n", "\n", "\n", "\n", "\n", "\n", "1315738861968exp\n", "\n", "exp\n", "\n", "\n", "\n", "1315738861968exp->1315738861968\n", "\n", "\n", "\n", "\n", "\n", "1315738827216\n", "\n", " \n", "\n", "data 0.1464\n", "\n", "grad 4.8284\n", "\n", "\n", "\n", "1315738827216->1315738812816*\n", "\n", "\n", "\n", "\n", "\n", "1315738827216**-1->1315738827216\n", "\n", "\n", "\n", "\n", "\n", "1315738814224\n", "\n", "w1\n", "\n", "data -3.0000\n", "\n", "grad 1.0000\n", "\n", "\n", "\n", "1315738814224->1315738816144*\n", "\n", "\n", "\n", "\n", "\n", "1315738820432\n", "\n", " \n", "\n", "data 2.0000\n", "\n", "grad 0.2203\n", "\n", "\n", "\n", "1315738813840*\n", "\n", "*\n", "\n", "\n", "\n", "1315738820432->1315738813840*\n", "\n", "\n", "\n", "\n", "\n", "1315738828624\n", "\n", "b\n", "\n", "data 6.8814\n", "\n", "grad 0.5000\n", "\n", "\n", "\n", "1315738813392+\n", "\n", "+\n", "\n", "\n", "\n", "1315738828624->1315738813392+\n", "\n", "\n", "\n", "\n", "\n", "1315738813264\n", "\n", "x1*w1 + x2*w2\n", "\n", "data -6.0000\n", "\n", "grad 0.5000\n", "\n", "\n", "\n", "1315738813264->1315738813392+\n", "\n", "\n", "\n", "\n", "\n", "1315738813264+->1315738813264\n", "\n", "\n", "\n", "\n", "\n", "1315738814800\n", "\n", "x2*w2\n", "\n", "data 0.0000\n", "\n", "grad 0.5000\n", "\n", "\n", "\n", "1315738814800->1315738813264+\n", "\n", "\n", "\n", "\n", "\n", "1315738814800*->1315738814800\n", "\n", "\n", "\n", "\n", "\n", "1315738813840\n", "\n", " \n", "\n", "data 1.7627\n", "\n", "grad 0.2500\n", "\n", "\n", "\n", "1315738813840->1315738861968exp\n", "\n", "\n", "\n", "\n", "\n", "1315738813840*->1315738813840\n", "\n", "\n", "\n", "\n", "\n", "1315738813392\n", "\n", "n\n", "\n", "data 0.8814\n", "\n", "grad 0.5000\n", "\n", "\n", "\n", "1315738813392->1315738813840*\n", "\n", "\n", "\n", "\n", "\n", "1315738813392+->1315738813392\n", "\n", "\n", "\n", "\n", "\n", "1315738812816\n", "\n", "output\n", "\n", "data 0.7071\n", "\n", "grad 1.0000\n", "\n", "\n", "\n", "1315738812816*->1315738812816\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 161, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# inputs x1,x2\n", "x1 = Value(2.0, label='x1')\n", "x2 = Value(0.0, label='x2')\n", "# weights w1,w2\n", "w1 = Value(-3.0, label='w1')\n", "w2 = Value(1.0, label='w2')\n", "# bias of the neuron\n", "b = Value(6.8813735870195432, label='b')\n", "# x1*w1 + x2*w2 + b\n", "x1w1 = x1*w1; x1w1.label = 'x1*w1'\n", "x2w2 = x2*w2; x2w2.label = 'x2*w2'\n", "x1w1x2w2 = x1w1 + x2w2; x1w1x2w2.label = 'x1*w1 + x2*w2'\n", "n = x1w1x2w2 + b; n.label = 'n'\n", "# Let's define tanh using exp and divide and subtract\n", "e = (2*n).exp()\n", "o = (e - 1) / (e + 1) \n", "#---\n", "o.label = 'output'\n", "o.backward()\n", "draw_dot(o)" ] }, { "cell_type": "markdown", "id": "c817b8f7-a9e5-4193-a759-27973d83868d", "metadata": {}, "source": [ "Now let's look at the same thing in pytorch" ] }, { "cell_type": "code", "execution_count": 162, "id": "961afc05-1958-41e0-b434-40d8b0a148e3", "metadata": {}, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": 163, "id": "67c41ab6-8b5d-4908-a868-0d63c196ae5f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.7071066904050358\n", "---\n", "x2 0.5000001283844369\n", "w2 0.0\n", "x1 -1.5000003851533106\n", "w1 1.0000002567688737\n" ] } ], "source": [ "x1 = torch.Tensor([2.0]).double() ; x1.requires_grad = True # This is false by default because no one needs to see gradients on the input layer\n", "x2 = torch.Tensor([0.0]).double() ; x2.requires_grad = True\n", "w1 = torch.Tensor([-3.0]).double() ; w1.requires_grad = True\n", "w2 = torch.Tensor([1.0]).double() ; w2.requires_grad = True\n", "b = torch.Tensor([6.8813735870195432]).double() ; b.requires_grad = True\n", "n = x1*w1 + x2*w2 + b\n", "o = torch.tanh(n)\n", "\n", "print(o.data.item())\n", "o.backward()\n", "\n", "print('---')\n", "print('x2', x2.grad.item())\n", "print('w2', w2.grad.item())\n", "print('x1', x1.grad.item())\n", "print('w1', w1.grad.item())" ] }, { "cell_type": "markdown", "id": "43f61380-db69-4b53-9a87-e6121fe2eae1", "metadata": {}, "source": [ "Let's now build the a neural net library (multi-layer perceptron) in mircograd" ] }, { "cell_type": "code", "execution_count": 175, "id": "9399d558-a3c6-44ae-8b3c-a995d750f20e", "metadata": {}, "outputs": [], "source": [ "import random\n", "class Neuron:\n", "\n", " def __init__(self, nin): # nin is number of inputs\n", " self.w = [Value(random.uniform(-1,1)) for _ in range(nin)]\n", " self.b = Value(random.uniform(-1,1))\n", "\n", " def __call__(self, x):\n", " # w * x + b\n", " act = sum((wi * xi for wi, xi in zip(self.w, x)), self.b)\n", " out = act.tanh()\n", " return out\n", "\n", " def parameters(self):\n", " return self.w + [self.b]\n", "\n", "class Layer:\n", "\n", " def __init__(self, nin, nout): # nin is the number of inputs on each neuron, nout is the number of neurons in the layer\n", " self.neurons=[Neuron(nin) for _ in range(nout)]\n", "\n", " def __call__(self, x):\n", " outs = [n(x) for n in self.neurons]\n", " return outs[0] if len(outs) == 1 else outs # QoL \n", "\n", " def parameters(self):\n", " return [p for neuron in self.neurons for p in neuron.parameters()]\n", " # params = []\n", " # for neuron in self.neurons:\n", " # ps = neurons.parameters()\n", " # params.extend(ps)\n", " # return params\n", "\n", "class MLP:\n", "\n", " def __init__(self, nin, nouts): # nout is the list of the numbers of neurons in each layer\n", " sz = [nin] + nouts\n", " self.layers = [Layer(sz[i], sz[i+1]) for i in range(len(nouts))]\n", "\n", " def __call__(self, x):\n", " for layer in self.layers:\n", " x = layer(x)\n", " return x\n", "\n", " def parameters(self):\n", " return [p for layer in self.layers for p in layer.parameters()]\n", "\n", "# x = [2.0, 3.0,]\n", "# n = Layer(2, 3) # three 2-dimensional neutrons in this Layer\n", "# n(x)" ] }, { "attachments": { "ee0672be-bf08-4422-b9a4-24bd40eea517.jpg": { "image/jpeg": "" } }, "cell_type": "markdown", "id": "b7f37088-3ee3-466b-974b-bdf7a1df0d25", "metadata": {}, "source": [ "![images.jpg](attachment:ee0672be-bf08-4422-b9a4-24bd40eea517.jpg)\n", "\n", "Now let's define a layer and a MLP (multi-layer perceptron)" ] }, { "cell_type": "markdown", "id": "7e885e5c-3de4-4d81-89c5-02423da45778", "metadata": {}, "source": [ "Let's train a NN! (finally!) This is a simple binary classifer NN" ] }, { "cell_type": "code", "execution_count": 240, "id": "58ab19a9-8f0f-41a5-b270-95f31a1127e4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Value(data=0.5295324547878887)" ] }, "execution_count": 240, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = [2.0, 3.0, -1.0]\n", "n = MLP(3, [4, 4, 1]) # implementing the example image\n", "n(x)" ] }, { "cell_type": "code", "execution_count": 184, "id": "78f114fb-4736-458e-9eb0-3e03405dc63c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "41" ] }, "execution_count": 184, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(n.parameters())" ] }, { "cell_type": "code", "execution_count": 177, "id": "e095ac0b-5aa5-4224-a2d7-67a5af1c9742", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[Value(data=0.07700064719309661),\n", " Value(data=0.6874724332729629),\n", " Value(data=-0.7191542943909137),\n", " Value(data=-0.12016611563222777)]" ] }, "execution_count": 177, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# when we feed the first inputs, we should get the first output (1.0) and so on.\n", "\n", "xs = [\n", " [2.0, 3.0, -1.0],\n", " [3.0, -1.0, 0.5],\n", " [0.5, 1.0, 1.0],\n", " [1.0, 1.0, -1.0],\n", "]\n", "ys = [1.0, -1.0, -1.0, 1.0] # desired targets\n", "ypred = [n(x) for x in xs]\n", "ypred" ] }, { "cell_type": "markdown", "id": "c9676e56-fff4-4d14-8120-955a9a9622ee", "metadata": {}, "source": [ "We want `ypred` to go as close to as `ys`. We do this (and in deep learning) by calculating a single number for perfomance called the *loss*. We are going to implement a mean-square loss" ] }, { "cell_type": "code", "execution_count": 178, "id": "584c98fa-5fef-4310-88b1-3638b41821e8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Value(data=5.033137455307795)" ] }, "execution_count": 178, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loss = sum((yout - ygt)**2 for ygt, yout in zip(ys, ypred))\n", "loss" ] }, { "cell_type": "code", "execution_count": 179, "id": "7e1b37ad-e268-4cbb-8f98-fa7e5daf44c8", "metadata": {}, "outputs": [], "source": [ "loss.backward()" ] }, { "cell_type": "code", "execution_count": 180, "id": "4768716a-2ff9-4e50-b7c1-fff9691b2ef5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.7368026977666375" ] }, "execution_count": 180, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n.layers[0].neurons[0].w[0].grad # this is now NOT zero!" ] }, { "cell_type": "code", "execution_count": 185, "id": "1eb237df-ca08-4e5f-b22f-1c6d4d10341e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "-0.32973208920024066" ] }, "execution_count": 185, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n.layers[0].neurons[0].w[0].data" ] }, { "cell_type": "code", "execution_count": 187, "id": "fcfcb22d-b973-45b5-aee4-78d7d22e6d5b", "metadata": {}, "outputs": [], "source": [ "for p in n.parameters():\n", " p.data += -0.01 * p.grad # negative sign because we want to minimise the loss. We think of grad as a vector pointing towards increased loss. This is gradient descent.\n", " # https://en.wikipedia.org/wiki/Gradient_descent#An_analogy_for_understanding_gradient_descent" ] }, { "cell_type": "markdown", "id": "033750e6-d9ca-4ce7-9582-e5a2841f953a", "metadata": {}, "source": [ "Let's look at the new loss (It has decreased)" ] }, { "cell_type": "code", "execution_count": 189, "id": "a78061fa-fca6-4a04-9a15-ebe3908dc021", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Value(data=3.7015356808242585)" ] }, "execution_count": 189, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ypred = [n(x) for x in xs]\n", "loss = sum((yout - ygt)**2 for ygt, yout in zip(ys, ypred))\n", "loss" ] }, { "cell_type": "markdown", "id": "5922cb79-0965-4f6b-b6c4-e30253cc57cb", "metadata": {}, "source": [ "Let backprop this loss and update the weights" ] }, { "cell_type": "code", "execution_count": 190, "id": "3bd72e52-5dec-4e1b-84d5-b7435f90d19f", "metadata": {}, "outputs": [], "source": [ "loss.backward()" ] }, { "cell_type": "code", "execution_count": 191, "id": "3dd28be9-dc33-415b-bf2c-3da6c8d59160", "metadata": {}, "outputs": [], "source": [ "for p in n.parameters():\n", " p.data += -0.01 * p.grad # negative sign because we want to minimise the loss. We think of grad as a vector pointing towards increased loss. This is gradient descent.\n", " # https://en.wikipedia.org/wiki/Gradient_descent#An_analogy_for_understanding_gradient_descent" ] }, { "cell_type": "markdown", "id": "c2b0b435-a7c0-4ab2-b907-55709bfb7342", "metadata": {}, "source": [ "Now let's find the new loss" ] }, { "cell_type": "code", "execution_count": 192, "id": "8cc95aa5-ecb4-4a24-b631-1daf8e9f7b78", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Value(data=3.193172494954373)" ] }, "execution_count": 192, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ypred = [n(x) for x in xs]\n", "loss = sum((yout - ygt)**2 for ygt, yout in zip(ys, ypred))\n", "loss" ] }, { "cell_type": "markdown", "id": "3faa653d-3787-4ab3-9abc-ea8df46153a1", "metadata": {}, "source": [ "So the flow is forward pass -> backward pass -> update" ] }, { "cell_type": "code", "execution_count": 241, "id": "f530b51e-9e4b-47b2-9fb0-335c48c3c424", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step # 0 6.063052689312497\n", "Step # 1 3.4913909241492496\n", "Step # 2 1.2084297159180784\n", "Step # 3 0.4656671369410741\n", "Step # 4 0.24785984660117516\n", "Step # 5 0.17205836536367008\n", "Step # 6 0.13140359668124182\n", "Step # 7 0.10607370364195885\n", "Step # 8 0.08881662744244229\n", "Step # 9 0.07632247962333527\n", "Step # 10 0.06686857143197884\n", "Step # 11 0.05947134159212191\n", "Step # 12 0.05352900687658706\n", "Step # 13 0.04865302737738714\n", "Step # 14 0.04458150502268218\n", "Step # 15 0.04113156601356174\n", "Step # 16 0.03817169980897338\n", "Step # 17 0.03560494239910917\n", "Step # 18 0.0333582524439753\n", "Step # 19 0.031375575848589486\n", "Step # 20 0.02961318784557092\n", "Step # 21 0.028036486100278167\n", "Step # 22 0.026617733994873168\n", "Step # 23 0.025334441397667242\n", "Step # 24 0.024168182466436462\n", "Step # 25 0.023103718921689623\n", "Step # 26 0.022128340591535754\n", "Step # 27 0.021231362959825167\n", "Step # 28 0.020403739813356792\n", "Step # 29 0.01963776138740571\n", "Step # 30 0.0189268167944545\n", "Step # 31 0.01826520532701045\n", "Step # 32 0.017647985303954276\n", "Step # 33 0.017070852033527407\n", "Step # 34 0.016530038559047344\n", "Step # 35 0.016022234379549593\n", "Step # 36 0.015544518462321941\n", "Step # 37 0.015094303701612632\n", "Step # 38 0.014669290606996854\n", "Step # 39 0.014267428481870884\n", "Step # 40 0.01388688271714819\n", "Step # 41 0.013526007106099663\n", "Step # 42 0.013183320304248127\n", "Step # 43 0.01285748572855482\n", "Step # 44 0.012547294324111991\n", "Step # 45 0.01225164973260155\n", "Step # 46 0.011969555481211223\n", "Step # 47 0.011700103878302559\n", "Step # 48 0.011442466356546074\n", "Step # 49 0.011195885048268026\n", "Step # 50 0.010959665413551208\n", "Step # 51 0.01073316977087405\n", "Step # 52 0.010515811604061303\n", "Step # 53 0.010307050539087767\n", "Step # 54 0.01010638790062819\n", "Step # 55 0.009913362771829845\n", "Step # 56 0.009727548492106068\n", "Step # 57 0.009548549537221928\n", "Step # 58 0.009375998733897512\n", "Step # 59 0.009209554767852004\n", "Step # 60 0.00904889994987455\n", "Step # 61 0.008893738209305405\n", "Step # 62 0.008743793288391685\n", "Step # 63 0.008598807114459579\n", "Step # 64 0.008458538329820275\n", "Step # 65 0.00832276096187455\n", "Step # 66 0.008191263218074911\n", "Step # 67 0.00806384639228904\n", "Step # 68 0.007940323870742451\n", "Step # 69 0.007820520227127632\n", "Step # 70 0.007704270397693438\n", "Step # 71 0.007591418928192805\n", "Step # 72 0.007481819285496838\n", "Step # 73 0.007375333227492005\n", "Step # 74 0.0072718302255883925\n", "Step # 75 0.007171186934788122\n", "Step # 76 0.0070732867068092455\n", "Step # 77 0.006978019142241778\n", "Step # 78 0.0068852796781358255\n", "Step # 79 0.006794969207796124\n", "Step # 80 0.00670699372988889\n", "Step # 81 0.006621264024259348\n", "Step # 82 0.006537695352119681\n", "Step # 83 0.006456207178497548\n", "Step # 84 0.006376722915042138\n", "Step # 85 0.006299169681467748\n", "Step # 86 0.0062234780840792975\n", "Step # 87 0.006149582009970493\n", "Step # 88 0.006077418435616704\n", "Step # 89 0.0060069272487023664\n", "Step # 90 0.005938051082127848\n", "Step # 91 0.005870735159236637\n", "Step # 92 0.0058049271493877865\n", "Step # 93 0.005740577033077032\n", "Step # 94 0.005677636975877802\n", "Step # 95 0.005616061210537649\n", "Step # 96 0.005555805926620974\n", "Step # 97 0.005496829167141412\n", "Step # 98 0.0054390907316724645\n", "Step # 99 0.005382552085468513\n", "Step # 100 0.005327176274165607\n", "Step # 101 0.005272927843666831\n", "Step # 102 0.005219772764848314\n", "Step # 103 0.005167678362751362\n", "Step # 104 0.005116613249951687\n", "Step # 105 0.005066547263821784\n", "Step # 106 0.005017451407423288\n", "Step # 107 0.00496929779378706\n", "Step # 108 0.004922059593356691\n", "Step # 109 0.004875710984387967\n", "Step # 110 0.00483022710611221\n", "Step # 111 0.004785584014485676\n", "Step # 112 0.004741758640360024\n", "Step # 113 0.00469872874992065\n", "Step # 114 0.004656472907251062\n", "Step # 115 0.004614970438891062\n", "Step # 116 0.004574201400266073\n", "Step # 117 0.004534146543873533\n", "Step # 118 0.004494787289119742\n", "Step # 119 0.004456105693708702\n", "Step # 120 0.004418084426490032\n", "Step # 121 0.004380706741680469\n", "Step # 122 0.004343956454378153\n", "Step # 123 0.004307817917295212\n", "Step # 124 0.0042722759986379205\n", "Step # 125 0.004237316061069736\n", "Step # 126 0.004202923941695134\n", "Step # 127 0.00416908593300756\n", "Step # 128 0.0041357887647473605\n", "Step # 129 0.004103019586619787\n", "Step # 130 0.004070765951825515\n", "Step # 131 0.0040390158013597155\n", "Step # 132 0.00400775744903826\n", "Step # 133 0.003976979567211711\n", "Step # 134 0.003946671173131104\n", "Step # 135 0.003916821615930658\n", "Step # 136 0.0038874205641950834\n", "Step # 137 0.003858457994081638\n", "Step # 138 0.0038299241779675714\n", "Step # 139 0.0038018096735966287\n", "Step # 140 0.0037741053136988835\n", "Step # 141 0.003746802196060106\n", "Step # 142 0.003719891674018089\n", "Step # 143 0.0036933653473650166\n", "Step # 144 0.003667215053634969\n", "Step # 145 0.003641432859758701\n", "Step # 146 0.00361601105406707\n", "Step # 147 0.0035909421386263116\n", "Step # 148 0.003566218821889728\n", "Step # 149 0.0035418340116498666\n", "Step # 150 0.003517780808277683\n", "Step # 151 0.003494052498234713\n", "Step # 152 0.0034706425478457192\n", "Step # 153 0.0034475445973193913\n", "Step # 154 0.003424752455006112\n", "Step # 155 0.0034022600918815187\n", "Step # 156 0.0033800616362457997\n", "Step # 157 0.0033581513686285923\n", "Step # 158 0.003336523716890975\n", "Step # 159 0.003315173251514751\n", "Step # 160 0.0032940946810716405\n", "Step # 161 0.003273282847863682\n", "Step # 162 0.0032527327237279312\n", "Step # 163 0.0032324394059978774\n", "Step # 164 0.0032123981136149104\n", "Step # 165 0.003192604183383576\n", "Step # 166 0.0031730530663640144\n", "Step # 167 0.0031537403243963964\n", "Step # 168 0.003134661626750936\n", "Step # 169 0.003115812746899099\n", "Step # 170 0.003097189559400266\n", "Step # 171 0.003078788036899308\n", "Step # 172 0.0030606042472306025\n", "Step # 173 0.003042634350623824\n", "Step # 174 0.0030248745970078003\n", "Step # 175 0.0030073213234078264\n", "Step # 176 0.0029899709514333668\n", "Step # 177 0.0029728199848519754\n", "Step # 178 0.00295586500724621\n", "Step # 179 0.002939102679750448\n", "Step # 180 0.002922529738863839\n", "Step # 181 0.002906142994337289\n", "Step # 182 0.0028899393271306856\n", "Step # 183 0.002873915687438391\n", "Step # 184 0.0028580690927797886\n", "Step # 185 0.002842396626152884\n", "Step # 186 0.002826895434248074\n", "Step # 187 0.0028115627257202106\n", "Step # 188 0.0027963957695165496\n", "Step # 189 0.0027813918932584844\n", "Step # 190 0.0027665484816751626\n", "Step # 191 0.002751862975086844\n", "Step # 192 0.002737332867936595\n", "Step # 193 0.002722955707367939\n", "Step # 194 0.0027087290918473375\n", "Step # 195 0.002694650669829513\n", "Step # 196 0.002680718138464186\n", "Step # 197 0.0026669292423429185\n", "Step # 198 0.002653281772284076\n", "Step # 199 0.0026397735641553854\n", "Step # 200 0.0026264024977318504\n", "Step # 201 0.0026131664955885165\n", "Step # 202 0.0026000635220264513\n", "Step # 203 0.0025870915820308447\n", "Step # 204 0.0025742487202602853\n", "Step # 205 0.002561533020065842\n", "Step # 206 0.0025489426025392856\n", "Step # 207 0.002536475625589039\n", "Step # 208 0.002524130283043393\n", "Step # 209 0.0025119048037794776\n", "Step # 210 0.0024997974508777603\n", "Step # 211 0.0024878065208007068\n", "Step # 212 0.002475930342594991\n", "Step # 213 0.0024641672771165517\n", "Step # 214 0.0024525157162775986\n", "Step # 215 0.002440974082314912\n", "Step # 216 0.0024295408270786777\n", "Step # 217 0.002418214431341317\n", "Step # 218 0.0024069934041255226\n", "Step # 219 0.002395876282050965\n", "Step # 220 0.0023848616286989387\n", "Step # 221 0.0023739480339946685\n", "Step # 222 0.0023631341136063117\n", "Step # 223 0.002352418508360469\n", "Step # 224 0.002341799883673461\n", "Step # 225 0.0023312769289979738\n", "Step # 226 0.002320848357284638\n", "Step # 227 0.0023105129044579276\n", "Step # 228 0.002300269328906075\n", "Step # 229 0.002290116410984435\n", "Step # 230 0.0022800529525321095\n", "Step # 231 0.002270077776401008\n", "Step # 232 0.00226018972599765\n", "Step # 233 0.0022503876648364727\n", "Step # 234 0.002240670476104961\n", "Step # 235 0.0022310370622400065\n", "Step # 236 0.0022214863445151186\n", "Step # 237 0.0022120172626381566\n", "Step # 238 0.0022026287743595026\n", "Step # 239 0.0021933198550899\n", "Step # 240 0.0021840894975282666\n", "Step # 241 0.0021749367112986412\n", "Step # 242 0.0021658605225963992\n", "Step # 243 0.0021568599738431664\n", "Step # 244 0.002147934123350419\n", "Step # 245 0.0021390820449913844\n", "Step # 246 0.00213030282788102\n", "Step # 247 0.0021215955760638706\n", "Step # 248 0.0021129594082095427\n", "Step # 249 0.002104393457315599\n", "Step # 250 0.0020958968704177764\n", "Step # 251 0.0020874688083069564\n", "Step # 252 0.002079108445253255\n", "Step # 253 0.0020708149687364515\n", "Step # 254 0.002062587579183047\n", "Step # 255 0.002054425489709562\n", "Step # 256 0.0020463279258717145\n", "Step # 257 0.0020382941254198286\n", "Step # 258 0.0020303233380597743\n", "Step # 259 0.00202241482521953\n", "Step # 260 0.0020145678598212833\n", "Step # 261 0.002006781726058751\n", "Step # 262 0.001999055719179698\n", "Step # 263 0.00199138914527338\n", "Step # 264 0.001983781321063024\n", "Step # 265 0.0019762315737029182\n", "Step # 266 0.001968739240580211\n", "Step # 267 0.00196130366912113\n", "Step # 268 0.0019539242166017054\n", "Step # 269 0.0019466002499626315\n", "Step # 270 0.0019393311456283979\n", "Step # 271 0.0019321162893303435\n", "Step # 272 0.001924955075933836\n", "Step # 273 0.0019178469092690424\n", "Step # 274 0.001910791201965695\n", "Step # 275 0.001903787375291321\n", "Step # 276 0.0018968348589930967\n", "Step # 277 0.0018899330911432088\n", "Step # 278 0.0018830815179874684\n", "Step # 279 0.0018762795937972963\n", "Step # 280 0.001869526780724939\n", "Step # 281 0.0018628225486617025\n", "Step # 282 0.0018561663750993143\n", "Step # 283 0.00184955774499431\n", "Step # 284 0.0018429961506351163\n", "Step # 285 0.0018364810915121475\n", "Step # 286 0.0018300120741906545\n", "Step # 287 0.0018235886121860377\n", "Step # 288 0.0018172102258421216\n", "Step # 289 0.001810876442211629\n", "Step # 290 0.001804586794939505\n", "Step # 291 0.00179834082414832\n", "Step # 292 0.001792138076326316\n", "Step # 293 0.001785978104217711\n", "Step # 294 0.0017798604667150894\n", "Step # 295 0.0017737847287542753\n", "Step # 296 0.0017677504612111706\n", "Step # 297 0.001761757240800724\n", "Step # 298 0.0017558046499780197\n", "Step # 299 0.0017498922768413362\n", "Step # 300 0.0017440197150370828\n", "Step # 301 0.001738186563666814\n", "Step # 302 0.0017323924271959753\n", "Step # 303 0.0017266369153644931\n", "Step # 304 0.0017209196430992434\n", "Step # 305 0.0017152402304281022\n", "Step # 306 0.0017095983023958218\n", "Step # 307 0.0017039934889815072\n", "Step # 308 0.0016984254250177687\n", "Step # 309 0.0016928937501113808\n", "Step # 310 0.0016873981085655373\n", "Step # 311 0.001681938149303642\n", "Step # 312 0.0016765135257945742\n", "Step # 313 0.0016711238959792873\n", "Step # 314 0.0016657689221990137\n", "Step # 315 0.0016604482711246983\n", "Step # 316 0.0016551616136878556\n", "Step # 317 0.0016499086250127352\n", "Step # 318 0.0016446889843497237\n", "Step # 319 0.0016395023750101637\n", "Step # 320 0.0016343484843021818\n", "Step # 321 0.001629227003467954\n", "Step # 322 0.0016241376276219487\n", "Step # 323 0.001619080055690501\n", "Step # 324 0.001614053990352399\n", "Step # 325 0.001609059137980624\n", "Step # 326 0.0016040952085851895\n", "Step # 327 0.001599161915756965\n", "Step # 328 0.0015942589766126342\n", "Step # 329 0.001589386111740561\n", "Step # 330 0.0015845430451477908\n", "Step # 331 0.0015797295042078328\n", "Step # 332 0.0015749452196096151\n", "Step # 333 0.0015701899253071257\n", "Step # 334 0.0015654633584701642\n", "Step # 335 0.0015607652594359178\n", "Step # 336 0.0015560953716613338\n", "Step # 337 0.0015514534416765074\n", "Step # 338 0.0015468392190387274\n", "Step # 339 0.0015422524562874749\n", "Step # 340 0.0015376929089001685\n", "Step # 341 0.0015331603352486678\n", "Step # 342 0.0015286544965566408\n", "Step # 343 0.0015241751568575948\n", "Step # 344 0.0015197220829536045\n", "Step # 345 0.0015152950443749508\n", "Step # 346 0.0015108938133403076\n", "Step # 347 0.001506518164717592\n", "Step # 348 0.0015021678759856764\n", "Step # 349 0.001497842727196576\n", "Step # 350 0.0014935425009383841\n", "Step # 351 0.0014892669822988897\n", "Step # 352 0.0014850159588296608\n", "Step # 353 0.0014807892205109291\n", "Step # 354 0.0014765865597169416\n", "Step # 355 0.0014724077711819657\n", "Step # 356 0.0014682526519668873\n", "Step # 357 0.0014641210014262843\n", "Step # 358 0.0014600126211761822\n", "Step # 359 0.0014559273150622164\n", "Step # 360 0.0014518648891284985\n", "Step # 361 0.0014478251515867807\n", "Step # 362 0.0014438079127864097\n", "Step # 363 0.0014398129851845107\n", "Step # 364 0.0014358401833168212\n", "Step # 365 0.0014318893237690515\n", "Step # 366 0.001427960225148534\n", "Step # 367 0.0014240527080565143\n", "Step # 368 0.0014201665950608262\n", "Step # 369 0.001416301710669022\n", "Step # 370 0.001412457881301926\n", "Step # 371 0.0014086349352676292\n", "Step # 372 0.001404832702735981\n", "Step # 373 0.0014010510157132982\n", "Step # 374 0.001397289708017771\n", "Step # 375 0.0013935486152549387\n", "Step # 376 0.0013898275747938728\n", "Step # 377 0.0013861264257434972\n", "Step # 378 0.0013824450089294398\n", "Step # 379 0.0013787831668711824\n", "Step # 380 0.0013751407437595923\n", "Step # 381 0.0013715175854347953\n", "Step # 382 0.001367913539364437\n", "Step # 383 0.0013643284546222562\n", "Step # 384 0.0013607621818670114\n", "Step # 385 0.0013572145733217174\n", "Step # 386 0.0013536854827532383\n", "Step # 387 0.0013501747654521238\n", "Step # 388 0.0013466822782129332\n", "Step # 389 0.0013432078793146056\n", "Step # 390 0.0013397514285014057\n", "Step # 391 0.0013363127869639735\n", "Step # 392 0.001332891817320703\n", "Step # 393 0.0013294883835995312\n", "Step # 394 0.0013261023512197853\n", "Step # 395 0.0013227335869745427\n", "Step # 396 0.0013193819590130663\n", "Step # 397 0.001316047336823662\n", "Step # 398 0.0013127295912166706\n", "Step # 399 0.001309428594307766\n", "Step # 400 0.001306144219501601\n", "Step # 401 0.0013028763414754797\n", "Step # 402 0.0012996248361635194\n", "Step # 403 0.0012963895807408523\n", "Step # 404 0.0012931704536082113\n", "Step # 405 0.0012899673343766713\n", "Step # 406 0.0012867801038525559\n", "Step # 407 0.0012836086440227708\n", "Step # 408 0.001280452838040121\n", "Step # 409 0.0012773125702089847\n", "Step # 410 0.0012741877259711454\n", "Step # 411 0.001271078191891917\n", "Step # 412 0.001267983855646318\n", "Step # 413 0.0012649046060055527\n", "Step # 414 0.0012618403328237514\n", "Step # 415 0.001258790927024721\n", "Step # 416 0.0012557562805890292\n", "Step # 417 0.001252736286541282\n", "Step # 418 0.0012497308389374862\n", "Step # 419 0.0012467398328526748\n", "Step # 420 0.0012437631643686506\n", "Step # 421 0.0012408007305620295\n", "Step # 422 0.0012378524294922606\n", "Step # 423 0.0012349181601899752\n", "Step # 424 0.0012319978226454458\n", "Step # 425 0.0012290913177972365\n", "Step # 426 0.0012261985475209183\n", "Step # 427 0.0012233194146180962\n", "Step # 428 0.0012204538228054985\n", "Step # 429 0.0012176016767041437\n", "Step # 430 0.0012147628818288965\n", "Step # 431 0.001211937344577902\n", "Step # 432 0.0012091249722223463\n", "Step # 433 0.001206325672896266\n", "Step # 434 0.0012035393555866172\n", "Step # 435 0.0012007659301232624\n", "Step # 436 0.0011980053071693644\n", "Step # 437 0.0011952573982117374\n", "Step # 438 0.0011925221155513496\n", "Step # 439 0.0011897993722939958\n", "Step # 440 0.001187089082341147\n", "Step # 441 0.00118439116038077\n", "Step # 442 0.0011817055218784353\n", "Step # 443 0.0011790320830684352\n", "Step # 444 0.0011763707609451254\n", "Step # 445 0.0011737214732542571\n", "Step # 446 0.0011710841384845044\n", "Step # 447 0.0011684586758591653\n", "Step # 448 0.001165845005327818\n", "Step # 449 0.0011632430475582091\n", "Step # 450 0.0011606527239282583\n", "Step # 451 0.0011580739565180418\n", "Step # 452 0.0011555066681020623\n", "Step # 453 0.0011529507821414794\n", "Step # 454 0.0011504062227765052\n", "Step # 455 0.0011478729148188818\n", "Step # 456 0.0011453507837445055\n", "Step # 457 0.001142839755686052\n", "Step # 458 0.001140339757425801\n", "Step # 459 0.0011378507163885086\n", "Step # 460 0.001135372560634334\n", "Step # 461 0.0011329052188519735\n", "Step # 462 0.00113044862035173\n", "Step # 463 0.0011280026950588058\n", "Step # 464 0.0011255673735066606\n", "Step # 465 0.0011231425868303204\n", "Step # 466 0.0011207282667599853\n", "Step # 467 0.0011183243456145617\n", "Step # 468 0.0011159307562953504\n", "Step # 469 0.0011135474322797852\n", "Step # 470 0.0011111743076152691\n", "Step # 471 0.0011088113169130636\n", "Step # 472 0.001106458395342307\n", "Step # 473 0.0011041154786240565\n", "Step # 474 0.001101782503025445\n", "Step # 475 0.00109945940535386\n", "Step # 476 0.0010971461229512618\n", "Step # 477 0.0010948425936885542\n", "Step # 478 0.0010925487559599748\n", "Step # 479 0.0010902645486776347\n", "Step # 480 0.001087989911266041\n", "Step # 481 0.0010857247836567896\n", "Step # 482 0.0010834691062832588\n", "Step # 483 0.0010812228200753333\n", "Step # 484 0.0010789858664543004\n", "Step # 485 0.001076758187327722\n", "Step # 486 0.0010745397250843912\n", "Step # 487 0.0010723304225894125\n", "Step # 488 0.0010701302231791916\n", "Step # 489 0.0010679390706566816\n", "Step # 490 0.0010657569092865494\n", "Step # 491 0.0010635836837904397\n", "Step # 492 0.0010614193393423154\n", "Step # 493 0.0010592638215638258\n", "Step # 494 0.0010571170765197595\n", "Step # 495 0.0010549790507135449\n", "Step # 496 0.0010528496910827633\n", "Step # 497 0.0010507289449948108\n", "Step # 498 0.0010486167602425077\n", "Step # 499 0.0010465130850398144\n" ] } ], "source": [ "xs = [\n", " [2.0, 3.0, -1.0],\n", " [3.0, -1.0, 0.5],\n", " [0.5, 1.0, 1.0],\n", " [1.0, 1.0, -1.0],\n", "]\n", "ys = [1.0, -1.0, -1.0, 1.0] # desired targets\n", "ls = []\n", "\n", "for k in range(500):\n", " # forward pass\n", " ypred = [n(x) for x in xs]\n", " loss = sum((yout - ygt)**2 for ygt, yout in zip(ys, ypred))\n", "\n", " # backward pass\n", " for p in n.parameters():\n", " p.grad = 0.0 # YOU NEED TO clear out your gradient before every backward pass or it will just accumulate\n", " loss.backward()\n", "\n", " # update (the gradient descent)\n", " for p in n.parameters():\n", " p.data += -0.05 * p.grad\n", "\n", " print('Step #', k, ' ', loss.data)\n", " ls += [loss.data]" ] }, { "cell_type": "code", "execution_count": 245, "id": "8615e709-2365-4988-bad8-473590ff22b9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[Value(data=0.9868926891382584),\n", " Value(data=-0.9822284602914287),\n", " Value(data=-0.9834981983414189),\n", " Value(data=0.9831334891456814)]" ] }, "execution_count": 245, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ypred = [n(x) for x in xs]\n", "ypred" ] }, { "cell_type": "code", "execution_count": 246, "id": "046c7e20-5384-476f-abb7-d6d313561ebb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 246, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(np.arange(0, 500, 1), ls)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 5 }