package com.imsl.test.example.datamining;
import com.imsl.datamining.*;
import com.imsl.stat.ProbabilityDistribution;
/**
*
Trains a classifier with a user supplied probability
* function.
*
*
* This example is similar to {@link NaiveBayesClassifierEx1}, where we train a
* classifier on Fisher's Iris data using 140 of the 150 continuous patterns,
* and then classify ten remaining plants using their sepal and petal
* measurements.
*
* Instead of using the NormalDistribution class, a user supplied
* normal (Gaussian) distribution is supplied directly. Rather than calculating
* the means and standard deviations from the data, as is done by the
* NormalDistribution
's eval(double[])
method, the
* user supplied class requires the means and standard deviations in the class
* constructor. The output is the same as in {@link NaiveBayesClassifierEx1},
* since the means and standard deviations in this example are simply rounded
* means and standard deviations of the actual data subset by target
* classifications.
*
* @see Example
* @see Output
*
*/
public class NaiveBayesClassifierEx3 {
private static double[][] irisData = {
{1.0, 5.1, 3.5, 1.4, .2}, {1.0, 4.9, 3.0, 1.4, .2},
{1.0, 4.7, 3.2, 1.3, .2}, {1.0, 4.6, 3.1, 1.5, .2},
{1.0, 5.0, 3.6, 1.4, .2}, {1.0, 5.4, 3.9, 1.7, .4},
{1.0, 4.6, 3.4, 1.4, .3}, {1.0, 5.0, 3.4, 1.5, .2},
{1.0, 4.4, 2.9, 1.4, .2}, {1.0, 4.9, 3.1, 1.5, .1},
{1.0, 5.4, 3.7, 1.5, .2}, {1.0, 4.8, 3.4, 1.6, .2},
{1.0, 4.8, 3.0, 1.4, .1}, {1.0, 4.3, 3.0, 1.1, .1},
{1.0, 5.8, 4.0, 1.2, .2}, {1.0, 5.7, 4.4, 1.5, .4},
{1.0, 5.4, 3.9, 1.3, .4}, {1.0, 5.1, 3.5, 1.4, .3},
{1.0, 5.7, 3.8, 1.7, .3}, {1.0, 5.1, 3.8, 1.5, .3},
{1.0, 5.4, 3.4, 1.7, .2}, {1.0, 5.1, 3.7, 1.5, .4},
{1.0, 4.6, 3.6, 1.0, .2}, {1.0, 5.1, 3.3, 1.7, .5},
{1.0, 4.8, 3.4, 1.9, .2}, {1.0, 5.0, 3.0, 1.6, .2},
{1.0, 5.0, 3.4, 1.6, .4}, {1.0, 5.2, 3.5, 1.5, .2},
{1.0, 5.2, 3.4, 1.4, .2}, {1.0, 4.7, 3.2, 1.6, .2},
{1.0, 4.8, 3.1, 1.6, .2}, {1.0, 5.4, 3.4, 1.5, .4},
{1.0, 5.2, 4.1, 1.5, .1}, {1.0, 5.5, 4.2, 1.4, .2},
{1.0, 4.9, 3.1, 1.5, .2}, {1.0, 5.0, 3.2, 1.2, .2},
{1.0, 5.5, 3.5, 1.3, .2}, {1.0, 4.9, 3.6, 1.4, .1},
{1.0, 4.4, 3.0, 1.3, .2}, {1.0, 5.1, 3.4, 1.5, .2},
{1.0, 5.0, 3.5, 1.3, .3}, {1.0, 4.5, 2.3, 1.3, .3},
{1.0, 4.4, 3.2, 1.3, .2}, {1.0, 5.0, 3.5, 1.6, .6},
{1.0, 5.1, 3.8, 1.9, .4}, {1.0, 4.8, 3.0, 1.4, .3},
{1.0, 5.1, 3.8, 1.6, .2}, {1.0, 4.6, 3.2, 1.4, .2},
{1.0, 5.3, 3.7, 1.5, .2}, {1.0, 5.0, 3.3, 1.4, .2},
{2.0, 7.0, 3.2, 4.7, 1.4}, {2.0, 6.4, 3.2, 4.5, 1.5},
{2.0, 6.9, 3.1, 4.9, 1.5}, {2.0, 5.5, 2.3, 4.0, 1.3},
{2.0, 6.5, 2.8, 4.6, 1.5}, {2.0, 5.7, 2.8, 4.5, 1.3},
{2.0, 6.3, 3.3, 4.7, 1.6}, {2.0, 4.9, 2.4, 3.3, 1.0},
{2.0, 6.6, 2.9, 4.6, 1.3}, {2.0, 5.2, 2.7, 3.9, 1.4},
{2.0, 5.0, 2.0, 3.5, 1.0}, {2.0, 5.9, 3.0, 4.2, 1.5},
{2.0, 6.0, 2.2, 4.0, 1.0}, {2.0, 6.1, 2.9, 4.7, 1.4},
{2.0, 5.6, 2.9, 3.6, 1.3}, {2.0, 6.7, 3.1, 4.4, 1.4},
{2.0, 5.6, 3.0, 4.5, 1.5}, {2.0, 5.8, 2.7, 4.1, 1.0},
{2.0, 6.2, 2.2, 4.5, 1.5}, {2.0, 5.6, 2.5, 3.9, 1.1},
{2.0, 5.9, 3.2, 4.8, 1.8}, {2.0, 6.1, 2.8, 4.0, 1.3},
{2.0, 6.3, 2.5, 4.9, 1.5}, {2.0, 6.1, 2.8, 4.7, 1.2},
{2.0, 6.4, 2.9, 4.3, 1.3}, {2.0, 6.6, 3.0, 4.4, 1.4},
{2.0, 6.8, 2.8, 4.8, 1.4}, {2.0, 6.7, 3.0, 5.0, 1.7},
{2.0, 6.0, 2.9, 4.5, 1.5}, {2.0, 5.7, 2.6, 3.5, 1.0},
{2.0, 5.5, 2.4, 3.8, 1.1}, {2.0, 5.5, 2.4, 3.7, 1.0},
{2.0, 5.8, 2.7, 3.9, 1.2}, {2.0, 6.0, 2.7, 5.1, 1.6},
{2.0, 5.4, 3.0, 4.5, 1.5}, {2.0, 6.0, 3.4, 4.5, 1.6},
{2.0, 6.7, 3.1, 4.7, 1.5}, {2.0, 6.3, 2.3, 4.4, 1.3},
{2.0, 5.6, 3.0, 4.1, 1.3}, {2.0, 5.5, 2.5, 4.0, 1.3},
{2.0, 5.5, 2.6, 4.4, 1.2}, {2.0, 6.1, 3.0, 4.6, 1.4},
{2.0, 5.8, 2.6, 4.0, 1.2}, {2.0, 5.0, 2.3, 3.3, 1.0},
{2.0, 5.6, 2.7, 4.2, 1.3}, {2.0, 5.7, 3.0, 4.2, 1.2},
{2.0, 5.7, 2.9, 4.2, 1.3}, {2.0, 6.2, 2.9, 4.3, 1.3},
{2.0, 5.1, 2.5, 3.0, 1.1}, {2.0, 5.7, 2.8, 4.1, 1.3},
{3.0, 6.3, 3.3, 6.0, 2.5}, {3.0, 5.8, 2.7, 5.1, 1.9},
{3.0, 7.1, 3.0, 5.9, 2.1}, {3.0, 6.3, 2.9, 5.6, 1.8},
{3.0, 6.5, 3.0, 5.8, 2.2}, {3.0, 7.6, 3.0, 6.6, 2.1},
{3.0, 4.9, 2.5, 4.5, 1.7}, {3.0, 7.3, 2.9, 6.3, 1.8},
{3.0, 6.7, 2.5, 5.8, 1.8}, {3.0, 7.2, 3.6, 6.1, 2.5},
{3.0, 6.5, 3.2, 5.1, 2.0}, {3.0, 6.4, 2.7, 5.3, 1.9},
{3.0, 6.8, 3.0, 5.5, 2.1}, {3.0, 5.7, 2.5, 5.0, 2.0},
{3.0, 5.8, 2.8, 5.1, 2.4}, {3.0, 6.4, 3.2, 5.3, 2.3},
{3.0, 6.5, 3.0, 5.5, 1.8}, {3.0, 7.7, 3.8, 6.7, 2.2},
{3.0, 7.7, 2.6, 6.9, 2.3}, {3.0, 6.0, 2.2, 5.0, 1.5},
{3.0, 6.9, 3.2, 5.7, 2.3}, {3.0, 5.6, 2.8, 4.9, 2.0},
{3.0, 7.7, 2.8, 6.7, 2.0}, {3.0, 6.3, 2.7, 4.9, 1.8},
{3.0, 6.7, 3.3, 5.7, 2.1}, {3.0, 7.2, 3.2, 6.0, 1.8},
{3.0, 6.2, 2.8, 4.8, 1.8}, {3.0, 6.1, 3.0, 4.9, 1.8},
{3.0, 6.4, 2.8, 5.6, 2.1}, {3.0, 7.2, 3.0, 5.8, 1.6},
{3.0, 7.4, 2.8, 6.1, 1.9}, {3.0, 7.9, 3.8, 6.4, 2.0},
{3.0, 6.4, 2.8, 5.6, 2.2}, {3.0, 6.3, 2.8, 5.1, 1.5},
{3.0, 6.1, 2.6, 5.6, 1.4}, {3.0, 7.7, 3.0, 6.1, 2.3},
{3.0, 6.3, 3.4, 5.6, 2.4}, {3.0, 6.4, 3.1, 5.5, 1.8},
{3.0, 6.0, 3.0, 4.8, 1.8}, {3.0, 6.9, 3.1, 5.4, 2.1},
{3.0, 6.7, 3.1, 5.6, 2.4}, {3.0, 6.9, 3.1, 5.1, 2.3},
{3.0, 5.8, 2.7, 5.1, 1.9}, {3.0, 6.8, 3.2, 5.9, 2.3},
{3.0, 6.7, 3.3, 5.7, 2.5}, {3.0, 6.7, 3.0, 5.2, 2.3},
{3.0, 6.3, 2.5, 5.0, 1.9}, {3.0, 6.5, 3.0, 5.2, 2.0},
{3.0, 6.2, 3.4, 5.4, 2.3}, {3.0, 5.9, 3.0, 5.1, 1.8}
};
/**
* The main method for the example.
*/
public static void main(String[] args) throws Exception {
/* Data corrections described in the KDD data mining archive */
irisData[34][4] = 0.1;
irisData[37][2] = 3.1;
irisData[37][3] = 1.5;
/* Train first 140 patterns of the iris Fisher Data */
int[] irisClassificationData = new int[irisData.length - 10];
double[][] irisContinuousData
= new double[irisData.length - 10][irisData[0].length - 1];
for (int i = 0; i < irisData.length - 10; i++) {
irisClassificationData[i] = (int) irisData[i][0] - 1;
System.arraycopy(irisData[i], 1,
irisContinuousData[i], 0, irisData[0].length - 1);
}
int nNominal = 0; // no nominal input attributes
int nContinuous = 4; // four continuous input attributes
int nClasses = 3; // three classification categories
NaiveBayesClassifier nbTrainer
= new NaiveBayesClassifier(nContinuous, nNominal, nClasses);
double[][] means = {
{5.06, 5.94, 6.58}, {3.42, 2.8, 2.97},
{1.5, 4.3, 5.6}, {0.25, 1.33, 2.1}
};
double[][] stdev = {
{0.35, 0.52, 0.64}, {0.38, 0.3, 0.32},
{0.17, 0.47, 0.55}, {0.12, 0.198, 0.275}
};
for (int i = 0; i < nContinuous; i++) {
ProbabilityDistribution[] pdf
= new ProbabilityDistribution[nClasses];
for (int j = 0; j < nClasses; j++) {
pdf[j] = new TestGaussFcn1(means[i][j], stdev[i][j]);
}
nbTrainer.createContinuousAttribute(pdf);
}
nbTrainer.train(irisContinuousData, irisClassificationData);
int[][] classErrors = nbTrainer.getTrainingErrors();
System.out.println(" Iris Classification Error Rates");
System.out.println("------------------------------------------------");
System.out.println(" Setosa Versicolour Virginica | Total");
System.out.println(" " + classErrors[0][0] + "/" + classErrors[0][1]
+ " " + classErrors[1][0] + "/" + classErrors[1][1]
+ " " + classErrors[2][0] + "/" + classErrors[2][1]
+ " | " + classErrors[3][0] + "/" + classErrors[3][1]);
System.out.println(
"------------------------------------------------\n\n\n");
/* Classify last 10 iris data patterns with the trained classifier */
double[] continuousInput = new double[(irisData[0].length - 1)];
double[] classifiedProbabilities = new double[nClasses];
System.out.println("Probabilities for Incorrect Classifications");
System.out.println(" Predicted ");
System.out.println(
" Class | Class | P(0) P(1) P(2) ");
System.out.println(
"-------------------------------------------------------");
for (int i = 0; i < 10; i++) {
int targetClassification
= (int) irisData[(irisData.length - 10) + i][0] - 1;
System.arraycopy(irisData[(irisData.length - 10) + i], 1,
continuousInput, 0, (irisData[0].length - 1));
classifiedProbabilities
= nbTrainer.probabilities(continuousInput, null);
int classification = nbTrainer.predictClass(continuousInput, null);
switch (classification) {
case 0:
System.out.print("Setosa |");
break;
case 1:
System.out.print("Versicolour |");
break;
case 2:
System.out.print("Virginica |");
break;
default:
System.out.print("Missing |");
break;
}
switch (targetClassification) {
case 0:
System.out.print(" Setosa |");
break;
case 1:
System.out.print(" Versicolour |");
break;
case 2:
System.out.print(" Virginica |");
break;
default:
System.out.print(" Missing |");
break;
}
for (int j = 0; j < nClasses; j++) {
Object[] pArgs = {classifiedProbabilities[j]};
System.out.printf(" %2.3f ", pArgs);
}
System.out.println("");
}
}
/**
* Defines the user supplied probability distribution.
* In this case, the probability distribution is Gaussian and results should
* match the default case in Example 1.
*/
static public class TestGaussFcn1 implements ProbabilityDistribution {
private double mean;
private double stdev;
public TestGaussFcn1(double mean, double stdev) {
this.mean = mean;
this.stdev = stdev;
}
@Override
public double[] eval(double[] xData) {
double[] pdf = new double[xData.length];
for (int i = 0; i < xData.length; i++) {
pdf[i] = eval(xData[i], null);
}
return pdf;
}
@Override
public double[] eval(double[] xData, Object[] Params) {
double[] pdf = new double[xData.length];
for (int i = 0; i < xData.length; i++) {
pdf[i] = eval(xData[i], Params);
}
return pdf;
}
@Override
public double eval(double xData, Object[] Params) {
return l_gaussian_pdf(xData, mean, stdev);
}
@Override
public Object[] getParameters() {
Double[] parms = new Double[2];
parms[0] = this.mean;
parms[1] = this.stdev;
return parms;
}
private double l_gaussian_pdf(double x, double mean, double stdev) {
double e, phi2, z, s;
double sqrt_pi2 = 2.506628274631; // sqrt(2*pi)
if (Double.isNaN(x)) {
return Double.NaN;
}
if (Double.isNaN(mean) || Double.isNaN(stdev)) {
return Double.NaN;
} else {
z = x;
z -= mean;
s = stdev;
phi2 = sqrt_pi2 * s;
e = -0.5 * (z * z) / (s * s);
return Math.exp(e) / phi2;
}
}
}
}