This example applies the QUEST
method to a simulated data set with 50 cases and three predictors of mixed-type. A maximally grown tree under the default controls and the optimally pruned sub-tree obtained from cross-validation and minimal cost complexity pruning are produced. Notice that the optimally pruned tree consists of just the root node, whereas the maximal tree has five nodes and three levels.
import com.imsl.datamining.*; import com.imsl.datamining.decisionTree.*; import com.imsl.stat.Random; public class CrossValidationEx1 { public static void main(String[] args) throws Exception { PredictiveModel.VariableType[] sim0VarType = { PredictiveModel.VariableType.CATEGORICAL, PredictiveModel.VariableType.QUANTITATIVE_CONTINUOUS, PredictiveModel.VariableType.CATEGORICAL, PredictiveModel.VariableType.CATEGORICAL }; double[][] sim0XY = { {2, 25.92869, 0, 0}, {1, 51.63245, 1, 1}, {1, 25.78432, 0, 2}, {0, 39.37948, 0, 3}, {2, 24.65058, 0, 2}, {2, 45.20084, 0, 2}, {2, 52.67960, 1, 3}, {1, 44.28342, 1, 3}, {2, 40.63523, 1, 3}, {2, 51.76094, 0, 3}, {2, 26.30368, 0, 1}, {2, 20.70230, 1, 0}, {2, 38.74273, 1, 3}, {2, 19.47333, 0, 0}, {1, 26.42211, 0, 0}, {2, 37.05986, 1, 0}, {1, 51.67043, 1, 3}, {0, 42.40156, 0, 3}, {2, 33.90027, 1, 2}, {1, 35.43282, 0, 0}, {1, 44.30369, 0, 1}, {0, 46.72387, 0, 2}, {1, 46.99262, 0, 2}, {0, 36.05923, 0, 3}, {2, 36.83197, 1, 1}, {1, 61.66257, 1, 2}, {0, 25.67714, 0, 3}, {1, 39.08567, 1, 0}, {0, 48.84341, 1, 1}, {1, 39.34391, 0, 3}, {2, 24.73522, 0, 2}, {1, 50.55251, 1, 3}, {0, 31.34263, 1, 3}, {1, 27.15795, 1, 0}, {0, 31.72685, 0, 2}, {0, 25.00408, 0, 3}, {1, 26.35457, 1, 3}, {2, 38.12343, 0, 1}, {0, 49.94030, 0, 2}, {1, 42.45779, 1, 3}, {0, 38.80948, 1, 1}, {0, 43.22799, 1, 1}, {0, 41.87624, 0, 3}, {2, 48.07820, 0, 2}, {0, 43.23673, 1, 0}, {2, 39.41294, 0, 3}, {1, 23.93346, 0, 2}, {2, 42.84130, 1, 3}, {2, 30.40669, 0, 1}, {0, 37.77389, 0, 2} }; Random r = new Random(123457); r.setMultiplier(16807); QUEST dt = new QUEST(sim0XY, 3, sim0VarType); dt.setAutoPruningFlag(true); dt.fitModel(); /* print the maximal tree */ dt.printDecisionTree(true); CrossValidation cv = new CrossValidation(dt); cv.setRandomObject(r); cv.crossValidate(); double cvError = cv.getCrossValidatedError(); double[] Rcv = cv.getRiskValues(); double[] SERcv = cv.getRiskStandardErrors(); System.out.println("\nTree \t Complexity\t CV Risk \t SE of CV Risk "); for (int k = 0; k < Rcv.length; k++) { System.out.printf(" %d \t %3.2f \t %5.4f \t %5.4f\n", k, dt.getCostComplexityValues()[k], Rcv[k], SERcv[k]); } /* prune the tree using the selected complexity value */ dt.pruneTree(dt.getCostComplexityValues()[0]); System.out.printf("Minimum CV Risk Values: %5.4f\n", cvError); System.out.printf("Minimum CV Risk + Standard error: %5.4f\n", (cvError + SERcv[0])); /* print the pruned tree */ dt.printDecisionTree(false); } }
Decision Tree: Node 0: Cost = 0.620, N= 50, Level = 0, Child nodes: 1 2 P(Y=0)= 0.180 P(Y=1)= 0.180 P(Y=2)= 0.260 P(Y=3)= 0.380 Predicted Y: 3 Node 1: Cost = 0.220, N= 17, Level = 1 Rule: X1 <= 35.031 P(Y=0)= 0.294 P(Y=1)= 0.118 P(Y=2)= 0.353 P(Y=3)= 0.235 Predicted Y: 2 Node 2: Cost = 0.360, N= 33, Level = 1, Child nodes: 3 4 Rule: X1 > 35.031 P(Y=0)= 0.121 P(Y=1)= 0.212 P(Y=2)= 0.212 P(Y=3)= 0.455 Predicted Y: 3 Node 3: Cost = 0.180, N= 19, Level = 2 Rule: X1 <= 43.265 P(Y=0)= 0.211 P(Y=1)= 0.211 P(Y=2)= 0.053 P(Y=3)= 0.526 Predicted Y: 3 Node 4: Cost = 0.160, N= 14, Level = 2 Rule: X1 > 43.265 P(Y=0)= 0.000 P(Y=1)= 0.214 P(Y=2)= 0.429 P(Y=3)= 0.357 Predicted Y: 2 Tree Complexity CV Risk SE of CV Risk 0 0.00 0.7041 0.0804 1 0.02 0.7264 0.0856 2 0.04 0.7281 0.0860 Minimum CV Risk Values: 0.7041 Minimum CV Risk + Standard error: 0.7845 Decision Tree: Node 0: Cost = 0.620, N= 50, Level = 0, Child nodes: 1 2 P(Y=0)= 0.180 P(Y=1)= 0.180 P(Y=2)= 0.260 P(Y=3)= 0.380 Predicted Y: 3 Node 1: Cost = 0.220, N= 17, Level = 1 Rule: X1 <= 35.031 P(Y=0)= 0.294 P(Y=1)= 0.118 P(Y=2)= 0.353 P(Y=3)= 0.235 Predicted Y: 2 Node 2: Cost = 0.360, N= 33, Level = 1 Rule: X1 > 35.031 P(Y=0)= 0.121 P(Y=1)= 0.212 P(Y=2)= 0.212 P(Y=3)= 0.455 Predicted Y: 3 Pruned at Node id 2.Link to Java source.