Class MultiClassification
- All Implemented Interfaces:
Serializable
Extends neural network analysis to solving multi-classification problems. In these problems, the target output for the network is the probability that the pattern falls into each of several classes, where the number of classes is 3 or greater. These probabilities are then used to assign patterns to one of the target classes. Typical applications include determining the credit classification for a business (excellent, good, fair or poor), and determining which of three or more treatments a patient should receive based upon their physical, clinical and laboratory information. This class signals that network training will minimize the multi-classification cross-entropy error, and that network outputs are the probabilities that the pattern belongs to each of the target classes. These probabilities are scaled to sum to 1.0 using softmax activation.
- See Also:
-
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptiondouble[]computeStatistics(double[][] xData, int[] yData) Computes classification statistics for the supplied network patterns and their associated classifications.getError()Returns the error function for use byQuasiNewtonTrainerfor training a classification network.Returns the network being used for classification.intpredictedClass(double[] x) Calculates the classification probablities for the input patternx, and returns the class with the highest probability.double[]probabilities(double[] x) Returns classification probabilities for the input patternx.voidTrains the classification neural network using supplied training patterns.
-
Constructor Details
-
MultiClassification
Creates a classifier.- Parameters:
network- is the neural network used for classification. It'sOutputPerceptrons should use linear activation functions,Activation.LINEAR. The number ofOutputPerceptrons should equal the number of classes.
-
-
Method Details
-
getNetwork
Returns the network being used for classification.- Returns:
- the
networkset by the constructor.
-
train
Trains the classification neural network using supplied training patterns.- Parameters:
trainer- ATrainerobject, which is used to train the network. The error function in anyQuasiNewtonTrainerincluded intrainershould be set to the error function from this class using themethod.getErrorxData- Adoublematrix containing the input training patterns. The number of columns inxDatamust equal the number ofNodes in theInputLayer. Each row ofxDatacontains a training pattern.yData- Anintarray containing the output classification patterns. These values must be in the range of one to the number ofOutputPerceptrons in the network.
-
getError
Returns the error function for use byQuasiNewtonTrainerfor training a classification network. This error function combines the softmax activation function and the cross-entropy error function.- Returns:
- an implementation of the multi-classification cross-entropy error function.
-
predictedClass
public int predictedClass(double[] x) Calculates the classification probablities for the input patternx, and returns the class with the highest probability.This method classifies patterns into one of the target classes based upon the patterns values.
- Parameters:
x- Thedoublearray containing the network input patterns to classify. The length ofxshould equal the number of inputs in the network.- Returns:
- The classification predicted by the trained network for
x. This will be one of the integers 1,2,...,nClasses, where nClasses is equal tonOuptuts.nOuptutsis the number of outputs in the network representing the number classes.
-
probabilities
public double[] probabilities(double[] x) Returns classification probabilities for the input patternx.The number of probabilities is equal to the number of target classes, which is the number of outputs in the
FeedForwardNetwork. Each are calculated using the softmax activation for each of theOutputPerceptrons. The softmax function transforms the outputs potential z to the probability y by $$y_i = { {\rm{softmax}}_{\rm{i}}=\frac{{{\mathop{\rm e}\nolimits} ^{Z_i } }} {{\sum\limits_{j = 1}^C {e^{Z_j } } }} }$$- Parameters:
x- adoublearray containing the input patterns to classify. The length ofxmust be equal to the number ofInputNodes.- Returns:
- A
doublecontaining the scaled probabilities.
-
computeStatistics
public double[] computeStatistics(double[][] xData, int[] yData) Computes classification statistics for the supplied network patterns and their associated classifications.Method
computeStatisticsreturns a two element array where the first element returned is the cross-entropy error; the second is the classification error rate. The classification error rate is calculated by comparing the estimated classification probabilities to the target classifications. If the estimated probability for the target class is not the largest for among the target classes, then the pattern is tallied as a classification error.- Parameters:
xData- Adoublematrix specifying the input training patterns. The number of columns inxDatamust equal the number ofNodes in theInputLayer.yData- Anintcontaining the output classification patterns. The values inyDatamust be in the range of one to the number ofOutputPerceptrons in the network.- Returns:
- A
doublearray containing the two statistics described above.
-