Artificial Neural Network (ANN) is a mathematical model used for Machine Learning purpose. ANN consists of input neurons, weights, hidden neurons, bias and output neurons. The input neurons form the input layer, hidden neurons form the hidden layers and the output neurons form the output layer. The number of neurons in the input layer may be different from the number hidden neurons in each hidden layer and number of neurons in the output layer. Similarly the number of output neurons may be different from the number hidden neurons in each hidden layer. There can multiple hidden layers, but can have only one input layer and one output layer. The ANN basically looks as in figure-1. Here each coloured circle represents one neuron.
(Figure -1)
The ANN typically has random values for weights and biases initially (if the network is not initialized from saved data). The input neurons are given with the input values. For each non-input neuron the output value at that neuron is calculated as shown in figure-2.
(Figure-2)
Here,
xi : the input values of input neurons
wi : the weights for each connection from the one neuron to the another
Taking all values together, the weighted sum is :
where b : bias
Now this value is applied to the activation function to get the output value. There are many activation functions used in ANN. For example Binary Step function is used in figure-2 which gives output either 0 or 1 for any given value.
The above steps are done starting from the first layer of the hidden neurons till the output neurons. Finally we will have the output values at each output neurons. This is called forward propagation, because the calculations are done from left to right towards the output. After we get the output values for each output neurons (there may be one or multiple output neurons), we compare them with the desired output values. Then the error is calculated at each output neuron by comparing the actual output and the expected output. There are many error functions used in ANN. For example the below error function calculates the square of difference between the actual and the target output at each neuron and then sums up these squared values for all output neurons and finally make it half.
If this error is more than a threshold value then it has to be propagated from the output layer till the first hidden layer and all the weights and biases have to be readjusted so as to get the target output for this given input. This process is called back propagation and it adjusts the weights. One set of forward propagation and corresponding back propagation for all inputs is called one epoch. You may have to run few epochs to get the error within a threshold value. Performing the above steps for many inputs with their target outputs is called as training the ANN.
The number of input neurons and number of output neurons depend on the problem statement. Example: for handwritten digit recognition we need 10 output neurons (0 to 9). If the input image (of a digit) is given with a fixed dimension, then the number of input neurons would be (height * width) of the input image. The number of hidden layers and number of neurons in each hidden layer is decided based on many other factors. There is no such formula to decide these numbers. Sometimes it can be decided based on the experience of the programmer.
The back propagation algorithm is given below. There are many other versions of back propagation algorithms available. Below is one among them.
where, j : any hidden neuron (left side)
k : any hidden or output neuron (right side)
O : output value
Eta : learning rate
W : weight
Wjk : weight of the connection between the jth layer neuron to the kth layer neuron
Theta : bias
I will later update this page with the back propagation algorithm which I used.
If you have any question or doubt on this back propagation algorithm, then ask me in the comment.
My ANN Implementation
After a ANN is trained with a lot of sample data, it is tested to see the performance and accuracy. Usually 70% of the available input data are used to train the network, 10% of the data are used for validating the training output and the remaining 20% data are used for testing.
(Figure-3)
Figure-3 shows how does the error get decreased over the epochs and after 8th epoch the training error and the validation error are almost converging.
I have implemented ANN following the above steps using Java programming language to recognize hand written digits. The training and testing data set were downloaded from MNIST. Since the input images are of 28*28 dimension, the number of input neurons are 784. The number of output neuron are 10. There are 2 hidden layers having 300 neurons each. I used sigmoid function as activation function. I chose learning rate as 0.005. I trained the network with ~42k images, validated the training after every epoch with ~5k images and then tested it with another ~13k images. My ANN is able to recognize the images of hand written digits with an approximate accuracy of 81%. I have to rework on how can I improve the accuracy by tweaking the parameters.
I have developed a generic code which can be used to train and test any type of data and not limited to MNIST handwritten digits.
Below are the outputs of my program during the training
Training Neural Network...
Time taken to train 41930 input data in epoch #1 = 11.95 minutes
Average training time for one input = 17 milliseconds
Completed validation of 4966 data after epoch #1. Time taken for validation = 3.60 seconds
Average validation time for one input = 0.72 milliseconds
-----------------------------------------------------------------
Time taken to train 41930 input data in epoch #2 = 17.51 minutes
Average training time for one input = 25 milliseconds
Completed validation of 4966 data after epoch #2. Time taken for validation = 3.58 seconds
Average validation time for one input = 0.72 milliseconds
-----------------------------------------------------------------
Time taken to train 41930 input data in epoch #3 = 11.95 minutes
Average training time for one input = 17 milliseconds
Completed validation of 4966 data after epoch #3. Time taken for validation = 3.61 seconds
Average validation time for one input = 0.73 milliseconds
-----------------------------------------------------------------
Time taken to train 41930 input data in epoch #4 = 11.95 minutes
Average training time for one input = 17 milliseconds
Completed validation of 4966 data after epoch #4. Time taken for validation = 4.08 seconds
Average validation time for one input = 0.82 milliseconds
-----------------------------------------------------------------
Time taken to train 41930 input data in epoch #5 = 12.02 minutes
Average training time for one input = 17 milliseconds
Completed validation of 4966 data after epoch #5. Time taken for validation = 3.57 seconds
Average validation time for one input = 0.72 milliseconds
Time taken to train 41930 input data in epoch #6 = 11.62 minutes
Average training time for one input = 16 milliseconds
Completed validation of 4966 data after epoch #6. Time taken for validation = 3.51 seconds
Average validation time for one input = 0.71 milliseconds
-----------------------------------------------------------------
Time taken to train 41930 input data in epoch #7 = 12.00 minutes
Average training time for one input = 17 milliseconds
Completed validation of 4966 data after epoch #7. Time taken for validation = 3.67 seconds
Average validation time for one input = 0.74 milliseconds
-----------------------------------------------------------------
Time taken to train 41930 input data in epoch #8 = 11.91 minutes
Average training time for one input = 17 milliseconds
Completed validation of 4966 data after epoch #8. Time taken for validation = 3.52 seconds
Average validation time for one input = 0.71 milliseconds
-----------------------------------------------------------------
Training Completed
Saved the trained model successfully.
Time Taken for Training: 1.69 hours
MNIST handwritten character dataset. Testing the input images...
Testing Completed.
Time Taken for Testing: 21.23 seconds
Total Tested Data = 13104
Incorrectly Recognized Data = 2445
Correctly Recognized Data = 10659
Accuracy = 81.34 %
Confusion matrix: input digits (leftmost vertical) vs recognized digits (bottommost horizontal)
9| 18 11 20 27 70 21 15 114 18 996
8| 15 49 39 68 12 48 46 4 920 45
7| 20 25 20 8 29 0 2 1163 23 56
6| 13 19 19 6 8 13 1190 5 22 1
5| 29 50 19 138 50 662 57 26 66 21
4| 3 16 37 7 929 3 65 8 35 113
3| 8 15 59 1076 6 42 24 38 58 24
2| 22 29 1088 36 40 3 68 31 56 7
1| 0 1454 26 4 3 3 14 3 31 6
0| 1181 0 28 16 3 11 19 5 32 3
---------------------------------------------------------------------
0 1 2 3 4 5 6 7 8 9
Thanks,
S. R. Giri
Comments
Post a Comment