public class BinaryClassification extends Object implements Serializable
Uses a FeedForwardNetwork
to solve a binary classification problem.
In these problems, the target output for the network is the probability that
the pattern falls into one of two classes. These probabilities are then used to
assign or predict new patterns as belonging to one of the two classes.
Determining whether a
credit applicant is a good or bad credit risk based upon their credit history and other
factors and whether a
person should or should not receive a particular treatment based upon their
physical and laboratory information, are two examples of binary classification
problems. The network is trained by
minimizing the binary cross-entropy error and output is calculated by applying the logistic
activation function to the potential of the single output. Network output
is the probability \(P(C_1)\) that the pattern belongs to the first class.
The probability for the second class is then \(P(C_2) = 1 - P(C_1)\).
Constructor and Description |
---|
BinaryClassification(Network network)
Creates a binary classifier.
|
Modifier and Type | Method and Description |
---|---|
double[] |
computeStatistics(double[][] xData,
int[] yData)
Computes the classification error statistics for the supplied network patterns
and their associated classifications.
|
QuasiNewtonTrainer.Error |
getError()
Returns the error function for use by
QuasiNewtonTrainer
for training a binary classification network. |
Network |
getNetwork()
Returns the network being used for classification.
|
int |
predictedClass(double[] x)
Calculates the classification probablities for the input
pattern
x , and returns either 0 or 1 identifying the class with the
highest probability. |
double[] |
probabilities(double[] x)
Returns classification probabilities for the input pattern
x . |
void |
train(Trainer trainer,
double[][] xData,
int[] yData)
Trains the classification neural network using supplied trainer and patterns.
|
public BinaryClassification(Network network)
network
- is the neural network used for classification.
Its OutputPerceptron
s should use the logistic activation
function, Activation.LOGISTIC
.public Network getNetwork()
public void train(Trainer trainer, double[][] xData, int[] yData)
trainer
- A Trainer
object, which is used to train the
network. The error function in any QuasiNewton
trainer included in trainer
should be set to
the error function from this class using the getError
method provided by this class.xData
- A double
matrix containing the input training
patterns. The number of columns in xData
must
equal the number of nodes in the input layer. Each row of
xData
contains a training pattern.yData
- An int
array containing the output classification values.
These values must be 0 or 1.public QuasiNewtonTrainer.Error getError()
QuasiNewtonTrainer
for training a binary classification network.public int predictedClass(double[] x)
x
, and returns either 0 or 1 identifying the class with the
highest probability.
This method is used to classify patterns into one of the two target classes based upon the pattern's values. The predicted classification is the class with the largest probability, i.e. greater than 0.5.
x
- the double
array containing the network input
patterns to classify. The length of x
should be
equal to the number of inputs in the network.x
.
This will be either 0 or 1.public double[] probabilities(double[] x)
x
.
Calculates the two probabilities for the pattern supplied:
\(P(C_1)\) and \(P(C_2)\).
The probability that the pattern belongs to the first class,
\(P(C_1)\), is estimated using the logistic function of
the OutputPerceptron
's potential. The probability for the second class
is calculated as \(P(C_2) = 1 - P(C_1)\). The predicted
classification is the class with the largest probability, i.e. greater
than 0.5.
x
- a double
array containing the network input
pattern to classify. The length of x
must equal
the number of nodes in the input layer.x
being in class \(C_1\),
followed by the probability of x
being in class \(C_2\).public double[] computeStatistics(double[][] xData, int[] yData)
The first element returned is the binary cross-entropy error; the second is the classification error rate. The classification error rate is calculated by comparing the estimated classification probabilities to the target classifications. If the estimated probability for the target class is less than 0.5, then this is tallied as a classification error.
xData
- A double
matrix specifying the input
training patterns. The number of columns in xData
must equal the number of Node
s in
the InputLayer
.yData
- An int
containing the output classification
patterns. The values in yData
must be 0 or 1.double
array containing the binary
cross-entropy error and the classification error rate.Copyright © 2020 Rogue Wave Software. All rights reserved.