This example applies the QUEST
method to a simulated data set with 50 cases and three predictors of mixed-type. A maximally grown tree under the default controls and the optimally pruned sub-tree obtained from cross-validation and minimal cost complexity pruning are produced. Notice that the optimally pruned tree consists of just the root node, whereas the maximal tree has five nodes and three levels.
import com.imsl.datamining.*;
import com.imsl.datamining.decisionTree.*;
import com.imsl.stat.Random;
public class CrossValidationEx1 {
public static void main(String[] args) throws Exception {
PredictiveModel.VariableType[] sim0VarType = {
PredictiveModel.VariableType.CATEGORICAL,
PredictiveModel.VariableType.QUANTITATIVE_CONTINUOUS,
PredictiveModel.VariableType.CATEGORICAL,
PredictiveModel.VariableType.CATEGORICAL
};
double[][] sim0XY = {
{2, 25.92869, 0, 0},
{1, 51.63245, 1, 1},
{1, 25.78432, 0, 2},
{0, 39.37948, 0, 3},
{2, 24.65058, 0, 2},
{2, 45.20084, 0, 2},
{2, 52.67960, 1, 3},
{1, 44.28342, 1, 3},
{2, 40.63523, 1, 3},
{2, 51.76094, 0, 3},
{2, 26.30368, 0, 1},
{2, 20.70230, 1, 0},
{2, 38.74273, 1, 3},
{2, 19.47333, 0, 0},
{1, 26.42211, 0, 0},
{2, 37.05986, 1, 0},
{1, 51.67043, 1, 3},
{0, 42.40156, 0, 3},
{2, 33.90027, 1, 2},
{1, 35.43282, 0, 0},
{1, 44.30369, 0, 1},
{0, 46.72387, 0, 2},
{1, 46.99262, 0, 2},
{0, 36.05923, 0, 3},
{2, 36.83197, 1, 1},
{1, 61.66257, 1, 2},
{0, 25.67714, 0, 3},
{1, 39.08567, 1, 0},
{0, 48.84341, 1, 1},
{1, 39.34391, 0, 3},
{2, 24.73522, 0, 2},
{1, 50.55251, 1, 3},
{0, 31.34263, 1, 3},
{1, 27.15795, 1, 0},
{0, 31.72685, 0, 2},
{0, 25.00408, 0, 3},
{1, 26.35457, 1, 3},
{2, 38.12343, 0, 1},
{0, 49.94030, 0, 2},
{1, 42.45779, 1, 3},
{0, 38.80948, 1, 1},
{0, 43.22799, 1, 1},
{0, 41.87624, 0, 3},
{2, 48.07820, 0, 2},
{0, 43.23673, 1, 0},
{2, 39.41294, 0, 3},
{1, 23.93346, 0, 2},
{2, 42.84130, 1, 3},
{2, 30.40669, 0, 1},
{0, 37.77389, 0, 2}
};
Random r = new Random(123457);
r.setMultiplier(16807);
QUEST dt = new QUEST(sim0XY, 3, sim0VarType);
dt.setAutoPruningFlag(true);
dt.fitModel();
/* print the maximal tree */
dt.printDecisionTree(true);
CrossValidation cv = new CrossValidation(dt);
cv.setRandomObject(r);
cv.crossValidate();
double cvError = cv.getCrossValidatedError();
double[] Rcv = cv.getRiskValues();
double[] SERcv = cv.getRiskStandardErrors();
System.out.println("\nTree \t Complexity\t CV Risk \t SE of CV Risk ");
for (int k = 0; k < Rcv.length; k++) {
System.out.printf(" %d \t %3.2f \t %5.4f \t %5.4f\n",
k, dt.getCostComplexityValues()[k], Rcv[k], SERcv[k]);
}
/* prune the tree using the selected complexity value */
dt.pruneTree(dt.getCostComplexityValues()[0]);
System.out.printf("Minimum CV Risk Values: %5.4f\n", cvError);
System.out.printf("Minimum CV Risk + Standard error: %5.4f\n",
(cvError + SERcv[0]));
/* print the pruned tree */
dt.printDecisionTree(false);
}
}
Decision Tree:
Node 0: Cost = 0.620, N= 50, Level = 0, Child nodes: 1 2
P(Y=0)= 0.180
P(Y=1)= 0.180
P(Y=2)= 0.260
P(Y=3)= 0.380
Predicted Y: 3
Node 1: Cost = 0.220, N= 17, Level = 1
Rule: X1 <= 35.031
P(Y=0)= 0.294
P(Y=1)= 0.118
P(Y=2)= 0.353
P(Y=3)= 0.235
Predicted Y: 2
Node 2: Cost = 0.360, N= 33, Level = 1, Child nodes: 3 4
Rule: X1 <= 35.031
P(Y=0)= 0.121
P(Y=1)= 0.212
P(Y=2)= 0.212
P(Y=3)= 0.455
Predicted Y: 3
Node 3: Cost = 0.180, N= 19, Level = 2
Rule: X1 <= 43.265
P(Y=0)= 0.211
P(Y=1)= 0.211
P(Y=2)= 0.053
P(Y=3)= 0.526
Predicted Y: 3
Node 4: Cost = 0.160, N= 14, Level = 2
Rule: X1 <= 43.265
P(Y=0)= 0.000
P(Y=1)= 0.214
P(Y=2)= 0.429
P(Y=3)= 0.357
Predicted Y: 2
Tree Complexity CV Risk SE of CV Risk
0 0.00 0.7041 0.0804
1 0.02 0.7264 0.0856
2 0.04 0.7281 0.0860
Minimum CV Risk Values: 0.7041
Minimum CV Risk + Standard error: 0.7845
Decision Tree:
Node 0: Cost = 0.620, N= 50, Level = 0, Child nodes: 1 2
P(Y=0)= 0.180
P(Y=1)= 0.180
P(Y=2)= 0.260
P(Y=3)= 0.380
Predicted Y: 3
Node 1: Cost = 0.220, N= 17, Level = 1
Rule: X1 <= 35.031
P(Y=0)= 0.294
P(Y=1)= 0.118
P(Y=2)= 0.353
P(Y=3)= 0.235
Predicted Y: 2
Node 2: Cost = 0.360, N= 33, Level = 1
Rule: X1 <= 35.031
P(Y=0)= 0.121
P(Y=1)= 0.212
P(Y=2)= 0.212
P(Y=3)= 0.455
Predicted Y: 3
Pruned at Node id 2.
Link to Java source.