This example trains a 3-layer network using Fisher's Iris data with four continuous input attributes and three output classifications. This is perhaps the best known database to be found in the pattern recognition literature. Fisher's paper is a classic in the field. The data set contains 3 classes of 50 instances each, where each class refers to a type of iris plant.
The structure of the network consists of four input nodes and three layers, with four perceptrons in the first hidden layer, three perceptrons in the second hidden layer and three in the output layer.
The four input attributes represent
The output attribute represents the class of the iris plant and are encoded using binary encoding.
There are a total of 46 weights in this network, including the bias weights. All hidden layers use the logistic activation function. Since the target output is multi-classification the softmax activation function is used in the output layer and the MultiClassification
error function class is used by the trainer. The error class MultiClassification
combines the cross-entropy error claculations and the softmax function.
import com.imsl.datamining.neural.*; import com.imsl.math.PrintMatrix; import com.imsl.math.PrintMatrixFormat; import java.io.*; import java.util.logging.*; //***************************************************************************** // Three Layer Feed-Forward Network with 4 inputs, all // continuous, and 3 classification categories. // // new classification training_ex5.c // // This is perhaps the best known database to be found in the pattern // recognition literature. Fisher's paper is a classic in the field. // The data set contains 3 classes of 50 instances each, // where each class refers to a type of iris plant. One class is // linearly separable from the other 2; the latter are NOT linearly // separable from each other. // // Predicted attribute: class of iris plant. // 1=Iris Setosa, 2=Iris Versicolour, and 3=Iris Virginica // // Input Attributes (4 Continuous Attributes) // X1: Sepal length, X2: Sepal width, X3: Petal length, and X4: Petal width //***************************************************************************** public class MultiClassificationEx1 implements Serializable { private static int nObs = 150; // number of training patterns private static int nInputs = 4; // 9 nominal coded as 0=x, 1=o, 2=blank private static int nOutputs = 3; // one continuous output (nClasses=2) private static boolean trace = true; // Turns on/off training log // irisData[]: The raw data matrix. This is a 2-D matrix with 150 rows and 5 columns. * // The first 4 columns are the continuous input attributes and the 5th * // column is the classification category (1-3). These data contain no * // categorical input attributes. * private static double[][] irisData = { {5.1,3.5,1.4,0.2,1},{4.9,3.0,1.4,0.2,1},{4.7,3.2,1.3,0.2,1},{4.6,3.1,1.5,0.2,1}, {5.0,3.6,1.4,0.2,1},{5.4,3.9,1.7,0.4,1},{4.6,3.4,1.4,0.3,1},{5.0,3.4,1.5,0.2,1}, {4.4,2.9,1.4,0.2,1},{4.9,3.1,1.5,0.1,1},{5.4,3.7,1.5,0.2,1},{4.8,3.4,1.6,0.2,1}, {4.8,3.0,1.4,0.1,1},{4.3,3.0,1.1,0.1,1},{5.8,4.0,1.2,0.2,1},{5.7,4.4,1.5,0.4,1}, {5.4,3.9,1.3,0.4,1},{5.1,3.5,1.4,0.3,1},{5.7,3.8,1.7,0.3,1},{5.1,3.8,1.5,0.3,1}, {5.4,3.4,1.7,0.2,1},{5.1,3.7,1.5,0.4,1},{4.6,3.6,1.0,0.2,1},{5.1,3.3,1.7,0.5,1}, {4.8,3.4,1.9,0.2,1},{5.0,3.0,1.6,0.2,1},{5.0,3.4,1.6,0.4,1},{5.2,3.5,1.5,0.2,1}, {5.2,3.4,1.4,0.2,1},{4.7,3.2,1.6,0.2,1},{4.8,3.1,1.6,0.2,1},{5.4,3.4,1.5,0.4,1}, {5.2,4.1,1.5,0.1,1},{5.5,4.2,1.4,0.2,1},{4.9,3.1,1.5,0.1,1},{5.0,3.2,1.2,0.2,1}, {5.5,3.5,1.3,0.2,1},{4.9,3.1,1.5,0.1,1},{4.4,3.0,1.3,0.2,1},{5.1,3.4,1.5,0.2,1}, {5.0,3.5,1.3,0.3,1},{4.5,2.3,1.3,0.3,1},{4.4,3.2,1.3,0.2,1},{5.0,3.5,1.6,0.6,1}, {5.1,3.8,1.9,0.4,1},{4.8,3.0,1.4,0.3,1},{5.1,3.8,1.6,0.2,1},{4.6,3.2,1.4,0.2,1}, {5.3,3.7,1.5,0.2,1},{5.0,3.3,1.4,0.2,1}, {7.0,3.2,4.7,1.4,2},{6.4,3.2,4.5,1.5,2},{6.9,3.1,4.9,1.5,2},{5.5,2.3,4.0,1.3,2}, {6.5,2.8,4.6,1.5,2},{5.7,2.8,4.5,1.3,2},{6.3,3.3,4.7,1.6,2},{4.9,2.4,3.3,1.0,2}, {6.6,2.9,4.6,1.3,2},{5.2,2.7,3.9,1.4,2},{5.0,2.0,3.5,1.0,2},{5.9,3.0,4.2,1.5,2}, {6.0,2.2,4.0,1.0,2},{6.1,2.9,4.7,1.4,2},{5.6,2.9,3.6,1.3,2},{6.7,3.1,4.4,1.4,2}, {5.6,3.0,4.5,1.5,2},{5.8,2.7,4.1,1.0,2},{6.2,2.2,4.5,1.5,2},{5.6,2.5,3.9,1.1,2}, {5.9,3.2,4.8,1.8,2},{6.1,2.8,4.0,1.3,2},{6.3,2.5,4.9,1.5,2},{6.1,2.8,4.7,1.2,2}, {6.4,2.9,4.3,1.3,2},{6.6,3.0,4.4,1.4,2},{6.8,2.8,4.8,1.4,2},{6.7,3.0,5.0,1.7,2}, {6.0,2.9,4.5,1.5,2},{5.7,2.6,3.5,1.0,2},{5.5,2.4,3.8,1.1,2},{5.5,2.4,3.7,1.0,2}, {5.8,2.7,3.9,1.2,2},{6.0,2.7,5.1,1.6,2},{5.4,3.0,4.5,1.5,2},{6.0,3.4,4.5,1.6,2}, {6.7,3.1,4.7,1.5,2},{6.3,2.3,4.4,1.3,2},{5.6,3.0,4.1,1.3,2},{5.5,2.5,4.0,1.3,2}, {5.5,2.6,4.4,1.2,2},{6.1,3.0,4.6,1.4,2},{5.8,2.6,4.0,1.2,2},{5.0,2.3,3.3,1.0,2}, {5.6,2.7,4.2,1.3,2},{5.7,3.0,4.2,1.2,2},{5.7,2.9,4.2,1.3,2},{6.2,2.9,4.3,1.3,2}, {5.1,2.5,3.0,1.1,2},{5.7,2.8,4.1,1.3,2}, {6.3,3.3,6.0,2.5,3},{5.8,2.7,5.1,1.9,3},{7.1,3.0,5.9,2.1,3},{6.3,2.9,5.6,1.8,3}, {6.5,3.0,5.8,2.2,3},{7.6,3.0,6.6,2.1,3},{4.9,2.5,4.5,1.7,3},{7.3,2.9,6.3,1.8,3}, {6.7,2.5,5.8,1.8,3},{7.2,3.6,6.1,2.5,3},{6.5,3.2,5.1,2.0,3},{6.4,2.7,5.3,1.9,3}, {6.8,3.0,5.5,2.1,3},{5.7,2.5,5.0,2.0,3},{5.8,2.8,5.1,2.4,3},{6.4,3.2,5.3,2.3,3}, {6.5,3.0,5.5,1.8,3},{7.7,3.8,6.7,2.2,3},{7.7,2.6,6.9,2.3,3},{6.0,2.2,5.0,1.5,3}, {6.9,3.2,5.7,2.3,3},{5.6,2.8,4.9,2.0,3},{7.7,2.8,6.7,2.0,3},{6.3,2.7,4.9,1.8,3}, {6.7,3.3,5.7,2.1,3},{7.2,3.2,6.0,1.8,3},{6.2,2.8,4.8,1.8,3},{6.1,3.0,4.9,1.8,3}, {6.4,2.8,5.6,2.1,3},{7.2,3.0,5.8,1.6,3},{7.4,2.8,6.1,1.9,3},{7.9,3.8,6.4,2.0,3}, {6.4,2.8,5.6,2.2,3},{6.3,2.8,5.1,1.5,3},{6.1,2.6,5.6,1.4,3},{7.7,3.0,6.1,2.3,3}, {6.3,3.4,5.6,2.4,3},{6.4,3.1,5.5,1.8,3},{6.0,3.0,4.8,1.8,3},{6.9,3.1,5.4,2.1,3}, {6.7,3.1,5.6,2.4,3},{6.9,3.1,5.1,2.3,3},{5.8,2.7,5.1,1.9,3},{6.8,3.2,5.9,2.3,3}, {6.7,3.3,5.7,2.5,3},{6.7,3.0,5.2,2.3,3},{6.3,2.5,5.0,1.9,3},{6.5,3.0,5.2,2.0,3}, {6.2,3.4,5.4,2.3,3},{5.9,3.0,5.1,1.8,3} }; public static void main(String[] args) throws Exception { double xData[][] = new double[nObs][nInputs]; int yData[] = new int[nObs]; for (int i = 0; i < nObs; i++) { for (int j = 0; j < nInputs; j++) { xData[i][j] = irisData[i][j]; } yData[i] = (int)irisData[i][4]; } // Create network FeedForwardNetwork network = new FeedForwardNetwork(); network.getInputLayer().createInputs(nInputs); network.createHiddenLayer().createPerceptrons(4, Activation.LOGISTIC, 0.0); network.createHiddenLayer().createPerceptrons(3, Activation.LOGISTIC, 0.0); network.getOutputLayer().createPerceptrons(nOutputs, Activation.SOFTMAX, 0.0); network.linkAll(); MultiClassification classification = new MultiClassification(network); // Create trainer QuasiNewtonTrainer trainer = new QuasiNewtonTrainer(); trainer.setError(classification.getError()); trainer.setMaximumTrainingIterations(1000); // If tracing is requested setup training logger if (trace) { Handler handler = new FileHandler("ClassificationNetworkTraining.log"); Logger logger = Logger.getLogger("com.imsl.datamining.neural"); logger.setLevel(Level.FINEST); logger.addHandler(handler); handler.setFormatter(QuasiNewtonTrainer.getFormatter()); } // Train Network long t0 = System.currentTimeMillis(); classification.train(trainer, xData, yData); // Display Network Errors double stats[] = classification.computeStatistics(xData, yData); System.out.println("***********************************************"); System.out.println("--> Cross-entropy error: "+(float)stats[0]); System.out.println("--> Classification error rate: "+(float)stats[1]); System.out.println("***********************************************"); System.out.println(""); double weight[] = network.getWeights(); double gradient[] = trainer.getErrorGradient(); double wg[][] = new double[weight.length][2]; for(int i = 0; i < weight.length; i++) { wg[i][0] = weight[i]; wg[i][1] = gradient[i]; } PrintMatrixFormat pmf = new PrintMatrixFormat(); pmf.setNumberFormat(new java.text.DecimalFormat("0.000000")); pmf.setColumnLabels(new String[]{"Weights", "Gradients"}); new PrintMatrix().print(pmf,wg); double report[][] = new double[nObs][nInputs+2]; for (int i = 0; i < nObs; i++) { for (int j = 0; j < nInputs; j++) { report[i][j] = xData[i][j]; } report[i][nInputs] = irisData[i][4]; report[i][nInputs+1] = classification.predictedClass(xData[i]); } pmf = new PrintMatrixFormat(); pmf.setColumnLabels(new String[]{ "Sepal Length", "Sepal Width", "Petal Length", "Petal Width", "Expected", "Predicted"}); new PrintMatrix("Forecast").print(pmf, report); // ********************************************************************** // DISPLAY CLASSIFICATION STATISTICS // ********************************************************************** double statsClass[] = classification.computeStatistics(xData, yData); // Display Network Errors System.out.println("***********************************************"); System.out.println("--> Cross-Entropy Error: "+(float)statsClass[0]); System.out.println("--> Classification Error: "+(float)statsClass[1]); System.out.println("***********************************************"); System.out.println(""); long t1 = System.currentTimeMillis(); double small = 1.e-7; double time = t1-t0; //Math.max(small, (double)(t1-t0)/(double)iters); time = time/1000; System.out.println("****************Time: "+time); System.out.println("Cross-Entropy Error Value = "+trainer.getErrorValue()); } }
*********************************************** --> Cross-entropy error: 4.640623 --> Classification error rate: 0.006666667 *********************************************** Weights Gradients 0 -51.777881 -0.021660 1 605.119380 0.000000 2 -284.226877 0.000000 3 327.038883 0.000000 4 -41.160485 -0.009887 5 -867.891312 0.000000 6 -1210.846071 0.000000 7 -994.103717 0.000000 8 73.932788 -0.016740 9 -346.829319 0.000000 10 704.482597 0.000000 11 -497.908892 0.000000 12 51.636506 -0.006301 13 1943.984336 0.000000 14 1516.711136 0.000000 15 1935.687178 0.000000 16 -3.143561 -2.271656 17 -443.852301 -7.201949 18 242.475544 -0.000024 19 23.461487 -2.272793 20 189.287779 -7.201954 21 260.386655 -0.096456 22 564.420647 -2.272793 23 607.227248 -7.201954 24 -62.368750 -0.096456 25 163.370794 -2.272793 26 216.054929 -7.201954 27 296.537883 -0.096456 28 -15686.506783 0.000000 29 3478.164215 0.004606 30 12209.342568 -0.004606 31 -15443.797985 0.000000 32 4719.334347 0.002674 33 10725.463639 -0.002674 34 -15303.926099 0.000000 35 3602.472102 0.004863 36 11702.453998 -0.004863 37 -19.854440 -0.003322 38 965.005400 0.000000 39 874.394173 0.000000 40 898.666721 0.000000 41 -745.305267 -2.272793 42 -568.545362 -7.201954 43 -494.170957 -0.096456 44 36175.248628 0.000000 45 -8292.572938 0.004882 46 -27882.675691 -0.004882 Forecast Sepal Length Sepal Width Petal Length Petal Width Expected Predicted 0 5.1 3.5 1.4 0.2 1 1 1 4.9 3 1.4 0.2 1 1 2 4.7 3.2 1.3 0.2 1 1 3 4.6 3.1 1.5 0.2 1 1 4 5 3.6 1.4 0.2 1 1 5 5.4 3.9 1.7 0.4 1 1 6 4.6 3.4 1.4 0.3 1 1 7 5 3.4 1.5 0.2 1 1 8 4.4 2.9 1.4 0.2 1 1 9 4.9 3.1 1.5 0.1 1 1 10 5.4 3.7 1.5 0.2 1 1 11 4.8 3.4 1.6 0.2 1 1 12 4.8 3 1.4 0.1 1 1 13 4.3 3 1.1 0.1 1 1 14 5.8 4 1.2 0.2 1 1 15 5.7 4.4 1.5 0.4 1 1 16 5.4 3.9 1.3 0.4 1 1 17 5.1 3.5 1.4 0.3 1 1 18 5.7 3.8 1.7 0.3 1 1 19 5.1 3.8 1.5 0.3 1 1 20 5.4 3.4 1.7 0.2 1 1 21 5.1 3.7 1.5 0.4 1 1 22 4.6 3.6 1 0.2 1 1 23 5.1 3.3 1.7 0.5 1 1 24 4.8 3.4 1.9 0.2 1 1 25 5 3 1.6 0.2 1 1 26 5 3.4 1.6 0.4 1 1 27 5.2 3.5 1.5 0.2 1 1 28 5.2 3.4 1.4 0.2 1 1 29 4.7 3.2 1.6 0.2 1 1 30 4.8 3.1 1.6 0.2 1 1 31 5.4 3.4 1.5 0.4 1 1 32 5.2 4.1 1.5 0.1 1 1 33 5.5 4.2 1.4 0.2 1 1 34 4.9 3.1 1.5 0.1 1 1 35 5 3.2 1.2 0.2 1 1 36 5.5 3.5 1.3 0.2 1 1 37 4.9 3.1 1.5 0.1 1 1 38 4.4 3 1.3 0.2 1 1 39 5.1 3.4 1.5 0.2 1 1 40 5 3.5 1.3 0.3 1 1 41 4.5 2.3 1.3 0.3 1 1 42 4.4 3.2 1.3 0.2 1 1 43 5 3.5 1.6 0.6 1 1 44 5.1 3.8 1.9 0.4 1 1 45 4.8 3 1.4 0.3 1 1 46 5.1 3.8 1.6 0.2 1 1 47 4.6 3.2 1.4 0.2 1 1 48 5.3 3.7 1.5 0.2 1 1 49 5 3.3 1.4 0.2 1 1 50 7 3.2 4.7 1.4 2 2 51 6.4 3.2 4.5 1.5 2 2 52 6.9 3.1 4.9 1.5 2 2 53 5.5 2.3 4 1.3 2 2 54 6.5 2.8 4.6 1.5 2 2 55 5.7 2.8 4.5 1.3 2 2 56 6.3 3.3 4.7 1.6 2 2 57 4.9 2.4 3.3 1 2 2 58 6.6 2.9 4.6 1.3 2 2 59 5.2 2.7 3.9 1.4 2 2 60 5 2 3.5 1 2 2 61 5.9 3 4.2 1.5 2 2 62 6 2.2 4 1 2 2 63 6.1 2.9 4.7 1.4 2 2 64 5.6 2.9 3.6 1.3 2 2 65 6.7 3.1 4.4 1.4 2 2 66 5.6 3 4.5 1.5 2 2 67 5.8 2.7 4.1 1 2 2 68 6.2 2.2 4.5 1.5 2 2 69 5.6 2.5 3.9 1.1 2 2 70 5.9 3.2 4.8 1.8 2 2 71 6.1 2.8 4 1.3 2 2 72 6.3 2.5 4.9 1.5 2 2 73 6.1 2.8 4.7 1.2 2 2 74 6.4 2.9 4.3 1.3 2 2 75 6.6 3 4.4 1.4 2 2 76 6.8 2.8 4.8 1.4 2 2 77 6.7 3 5 1.7 2 2 78 6 2.9 4.5 1.5 2 2 79 5.7 2.6 3.5 1 2 2 80 5.5 2.4 3.8 1.1 2 2 81 5.5 2.4 3.7 1 2 2 82 5.8 2.7 3.9 1.2 2 2 83 6 2.7 5.1 1.6 2 3 84 5.4 3 4.5 1.5 2 2 85 6 3.4 4.5 1.6 2 2 86 6.7 3.1 4.7 1.5 2 2 87 6.3 2.3 4.4 1.3 2 2 88 5.6 3 4.1 1.3 2 2 89 5.5 2.5 4 1.3 2 2 90 5.5 2.6 4.4 1.2 2 2 91 6.1 3 4.6 1.4 2 2 92 5.8 2.6 4 1.2 2 2 93 5 2.3 3.3 1 2 2 94 5.6 2.7 4.2 1.3 2 2 95 5.7 3 4.2 1.2 2 2 96 5.7 2.9 4.2 1.3 2 2 97 6.2 2.9 4.3 1.3 2 2 98 5.1 2.5 3 1.1 2 2 99 5.7 2.8 4.1 1.3 2 2 100 6.3 3.3 6 2.5 3 3 101 5.8 2.7 5.1 1.9 3 3 102 7.1 3 5.9 2.1 3 3 103 6.3 2.9 5.6 1.8 3 3 104 6.5 3 5.8 2.2 3 3 105 7.6 3 6.6 2.1 3 3 106 4.9 2.5 4.5 1.7 3 3 107 7.3 2.9 6.3 1.8 3 3 108 6.7 2.5 5.8 1.8 3 3 109 7.2 3.6 6.1 2.5 3 3 110 6.5 3.2 5.1 2 3 3 111 6.4 2.7 5.3 1.9 3 3 112 6.8 3 5.5 2.1 3 3 113 5.7 2.5 5 2 3 3 114 5.8 2.8 5.1 2.4 3 3 115 6.4 3.2 5.3 2.3 3 3 116 6.5 3 5.5 1.8 3 3 117 7.7 3.8 6.7 2.2 3 3 118 7.7 2.6 6.9 2.3 3 3 119 6 2.2 5 1.5 3 3 120 6.9 3.2 5.7 2.3 3 3 121 5.6 2.8 4.9 2 3 3 122 7.7 2.8 6.7 2 3 3 123 6.3 2.7 4.9 1.8 3 3 124 6.7 3.3 5.7 2.1 3 3 125 7.2 3.2 6 1.8 3 3 126 6.2 2.8 4.8 1.8 3 3 127 6.1 3 4.9 1.8 3 3 128 6.4 2.8 5.6 2.1 3 3 129 7.2 3 5.8 1.6 3 3 130 7.4 2.8 6.1 1.9 3 3 131 7.9 3.8 6.4 2 3 3 132 6.4 2.8 5.6 2.2 3 3 133 6.3 2.8 5.1 1.5 3 3 134 6.1 2.6 5.6 1.4 3 3 135 7.7 3 6.1 2.3 3 3 136 6.3 3.4 5.6 2.4 3 3 137 6.4 3.1 5.5 1.8 3 3 138 6 3 4.8 1.8 3 3 139 6.9 3.1 5.4 2.1 3 3 140 6.7 3.1 5.6 2.4 3 3 141 6.9 3.1 5.1 2.3 3 3 142 5.8 2.7 5.1 1.9 3 3 143 6.8 3.2 5.9 2.3 3 3 144 6.7 3.3 5.7 2.5 3 3 145 6.7 3 5.2 2.3 3 3 146 6.3 2.5 5 1.9 3 3 147 6.5 3 5.2 2 3 3 148 6.2 3.4 5.4 2.3 3 3 149 5.9 3 5.1 1.8 3 3 *********************************************** --> Cross-Entropy Error: 4.640623 --> Classification Error: 0.006666667 *********************************************** ****************Time: 3.345 Cross-Entropy Error Value = 4.6406232788035595Link to Java source.