Example 2: MultiClassification

This example trains a 2-layer network using three binary inputs (X0, X1, X2) and one three-level classification (Y). Where

Y = 0 if X1 = 1

Y = 1 if X2 = 1

Y = 2 if X3 = 1

import com.imsl.datamining.neural.*;
import com.imsl.math.PrintMatrix;
import com.imsl.math.PrintMatrixFormat;
import java.io.*;
import java.util.logging.*;

//*****************************************************************************
// Two-Layer FFN with 3 binary inputs (X0, X1, X2) and one three-level
// classification variable (Y)
// Y = 0 if X1 = 1
// Y = 1 if X2 = 1
// Y = 2 if X3 = 1
//  (training_ex6)
//*****************************************************************************

public class MultiClassificationEx2 implements Serializable {
    private static int nObs          =  6; // number of training patterns
    private static int nInputs       =  3; // 3 inputs, all categorical
    private static int nOutputs      =  3; // 
    private static boolean trace     = true; // Turns on/off training log  
    private static double xData[][] = {
        {1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0}, {0, 0, 1}, {0, 0, 1}
    };
    private static int yData[] = {1, 1, 2, 2, 3, 3};
 
    private static double weights[] = {
         1.29099444873580580000,-0.64549722436790280000,-0.64549722436790291000,
         0.00000000000000000000, 1.11803398874989490000,-1.11803398874989470000,
         0.57735026918962584000, 0.57735026918962584000, 0.57735026918962584000,
         0.33333333333333331000, 0.33333333333333331000, 0.33333333333333331000,
         0.33333333333333331000, 0.33333333333333331000, 0.33333333333333331000,
         0.33333333333333331000, 0.33333333333333331000, 0.33333333333333331000,
        -0.00000000000000005851,-0.00000000000000005851,-0.57735026918962573000,
         0.00000000000000000000, 0.00000000000000000000, 0.00000000000000000000};

    public static void main(String[] args) throws Exception {
       FeedForwardNetwork network = new FeedForwardNetwork();
       network.getInputLayer().createInputs(nInputs);
       network.createHiddenLayer().createPerceptrons(3, Activation.LINEAR, 0.0);
       //network.createHiddenLayer().createPerceptrons(4, Activation.TANH, 0.0);
       network.getOutputLayer().createPerceptrons(nOutputs, Activation.SOFTMAX, 0.0);
       network.linkAll();
       network.setWeights(weights);
       
       MultiClassification classification = new MultiClassification(network);

       QuasiNewtonTrainer trainer = new QuasiNewtonTrainer();
       trainer.setError(classification.getError());
       trainer.setMaximumTrainingIterations(1000);
       trainer.setFalseConvergenceTolerance(1.0e-20);
       trainer.setGradientTolerance(1.0e-20);
       trainer.setRelativeTolerance(1.0e-20);
       trainer.setStepTolerance(1.0e-20);
       
       // If tracing is requested setup training logger
       if (trace) {
           Handler handler = new FileHandler("ClassificationNetworkEx2.log");
           Logger logger = Logger.getLogger("com.imsl.datamining.neural");
           logger.setLevel(Level.FINEST);
           logger.addHandler(handler);
           handler.setFormatter(QuasiNewtonTrainer.getFormatter());
       }
       // Train Network
       classification.train(trainer, xData, yData);


       // Display Network Errors
       double stats[] = classification.computeStatistics(xData, yData);
       System.out.println("***********************************************");
       System.out.println("--> Cross-Entropy Error:      "+(float)stats[0]);
       System.out.println("--> Classification Error:     "+(float)stats[1]);
       System.out.println("***********************************************");
       System.out.println();
   
       double weight[]   = network.getWeights();
       double gradient[] = trainer.getErrorGradient();
       double wg[][] = new double[weight.length][2];
       for(int i = 0;  i < weight.length;  i++) {
          wg[i][0] = weight[i];
          wg[i][1] = gradient[i];
       }
       PrintMatrixFormat pmf = new PrintMatrixFormat();
       pmf.setNumberFormat(new java.text.DecimalFormat("0.000000"));
       pmf.setColumnLabels(new String[]{"Weights", "Gradients"});
       new PrintMatrix().print(pmf,wg);

       double report[][] = new double[nObs][nInputs+nOutputs+2];
       for (int i = 0;  i < nObs;  i++) {
            for (int j = 0;  j < nInputs;  j++) {
                report[i][j] = xData[i][j];
            }
            report[i][nInputs] = yData[i];
            double p[] = classification.probabilities(xData[i]);
            for (int j = 0;  j < nOutputs;  j++) {
                report[i][nInputs+1+j] = p[j];
            }
            report[i][nInputs+nOutputs+1] = classification.predictedClass(xData[i]);
        }
        pmf = new PrintMatrixFormat();
        pmf.setColumnLabels(new String[]{"X1", "X2", "X3", "Y", "P(C1)", "P(C2)",
            "P(C3)", "Predicted"});
        new PrintMatrix("Forecast").print(pmf, report);
        System.out.println("Cross-Entropy Error Value = "+trainer.getErrorValue());
        
    // **********************************************************************
    // DISPLAY CLASSIFICATION STATISTICS
    // **********************************************************************
       double statsClass[] = classification.computeStatistics(xData, yData);
       // Display Network Errors
       System.out.println("***********************************************");
       System.out.println("--> Cross-Entropy Error:      "+(float)statsClass[0]);
       System.out.println("--> Classification Error:     "+(float)statsClass[1]);
       System.out.println("***********************************************");
       System.out.println("");
        
    }
}

Output

***********************************************
--> Cross-Entropy Error:      0.0
--> Classification Error:     0.0
***********************************************

     Weights   Gradients  
 0   3.401208  -0.000000  
 1  -4.126657   0.000000  
 2  -2.201606  -0.000000  
 3  -2.009527   0.000000  
 4   3.173323  -0.000000  
 5  -4.200377  -0.000000  
 6   0.028736  -0.000000  
 7   2.657051   0.000000  
 8   4.868134  -0.000000  
 9   3.711295  -0.000000  
10  -2.723536  -0.000000  
11   0.012241   0.000000  
12  -4.996359   0.000000  
13   4.296983   0.000000  
14   1.699376  -0.000000  
15  -1.993114   0.000000  
16  -4.048833   0.000000  
17   7.041948  -0.000000  
18  -0.447927  -0.000000  
19   0.653830   0.000000  
20  -0.925019  -0.000000  
21  -0.078963   0.000000  
22   0.247835   0.000000  
23  -0.168872  -0.000000  

                     Forecast
   X1  X2  X3  Y  P(C1)  P(C2)  P(C3)  Predicted  
0  1   0   0   1    1      0      0        1      
1  1   0   0   1    1      0      0        1      
2  0   1   0   2    0      1      0        2      
3  0   1   0   2    0      1      0        2      
4  0   0   1   3    0      0      1        3      
5  0   0   1   3    0      0      1        3      

Cross-Entropy Error Value = 0.0
***********************************************
--> Cross-Entropy Error:      0.0
--> Classification Error:     0.0
***********************************************

Link to Java source.