This example is the same as Example 1, using Fisher's (1936) Iris data to train a Naive Bayes classifier using 140 of the 150 continuous patterns, then classifies ten unknown plants using their sepal and petal measurements.
Instead of using the NormalDistribution
class from the Imsl.Stat
namespace, a user supplied normal (Gaussian) distribution is used. Rather than calculating the means and standard deviations from the data, as is done by the NormalDistribution
's Eval(double[])
method, the user supplied class requires the means and standard deviations in the class constructor. The output is the same as in Example 1, since the means and standard deviations in this example are simply rounded means and standard deviations of the actual data subset by target classifications.
using System; using Imsl.DataMining; using IProbabilityDistribution = Imsl.Stat.IProbabilityDistribution; public class NaiveBayesClassifierEx3 { private static double[][] irisFisherData = new double[][]{ new double[]{1.0, 5.1, 3.5, 1.4, .2}, new double[]{1.0, 4.9, 3.0, 1.4, .2}, new double[]{1.0, 4.7, 3.2, 1.3, .2}, new double[]{1.0, 4.6, 3.1, 1.5, .2}, new double[]{1.0, 5.0, 3.6, 1.4, .2}, new double[]{1.0, 5.4, 3.9, 1.7, .4}, new double[]{1.0, 4.6, 3.4, 1.4, .3}, new double[]{1.0, 5.0, 3.4, 1.5, .2}, new double[]{1.0, 4.4, 2.9, 1.4, .2}, new double[]{1.0, 4.9, 3.1, 1.5, .1}, new double[]{1.0, 5.4, 3.7, 1.5, .2}, new double[]{1.0, 4.8, 3.4, 1.6, .2}, new double[]{1.0, 4.8, 3.0, 1.4, .1}, new double[]{1.0, 4.3, 3.0, 1.1, .1}, new double[]{1.0, 5.8, 4.0, 1.2, .2}, new double[]{1.0, 5.7, 4.4, 1.5, .4}, new double[]{1.0, 5.4, 3.9, 1.3, .4}, new double[]{1.0, 5.1, 3.5, 1.4, .3}, new double[]{1.0, 5.7, 3.8, 1.7, .3}, new double[]{1.0, 5.1, 3.8, 1.5, .3}, new double[]{1.0, 5.4, 3.4, 1.7, .2}, new double[]{1.0, 5.1, 3.7, 1.5, .4}, new double[]{1.0, 4.6, 3.6, 1.0, .2}, new double[]{1.0, 5.1, 3.3, 1.7, .5}, new double[]{1.0, 4.8, 3.4, 1.9, .2}, new double[]{1.0, 5.0, 3.0, 1.6, .2}, new double[]{1.0, 5.0, 3.4, 1.6, .4}, new double[]{1.0, 5.2, 3.5, 1.5, .2}, new double[]{1.0, 5.2, 3.4, 1.4, .2}, new double[]{1.0, 4.7, 3.2, 1.6, .2}, new double[]{1.0, 4.8, 3.1, 1.6, .2}, new double[]{1.0, 5.4, 3.4, 1.5, .4}, new double[]{1.0, 5.2, 4.1, 1.5, .1}, new double[]{1.0, 5.5, 4.2, 1.4, .2}, new double[]{1.0, 4.9, 3.1, 1.5, .2}, new double[]{1.0, 5.0, 3.2, 1.2, .2}, new double[]{1.0, 5.5, 3.5, 1.3, .2}, new double[]{1.0, 4.9, 3.6, 1.4, .1}, new double[]{1.0, 4.4, 3.0, 1.3, .2}, new double[]{1.0, 5.1, 3.4, 1.5, .2}, new double[]{1.0, 5.0, 3.5, 1.3, .3}, new double[]{1.0, 4.5, 2.3, 1.3, .3}, new double[]{1.0, 4.4, 3.2, 1.3, .2}, new double[]{1.0, 5.0, 3.5, 1.6, .6}, new double[]{1.0, 5.1, 3.8, 1.9, .4}, new double[]{1.0, 4.8, 3.0, 1.4, .3}, new double[]{1.0, 5.1, 3.8, 1.6, .2}, new double[]{1.0, 4.6, 3.2, 1.4, .2}, new double[]{1.0, 5.3, 3.7, 1.5, .2}, new double[]{1.0, 5.0, 3.3, 1.4, .2}, new double[]{2.0, 7.0, 3.2, 4.7, 1.4}, new double[]{2.0, 6.4, 3.2, 4.5, 1.5}, new double[]{2.0, 6.9, 3.1, 4.9, 1.5}, new double[]{2.0, 5.5, 2.3, 4.0, 1.3}, new double[]{2.0, 6.5, 2.8, 4.6, 1.5}, new double[]{2.0, 5.7, 2.8, 4.5, 1.3}, new double[]{2.0, 6.3, 3.3, 4.7, 1.6}, new double[]{2.0, 4.9, 2.4, 3.3, 1.0}, new double[]{2.0, 6.6, 2.9, 4.6, 1.3}, new double[]{2.0, 5.2, 2.7, 3.9, 1.4}, new double[]{2.0, 5.0, 2.0, 3.5, 1.0}, new double[]{2.0, 5.9, 3.0, 4.2, 1.5}, new double[]{2.0, 6.0, 2.2, 4.0, 1.0}, new double[]{2.0, 6.1, 2.9, 4.7, 1.4}, new double[]{2.0, 5.6, 2.9, 3.6, 1.3}, new double[]{2.0, 6.7, 3.1, 4.4, 1.4}, new double[]{2.0, 5.6, 3.0, 4.5, 1.5}, new double[]{2.0, 5.8, 2.7, 4.1, 1.0}, new double[]{2.0, 6.2, 2.2, 4.5, 1.5}, new double[]{2.0, 5.6, 2.5, 3.9, 1.1}, new double[]{2.0, 5.9, 3.2, 4.8, 1.8}, new double[]{2.0, 6.1, 2.8, 4.0, 1.3}, new double[]{2.0, 6.3, 2.5, 4.9, 1.5}, new double[]{2.0, 6.1, 2.8, 4.7, 1.2}, new double[]{2.0, 6.4, 2.9, 4.3, 1.3}, new double[]{2.0, 6.6, 3.0, 4.4, 1.4}, new double[]{2.0, 6.8, 2.8, 4.8, 1.4}, new double[]{2.0, 6.7, 3.0, 5.0, 1.7}, new double[]{2.0, 6.0, 2.9, 4.5, 1.5}, new double[]{2.0, 5.7, 2.6, 3.5, 1.0}, new double[]{2.0, 5.5, 2.4, 3.8, 1.1}, new double[]{2.0, 5.5, 2.4, 3.7, 1.0}, new double[]{2.0, 5.8, 2.7, 3.9, 1.2}, new double[]{2.0, 6.0, 2.7, 5.1, 1.6}, new double[]{2.0, 5.4, 3.0, 4.5, 1.5}, new double[]{2.0, 6.0, 3.4, 4.5, 1.6}, new double[]{2.0, 6.7, 3.1, 4.7, 1.5}, new double[]{2.0, 6.3, 2.3, 4.4, 1.3}, new double[]{2.0, 5.6, 3.0, 4.1, 1.3}, new double[]{2.0, 5.5, 2.5, 4.0, 1.3}, new double[]{2.0, 5.5, 2.6, 4.4, 1.2}, new double[]{2.0, 6.1, 3.0, 4.6, 1.4}, new double[]{2.0, 5.8, 2.6, 4.0, 1.2}, new double[]{2.0, 5.0, 2.3, 3.3, 1.0}, new double[]{2.0, 5.6, 2.7, 4.2, 1.3}, new double[]{2.0, 5.7, 3.0, 4.2, 1.2}, new double[]{2.0, 5.7, 2.9, 4.2, 1.3}, new double[]{2.0, 6.2, 2.9, 4.3, 1.3}, new double[]{2.0, 5.1, 2.5, 3.0, 1.1}, new double[]{2.0, 5.7, 2.8, 4.1, 1.3}, new double[]{3.0, 6.3, 3.3, 6.0, 2.5}, new double[]{3.0, 5.8, 2.7, 5.1, 1.9}, new double[]{3.0, 7.1, 3.0, 5.9, 2.1}, new double[]{3.0, 6.3, 2.9, 5.6, 1.8}, new double[]{3.0, 6.5, 3.0, 5.8, 2.2}, new double[]{3.0, 7.6, 3.0, 6.6, 2.1}, new double[]{3.0, 4.9, 2.5, 4.5, 1.7}, new double[]{3.0, 7.3, 2.9, 6.3, 1.8}, new double[]{3.0, 6.7, 2.5, 5.8, 1.8}, new double[]{3.0, 7.2, 3.6, 6.1, 2.5}, new double[]{3.0, 6.5, 3.2, 5.1, 2.0}, new double[]{3.0, 6.4, 2.7, 5.3, 1.9}, new double[]{3.0, 6.8, 3.0, 5.5, 2.1}, new double[]{3.0, 5.7, 2.5, 5.0, 2.0}, new double[]{3.0, 5.8, 2.8, 5.1, 2.4}, new double[]{3.0, 6.4, 3.2, 5.3, 2.3}, new double[]{3.0, 6.5, 3.0, 5.5, 1.8}, new double[]{3.0, 7.7, 3.8, 6.7, 2.2}, new double[]{3.0, 7.7, 2.6, 6.9, 2.3}, new double[]{3.0, 6.0, 2.2, 5.0, 1.5}, new double[]{3.0, 6.9, 3.2, 5.7, 2.3}, new double[]{3.0, 5.6, 2.8, 4.9, 2.0}, new double[]{3.0, 7.7, 2.8, 6.7, 2.0}, new double[]{3.0, 6.3, 2.7, 4.9, 1.8}, new double[]{3.0, 6.7, 3.3, 5.7, 2.1}, new double[]{3.0, 7.2, 3.2, 6.0, 1.8}, new double[]{3.0, 6.2, 2.8, 4.8, 1.8}, new double[]{3.0, 6.1, 3.0, 4.9, 1.8}, new double[]{3.0, 6.4, 2.8, 5.6, 2.1}, new double[]{3.0, 7.2, 3.0, 5.8, 1.6}, new double[]{3.0, 7.4, 2.8, 6.1, 1.9}, new double[]{3.0, 7.9, 3.8, 6.4, 2.0}, new double[]{3.0, 6.4, 2.8, 5.6, 2.2}, new double[]{3.0, 6.3, 2.8, 5.1, 1.5}, new double[]{3.0, 6.1, 2.6, 5.6, 1.4}, new double[]{3.0, 7.7, 3.0, 6.1, 2.3}, new double[]{3.0, 6.3, 3.4, 5.6, 2.4}, new double[]{3.0, 6.4, 3.1, 5.5, 1.8}, new double[]{3.0, 6.0, 3.0, 4.8, 1.8}, new double[]{3.0, 6.9, 3.1, 5.4, 2.1}, new double[]{3.0, 6.7, 3.1, 5.6, 2.4}, new double[]{3.0, 6.9, 3.1, 5.1, 2.3}, new double[]{3.0, 5.8, 2.7, 5.1, 1.9}, new double[]{3.0, 6.8, 3.2, 5.9, 2.3}, new double[]{3.0, 6.7, 3.3, 5.7, 2.5}, new double[]{3.0, 6.7, 3.0, 5.2, 2.3}, new double[]{3.0, 6.3, 2.5, 5.0, 1.9}, new double[]{3.0, 6.5, 3.0, 5.2, 2.0}, new double[]{3.0, 6.2, 3.4, 5.4, 2.3}, new double[]{3.0, 5.9, 3.0, 5.1, 1.8}}; public static void Main(System.String[] args) { /* Data corrections described in the KDD data mining archive */ irisFisherData[34][4] = 0.1; irisFisherData[37][2] = 3.1; irisFisherData[37][3] = 1.5; /* Train first 140 patterns of the iris Fisher Data */ int[] irisClassificationData = new int[irisFisherData.Length - 10]; double[][] irisContinuousData = new double[irisFisherData.Length - 10][]; for (int i = 0; i < irisFisherData.Length - 10; i++) { irisContinuousData[i] = new double[irisFisherData[0].Length - 1]; } for (int i = 0; i < irisFisherData.Length - 10; i++) { irisClassificationData[i] = (int) irisFisherData[i][0] - 1; Array.Copy(irisFisherData[i], 1, irisContinuousData[i], 0, irisFisherData[0].Length - 1); } int nNominal = 0; /* no nominal input attributes */ int nContinuous = 4; /* four continuous input attributes */ int nClasses = 3; /* three classification categories */ NaiveBayesClassifier nbTrainer = new NaiveBayesClassifier(nContinuous, nNominal, nClasses); double[][] means = new double[][]{ new double[]{5.06, 5.94, 6.58}, new double[]{3.42, 2.8, 2.97}, new double[]{1.5, 4.3, 5.6}, new double[]{0.25, 1.33, 2.1} }; double[][] stdev = new double[][]{ new double[]{0.35, 0.52, 0.64}, new double[]{0.38, 0.3, 0.32}, new double[]{0.17, 0.47, 0.55}, new double[]{0.12, 0.198, 0.275} }; for (int i = 0; i < nContinuous; i++) { IProbabilityDistribution[] pdf = new IProbabilityDistribution[nClasses]; for (int j = 0; j < nClasses; j++) { pdf[j] = new TestGaussFcn1(means[i][j], stdev[i][j]); } nbTrainer.CreateContinuousAttribute(pdf); } nbTrainer.Train(irisContinuousData, null, irisClassificationData); int[][] classErrors = nbTrainer.GetTrainingErrors(); System.Console.Out.WriteLine( " Iris Classification Error Rates"); System.Console.Out.WriteLine( "------------------------------------------------"); System.Console.Out.WriteLine( " Setosa Versicolour Virginica | Total"); System.Console.Out.WriteLine(" " + classErrors[0][0] + "/" + classErrors[0][1] + " " + classErrors[1][0] + "/" + classErrors[1][1] + " " + classErrors[2][0] + "/" + classErrors[2][1] + " | " + classErrors[3][0] + "/" + classErrors[3][1]); System.Console.Out.WriteLine( "------------------------------------------------\n\n\n"); /* Classify last 10 iris data patterns * with the trained classifier */ double[] continuousInput = new double[(irisFisherData[0].Length - 1)]; double[] classifiedProbabilities = new double[nClasses]; System.Console.Out.WriteLine( "Probabilities for Incorrect Classifications"); System.Console.Out.WriteLine(" Predicted "); System.Console.Out.WriteLine( " Class | Class | P(0) P(1) P(2) "); System.Console.Out.WriteLine( "-------------------------------------------------------"); for (int i = 0; i < 10; i++) { int targetClassification = (int) irisFisherData[(irisFisherData.Length - 10) + i][0] - 1; Array.Copy(irisFisherData[(irisFisherData.Length - 10) + i], 1, continuousInput, 0, (irisFisherData[0].Length - 1)); classifiedProbabilities = nbTrainer.Probabilities(continuousInput, null); int classification = nbTrainer.PredictClass(continuousInput, null); if (classification == 0) { System.Console.Out.Write("Setosa |"); } else if (classification == 1) { System.Console.Out.Write("Versicolour |"); } else if (classification == 2) { System.Console.Out.Write("Virginica |"); } else { System.Console.Out.Write("Missing |"); } if (targetClassification == 0) { System.Console.Out.Write(" Setosa |"); } else if (targetClassification == 1) { System.Console.Out.Write(" Versicolour |"); } else if (targetClassification == 2) { System.Console.Out.Write(" Virginica |"); } else { System.Console.Out.Write(" Missing |"); } for (int j = 0; j < nClasses; j++) { System.Object[] pArgs = new System.Object[] { (double)classifiedProbabilities[j] }; System.Console.Out.Write(" {0, 2:f3} ", pArgs); } System.Console.Out.WriteLine(""); } } public class TestGaussFcn1 : IProbabilityDistribution { virtual public System.Object[] GetParameters() { System.Object[] parms = new System.Object[2]; parms[0] = this.mean; parms[1] = this.stdev; return parms; } private double mean; private double stdev; public TestGaussFcn1(double mean, double stdev) { this.mean = mean; this.stdev = stdev; } public virtual double[] Eval(double[] xData) { double[] pdf = new double[xData.Length]; for (int i = 0; i < xData.Length; i++) { pdf[i] = Eval(xData[i], null); } return pdf; } public virtual double[] Eval(double[] xData, System.Object[] Params) { double[] pdf = new double[xData.Length]; for (int i = 0; i < xData.Length; i++) { pdf[i] = Eval(xData[i], Params); } return pdf; } public virtual double Eval(double xData, System.Object[] Params) { return GaussianPdf(xData, mean, stdev); } private double GaussianPdf(double x, double mean, double stdev) { double e, phi2, z, s; double sqrt_pi2 = 2.506628274631; /* sqrt(2*pi) */ if (System.Double.IsNaN(x)) { return System.Double.NaN; } if (System.Double.IsNaN(mean) || System.Double.IsNaN(stdev)) { return System.Double.NaN; } else { z = x; z -= mean; s = stdev; phi2 = sqrt_pi2 * s; e = (- 0.5) * (z * z) / (s * s); return System.Math.Exp(e) / phi2; } } } }
The Naive Bayes classifier incorrectly classifies 6 of the 150 training patterns.
Iris Classification Error Rates ------------------------------------------------ Setosa Versicolour Virginica | Total 0/50 2/50 4/40 | 6/140 ------------------------------------------------ Probabilities for Incorrect Classifications Predicted Class | Class | P(0) P(1) P(2) ------------------------------------------------------- Virginica | Virginica | 0.000 0.000 1.000 Virginica | Virginica | 0.000 0.000 1.000 Virginica | Virginica | 0.000 0.051 0.949 Virginica | Virginica | 0.000 0.000 1.000 Virginica | Virginica | 0.000 0.000 1.000 Virginica | Virginica | 0.000 0.000 1.000 Virginica | Virginica | 0.000 0.048 0.952 Virginica | Virginica | 0.000 0.001 0.999 Virginica | Virginica | 0.000 0.000 1.000 Virginica | Virginica | 0.000 0.126 0.874Link to C# source.