This example trains a 3-layer network using 48 training patterns from four nominal input attributes. The first two nominal attributes have two classifications. The third and fourth nominal attributes have three and four classifications respectively. All four attributes are encoded using binary encoding. This results in eleven binary network input columns. The output class is 1 if the first two nominal attributes sum to 1, and 0 otherwise.
The structure of the network consists of eleven input nodes and three layers, with three perceptrons in the first hidden layer, two perceptrons in the second hidden layer, and one perceptron in the output layer.
There are a total of 47 weights in this network, including the six bias weights. The linearactivations function is used for both hidden layers. Since the target output is binary classification the logistic activation function is used in the output layer. Training is conducted using the quasi-newton trainer with the binary-entropy error function provided by the BinaryClassification
class.
import com.imsl.datamining.neural.*; import java.io.*; import java.util.logging.*; import com.imsl.math.PrintMatrix; import com.imsl.math.PrintMatrixFormat; import java.util.Random; //***************************************************************************** // Two Layer Feed-Forward Network with 11 inputs: 4 nominal with 2,2,3,4 categories, // encoded using binary encoding, and 1 output target (class). // // new classification training_ex1.c //***************************************************************************** public class BinaryClassificationEx1 implements Serializable { // Network Settings private static int nObs = 48; // number of training patterns private static int nInputs = 11; // four nominal with 2,2,3,4 categories private static int nCategorical = 11; // three categorical attributes private static int nOutputs = 1; // one continuous output (nClasses=2) private static int nPerceptrons1 = 3; // perceptrons in 1st hidden layer private static int nPerceptrons2 = 2; // perceptrons in 2nd hidden layer private static boolean trace = true; // Turns on/off training log private static Activation hiddenLayerActivation = Activation.LINEAR; private static Activation outputLayerActivation = Activation.LOGISTIC; /* 2 classifications */ private static int[] x1 = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; /* 2 classifications */ private static int[] x2 = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}; /* 3 classifications */ private static int[] x3 = { 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3 }; /* 4 classifications */ private static int[] x4 = { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }; // ********************************************************************** // MAIN // ********************************************************************** public static void main(String[] args) throws Exception { double x[]; // temporary x space for generating forecasts double xData[][]; // Input Attributes for Trainer int yData[]; // Output Attributes for Trainer int i, j; // array indicies int nWeights = 0; // Number of weights obtained from network String trainLogName = "BinaryClassificationExample.log"; // ****************************************************************** // Binary encode 4 categorical variables. // Var x1 contains 2 classes // Var x2 contains 2 classes // Var x3 contains 3 classes // Var x4 contains 4 classes // ******************************************************************* int[][] z1; int[][] z2; int[][] z3; int[][] z4; UnsupervisedNominalFilter filter = new UnsupervisedNominalFilter(2); z1 = filter.encode(x1); z2 = filter.encode(x2); filter = new UnsupervisedNominalFilter(3); z3 = filter.encode(x3); filter = new UnsupervisedNominalFilter(4); z4 = filter.encode(x4); /* Concatenate binary encoded z's */ xData = new double[nObs][nInputs]; yData = new int[nObs]; for (i=0; i<(nObs); i++) { for (j=0; j <nCategorical; j++) { xData[i][j] = 0; if (j < 2) xData[i][j] = (double) z1[i][j]; if (j > 1 && j < 4) xData[i][j] = (double) z2[i][j-2]; if (j > 3 && j < 7) xData[i][j] = (double) z3[i][j-4]; if (j > 6) xData[i][j] = (double)z4[i][j-7]; } yData[i] = ((x1[i] +x2[i] == 2) ? 1 : 0); } // ********************************************************************** // CREATE FEEDFORWARD NETWORK // ********************************************************************** long t0 = System.currentTimeMillis(); FeedForwardNetwork network = new FeedForwardNetwork(); network.getInputLayer().createInputs(nInputs); network.createHiddenLayer().createPerceptrons(nPerceptrons1); network.createHiddenLayer().createPerceptrons(nPerceptrons2); network.getOutputLayer().createPerceptrons(nOutputs); BinaryClassification classification = new BinaryClassification(network); network.linkAll(); Random r = new Random(123457L); network.setRandomWeights(xData, r); Perceptron perceptrons[] = network.getPerceptrons(); for (i=0; i < perceptrons.length-1; i++) { perceptrons[i].setActivation(hiddenLayerActivation); } perceptrons[perceptrons.length-1].setActivation(outputLayerActivation); // ********************************************************************** // TRAIN NETWORK USING QUASI-NEWTON TRAINER // ********************************************************************** QuasiNewtonTrainer trainer = new QuasiNewtonTrainer(); trainer.setError(classification.getError()); trainer.setMaximumTrainingIterations(1000); trainer.setMaximumStepsize(3.0); trainer.setGradientTolerance(1.0e-20); trainer.setFalseConvergenceTolerance(1.0e-20); trainer.setStepTolerance(1.0e-20); trainer.setRelativeTolerance(1.0e-20); if (trace) { try { Handler handler = new FileHandler(trainLogName); Logger logger = Logger.getLogger("com.imsl.datamining.neural"); logger.setLevel(Level.FINEST); logger.addHandler(handler); handler.setFormatter(QuasiNewtonTrainer.getFormatter()); System.out.println("--> Training Log Created in "+ trainLogName); } catch (Exception e) { System.out.println("--> Cannot Create Training Log."); } } classification.train(trainer, xData, yData); // ********************************************************************** // DISPLAY TRAINING STATISTICS // ********************************************************************** double stats[] = classification.computeStatistics(xData, yData); System.out.println("***********************************************"); System.out.println("--> Cross-entropy error: "+(float)stats[0]); System.out.println("--> Classification error rate: "+(float)stats[1]); System.out.println("***********************************************"); System.out.println(""); // ********************************************************************** // OBTAIN AND DISPLAY NETWORK WEIGHTS AND GRADIENTS // ********************************************************************** double weight[] = network.getWeights(); double gradient[] = trainer.getErrorGradient(); double wg[][] = new double[weight.length][2]; for(i = 0; i < weight.length; i++) { wg[i][0] = weight[i]; wg[i][1] = gradient[i]; } PrintMatrixFormat pmf = new PrintMatrixFormat(); pmf.setNumberFormat(new java.text.DecimalFormat("0.000000")); pmf.setColumnLabels(new String[]{"Weights", "Gradients"}); new PrintMatrix().print(pmf,wg); // **************************** // forecast the network // **************************** double report[][] = new double[nObs][6]; for ( i = 0; i < nObs; i++) { report[i][0] = x1[i]; report[i][1] = x2[i]; report[i][2] = x3[i]; report[i][3] = x4[i]; report[i][4] = yData[i]; report[i][5] = classification.predictedClass(xData[i]); } pmf = new PrintMatrixFormat(); pmf.setColumnLabels(new String[]{ "X1", "X2", "X3", "X4", "Expected", "Predicted"}); new PrintMatrix("Forecast").print(pmf, report); // ********************************************************************** // DISPLAY CLASSIFICATION STATISTICS // ********************************************************************** double statsClass[] = classification.computeStatistics(xData, yData); // Display Network Errors System.out.println("***********************************************"); System.out.println("--> Cross-Entropy Error: "+(float)statsClass[0]); System.out.println("--> Classification Error: "+(float)statsClass[1]); System.out.println("***********************************************"); System.out.println(""); long t1 = System.currentTimeMillis(); double small = 1.e-7; double time = t1-t0; time = time/1000; System.out.println("****************Time: "+time); System.out.println("trainer.getErrorValue = "+trainer.getErrorValue()); } }
--> Training Log Created in BinaryClassificationExample.log *********************************************** --> Cross-entropy error: 1.8296475E-13 --> Classification error rate: 0.0 *********************************************** Weights Gradients 0 2.575599 -0.000000 1 1.770546 -0.000000 2 1.675687 -0.000000 3 -5.859796 0.000000 4 -1.794721 0.000000 5 -4.925026 0.000000 6 3.654187 0.000000 7 2.089872 0.000000 8 2.485173 0.000000 9 -5.238608 0.000000 10 -1.396975 0.000000 11 -4.730949 0.000000 12 0.143083 0.000000 13 0.777367 0.000000 14 0.316769 0.000000 15 -3.270781 -0.000000 16 0.283153 -0.000000 17 -0.162338 -0.000000 18 1.153316 0.000000 19 0.782549 0.000000 20 -0.387279 0.000000 21 -2.010958 -0.000000 22 0.273662 -0.000000 23 -0.670019 -0.000000 24 2.096144 0.000000 25 -0.264374 0.000000 26 0.351305 0.000000 27 1.190361 0.000000 28 -0.053966 0.000000 29 0.555192 0.000000 30 -2.001125 -0.000000 31 0.735950 -0.000000 32 -0.829534 -0.000000 33 -4.824521 0.000000 34 -4.824521 0.000000 35 -0.652606 0.000000 36 -0.652606 0.000000 37 -2.921224 0.000000 38 -2.921224 0.000000 39 -1.621591 0.000000 40 -1.621591 0.000000 41 -1.967947 0.000000 42 1.534864 0.000000 43 0.907830 0.000000 44 1.594078 -0.000000 45 1.594078 -0.000000 46 -0.169361 0.000000 Forecast X1 X2 X3 X4 Expected Predicted 0 1 1 1 1 1 1 1 1 1 1 2 1 1 2 1 1 1 3 1 1 3 1 1 1 4 1 1 4 1 1 2 1 1 1 5 1 1 2 2 1 1 6 1 1 2 3 1 1 7 1 1 2 4 1 1 8 1 1 3 1 1 1 9 1 1 3 2 1 1 10 1 1 3 3 1 1 11 1 1 3 4 1 1 12 1 2 1 1 0 0 13 1 2 1 2 0 0 14 1 2 1 3 0 0 15 1 2 1 4 0 0 16 1 2 2 1 0 0 17 1 2 2 2 0 0 18 1 2 2 3 0 0 19 1 2 2 4 0 0 20 1 2 3 1 0 0 21 1 2 3 2 0 0 22 1 2 3 3 0 0 23 1 2 3 4 0 0 24 2 1 1 1 0 0 25 2 1 1 2 0 0 26 2 1 1 3 0 0 27 2 1 1 4 0 0 28 2 1 2 1 0 0 29 2 1 2 2 0 0 30 2 1 2 3 0 0 31 2 1 2 4 0 0 32 2 1 3 1 0 0 33 2 1 3 2 0 0 34 2 1 3 3 0 0 35 2 1 3 4 0 0 36 2 2 1 1 0 0 37 2 2 1 2 0 0 38 2 2 1 3 0 0 39 2 2 1 4 0 0 40 2 2 2 1 0 0 41 2 2 2 2 0 0 42 2 2 2 3 0 0 43 2 2 2 4 0 0 44 2 2 3 1 0 0 45 2 2 3 2 0 0 46 2 2 3 3 0 0 47 2 2 3 4 0 0 *********************************************** --> Cross-Entropy Error: 1.8296475E-13 --> Classification Error: 0.0 *********************************************** ****************Time: 0.203 trainer.getErrorValue = 1.8296475445823478E-13Link to Java source.