How can I visualise this gradient descent algorithm?

Issue

How can I visually display this gradient descent algorithm (e.g graph)?

import matplotlib.pyplot as plt

def sigmoid(sop):
    return 1.0 / (1 + numpy.exp(-1 * sop))

def error(predicted, target):
    return numpy.power(predicted - target, 2)

def error_predicted_deriv(predicted, target):
    return 2 * (predicted - target)

def activation_sop_deriv(sop):
    return sigmoid(sop) * (1.0 - sigmoid(sop))

def sop_w_deriv(x):
    return x

def update_w(w, grad, learning_rate):
    return w - learning_rate * grad

x = 0.1
target = 0.3
learning_rate = 0.01
w = numpy.random.rand()
print("Initial W : ", w)

iterations = 10000

for k in range(iterations):
    # Forward Pass
    y = w * x
    predicted = sigmoid(y)
    err = error(predicted, target)

    # Backward Pass
    g1 = error_predicted_deriv(predicted, target)

    g2 = activation_sop_deriv(predicted)

    g3 = sop_w_deriv(x)

    grad = g3 * g2 * g1
    # print(predicted)

    w = update_w(w, grad, learning_rate)

I tried making a very simple plot with matplotlib but couldn’t get the line to actual display (the graph initialised properly, but the line didn’t appear).

Here’s what I did:

plt.plot(iterations, predicted)
plt.ylabel("Prediction")
plt.xlabel("Iteration Number")
plt.show()

I tried doing a search but none of the resources I found applied to this particular format of gradient descent.

Solution

Both iterations and predicted are scalar values in your code, that’s why you can’t generate the line chart. You would need to store their values in two arrays in order to be able to plot them:

K = 10000

iterations = numpy.arange(K)
predicted = numpy.zeros(K)

for k in range(K):

    # Forward Pass
    y = w * x
    predicted[k] = sigmoid(y)
    err = error(predicted[k], target)

    # Backward Pass
    g1 = error_predicted_deriv(predicted[k], target)
    g2 = activation_sop_deriv(predicted[k])
    g3 = sop_w_deriv(x)

    grad = g3 * g2 * g1

    # print(predicted[k])

    w = update_w(w, grad, learning_rate)

Answered By – Flavia Giammarino

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published