Class BinaryClassification
- All Implemented Interfaces:
Serializable
Uses a FeedForwardNetwork to solve a binary classification problem.
In these problems, the target output for the network is the probability that
the pattern falls into one of two classes. These probabilities are then used to
assign or predict new patterns as belonging to one of the two classes.
Determining whether a
credit applicant is a good or bad credit risk based upon their credit history and other
factors and whether a
person should or should not receive a particular treatment based upon their
physical and laboratory information, are two examples of binary classification
problems. The network is trained by
minimizing the binary cross-entropy error and output is calculated by applying the logistic
activation function to the potential of the single output. Network output
is the probability \(P(C_1)\) that the pattern belongs to the first class.
The probability for the second class is then \(P(C_2) = 1 - P(C_1)\).
- See Also:
-
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptiondouble[]computeStatistics(double[][] xData, int[] yData) Computes the classification error statistics for the supplied network patterns and their associated classifications.getError()Returns the error function for use byQuasiNewtonTrainerfor training a binary classification network.Returns the network being used for classification.intpredictedClass(double[] x) Calculates the classification probablities for the input patternx, and returns either 0 or 1 identifying the class with the highest probability.double[]probabilities(double[] x) Returns classification probabilities for the input patternx.voidTrains the classification neural network using supplied trainer and patterns.
-
Constructor Details
-
BinaryClassification
Creates a binary classifier.- Parameters:
network- is the neural network used for classification. ItsOutputPerceptrons should use the logistic activation function,Activation.LOGISTIC.
-
-
Method Details
-
getNetwork
Returns the network being used for classification.- Returns:
- the network set by the constructor.
-
train
Trains the classification neural network using supplied trainer and patterns.- Parameters:
trainer- ATrainerobject, which is used to train the network. The error function in anyQuasiNewtontrainer included intrainershould be set to the error function from this class using themethod provided by this class.getErrorxData- Adoublematrix containing the input training patterns. The number of columns inxDatamust equal the number of nodes in the input layer. Each row ofxDatacontains a training pattern.yData- Anintarray containing the output classification values. These values must be 0 or 1.
-
getError
Returns the error function for use byQuasiNewtonTrainerfor training a binary classification network.- Returns:
- an implementation of the binary-entropy error function.
-
predictedClass
public int predictedClass(double[] x) Calculates the classification probablities for the input patternx, and returns either 0 or 1 identifying the class with the highest probability.This method is used to classify patterns into one of the two target classes based upon the pattern's values. The predicted classification is the class with the largest probability, i.e. greater than 0.5.
- Parameters:
x- thedoublearray containing the network input patterns to classify. The length ofxshould be equal to the number of inputs in the network.- Returns:
- The classification predicted by the trained network for
x. This will be either 0 or 1.
-
probabilities
public double[] probabilities(double[] x) Returns classification probabilities for the input patternx.Calculates the two probabilities for the pattern supplied: \(P(C_1)\) and \(P(C_2)\). The probability that the pattern belongs to the first class, \(P(C_1)\), is estimated using the logistic function of the
OutputPerceptron's potential. The probability for the second class is calculated as \(P(C_2) = 1 - P(C_1)\). The predicted classification is the class with the largest probability, i.e. greater than 0.5.- Parameters:
x- adoublearray containing the network input pattern to classify. The length ofxmust equal the number of nodes in the input layer.- Returns:
- the probability of
xbeing in class \(C_1\), followed by the probability ofxbeing in class \(C_2\).
-
computeStatistics
public double[] computeStatistics(double[][] xData, int[] yData) Computes the classification error statistics for the supplied network patterns and their associated classifications.The first element returned is the binary cross-entropy error; the second is the classification error rate. The classification error rate is calculated by comparing the estimated classification probabilities to the target classifications. If the estimated probability for the target class is less than 0.5, then this is tallied as a classification error.
- Parameters:
xData- Adoublematrix specifying the input training patterns. The number of columns inxDatamust equal the number ofNodes in theInputLayer.yData- Anintcontaining the output classification patterns. The values inyDatamust be 0 or 1.- Returns:
- A two-element
doublearray containing the binary cross-entropy error and the classification error rate.
-