This example trains a Naive Bayes classifier using 24 training patterns with four nominal input attributes.
The first nominal attribute has three classifications and the others have two. The target classifications are contact lense prescriptions: hard, soft or neither recommended. These data are benchmark data from the Knowledge Discovery Databases archive maintained at the University of California, Irvine: http://archive.ics.uci.edu/ml/datasets/Lenses.
import com.imsl.datamining.*;
import java.io.*;
public class NaiveBayesClassifierEx2 {
public static void main(String[] args) throws Exception {
int[][] contactLensData = {
{1, 1, 1, 1}, {1, 1, 1, 2}, {1, 1, 2, 1}, {1, 1, 2, 2},
{1, 2, 1, 1}, {1, 2, 1, 2}, {1, 2, 2, 1}, {1, 2, 2, 2},
{2, 1, 1, 1}, {2, 1, 1, 2}, {2, 1, 2, 1}, {2, 1, 2, 2},
{2, 2, 1, 1}, {2, 2, 1, 2}, {2, 2, 2, 1}, {2, 2, 2, 2},
{3, 1, 1, 1}, {3, 1, 1, 2}, {3, 1, 2, 1}, {3, 1, 2, 2},
{3, 2, 1, 1}, {3, 2, 1, 2}, {3, 2, 2, 1}, {3, 2, 2, 2}
};
int[] classificationData = {
3, 2, 3, 1,
3, 2, 3, 1,
3, 2, 3, 1,
3, 2, 3, 3,
3, 3, 3, 1,
3, 2, 3, 3
};
/* classification values must start at 0 */
for (int i=0; i<classificationData.length; i++) {
classificationData[i] -=1;
for (int j=0; j<contactLensData[0].length; j++){
contactLensData[i][j] -= 1;
}
}
NaiveBayesClassifier nbTrainer = new NaiveBayesClassifier(0, 4, 3);
int nNominal = 4;
int categories[] = {3, 2, 2, 2};
for (int i=0; i<nNominal; i++)
nbTrainer.createNominalAttribute(categories[i]);
nbTrainer.train(contactLensData, classificationData );
int[][] classErrors = nbTrainer.getTrainingErrors();
System.out.println("\n Contact Lens Error Rates");
System.out.println("------------------------------------------------");
System.out.println(" Hard Soft Neither | Total");
System.out.println(" "+classErrors[0][0]+"/"+classErrors[0][1]+
" "+classErrors[1][0]+"/"+classErrors[1][1]+" "+
classErrors[2][0]+"/"+classErrors[2][1]+" | "+
classErrors[3][0]+"/"+classErrors[3][1]);
System.out.println(
"------------------------------------------------\n\n\n");
/* Classify all patterns with the trained classifier */
int[] nominalInput = new int[contactLensData[0].length];
double[] classifiedProbabilities = new double[3];
System.out.println("Probabilities for Incorrect Classifications");
System.out.println(" Predicted ");
System.out.println(" Class | Class | "+"" +
"P(0) P(1) P(2) | classification error");
System.out.println("---------------------------------------"+
"-----------------------------------------");
for (int i=0; i<contactLensData.length; i++) {
System.arraycopy(contactLensData[i], 0,
nominalInput, 0, contactLensData[0].length);
classifiedProbabilities =
nbTrainer.probabilities(null, nominalInput);
int classification = nbTrainer.predictClass(null, nominalInput);
double error =
nbTrainer.classError(null, nominalInput, classificationData[i]);
if ( classification == 0 ) System.out.print(" Hard |");
else if (classification == 1) System.out.print(" Soft |");
else if (classification == 2) System.out.print(" Neither |");
else System.out.print(" Missing |");
if (classificationData[i] == 0) System.out.print(" Hard |");
else if (classificationData[i] == 1) System.out.print(" Soft |");
else if (classificationData[i] == 2) System.out.print(" Neither |");
else System.out.print(" Missing |");
for (int j=0; j<3; j++) {
Object[] pArgs = {new Double(classifiedProbabilities[j])};
System.out.printf(" %2.3f ", pArgs);
}
System.out.println(" | "+error);
}
}
}
Contact Lens Error Rates
------------------------------------------------
Hard Soft Neither | Total
0/4 0/5 1/15 | 1/24
------------------------------------------------
Probabilities for Incorrect Classifications
Predicted
Class | Class | P(0) P(1) P(2) | classification error
--------------------------------------------------------------------------------
Neither | Neither | 0.044 0.130 0.827 | 0.17328273537224337
Soft | Soft | 0.174 0.622 0.203 | 0.3777037515962849
Neither | Neither | 0.186 0.018 0.795 | 0.20481484453118481
Hard | Hard | 0.724 0.086 0.190 | 0.2762213881704395
Neither | Neither | 0.019 0.154 0.827 | 0.1731192809428289
Soft | Soft | 0.076 0.724 0.200 | 0.27580075885359856
Neither | Neither | 0.092 0.024 0.884 | 0.11636647828065916
Hard | Hard | 0.524 0.166 0.310 | 0.4759671271915248
Neither | Neither | 0.025 0.113 0.862 | 0.1379488917335967
Soft | Soft | 0.118 0.633 0.248 | 0.36667756449941913
Neither | Neither | 0.113 0.017 0.870 | 0.1300941614979002
Hard | Hard | 0.606 0.108 0.286 | 0.3943953202159177
Neither | Neither | 0.011 0.133 0.856 | 0.1438071275344872
Soft | Soft | 0.050 0.714 0.236 | 0.28621890950268947
Neither | Neither | 0.054 0.021 0.925 | 0.07477065629881141
Neither | Neither | 0.394 0.187 0.419 | 0.5812071082121901
Neither | Neither | 0.023 0.068 0.909 | 0.09075297709414065
Soft | Neither | 0.142 0.509 0.349 | 0.6509258682358112
Neither | Neither | 0.099 0.010 0.891 | 0.1092518501445181
Hard | Hard | 0.599 0.071 0.330 | 0.40138301592296255
Neither | Neither | 0.010 0.081 0.909 | 0.09065883447725098
Soft | Soft | 0.062 0.594 0.344 | 0.40625618912343875
Neither | Neither | 0.047 0.012 0.941 | 0.059009463869502565
Neither | Neither | 0.391 0.124 0.485 | 0.514955474332508
Link to Java source.