public class MultiClassification extends Object implements Serializable
Extends neural network analysis to solving multi-classification problems. In these problems, the target output for the network is the probability that the pattern falls into each of several classes, where the number of classes is 3 or greater. These probabilities are then used to assign patterns to one of the target classes. Typical applications include determining the credit classification for a business (excellent, good, fair or poor), and determining which of three or more treatments a patient should receive based upon their physical, clinical and laboratory information. This class signals that network training will minimize the multi-classification cross-entropy error, and that network outputs are the probabilities that the pattern belongs to each of the target classes. These probabilities are scaled to sum to 1.0 using softmax activation.
Constructor and Description |
---|
MultiClassification(Network network)
Creates a classifier.
|
Modifier and Type | Method and Description |
---|---|
double[] |
computeStatistics(double[][] xData,
int[] yData)
Computes classification statistics for the supplied
network patterns and their associated classifications.
|
QuasiNewtonTrainer.Error |
getError()
Returns the error function for use by
QuasiNewtonTrainer
for training a classification network. |
Network |
getNetwork()
Returns the network being used for classification.
|
int |
predictedClass(double[] x)
Calculates the classification probablities for the input
pattern
x , and returns the class with the highest probability. |
double[] |
probabilities(double[] x)
Returns classification probabilities for the input pattern
x . |
void |
train(Trainer trainer,
double[][] xData,
int[] yData)
Trains the classification neural network using supplied training patterns.
|
public MultiClassification(Network network)
network
- is the neural network used for classification.
It's OutputPerceptron
s should use linear activation functions,
Activation.LINEAR
.
The number of OutputPerceptron
s should equal the number of classes.public Network getNetwork()
network
set by the constructor.public void train(Trainer trainer, double[][] xData, int[] yData)
trainer
- A Trainer
object, which is used to train the
network. The error function in any
QuasiNewtonTrainer
included in
trainer
should be set to
the error function from this class using the
getError
method.xData
- A double
matrix containing the input training
patterns. The number of columns in xData
must
equal the number of Node
s in the
InputLayer
. Each row of
xData
contains a training pattern.yData
- An int
array containing the output classification patterns.
These values must be in the range of one to the number of
OutputPerceptron
s in the network.public QuasiNewtonTrainer.Error getError()
QuasiNewtonTrainer
for training a classification network.
This error function combines the softmax activation function and the
cross-entropy error function.public int predictedClass(double[] x)
x
, and returns the class with the highest probability.
This method classifies patterns into one of the target classes based upon the patterns values.
x
- The double
array containing the network input
patterns to classify. The length of x
should
equal the number of inputs in the network.x
.
This will be one of the integers 1,2,...,nClasses, where
nClasses is equal to nOuptuts
. nOuptuts
is the number of outputs in the network representing the number
classes.public double[] probabilities(double[] x)
x
.
The number of probabilities is equal to the number of target classes,
which is the number of outputs in the FeedForwardNetwork
.
Each are calculated using the softmax activation for each of the
OutputPerceptron
s. The softmax function transforms the outputs
potential z to the probability y by
$$y_i = {
{\rm{softmax}}_{\rm{i}}=\frac{{{\mathop{\rm e}\nolimits} ^{Z_i } }}
{{\sum\limits_{j = 1}^C {e^{Z_j } } }} }$$
x
- a double
array containing the input patterns to
classify. The length of x
must be equal to the
number of InputNode
s.double
containing the scaled probabilities.public double[] computeStatistics(double[][] xData, int[] yData)
Method computeStatistics
returns a two element array where
the first element returned is the cross-entropy error; the second is
the classification error rate. The classification error rate is calculated by comparing
the estimated classification probabilities to the target classifications.
If the estimated probability for the target class is not the largest for among
the target classes, then the pattern is tallied as a classification error.
xData
- A double
matrix specifying the input
training patterns. The number of columns in xData
must equal the number of Node
s in
the InputLayer
.yData
- An int
containing the output classification
patterns. The values in yData
must be in the
range of one to the number of
OutputPerceptron
s in the network.double
array containing the two statistics described
above.Copyright © 2020 Rogue Wave Software. All rights reserved.