Class MultiClassification

java.lang.Object
com.imsl.datamining.neural.MultiClassification
All Implemented Interfaces:
Serializable

public class MultiClassification extends Object implements Serializable
Classifies patterns into three or more classes.

Extends neural network analysis to solving multi-classification problems. In these problems, the target output for the network is the probability that the pattern falls into each of several classes, where the number of classes is 3 or greater. These probabilities are then used to assign patterns to one of the target classes. Typical applications include determining the credit classification for a business (excellent, good, fair or poor), and determining which of three or more treatments a patient should receive based upon their physical, clinical and laboratory information. This class signals that network training will minimize the multi-classification cross-entropy error, and that network outputs are the probabilities that the pattern belongs to each of the target classes. These probabilities are scaled to sum to 1.0 using softmax activation.

See Also:
  • Constructor Summary

    Constructors
    Constructor
    Description
    Creates a classifier.
  • Method Summary

    Modifier and Type
    Method
    Description
    double[]
    computeStatistics(double[][] xData, int[] yData)
    Computes classification statistics for the supplied network patterns and their associated classifications.
    Returns the error function for use by QuasiNewtonTrainer for training a classification network.
    Returns the network being used for classification.
    int
    predictedClass(double[] x)
    Calculates the classification probablities for the input pattern x, and returns the class with the highest probability.
    double[]
    probabilities(double[] x)
    Returns classification probabilities for the input pattern x.
    void
    train(Trainer trainer, double[][] xData, int[] yData)
    Trains the classification neural network using supplied training patterns.

    Methods inherited from class java.lang.Object

    clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
  • Constructor Details

    • MultiClassification

      public MultiClassification(Network network)
      Creates a classifier.
      Parameters:
      network - is the neural network used for classification. It's OutputPerceptrons should use linear activation functions, Activation.LINEAR. The number of OutputPerceptrons should equal the number of classes.
  • Method Details

    • getNetwork

      public Network getNetwork()
      Returns the network being used for classification.
      Returns:
      the network set by the constructor.
    • train

      public void train(Trainer trainer, double[][] xData, int[] yData)
      Trains the classification neural network using supplied training patterns.
      Parameters:
      trainer - A Trainer object, which is used to train the network. The error function in any QuasiNewtonTrainer included in trainer should be set to the error function from this class using the getError method.
      xData - A double matrix containing the input training patterns. The number of columns in xData must equal the number of Nodes in the InputLayer. Each row of xData contains a training pattern.
      yData - An int array containing the output classification patterns. These values must be in the range of one to the number of OutputPerceptrons in the network.
    • getError

      public QuasiNewtonTrainer.Error getError()
      Returns the error function for use by QuasiNewtonTrainer for training a classification network. This error function combines the softmax activation function and the cross-entropy error function.
      Returns:
      an implementation of the multi-classification cross-entropy error function.
    • predictedClass

      public int predictedClass(double[] x)
      Calculates the classification probablities for the input pattern x, and returns the class with the highest probability.

      This method classifies patterns into one of the target classes based upon the patterns values.

      Parameters:
      x - The double array containing the network input patterns to classify. The length of x should equal the number of inputs in the network.
      Returns:
      The classification predicted by the trained network for x. This will be one of the integers 1,2,...,nClasses, where nClasses is equal to nOuptuts. nOuptuts is the number of outputs in the network representing the number classes.
    • probabilities

      public double[] probabilities(double[] x)
      Returns classification probabilities for the input pattern x.

      The number of probabilities is equal to the number of target classes, which is the number of outputs in the FeedForwardNetwork. Each are calculated using the softmax activation for each of the OutputPerceptrons. The softmax function transforms the outputs potential z to the probability y by $$y_i = { {\rm{softmax}}_{\rm{i}}=\frac{{{\mathop{\rm e}\nolimits} ^{Z_i } }} {{\sum\limits_{j = 1}^C {e^{Z_j } } }} }$$

      Parameters:
      x - a double array containing the input patterns to classify. The length of x must be equal to the number of InputNodes.
      Returns:
      A double containing the scaled probabilities.
    • computeStatistics

      public double[] computeStatistics(double[][] xData, int[] yData)
      Computes classification statistics for the supplied network patterns and their associated classifications.

      Method computeStatistics returns a two element array where the first element returned is the cross-entropy error; the second is the classification error rate. The classification error rate is calculated by comparing the estimated classification probabilities to the target classifications. If the estimated probability for the target class is not the largest for among the target classes, then the pattern is tallied as a classification error.

      Parameters:
      xData - A double matrix specifying the input training patterns. The number of columns in xData must equal the number of Nodes in the InputLayer.
      yData - An int containing the output classification patterns. The values in yData must be in the range of one to the number of OutputPerceptrons in the network.
      Returns:
      A double array containing the two statistics described above.