Class BinaryClassification

java.lang.Object
com.imsl.datamining.neural.BinaryClassification
All Implemented Interfaces:
Serializable

public class BinaryClassification extends Object implements Serializable
Classifies patterns into two classes.

Uses a FeedForwardNetwork to solve a binary classification problem. In these problems, the target output for the network is the probability that the pattern falls into one of two classes. These probabilities are then used to assign or predict new patterns as belonging to one of the two classes. Determining whether a credit applicant is a good or bad credit risk based upon their credit history and other factors and whether a person should or should not receive a particular treatment based upon their physical and laboratory information, are two examples of binary classification problems. The network is trained by minimizing the binary cross-entropy error and output is calculated by applying the logistic activation function to the potential of the single output. Network output is the probability \(P(C_1)\) that the pattern belongs to the first class. The probability for the second class is then \(P(C_2) = 1 - P(C_1)\).

See Also:
  • Constructor Summary

    Constructors
    Constructor
    Description
    Creates a binary classifier.
  • Method Summary

    Modifier and Type
    Method
    Description
    double[]
    computeStatistics(double[][] xData, int[] yData)
    Computes the classification error statistics for the supplied network patterns and their associated classifications.
    Returns the error function for use by QuasiNewtonTrainer for training a binary classification network.
    Returns the network being used for classification.
    int
    predictedClass(double[] x)
    Calculates the classification probablities for the input pattern x, and returns either 0 or 1 identifying the class with the highest probability.
    double[]
    probabilities(double[] x)
    Returns classification probabilities for the input pattern x.
    void
    train(Trainer trainer, double[][] xData, int[] yData)
    Trains the classification neural network using supplied trainer and patterns.

    Methods inherited from class java.lang.Object

    clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
  • Constructor Details

    • BinaryClassification

      public BinaryClassification(Network network)
      Creates a binary classifier.
      Parameters:
      network - is the neural network used for classification. Its OutputPerceptrons should use the logistic activation function, Activation.LOGISTIC.
  • Method Details

    • getNetwork

      public Network getNetwork()
      Returns the network being used for classification.
      Returns:
      the network set by the constructor.
    • train

      public void train(Trainer trainer, double[][] xData, int[] yData)
      Trains the classification neural network using supplied trainer and patterns.
      Parameters:
      trainer - A Trainer object, which is used to train the network. The error function in any QuasiNewton trainer included in trainer should be set to the error function from this class using the getError method provided by this class.
      xData - A double matrix containing the input training patterns. The number of columns in xData must equal the number of nodes in the input layer. Each row of xData contains a training pattern.
      yData - An int array containing the output classification values. These values must be 0 or 1.
    • getError

      public QuasiNewtonTrainer.Error getError()
      Returns the error function for use by QuasiNewtonTrainer for training a binary classification network.
      Returns:
      an implementation of the binary-entropy error function.
    • predictedClass

      public int predictedClass(double[] x)
      Calculates the classification probablities for the input pattern x, and returns either 0 or 1 identifying the class with the highest probability.

      This method is used to classify patterns into one of the two target classes based upon the pattern's values. The predicted classification is the class with the largest probability, i.e. greater than 0.5.

      Parameters:
      x - the double array containing the network input patterns to classify. The length of x should be equal to the number of inputs in the network.
      Returns:
      The classification predicted by the trained network for x. This will be either 0 or 1.
    • probabilities

      public double[] probabilities(double[] x)
      Returns classification probabilities for the input pattern x.

      Calculates the two probabilities for the pattern supplied: \(P(C_1)\) and \(P(C_2)\). The probability that the pattern belongs to the first class, \(P(C_1)\), is estimated using the logistic function of the OutputPerceptron's potential. The probability for the second class is calculated as \(P(C_2) = 1 - P(C_1)\). The predicted classification is the class with the largest probability, i.e. greater than 0.5.

      Parameters:
      x - a double array containing the network input pattern to classify. The length of x must equal the number of nodes in the input layer.
      Returns:
      the probability of x being in class \(C_1\), followed by the probability of x being in class \(C_2\).
    • computeStatistics

      public double[] computeStatistics(double[][] xData, int[] yData)
      Computes the classification error statistics for the supplied network patterns and their associated classifications.

      The first element returned is the binary cross-entropy error; the second is the classification error rate. The classification error rate is calculated by comparing the estimated classification probabilities to the target classifications. If the estimated probability for the target class is less than 0.5, then this is tallied as a classification error.

      Parameters:
      xData - A double matrix specifying the input training patterns. The number of columns in xData must equal the number of Nodes in the InputLayer.
      yData - An int containing the output classification patterns. The values in yData must be 0 or 1.
      Returns:
      A two-element double array containing the binary cross-entropy error and the classification error rate.