package com.imsl.test.example.datamining.neural; import com.imsl.datamining.neural.Activation; import com.imsl.datamining.neural.FeedForwardNetwork; import com.imsl.datamining.neural.MultiClassification; import com.imsl.datamining.neural.QuasiNewtonTrainer; import com.imsl.math.PrintMatrix; import com.imsl.math.PrintMatrixFormat; import java.io.Serializable; import java.util.logging.FileHandler; import java.util.logging.Handler; import java.util.logging.Level; import java.util.logging.Logger; /** * *

Trains a 3-layer network to Fisher's iris data.

*

* This example trains a 3-layer network using Fisher's Iris data with four * continuous input attributes and three output classifications. * Fisher's Iris data is perhaps the best known data set in the pattern recognition * literature. Published in 1936 Fisher's paper is a classic in the field. * The data set contains 3 classes of * 50 instances each, where each class refers to a type of iris plant. The continuous * attributes are measurements of petal length, petal width, sepal length, sepal width. *

*

* The structure of the network consists of four input nodes and three layers, * with four perceptrons in the first hidden layer, three perceptrons in the * second hidden layer and three in the output layer. *

*

* The four input attributes represent *

    *
  1. Sepal length
  2. *
  3. Sepal width
  4. *
  5. Petal length
  6. *
  7. Petal width
  8. *
*

*

* The output attribute represents the class of the iris plant and are encoded * using binary encoding: *

    *
  1. Iris Setosa
  2. *
  3. Iris Versicolour
  4. *
  5. Iris Virginica
  6. *
*

*

* With 4 inputs, 4 and 3 perceptrons in the hidden layers, and 3 perceptrons in the output * layer there are a total of 47 weights in this network, including the bias weights. * All hidden layers use the logistic activation function. Since the target * output is multi-classification the softmax activation function is used in the * output layer and the * MultiClassification * error function class is * used by the trainer. The error class MultiClassification * combines the cross-entropy error calculations and the softmax function. *

* * @see Code * @see Output */ public class MultiClassificationEx1 implements Serializable { private static int nObs = 150; private static int nInputs = 4; private static int nOutputs = 3; private static int nPerceptrons1 = 4; private static int nPerceptrons2 = 3; private static boolean trace = false; // turn on/off tracing // irisData: the first 4 columns are the continuous input attributes and // the 5th column is the classification category (1-3). private static double[][] irisData = { {5.1, 3.5, 1.4, 0.2, 1}, {4.9, 3.0, 1.4, 0.2, 1}, {4.7, 3.2, 1.3, 0.2, 1}, {4.6, 3.1, 1.5, 0.2, 1}, {5.0, 3.6, 1.4, 0.2, 1}, {5.4, 3.9, 1.7, 0.4, 1}, {4.6, 3.4, 1.4, 0.3, 1}, {5.0, 3.4, 1.5, 0.2, 1}, {4.4, 2.9, 1.4, 0.2, 1}, {4.9, 3.1, 1.5, 0.1, 1}, {5.4, 3.7, 1.5, 0.2, 1}, {4.8, 3.4, 1.6, 0.2, 1}, {4.8, 3.0, 1.4, 0.1, 1}, {4.3, 3.0, 1.1, 0.1, 1}, {5.8, 4.0, 1.2, 0.2, 1}, {5.7, 4.4, 1.5, 0.4, 1}, {5.4, 3.9, 1.3, 0.4, 1}, {5.1, 3.5, 1.4, 0.3, 1}, {5.7, 3.8, 1.7, 0.3, 1}, {5.1, 3.8, 1.5, 0.3, 1}, {5.4, 3.4, 1.7, 0.2, 1}, {5.1, 3.7, 1.5, 0.4, 1}, {4.6, 3.6, 1.0, 0.2, 1}, {5.1, 3.3, 1.7, 0.5, 1}, {4.8, 3.4, 1.9, 0.2, 1}, {5.0, 3.0, 1.6, 0.2, 1}, {5.0, 3.4, 1.6, 0.4, 1}, {5.2, 3.5, 1.5, 0.2, 1}, {5.2, 3.4, 1.4, 0.2, 1}, {4.7, 3.2, 1.6, 0.2, 1}, {4.8, 3.1, 1.6, 0.2, 1}, {5.4, 3.4, 1.5, 0.4, 1}, {5.2, 4.1, 1.5, 0.1, 1}, {5.5, 4.2, 1.4, 0.2, 1}, {4.9, 3.1, 1.5, 0.1, 1}, {5.0, 3.2, 1.2, 0.2, 1}, {5.5, 3.5, 1.3, 0.2, 1}, {4.9, 3.1, 1.5, 0.1, 1}, {4.4, 3.0, 1.3, 0.2, 1}, {5.1, 3.4, 1.5, 0.2, 1}, {5.0, 3.5, 1.3, 0.3, 1}, {4.5, 2.3, 1.3, 0.3, 1}, {4.4, 3.2, 1.3, 0.2, 1}, {5.0, 3.5, 1.6, 0.6, 1}, {5.1, 3.8, 1.9, 0.4, 1}, {4.8, 3.0, 1.4, 0.3, 1}, {5.1, 3.8, 1.6, 0.2, 1}, {4.6, 3.2, 1.4, 0.2, 1}, {5.3, 3.7, 1.5, 0.2, 1}, {5.0, 3.3, 1.4, 0.2, 1}, {7.0, 3.2, 4.7, 1.4, 2}, {6.4, 3.2, 4.5, 1.5, 2}, {6.9, 3.1, 4.9, 1.5, 2}, {5.5, 2.3, 4.0, 1.3, 2}, {6.5, 2.8, 4.6, 1.5, 2}, {5.7, 2.8, 4.5, 1.3, 2}, {6.3, 3.3, 4.7, 1.6, 2}, {4.9, 2.4, 3.3, 1.0, 2}, {6.6, 2.9, 4.6, 1.3, 2}, {5.2, 2.7, 3.9, 1.4, 2}, {5.0, 2.0, 3.5, 1.0, 2}, {5.9, 3.0, 4.2, 1.5, 2}, {6.0, 2.2, 4.0, 1.0, 2}, {6.1, 2.9, 4.7, 1.4, 2}, {5.6, 2.9, 3.6, 1.3, 2}, {6.7, 3.1, 4.4, 1.4, 2}, {5.6, 3.0, 4.5, 1.5, 2}, {5.8, 2.7, 4.1, 1.0, 2}, {6.2, 2.2, 4.5, 1.5, 2}, {5.6, 2.5, 3.9, 1.1, 2}, {5.9, 3.2, 4.8, 1.8, 2}, {6.1, 2.8, 4.0, 1.3, 2}, {6.3, 2.5, 4.9, 1.5, 2}, {6.1, 2.8, 4.7, 1.2, 2}, {6.4, 2.9, 4.3, 1.3, 2}, {6.6, 3.0, 4.4, 1.4, 2}, {6.8, 2.8, 4.8, 1.4, 2}, {6.7, 3.0, 5.0, 1.7, 2}, {6.0, 2.9, 4.5, 1.5, 2}, {5.7, 2.6, 3.5, 1.0, 2}, {5.5, 2.4, 3.8, 1.1, 2}, {5.5, 2.4, 3.7, 1.0, 2}, {5.8, 2.7, 3.9, 1.2, 2}, {6.0, 2.7, 5.1, 1.6, 2}, {5.4, 3.0, 4.5, 1.5, 2}, {6.0, 3.4, 4.5, 1.6, 2}, {6.7, 3.1, 4.7, 1.5, 2}, {6.3, 2.3, 4.4, 1.3, 2}, {5.6, 3.0, 4.1, 1.3, 2}, {5.5, 2.5, 4.0, 1.3, 2}, {5.5, 2.6, 4.4, 1.2, 2}, {6.1, 3.0, 4.6, 1.4, 2}, {5.8, 2.6, 4.0, 1.2, 2}, {5.0, 2.3, 3.3, 1.0, 2}, {5.6, 2.7, 4.2, 1.3, 2}, {5.7, 3.0, 4.2, 1.2, 2}, {5.7, 2.9, 4.2, 1.3, 2}, {6.2, 2.9, 4.3, 1.3, 2}, {5.1, 2.5, 3.0, 1.1, 2}, {5.7, 2.8, 4.1, 1.3, 2}, {6.3, 3.3, 6.0, 2.5, 3}, {5.8, 2.7, 5.1, 1.9, 3}, {7.1, 3.0, 5.9, 2.1, 3}, {6.3, 2.9, 5.6, 1.8, 3}, {6.5, 3.0, 5.8, 2.2, 3}, {7.6, 3.0, 6.6, 2.1, 3}, {4.9, 2.5, 4.5, 1.7, 3}, {7.3, 2.9, 6.3, 1.8, 3}, {6.7, 2.5, 5.8, 1.8, 3}, {7.2, 3.6, 6.1, 2.5, 3}, {6.5, 3.2, 5.1, 2.0, 3}, {6.4, 2.7, 5.3, 1.9, 3}, {6.8, 3.0, 5.5, 2.1, 3}, {5.7, 2.5, 5.0, 2.0, 3}, {5.8, 2.8, 5.1, 2.4, 3}, {6.4, 3.2, 5.3, 2.3, 3}, {6.5, 3.0, 5.5, 1.8, 3}, {7.7, 3.8, 6.7, 2.2, 3}, {7.7, 2.6, 6.9, 2.3, 3}, {6.0, 2.2, 5.0, 1.5, 3}, {6.9, 3.2, 5.7, 2.3, 3}, {5.6, 2.8, 4.9, 2.0, 3}, {7.7, 2.8, 6.7, 2.0, 3}, {6.3, 2.7, 4.9, 1.8, 3}, {6.7, 3.3, 5.7, 2.1, 3}, {7.2, 3.2, 6.0, 1.8, 3}, {6.2, 2.8, 4.8, 1.8, 3}, {6.1, 3.0, 4.9, 1.8, 3}, {6.4, 2.8, 5.6, 2.1, 3}, {7.2, 3.0, 5.8, 1.6, 3}, {7.4, 2.8, 6.1, 1.9, 3}, {7.9, 3.8, 6.4, 2.0, 3}, {6.4, 2.8, 5.6, 2.2, 3}, {6.3, 2.8, 5.1, 1.5, 3}, {6.1, 2.6, 5.6, 1.4, 3}, {7.7, 3.0, 6.1, 2.3, 3}, {6.3, 3.4, 5.6, 2.4, 3}, {6.4, 3.1, 5.5, 1.8, 3}, {6.0, 3.0, 4.8, 1.8, 3}, {6.9, 3.1, 5.4, 2.1, 3}, {6.7, 3.1, 5.6, 2.4, 3}, {6.9, 3.1, 5.1, 2.3, 3}, {5.8, 2.7, 5.1, 1.9, 3}, {6.8, 3.2, 5.9, 2.3, 3}, {6.7, 3.3, 5.7, 2.5, 3}, {6.7, 3.0, 5.2, 2.3, 3}, {6.3, 2.5, 5.0, 1.9, 3}, {6.5, 3.0, 5.2, 2.0, 3}, {6.2, 3.4, 5.4, 2.3, 3}, {5.9, 3.0, 5.1, 1.8, 3} }; public static void main(String[] args) throws Exception { double xData[][] = new double[nObs][nInputs]; int yData[] = new int[nObs]; for (int i = 0; i < nObs; i++) { System.arraycopy(irisData[i], 0, xData[i], 0, nInputs); yData[i] = (int) irisData[i][4]; } // Create network FeedForwardNetwork network = new FeedForwardNetwork(); network.getInputLayer().createInputs(nInputs); network.createHiddenLayer(). createPerceptrons(nPerceptrons1, Activation.LOGISTIC, 0.0); network.createHiddenLayer(). createPerceptrons(nPerceptrons2, Activation.LOGISTIC, 0.0); network.getOutputLayer(). createPerceptrons(nOutputs, Activation.SOFTMAX, 0.0); network.linkAll(); MultiClassification classification = new MultiClassification(network); // Create trainer QuasiNewtonTrainer trainer = new QuasiNewtonTrainer(); trainer.setError(classification.getError()); trainer.setMaximumTrainingIterations(1000); // Set up training logger if trace = true if (trace) { Handler handler = new FileHandler("ClassificationNetworkTraining.log"); Logger logger = Logger.getLogger("com.imsl.datamining.neural"); logger.setLevel(Level.FINEST); logger.addHandler(handler); handler.setFormatter(QuasiNewtonTrainer.getFormatter()); } // Train network long t0 = System.currentTimeMillis(); classification.train(trainer, xData, yData); long t1 = System.currentTimeMillis(); double time = t1 - t0; time = time / 1000; System.out.println("****************Time: " + time); // Display network errors double stats[] = classification.computeStatistics(xData, yData); System.out.println("***********************************************"); System.out.println("--> Cross-entropy error: " + (float) stats[0]); System.out.println("--> Classification error rate: " + (float) stats[1]); System.out.println("***********************************************"); System.out.println(""); // Display weights and gradients double weight[] = network.getWeights(); double gradient[] = trainer.getErrorGradient(); double wg[][] = new double[weight.length][2]; for (int i = 0; i < weight.length; i++) { wg[i][0] = weight[i]; wg[i][1] = gradient[i]; } PrintMatrixFormat pmf = new PrintMatrixFormat(); pmf.setNumberFormat(new java.text.DecimalFormat("0.000000")); pmf.setColumnLabels(new String[]{"Weights", "Gradients"}); new PrintMatrix().print(pmf, wg); // Display fitted values double report[][] = new double[nObs][nInputs + 5]; for (int i = 0; i < nObs; i++) { System.arraycopy(xData[i], 0, report[i], 0, nInputs); report[i][nInputs + 3] = irisData[i][4]; report[i][nInputs + 4] = classification.predictedClass(xData[i]); double[] probabilities = classification.probabilities(xData[i]); report[i][nInputs]=probabilities[0]; report[i][nInputs+1]=probabilities[1]; report[i][nInputs+2]=probabilities[2]; } pmf = new PrintMatrixFormat(); pmf.setColumnLabels(new String[]{ "Sepal Length", "Sepal Width", "Petal Length", "Petal Width", "P(Y=1)", "P(Y=2)", "P(Y=3)", "Actual", "Predicted"} ); new PrintMatrix("Fitted values").print(pmf, report); } }