JMSLTM Numerical Library 5.0.1

com.imsl.datamining.neural
Class MultiClassification

java.lang.Object
  extended by com.imsl.datamining.neural.MultiClassification
All Implemented Interfaces:
Serializable

public class MultiClassification
extends Object
implements Serializable

Classifies patterns into three or more classes.

Extends neural network analysis to solving multi-classification problems. In these problems, the target output for the network is the probability that the pattern falls into each of several classes, where the number of classes is 3 or greater. These probabilities are then used to assign patterns to one of the target classes. Typical applications include determining the credit classification for a business (excellent, good, fair or poor), and determining which of three or more treatments a patient should receive based upon their physical, clinical and laboratory information. This class signals that network training will minimize the multi-classification cross-entropy error, and that network outputs are the probabilities that the pattern belongs to each of the target classes. These probabilities are scaled to sum to 1.0 using softmax activation.

See Also:
Example 1, Example 2, Serialized Form

Constructor Summary
MultiClassification(Network network)
          Creates a classifier.
 
Method Summary
 double[] computeStatistics(double[][] xData, int[] yData)
          Computes classification statistics for the supplied network patterns and their associated classifications.
 QuasiNewtonTrainer.Error getError()
          Returns the error function for use by QuasiNewtonTrainer for training a classification network.
 Network getNetwork()
          Returns the network being used for classification.
 int predictedClass(double[] x)
          Calculates the classification probablities for the input pattern x, and returns the class with the highest probability.
 double[] probabilities(double[] x)
          Returns classification probabilities for the input pattern x.
 void train(Trainer trainer, double[][] xData, int[] yData)
          Trains the classification neural network using supplied training patterns.
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Constructor Detail

MultiClassification

public MultiClassification(Network network)
Creates a classifier.

Parameters:
network - is the neural network used for classification. It's OutputPerceptrons should use linear activation functions, Activation.LINEAR. The number of OutputPerceptrons should equal the number of classes.
Method Detail

computeStatistics

public double[] computeStatistics(double[][] xData,
                                  int[] yData)
Computes classification statistics for the supplied network patterns and their associated classifications.

Method computeStatistics returns a two element array where the first element returned is the cross-entropy error; the second is the classification error rate. The classification error rate is calculated by comparing the estimated classification probabilities to the target classifications. If the estimated probability for the target class is not the largest for among the target classes, then the pattern is tallied as a classification error.

Parameters:
xData - A double matrix specifying the input training patterns. The number of columns in xData must equal the number of Nodes in the InputLayer.
yData - An int containing the output classification patterns. The values in yData must be in the range of one to the number of OutputPerceptrons in the network.
Returns:
A double array containing the two statistics described above.

getError

public QuasiNewtonTrainer.Error getError()
Returns the error function for use by QuasiNewtonTrainer for training a classification network. This error function combines the softmax activation function and the cross-entropy error function.

Returns:
an implementation of the multi-classification cross-entropy error function.

getNetwork

public Network getNetwork()
Returns the network being used for classification.

Returns:
the network set by the constructor.

predictedClass

public int predictedClass(double[] x)
Calculates the classification probablities for the input pattern x, and returns the class with the highest probability.

This method classifies patterns into one of the target classes based upon the patterns values.

Parameters:
x - The double array containing the network input patterns to classify. The length of x should equal the number of inputs in the network.
Returns:
The classification predicted by the trained network for x. This will be one of the integers 1,2,...,nClasses, where nClasses is equal to nOuptuts. nOuptuts is the number of outputs in the network representing the number classes.

probabilities

public double[] probabilities(double[] x)
Returns classification probabilities for the input pattern x.

The number of probabilities is equal to the number of target classes, which is the number of outputs in the FeedForwardNetwork. Each are calculated using the softmax activation for each of the OutputPerceptrons. The softmax function transforms the outputs potential z to the probability y by

y_i = {
  {rm{softmax}}_{rm{i}}=frac{{{mathop{rm e}nolimits} ^{Z_i } }}
  {{sumlimits_{j = 1}^C {e^{Z_j } } }} }

Parameters:
x - a double array containing the input patterns to classify. The length of x must be equal to the number of InputNodes.
Returns:
A double containing the scaled probabilities.

train

public void train(Trainer trainer,
                  double[][] xData,
                  int[] yData)
Trains the classification neural network using supplied training patterns.

Parameters:
trainer - A Trainer object, which is used to train the network. The error function in any QuasiNewtonTrainer included in trainer should be set to the error function from this class using the getError method.
xData - A double matrix containing the input training patterns. The number of columns in xData must equal the number of Nodes in the InputLayer. Each row of xData contains a training pattern.
yData - An int array containing the output classification patterns. These values must be in the range of one to the number of OutputPerceptrons in the network.

JMSLTM Numerical Library 5.0.1

Copyright © 1970-2008 Visual Numerics, Inc.
Built July 8 2008.