package com.imsl.test.example.datamining; import com.imsl.datamining.*; import com.imsl.stat.ProbabilityDistribution; /** *

Trains a classifier with a user supplied probability * function.

* *

* This example is similar to {@link NaiveBayesClassifierEx1}, where we train a * classifier on Fisher's Iris data using 140 of the 150 continuous patterns, * and then classify ten remaining plants using their sepal and petal * measurements.

*

* Instead of using the NormalDistribution class, a user supplied * normal (Gaussian) distribution is supplied directly. Rather than calculating * the means and standard deviations from the data, as is done by the * NormalDistribution's eval(double[]) method, the * user supplied class requires the means and standard deviations in the class * constructor. The output is the same as in {@link NaiveBayesClassifierEx1}, * since the means and standard deviations in this example are simply rounded * means and standard deviations of the actual data subset by target * classifications.

* * @see Example * @see Output * */ public class NaiveBayesClassifierEx3 { private static double[][] irisData = { {1.0, 5.1, 3.5, 1.4, .2}, {1.0, 4.9, 3.0, 1.4, .2}, {1.0, 4.7, 3.2, 1.3, .2}, {1.0, 4.6, 3.1, 1.5, .2}, {1.0, 5.0, 3.6, 1.4, .2}, {1.0, 5.4, 3.9, 1.7, .4}, {1.0, 4.6, 3.4, 1.4, .3}, {1.0, 5.0, 3.4, 1.5, .2}, {1.0, 4.4, 2.9, 1.4, .2}, {1.0, 4.9, 3.1, 1.5, .1}, {1.0, 5.4, 3.7, 1.5, .2}, {1.0, 4.8, 3.4, 1.6, .2}, {1.0, 4.8, 3.0, 1.4, .1}, {1.0, 4.3, 3.0, 1.1, .1}, {1.0, 5.8, 4.0, 1.2, .2}, {1.0, 5.7, 4.4, 1.5, .4}, {1.0, 5.4, 3.9, 1.3, .4}, {1.0, 5.1, 3.5, 1.4, .3}, {1.0, 5.7, 3.8, 1.7, .3}, {1.0, 5.1, 3.8, 1.5, .3}, {1.0, 5.4, 3.4, 1.7, .2}, {1.0, 5.1, 3.7, 1.5, .4}, {1.0, 4.6, 3.6, 1.0, .2}, {1.0, 5.1, 3.3, 1.7, .5}, {1.0, 4.8, 3.4, 1.9, .2}, {1.0, 5.0, 3.0, 1.6, .2}, {1.0, 5.0, 3.4, 1.6, .4}, {1.0, 5.2, 3.5, 1.5, .2}, {1.0, 5.2, 3.4, 1.4, .2}, {1.0, 4.7, 3.2, 1.6, .2}, {1.0, 4.8, 3.1, 1.6, .2}, {1.0, 5.4, 3.4, 1.5, .4}, {1.0, 5.2, 4.1, 1.5, .1}, {1.0, 5.5, 4.2, 1.4, .2}, {1.0, 4.9, 3.1, 1.5, .2}, {1.0, 5.0, 3.2, 1.2, .2}, {1.0, 5.5, 3.5, 1.3, .2}, {1.0, 4.9, 3.6, 1.4, .1}, {1.0, 4.4, 3.0, 1.3, .2}, {1.0, 5.1, 3.4, 1.5, .2}, {1.0, 5.0, 3.5, 1.3, .3}, {1.0, 4.5, 2.3, 1.3, .3}, {1.0, 4.4, 3.2, 1.3, .2}, {1.0, 5.0, 3.5, 1.6, .6}, {1.0, 5.1, 3.8, 1.9, .4}, {1.0, 4.8, 3.0, 1.4, .3}, {1.0, 5.1, 3.8, 1.6, .2}, {1.0, 4.6, 3.2, 1.4, .2}, {1.0, 5.3, 3.7, 1.5, .2}, {1.0, 5.0, 3.3, 1.4, .2}, {2.0, 7.0, 3.2, 4.7, 1.4}, {2.0, 6.4, 3.2, 4.5, 1.5}, {2.0, 6.9, 3.1, 4.9, 1.5}, {2.0, 5.5, 2.3, 4.0, 1.3}, {2.0, 6.5, 2.8, 4.6, 1.5}, {2.0, 5.7, 2.8, 4.5, 1.3}, {2.0, 6.3, 3.3, 4.7, 1.6}, {2.0, 4.9, 2.4, 3.3, 1.0}, {2.0, 6.6, 2.9, 4.6, 1.3}, {2.0, 5.2, 2.7, 3.9, 1.4}, {2.0, 5.0, 2.0, 3.5, 1.0}, {2.0, 5.9, 3.0, 4.2, 1.5}, {2.0, 6.0, 2.2, 4.0, 1.0}, {2.0, 6.1, 2.9, 4.7, 1.4}, {2.0, 5.6, 2.9, 3.6, 1.3}, {2.0, 6.7, 3.1, 4.4, 1.4}, {2.0, 5.6, 3.0, 4.5, 1.5}, {2.0, 5.8, 2.7, 4.1, 1.0}, {2.0, 6.2, 2.2, 4.5, 1.5}, {2.0, 5.6, 2.5, 3.9, 1.1}, {2.0, 5.9, 3.2, 4.8, 1.8}, {2.0, 6.1, 2.8, 4.0, 1.3}, {2.0, 6.3, 2.5, 4.9, 1.5}, {2.0, 6.1, 2.8, 4.7, 1.2}, {2.0, 6.4, 2.9, 4.3, 1.3}, {2.0, 6.6, 3.0, 4.4, 1.4}, {2.0, 6.8, 2.8, 4.8, 1.4}, {2.0, 6.7, 3.0, 5.0, 1.7}, {2.0, 6.0, 2.9, 4.5, 1.5}, {2.0, 5.7, 2.6, 3.5, 1.0}, {2.0, 5.5, 2.4, 3.8, 1.1}, {2.0, 5.5, 2.4, 3.7, 1.0}, {2.0, 5.8, 2.7, 3.9, 1.2}, {2.0, 6.0, 2.7, 5.1, 1.6}, {2.0, 5.4, 3.0, 4.5, 1.5}, {2.0, 6.0, 3.4, 4.5, 1.6}, {2.0, 6.7, 3.1, 4.7, 1.5}, {2.0, 6.3, 2.3, 4.4, 1.3}, {2.0, 5.6, 3.0, 4.1, 1.3}, {2.0, 5.5, 2.5, 4.0, 1.3}, {2.0, 5.5, 2.6, 4.4, 1.2}, {2.0, 6.1, 3.0, 4.6, 1.4}, {2.0, 5.8, 2.6, 4.0, 1.2}, {2.0, 5.0, 2.3, 3.3, 1.0}, {2.0, 5.6, 2.7, 4.2, 1.3}, {2.0, 5.7, 3.0, 4.2, 1.2}, {2.0, 5.7, 2.9, 4.2, 1.3}, {2.0, 6.2, 2.9, 4.3, 1.3}, {2.0, 5.1, 2.5, 3.0, 1.1}, {2.0, 5.7, 2.8, 4.1, 1.3}, {3.0, 6.3, 3.3, 6.0, 2.5}, {3.0, 5.8, 2.7, 5.1, 1.9}, {3.0, 7.1, 3.0, 5.9, 2.1}, {3.0, 6.3, 2.9, 5.6, 1.8}, {3.0, 6.5, 3.0, 5.8, 2.2}, {3.0, 7.6, 3.0, 6.6, 2.1}, {3.0, 4.9, 2.5, 4.5, 1.7}, {3.0, 7.3, 2.9, 6.3, 1.8}, {3.0, 6.7, 2.5, 5.8, 1.8}, {3.0, 7.2, 3.6, 6.1, 2.5}, {3.0, 6.5, 3.2, 5.1, 2.0}, {3.0, 6.4, 2.7, 5.3, 1.9}, {3.0, 6.8, 3.0, 5.5, 2.1}, {3.0, 5.7, 2.5, 5.0, 2.0}, {3.0, 5.8, 2.8, 5.1, 2.4}, {3.0, 6.4, 3.2, 5.3, 2.3}, {3.0, 6.5, 3.0, 5.5, 1.8}, {3.0, 7.7, 3.8, 6.7, 2.2}, {3.0, 7.7, 2.6, 6.9, 2.3}, {3.0, 6.0, 2.2, 5.0, 1.5}, {3.0, 6.9, 3.2, 5.7, 2.3}, {3.0, 5.6, 2.8, 4.9, 2.0}, {3.0, 7.7, 2.8, 6.7, 2.0}, {3.0, 6.3, 2.7, 4.9, 1.8}, {3.0, 6.7, 3.3, 5.7, 2.1}, {3.0, 7.2, 3.2, 6.0, 1.8}, {3.0, 6.2, 2.8, 4.8, 1.8}, {3.0, 6.1, 3.0, 4.9, 1.8}, {3.0, 6.4, 2.8, 5.6, 2.1}, {3.0, 7.2, 3.0, 5.8, 1.6}, {3.0, 7.4, 2.8, 6.1, 1.9}, {3.0, 7.9, 3.8, 6.4, 2.0}, {3.0, 6.4, 2.8, 5.6, 2.2}, {3.0, 6.3, 2.8, 5.1, 1.5}, {3.0, 6.1, 2.6, 5.6, 1.4}, {3.0, 7.7, 3.0, 6.1, 2.3}, {3.0, 6.3, 3.4, 5.6, 2.4}, {3.0, 6.4, 3.1, 5.5, 1.8}, {3.0, 6.0, 3.0, 4.8, 1.8}, {3.0, 6.9, 3.1, 5.4, 2.1}, {3.0, 6.7, 3.1, 5.6, 2.4}, {3.0, 6.9, 3.1, 5.1, 2.3}, {3.0, 5.8, 2.7, 5.1, 1.9}, {3.0, 6.8, 3.2, 5.9, 2.3}, {3.0, 6.7, 3.3, 5.7, 2.5}, {3.0, 6.7, 3.0, 5.2, 2.3}, {3.0, 6.3, 2.5, 5.0, 1.9}, {3.0, 6.5, 3.0, 5.2, 2.0}, {3.0, 6.2, 3.4, 5.4, 2.3}, {3.0, 5.9, 3.0, 5.1, 1.8} }; /** * The main method for the example. */ public static void main(String[] args) throws Exception { /* Data corrections described in the KDD data mining archive */ irisData[34][4] = 0.1; irisData[37][2] = 3.1; irisData[37][3] = 1.5; /* Train first 140 patterns of the iris Fisher Data */ int[] irisClassificationData = new int[irisData.length - 10]; double[][] irisContinuousData = new double[irisData.length - 10][irisData[0].length - 1]; for (int i = 0; i < irisData.length - 10; i++) { irisClassificationData[i] = (int) irisData[i][0] - 1; System.arraycopy(irisData[i], 1, irisContinuousData[i], 0, irisData[0].length - 1); } int nNominal = 0; // no nominal input attributes int nContinuous = 4; // four continuous input attributes int nClasses = 3; // three classification categories NaiveBayesClassifier nbTrainer = new NaiveBayesClassifier(nContinuous, nNominal, nClasses); double[][] means = { {5.06, 5.94, 6.58}, {3.42, 2.8, 2.97}, {1.5, 4.3, 5.6}, {0.25, 1.33, 2.1} }; double[][] stdev = { {0.35, 0.52, 0.64}, {0.38, 0.3, 0.32}, {0.17, 0.47, 0.55}, {0.12, 0.198, 0.275} }; for (int i = 0; i < nContinuous; i++) { ProbabilityDistribution[] pdf = new ProbabilityDistribution[nClasses]; for (int j = 0; j < nClasses; j++) { pdf[j] = new TestGaussFcn1(means[i][j], stdev[i][j]); } nbTrainer.createContinuousAttribute(pdf); } nbTrainer.train(irisContinuousData, irisClassificationData); int[][] classErrors = nbTrainer.getTrainingErrors(); System.out.println(" Iris Classification Error Rates"); System.out.println("------------------------------------------------"); System.out.println(" Setosa Versicolour Virginica | Total"); System.out.println(" " + classErrors[0][0] + "/" + classErrors[0][1] + " " + classErrors[1][0] + "/" + classErrors[1][1] + " " + classErrors[2][0] + "/" + classErrors[2][1] + " | " + classErrors[3][0] + "/" + classErrors[3][1]); System.out.println( "------------------------------------------------\n\n\n"); /* Classify last 10 iris data patterns with the trained classifier */ double[] continuousInput = new double[(irisData[0].length - 1)]; double[] classifiedProbabilities = new double[nClasses]; System.out.println("Probabilities for Incorrect Classifications"); System.out.println(" Predicted "); System.out.println( " Class | Class | P(0) P(1) P(2) "); System.out.println( "-------------------------------------------------------"); for (int i = 0; i < 10; i++) { int targetClassification = (int) irisData[(irisData.length - 10) + i][0] - 1; System.arraycopy(irisData[(irisData.length - 10) + i], 1, continuousInput, 0, (irisData[0].length - 1)); classifiedProbabilities = nbTrainer.probabilities(continuousInput, null); int classification = nbTrainer.predictClass(continuousInput, null); switch (classification) { case 0: System.out.print("Setosa |"); break; case 1: System.out.print("Versicolour |"); break; case 2: System.out.print("Virginica |"); break; default: System.out.print("Missing |"); break; } switch (targetClassification) { case 0: System.out.print(" Setosa |"); break; case 1: System.out.print(" Versicolour |"); break; case 2: System.out.print(" Virginica |"); break; default: System.out.print(" Missing |"); break; } for (int j = 0; j < nClasses; j++) { Object[] pArgs = {classifiedProbabilities[j]}; System.out.printf(" %2.3f ", pArgs); } System.out.println(""); } } /** *

Defines the user supplied probability distribution.

* In this case, the probability distribution is Gaussian and results should * match the default case in Example 1. */ static public class TestGaussFcn1 implements ProbabilityDistribution { private double mean; private double stdev; public TestGaussFcn1(double mean, double stdev) { this.mean = mean; this.stdev = stdev; } @Override public double[] eval(double[] xData) { double[] pdf = new double[xData.length]; for (int i = 0; i < xData.length; i++) { pdf[i] = eval(xData[i], null); } return pdf; } @Override public double[] eval(double[] xData, Object[] Params) { double[] pdf = new double[xData.length]; for (int i = 0; i < xData.length; i++) { pdf[i] = eval(xData[i], Params); } return pdf; } @Override public double eval(double xData, Object[] Params) { return l_gaussian_pdf(xData, mean, stdev); } @Override public Object[] getParameters() { Double[] parms = new Double[2]; parms[0] = this.mean; parms[1] = this.stdev; return parms; } private double l_gaussian_pdf(double x, double mean, double stdev) { double e, phi2, z, s; double sqrt_pi2 = 2.506628274631; // sqrt(2*pi) if (Double.isNaN(x)) { return Double.NaN; } if (Double.isNaN(mean) || Double.isNaN(stdev)) { return Double.NaN; } else { z = x; z -= mean; s = stdev; phi2 = sqrt_pi2 * s; e = -0.5 * (z * z) / (s * s); return Math.exp(e) / phi2; } } } }