package com.imsl.test.example.datamining;
import com.imsl.datamining.*;
import com.imsl.stat.NormalDistribution;
/**
*
*
Trains a classifier to Fisher's Iris data.
*
* This example trains a NaiveBayesClassifier
* on Fisher's Iris data. The training data contains 140 of the 150 continuous patterns and the remaining
* ten are then used as a test set.
*
* Fisher's (1936) Iris data is often used for benchmarking classification
* algorithms. It consists of the following continuous input attributes and a
* classification target:
*
*
- Continuous Attributes Usage
*
* - Sepal Length
* - Sepal Width
* - Petal Length
* - Petal Width
* - Classification of Iris Type
*
* - Setosa
* - Versicolour
* - Virginica
*
*
*
*
* @see Code
* @see Output
*
*/
public class NaiveBayesClassifierEx1 {
private static double[][] irisData = {
{1.0, 5.1, 3.5, 1.4, .2}, {1.0, 4.9, 3.0, 1.4, .2},
{1.0, 4.7, 3.2, 1.3, .2}, {1.0, 4.6, 3.1, 1.5, .2},
{1.0, 5.0, 3.6, 1.4, .2}, {1.0, 5.4, 3.9, 1.7, .4},
{1.0, 4.6, 3.4, 1.4, .3}, {1.0, 5.0, 3.4, 1.5, .2},
{1.0, 4.4, 2.9, 1.4, .2}, {1.0, 4.9, 3.1, 1.5, .1},
{1.0, 5.4, 3.7, 1.5, .2}, {1.0, 4.8, 3.4, 1.6, .2},
{1.0, 4.8, 3.0, 1.4, .1}, {1.0, 4.3, 3.0, 1.1, .1},
{1.0, 5.8, 4.0, 1.2, .2}, {1.0, 5.7, 4.4, 1.5, .4},
{1.0, 5.4, 3.9, 1.3, .4}, {1.0, 5.1, 3.5, 1.4, .3},
{1.0, 5.7, 3.8, 1.7, .3}, {1.0, 5.1, 3.8, 1.5, .3},
{1.0, 5.4, 3.4, 1.7, .2}, {1.0, 5.1, 3.7, 1.5, .4},
{1.0, 4.6, 3.6, 1.0, .2}, {1.0, 5.1, 3.3, 1.7, .5},
{1.0, 4.8, 3.4, 1.9, .2}, {1.0, 5.0, 3.0, 1.6, .2},
{1.0, 5.0, 3.4, 1.6, .4}, {1.0, 5.2, 3.5, 1.5, .2},
{1.0, 5.2, 3.4, 1.4, .2}, {1.0, 4.7, 3.2, 1.6, .2},
{1.0, 4.8, 3.1, 1.6, .2}, {1.0, 5.4, 3.4, 1.5, .4},
{1.0, 5.2, 4.1, 1.5, .1}, {1.0, 5.5, 4.2, 1.4, .2},
{1.0, 4.9, 3.1, 1.5, .2}, {1.0, 5.0, 3.2, 1.2, .2},
{1.0, 5.5, 3.5, 1.3, .2}, {1.0, 4.9, 3.6, 1.4, .1},
{1.0, 4.4, 3.0, 1.3, .2}, {1.0, 5.1, 3.4, 1.5, .2},
{1.0, 5.0, 3.5, 1.3, .3}, {1.0, 4.5, 2.3, 1.3, .3},
{1.0, 4.4, 3.2, 1.3, .2}, {1.0, 5.0, 3.5, 1.6, .6},
{1.0, 5.1, 3.8, 1.9, .4}, {1.0, 4.8, 3.0, 1.4, .3},
{1.0, 5.1, 3.8, 1.6, .2}, {1.0, 4.6, 3.2, 1.4, .2},
{1.0, 5.3, 3.7, 1.5, .2}, {1.0, 5.0, 3.3, 1.4, .2},
{2.0, 7.0, 3.2, 4.7, 1.4}, {2.0, 6.4, 3.2, 4.5, 1.5},
{2.0, 6.9, 3.1, 4.9, 1.5}, {2.0, 5.5, 2.3, 4.0, 1.3},
{2.0, 6.5, 2.8, 4.6, 1.5}, {2.0, 5.7, 2.8, 4.5, 1.3},
{2.0, 6.3, 3.3, 4.7, 1.6}, {2.0, 4.9, 2.4, 3.3, 1.0},
{2.0, 6.6, 2.9, 4.6, 1.3}, {2.0, 5.2, 2.7, 3.9, 1.4},
{2.0, 5.0, 2.0, 3.5, 1.0}, {2.0, 5.9, 3.0, 4.2, 1.5},
{2.0, 6.0, 2.2, 4.0, 1.0}, {2.0, 6.1, 2.9, 4.7, 1.4},
{2.0, 5.6, 2.9, 3.6, 1.3}, {2.0, 6.7, 3.1, 4.4, 1.4},
{2.0, 5.6, 3.0, 4.5, 1.5}, {2.0, 5.8, 2.7, 4.1, 1.0},
{2.0, 6.2, 2.2, 4.5, 1.5}, {2.0, 5.6, 2.5, 3.9, 1.1},
{2.0, 5.9, 3.2, 4.8, 1.8}, {2.0, 6.1, 2.8, 4.0, 1.3},
{2.0, 6.3, 2.5, 4.9, 1.5}, {2.0, 6.1, 2.8, 4.7, 1.2},
{2.0, 6.4, 2.9, 4.3, 1.3}, {2.0, 6.6, 3.0, 4.4, 1.4},
{2.0, 6.8, 2.8, 4.8, 1.4}, {2.0, 6.7, 3.0, 5.0, 1.7},
{2.0, 6.0, 2.9, 4.5, 1.5}, {2.0, 5.7, 2.6, 3.5, 1.0},
{2.0, 5.5, 2.4, 3.8, 1.1}, {2.0, 5.5, 2.4, 3.7, 1.0},
{2.0, 5.8, 2.7, 3.9, 1.2}, {2.0, 6.0, 2.7, 5.1, 1.6},
{2.0, 5.4, 3.0, 4.5, 1.5}, {2.0, 6.0, 3.4, 4.5, 1.6},
{2.0, 6.7, 3.1, 4.7, 1.5}, {2.0, 6.3, 2.3, 4.4, 1.3},
{2.0, 5.6, 3.0, 4.1, 1.3}, {2.0, 5.5, 2.5, 4.0, 1.3},
{2.0, 5.5, 2.6, 4.4, 1.2}, {2.0, 6.1, 3.0, 4.6, 1.4},
{2.0, 5.8, 2.6, 4.0, 1.2}, {2.0, 5.0, 2.3, 3.3, 1.0},
{2.0, 5.6, 2.7, 4.2, 1.3}, {2.0, 5.7, 3.0, 4.2, 1.2},
{2.0, 5.7, 2.9, 4.2, 1.3}, {2.0, 6.2, 2.9, 4.3, 1.3},
{2.0, 5.1, 2.5, 3.0, 1.1}, {2.0, 5.7, 2.8, 4.1, 1.3},
{3.0, 6.3, 3.3, 6.0, 2.5}, {3.0, 5.8, 2.7, 5.1, 1.9},
{3.0, 7.1, 3.0, 5.9, 2.1}, {3.0, 6.3, 2.9, 5.6, 1.8},
{3.0, 6.5, 3.0, 5.8, 2.2}, {3.0, 7.6, 3.0, 6.6, 2.1},
{3.0, 4.9, 2.5, 4.5, 1.7}, {3.0, 7.3, 2.9, 6.3, 1.8},
{3.0, 6.7, 2.5, 5.8, 1.8}, {3.0, 7.2, 3.6, 6.1, 2.5},
{3.0, 6.5, 3.2, 5.1, 2.0}, {3.0, 6.4, 2.7, 5.3, 1.9},
{3.0, 6.8, 3.0, 5.5, 2.1}, {3.0, 5.7, 2.5, 5.0, 2.0},
{3.0, 5.8, 2.8, 5.1, 2.4}, {3.0, 6.4, 3.2, 5.3, 2.3},
{3.0, 6.5, 3.0, 5.5, 1.8}, {3.0, 7.7, 3.8, 6.7, 2.2},
{3.0, 7.7, 2.6, 6.9, 2.3}, {3.0, 6.0, 2.2, 5.0, 1.5},
{3.0, 6.9, 3.2, 5.7, 2.3}, {3.0, 5.6, 2.8, 4.9, 2.0},
{3.0, 7.7, 2.8, 6.7, 2.0}, {3.0, 6.3, 2.7, 4.9, 1.8},
{3.0, 6.7, 3.3, 5.7, 2.1}, {3.0, 7.2, 3.2, 6.0, 1.8},
{3.0, 6.2, 2.8, 4.8, 1.8}, {3.0, 6.1, 3.0, 4.9, 1.8},
{3.0, 6.4, 2.8, 5.6, 2.1}, {3.0, 7.2, 3.0, 5.8, 1.6},
{3.0, 7.4, 2.8, 6.1, 1.9}, {3.0, 7.9, 3.8, 6.4, 2.0},
{3.0, 6.4, 2.8, 5.6, 2.2}, {3.0, 6.3, 2.8, 5.1, 1.5},
{3.0, 6.1, 2.6, 5.6, 1.4}, {3.0, 7.7, 3.0, 6.1, 2.3},
{3.0, 6.3, 3.4, 5.6, 2.4}, {3.0, 6.4, 3.1, 5.5, 1.8},
{3.0, 6.0, 3.0, 4.8, 1.8}, {3.0, 6.9, 3.1, 5.4, 2.1},
{3.0, 6.7, 3.1, 5.6, 2.4}, {3.0, 6.9, 3.1, 5.1, 2.3},
{3.0, 5.8, 2.7, 5.1, 1.9}, {3.0, 6.8, 3.2, 5.9, 2.3},
{3.0, 6.7, 3.3, 5.7, 2.5}, {3.0, 6.7, 3.0, 5.2, 2.3},
{3.0, 6.3, 2.5, 5.0, 1.9}, {3.0, 6.5, 3.0, 5.2, 2.0},
{3.0, 6.2, 3.4, 5.4, 2.3}, {3.0, 5.9, 3.0, 5.1, 1.8}
};
public static void main(String[] args) throws Exception {
/* Data corrections described in the KDD data mining archive */
irisData[34][4] = 0.1;
irisData[37][2] = 3.1;
irisData[37][3] = 1.5;
/* Train first 140 patterns of the iris Fisher Data */
int[] irisClassificationData = new int[irisData.length - 10];
double[][] irisContinuousData
= new double[irisData.length - 10][irisData[0].length - 1];
for (int i = 0; i < irisData.length - 10; i++) {
irisClassificationData[i] = (int) irisData[i][0] - 1;
System.arraycopy(irisData[i], 1,
irisContinuousData[i], 0, irisData[0].length - 1);
}
int nNominal = 0; // no nominal input attributes
int nContinuous = 4; // four continuous input attributes
int nClasses = 3; // three classification categories
NaiveBayesClassifier nbTrainer
= new NaiveBayesClassifier(nContinuous, nNominal, nClasses);
for (int i = 0; i < nContinuous; i++) {
nbTrainer.createContinuousAttribute(new NormalDistribution());
}
nbTrainer.train(irisContinuousData, irisClassificationData);
int[][] classErrors = nbTrainer.getTrainingErrors();
System.out.println(" Iris Classification Training Error Rates");
System.out.println("------------------------------------------------");
System.out.println(" Setosa Versicolour Virginica | Total");
System.out.println(" " + classErrors[0][0] + "/" + classErrors[0][1]
+ " " + classErrors[1][0] + "/" + classErrors[1][1]
+ " " + classErrors[2][0] + "/" + classErrors[2][1]
+ " | " + classErrors[3][0]
+ "/" + classErrors[3][1]);
System.out.println(
"------------------------------------------------\n\n\n");
/* Classify last 10 iris data patterns with the trained classifier */
double[] continuousInput = new double[(irisData[0].length - 1)];
double[] classifiedProbabilities = new double[nClasses];
System.out.println("Probabilities for Incorrect Classifications");
System.out.println(" Predicted ");
System.out.println(
" Class | Class | P(0) P(1) P(2) ");
System.out.println(
"-------------------------------------------------------");
for (int i = 0; i < 10; i++) {
int targetClassification
= (int) irisData[(irisData.length - 10) + i][0] - 1;
System.arraycopy(irisData[(irisData.length - 10) + i],
1, continuousInput, 0, (irisData[0].length - 1));
classifiedProbabilities
= nbTrainer.probabilities(continuousInput, null);
int classification = nbTrainer.predictClass(continuousInput, null);
if (classification == 0) {
System.out.print("Setosa |");
} else if (classification == 1) {
System.out.print("Versicolour |");
} else if (classification == 2) {
System.out.print("Virginica |");
} else {
System.out.print("Missing |");
}
if (targetClassification == 0) {
System.out.print(" Setosa |");
} else if (targetClassification == 1) {
System.out.print(" Versicolour |");
} else if (targetClassification == 2) {
System.out.print(" Virginica |");
} else {
System.out.print(" Missing |");
}
for (int j = 0; j < nClasses; j++) {
Object[] pArgs = {new Double(classifiedProbabilities[j])};
System.out.printf(" %2.3f ", pArgs);
}
System.out.println();
}
}
}