Example: CrossValidation

This example applies the QUEST method to a simulated data set with 50 cases and three predictors of mixed-type. A maximally grown tree under the default controls and the optimally pruned sub-tree obtained from cross-validation and minimal cost complexity pruning are produced. Notice that the optimally pruned tree consists of just the root node, whereas the maximal tree has five nodes and three levels.


import com.imsl.datamining.*;
import com.imsl.datamining.decisionTree.*;
import com.imsl.stat.Random;

public class CrossValidationEx1 {

    public static void main(String[] args) throws Exception {
        PredictiveModel.VariableType[] sim0VarType = {
            PredictiveModel.VariableType.CATEGORICAL,
            PredictiveModel.VariableType.QUANTITATIVE_CONTINUOUS,
            PredictiveModel.VariableType.CATEGORICAL,
            PredictiveModel.VariableType.CATEGORICAL
        };

        double[][] sim0XY = {
            {2, 25.92869, 0, 0},
            {1, 51.63245, 1, 1},
            {1, 25.78432, 0, 2},
            {0, 39.37948, 0, 3},
            {2, 24.65058, 0, 2},
            {2, 45.20084, 0, 2},
            {2, 52.67960, 1, 3},
            {1, 44.28342, 1, 3},
            {2, 40.63523, 1, 3},
            {2, 51.76094, 0, 3},
            {2, 26.30368, 0, 1},
            {2, 20.70230, 1, 0},
            {2, 38.74273, 1, 3},
            {2, 19.47333, 0, 0},
            {1, 26.42211, 0, 0},
            {2, 37.05986, 1, 0},
            {1, 51.67043, 1, 3},
            {0, 42.40156, 0, 3},
            {2, 33.90027, 1, 2},
            {1, 35.43282, 0, 0},
            {1, 44.30369, 0, 1},
            {0, 46.72387, 0, 2},
            {1, 46.99262, 0, 2},
            {0, 36.05923, 0, 3},
            {2, 36.83197, 1, 1},
            {1, 61.66257, 1, 2},
            {0, 25.67714, 0, 3},
            {1, 39.08567, 1, 0},
            {0, 48.84341, 1, 1},
            {1, 39.34391, 0, 3},
            {2, 24.73522, 0, 2},
            {1, 50.55251, 1, 3},
            {0, 31.34263, 1, 3},
            {1, 27.15795, 1, 0},
            {0, 31.72685, 0, 2},
            {0, 25.00408, 0, 3},
            {1, 26.35457, 1, 3},
            {2, 38.12343, 0, 1},
            {0, 49.94030, 0, 2},
            {1, 42.45779, 1, 3},
            {0, 38.80948, 1, 1},
            {0, 43.22799, 1, 1},
            {0, 41.87624, 0, 3},
            {2, 48.07820, 0, 2},
            {0, 43.23673, 1, 0},
            {2, 39.41294, 0, 3},
            {1, 23.93346, 0, 2},
            {2, 42.84130, 1, 3},
            {2, 30.40669, 0, 1},
            {0, 37.77389, 0, 2}
        };

        Random r = new Random(123457);
        r.setMultiplier(16807);
        QUEST dt = new QUEST(sim0XY, 3, sim0VarType);

        dt.setAutoPruningFlag(true);
        dt.fitModel();
        /* print the maximal tree */
        dt.printDecisionTree(true);

        CrossValidation cv = new CrossValidation(dt);
        cv.setRandomObject(r);
        cv.crossValidate();
        double cvError = cv.getCrossValidatedError();
        double[] Rcv = cv.getRiskValues();
        double[] SERcv = cv.getRiskStandardErrors();

        System.out.println("\nTree \t Complexity\t CV Risk \t SE of CV Risk ");
        for (int k = 0; k < Rcv.length; k++) {
            System.out.printf("  %d \t  %3.2f   \t   %5.4f \t   %5.4f\n",
                    k, dt.getCostComplexityValues()[k], Rcv[k], SERcv[k]);
        }
        /* prune the tree using the selected complexity value */
        dt.pruneTree(dt.getCostComplexityValues()[0]);

        System.out.printf("Minimum CV Risk Values: %5.4f\n", cvError);
        System.out.printf("Minimum CV Risk + Standard error: %5.4f\n",
                (cvError + SERcv[0]));

        /* print the pruned tree */
        dt.printDecisionTree(false);
    }
}

Output


Decision Tree:


Node 0: Cost = 0.620, N= 50, Level = 0, Child nodes:  1  2 
P(Y=0)= 0.180
P(Y=1)= 0.180
P(Y=2)= 0.260
P(Y=3)= 0.380
Predicted Y:   3 
   
Node 1: Cost = 0.220, N= 17, Level = 1
   Rule: X1    <= 35.031
    P(Y=0)= 0.294
    P(Y=1)= 0.118
    P(Y=2)= 0.353
    P(Y=3)= 0.235
    Predicted Y:   2 
   
Node 2: Cost = 0.360, N= 33, Level = 1, Child nodes:  3  4 
   Rule: X1    > 35.031
    P(Y=0)= 0.121
    P(Y=1)= 0.212
    P(Y=2)= 0.212
    P(Y=3)= 0.455
    Predicted Y:   3 
      
Node 3: Cost = 0.180, N= 19, Level = 2
      Rule: X1       <= 43.265
        P(Y=0)= 0.211
        P(Y=1)= 0.211
        P(Y=2)= 0.053
        P(Y=3)= 0.526
        Predicted Y:   3 
      
Node 4: Cost = 0.160, N= 14, Level = 2
      Rule: X1       > 43.265
        P(Y=0)= 0.000
        P(Y=1)= 0.214
        P(Y=2)= 0.429
        P(Y=3)= 0.357
        Predicted Y:   2 

Tree 	 Complexity	 CV Risk 	 SE of CV Risk 
  0 	  0.00   	   0.7041 	   0.0804
  1 	  0.02   	   0.7264 	   0.0856
  2 	  0.04   	   0.7281 	   0.0860
Minimum CV Risk Values: 0.7041
Minimum CV Risk + Standard error: 0.7845

Decision Tree:


Node 0: Cost = 0.620, N= 50, Level = 0, Child nodes:  1  2 
P(Y=0)= 0.180
P(Y=1)= 0.180
P(Y=2)= 0.260
P(Y=3)= 0.380
Predicted Y:   3 
   
Node 1: Cost = 0.220, N= 17, Level = 1
   Rule: X1    <= 35.031
    P(Y=0)= 0.294
    P(Y=1)= 0.118
    P(Y=2)= 0.353
    P(Y=3)= 0.235
    Predicted Y:   2 
   
Node 2: Cost = 0.360, N= 33, Level = 1
   Rule: X1    > 35.031
    P(Y=0)= 0.121
    P(Y=1)= 0.212
    P(Y=2)= 0.212
    P(Y=3)= 0.455
    Predicted Y:   3 
Pruned at Node id 2.
Link to Java source.