Example 3: Naive Bayes Classifier Using User Supplied Probability Function

This example is the same as Example 1, using Fisher's (1936) Iris data to train a Naive Bayes classifier using 140 of the 150 continuous patterns, then classifies ten unknown plants using their sepal and petal measurements.

Instead of using the NormalDistribution class from the com.imsl.stat package, a user supplied normal (Gaussian) distribution is used. Rather than calculating the means and standard deviations from the data, as is done by the NormalDistribution 's eval(double[]) method, the user supplied class requires the means and standard deviations in the class constructor. The output is the same as in Example 1, since the means and standard deviations in this example are simply rounded means and standard deviations of the actual data subset by target classifications.


import com.imsl.datamining.*;
import com.imsl.stat.ProbabilityDistribution;
import java.io.*;

public class NaiveBayesClassifierEx3 {

   private static double[][] irisFisherData = {
      {1.0, 5.1, 3.5, 1.4, .2}, {1.0, 4.9, 3.0, 1.4, .2},
      {1.0, 4.7, 3.2, 1.3, .2}, {1.0, 4.6, 3.1, 1.5, .2},
      {1.0, 5.0, 3.6, 1.4, .2}, {1.0, 5.4, 3.9, 1.7, .4},
      {1.0, 4.6, 3.4, 1.4, .3}, {1.0, 5.0, 3.4, 1.5, .2},
      {1.0, 4.4, 2.9, 1.4, .2}, {1.0, 4.9, 3.1, 1.5, .1},
      {1.0, 5.4, 3.7, 1.5, .2}, {1.0, 4.8, 3.4, 1.6, .2},
      {1.0, 4.8, 3.0, 1.4, .1}, {1.0, 4.3, 3.0, 1.1, .1},
      {1.0, 5.8, 4.0, 1.2, .2}, {1.0, 5.7, 4.4, 1.5, .4},
      {1.0, 5.4, 3.9, 1.3, .4}, {1.0, 5.1, 3.5, 1.4, .3},
      {1.0, 5.7, 3.8, 1.7, .3}, {1.0, 5.1, 3.8, 1.5, .3},
      {1.0, 5.4, 3.4, 1.7, .2}, {1.0, 5.1, 3.7, 1.5, .4},
      {1.0, 4.6, 3.6, 1.0, .2}, {1.0, 5.1, 3.3, 1.7, .5},
      {1.0, 4.8, 3.4, 1.9, .2}, {1.0, 5.0, 3.0, 1.6, .2},
      {1.0, 5.0, 3.4, 1.6, .4}, {1.0, 5.2, 3.5, 1.5, .2},
      {1.0, 5.2, 3.4, 1.4, .2}, {1.0, 4.7, 3.2, 1.6, .2},
      {1.0, 4.8, 3.1, 1.6, .2}, {1.0, 5.4, 3.4, 1.5, .4},
      {1.0, 5.2, 4.1, 1.5, .1}, {1.0, 5.5, 4.2, 1.4, .2},
      {1.0, 4.9, 3.1, 1.5, .2}, {1.0, 5.0, 3.2, 1.2, .2},
      {1.0, 5.5, 3.5, 1.3, .2}, {1.0, 4.9, 3.6, 1.4, .1},
      {1.0, 4.4, 3.0, 1.3, .2}, {1.0, 5.1, 3.4, 1.5, .2},
      {1.0, 5.0, 3.5, 1.3, .3}, {1.0, 4.5, 2.3, 1.3, .3},
      {1.0, 4.4, 3.2, 1.3, .2}, {1.0, 5.0, 3.5, 1.6, .6},
      {1.0, 5.1, 3.8, 1.9, .4}, {1.0, 4.8, 3.0, 1.4, .3},
      {1.0, 5.1, 3.8, 1.6, .2}, {1.0, 4.6, 3.2, 1.4, .2},
      {1.0, 5.3, 3.7, 1.5, .2}, {1.0, 5.0, 3.3, 1.4, .2},
      {2.0, 7.0, 3.2, 4.7, 1.4}, {2.0, 6.4, 3.2, 4.5, 1.5},
      {2.0, 6.9, 3.1, 4.9, 1.5}, {2.0, 5.5, 2.3, 4.0, 1.3},
      {2.0, 6.5, 2.8, 4.6, 1.5}, {2.0, 5.7, 2.8, 4.5, 1.3},
      {2.0, 6.3, 3.3, 4.7, 1.6}, {2.0, 4.9, 2.4, 3.3, 1.0},
      {2.0, 6.6, 2.9, 4.6, 1.3}, {2.0, 5.2, 2.7, 3.9, 1.4},
      {2.0, 5.0, 2.0, 3.5, 1.0}, {2.0, 5.9, 3.0, 4.2, 1.5},
      {2.0, 6.0, 2.2, 4.0, 1.0}, {2.0, 6.1, 2.9, 4.7, 1.4},
      {2.0, 5.6, 2.9, 3.6, 1.3}, {2.0, 6.7, 3.1, 4.4, 1.4},
      {2.0, 5.6, 3.0, 4.5, 1.5}, {2.0, 5.8, 2.7, 4.1, 1.0},
      {2.0, 6.2, 2.2, 4.5, 1.5}, {2.0, 5.6, 2.5, 3.9, 1.1},
      {2.0, 5.9, 3.2, 4.8, 1.8}, {2.0, 6.1, 2.8, 4.0, 1.3},
      {2.0, 6.3, 2.5, 4.9, 1.5}, {2.0, 6.1, 2.8, 4.7, 1.2},
      {2.0, 6.4, 2.9, 4.3, 1.3}, {2.0, 6.6, 3.0, 4.4, 1.4},
      {2.0, 6.8, 2.8, 4.8, 1.4}, {2.0, 6.7, 3.0, 5.0, 1.7},
      {2.0, 6.0, 2.9, 4.5, 1.5}, {2.0, 5.7, 2.6, 3.5, 1.0},
      {2.0, 5.5, 2.4, 3.8, 1.1}, {2.0, 5.5, 2.4, 3.7, 1.0},
      {2.0, 5.8, 2.7, 3.9, 1.2}, {2.0, 6.0, 2.7, 5.1, 1.6},
      {2.0, 5.4, 3.0, 4.5, 1.5}, {2.0, 6.0, 3.4, 4.5, 1.6},
      {2.0, 6.7, 3.1, 4.7, 1.5}, {2.0, 6.3, 2.3, 4.4, 1.3},
      {2.0, 5.6, 3.0, 4.1, 1.3}, {2.0, 5.5, 2.5, 4.0, 1.3},
      {2.0, 5.5, 2.6, 4.4, 1.2}, {2.0, 6.1, 3.0, 4.6, 1.4},
      {2.0, 5.8, 2.6, 4.0, 1.2}, {2.0, 5.0, 2.3, 3.3, 1.0},
      {2.0, 5.6, 2.7, 4.2, 1.3}, {2.0, 5.7, 3.0, 4.2, 1.2},
      {2.0, 5.7, 2.9, 4.2, 1.3}, {2.0, 6.2, 2.9, 4.3, 1.3},
      {2.0, 5.1, 2.5, 3.0, 1.1}, {2.0, 5.7, 2.8, 4.1, 1.3},
      {3.0, 6.3, 3.3, 6.0, 2.5}, {3.0, 5.8, 2.7, 5.1, 1.9},
      {3.0, 7.1, 3.0, 5.9, 2.1}, {3.0, 6.3, 2.9, 5.6, 1.8},
      {3.0, 6.5, 3.0, 5.8, 2.2}, {3.0, 7.6, 3.0, 6.6, 2.1},
      {3.0, 4.9, 2.5, 4.5, 1.7}, {3.0, 7.3, 2.9, 6.3, 1.8},
      {3.0, 6.7, 2.5, 5.8, 1.8}, {3.0, 7.2, 3.6, 6.1, 2.5},
      {3.0, 6.5, 3.2, 5.1, 2.0}, {3.0, 6.4, 2.7, 5.3, 1.9},
      {3.0, 6.8, 3.0, 5.5, 2.1}, {3.0, 5.7, 2.5, 5.0, 2.0},
      {3.0, 5.8, 2.8, 5.1, 2.4}, {3.0, 6.4, 3.2, 5.3, 2.3},
      {3.0, 6.5, 3.0, 5.5, 1.8}, {3.0, 7.7, 3.8, 6.7, 2.2},
      {3.0, 7.7, 2.6, 6.9, 2.3}, {3.0, 6.0, 2.2, 5.0, 1.5},
      {3.0, 6.9, 3.2, 5.7, 2.3}, {3.0, 5.6, 2.8, 4.9, 2.0},
      {3.0, 7.7, 2.8, 6.7, 2.0}, {3.0, 6.3, 2.7, 4.9, 1.8},
      {3.0, 6.7, 3.3, 5.7, 2.1}, {3.0, 7.2, 3.2, 6.0, 1.8},
      {3.0, 6.2, 2.8, 4.8, 1.8}, {3.0, 6.1, 3.0, 4.9, 1.8},
      {3.0, 6.4, 2.8, 5.6, 2.1}, {3.0, 7.2, 3.0, 5.8, 1.6},
      {3.0, 7.4, 2.8, 6.1, 1.9}, {3.0, 7.9, 3.8, 6.4, 2.0},
      {3.0, 6.4, 2.8, 5.6, 2.2}, {3.0, 6.3, 2.8, 5.1, 1.5},
      {3.0, 6.1, 2.6, 5.6, 1.4}, {3.0, 7.7, 3.0, 6.1, 2.3},
      {3.0, 6.3, 3.4, 5.6, 2.4}, {3.0, 6.4, 3.1, 5.5, 1.8},
      {3.0, 6.0, 3.0, 4.8, 1.8}, {3.0, 6.9, 3.1, 5.4, 2.1},
      {3.0, 6.7, 3.1, 5.6, 2.4}, {3.0, 6.9, 3.1, 5.1, 2.3},
      {3.0, 5.8, 2.7, 5.1, 1.9}, {3.0, 6.8, 3.2, 5.9, 2.3},
      {3.0, 6.7, 3.3, 5.7, 2.5}, {3.0, 6.7, 3.0, 5.2, 2.3},
      {3.0, 6.3, 2.5, 5.0, 1.9}, {3.0, 6.5, 3.0, 5.2, 2.0},
      {3.0, 6.2, 3.4, 5.4, 2.3}, {3.0, 5.9, 3.0, 5.1, 1.8}
   };

   public static void main(String[] args) throws Exception {

      /* Data corrections described in the KDD data mining archive */
      irisFisherData[34][4] = 0.1;
      irisFisherData[37][2] = 3.1;
      irisFisherData[37][3] = 1.5;

      /*  Train first 140 patterns of the iris Fisher Data */
      int[] irisClassificationData = new int[irisFisherData.length - 10];
      double[][] irisContinuousData =
            new double[irisFisherData.length - 10][irisFisherData[0].length- 1];

      for (int i = 0; i < irisFisherData.length - 10; i++) {
         irisClassificationData[i] = (int) irisFisherData[i][0] - 1;
         System.arraycopy(irisFisherData[i], 1,
               irisContinuousData[i], 0, irisFisherData[0].length - 1);
      }

      int nNominal = 0;   /* no nominal input attributes      */
      int nContinuous = 4;   /* four continuous input attributes */
      int nClasses = 3;   /* three classification categories  */

      NaiveBayesClassifier nbTrainer =
            new NaiveBayesClassifier(nContinuous, nNominal, nClasses);

      double[][] means = {
         {5.06, 5.94, 6.58},
         {3.42, 2.8, 2.97},
         {1.5, 4.3, 5.6},
         {0.25, 1.33, 2.1}
      };
      double[][] stdev = {
         {0.35, 0.52, 0.64},
         {0.38, 0.3, 0.32},
         {0.17, 0.47, 0.55},
         {0.12, 0.198, 0.275}
      };

      for (int i = 0; i < nContinuous; i++) {
         ProbabilityDistribution[] pdf =
               new ProbabilityDistribution[nClasses];
         for (int j = 0; j < nClasses; j++) {
            pdf[j] = new TestGaussFcn1(means[i][j], stdev[i][j]);
         }
         nbTrainer.createContinuousAttribute(pdf);

      }
      nbTrainer.train(irisContinuousData, irisClassificationData);

      int[][] classErrors = nbTrainer.getTrainingErrors();

      System.out.println("     Iris Classification Error Rates");
      System.out.println("------------------------------------------------");
      System.out.println("  Setosa   Versicolour   Virginica    |   Total");
      System.out.println("  " + classErrors[0][0] + "/" + classErrors[0][1] +
            "         " + classErrors[1][0] + "/" + classErrors[1][1] +
            "        " + classErrors[2][0] + "/" + classErrors[2][1] +
            "       |   " + classErrors[3][0] + "/" + classErrors[3][1]);
      System.out.println(
            "------------------------------------------------\n\n\n");


      /*  Classify last  10 iris data patterns with the trained classifier */
      double[] continuousInput = new double[(irisFisherData[0].length - 1)];
      double[] classifiedProbabilities = new double[nClasses];

      System.out.println("Probabilities for Incorrect Classifications");
      System.out.println(" Predicted   ");
      System.out.println(
            "   Class     |  Class       |   P(0)     P(1)     P(2) ");
      System.out.println(
            "-------------------------------------------------------");
      for (int i = 0; i < 10; i++) {
         int targetClassification =
               (int) irisFisherData[(irisFisherData.length - 10) + i][0] - 1;
         System.arraycopy(irisFisherData[(irisFisherData.length - 10) + i], 1,
               continuousInput, 0, (irisFisherData[0].length - 1));

         classifiedProbabilities =nbTrainer.probabilities(continuousInput,null);
         int classification = nbTrainer.predictClass(continuousInput, null);
         if (classification == 0) {
            System.out.print("Setosa       |");
         } else if (classification == 1) {
            System.out.print("Versicolour  |");
         } else if (classification == 2) {
            System.out.print("Virginica    |");
         } else {
            System.out.print("Missing      |");
         }
         if (targetClassification == 0) {
            System.out.print(" Setosa       |");
         } else if (targetClassification == 1) {
            System.out.print(" Versicolour  |");
         } else if (targetClassification == 2) {
            System.out.print(" Virginica    |");
         } else {
            System.out.print(" Missing      |");
         }
         for (int j = 0; j < nClasses; j++) {
            Object[] pArgs = {new Double(classifiedProbabilities[j])};
            System.out.printf("   %2.3f ", pArgs);
         }
         System.out.println("");
      }
   }

   static public class TestGaussFcn1 implements ProbabilityDistribution {

      private double mean;
      private double stdev;

      public TestGaussFcn1(double mean, double stdev) {
         this.mean = mean;
         this.stdev = stdev;
      }

      public double[] eval(double[] xData) {
         double[] pdf = new double[xData.length];
         for (int i = 0; i < xData.length; i++) {
            pdf[i] = eval(xData[i], null);
         }
         return pdf;
      }

      public double[] eval(double[] xData, Object[] Params) {
         double[] pdf = new double[xData.length];
         for (int i = 0; i < xData.length; i++) {
            pdf[i] = eval(xData[i], Params);
         }
         return pdf;
      }

      public double eval(double xData, Object[] Params) {
         return l_gaussian_pdf(xData, mean, stdev);
      }

      public Object[] getParameters() {
         Double[] parms = new Double[2];
         parms[0] = this.mean;
         parms[1] = this.stdev;
         return parms;
      }

      private double l_gaussian_pdf(double x, double mean, double stdev) {
         double e, phi2, z, s;
         double sqrt_pi2 = 2.506628274631; /* sqrt(2*pi) */
         if (Double.isNaN(x)) {
            return Double.NaN;
         }
         if (Double.isNaN(mean) || Double.isNaN(stdev)) {
            return Double.NaN;
         } else {
            z = x;
            z -= mean;
            s = stdev;
            phi2 = sqrt_pi2 * s;
            e = -0.5 * (z * z) / (s * s);
            return Math.exp(e) / phi2;
         }
      }
   }
}

Output

The Naive Bayes classifier incorrectly classifies 6 of the 150 training patterns.

     Iris Classification Error Rates
------------------------------------------------
  Setosa   Versicolour   Virginica    |   Total
  0/50         2/50        4/40       |   6/140
------------------------------------------------



Probabilities for Incorrect Classifications
 Predicted   
   Class     |  Class       |   P(0)     P(1)     P(2) 
-------------------------------------------------------
Virginica    | Virginica    |   0.000    0.000    1.000 
Virginica    | Virginica    |   0.000    0.000    1.000 
Virginica    | Virginica    |   0.000    0.051    0.949 
Virginica    | Virginica    |   0.000    0.000    1.000 
Virginica    | Virginica    |   0.000    0.000    1.000 
Virginica    | Virginica    |   0.000    0.000    1.000 
Virginica    | Virginica    |   0.000    0.048    0.952 
Virginica    | Virginica    |   0.000    0.001    0.999 
Virginica    | Virginica    |   0.000    0.000    1.000 
Virginica    | Virginica    |   0.000    0.126    0.874 
Link to Java source.