This example is the same as Example 1, using Fisher's (1936) Iris data to train a Naive Bayes classifier using 140 of the 150 continuous patterns, then classifies ten unknown plants using their sepal and petal measurements.
Instead of using the NormalDistribution
class from the com.imsl.stat
package, a user supplied normal (Gaussian) distribution is used. Rather than calculating the means and standard deviations from the data, as is done by the NormalDistribution
's eval(double[])
method, the user supplied class requires the means and standard deviations in the class constructor. The output is the same as in Example 1, since the means and standard deviations in this example are simply rounded means and standard deviations of the actual data subset by target classifications.
import com.imsl.datamining.*; import com.imsl.stat.ProbabilityDistribution; import java.io.*; public class NaiveBayesClassifierEx3 { private static double[][] irisFisherData = { {1.0, 5.1, 3.5, 1.4, .2}, {1.0, 4.9, 3.0, 1.4, .2}, {1.0, 4.7, 3.2, 1.3, .2}, {1.0, 4.6, 3.1, 1.5, .2}, {1.0, 5.0, 3.6, 1.4, .2}, {1.0, 5.4, 3.9, 1.7, .4}, {1.0, 4.6, 3.4, 1.4, .3}, {1.0, 5.0, 3.4, 1.5, .2}, {1.0, 4.4, 2.9, 1.4, .2}, {1.0, 4.9, 3.1, 1.5, .1}, {1.0, 5.4, 3.7, 1.5, .2}, {1.0, 4.8, 3.4, 1.6, .2}, {1.0, 4.8, 3.0, 1.4, .1}, {1.0, 4.3, 3.0, 1.1, .1}, {1.0, 5.8, 4.0, 1.2, .2}, {1.0, 5.7, 4.4, 1.5, .4}, {1.0, 5.4, 3.9, 1.3, .4}, {1.0, 5.1, 3.5, 1.4, .3}, {1.0, 5.7, 3.8, 1.7, .3}, {1.0, 5.1, 3.8, 1.5, .3}, {1.0, 5.4, 3.4, 1.7, .2}, {1.0, 5.1, 3.7, 1.5, .4}, {1.0, 4.6, 3.6, 1.0, .2}, {1.0, 5.1, 3.3, 1.7, .5}, {1.0, 4.8, 3.4, 1.9, .2}, {1.0, 5.0, 3.0, 1.6, .2}, {1.0, 5.0, 3.4, 1.6, .4}, {1.0, 5.2, 3.5, 1.5, .2}, {1.0, 5.2, 3.4, 1.4, .2}, {1.0, 4.7, 3.2, 1.6, .2}, {1.0, 4.8, 3.1, 1.6, .2}, {1.0, 5.4, 3.4, 1.5, .4}, {1.0, 5.2, 4.1, 1.5, .1}, {1.0, 5.5, 4.2, 1.4, .2}, {1.0, 4.9, 3.1, 1.5, .2}, {1.0, 5.0, 3.2, 1.2, .2}, {1.0, 5.5, 3.5, 1.3, .2}, {1.0, 4.9, 3.6, 1.4, .1}, {1.0, 4.4, 3.0, 1.3, .2}, {1.0, 5.1, 3.4, 1.5, .2}, {1.0, 5.0, 3.5, 1.3, .3}, {1.0, 4.5, 2.3, 1.3, .3}, {1.0, 4.4, 3.2, 1.3, .2}, {1.0, 5.0, 3.5, 1.6, .6}, {1.0, 5.1, 3.8, 1.9, .4}, {1.0, 4.8, 3.0, 1.4, .3}, {1.0, 5.1, 3.8, 1.6, .2}, {1.0, 4.6, 3.2, 1.4, .2}, {1.0, 5.3, 3.7, 1.5, .2}, {1.0, 5.0, 3.3, 1.4, .2}, {2.0, 7.0, 3.2, 4.7, 1.4}, {2.0, 6.4, 3.2, 4.5, 1.5}, {2.0, 6.9, 3.1, 4.9, 1.5}, {2.0, 5.5, 2.3, 4.0, 1.3}, {2.0, 6.5, 2.8, 4.6, 1.5}, {2.0, 5.7, 2.8, 4.5, 1.3}, {2.0, 6.3, 3.3, 4.7, 1.6}, {2.0, 4.9, 2.4, 3.3, 1.0}, {2.0, 6.6, 2.9, 4.6, 1.3}, {2.0, 5.2, 2.7, 3.9, 1.4}, {2.0, 5.0, 2.0, 3.5, 1.0}, {2.0, 5.9, 3.0, 4.2, 1.5}, {2.0, 6.0, 2.2, 4.0, 1.0}, {2.0, 6.1, 2.9, 4.7, 1.4}, {2.0, 5.6, 2.9, 3.6, 1.3}, {2.0, 6.7, 3.1, 4.4, 1.4}, {2.0, 5.6, 3.0, 4.5, 1.5}, {2.0, 5.8, 2.7, 4.1, 1.0}, {2.0, 6.2, 2.2, 4.5, 1.5}, {2.0, 5.6, 2.5, 3.9, 1.1}, {2.0, 5.9, 3.2, 4.8, 1.8}, {2.0, 6.1, 2.8, 4.0, 1.3}, {2.0, 6.3, 2.5, 4.9, 1.5}, {2.0, 6.1, 2.8, 4.7, 1.2}, {2.0, 6.4, 2.9, 4.3, 1.3}, {2.0, 6.6, 3.0, 4.4, 1.4}, {2.0, 6.8, 2.8, 4.8, 1.4}, {2.0, 6.7, 3.0, 5.0, 1.7}, {2.0, 6.0, 2.9, 4.5, 1.5}, {2.0, 5.7, 2.6, 3.5, 1.0}, {2.0, 5.5, 2.4, 3.8, 1.1}, {2.0, 5.5, 2.4, 3.7, 1.0}, {2.0, 5.8, 2.7, 3.9, 1.2}, {2.0, 6.0, 2.7, 5.1, 1.6}, {2.0, 5.4, 3.0, 4.5, 1.5}, {2.0, 6.0, 3.4, 4.5, 1.6}, {2.0, 6.7, 3.1, 4.7, 1.5}, {2.0, 6.3, 2.3, 4.4, 1.3}, {2.0, 5.6, 3.0, 4.1, 1.3}, {2.0, 5.5, 2.5, 4.0, 1.3}, {2.0, 5.5, 2.6, 4.4, 1.2}, {2.0, 6.1, 3.0, 4.6, 1.4}, {2.0, 5.8, 2.6, 4.0, 1.2}, {2.0, 5.0, 2.3, 3.3, 1.0}, {2.0, 5.6, 2.7, 4.2, 1.3}, {2.0, 5.7, 3.0, 4.2, 1.2}, {2.0, 5.7, 2.9, 4.2, 1.3}, {2.0, 6.2, 2.9, 4.3, 1.3}, {2.0, 5.1, 2.5, 3.0, 1.1}, {2.0, 5.7, 2.8, 4.1, 1.3}, {3.0, 6.3, 3.3, 6.0, 2.5}, {3.0, 5.8, 2.7, 5.1, 1.9}, {3.0, 7.1, 3.0, 5.9, 2.1}, {3.0, 6.3, 2.9, 5.6, 1.8}, {3.0, 6.5, 3.0, 5.8, 2.2}, {3.0, 7.6, 3.0, 6.6, 2.1}, {3.0, 4.9, 2.5, 4.5, 1.7}, {3.0, 7.3, 2.9, 6.3, 1.8}, {3.0, 6.7, 2.5, 5.8, 1.8}, {3.0, 7.2, 3.6, 6.1, 2.5}, {3.0, 6.5, 3.2, 5.1, 2.0}, {3.0, 6.4, 2.7, 5.3, 1.9}, {3.0, 6.8, 3.0, 5.5, 2.1}, {3.0, 5.7, 2.5, 5.0, 2.0}, {3.0, 5.8, 2.8, 5.1, 2.4}, {3.0, 6.4, 3.2, 5.3, 2.3}, {3.0, 6.5, 3.0, 5.5, 1.8}, {3.0, 7.7, 3.8, 6.7, 2.2}, {3.0, 7.7, 2.6, 6.9, 2.3}, {3.0, 6.0, 2.2, 5.0, 1.5}, {3.0, 6.9, 3.2, 5.7, 2.3}, {3.0, 5.6, 2.8, 4.9, 2.0}, {3.0, 7.7, 2.8, 6.7, 2.0}, {3.0, 6.3, 2.7, 4.9, 1.8}, {3.0, 6.7, 3.3, 5.7, 2.1}, {3.0, 7.2, 3.2, 6.0, 1.8}, {3.0, 6.2, 2.8, 4.8, 1.8}, {3.0, 6.1, 3.0, 4.9, 1.8}, {3.0, 6.4, 2.8, 5.6, 2.1}, {3.0, 7.2, 3.0, 5.8, 1.6}, {3.0, 7.4, 2.8, 6.1, 1.9}, {3.0, 7.9, 3.8, 6.4, 2.0}, {3.0, 6.4, 2.8, 5.6, 2.2}, {3.0, 6.3, 2.8, 5.1, 1.5}, {3.0, 6.1, 2.6, 5.6, 1.4}, {3.0, 7.7, 3.0, 6.1, 2.3}, {3.0, 6.3, 3.4, 5.6, 2.4}, {3.0, 6.4, 3.1, 5.5, 1.8}, {3.0, 6.0, 3.0, 4.8, 1.8}, {3.0, 6.9, 3.1, 5.4, 2.1}, {3.0, 6.7, 3.1, 5.6, 2.4}, {3.0, 6.9, 3.1, 5.1, 2.3}, {3.0, 5.8, 2.7, 5.1, 1.9}, {3.0, 6.8, 3.2, 5.9, 2.3}, {3.0, 6.7, 3.3, 5.7, 2.5}, {3.0, 6.7, 3.0, 5.2, 2.3}, {3.0, 6.3, 2.5, 5.0, 1.9}, {3.0, 6.5, 3.0, 5.2, 2.0}, {3.0, 6.2, 3.4, 5.4, 2.3}, {3.0, 5.9, 3.0, 5.1, 1.8} }; public static void main(String[] args) throws Exception { /* Data corrections described in the KDD data mining archive */ irisFisherData[34][4] = 0.1; irisFisherData[37][2] = 3.1; irisFisherData[37][3] = 1.5; /* Train first 140 patterns of the iris Fisher Data */ int[] irisClassificationData = new int[irisFisherData.length - 10]; double[][] irisContinuousData = new double[irisFisherData.length - 10][irisFisherData[0].length- 1]; for (int i = 0; i < irisFisherData.length - 10; i++) { irisClassificationData[i] = (int) irisFisherData[i][0] - 1; System.arraycopy(irisFisherData[i], 1, irisContinuousData[i], 0, irisFisherData[0].length - 1); } int nNominal = 0; /* no nominal input attributes */ int nContinuous = 4; /* four continuous input attributes */ int nClasses = 3; /* three classification categories */ NaiveBayesClassifier nbTrainer = new NaiveBayesClassifier(nContinuous, nNominal, nClasses); double[][] means = { {5.06, 5.94, 6.58}, {3.42, 2.8, 2.97}, {1.5, 4.3, 5.6}, {0.25, 1.33, 2.1} }; double[][] stdev = { {0.35, 0.52, 0.64}, {0.38, 0.3, 0.32}, {0.17, 0.47, 0.55}, {0.12, 0.198, 0.275} }; for (int i = 0; i < nContinuous; i++) { ProbabilityDistribution[] pdf = new ProbabilityDistribution[nClasses]; for (int j = 0; j < nClasses; j++) { pdf[j] = new TestGaussFcn1(means[i][j], stdev[i][j]); } nbTrainer.createContinuousAttribute(pdf); } nbTrainer.train(irisContinuousData, irisClassificationData); int[][] classErrors = nbTrainer.getTrainingErrors(); System.out.println(" Iris Classification Error Rates"); System.out.println("------------------------------------------------"); System.out.println(" Setosa Versicolour Virginica | Total"); System.out.println(" " + classErrors[0][0] + "/" + classErrors[0][1] + " " + classErrors[1][0] + "/" + classErrors[1][1] + " " + classErrors[2][0] + "/" + classErrors[2][1] + " | " + classErrors[3][0] + "/" + classErrors[3][1]); System.out.println( "------------------------------------------------\n\n\n"); /* Classify last 10 iris data patterns with the trained classifier */ double[] continuousInput = new double[(irisFisherData[0].length - 1)]; double[] classifiedProbabilities = new double[nClasses]; System.out.println("Probabilities for Incorrect Classifications"); System.out.println(" Predicted "); System.out.println( " Class | Class | P(0) P(1) P(2) "); System.out.println( "-------------------------------------------------------"); for (int i = 0; i < 10; i++) { int targetClassification = (int) irisFisherData[(irisFisherData.length - 10) + i][0] - 1; System.arraycopy(irisFisherData[(irisFisherData.length - 10) + i], 1, continuousInput, 0, (irisFisherData[0].length - 1)); classifiedProbabilities =nbTrainer.probabilities(continuousInput,null); int classification = nbTrainer.predictClass(continuousInput, null); if (classification == 0) { System.out.print("Setosa |"); } else if (classification == 1) { System.out.print("Versicolour |"); } else if (classification == 2) { System.out.print("Virginica |"); } else { System.out.print("Missing |"); } if (targetClassification == 0) { System.out.print(" Setosa |"); } else if (targetClassification == 1) { System.out.print(" Versicolour |"); } else if (targetClassification == 2) { System.out.print(" Virginica |"); } else { System.out.print(" Missing |"); } for (int j = 0; j < nClasses; j++) { Object[] pArgs = {new Double(classifiedProbabilities[j])}; System.out.printf(" %2.3f ", pArgs); } System.out.println(""); } } static public class TestGaussFcn1 implements ProbabilityDistribution { private double mean; private double stdev; public TestGaussFcn1(double mean, double stdev) { this.mean = mean; this.stdev = stdev; } public double[] eval(double[] xData) { double[] pdf = new double[xData.length]; for (int i = 0; i < xData.length; i++) { pdf[i] = eval(xData[i], null); } return pdf; } public double[] eval(double[] xData, Object[] Params) { double[] pdf = new double[xData.length]; for (int i = 0; i < xData.length; i++) { pdf[i] = eval(xData[i], Params); } return pdf; } public double eval(double xData, Object[] Params) { return l_gaussian_pdf(xData, mean, stdev); } public Object[] getParameters() { Double[] parms = new Double[2]; parms[0] = this.mean; parms[1] = this.stdev; return parms; } private double l_gaussian_pdf(double x, double mean, double stdev) { double e, phi2, z, s; double sqrt_pi2 = 2.506628274631; /* sqrt(2*pi) */ if (Double.isNaN(x)) { return Double.NaN; } if (Double.isNaN(mean) || Double.isNaN(stdev)) { return Double.NaN; } else { z = x; z -= mean; s = stdev; phi2 = sqrt_pi2 * s; e = -0.5 * (z * z) / (s * s); return Math.exp(e) / phi2; } } } }
The Naive Bayes classifier incorrectly classifies 6 of the 150 training patterns.
Iris Classification Error Rates ------------------------------------------------ Setosa Versicolour Virginica | Total 0/50 2/50 4/40 | 6/140 ------------------------------------------------ Probabilities for Incorrect Classifications Predicted Class | Class | P(0) P(1) P(2) ------------------------------------------------------- Virginica | Virginica | 0.000 0.000 1.000 Virginica | Virginica | 0.000 0.000 1.000 Virginica | Virginica | 0.000 0.051 0.949 Virginica | Virginica | 0.000 0.000 1.000 Virginica | Virginica | 0.000 0.000 1.000 Virginica | Virginica | 0.000 0.000 1.000 Virginica | Virginica | 0.000 0.048 0.952 Virginica | Virginica | 0.000 0.001 0.999 Virginica | Virginica | 0.000 0.000 1.000 Virginica | Virginica | 0.000 0.126 0.874Link to Java source.