Fisher's (1936) Iris data is often used for benchmarking classification algorithms. It consists of the following continuous input attributes and a classification target:
This example trains a Naive Bayes classifier using 140 of the 150 continuous patterns, then classifies ten unknown plants using their sepal and petal measurements.
using System;
using Imsl.DataMining;
using NormalDistribution = Imsl.Stat.NormalDistribution;
public class NaiveBayesClassifierEx1
{
private static double[][] irisFisherData = new double[][]{
new double[]{1.0, 5.1, 3.5, 1.4, .2},
new double[]{1.0, 4.9, 3.0, 1.4, .2},
new double[]{1.0, 4.7, 3.2, 1.3, .2},
new double[]{1.0, 4.6, 3.1, 1.5, .2},
new double[]{1.0, 5.0, 3.6, 1.4, .2},
new double[]{1.0, 5.4, 3.9, 1.7, .4},
new double[]{1.0, 4.6, 3.4, 1.4, .3},
new double[]{1.0, 5.0, 3.4, 1.5, .2},
new double[]{1.0, 4.4, 2.9, 1.4, .2},
new double[]{1.0, 4.9, 3.1, 1.5, .1},
new double[]{1.0, 5.4, 3.7, 1.5, .2},
new double[]{1.0, 4.8, 3.4, 1.6, .2},
new double[]{1.0, 4.8, 3.0, 1.4, .1},
new double[]{1.0, 4.3, 3.0, 1.1, .1},
new double[]{1.0, 5.8, 4.0, 1.2, .2},
new double[]{1.0, 5.7, 4.4, 1.5, .4},
new double[]{1.0, 5.4, 3.9, 1.3, .4},
new double[]{1.0, 5.1, 3.5, 1.4, .3},
new double[]{1.0, 5.7, 3.8, 1.7, .3},
new double[]{1.0, 5.1, 3.8, 1.5, .3},
new double[]{1.0, 5.4, 3.4, 1.7, .2},
new double[]{1.0, 5.1, 3.7, 1.5, .4},
new double[]{1.0, 4.6, 3.6, 1.0, .2},
new double[]{1.0, 5.1, 3.3, 1.7, .5},
new double[]{1.0, 4.8, 3.4, 1.9, .2},
new double[]{1.0, 5.0, 3.0, 1.6, .2},
new double[]{1.0, 5.0, 3.4, 1.6, .4},
new double[]{1.0, 5.2, 3.5, 1.5, .2},
new double[]{1.0, 5.2, 3.4, 1.4, .2},
new double[]{1.0, 4.7, 3.2, 1.6, .2},
new double[]{1.0, 4.8, 3.1, 1.6, .2},
new double[]{1.0, 5.4, 3.4, 1.5, .4},
new double[]{1.0, 5.2, 4.1, 1.5, .1},
new double[]{1.0, 5.5, 4.2, 1.4, .2},
new double[]{1.0, 4.9, 3.1, 1.5, .2},
new double[]{1.0, 5.0, 3.2, 1.2, .2},
new double[]{1.0, 5.5, 3.5, 1.3, .2},
new double[]{1.0, 4.9, 3.6, 1.4, .1},
new double[]{1.0, 4.4, 3.0, 1.3, .2},
new double[]{1.0, 5.1, 3.4, 1.5, .2},
new double[]{1.0, 5.0, 3.5, 1.3, .3},
new double[]{1.0, 4.5, 2.3, 1.3, .3},
new double[]{1.0, 4.4, 3.2, 1.3, .2},
new double[]{1.0, 5.0, 3.5, 1.6, .6},
new double[]{1.0, 5.1, 3.8, 1.9, .4},
new double[]{1.0, 4.8, 3.0, 1.4, .3},
new double[]{1.0, 5.1, 3.8, 1.6, .2},
new double[]{1.0, 4.6, 3.2, 1.4, .2},
new double[]{1.0, 5.3, 3.7, 1.5, .2},
new double[]{1.0, 5.0, 3.3, 1.4, .2},
new double[]{2.0, 7.0, 3.2, 4.7, 1.4},
new double[]{2.0, 6.4, 3.2, 4.5, 1.5},
new double[]{2.0, 6.9, 3.1, 4.9, 1.5},
new double[]{2.0, 5.5, 2.3, 4.0, 1.3},
new double[]{2.0, 6.5, 2.8, 4.6, 1.5},
new double[]{2.0, 5.7, 2.8, 4.5, 1.3},
new double[]{2.0, 6.3, 3.3, 4.7, 1.6},
new double[]{2.0, 4.9, 2.4, 3.3, 1.0},
new double[]{2.0, 6.6, 2.9, 4.6, 1.3},
new double[]{2.0, 5.2, 2.7, 3.9, 1.4},
new double[]{2.0, 5.0, 2.0, 3.5, 1.0},
new double[]{2.0, 5.9, 3.0, 4.2, 1.5},
new double[]{2.0, 6.0, 2.2, 4.0, 1.0},
new double[]{2.0, 6.1, 2.9, 4.7, 1.4},
new double[]{2.0, 5.6, 2.9, 3.6, 1.3},
new double[]{2.0, 6.7, 3.1, 4.4, 1.4},
new double[]{2.0, 5.6, 3.0, 4.5, 1.5},
new double[]{2.0, 5.8, 2.7, 4.1, 1.0},
new double[]{2.0, 6.2, 2.2, 4.5, 1.5},
new double[]{2.0, 5.6, 2.5, 3.9, 1.1},
new double[]{2.0, 5.9, 3.2, 4.8, 1.8},
new double[]{2.0, 6.1, 2.8, 4.0, 1.3},
new double[]{2.0, 6.3, 2.5, 4.9, 1.5},
new double[]{2.0, 6.1, 2.8, 4.7, 1.2},
new double[]{2.0, 6.4, 2.9, 4.3, 1.3},
new double[]{2.0, 6.6, 3.0, 4.4, 1.4},
new double[]{2.0, 6.8, 2.8, 4.8, 1.4},
new double[]{2.0, 6.7, 3.0, 5.0, 1.7},
new double[]{2.0, 6.0, 2.9, 4.5, 1.5},
new double[]{2.0, 5.7, 2.6, 3.5, 1.0},
new double[]{2.0, 5.5, 2.4, 3.8, 1.1},
new double[]{2.0, 5.5, 2.4, 3.7, 1.0},
new double[]{2.0, 5.8, 2.7, 3.9, 1.2},
new double[]{2.0, 6.0, 2.7, 5.1, 1.6},
new double[]{2.0, 5.4, 3.0, 4.5, 1.5},
new double[]{2.0, 6.0, 3.4, 4.5, 1.6},
new double[]{2.0, 6.7, 3.1, 4.7, 1.5},
new double[]{2.0, 6.3, 2.3, 4.4, 1.3},
new double[]{2.0, 5.6, 3.0, 4.1, 1.3},
new double[]{2.0, 5.5, 2.5, 4.0, 1.3},
new double[]{2.0, 5.5, 2.6, 4.4, 1.2},
new double[]{2.0, 6.1, 3.0, 4.6, 1.4},
new double[]{2.0, 5.8, 2.6, 4.0, 1.2},
new double[]{2.0, 5.0, 2.3, 3.3, 1.0},
new double[]{2.0, 5.6, 2.7, 4.2, 1.3},
new double[]{2.0, 5.7, 3.0, 4.2, 1.2},
new double[]{2.0, 5.7, 2.9, 4.2, 1.3},
new double[]{2.0, 6.2, 2.9, 4.3, 1.3},
new double[]{2.0, 5.1, 2.5, 3.0, 1.1},
new double[]{2.0, 5.7, 2.8, 4.1, 1.3},
new double[]{3.0, 6.3, 3.3, 6.0, 2.5},
new double[]{3.0, 5.8, 2.7, 5.1, 1.9},
new double[]{3.0, 7.1, 3.0, 5.9, 2.1},
new double[]{3.0, 6.3, 2.9, 5.6, 1.8},
new double[]{3.0, 6.5, 3.0, 5.8, 2.2},
new double[]{3.0, 7.6, 3.0, 6.6, 2.1},
new double[]{3.0, 4.9, 2.5, 4.5, 1.7},
new double[]{3.0, 7.3, 2.9, 6.3, 1.8},
new double[]{3.0, 6.7, 2.5, 5.8, 1.8},
new double[]{3.0, 7.2, 3.6, 6.1, 2.5},
new double[]{3.0, 6.5, 3.2, 5.1, 2.0},
new double[]{3.0, 6.4, 2.7, 5.3, 1.9},
new double[]{3.0, 6.8, 3.0, 5.5, 2.1},
new double[]{3.0, 5.7, 2.5, 5.0, 2.0},
new double[]{3.0, 5.8, 2.8, 5.1, 2.4},
new double[]{3.0, 6.4, 3.2, 5.3, 2.3},
new double[]{3.0, 6.5, 3.0, 5.5, 1.8},
new double[]{3.0, 7.7, 3.8, 6.7, 2.2},
new double[]{3.0, 7.7, 2.6, 6.9, 2.3},
new double[]{3.0, 6.0, 2.2, 5.0, 1.5},
new double[]{3.0, 6.9, 3.2, 5.7, 2.3},
new double[]{3.0, 5.6, 2.8, 4.9, 2.0},
new double[]{3.0, 7.7, 2.8, 6.7, 2.0},
new double[]{3.0, 6.3, 2.7, 4.9, 1.8},
new double[]{3.0, 6.7, 3.3, 5.7, 2.1},
new double[]{3.0, 7.2, 3.2, 6.0, 1.8},
new double[]{3.0, 6.2, 2.8, 4.8, 1.8},
new double[]{3.0, 6.1, 3.0, 4.9, 1.8},
new double[]{3.0, 6.4, 2.8, 5.6, 2.1},
new double[]{3.0, 7.2, 3.0, 5.8, 1.6},
new double[]{3.0, 7.4, 2.8, 6.1, 1.9},
new double[]{3.0, 7.9, 3.8, 6.4, 2.0},
new double[]{3.0, 6.4, 2.8, 5.6, 2.2},
new double[]{3.0, 6.3, 2.8, 5.1, 1.5},
new double[]{3.0, 6.1, 2.6, 5.6, 1.4},
new double[]{3.0, 7.7, 3.0, 6.1, 2.3},
new double[]{3.0, 6.3, 3.4, 5.6, 2.4},
new double[]{3.0, 6.4, 3.1, 5.5, 1.8},
new double[]{3.0, 6.0, 3.0, 4.8, 1.8},
new double[]{3.0, 6.9, 3.1, 5.4, 2.1},
new double[]{3.0, 6.7, 3.1, 5.6, 2.4},
new double[]{3.0, 6.9, 3.1, 5.1, 2.3},
new double[]{3.0, 5.8, 2.7, 5.1, 1.9},
new double[]{3.0, 6.8, 3.2, 5.9, 2.3},
new double[]{3.0, 6.7, 3.3, 5.7, 2.5},
new double[]{3.0, 6.7, 3.0, 5.2, 2.3},
new double[]{3.0, 6.3, 2.5, 5.0, 1.9},
new double[]{3.0, 6.5, 3.0, 5.2, 2.0},
new double[]{3.0, 6.2, 3.4, 5.4, 2.3},
new double[]{3.0, 5.9, 3.0, 5.1, 1.8}};
public static void Main(String[] args)
{
/* Data corrections described in the KDD data mining archive */
irisFisherData[34][4] = 0.1;
irisFisherData[37][2] = 3.1;
irisFisherData[37][3] = 1.5;
/* Train first 140 patterns of the iris Fisher Data */
int[] irisClassificationData =
new int[irisFisherData.Length - 10];
double[][] irisContinuousData =
new double[irisFisherData.Length - 10][];
for (int i = 0; i < irisFisherData.Length - 10; i++)
{
irisContinuousData[i] =
new double[irisFisherData[0].Length - 1];
}
for (int i = 0; i < irisFisherData.Length - 10; i++)
{
irisClassificationData[i] = (int)irisFisherData[i][0] - 1;
Array.Copy(irisFisherData[i], 1, irisContinuousData[i], 0,
irisFisherData[0].Length - 1);
}
int nNominal = 0; /* no nominal input attributes */
int nContinuous = 4; /* four continuous input attributes */
int nClasses = 3; /* three classification categories */
NaiveBayesClassifier nbTrainer =
new NaiveBayesClassifier(nContinuous, nNominal, nClasses);
for (int i = 0; i < nContinuous; i++)
nbTrainer.CreateContinuousAttribute(
new NormalDistribution());
nbTrainer.Train(irisContinuousData, null,
irisClassificationData);
int[][] classErrors = nbTrainer.GetTrainingErrors();
Console.Out.WriteLine(
" Iris Classification Training Error Rates");
Console.Out.WriteLine(
"------------------------------------------------");
Console.Out.WriteLine(
" Setosa Versicolour Virginica | Total");
Console.Out.WriteLine(" " + classErrors[0][0] + "/" +
classErrors[0][1] + " " + classErrors[1][0] + "/" +
classErrors[1][1] + " " + classErrors[2][0] + "/" +
classErrors[2][1] + " | " + classErrors[3][0] +
"/" + classErrors[3][1]);
Console.Out.WriteLine(
"------------------------------------------------\n\n\n");
/* Classify last 10 iris data patterns
* with the trained classifier
*/
double[] continuousInput =
new double[(irisFisherData[0].Length - 1)];
double[] classifiedProbabilities = new double[nClasses];
Console.Out.WriteLine(
"Probabilities for Incorrect Classifications");
Console.Out.WriteLine(" Predicted ");
Console.Out.WriteLine(
" Class | Class | P(0) P(1) P(2) ");
Console.Out.WriteLine(
"-------------------------------------------------------");
for (int i = 0; i < 10; i++)
{
int targetClassification =(int)
irisFisherData[(irisFisherData.Length - 10) + i][0] - 1;
Array.Copy(irisFisherData[(irisFisherData.Length - 10) + i],
1, continuousInput, 0, (irisFisherData[0].Length - 1));
classifiedProbabilities =
nbTrainer.Probabilities(continuousInput, null);
int classification =
nbTrainer.PredictClass(continuousInput, null);
if (classification == 0)
Console.Out.Write("Setosa |");
else if (classification == 1)
Console.Out.Write("Versicolour |");
else if (classification == 2)
Console.Out.Write("Virginica |");
else
Console.Out.Write("Missing |");
if (targetClassification == 0)
Console.Out.Write(" Setosa |");
else if (targetClassification == 1)
Console.Out.Write(" Versicolour |");
else if (targetClassification == 2)
Console.Out.Write(" Virginica |");
else
Console.Out.Write(" Missing |");
for (int j = 0; j < nClasses; j++)
{
Object[] pArgs = new Object[] {
(double)classifiedProbabilities[j] };
Console.Out.Write(" {0, 2:f3} ", pArgs);
}
Console.Out.WriteLine();
}
}
}
The Naive Bayes classifier incorrectly classifies 6 of the 150 training patterns.
Iris Classification Training Error Rates
------------------------------------------------
Setosa Versicolour Virginica | Total
0/50 0/50 20/40 | 20/140
------------------------------------------------
Probabilities for Incorrect Classifications
Predicted
Class | Class | P(0) P(1) P(2)
-------------------------------------------------------
Virginica | Virginica | 0.000 0.436 0.564
Virginica | Virginica | 0.000 0.466 0.534
Versicolour | Virginica | 0.000 0.542 0.458
Virginica | Virginica | 0.000 0.441 0.559
Virginica | Virginica | 0.000 0.412 0.588
Virginica | Virginica | 0.000 0.466 0.534
Versicolour | Virginica | 0.000 0.542 0.458
Versicolour | Virginica | 0.000 0.515 0.485
Virginica | Virginica | 0.000 0.460 0.540
Versicolour | Virginica | 0.000 0.551 0.449
Link to C# source.