Class QuasiNewtonTrainer
- All Implemented Interfaces:
Trainer,Serializable
MinUnconMultiVar.-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionprotected classprotected classstatic interfaceError function to be minimized by trainer.protected classThe Objective class is passed to the optimizer.protected classThe Objective class is passed to the optimizer. -
Field Summary
FieldsModifier and TypeFieldDescriptionstatic final QuasiNewtonTrainer.ErrorCompute the sum of squares error. -
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionprotected Objectclone()Clones a copy of the trainer.getError()Returns the function used to compute the error to be minimized.double[]Returns the value of the gradient of the error function with respect to the weights.intReturns the error status from the trainer.doubleReturns the final value of the error function.static FormatterReturns the logging formatter object.static LoggerReturns theLoggerobject.intReturns the number of iterations used during training.booleanReturns the use back propagation setting.protected voidsetEpochNumber(int num) Sets the epoch number for the trainer.voidsetError(QuasiNewtonTrainer.Error error) Sets the function used to compute the network error.voidsetGradientTolerance(double gradientTolerance) Set the gradient tolerance.voidsetMaximumStepsize(double maximumStepsize) Sets the maximum step size.voidsetMaximumTrainingIterations(int maximumTrainingIterations) Sets the maximum number of iterations to use in a training.protected voidsetParallelMode(ArrayList[] allLogRecords) Sets the trainer to be used in multi-threaded EpochTainer.voidsetStepTolerance(double stepTolerance) Sets the scaled step tolerance.voidsetUseBackPropagation(boolean flag) Sets whether or not to use the back propagation algorithm for gradient calculations during network training.voidTrains the neural network using supplied training patterns.
-
Field Details
-
SUM_OF_SQUARES
Compute the sum of squares error. The sum of squares error term is \(e(y,\hat{y})=(y-\hat{y})^2/2\).This is the default
Errorobject used byQuasiNewtonTrainer.
-
-
Constructor Details
-
QuasiNewtonTrainer
public QuasiNewtonTrainer()Constructs aQuasiNewtonTrainerobject.
-
-
Method Details
-
setUseBackPropagation
public void setUseBackPropagation(boolean flag) Sets whether or not to use the back propagation algorithm for gradient calculations during network training.By default, the quasi-newton algorithm optimizes the network using numerical gradients. This method directs the quasi-newton trainer to use the back propagation algorithm for gradient calculations during network training. Depending upon the data and network architecture, one approach is typically faster than the other, or is less sensitive to finding local network optima.
- Parameters:
flag-booleanspecifies whether or not to use the back propagation algorithm for gradient calculations. Default value istrue.
-
getUseBackPropagation
public boolean getUseBackPropagation()Returns the use back propagation setting.- Returns:
- a
booleanspecifying whether or not back propagation is being used for gradient calculations.
-
clone
Clones a copy of the trainer. -
setParallelMode
Sets the trainer to be used in multi-threaded EpochTainer.- Parameters:
allLogRecords- AnArrayListarray containing the log records.
-
setEpochNumber
protected void setEpochNumber(int num) Sets the epoch number for the trainer.- Parameters:
num- Anintarray containing the epoch number.
-
setMaximumStepsize
public void setMaximumStepsize(double maximumStepsize) Sets the maximum step size.- Parameters:
maximumStepsize- A nonnegativedoublevalue specifying the maximum allowable step size in the optimizer.- See Also:
-
setMaximumTrainingIterations
public void setMaximumTrainingIterations(int maximumTrainingIterations) Sets the maximum number of iterations to use in a training.- Parameters:
maximumTrainingIterations- Anintrepresenting the maximum number of training iterations. Default: 100.- See Also:
-
setStepTolerance
public void setStepTolerance(double stepTolerance) Sets the scaled step tolerance.The second stopping criterion for
MinUnconMultiVar, the optimizer used by thisTrainer, is that the scaled distance between the last two steps be less than the step tolerance.- Parameters:
stepTolerance- Adoublewhich is the step tolerance. Default: 3.66685e-11.- See Also:
-
setGradientTolerance
public void setGradientTolerance(double gradientTolerance) Set the gradient tolerance.- Parameters:
gradientTolerance- Adoublespecifying the gradient tolerance. Default: cube root of machine precision.- See Also:
-
getTrainingIterations
public int getTrainingIterations()Returns the number of iterations used during training.- Returns:
- An
intrepresenting the number of iterations used during training. - See Also:
-
getErrorStatus
public int getErrorStatus()Returns the error status from the trainer.- Specified by:
getErrorStatusin interfaceTrainer- Returns:
- An
intrepresenting the error status from the trainer. Zero indicates that no errors were encountered during training. Any non-zero value indicates that some error condition arose during training. In many cases the trainer is able to recover from these conditions and produce a well-trained network.Error Status Condition 0 No error occurred during training. 1 The last global step failed to locate a lower point than the current error value. The current solution may be an approximate solution and no more accuracy is possible, or the step tolerance may be too large. 2 Relative function convergence; both the actual and predicted relative reductions in the error function are less than or equal to the relative function convergence tolerance. 3 Scaled step tolerance satisfied; the current point may be an approximate local solution, or the algorithm is making very slow progress and is not near a solution, or the step tolerance is too big. 4 MinUnconMultiVar.FalseConvergenceExceptionthrown by optimizer.5 MinUnconMultiVar.MaxIterationsExceptionthrown by optimizer.6 MinUnconMultiVar.UnboundedBelowExceptionthrown by optimizer. - See Also:
-
train
Trains the neural network using supplied training patterns.Each row of
xDataandyDatacontains a training pattern. The number of rows in these two arrays must be at least equal to the number of weights in the network.- Specified by:
trainin interfaceTrainer- Parameters:
network- TheNetworkto be trained.xData- An inputdoublematrix containing training patterns. The number of columns inxDatamust equal the number of nodes in the input layer.yData- An outputdoublematrix containing output training patterns. The number of columns inyDatamust equal the number of perceptrons in the output layer.
-
getErrorValue
public double getErrorValue()Returns the final value of the error function.- Specified by:
getErrorValuein interfaceTrainer- Returns:
- A
doublerepresenting the final value of the error function from the last training. Before training,NaNis returned.
-
getErrorGradient
public double[] getErrorGradient()Returns the value of the gradient of the error function with respect to the weights.- Specified by:
getErrorGradientin interfaceTrainer- Returns:
- A
doublearray whose length is equal to the number of network weights, containing the value of the gradient of the error function with respect to the weights. Before training, null is returned.
-
getLogger
Returns theLoggerobject. This is theLoggerused to trace this class. It is namedcom.imsl.datamining.neural.QuasiNewtonTrainer.- Returns:
- The
Loggerobject, if present, ornull.
-
getFormatter
Returns the logging formatter object.Loggersupport requires JDK1.4. Use with earlier versions returns null.The returned
Formatteris used as input toHandler.setFormatter(java.util.logging.Formatter)to format the output log.- Returns:
- The
Formatterobject, if present, ornull.
-
getError
Returns the function used to compute the error to be minimized.- Returns:
- The
Errorobject containing the function to be minimized.
-
setError
Sets the function used to compute the network error.- Parameters:
error- TheErrorobject containing the function to be used to compute the network error. The default is to compute the sum of squares error,SUM_OF_SQUARES.
-