Class CrossValidation
- All Implemented Interfaces:
Serializable,Cloneable
setStratifiedCrossValidation(boolean).
The cross-validated estimate of the risk function is given by
$$R^{CV}(d_k)=\frac{1}{N}\sum^{V}_{v=1} \sum_{(x_n,y_n)\in{\eta_v}}{L(y_n,d^{v}_k(x_n))}$$
\(L(y,d(x))\) is the loss incurred when the prediction is \(d(x)\) for the actual y. The inner summation is over the examples in the test sample held out for each fold.
If the predictive model is an instance of a decision tree, cross-validation is performed on each optimal sub-tree determined by cost-complexity pruning. Let the symbol \(\eta\) denote the full training data set, \(\eta_v\) the \(v^{th}\) sub-sample. Then use \(d^{v}_k\) to indicate the set of predictions corresponding to the kth optimal sub-tree fitted on the training sample \({\eta-\eta_v}\). To select one sub-tree from among the configurations, two criteria are the minimum $$ k^* = \text{argmin}_k R^{CV}(d_k) $$ and the least complicated (smallest sub-tree) which satisfies: $$ k^{**} = \text{max}(k): R^{CV}(d_k) \le R^{CV}(d_{k^*}) + SE(R^{CV}(d_{k})) $$ The standard error is approximated by $$ SE(R^{CV}(d_{k})) \approx \sqrt{\frac{s^2}{n}} $$ with $$ s^2 = \frac{1}{N} \sum^{V}_{v=1}\sum_{(x_n,y_n)\in{\eta_v}}(L(Y_n,d_k^{v}(x_n)- R^{CV}(d_{k}) ))^2 $$ The summation is over each fold and each learning sub-sample within each fold.
- See Also:
-
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionvoidPerforms V-Fold cross-validation.doubleReturns the cross-validated error.intReturns the number of folds set for the cross-validation.intReturns the maximum number ofjava.lang.Threadinstances that may be used for parallel processing.Returns the random object being used in the permutation of the observations.double[]Returns the estimated standard errors for the risk values.double[]Returns the vector of risk values.booleanReturns the flag to perform stratified cross-validation for a categorical response variable.voidsetNumberOfSampleFolds(int nFolds) Sets the number of folds to use in cross validation.voidsetNumberOfThreads(int numberOfThreads) Sets the maximum number ofjava.lang.Threadinstances that may be used for parallel processing.voidSets the random object to be used in the permutation of observation data.voidsetStratifiedCrossValidation(boolean stratify) Sets the flag to perform stratified cross-validation.
-
Constructor Details
-
CrossValidation
Creates aCrossValidationobject.- Parameters:
pm- an object of a class that extendsPredictiveModel- Throws:
PredictiveModel.PredictiveModelException- is thrown when an exception has occurred in the com.imsl.datamining.PredictiveModel. Superclass exceptions should be considered such as com.imsl.datamining.PredictiveModel.StateChangeException and com.imsl.datamining.PredictiveModel.SumOfProbabilitiesNotOneException.
-
-
Method Details
-
crossValidate
public void crossValidate() throws PredictiveModel.PredictiveModelException, NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetExceptionPerforms V-Fold cross-validation.- Throws:
PredictiveModel.PredictiveModelException- is thrown when an exception occurs in the commonPredictiveModelprogramming interface methods or an exception class that has extended thePredictiveModelExceptionclass.NoSuchMethodException- is thrown when the PredictiveModel subclass is missing a constructor with the expected signature (see com.imsl.datamining.PredictiveModel.PredictiveModel).InstantiationException- is thrown when an object fails to instantiate. This may occur if the PredictiveModel subclass is not concrete.IllegalAccessException- is thrown when the currently executing method does not have access to the definition of the specified class, field, method or constructor.InvocationTargetException- is thrown when a wrapped exception is thrown by an invoked method or constructor.
-
getCrossValidatedError
Returns the cross-validated error. If the response variable is categorical, the error is the misclassification rate, weighted by the misclassification costs and prior probabilities, attributes of thePredictiveModelobject. If the response variable is quantitative/continuous, the error is the mean squared prediction error, also weighted if weights are set in thePredictiveModelobject. If there are multiple model configurations, the minimum value is returned.- Returns:
- a
double, the cross-validated prediction error - Throws:
PredictiveModel.StateChangeException- is thrown when an input parameter in thePredictiveModelhas changed that might affect the model estimates or predictions.
-
getNumberOfThreads
public int getNumberOfThreads()Returns the maximum number ofjava.lang.Threadinstances that may be used for parallel processing.- Returns:
- an
intcontaining the maximum number ofjava.lang.Threadinstances that may be used for parallel processing.The actual number of threads used in parallel processing will be the lesser of
numberOfThreadsandnFolds, the number of folds set for cross-validation. This assessment is made to optimize use of resources.
-
setNumberOfThreads
public void setNumberOfThreads(int numberOfThreads) Sets the maximum number ofjava.lang.Threadinstances that may be used for parallel processing.- Parameters:
numberOfThreads- anintspecifying the maximum number ofjava.lang.Threadinstances that may be used for parallel processing.The actual number of threads used in parallel processing will be the lesser of
numberOfThreadsandnFolds, the number of folds set for cross-validation. This assessment is made to optimize use of resources.Default:
numberOfThreads= 1.
-
getNumberOfSampleFolds
public int getNumberOfSampleFolds()Returns the number of folds set for the cross-validation.- Returns:
- an
int, the number of folds
-
setNumberOfSampleFolds
public void setNumberOfSampleFolds(int nFolds) Sets the number of folds to use in cross validation.- Parameters:
nFolds- anintspecifying the number of foldsnFoldsmust be between 1 and the number of observations (xy.length), inclusive. IfnFolds= 1 the full data set is used once to generate thePredictiveModel. In other words, no cross-validation is performed. If 1 <xy.length/nFolds\(\le\) 3, leave-one-out cross validation is performed.Default:
nFolds= 10.
-
isStratifiedCrossValidation
public boolean isStratifiedCrossValidation()Returns the flag to perform stratified cross-validation for a categorical response variable.When
true, the methodcrossValidatecreates the samples to have roughly the same proportion of each class level as occurs in the full training set. Whenfalse, regularV-fold cross-validation is performed. If the response variable is continuous, the flag has no effect.- Returns:
- a
boolean, the state of the flag
-
setStratifiedCrossValidation
public void setStratifiedCrossValidation(boolean stratify) Sets the flag to perform stratified cross-validation.When
true, the methodcrossValidatecreates the samples to have roughly the same proportion of each class level as occurs in the full training set. Whenfalse, regularV-fold cross-validation is performed. If the response variable is continuous the flag has no effect.- Parameters:
stratify- abooleanindicating whether or not stratified cross-validation should be performed for categorical response variablesDefault:
stratify=false
-
setRandomObject
Sets the random object to be used in the permutation of observation data.- Parameters:
r- aRandomobject to be used in random permutation of observation data.Specifying a seed for the
Randomobject can produce repeatable/deterministic output.
-
getRiskValues
public double[] getRiskValues()Returns the vector of risk values.In most cases the length is 1. For
DecisionTree,CrossValidationreturns an array of length >= 1.- Returns:
- a
doublearray containing the estimated risk values.
-
getRandomObject
Returns the random object being used in the permutation of the observations.- Returns:
- a
Randomobject being used for permutations
-
getRiskStandardErrors
public double[] getRiskStandardErrors()Returns the estimated standard errors for the risk values.In most cases the length is 1. For
DecisionTree,CrossValidationreturns an array of length >= 1.- Returns:
- a
doublearray containing the estimated standard errors for the risk values.
-