package com.imsl.test.example.datamining;
import com.imsl.datamining.*;
import com.imsl.datamining.decisionTree.*;
import com.imsl.stat.Random;
/**
*
*
Uses cross-validation to determine the
* optimally pruned decision tree.
*
*
* In this example, the
* QUEST
* algorithm is first used to fit a decision tree to simulated data. Then the
* minimal cost-complexity value is found via cross-validation. The class
* CrossValidation
* constructor accepts the decision tree object and then the method
* CrossValidation.crossValidate()
is called to run the procedure.
* Finally, the tree is pruned using the result. Notice that the optimally
* pruned tree consists of just the root node, whereas the maximal tree has five
* nodes and three levels.
*
*
* @see Code
* @see Output
*
*/
public class CrossValidationEx1 {
/**
* The main method of the example.
*/
public static void main(String[] args) throws Exception {
PredictiveModel.VariableType[] sim0VarType = {
PredictiveModel.VariableType.CATEGORICAL,
PredictiveModel.VariableType.QUANTITATIVE_CONTINUOUS,
PredictiveModel.VariableType.CATEGORICAL,
PredictiveModel.VariableType.CATEGORICAL
};
double[][] sim0XY = {
{2, 25.92869, 0, 0}, {1, 51.63245, 1, 1}, {1, 25.78432, 0, 2},
{0, 39.37948, 0, 3}, {2, 24.65058, 0, 2}, {2, 45.20084, 0, 2},
{2, 52.67960, 1, 3}, {1, 44.28342, 1, 3}, {2, 40.63523, 1, 3},
{2, 51.76094, 0, 3}, {2, 26.30368, 0, 1}, {2, 20.70230, 1, 0},
{2, 38.74273, 1, 3}, {2, 19.47333, 0, 0}, {1, 26.42211, 0, 0},
{2, 37.05986, 1, 0}, {1, 51.67043, 1, 3}, {0, 42.40156, 0, 3},
{2, 33.90027, 1, 2}, {1, 35.43282, 0, 0}, {1, 44.30369, 0, 1},
{0, 46.72387, 0, 2}, {1, 46.99262, 0, 2}, {0, 36.05923, 0, 3},
{2, 36.83197, 1, 1}, {1, 61.66257, 1, 2}, {0, 25.67714, 0, 3},
{1, 39.08567, 1, 0}, {0, 48.84341, 1, 1}, {1, 39.34391, 0, 3},
{2, 24.73522, 0, 2}, {1, 50.55251, 1, 3}, {0, 31.34263, 1, 3},
{1, 27.15795, 1, 0}, {0, 31.72685, 0, 2}, {0, 25.00408, 0, 3},
{1, 26.35457, 1, 3}, {2, 38.12343, 0, 1}, {0, 49.94030, 0, 2},
{1, 42.45779, 1, 3}, {0, 38.80948, 1, 1}, {0, 43.22799, 1, 1},
{0, 41.87624, 0, 3}, {2, 48.07820, 0, 2}, {0, 43.23673, 1, 0},
{2, 39.41294, 0, 3}, {1, 23.93346, 0, 2}, {2, 42.84130, 1, 3},
{2, 30.40669, 0, 1}, {0, 37.77389, 0, 2}
};
Random r = new Random(123457);
r.setMultiplier(16807);
QUEST dt = new QUEST(sim0XY, 3, sim0VarType);
dt.setAutoPruningFlag(true);
dt.fitModel();
/* print the maximal tree */
dt.printDecisionTree(true);
CrossValidation cv = new CrossValidation(dt);
cv.setRandomObject(r);
cv.crossValidate();
double cvError = cv.getCrossValidatedError();
double[] Rcv = cv.getRiskValues();
double[] SERcv = cv.getRiskStandardErrors();
System.out.println("\nTree \t Complexity\t CV Risk \t SE of CV Risk ");
for (int k = 0; k < Rcv.length; k++) {
System.out.printf(" %d \t %3.2f \t %5.4f \t %5.4f\n",
k, dt.getCostComplexityValues()[k], Rcv[k], SERcv[k]);
}
/* prune the tree using the selected complexity value */
dt.pruneTree(dt.getCostComplexityValues()[0]);
System.out.printf("Minimum CV Risk Values: %5.4f\n", cvError);
System.out.printf("Minimum CV Risk + Standard error: %5.4f\n",
(cvError + SERcv[0]));
/* print the pruned tree */
dt.printDecisionTree(false);
}
}