Example 3: Gradient Boosting

This example applies cross-validation to the number of iterations parameter in stochastic gradient boosting. The number of iterations with the minimum cross-validated risk estimate is 50.


import com.imsl.datamining.CrossValidation;
import com.imsl.datamining.GradientBoosting;
import com.imsl.stat.Random;

public class GradientBoostingEx3 {

    public static void main(String[] args) throws Exception {
        GradientBoosting.VariableType[] VarType = {
            GradientBoosting.VariableType.CATEGORICAL,
            GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS,
            GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS,
            GradientBoosting.VariableType.CATEGORICAL,
            GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS
        };
        double[][] trainingData = {
            {0.0, 0.4223019897, 1.7540411302, 3.0, 0.763836258},
            {0.0, 0.0907259332, 0.8722643796, 2.0, 1.859006285},
            {0.0, 0.1384744535, 0.838324877, 1.0, 0.249729405},
            {1.0, 0.5435024537, 1.2359190206, 4.0, 0.831992314},
            {0.0, 0.8359154933, 1.8527500411, 1.0, 1.089201049},
            {1.0, 0.3577950741, 0.3652825342, 3.0, 2.204364955},
            {1.0, 0.6799094002, 0.6610595905, 3.0, 1.44730419},
            {0.0, 0.5821297709, 1.6180879478, 1.0, 2.957565282},
            {1.0, 0.8229457375, 1.0201675948, 3.0, 2.872570117},
            {0.0, 0.0633462721, 0.4140600134, 1.0, 0.63906323},
            {1.0, 0.1019134156, 0.0677204356, 3.0, 1.493447564},
            {0.0, 0.1551713238, 1.541201456, 3.0, 1.90219884},
            {1.0, 0.8273822817, 0.2114979578, 3.0, 2.855730173},
            {0.0, 0.7955570114, 1.8757067556, 2.0, 2.930132627},
            {0.0, 0.6537275917, 1.2139678737, 2.0, 1.535853243},
            {1.0, 0.1243124125, 1.5130919744, 4.0, 2.733670775},
            {0.0, 0.2163864174, 0.7051185896, 2.0, 2.755841087},
            {0.0, 0.2522670308, 1.2821007571, 2.0, 0.342119491},
            {0.0, 0.8677104027, 1.9003869346, 2.0, 2.454376481},
            {1.0, 0.8670932774, 0.7993045617, 4.0, 2.732812615},
            {0.0, 0.5384287981, 0.1856947718, 1.0, 1.838702635},
            {0.0, 0.7236269342, 0.4993310347, 1.0, 1.030699128},
            {0.0, 0.0789361731, 1.011216166, 1.0, 2.539607478},
            {1.0, 0.7631686032, 0.0536725423, 2.0, 1.401761686},
            {0.0, 0.1157020777, 0.0123261618, 1.0, 2.098372295},
            {1.0, 0.1451248352, 1.9153951635, 3.0, 0.492650534},
            {1.0, 0.8497178114, 1.80941298, 4.0, 2.653985489},
            {0.0, 0.8027864883, 1.2631045617, 3.0, 2.716214291},
            {0.0, 0.798560373, 0.6872106791, 2.0, 2.763023936},
            {1.0, 0.1816879204, 0.4323868025, 4.0, 0.098090197},
            {1.0, 0.6301239238, 0.3670980479, 3.0, 0.02313788},
            {1.0, 0.0411311248, 0.0173408454, 3.0, 1.994786958},
            {1.0, 0.0427366099, 0.8114635572, 3.0, 2.966069741},
            {1.0, 0.4107826762, 0.1929467283, 4.0, 0.573832348},
            {0.0, 0.9441903098, 0.0729898885, 1.0, 1.710992303},
            {1.0, 0.3597549822, 0.2799857073, 2.0, 0.969428934},
            {0.0, 0.3741368004, 1.6052779425, 2.0, 1.866030486},
            {0.0, 0.3515911719, 0.3383029872, 1.0, 2.639469598},
            {0.0, 0.9184092905, 1.7116801264, 1.0, 1.380178652},
            {1.0, 0.77803064, 1.9830028405, 3.0, 1.834021992},
            {0.0, 0.573786814, 0.0258851023, 1.0, 1.52130144},
            {1.0, 0.3279244492, 0.6977945678, 4.0, 1.322451157},
            {0.0, 0.7924819048, 0.3694838509, 1.0, 2.369654865},
            {0.0, 0.9787846403, 1.1470323382, 2.0, 0.037156113},
            {1.0, 0.6910662795, 0.1019420708, 2.0, 2.58588334},
            {0.0, 0.1367050812, 0.6635301332, 2.0, 0.368273583},
            {0.0, 0.2826360366, 1.4468787988, 1.0, 2.705811968},
            {0.0, 0.4524727969, 0.7885378413, 2.0, 0.851228449},
            {0.0, 0.5118664701, 1.061143666, 1.0, 0.249325278},
            {0.0, 0.9965170731, 0.2068265025, 2.0, 0.9210639},
            {1.0, 0.7801500652, 1.565742691, 4.0, 1.827419217},
            {0.0, 0.2906187973, 1.7036567871, 2.0, 2.842997725},
            {0.0, 0.1753704017, 0.7124397112, 2.0, 1.262811961},
            {1.0, 0.7796778064, 0.3478030777, 3.0, 0.90719801},
            {1.0, 0.3889356288, 1.1771452101, 4.0, 1.298438454},
            {0.0, 0.9374473374, 1.1879778663, 1.0, 1.854424331},
            {1.0, 0.1939157653, 0.093336341, 4.0, 0.166025681},
            {1.0, 0.2023756928, 0.0623724433, 3.0, 0.536441906},
            {0.0, 0.1691352043, 1.1587338657, 2.0, 2.15494096},
            {1.0, 0.0921523357, 0.2247394961, 3.0, 2.006995301},
            {0.0, 0.819186907, 0.0392292971, 1.0, 1.282159743},
            {0.0, 0.9458126165, 1.5268264762, 1.0, 1.960050194},
            {0.0, 0.1373939656, 1.8025095677, 2.0, 0.633624267},
            {0.0, 0.0555424779, 0.5022063241, 2.0, 0.639495004},
            {1.0, 0.3581428374, 1.4436954968, 3.0, 1.408938169},
            {1.0, 0.1189418568, 0.8011626904, 4.0, 0.210266769},
            {1.0, 0.5782070206, 1.58215921, 3.0, 2.648622607},
            {0.0, 0.460689794, 0.0704823257, 1.0, 1.45671379},
            {0.0, 0.6959878858, 0.2245675903, 2.0, 1.849515461},
            {0.0, 0.1930288749, 0.6296302159, 2.0, 2.597390946},
            {0.0, 0.4912149447, 0.0713489084, 1.0, 0.426487798},
            {0.0, 0.3496920248, 1.0135462089, 1.0, 2.962295362},
            {1.0, 0.7716284667, 0.5387295927, 4.0, 0.736709363},
            {1.0, 0.3463061263, 0.7819578522, 4.0, 1.597238498},
            {1.0, 0.6897138762, 1.2793166582, 4.0, 2.376281484},
            {0.0, 0.2818824656, 1.4379718141, 3.0, 2.627468417},
            {0.0, 0.5659798421, 1.6243568249, 1.0, 1.624809581},
            {0.0, 0.7965560518, 0.3933029529, 2.0, 0.415849269},
            {0.0, 0.9156922165, 1.0465683565, 1.0, 2.802914008},
            {0.0, 0.8299879942, 1.2237155279, 1.0, 2.611676934},
            {0.0, 0.0241912066, 1.9213823564, 1.0, 0.659596571},
            {0.0, 0.0948590154, 0.3609640412, 1.0, 1.287687748},
            {0.0, 0.230467916, 1.9421709292, 3.0, 2.290064565},
            {0.0, 0.2209760561, 0.4812708795, 1.0, 1.862393057},
            {0.0, 0.4704530933, 0.2644400774, 1.0, 1.960189529},
            {1.0, 0.1986645423, 0.48924731, 2.0, 0.333790415},
            {0.0, 0.9201823308, 1.4247304946, 1.0, 0.367654009},
            {1.0, 0.8118424334, 0.1017034058, 2.0, 2.001390385},
            {1.0, 0.1347265388, 0.1362061207, 3.0, 1.151431168},
            {0.0, 0.9884603191, 1.5700038988, 2.0, 0.717332943},
            {0.0, 0.1964012324, 0.4306495111, 1.0, 1.689056823},
            {1.0, 0.4031848807, 1.1251849262, 4.0, 1.977734922},
            {1.0, 0.0341882701, 0.3717348906, 4.0, 1.830587439},
            {0.0, 0.5073120815, 1.7860476542, 3.0, 0.142862822},
            {0.0, 0.6363195451, 0.6631249222, 2.0, 1.211148724},
            {1.0, 0.1642774614, 1.1963615627, 3.0, 0.843113448},
            {0.0, 0.0945515088, 1.8669327218, 1.0, 2.417198514},
            {0.0, 0.2364508687, 1.4035215094, 2.0, 2.964026097},
            {1.0, 0.7490112646, 0.1778408242, 4.0, 2.343119453},
            {1.0, 0.5193473259, 0.3090019161, 3.0, 1.300277323}};

        GradientBoosting gb = new GradientBoosting(trainingData, 4, VarType);

        int[] cvIterations = {10, 30, 50, 100, 500};
        gb.setIterationsArray(cvIterations);
        gb.setShrinkageParameter(0.05);
        gb.setSampleSizeProportion(0.5);
        gb.setRandomObject(new Random(123457));

        CrossValidation cv = new CrossValidation(gb);
        cv.setRandomObject(new Random(123457));
        cv.crossValidate();
        double cvError = cv.getCrossValidatedError();
        double[] Rcv = cv.getRiskValues();
        double[] SERcv = cv.getRiskStandardErrors();

        System.out.println("\nModel \t Number Of Iterations\t   CV Risk       "
                + "SE of CV Risk ");
        for (int k = 0; k < Rcv.length; k++) {
            System.out.printf("  %d             %3d                 %6.2f          %7.2f\n",
                    k, cvIterations[k], Rcv[k], SERcv[k]);
        }

        System.out.printf("Minimum CV Risk Values: %5.4f\n", cvError);
        System.out.printf("Minimum CV Risk + Standard error: %5.4f\n",
                (cvError + SERcv[0]));
    }
}

Output


Model 	 Number Of Iterations	   CV Risk       SE of CV Risk 
  0              10                 106.89           926.69
  1              30                 102.32           850.83
  2              50                  98.43           785.00
  3             100                 111.95          1016.45
  4             500                 124.51          1257.52
Minimum CV Risk Values: 98.4270
Minimum CV Risk + Standard error: 1025.1126
Link to Java source.