package com.imsl.test.example.datamining; import com.imsl.datamining.*; /** * *
* This example trains a NaiveBayesClassifier on 24 training patterns * with four nominal input attributes.
** The first nominal attribute has three classifications and the others have * two. The target classifications are contact lense prescriptions: hard, soft * or neither recommended. These data are benchmark data from the Knowledge * Discovery Databases archive maintained at the University of California, * Irvine: * * http://archive.ics.uci.edu/ml/datasets/Lenses. *
* * @see Code * @see Output */ public class NaiveBayesClassifierEx2 { public static void main(String[] args) throws Exception { int[][] contactLensData = { {1, 1, 1, 1}, {1, 1, 1, 2}, {1, 1, 2, 1}, {1, 1, 2, 2}, {1, 2, 1, 1}, {1, 2, 1, 2}, {1, 2, 2, 1}, {1, 2, 2, 2}, {2, 1, 1, 1}, {2, 1, 1, 2}, {2, 1, 2, 1}, {2, 1, 2, 2}, {2, 2, 1, 1}, {2, 2, 1, 2}, {2, 2, 2, 1}, {2, 2, 2, 2}, {3, 1, 1, 1}, {3, 1, 1, 2}, {3, 1, 2, 1}, {3, 1, 2, 2}, {3, 2, 1, 1}, {3, 2, 1, 2}, {3, 2, 2, 1}, {3, 2, 2, 2} }; int[] classificationData = { 3, 2, 3, 1, 3, 2, 3, 1, 3, 2, 3, 1, 3, 2, 3, 3, 3, 3, 3, 1, 3, 2, 3, 3 }; /* classification values must start at 0 */ for (int i = 0; i < classificationData.length; i++) { classificationData[i] -= 1; for (int j = 0; j < contactLensData[0].length; j++) { contactLensData[i][j] -= 1; } } NaiveBayesClassifier nbTrainer = new NaiveBayesClassifier(0, 4, 3); int nNominal = 4; int categories[] = {3, 2, 2, 2}; for (int i = 0; i < nNominal; i++) { nbTrainer.createNominalAttribute(categories[i]); } nbTrainer.train(contactLensData, classificationData); int[][] classErrors = nbTrainer.getTrainingErrors(); System.out.println("\n Contact Lens Error Rates"); System.out.println("------------------------------------------------"); System.out.println(" Hard Soft Neither | Total"); System.out.println(" " + classErrors[0][0] + "/" + classErrors[0][1] + " " + classErrors[1][0] + "/" + classErrors[1][1] + " " + classErrors[2][0] + "/" + classErrors[2][1] + " | " + classErrors[3][0] + "/" + classErrors[3][1]); System.out.println( "------------------------------------------------\n\n\n"); /* Classify all patterns with the trained classifier */ int[] nominalInput = new int[contactLensData[0].length]; double[] classifiedProbabilities = new double[3]; System.out.println("Probabilities for Incorrect Classifications"); System.out.println(" Predicted "); System.out.println(" Class | Class | " + "" + "P(0) P(1) P(2) | classification error"); System.out.println("---------------------------------------" + "-----------------------------------------"); for (int i = 0; i < contactLensData.length; i++) { System.arraycopy(contactLensData[i], 0, nominalInput, 0, contactLensData[0].length); classifiedProbabilities = nbTrainer.probabilities(null, nominalInput); int classification = nbTrainer.predictClass(null, nominalInput); double error = nbTrainer.classError( null, nominalInput, classificationData[i]); if (classification == 0) { System.out.print(" Hard |"); } else if (classification == 1) { System.out.print(" Soft |"); } else if (classification == 2) { System.out.print(" Neither |"); } else { System.out.print(" Missing |"); } if (classificationData[i] == 0) { System.out.print(" Hard |"); } else if (classificationData[i] == 1) { System.out.print(" Soft |"); } else if (classificationData[i] == 2) { System.out.print(" Neither |"); } else { System.out.print(" Missing |"); } for (int j = 0; j < 3; j++) { Object[] pArgs = {new Double(classifiedProbabilities[j])}; System.out.printf(" %2.3f ", pArgs); } System.out.println(" | " + error); } } }