what is the chain rule ?
consider the function y = m * x +b
i am going to work on the function which we can use for our machine learning problems .
consider a loss function h = [y - y_hat] ^2 which is also called mean square error
y - dependent on x
y_hat - predicted value
h = called loss
so why again two functions , ok
h = [m * x +b - y_hat]^2 looks complicated right ,this is where we will use the
chain rule ,how then ?
say partial derivative of m , b are
[1] dy /dm = x , dy/db = 1
do a substitution for h = u^2 , u = [y- y_hat]
[2]
dh/du = 2u = 2 * (y - y_hat)
du/dy = 1
so dh/dy = dh/du * du /dy --------[chain rule]
= 2 * (y - y_hat) *1
dh/dy = 2 * (y - y_hat)
so what is partial derivative of dh/dm and dh/db then ?
chain rule again from 1 and 2
dh/dm = dh/dy * dy/dm
[3] dh/dm = 2 * (y - y_hat) * x
dh/db = dh/dy * dy/db
[4] dh/db = 2 * (y - y_hat) * 1
values 3 and 4 are used in values corecction for m and b
new value of m = m - dh/dm
new value of b = b - dh /db
OK i think this is too much for now will do a implementation in our machine learning sample for better understanding .
Comments
Post a Comment