public class CrossValidation extends Object implements Serializable, Cloneable
CrossValidation.setStratifiedCrossValidation(boolean)
.
The cross-validated estimate of the risk function is given by
$$R^{CV}(d_k)=\frac{1}{N}\sum^{V}_{v=1} \sum_{(x_n,y_n)\in{\eta_v}}{L(y_n,d^{v}_k(x_n))}$$
\(L(y,d(x))\) is the loss incurred when the prediction is \(d(x)\) for the actual y. The inner summation is over the examples in the test sample held out for each fold.
If the predictive model is an instance of a decision tree, cross-validation is performed on each optimal sub-tree determined by cost-complexity pruning. Let the symbol \(\eta\) denote the full training data set, \(\eta_v\) the \(v^{th}\) sub-sample. Then use \(d^{v}_k\) to indicate the set of predictions corresponding to the kth optimal sub-tree fitted on the training sample \({\eta-\eta_v}\). To select one sub-tree from among the configurations, two criteria are the minimum $$ k^* = \text{argmin}_k R^{CV}(d_k) $$ and the least complicated (smallest sub-tree) which satisfies: $$ k^{**} = \text{max}(k): R^{CV}(d_k) \le R^{CV}(d_{k^*}) + SE(R^{CV}(d_{k})) $$ The standard error is approximated by $$ SE(R^{CV}(d_{k})) \approx \sqrt{\frac{s^2}{n}} $$ with $$ s^2 = \frac{1}{N} \sum^{V}_{v=1}\sum_{(x_n,y_n)\in{\eta_v}}(L(Y_n,d_k^{v}(x_n)- R^{CV}(d_{k}) ))^2 $$ The summation is over each fold and each learning sub-sample within each fold.
Constructor and Description |
---|
CrossValidation(PredictiveModel pm)
Creates a
CrossValidation object. |
Modifier and Type | Method and Description |
---|---|
void |
crossValidate()
Performs V-Fold cross-validation.
|
double |
getCrossValidatedError()
Returns the cross-validated error.
|
int |
getNumberOfSampleFolds()
Returns the number of folds set for the cross-validation.
|
int |
getNumberOfThreads()
Returns the maximum number of
java.lang.Thread instances
that may be used for parallel processing. |
Random |
getRandomObject()
Returns the random object being used in the permutation of the
observations.
|
double[] |
getRiskStandardErrors()
Returns the estimated standard errors for the risk values.
|
double[] |
getRiskValues()
Returns the vector of risk values.
|
boolean |
isStratifiedCrossValidation()
Returns the flag to perform stratified cross-validation for a categorical
response variable.
|
void |
setNumberOfSampleFolds(int nFolds)
Sets the number of folds to use in cross validation.
|
void |
setNumberOfThreads(int numberOfThreads)
Sets the maximum number of
java.lang.Thread instances that
may be used for parallel processing. |
void |
setRandomObject(Random r)
Sets the random object to be used in the permutation of observation data.
|
void |
setStratifiedCrossValidation(boolean stratify)
Sets the flag to perform stratified cross-validation.
|
public CrossValidation(PredictiveModel pm) throws PredictiveModel.PredictiveModelException
CrossValidation
object.pm
- an object of a class that extends PredictiveModel
PredictiveModel.PredictiveModelException
- is thrown
when an exception has occurred in the com.imsl.datamining.PredictiveModel.
Superclass exceptions should be considered such as
com.imsl.datamining.PredictiveModel.StateChangeException and
com.imsl.datamining.PredictiveModel.SumOfProbabilitiesNotOneException.public void crossValidate() throws PredictiveModel.PredictiveModelException, NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetException
PredictiveModel.PredictiveModelException
- is thrown when an exception occurs in
the common
PredictiveModel
programming interface methods or an
exception class that has extended the
PredictiveModelException
class.NoSuchMethodException
- is thrown when the PredictiveModel subclass
is missing a constructor with the expected signature (see
com.imsl.datamining.PredictiveModel.PredictiveModel).InstantiationException
- is thrown when an object fails to instantiate. This
may occur if the PredictiveModel subclass is not concrete.IllegalAccessException
- is thrown when the currently executing method does not
have access to the definition of the specified class, field, method or
constructor.InvocationTargetException
- is thrown when a
wrapped exception is thrown by an invoked method or constructor.public double getCrossValidatedError() throws PredictiveModel.StateChangeException
PredictiveModel
object. If
the response variable is quantitative/continuous, the error is the mean
squared prediction error, also weighted if weights are set in the
PredictiveModel
object. If there are multiple model
configurations, the minimum value is returned.
double
, the cross-validated prediction errorPredictiveModel.StateChangeException
- is thrown
when an input
parameter in the PredictiveModel
has changed that might affect
the model estimates or predictions.public int getNumberOfThreads()
java.lang.Thread
instances
that may be used for parallel processing.int
containing the maximum number of
java.lang.Thread
instances that may be used for parallel
processing.
The actual number of threads used in parallel processing will be the
lesser of numberOfThreads
and nFolds
, the
number of folds set for cross-validation. This assessment is made to
optimize use of resources.
public void setNumberOfThreads(int numberOfThreads)
java.lang.Thread
instances that
may be used for parallel processing.numberOfThreads
- an int
specifying the maximum number
of java.lang.Thread
instances that may be used for parallel
processing.
The actual number of threads used in parallel processing will be the
lesser of numberOfThreads
and nFolds
, the
number of folds set for cross-validation. This assessment is made to
optimize use of resources.
Default: numberOfThreads
= 1.
public int getNumberOfSampleFolds()
int
, the number of foldspublic void setNumberOfSampleFolds(int nFolds)
nFolds
- an int
specifying the number of folds
nFolds
must be between 1 and the number of observations
(xy.length
), inclusive. If nFolds
= 1 the full
data set is used once to generate the PredictiveModel
. In other
words, no cross-validation is performed. If 1 <
xy.length
/nFolds
\(\le\) 3,
leave-one-out cross validation is performed.
Default: nFolds
= 10.
public boolean isStratifiedCrossValidation()
When true
, the method crossValidate
creates the
samples to have roughly the same proportion of each class level as occurs
in the full training set. When false
, regular
V
-fold cross-validation is performed. If the response
variable is continuous, the flag has no effect.
boolean
, the state of the flagpublic void setStratifiedCrossValidation(boolean stratify)
When true
, the method crossValidate
creates the
samples to have roughly the same proportion of each class level as occurs
in the full training set. When false
, regular
V
-fold cross-validation is performed. If the response
variable is continuous the flag has no effect.
stratify
- a boolean
indicating whether or not
stratified cross-validation should be performed for categorical response
variables
Default: stratify
=false
public void setRandomObject(Random r)
r
- a Random
object to be used in random permutation of
observation data.
Specifying a seed for the Random
object can produce
repeatable/deterministic output.
public double[] getRiskValues()
In most cases the length is 1. For
DecisionTree
,
CrossValidation
returns an array of length >= 1.
double
array containing the estimated risk values.public Random getRandomObject()
Random
object being used for permutationspublic double[] getRiskStandardErrors()
In most cases the length is 1. For
DecisionTree
,
CrossValidation
returns an array of length >= 1.
double
array containing the estimated standard
errors for the risk values.Copyright © 2020 Rogue Wave Software. All rights reserved.