This example trains a 3-layer network using Fisher's Iris data with four continuous input attributes and three output classifications. This is perhaps the best known database to be found in the pattern recognition literature. Fisher's paper is a classic in the field. The data set contains 3 classes of 50 instances each, where each class refers to a type of iris plant.
The structure of the network consists of four input nodes and three layers, with four perceptrons in the first hidden layer, three perceptrons in the second hidden layer and three in the output layer.
The four input attributes represent
The output attribute represents the class of the iris plant and are encoded using binary encoding.
There are a total of 46 weights in this network, including the bias weights. All hidden layers use the logistic activation function. Since the target output is multi-classification the softmax activation function is used in the output layer and the MultiClassification
error function class is used by the trainer. The error class MultiClassification
combines the cross-entropy error claculations and the softmax function.
import com.imsl.datamining.neural.*;
import com.imsl.math.*;
import java.io.*;
import java.util.logging.*;
//***************************************************************************
// Three Layer Feed-Forward Network with 4 inputs, all
// continuous, and 3 classification categories.
//
// new classification training_ex5.c
//
// This is perhaps the best known database to be found in the pattern
// recognition literature. Fisher's paper is a classic in the field.
// The data set contains 3 classes of 50 instances each,
// where each class refers to a type of iris plant. One class is
// linearly separable from the other 2; the latter are NOT linearly
// separable from each other.
//
// Predicted attribute: class of iris plant.
// 1=Iris Setosa, 2=Iris Versicolour, and 3=Iris Virginica
//
// Input Attributes (4 Continuous Attributes)
// X1: Sepal length, X2: Sepal width, X3: Petal length,
// and X4: Petal width
//***************************************************************************
public class MultiClassificationEx1 implements Serializable {
private static int nObs = 150; // number of training patterns
private static int nInputs = 4; // 9 nominal coded as 0=x, 1=o, 2=blank
private static int nOutputs = 3; // one continuous output (nClasses=2)
private static boolean trace = true; // Turns on/off training log
// irisData[]: The raw data matrix. This is a 2-D matrix with 150 rows
// and 5 columns. The first 4 columns are the continuous
// input attributes and the 5th column is the
// classification category (1-3). These data contain no
// categorical input attributes.
private static double[][] irisData = {
{5.1, 3.5, 1.4, 0.2, 1}, {4.9, 3.0, 1.4, 0.2, 1},
{4.7, 3.2, 1.3, 0.2, 1}, {4.6, 3.1, 1.5, 0.2, 1},
{5.0, 3.6, 1.4, 0.2, 1}, {5.4, 3.9, 1.7, 0.4, 1},
{4.6, 3.4, 1.4, 0.3, 1}, {5.0, 3.4, 1.5, 0.2, 1},
{4.4, 2.9, 1.4, 0.2, 1}, {4.9, 3.1, 1.5, 0.1, 1},
{5.4, 3.7, 1.5, 0.2, 1}, {4.8, 3.4, 1.6, 0.2, 1},
{4.8, 3.0, 1.4, 0.1, 1}, {4.3, 3.0, 1.1, 0.1, 1},
{5.8, 4.0, 1.2, 0.2, 1}, {5.7, 4.4, 1.5, 0.4, 1},
{5.4, 3.9, 1.3, 0.4, 1}, {5.1, 3.5, 1.4, 0.3, 1},
{5.7, 3.8, 1.7, 0.3, 1}, {5.1, 3.8, 1.5, 0.3, 1},
{5.4, 3.4, 1.7, 0.2, 1}, {5.1, 3.7, 1.5, 0.4, 1},
{4.6, 3.6, 1.0, 0.2, 1}, {5.1, 3.3, 1.7, 0.5, 1},
{4.8, 3.4, 1.9, 0.2, 1}, {5.0, 3.0, 1.6, 0.2, 1},
{5.0, 3.4, 1.6, 0.4, 1}, {5.2, 3.5, 1.5, 0.2, 1},
{5.2, 3.4, 1.4, 0.2, 1}, {4.7, 3.2, 1.6, 0.2, 1},
{4.8, 3.1, 1.6, 0.2, 1}, {5.4, 3.4, 1.5, 0.4, 1},
{5.2, 4.1, 1.5, 0.1, 1}, {5.5, 4.2, 1.4, 0.2, 1},
{4.9, 3.1, 1.5, 0.1, 1}, {5.0, 3.2, 1.2, 0.2, 1},
{5.5, 3.5, 1.3, 0.2, 1}, {4.9, 3.1, 1.5, 0.1, 1},
{4.4, 3.0, 1.3, 0.2, 1}, {5.1, 3.4, 1.5, 0.2, 1},
{5.0, 3.5, 1.3, 0.3, 1}, {4.5, 2.3, 1.3, 0.3, 1},
{4.4, 3.2, 1.3, 0.2, 1}, {5.0, 3.5, 1.6, 0.6, 1},
{5.1, 3.8, 1.9, 0.4, 1}, {4.8, 3.0, 1.4, 0.3, 1},
{5.1, 3.8, 1.6, 0.2, 1}, {4.6, 3.2, 1.4, 0.2, 1},
{5.3, 3.7, 1.5, 0.2, 1}, {5.0, 3.3, 1.4, 0.2, 1},
{7.0, 3.2, 4.7, 1.4, 2}, {6.4, 3.2, 4.5, 1.5, 2},
{6.9, 3.1, 4.9, 1.5, 2}, {5.5, 2.3, 4.0, 1.3, 2},
{6.5, 2.8, 4.6, 1.5, 2}, {5.7, 2.8, 4.5, 1.3, 2},
{6.3, 3.3, 4.7, 1.6, 2}, {4.9, 2.4, 3.3, 1.0, 2},
{6.6, 2.9, 4.6, 1.3, 2}, {5.2, 2.7, 3.9, 1.4, 2},
{5.0, 2.0, 3.5, 1.0, 2}, {5.9, 3.0, 4.2, 1.5, 2},
{6.0, 2.2, 4.0, 1.0, 2}, {6.1, 2.9, 4.7, 1.4, 2},
{5.6, 2.9, 3.6, 1.3, 2}, {6.7, 3.1, 4.4, 1.4, 2},
{5.6, 3.0, 4.5, 1.5, 2}, {5.8, 2.7, 4.1, 1.0, 2},
{6.2, 2.2, 4.5, 1.5, 2}, {5.6, 2.5, 3.9, 1.1, 2},
{5.9, 3.2, 4.8, 1.8, 2}, {6.1, 2.8, 4.0, 1.3, 2},
{6.3, 2.5, 4.9, 1.5, 2}, {6.1, 2.8, 4.7, 1.2, 2},
{6.4, 2.9, 4.3, 1.3, 2}, {6.6, 3.0, 4.4, 1.4, 2},
{6.8, 2.8, 4.8, 1.4, 2}, {6.7, 3.0, 5.0, 1.7, 2},
{6.0, 2.9, 4.5, 1.5, 2}, {5.7, 2.6, 3.5, 1.0, 2},
{5.5, 2.4, 3.8, 1.1, 2}, {5.5, 2.4, 3.7, 1.0, 2},
{5.8, 2.7, 3.9, 1.2, 2}, {6.0, 2.7, 5.1, 1.6, 2},
{5.4, 3.0, 4.5, 1.5, 2}, {6.0, 3.4, 4.5, 1.6, 2},
{6.7, 3.1, 4.7, 1.5, 2}, {6.3, 2.3, 4.4, 1.3, 2},
{5.6, 3.0, 4.1, 1.3, 2}, {5.5, 2.5, 4.0, 1.3, 2},
{5.5, 2.6, 4.4, 1.2, 2}, {6.1, 3.0, 4.6, 1.4, 2},
{5.8, 2.6, 4.0, 1.2, 2}, {5.0, 2.3, 3.3, 1.0, 2},
{5.6, 2.7, 4.2, 1.3, 2}, {5.7, 3.0, 4.2, 1.2, 2},
{5.7, 2.9, 4.2, 1.3, 2}, {6.2, 2.9, 4.3, 1.3, 2},
{5.1, 2.5, 3.0, 1.1, 2}, {5.7, 2.8, 4.1, 1.3, 2},
{6.3, 3.3, 6.0, 2.5, 3}, {5.8, 2.7, 5.1, 1.9, 3},
{7.1, 3.0, 5.9, 2.1, 3}, {6.3, 2.9, 5.6, 1.8, 3},
{6.5, 3.0, 5.8, 2.2, 3}, {7.6, 3.0, 6.6, 2.1, 3},
{4.9, 2.5, 4.5, 1.7, 3}, {7.3, 2.9, 6.3, 1.8, 3},
{6.7, 2.5, 5.8, 1.8, 3}, {7.2, 3.6, 6.1, 2.5, 3},
{6.5, 3.2, 5.1, 2.0, 3}, {6.4, 2.7, 5.3, 1.9, 3},
{6.8, 3.0, 5.5, 2.1, 3}, {5.7, 2.5, 5.0, 2.0, 3},
{5.8, 2.8, 5.1, 2.4, 3}, {6.4, 3.2, 5.3, 2.3, 3},
{6.5, 3.0, 5.5, 1.8, 3}, {7.7, 3.8, 6.7, 2.2, 3},
{7.7, 2.6, 6.9, 2.3, 3}, {6.0, 2.2, 5.0, 1.5, 3},
{6.9, 3.2, 5.7, 2.3, 3}, {5.6, 2.8, 4.9, 2.0, 3},
{7.7, 2.8, 6.7, 2.0, 3}, {6.3, 2.7, 4.9, 1.8, 3},
{6.7, 3.3, 5.7, 2.1, 3}, {7.2, 3.2, 6.0, 1.8, 3},
{6.2, 2.8, 4.8, 1.8, 3}, {6.1, 3.0, 4.9, 1.8, 3},
{6.4, 2.8, 5.6, 2.1, 3}, {7.2, 3.0, 5.8, 1.6, 3},
{7.4, 2.8, 6.1, 1.9, 3}, {7.9, 3.8, 6.4, 2.0, 3},
{6.4, 2.8, 5.6, 2.2, 3}, {6.3, 2.8, 5.1, 1.5, 3},
{6.1, 2.6, 5.6, 1.4, 3}, {7.7, 3.0, 6.1, 2.3, 3},
{6.3, 3.4, 5.6, 2.4, 3}, {6.4, 3.1, 5.5, 1.8, 3},
{6.0, 3.0, 4.8, 1.8, 3}, {6.9, 3.1, 5.4, 2.1, 3},
{6.7, 3.1, 5.6, 2.4, 3}, {6.9, 3.1, 5.1, 2.3, 3},
{5.8, 2.7, 5.1, 1.9, 3}, {6.8, 3.2, 5.9, 2.3, 3},
{6.7, 3.3, 5.7, 2.5, 3}, {6.7, 3.0, 5.2, 2.3, 3},
{6.3, 2.5, 5.0, 1.9, 3}, {6.5, 3.0, 5.2, 2.0, 3},
{6.2, 3.4, 5.4, 2.3, 3}, {5.9, 3.0, 5.1, 1.8, 3}
};
public static void main(String[] args) throws Exception {
double xData[][] = new double[nObs][nInputs];
int yData[] = new int[nObs];
for (int i = 0; i < nObs; i++) {
for (int j = 0; j < nInputs; j++) {
xData[i][j] = irisData[i][j];
}
yData[i] = (int) irisData[i][4];
}
// Create network
FeedForwardNetwork network = new FeedForwardNetwork();
network.getInputLayer().createInputs(nInputs);
network.createHiddenLayer().
createPerceptrons(4, Activation.LOGISTIC, 0.0);
network.createHiddenLayer().
createPerceptrons(3, Activation.LOGISTIC, 0.0);
network.getOutputLayer().
createPerceptrons(nOutputs, Activation.SOFTMAX, 0.0);
network.linkAll();
MultiClassification classification = new MultiClassification(network);
// Create trainer
QuasiNewtonTrainer trainer = new QuasiNewtonTrainer();
trainer.setError(classification.getError());
trainer.setMaximumTrainingIterations(1000);
// If tracing is requested setup training logger
if (trace) {
Handler handler
= new FileHandler("ClassificationNetworkTraining.log");
Logger logger = Logger.getLogger("com.imsl.datamining.neural");
logger.setLevel(Level.FINEST);
logger.addHandler(handler);
handler.setFormatter(QuasiNewtonTrainer.getFormatter());
}
// Train Network
long t0 = System.currentTimeMillis();
classification.train(trainer, xData, yData);
// Display Network Errors
double stats[] = classification.computeStatistics(xData, yData);
System.out.println("***********************************************");
System.out.println("--> Cross-entropy error: "
+ (float) stats[0]);
System.out.println("--> Classification error rate: "
+ (float) stats[1]);
System.out.println("***********************************************");
System.out.println("");
double weight[] = network.getWeights();
double gradient[] = trainer.getErrorGradient();
double wg[][] = new double[weight.length][2];
for (int i = 0; i < weight.length; i++) {
wg[i][0] = weight[i];
wg[i][1] = gradient[i];
}
PrintMatrixFormat pmf = new PrintMatrixFormat();
pmf.setNumberFormat(new java.text.DecimalFormat("0.000000"));
pmf.setColumnLabels(new String[]{"Weights", "Gradients"});
new PrintMatrix().print(pmf, wg);
double report[][] = new double[nObs][nInputs + 2];
for (int i = 0; i < nObs; i++) {
for (int j = 0; j < nInputs; j++) {
report[i][j] = xData[i][j];
}
report[i][nInputs] = irisData[i][4];
report[i][nInputs + 1] = classification.predictedClass(xData[i]);
}
pmf = new PrintMatrixFormat();
pmf.setColumnLabels(new String[]{
"Sepal Length",
"Sepal Width",
"Petal Length",
"Petal Width",
"Expected",
"Predicted"}
);
new PrintMatrix("Forecast").print(pmf, report);
// ******************************************************************
// DISPLAY CLASSIFICATION STATISTICS
// ******************************************************************
double statsClass[] = classification.computeStatistics(xData, yData);
// Display Network Errors
System.out.println("***********************************************");
System.out.println("--> Cross-Entropy Error: "
+ (float) statsClass[0]);
System.out.println("--> Classification Error: "
+ (float) statsClass[1]);
System.out.println("***********************************************");
System.out.println("");
long t1 = System.currentTimeMillis();
double time = t1 - t0;
time = time / 1000;
System.out.println("****************Time: " + time);
System.out.println("Cross-Entropy Error Value = "
+ trainer.getErrorValue());
}
}
***********************************************
--> Cross-entropy error: 4.653512
--> Classification error rate: 0.006666667
***********************************************
Weights Gradients
0 -42.381828 0.030801
1 193.055878 0.000000
2 -30.384656 0.000000
3 95.352605 0.000000
4 -33.692976 0.012782
5 -282.844912 0.000000
6 -422.581218 0.000000
7 -317.896968 0.000000
8 60.505928 0.023458
9 -94.286590 0.000000
10 109.828939 0.000000
11 -168.351914 0.000000
12 42.250439 0.008464
13 691.012987 0.000000
14 602.474794 0.000000
15 694.122349 0.000000
16 -3.036409 -1.035514
17 -151.673802 -5.466157
18 75.471899 0.000003
19 3.479083 -1.035303
20 46.614896 -5.466182
21 56.347637 0.032349
22 153.010906 -1.035303
23 204.597137 -5.466182
24 21.557752 0.032349
25 64.075577 -1.035303
26 67.599168 -5.466182
27 78.904088 0.032349
28 -5672.118029 0.000000
29 1244.905099 0.010077
30 4428.212930 -0.010077
31 -5600.671740 0.000000
32 1746.525173 0.004710
33 3855.146566 -0.004710
34 -5562.390356 0.000000
35 1230.760279 0.010431
36 4332.630076 -0.010431
37 -15.417798 0.004103
38 328.841061 0.000000
39 323.847338 0.000000
40 306.946067 0.000000
41 -214.124377 -1.035303
42 -167.320245 -5.466182
43 -156.514239 0.032349
44 13108.354735 0.000000
45 -2985.466557 0.010413
46 -10122.888178 -0.010413
Forecast
Sepal Length Sepal Width Petal Length Petal Width Expected Predicted
0 5.1 3.5 1.4 0.2 1 1
1 4.9 3 1.4 0.2 1 1
2 4.7 3.2 1.3 0.2 1 1
3 4.6 3.1 1.5 0.2 1 1
4 5 3.6 1.4 0.2 1 1
5 5.4 3.9 1.7 0.4 1 1
6 4.6 3.4 1.4 0.3 1 1
7 5 3.4 1.5 0.2 1 1
8 4.4 2.9 1.4 0.2 1 1
9 4.9 3.1 1.5 0.1 1 1
10 5.4 3.7 1.5 0.2 1 1
11 4.8 3.4 1.6 0.2 1 1
12 4.8 3 1.4 0.1 1 1
13 4.3 3 1.1 0.1 1 1
14 5.8 4 1.2 0.2 1 1
15 5.7 4.4 1.5 0.4 1 1
16 5.4 3.9 1.3 0.4 1 1
17 5.1 3.5 1.4 0.3 1 1
18 5.7 3.8 1.7 0.3 1 1
19 5.1 3.8 1.5 0.3 1 1
20 5.4 3.4 1.7 0.2 1 1
21 5.1 3.7 1.5 0.4 1 1
22 4.6 3.6 1 0.2 1 1
23 5.1 3.3 1.7 0.5 1 1
24 4.8 3.4 1.9 0.2 1 1
25 5 3 1.6 0.2 1 1
26 5 3.4 1.6 0.4 1 1
27 5.2 3.5 1.5 0.2 1 1
28 5.2 3.4 1.4 0.2 1 1
29 4.7 3.2 1.6 0.2 1 1
30 4.8 3.1 1.6 0.2 1 1
31 5.4 3.4 1.5 0.4 1 1
32 5.2 4.1 1.5 0.1 1 1
33 5.5 4.2 1.4 0.2 1 1
34 4.9 3.1 1.5 0.1 1 1
35 5 3.2 1.2 0.2 1 1
36 5.5 3.5 1.3 0.2 1 1
37 4.9 3.1 1.5 0.1 1 1
38 4.4 3 1.3 0.2 1 1
39 5.1 3.4 1.5 0.2 1 1
40 5 3.5 1.3 0.3 1 1
41 4.5 2.3 1.3 0.3 1 1
42 4.4 3.2 1.3 0.2 1 1
43 5 3.5 1.6 0.6 1 1
44 5.1 3.8 1.9 0.4 1 1
45 4.8 3 1.4 0.3 1 1
46 5.1 3.8 1.6 0.2 1 1
47 4.6 3.2 1.4 0.2 1 1
48 5.3 3.7 1.5 0.2 1 1
49 5 3.3 1.4 0.2 1 1
50 7 3.2 4.7 1.4 2 2
51 6.4 3.2 4.5 1.5 2 2
52 6.9 3.1 4.9 1.5 2 2
53 5.5 2.3 4 1.3 2 2
54 6.5 2.8 4.6 1.5 2 2
55 5.7 2.8 4.5 1.3 2 2
56 6.3 3.3 4.7 1.6 2 2
57 4.9 2.4 3.3 1 2 2
58 6.6 2.9 4.6 1.3 2 2
59 5.2 2.7 3.9 1.4 2 2
60 5 2 3.5 1 2 2
61 5.9 3 4.2 1.5 2 2
62 6 2.2 4 1 2 2
63 6.1 2.9 4.7 1.4 2 2
64 5.6 2.9 3.6 1.3 2 2
65 6.7 3.1 4.4 1.4 2 2
66 5.6 3 4.5 1.5 2 2
67 5.8 2.7 4.1 1 2 2
68 6.2 2.2 4.5 1.5 2 2
69 5.6 2.5 3.9 1.1 2 2
70 5.9 3.2 4.8 1.8 2 2
71 6.1 2.8 4 1.3 2 2
72 6.3 2.5 4.9 1.5 2 2
73 6.1 2.8 4.7 1.2 2 2
74 6.4 2.9 4.3 1.3 2 2
75 6.6 3 4.4 1.4 2 2
76 6.8 2.8 4.8 1.4 2 2
77 6.7 3 5 1.7 2 2
78 6 2.9 4.5 1.5 2 2
79 5.7 2.6 3.5 1 2 2
80 5.5 2.4 3.8 1.1 2 2
81 5.5 2.4 3.7 1 2 2
82 5.8 2.7 3.9 1.2 2 2
83 6 2.7 5.1 1.6 2 3
84 5.4 3 4.5 1.5 2 2
85 6 3.4 4.5 1.6 2 2
86 6.7 3.1 4.7 1.5 2 2
87 6.3 2.3 4.4 1.3 2 2
88 5.6 3 4.1 1.3 2 2
89 5.5 2.5 4 1.3 2 2
90 5.5 2.6 4.4 1.2 2 2
91 6.1 3 4.6 1.4 2 2
92 5.8 2.6 4 1.2 2 2
93 5 2.3 3.3 1 2 2
94 5.6 2.7 4.2 1.3 2 2
95 5.7 3 4.2 1.2 2 2
96 5.7 2.9 4.2 1.3 2 2
97 6.2 2.9 4.3 1.3 2 2
98 5.1 2.5 3 1.1 2 2
99 5.7 2.8 4.1 1.3 2 2
100 6.3 3.3 6 2.5 3 3
101 5.8 2.7 5.1 1.9 3 3
102 7.1 3 5.9 2.1 3 3
103 6.3 2.9 5.6 1.8 3 3
104 6.5 3 5.8 2.2 3 3
105 7.6 3 6.6 2.1 3 3
106 4.9 2.5 4.5 1.7 3 3
107 7.3 2.9 6.3 1.8 3 3
108 6.7 2.5 5.8 1.8 3 3
109 7.2 3.6 6.1 2.5 3 3
110 6.5 3.2 5.1 2 3 3
111 6.4 2.7 5.3 1.9 3 3
112 6.8 3 5.5 2.1 3 3
113 5.7 2.5 5 2 3 3
114 5.8 2.8 5.1 2.4 3 3
115 6.4 3.2 5.3 2.3 3 3
116 6.5 3 5.5 1.8 3 3
117 7.7 3.8 6.7 2.2 3 3
118 7.7 2.6 6.9 2.3 3 3
119 6 2.2 5 1.5 3 3
120 6.9 3.2 5.7 2.3 3 3
121 5.6 2.8 4.9 2 3 3
122 7.7 2.8 6.7 2 3 3
123 6.3 2.7 4.9 1.8 3 3
124 6.7 3.3 5.7 2.1 3 3
125 7.2 3.2 6 1.8 3 3
126 6.2 2.8 4.8 1.8 3 3
127 6.1 3 4.9 1.8 3 3
128 6.4 2.8 5.6 2.1 3 3
129 7.2 3 5.8 1.6 3 3
130 7.4 2.8 6.1 1.9 3 3
131 7.9 3.8 6.4 2 3 3
132 6.4 2.8 5.6 2.2 3 3
133 6.3 2.8 5.1 1.5 3 3
134 6.1 2.6 5.6 1.4 3 3
135 7.7 3 6.1 2.3 3 3
136 6.3 3.4 5.6 2.4 3 3
137 6.4 3.1 5.5 1.8 3 3
138 6 3 4.8 1.8 3 3
139 6.9 3.1 5.4 2.1 3 3
140 6.7 3.1 5.6 2.4 3 3
141 6.9 3.1 5.1 2.3 3 3
142 5.8 2.7 5.1 1.9 3 3
143 6.8 3.2 5.9 2.3 3 3
144 6.7 3.3 5.7 2.5 3 3
145 6.7 3 5.2 2.3 3 3
146 6.3 2.5 5 1.9 3 3
147 6.5 3 5.2 2 3 3
148 6.2 3.4 5.4 2.3 3 3
149 5.9 3 5.1 1.8 3 3
***********************************************
--> Cross-Entropy Error: 4.653512
--> Classification Error: 0.006666667
***********************************************
****************Time: 1.31
Cross-Entropy Error Value = 4.653511831588219
Link to Java source.