package com.imsl.test.example.datamining.neural;
import com.imsl.datamining.neural.Activation;
import com.imsl.datamining.neural.FeedForwardNetwork;
import com.imsl.datamining.neural.MultiClassification;
import com.imsl.datamining.neural.QuasiNewtonTrainer;
import com.imsl.math.PrintMatrix;
import com.imsl.math.PrintMatrixFormat;
import java.io.Serializable;
import java.util.logging.FileHandler;
import java.util.logging.Handler;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
*
*
Trains a 3-layer network to Fisher's iris data.
*
* This example trains a 3-layer network using Fisher's Iris data with four
* continuous input attributes and three output classifications.
* Fisher's Iris data is perhaps the best known data set in the pattern recognition
* literature. Published in 1936 Fisher's paper is a classic in the field.
* The data set contains 3 classes of
* 50 instances each, where each class refers to a type of iris plant. The continuous
* attributes are measurements of petal length, petal width, sepal length, sepal width.
*
*
* The structure of the network consists of four input nodes and three layers,
* with four perceptrons in the first hidden layer, three perceptrons in the
* second hidden layer and three in the output layer.
*
*
* The four input attributes represent
*
* - Sepal length
* - Sepal width
* - Petal length
* - Petal width
*
*
*
* The output attribute represents the class of the iris plant and are encoded
* using binary encoding:
*
* - Iris Setosa
* - Iris Versicolour
* - Iris Virginica
*
*
*
* With 4 inputs, 4 and 3 perceptrons in the hidden layers, and 3 perceptrons in the output
* layer there are a total of 47 weights in this network, including the bias weights.
* All hidden layers use the logistic activation function. Since the target
* output is multi-classification the softmax activation function is used in the
* output layer and the
* MultiClassification
* error function class is
* used by the trainer. The error class MultiClassification
* combines the cross-entropy error calculations and the softmax function.
*
*
* @see Code
* @see Output
*/
public class MultiClassificationEx1 implements Serializable {
private static int nObs = 150;
private static int nInputs = 4;
private static int nOutputs = 3;
private static int nPerceptrons1 = 4;
private static int nPerceptrons2 = 3;
private static boolean trace = false; // turn on/off tracing
// irisData: the first 4 columns are the continuous input attributes and
// the 5th column is the classification category (1-3).
private static double[][] irisData = {
{5.1, 3.5, 1.4, 0.2, 1}, {4.9, 3.0, 1.4, 0.2, 1},
{4.7, 3.2, 1.3, 0.2, 1}, {4.6, 3.1, 1.5, 0.2, 1},
{5.0, 3.6, 1.4, 0.2, 1}, {5.4, 3.9, 1.7, 0.4, 1},
{4.6, 3.4, 1.4, 0.3, 1}, {5.0, 3.4, 1.5, 0.2, 1},
{4.4, 2.9, 1.4, 0.2, 1}, {4.9, 3.1, 1.5, 0.1, 1},
{5.4, 3.7, 1.5, 0.2, 1}, {4.8, 3.4, 1.6, 0.2, 1},
{4.8, 3.0, 1.4, 0.1, 1}, {4.3, 3.0, 1.1, 0.1, 1},
{5.8, 4.0, 1.2, 0.2, 1}, {5.7, 4.4, 1.5, 0.4, 1},
{5.4, 3.9, 1.3, 0.4, 1}, {5.1, 3.5, 1.4, 0.3, 1},
{5.7, 3.8, 1.7, 0.3, 1}, {5.1, 3.8, 1.5, 0.3, 1},
{5.4, 3.4, 1.7, 0.2, 1}, {5.1, 3.7, 1.5, 0.4, 1},
{4.6, 3.6, 1.0, 0.2, 1}, {5.1, 3.3, 1.7, 0.5, 1},
{4.8, 3.4, 1.9, 0.2, 1}, {5.0, 3.0, 1.6, 0.2, 1},
{5.0, 3.4, 1.6, 0.4, 1}, {5.2, 3.5, 1.5, 0.2, 1},
{5.2, 3.4, 1.4, 0.2, 1}, {4.7, 3.2, 1.6, 0.2, 1},
{4.8, 3.1, 1.6, 0.2, 1}, {5.4, 3.4, 1.5, 0.4, 1},
{5.2, 4.1, 1.5, 0.1, 1}, {5.5, 4.2, 1.4, 0.2, 1},
{4.9, 3.1, 1.5, 0.1, 1}, {5.0, 3.2, 1.2, 0.2, 1},
{5.5, 3.5, 1.3, 0.2, 1}, {4.9, 3.1, 1.5, 0.1, 1},
{4.4, 3.0, 1.3, 0.2, 1}, {5.1, 3.4, 1.5, 0.2, 1},
{5.0, 3.5, 1.3, 0.3, 1}, {4.5, 2.3, 1.3, 0.3, 1},
{4.4, 3.2, 1.3, 0.2, 1}, {5.0, 3.5, 1.6, 0.6, 1},
{5.1, 3.8, 1.9, 0.4, 1}, {4.8, 3.0, 1.4, 0.3, 1},
{5.1, 3.8, 1.6, 0.2, 1}, {4.6, 3.2, 1.4, 0.2, 1},
{5.3, 3.7, 1.5, 0.2, 1}, {5.0, 3.3, 1.4, 0.2, 1},
{7.0, 3.2, 4.7, 1.4, 2}, {6.4, 3.2, 4.5, 1.5, 2},
{6.9, 3.1, 4.9, 1.5, 2}, {5.5, 2.3, 4.0, 1.3, 2},
{6.5, 2.8, 4.6, 1.5, 2}, {5.7, 2.8, 4.5, 1.3, 2},
{6.3, 3.3, 4.7, 1.6, 2}, {4.9, 2.4, 3.3, 1.0, 2},
{6.6, 2.9, 4.6, 1.3, 2}, {5.2, 2.7, 3.9, 1.4, 2},
{5.0, 2.0, 3.5, 1.0, 2}, {5.9, 3.0, 4.2, 1.5, 2},
{6.0, 2.2, 4.0, 1.0, 2}, {6.1, 2.9, 4.7, 1.4, 2},
{5.6, 2.9, 3.6, 1.3, 2}, {6.7, 3.1, 4.4, 1.4, 2},
{5.6, 3.0, 4.5, 1.5, 2}, {5.8, 2.7, 4.1, 1.0, 2},
{6.2, 2.2, 4.5, 1.5, 2}, {5.6, 2.5, 3.9, 1.1, 2},
{5.9, 3.2, 4.8, 1.8, 2}, {6.1, 2.8, 4.0, 1.3, 2},
{6.3, 2.5, 4.9, 1.5, 2}, {6.1, 2.8, 4.7, 1.2, 2},
{6.4, 2.9, 4.3, 1.3, 2}, {6.6, 3.0, 4.4, 1.4, 2},
{6.8, 2.8, 4.8, 1.4, 2}, {6.7, 3.0, 5.0, 1.7, 2},
{6.0, 2.9, 4.5, 1.5, 2}, {5.7, 2.6, 3.5, 1.0, 2},
{5.5, 2.4, 3.8, 1.1, 2}, {5.5, 2.4, 3.7, 1.0, 2},
{5.8, 2.7, 3.9, 1.2, 2}, {6.0, 2.7, 5.1, 1.6, 2},
{5.4, 3.0, 4.5, 1.5, 2}, {6.0, 3.4, 4.5, 1.6, 2},
{6.7, 3.1, 4.7, 1.5, 2}, {6.3, 2.3, 4.4, 1.3, 2},
{5.6, 3.0, 4.1, 1.3, 2}, {5.5, 2.5, 4.0, 1.3, 2},
{5.5, 2.6, 4.4, 1.2, 2}, {6.1, 3.0, 4.6, 1.4, 2},
{5.8, 2.6, 4.0, 1.2, 2}, {5.0, 2.3, 3.3, 1.0, 2},
{5.6, 2.7, 4.2, 1.3, 2}, {5.7, 3.0, 4.2, 1.2, 2},
{5.7, 2.9, 4.2, 1.3, 2}, {6.2, 2.9, 4.3, 1.3, 2},
{5.1, 2.5, 3.0, 1.1, 2}, {5.7, 2.8, 4.1, 1.3, 2},
{6.3, 3.3, 6.0, 2.5, 3}, {5.8, 2.7, 5.1, 1.9, 3},
{7.1, 3.0, 5.9, 2.1, 3}, {6.3, 2.9, 5.6, 1.8, 3},
{6.5, 3.0, 5.8, 2.2, 3}, {7.6, 3.0, 6.6, 2.1, 3},
{4.9, 2.5, 4.5, 1.7, 3}, {7.3, 2.9, 6.3, 1.8, 3},
{6.7, 2.5, 5.8, 1.8, 3}, {7.2, 3.6, 6.1, 2.5, 3},
{6.5, 3.2, 5.1, 2.0, 3}, {6.4, 2.7, 5.3, 1.9, 3},
{6.8, 3.0, 5.5, 2.1, 3}, {5.7, 2.5, 5.0, 2.0, 3},
{5.8, 2.8, 5.1, 2.4, 3}, {6.4, 3.2, 5.3, 2.3, 3},
{6.5, 3.0, 5.5, 1.8, 3}, {7.7, 3.8, 6.7, 2.2, 3},
{7.7, 2.6, 6.9, 2.3, 3}, {6.0, 2.2, 5.0, 1.5, 3},
{6.9, 3.2, 5.7, 2.3, 3}, {5.6, 2.8, 4.9, 2.0, 3},
{7.7, 2.8, 6.7, 2.0, 3}, {6.3, 2.7, 4.9, 1.8, 3},
{6.7, 3.3, 5.7, 2.1, 3}, {7.2, 3.2, 6.0, 1.8, 3},
{6.2, 2.8, 4.8, 1.8, 3}, {6.1, 3.0, 4.9, 1.8, 3},
{6.4, 2.8, 5.6, 2.1, 3}, {7.2, 3.0, 5.8, 1.6, 3},
{7.4, 2.8, 6.1, 1.9, 3}, {7.9, 3.8, 6.4, 2.0, 3},
{6.4, 2.8, 5.6, 2.2, 3}, {6.3, 2.8, 5.1, 1.5, 3},
{6.1, 2.6, 5.6, 1.4, 3}, {7.7, 3.0, 6.1, 2.3, 3},
{6.3, 3.4, 5.6, 2.4, 3}, {6.4, 3.1, 5.5, 1.8, 3},
{6.0, 3.0, 4.8, 1.8, 3}, {6.9, 3.1, 5.4, 2.1, 3},
{6.7, 3.1, 5.6, 2.4, 3}, {6.9, 3.1, 5.1, 2.3, 3},
{5.8, 2.7, 5.1, 1.9, 3}, {6.8, 3.2, 5.9, 2.3, 3},
{6.7, 3.3, 5.7, 2.5, 3}, {6.7, 3.0, 5.2, 2.3, 3},
{6.3, 2.5, 5.0, 1.9, 3}, {6.5, 3.0, 5.2, 2.0, 3},
{6.2, 3.4, 5.4, 2.3, 3}, {5.9, 3.0, 5.1, 1.8, 3}
};
public static void main(String[] args) throws Exception {
double xData[][] = new double[nObs][nInputs];
int yData[] = new int[nObs];
for (int i = 0; i < nObs; i++) {
System.arraycopy(irisData[i], 0, xData[i], 0, nInputs);
yData[i] = (int) irisData[i][4];
}
// Create network
FeedForwardNetwork network = new FeedForwardNetwork();
network.getInputLayer().createInputs(nInputs);
network.createHiddenLayer().
createPerceptrons(nPerceptrons1, Activation.LOGISTIC, 0.0);
network.createHiddenLayer().
createPerceptrons(nPerceptrons2, Activation.LOGISTIC, 0.0);
network.getOutputLayer().
createPerceptrons(nOutputs, Activation.SOFTMAX, 0.0);
network.linkAll();
MultiClassification classification = new MultiClassification(network);
// Create trainer
QuasiNewtonTrainer trainer = new QuasiNewtonTrainer();
trainer.setError(classification.getError());
trainer.setMaximumTrainingIterations(1000);
// Set up training logger if trace = true
if (trace) {
Handler handler
= new FileHandler("ClassificationNetworkTraining.log");
Logger logger = Logger.getLogger("com.imsl.datamining.neural");
logger.setLevel(Level.FINEST);
logger.addHandler(handler);
handler.setFormatter(QuasiNewtonTrainer.getFormatter());
}
// Train network
long t0 = System.currentTimeMillis();
classification.train(trainer, xData, yData);
long t1 = System.currentTimeMillis();
double time = t1 - t0;
time = time / 1000;
System.out.println("****************Time: " + time);
// Display network errors
double stats[] = classification.computeStatistics(xData, yData);
System.out.println("***********************************************");
System.out.println("--> Cross-entropy error: "
+ (float) stats[0]);
System.out.println("--> Classification error rate: "
+ (float) stats[1]);
System.out.println("***********************************************");
System.out.println("");
// Display weights and gradients
double weight[] = network.getWeights();
double gradient[] = trainer.getErrorGradient();
double wg[][] = new double[weight.length][2];
for (int i = 0; i < weight.length; i++) {
wg[i][0] = weight[i];
wg[i][1] = gradient[i];
}
PrintMatrixFormat pmf = new PrintMatrixFormat();
pmf.setNumberFormat(new java.text.DecimalFormat("0.000000"));
pmf.setColumnLabels(new String[]{"Weights", "Gradients"});
new PrintMatrix().print(pmf, wg);
// Display fitted values
double report[][] = new double[nObs][nInputs + 5];
for (int i = 0; i < nObs; i++) {
System.arraycopy(xData[i], 0, report[i], 0, nInputs);
report[i][nInputs + 3] = irisData[i][4];
report[i][nInputs + 4] = classification.predictedClass(xData[i]);
double[] probabilities = classification.probabilities(xData[i]);
report[i][nInputs]=probabilities[0];
report[i][nInputs+1]=probabilities[1];
report[i][nInputs+2]=probabilities[2];
}
pmf = new PrintMatrixFormat();
pmf.setColumnLabels(new String[]{
"Sepal Length",
"Sepal Width",
"Petal Length",
"Petal Width",
"P(Y=1)",
"P(Y=2)",
"P(Y=3)",
"Actual",
"Predicted"}
);
new PrintMatrix("Fitted values").print(pmf, report);
}
}