Example 1: MultiClassification

This example trains a 3-layer network using Fisher's Iris data with four continuous input attributes and three output classifications. This is perhaps the best known database to be found in the pattern recognition literature. Fisher's paper is a classic in the field. The data set contains 3 classes of 50 instances each, where each class refers to a type of iris plant.

The structure of the network consists of four input nodes and three layers, with four perceptrons in the first hidden layer, three perceptrons in the second hidden layer and three in the output layer.

The four input attributes represent

  1. Sepal length
  2. Sepal width
  3. Petal length
  4. Petal width

The output attribute represents the class of the iris plant and are encoded using binary encoding.

  1. Iris Setosa
  2. Iris Versicolour
  3. Iris Virginica

There are a total of 46 weights in this network, including the bias weights. All hidden layers use the logistic activation function. Since the target output is multi-classification the softmax activation function is used in the output layer and the MultiClassification error function class is used by the trainer. The error class MultiClassification combines the cross-entropy error claculations and the softmax function.

import com.imsl.datamining.neural.*;
import com.imsl.math.PrintMatrix;
import com.imsl.math.PrintMatrixFormat;
import java.io.*;
import java.util.logging.*;

//*****************************************************************************
// Three Layer Feed-Forward Network with 4 inputs, all 
// continuous, and 3 classification categories.
//
//  new classification training_ex5.c
//
// This is perhaps the best known database to be found in the pattern
//     recognition literature.  Fisher's paper is a classic in the field.
//     The data set contains 3 classes of 50 instances each,
//     where each class refers to a type of iris plant.  One class is
//     linearly separable from the other 2; the latter are NOT linearly
//     separable from each other.
//
//  Predicted attribute: class of iris plant.
//     1=Iris Setosa, 2=Iris Versicolour, and 3=Iris Virginica
//
//  Input Attributes (4 Continuous Attributes)
//     X1: Sepal length, X2: Sepal width, X3: Petal length, and X4: Petal width
//*****************************************************************************

public class MultiClassificationEx1 implements Serializable {
    private static int nObs          = 150; // number of training patterns
    private static int nInputs       =  4; // 9 nominal coded as 0=x, 1=o, 2=blank
    private static int nOutputs      =  3; // one continuous output (nClasses=2)
    private static boolean trace     = true; // Turns on/off training log

    
	// irisData[]:  The raw data matrix.  This is a 2-D matrix with 150 rows and 5 columns.  *
	//              The first 4 columns are the continuous input attributes and the 5th      *
	//              column is the classification category (1-3).  These data contain no      *
	//              categorical input attributes.                                            *

    private static double[][] irisData = {
        {5.1,3.5,1.4,0.2,1},{4.9,3.0,1.4,0.2,1},{4.7,3.2,1.3,0.2,1},{4.6,3.1,1.5,0.2,1},
        {5.0,3.6,1.4,0.2,1},{5.4,3.9,1.7,0.4,1},{4.6,3.4,1.4,0.3,1},{5.0,3.4,1.5,0.2,1},
        {4.4,2.9,1.4,0.2,1},{4.9,3.1,1.5,0.1,1},{5.4,3.7,1.5,0.2,1},{4.8,3.4,1.6,0.2,1},
        {4.8,3.0,1.4,0.1,1},{4.3,3.0,1.1,0.1,1},{5.8,4.0,1.2,0.2,1},{5.7,4.4,1.5,0.4,1},
        {5.4,3.9,1.3,0.4,1},{5.1,3.5,1.4,0.3,1},{5.7,3.8,1.7,0.3,1},{5.1,3.8,1.5,0.3,1},
        {5.4,3.4,1.7,0.2,1},{5.1,3.7,1.5,0.4,1},{4.6,3.6,1.0,0.2,1},{5.1,3.3,1.7,0.5,1},
        {4.8,3.4,1.9,0.2,1},{5.0,3.0,1.6,0.2,1},{5.0,3.4,1.6,0.4,1},{5.2,3.5,1.5,0.2,1},
        {5.2,3.4,1.4,0.2,1},{4.7,3.2,1.6,0.2,1},{4.8,3.1,1.6,0.2,1},{5.4,3.4,1.5,0.4,1},
        {5.2,4.1,1.5,0.1,1},{5.5,4.2,1.4,0.2,1},{4.9,3.1,1.5,0.1,1},{5.0,3.2,1.2,0.2,1},
        {5.5,3.5,1.3,0.2,1},{4.9,3.1,1.5,0.1,1},{4.4,3.0,1.3,0.2,1},{5.1,3.4,1.5,0.2,1},
        {5.0,3.5,1.3,0.3,1},{4.5,2.3,1.3,0.3,1},{4.4,3.2,1.3,0.2,1},{5.0,3.5,1.6,0.6,1},
        {5.1,3.8,1.9,0.4,1},{4.8,3.0,1.4,0.3,1},{5.1,3.8,1.6,0.2,1},{4.6,3.2,1.4,0.2,1},
        {5.3,3.7,1.5,0.2,1},{5.0,3.3,1.4,0.2,1},

        {7.0,3.2,4.7,1.4,2},{6.4,3.2,4.5,1.5,2},{6.9,3.1,4.9,1.5,2},{5.5,2.3,4.0,1.3,2},
        {6.5,2.8,4.6,1.5,2},{5.7,2.8,4.5,1.3,2},{6.3,3.3,4.7,1.6,2},{4.9,2.4,3.3,1.0,2},
        {6.6,2.9,4.6,1.3,2},{5.2,2.7,3.9,1.4,2},{5.0,2.0,3.5,1.0,2},{5.9,3.0,4.2,1.5,2},
        {6.0,2.2,4.0,1.0,2},{6.1,2.9,4.7,1.4,2},{5.6,2.9,3.6,1.3,2},{6.7,3.1,4.4,1.4,2},
        {5.6,3.0,4.5,1.5,2},{5.8,2.7,4.1,1.0,2},{6.2,2.2,4.5,1.5,2},{5.6,2.5,3.9,1.1,2},
        {5.9,3.2,4.8,1.8,2},{6.1,2.8,4.0,1.3,2},{6.3,2.5,4.9,1.5,2},{6.1,2.8,4.7,1.2,2},
        {6.4,2.9,4.3,1.3,2},{6.6,3.0,4.4,1.4,2},{6.8,2.8,4.8,1.4,2},{6.7,3.0,5.0,1.7,2},
        {6.0,2.9,4.5,1.5,2},{5.7,2.6,3.5,1.0,2},{5.5,2.4,3.8,1.1,2},{5.5,2.4,3.7,1.0,2},
        {5.8,2.7,3.9,1.2,2},{6.0,2.7,5.1,1.6,2},{5.4,3.0,4.5,1.5,2},{6.0,3.4,4.5,1.6,2},
        {6.7,3.1,4.7,1.5,2},{6.3,2.3,4.4,1.3,2},{5.6,3.0,4.1,1.3,2},{5.5,2.5,4.0,1.3,2},
        {5.5,2.6,4.4,1.2,2},{6.1,3.0,4.6,1.4,2},{5.8,2.6,4.0,1.2,2},{5.0,2.3,3.3,1.0,2},
        {5.6,2.7,4.2,1.3,2},{5.7,3.0,4.2,1.2,2},{5.7,2.9,4.2,1.3,2},{6.2,2.9,4.3,1.3,2},
        {5.1,2.5,3.0,1.1,2},{5.7,2.8,4.1,1.3,2},

        {6.3,3.3,6.0,2.5,3},{5.8,2.7,5.1,1.9,3},{7.1,3.0,5.9,2.1,3},{6.3,2.9,5.6,1.8,3},
        {6.5,3.0,5.8,2.2,3},{7.6,3.0,6.6,2.1,3},{4.9,2.5,4.5,1.7,3},{7.3,2.9,6.3,1.8,3},
        {6.7,2.5,5.8,1.8,3},{7.2,3.6,6.1,2.5,3},{6.5,3.2,5.1,2.0,3},{6.4,2.7,5.3,1.9,3},
        {6.8,3.0,5.5,2.1,3},{5.7,2.5,5.0,2.0,3},{5.8,2.8,5.1,2.4,3},{6.4,3.2,5.3,2.3,3},
        {6.5,3.0,5.5,1.8,3},{7.7,3.8,6.7,2.2,3},{7.7,2.6,6.9,2.3,3},{6.0,2.2,5.0,1.5,3},
        {6.9,3.2,5.7,2.3,3},{5.6,2.8,4.9,2.0,3},{7.7,2.8,6.7,2.0,3},{6.3,2.7,4.9,1.8,3},
        {6.7,3.3,5.7,2.1,3},{7.2,3.2,6.0,1.8,3},{6.2,2.8,4.8,1.8,3},{6.1,3.0,4.9,1.8,3},
        {6.4,2.8,5.6,2.1,3},{7.2,3.0,5.8,1.6,3},{7.4,2.8,6.1,1.9,3},{7.9,3.8,6.4,2.0,3},
        {6.4,2.8,5.6,2.2,3},{6.3,2.8,5.1,1.5,3},{6.1,2.6,5.6,1.4,3},{7.7,3.0,6.1,2.3,3},
        {6.3,3.4,5.6,2.4,3},{6.4,3.1,5.5,1.8,3},{6.0,3.0,4.8,1.8,3},{6.9,3.1,5.4,2.1,3},
        {6.7,3.1,5.6,2.4,3},{6.9,3.1,5.1,2.3,3},{5.8,2.7,5.1,1.9,3},{6.8,3.2,5.9,2.3,3},
        {6.7,3.3,5.7,2.5,3},{6.7,3.0,5.2,2.3,3},{6.3,2.5,5.0,1.9,3},{6.5,3.0,5.2,2.0,3},
        {6.2,3.4,5.4,2.3,3},{5.9,3.0,5.1,1.8,3}
    };    
    
    public static void main(String[] args) throws Exception {
        double xData[][] = new double[nObs][nInputs];
        int yData[] = new int[nObs];
        
        for (int i = 0;  i < nObs;  i++) {
            for (int j = 0;  j < nInputs;  j++) {
                xData[i][j] = irisData[i][j];
            }
            yData[i] = (int)irisData[i][4];
        }
        
        // Create network
        FeedForwardNetwork network = new FeedForwardNetwork();
        network.getInputLayer().createInputs(nInputs);
        network.createHiddenLayer().createPerceptrons(4, Activation.LOGISTIC, 0.0);
        network.createHiddenLayer().createPerceptrons(3, Activation.LOGISTIC, 0.0);
        network.getOutputLayer().createPerceptrons(nOutputs, Activation.SOFTMAX, 0.0);
        network.linkAll();
        
        MultiClassification classification = new MultiClassification(network);
        
        // Create trainer
        QuasiNewtonTrainer trainer = new QuasiNewtonTrainer();
        trainer.setError(classification.getError());
        trainer.setMaximumTrainingIterations(1000);
        
        // If tracing is requested setup training logger
        if (trace) {
            Handler handler = new FileHandler("ClassificationNetworkTraining.log");
            Logger logger = Logger.getLogger("com.imsl.datamining.neural");
            logger.setLevel(Level.FINEST);
            logger.addHandler(handler);
            handler.setFormatter(QuasiNewtonTrainer.getFormatter());
        }
        // Train Network
        long t0 = System.currentTimeMillis();
        classification.train(trainer, xData, yData);
        
        // Display Network Errors
       double stats[] = classification.computeStatistics(xData, yData);
       System.out.println("***********************************************");
       System.out.println("--> Cross-entropy error:        "+(float)stats[0]);
       System.out.println("--> Classification error rate:  "+(float)stats[1]);
       System.out.println("***********************************************");
       System.out.println("");
        
        double weight[]   = network.getWeights();
        double gradient[] = trainer.getErrorGradient();
        double wg[][] = new double[weight.length][2];
        for(int i = 0;  i < weight.length;  i++) 
        {
            wg[i][0] = weight[i];
            wg[i][1] = gradient[i];
        }
        PrintMatrixFormat pmf = new PrintMatrixFormat();
        pmf.setNumberFormat(new java.text.DecimalFormat("0.000000"));
        pmf.setColumnLabels(new String[]{"Weights", "Gradients"});
        new PrintMatrix().print(pmf,wg);
        
        double report[][] = new double[nObs][nInputs+2];
        for (int i = 0;  i < nObs;  i++) {
            for (int j = 0;  j < nInputs;  j++) {
                report[i][j] = xData[i][j];
            }
            report[i][nInputs] = irisData[i][4];
            report[i][nInputs+1] = classification.predictedClass(xData[i]);
        }
        pmf = new PrintMatrixFormat();
        pmf.setColumnLabels(new String[]{
            "Sepal Length", 
            "Sepal Width",
            "Petal Length",
            "Petal Width",
            "Expected",
            "Predicted"});
        new PrintMatrix("Forecast").print(pmf, report);


    // **********************************************************************
    // DISPLAY CLASSIFICATION STATISTICS
    // **********************************************************************
       double statsClass[] = classification.computeStatistics(xData, yData);
       // Display Network Errors
       System.out.println("***********************************************");
       System.out.println("--> Cross-Entropy Error:      "+(float)statsClass[0]);
       System.out.println("--> Classification Error:     "+(float)statsClass[1]);
       System.out.println("***********************************************");
       System.out.println("");
        long t1 = System.currentTimeMillis();
        double small = 1.e-7;
        double time =  t1-t0; //Math.max(small, (double)(t1-t0)/(double)iters);
        time = time/1000;
        System.out.println("****************Time:  "+time);

        System.out.println("Cross-Entropy Error Value = "+trainer.getErrorValue());

    }
}

Output

***********************************************
--> Cross-entropy error:        4.640623
--> Classification error rate:  0.006666667
***********************************************

       Weights     Gradients  
 0     -51.777881  -0.021660  
 1     605.119380   0.000000  
 2    -284.226877   0.000000  
 3     327.038883   0.000000  
 4     -41.160485  -0.009887  
 5    -867.891312   0.000000  
 6   -1210.846071   0.000000  
 7    -994.103717   0.000000  
 8      73.932788  -0.016740  
 9    -346.829319   0.000000  
10     704.482597   0.000000  
11    -497.908892   0.000000  
12      51.636506  -0.006301  
13    1943.984336   0.000000  
14    1516.711136   0.000000  
15    1935.687178   0.000000  
16      -3.143561  -2.271656  
17    -443.852301  -7.201949  
18     242.475544  -0.000024  
19      23.461487  -2.272793  
20     189.287779  -7.201954  
21     260.386655  -0.096456  
22     564.420647  -2.272793  
23     607.227248  -7.201954  
24     -62.368750  -0.096456  
25     163.370794  -2.272793  
26     216.054929  -7.201954  
27     296.537883  -0.096456  
28  -15686.506783   0.000000  
29    3478.164215   0.004606  
30   12209.342568  -0.004606  
31  -15443.797985   0.000000  
32    4719.334347   0.002674  
33   10725.463639  -0.002674  
34  -15303.926099   0.000000  
35    3602.472102   0.004863  
36   11702.453998  -0.004863  
37     -19.854440  -0.003322  
38     965.005400   0.000000  
39     874.394173   0.000000  
40     898.666721   0.000000  
41    -745.305267  -2.272793  
42    -568.545362  -7.201954  
43    -494.170957  -0.096456  
44   36175.248628   0.000000  
45   -8292.572938   0.004882  
46  -27882.675691  -0.004882  

                                    Forecast
     Sepal Length  Sepal Width  Petal Length  Petal Width  Expected  Predicted  
  0      5.1           3.5          1.4           0.2         1          1      
  1      4.9           3            1.4           0.2         1          1      
  2      4.7           3.2          1.3           0.2         1          1      
  3      4.6           3.1          1.5           0.2         1          1      
  4      5             3.6          1.4           0.2         1          1      
  5      5.4           3.9          1.7           0.4         1          1      
  6      4.6           3.4          1.4           0.3         1          1      
  7      5             3.4          1.5           0.2         1          1      
  8      4.4           2.9          1.4           0.2         1          1      
  9      4.9           3.1          1.5           0.1         1          1      
 10      5.4           3.7          1.5           0.2         1          1      
 11      4.8           3.4          1.6           0.2         1          1      
 12      4.8           3            1.4           0.1         1          1      
 13      4.3           3            1.1           0.1         1          1      
 14      5.8           4            1.2           0.2         1          1      
 15      5.7           4.4          1.5           0.4         1          1      
 16      5.4           3.9          1.3           0.4         1          1      
 17      5.1           3.5          1.4           0.3         1          1      
 18      5.7           3.8          1.7           0.3         1          1      
 19      5.1           3.8          1.5           0.3         1          1      
 20      5.4           3.4          1.7           0.2         1          1      
 21      5.1           3.7          1.5           0.4         1          1      
 22      4.6           3.6          1             0.2         1          1      
 23      5.1           3.3          1.7           0.5         1          1      
 24      4.8           3.4          1.9           0.2         1          1      
 25      5             3            1.6           0.2         1          1      
 26      5             3.4          1.6           0.4         1          1      
 27      5.2           3.5          1.5           0.2         1          1      
 28      5.2           3.4          1.4           0.2         1          1      
 29      4.7           3.2          1.6           0.2         1          1      
 30      4.8           3.1          1.6           0.2         1          1      
 31      5.4           3.4          1.5           0.4         1          1      
 32      5.2           4.1          1.5           0.1         1          1      
 33      5.5           4.2          1.4           0.2         1          1      
 34      4.9           3.1          1.5           0.1         1          1      
 35      5             3.2          1.2           0.2         1          1      
 36      5.5           3.5          1.3           0.2         1          1      
 37      4.9           3.1          1.5           0.1         1          1      
 38      4.4           3            1.3           0.2         1          1      
 39      5.1           3.4          1.5           0.2         1          1      
 40      5             3.5          1.3           0.3         1          1      
 41      4.5           2.3          1.3           0.3         1          1      
 42      4.4           3.2          1.3           0.2         1          1      
 43      5             3.5          1.6           0.6         1          1      
 44      5.1           3.8          1.9           0.4         1          1      
 45      4.8           3            1.4           0.3         1          1      
 46      5.1           3.8          1.6           0.2         1          1      
 47      4.6           3.2          1.4           0.2         1          1      
 48      5.3           3.7          1.5           0.2         1          1      
 49      5             3.3          1.4           0.2         1          1      
 50      7             3.2          4.7           1.4         2          2      
 51      6.4           3.2          4.5           1.5         2          2      
 52      6.9           3.1          4.9           1.5         2          2      
 53      5.5           2.3          4             1.3         2          2      
 54      6.5           2.8          4.6           1.5         2          2      
 55      5.7           2.8          4.5           1.3         2          2      
 56      6.3           3.3          4.7           1.6         2          2      
 57      4.9           2.4          3.3           1           2          2      
 58      6.6           2.9          4.6           1.3         2          2      
 59      5.2           2.7          3.9           1.4         2          2      
 60      5             2            3.5           1           2          2      
 61      5.9           3            4.2           1.5         2          2      
 62      6             2.2          4             1           2          2      
 63      6.1           2.9          4.7           1.4         2          2      
 64      5.6           2.9          3.6           1.3         2          2      
 65      6.7           3.1          4.4           1.4         2          2      
 66      5.6           3            4.5           1.5         2          2      
 67      5.8           2.7          4.1           1           2          2      
 68      6.2           2.2          4.5           1.5         2          2      
 69      5.6           2.5          3.9           1.1         2          2      
 70      5.9           3.2          4.8           1.8         2          2      
 71      6.1           2.8          4             1.3         2          2      
 72      6.3           2.5          4.9           1.5         2          2      
 73      6.1           2.8          4.7           1.2         2          2      
 74      6.4           2.9          4.3           1.3         2          2      
 75      6.6           3            4.4           1.4         2          2      
 76      6.8           2.8          4.8           1.4         2          2      
 77      6.7           3            5             1.7         2          2      
 78      6             2.9          4.5           1.5         2          2      
 79      5.7           2.6          3.5           1           2          2      
 80      5.5           2.4          3.8           1.1         2          2      
 81      5.5           2.4          3.7           1           2          2      
 82      5.8           2.7          3.9           1.2         2          2      
 83      6             2.7          5.1           1.6         2          3      
 84      5.4           3            4.5           1.5         2          2      
 85      6             3.4          4.5           1.6         2          2      
 86      6.7           3.1          4.7           1.5         2          2      
 87      6.3           2.3          4.4           1.3         2          2      
 88      5.6           3            4.1           1.3         2          2      
 89      5.5           2.5          4             1.3         2          2      
 90      5.5           2.6          4.4           1.2         2          2      
 91      6.1           3            4.6           1.4         2          2      
 92      5.8           2.6          4             1.2         2          2      
 93      5             2.3          3.3           1           2          2      
 94      5.6           2.7          4.2           1.3         2          2      
 95      5.7           3            4.2           1.2         2          2      
 96      5.7           2.9          4.2           1.3         2          2      
 97      6.2           2.9          4.3           1.3         2          2      
 98      5.1           2.5          3             1.1         2          2      
 99      5.7           2.8          4.1           1.3         2          2      
100      6.3           3.3          6             2.5         3          3      
101      5.8           2.7          5.1           1.9         3          3      
102      7.1           3            5.9           2.1         3          3      
103      6.3           2.9          5.6           1.8         3          3      
104      6.5           3            5.8           2.2         3          3      
105      7.6           3            6.6           2.1         3          3      
106      4.9           2.5          4.5           1.7         3          3      
107      7.3           2.9          6.3           1.8         3          3      
108      6.7           2.5          5.8           1.8         3          3      
109      7.2           3.6          6.1           2.5         3          3      
110      6.5           3.2          5.1           2           3          3      
111      6.4           2.7          5.3           1.9         3          3      
112      6.8           3            5.5           2.1         3          3      
113      5.7           2.5          5             2           3          3      
114      5.8           2.8          5.1           2.4         3          3      
115      6.4           3.2          5.3           2.3         3          3      
116      6.5           3            5.5           1.8         3          3      
117      7.7           3.8          6.7           2.2         3          3      
118      7.7           2.6          6.9           2.3         3          3      
119      6             2.2          5             1.5         3          3      
120      6.9           3.2          5.7           2.3         3          3      
121      5.6           2.8          4.9           2           3          3      
122      7.7           2.8          6.7           2           3          3      
123      6.3           2.7          4.9           1.8         3          3      
124      6.7           3.3          5.7           2.1         3          3      
125      7.2           3.2          6             1.8         3          3      
126      6.2           2.8          4.8           1.8         3          3      
127      6.1           3            4.9           1.8         3          3      
128      6.4           2.8          5.6           2.1         3          3      
129      7.2           3            5.8           1.6         3          3      
130      7.4           2.8          6.1           1.9         3          3      
131      7.9           3.8          6.4           2           3          3      
132      6.4           2.8          5.6           2.2         3          3      
133      6.3           2.8          5.1           1.5         3          3      
134      6.1           2.6          5.6           1.4         3          3      
135      7.7           3            6.1           2.3         3          3      
136      6.3           3.4          5.6           2.4         3          3      
137      6.4           3.1          5.5           1.8         3          3      
138      6             3            4.8           1.8         3          3      
139      6.9           3.1          5.4           2.1         3          3      
140      6.7           3.1          5.6           2.4         3          3      
141      6.9           3.1          5.1           2.3         3          3      
142      5.8           2.7          5.1           1.9         3          3      
143      6.8           3.2          5.9           2.3         3          3      
144      6.7           3.3          5.7           2.5         3          3      
145      6.7           3            5.2           2.3         3          3      
146      6.3           2.5          5             1.9         3          3      
147      6.5           3            5.2           2           3          3      
148      6.2           3.4          5.4           2.3         3          3      
149      5.9           3            5.1           1.8         3          3      

***********************************************
--> Cross-Entropy Error:      4.640623
--> Classification Error:     0.006666667
***********************************************

****************Time:  2.094
Cross-Entropy Error Value = 4.6406232788035595
Link to Java source.