This example trains a 2-layer network using three binary inputs (X0, X1, X2) and one three-level classification (Y). Where
Y = 0 if X1 = 1
Y = 1 if X2 = 1
Y = 2 if X3 = 1
import com.imsl.datamining.neural.*; import com.imsl.math.PrintMatrix; import com.imsl.math.PrintMatrixFormat; import java.io.*; import java.util.logging.*; //***************************************************************************** // Two-Layer FFN with 3 binary inputs (X0, X1, X2) and one three-level // classification variable (Y) // Y = 0 if X1 = 1 // Y = 1 if X2 = 1 // Y = 2 if X3 = 1 // (training_ex6) //***************************************************************************** public class MultiClassificationEx2 implements Serializable { private static int nObs = 6; // number of training patterns private static int nInputs = 3; // 3 inputs, all categorical private static int nOutputs = 3; // private static boolean trace = true; // Turns on/off training log private static double xData[][] = { {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1}, {0, 0, 1} }; private static int yData[] = {1, 1, 2, 2, 3, 3}; private static double weights[] = { 1.29099444873580580000,-0.64549722436790280000,-0.64549722436790291000, 0.00000000000000000000, 1.11803398874989490000,-1.11803398874989470000, 0.57735026918962584000, 0.57735026918962584000, 0.57735026918962584000, 0.33333333333333331000, 0.33333333333333331000, 0.33333333333333331000, 0.33333333333333331000, 0.33333333333333331000, 0.33333333333333331000, 0.33333333333333331000, 0.33333333333333331000, 0.33333333333333331000, -0.00000000000000005851,-0.00000000000000005851,-0.57735026918962573000, 0.00000000000000000000, 0.00000000000000000000, 0.00000000000000000000}; public static void main(String[] args) throws Exception { FeedForwardNetwork network = new FeedForwardNetwork(); network.getInputLayer().createInputs(nInputs); network.createHiddenLayer().createPerceptrons(3, Activation.LINEAR, 0.0); //network.createHiddenLayer().createPerceptrons(4, Activation.TANH, 0.0); network.getOutputLayer().createPerceptrons(nOutputs, Activation.SOFTMAX, 0.0); network.linkAll(); network.setWeights(weights); MultiClassification classification = new MultiClassification(network); QuasiNewtonTrainer trainer = new QuasiNewtonTrainer(); trainer.setError(classification.getError()); trainer.setMaximumTrainingIterations(1000); trainer.setFalseConvergenceTolerance(1.0e-20); trainer.setGradientTolerance(1.0e-20); trainer.setRelativeTolerance(1.0e-20); trainer.setStepTolerance(1.0e-20); // If tracing is requested setup training logger if (trace) { Handler handler = new FileHandler("ClassificationNetworkEx2.log"); Logger logger = Logger.getLogger("com.imsl.datamining.neural"); logger.setLevel(Level.FINEST); logger.addHandler(handler); handler.setFormatter(QuasiNewtonTrainer.getFormatter()); } // Train Network classification.train(trainer, xData, yData); // Display Network Errors double stats[] = classification.computeStatistics(xData, yData); System.out.println("***********************************************"); System.out.println("--> Cross-Entropy Error: "+(float)stats[0]); System.out.println("--> Classification Error: "+(float)stats[1]); System.out.println("***********************************************"); System.out.println(); double weight[] = network.getWeights(); double gradient[] = trainer.getErrorGradient(); double wg[][] = new double[weight.length][2]; for(int i = 0; i < weight.length; i++) { wg[i][0] = weight[i]; wg[i][1] = gradient[i]; } PrintMatrixFormat pmf = new PrintMatrixFormat(); pmf.setNumberFormat(new java.text.DecimalFormat("0.000000")); pmf.setColumnLabels(new String[]{"Weights", "Gradients"}); new PrintMatrix().print(pmf,wg); double report[][] = new double[nObs][nInputs+nOutputs+2]; for (int i = 0; i < nObs; i++) { for (int j = 0; j < nInputs; j++) { report[i][j] = xData[i][j]; } report[i][nInputs] = yData[i]; double p[] = classification.probabilities(xData[i]); for (int j = 0; j < nOutputs; j++) { report[i][nInputs+1+j] = p[j]; } report[i][nInputs+nOutputs+1] = classification.predictedClass(xData[i]); } pmf = new PrintMatrixFormat(); pmf.setColumnLabels(new String[]{"X1", "X2", "X3", "Y", "P(C1)", "P(C2)", "P(C3)", "Predicted"}); new PrintMatrix("Forecast").print(pmf, report); System.out.println("Cross-Entropy Error Value = "+trainer.getErrorValue()); // ********************************************************************** // DISPLAY CLASSIFICATION STATISTICS // ********************************************************************** double statsClass[] = classification.computeStatistics(xData, yData); // Display Network Errors System.out.println("***********************************************"); System.out.println("--> Cross-Entropy Error: "+(float)statsClass[0]); System.out.println("--> Classification Error: "+(float)statsClass[1]); System.out.println("***********************************************"); System.out.println(""); } }
*********************************************** --> Cross-Entropy Error: 0.0 --> Classification Error: 0.0 *********************************************** Weights Gradients 0 3.401208 -0.000000 1 -4.126657 0.000000 2 -2.201606 -0.000000 3 -2.009527 0.000000 4 3.173323 -0.000000 5 -4.200377 -0.000000 6 0.028736 -0.000000 7 2.657051 0.000000 8 4.868134 -0.000000 9 3.711295 -0.000000 10 -2.723536 -0.000000 11 0.012241 0.000000 12 -4.996359 0.000000 13 4.296983 0.000000 14 1.699376 -0.000000 15 -1.993114 0.000000 16 -4.048833 0.000000 17 7.041948 -0.000000 18 -0.447927 -0.000000 19 0.653830 0.000000 20 -0.925019 -0.000000 21 -0.078963 0.000000 22 0.247835 0.000000 23 -0.168872 -0.000000 Forecast X1 X2 X3 Y P(C1) P(C2) P(C3) Predicted 0 1 0 0 1 1 0 0 1 1 1 0 0 1 1 0 0 1 2 0 1 0 2 0 1 0 2 3 0 1 0 2 0 1 0 2 4 0 0 1 3 0 0 1 3 5 0 0 1 3 0 0 1 3 Cross-Entropy Error Value = 0.0 *********************************************** --> Cross-Entropy Error: 0.0 --> Classification Error: 0.0 ***********************************************Link to Java source.