Class CrossValidation

java.lang.Object
com.imsl.datamining.CrossValidation
All Implemented Interfaces:
Serializable, Cloneable

public class CrossValidation extends Object implements Serializable, Cloneable
Performs V-Fold cross-validation for predictive models. In V-fold cross validation, the data set is partitioned randomly into V approximately equally sized sub-samples. The model is then trained V different times with each of the sub-samples removed in turn to serve as a test set. A loss or risk function is updated with the prediction errors on each fold. The total risk is averaged over the the folds and serves as a measure of the model's predictive performance. For categorical response variables (classification problems), the data can be stratified to include roughly the same proportion of each class level that occurs in the full training sample in each of the V sub-samples. See the method, setStratifiedCrossValidation(boolean). The cross-validated estimate of the risk function is given by

$$R^{CV}(d_k)=\frac{1}{N}\sum^{V}_{v=1} \sum_{(x_n,y_n)\in{\eta_v}}{L(y_n,d^{v}_k(x_n))}$$

\(L(y,d(x))\) is the loss incurred when the prediction is \(d(x)\) for the actual y. The inner summation is over the examples in the test sample held out for each fold.

If the predictive model is an instance of a decision tree, cross-validation is performed on each optimal sub-tree determined by cost-complexity pruning. Let the symbol \(\eta\) denote the full training data set, \(\eta_v\) the \(v^{th}\) sub-sample. Then use \(d^{v}_k\) to indicate the set of predictions corresponding to the kth optimal sub-tree fitted on the training sample \({\eta-\eta_v}\). To select one sub-tree from among the configurations, two criteria are the minimum $$ k^* = \text{argmin}_k R^{CV}(d_k) $$ and the least complicated (smallest sub-tree) which satisfies: $$ k^{**} = \text{max}(k): R^{CV}(d_k) \le R^{CV}(d_{k^*}) + SE(R^{CV}(d_{k})) $$ The standard error is approximated by $$ SE(R^{CV}(d_{k})) \approx \sqrt{\frac{s^2}{n}} $$ with $$ s^2 = \frac{1}{N} \sum^{V}_{v=1}\sum_{(x_n,y_n)\in{\eta_v}}(L(Y_n,d_k^{v}(x_n)- R^{CV}(d_{k}) ))^2 $$ The summation is over each fold and each learning sub-sample within each fold.

See Also:
  • Constructor Details

    • CrossValidation

      public CrossValidation(PredictiveModel pm) throws PredictiveModel.PredictiveModelException
      Creates a CrossValidation object.
      Parameters:
      pm - an object of a class that extends PredictiveModel
      Throws:
      PredictiveModel.PredictiveModelException - is thrown when an exception has occurred in the com.imsl.datamining.PredictiveModel. Superclass exceptions should be considered such as com.imsl.datamining.PredictiveModel.StateChangeException and com.imsl.datamining.PredictiveModel.SumOfProbabilitiesNotOneException.
  • Method Details

    • crossValidate

      Performs V-Fold cross-validation.
      Throws:
      PredictiveModel.PredictiveModelException - is thrown when an exception occurs in the common PredictiveModel programming interface methods or an exception class that has extended the PredictiveModelException class.
      NoSuchMethodException - is thrown when the PredictiveModel subclass is missing a constructor with the expected signature (see com.imsl.datamining.PredictiveModel.PredictiveModel).
      InstantiationException - is thrown when an object fails to instantiate. This may occur if the PredictiveModel subclass is not concrete.
      IllegalAccessException - is thrown when the currently executing method does not have access to the definition of the specified class, field, method or constructor.
      InvocationTargetException - is thrown when a wrapped exception is thrown by an invoked method or constructor.
    • getCrossValidatedError

      public double getCrossValidatedError() throws PredictiveModel.StateChangeException
      Returns the cross-validated error.

      If the response variable is categorical, the error is the misclassification rate, weighted by the misclassification costs and prior probabilities, attributes of the PredictiveModel object. If the response variable is quantitative/continuous, the error is the mean squared prediction error, also weighted if weights are set in the PredictiveModel object. If there are multiple model configurations, the minimum value is returned.

      Returns:
      a double, the cross-validated prediction error
      Throws:
      PredictiveModel.StateChangeException - is thrown when an input parameter in the PredictiveModel has changed that might affect the model estimates or predictions.
    • getNumberOfThreads

      public int getNumberOfThreads()
      Returns the maximum number of java.lang.Thread instances that may be used for parallel processing.
      Returns:
      an int containing the maximum number of java.lang.Thread instances that may be used for parallel processing.

      The actual number of threads used in parallel processing will be the lesser of numberOfThreads and nFolds, the number of folds set for cross-validation. This assessment is made to optimize use of resources.

    • setNumberOfThreads

      public void setNumberOfThreads(int numberOfThreads)
      Sets the maximum number of java.lang.Thread instances that may be used for parallel processing.
      Parameters:
      numberOfThreads - an int specifying the maximum number of java.lang.Thread instances that may be used for parallel processing.

      The actual number of threads used in parallel processing will be the lesser of numberOfThreads and nFolds, the number of folds set for cross-validation. This assessment is made to optimize use of resources.

      Default: numberOfThreads = 1.

    • getNumberOfSampleFolds

      public int getNumberOfSampleFolds()
      Returns the number of folds set for the cross-validation.
      Returns:
      an int, the number of folds
    • setNumberOfSampleFolds

      public void setNumberOfSampleFolds(int nFolds)
      Sets the number of folds to use in cross validation.
      Parameters:
      nFolds - an int specifying the number of folds

      nFolds must be between 1 and the number of observations (xy.length), inclusive. If nFolds = 1 the full data set is used once to generate the PredictiveModel. In other words, no cross-validation is performed. If 1 &lt xy.length/nFolds \(\le\) 3, leave-one-out cross validation is performed.

      Default: nFolds = 10.

    • isStratifiedCrossValidation

      public boolean isStratifiedCrossValidation()
      Returns the flag to perform stratified cross-validation for a categorical response variable.

      When true, the method crossValidate creates the samples to have roughly the same proportion of each class level as occurs in the full training set. When false, regular V-fold cross-validation is performed. If the response variable is continuous, the flag has no effect.

      Returns:
      a boolean, the state of the flag
    • setStratifiedCrossValidation

      public void setStratifiedCrossValidation(boolean stratify)
      Sets the flag to perform stratified cross-validation.

      When true, the method crossValidate creates the samples to have roughly the same proportion of each class level as occurs in the full training set. When false, regular V-fold cross-validation is performed. If the response variable is continuous the flag has no effect.

      Parameters:
      stratify - a boolean indicating whether or not stratified cross-validation should be performed for categorical response variables

      Default: stratify=false

    • setRandomObject

      public void setRandomObject(Random r)
      Sets the random object to be used in the permutation of observation data.
      Parameters:
      r - a Random object to be used in random permutation of observation data.

      Specifying a seed for the Random object can produce repeatable/deterministic output.

    • getRiskValues

      public double[] getRiskValues()
      Returns the vector of risk values.

      In most cases the length is 1. For DecisionTree, CrossValidation returns an array of length >= 1.

      Returns:
      a double array containing the estimated risk values.
    • getRandomObject

      public Random getRandomObject()
      Returns the random object being used in the permutation of the observations.
      Returns:
      a Random object being used for permutations
    • getRiskStandardErrors

      public double[] getRiskStandardErrors()
      Returns the estimated standard errors for the risk values.

      In most cases the length is 1. For DecisionTree, CrossValidation returns an array of length >= 1.

      Returns:
      a double array containing the estimated standard errors for the risk values.