This example applies cross-validation to the number of iterations parameter in stochastic gradient boosting. The number of iterations with the minimum cross-validated risk estimate is 50.
import com.imsl.datamining.CrossValidation; import com.imsl.datamining.GradientBoosting; import com.imsl.stat.Random; public class GradientBoostingEx3 { public static void main(String[] args) throws Exception { GradientBoosting.VariableType[] VarType = { GradientBoosting.VariableType.CATEGORICAL, GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS, GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS, GradientBoosting.VariableType.CATEGORICAL, GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS }; double[][] trainingData = { {0.0, 0.4223019897, 1.7540411302, 3.0, 0.763836258}, {0.0, 0.0907259332, 0.8722643796, 2.0, 1.859006285}, {0.0, 0.1384744535, 0.838324877, 1.0, 0.249729405}, {1.0, 0.5435024537, 1.2359190206, 4.0, 0.831992314}, {0.0, 0.8359154933, 1.8527500411, 1.0, 1.089201049}, {1.0, 0.3577950741, 0.3652825342, 3.0, 2.204364955}, {1.0, 0.6799094002, 0.6610595905, 3.0, 1.44730419}, {0.0, 0.5821297709, 1.6180879478, 1.0, 2.957565282}, {1.0, 0.8229457375, 1.0201675948, 3.0, 2.872570117}, {0.0, 0.0633462721, 0.4140600134, 1.0, 0.63906323}, {1.0, 0.1019134156, 0.0677204356, 3.0, 1.493447564}, {0.0, 0.1551713238, 1.541201456, 3.0, 1.90219884}, {1.0, 0.8273822817, 0.2114979578, 3.0, 2.855730173}, {0.0, 0.7955570114, 1.8757067556, 2.0, 2.930132627}, {0.0, 0.6537275917, 1.2139678737, 2.0, 1.535853243}, {1.0, 0.1243124125, 1.5130919744, 4.0, 2.733670775}, {0.0, 0.2163864174, 0.7051185896, 2.0, 2.755841087}, {0.0, 0.2522670308, 1.2821007571, 2.0, 0.342119491}, {0.0, 0.8677104027, 1.9003869346, 2.0, 2.454376481}, {1.0, 0.8670932774, 0.7993045617, 4.0, 2.732812615}, {0.0, 0.5384287981, 0.1856947718, 1.0, 1.838702635}, {0.0, 0.7236269342, 0.4993310347, 1.0, 1.030699128}, {0.0, 0.0789361731, 1.011216166, 1.0, 2.539607478}, {1.0, 0.7631686032, 0.0536725423, 2.0, 1.401761686}, {0.0, 0.1157020777, 0.0123261618, 1.0, 2.098372295}, {1.0, 0.1451248352, 1.9153951635, 3.0, 0.492650534}, {1.0, 0.8497178114, 1.80941298, 4.0, 2.653985489}, {0.0, 0.8027864883, 1.2631045617, 3.0, 2.716214291}, {0.0, 0.798560373, 0.6872106791, 2.0, 2.763023936}, {1.0, 0.1816879204, 0.4323868025, 4.0, 0.098090197}, {1.0, 0.6301239238, 0.3670980479, 3.0, 0.02313788}, {1.0, 0.0411311248, 0.0173408454, 3.0, 1.994786958}, {1.0, 0.0427366099, 0.8114635572, 3.0, 2.966069741}, {1.0, 0.4107826762, 0.1929467283, 4.0, 0.573832348}, {0.0, 0.9441903098, 0.0729898885, 1.0, 1.710992303}, {1.0, 0.3597549822, 0.2799857073, 2.0, 0.969428934}, {0.0, 0.3741368004, 1.6052779425, 2.0, 1.866030486}, {0.0, 0.3515911719, 0.3383029872, 1.0, 2.639469598}, {0.0, 0.9184092905, 1.7116801264, 1.0, 1.380178652}, {1.0, 0.77803064, 1.9830028405, 3.0, 1.834021992}, {0.0, 0.573786814, 0.0258851023, 1.0, 1.52130144}, {1.0, 0.3279244492, 0.6977945678, 4.0, 1.322451157}, {0.0, 0.7924819048, 0.3694838509, 1.0, 2.369654865}, {0.0, 0.9787846403, 1.1470323382, 2.0, 0.037156113}, {1.0, 0.6910662795, 0.1019420708, 2.0, 2.58588334}, {0.0, 0.1367050812, 0.6635301332, 2.0, 0.368273583}, {0.0, 0.2826360366, 1.4468787988, 1.0, 2.705811968}, {0.0, 0.4524727969, 0.7885378413, 2.0, 0.851228449}, {0.0, 0.5118664701, 1.061143666, 1.0, 0.249325278}, {0.0, 0.9965170731, 0.2068265025, 2.0, 0.9210639}, {1.0, 0.7801500652, 1.565742691, 4.0, 1.827419217}, {0.0, 0.2906187973, 1.7036567871, 2.0, 2.842997725}, {0.0, 0.1753704017, 0.7124397112, 2.0, 1.262811961}, {1.0, 0.7796778064, 0.3478030777, 3.0, 0.90719801}, {1.0, 0.3889356288, 1.1771452101, 4.0, 1.298438454}, {0.0, 0.9374473374, 1.1879778663, 1.0, 1.854424331}, {1.0, 0.1939157653, 0.093336341, 4.0, 0.166025681}, {1.0, 0.2023756928, 0.0623724433, 3.0, 0.536441906}, {0.0, 0.1691352043, 1.1587338657, 2.0, 2.15494096}, {1.0, 0.0921523357, 0.2247394961, 3.0, 2.006995301}, {0.0, 0.819186907, 0.0392292971, 1.0, 1.282159743}, {0.0, 0.9458126165, 1.5268264762, 1.0, 1.960050194}, {0.0, 0.1373939656, 1.8025095677, 2.0, 0.633624267}, {0.0, 0.0555424779, 0.5022063241, 2.0, 0.639495004}, {1.0, 0.3581428374, 1.4436954968, 3.0, 1.408938169}, {1.0, 0.1189418568, 0.8011626904, 4.0, 0.210266769}, {1.0, 0.5782070206, 1.58215921, 3.0, 2.648622607}, {0.0, 0.460689794, 0.0704823257, 1.0, 1.45671379}, {0.0, 0.6959878858, 0.2245675903, 2.0, 1.849515461}, {0.0, 0.1930288749, 0.6296302159, 2.0, 2.597390946}, {0.0, 0.4912149447, 0.0713489084, 1.0, 0.426487798}, {0.0, 0.3496920248, 1.0135462089, 1.0, 2.962295362}, {1.0, 0.7716284667, 0.5387295927, 4.0, 0.736709363}, {1.0, 0.3463061263, 0.7819578522, 4.0, 1.597238498}, {1.0, 0.6897138762, 1.2793166582, 4.0, 2.376281484}, {0.0, 0.2818824656, 1.4379718141, 3.0, 2.627468417}, {0.0, 0.5659798421, 1.6243568249, 1.0, 1.624809581}, {0.0, 0.7965560518, 0.3933029529, 2.0, 0.415849269}, {0.0, 0.9156922165, 1.0465683565, 1.0, 2.802914008}, {0.0, 0.8299879942, 1.2237155279, 1.0, 2.611676934}, {0.0, 0.0241912066, 1.9213823564, 1.0, 0.659596571}, {0.0, 0.0948590154, 0.3609640412, 1.0, 1.287687748}, {0.0, 0.230467916, 1.9421709292, 3.0, 2.290064565}, {0.0, 0.2209760561, 0.4812708795, 1.0, 1.862393057}, {0.0, 0.4704530933, 0.2644400774, 1.0, 1.960189529}, {1.0, 0.1986645423, 0.48924731, 2.0, 0.333790415}, {0.0, 0.9201823308, 1.4247304946, 1.0, 0.367654009}, {1.0, 0.8118424334, 0.1017034058, 2.0, 2.001390385}, {1.0, 0.1347265388, 0.1362061207, 3.0, 1.151431168}, {0.0, 0.9884603191, 1.5700038988, 2.0, 0.717332943}, {0.0, 0.1964012324, 0.4306495111, 1.0, 1.689056823}, {1.0, 0.4031848807, 1.1251849262, 4.0, 1.977734922}, {1.0, 0.0341882701, 0.3717348906, 4.0, 1.830587439}, {0.0, 0.5073120815, 1.7860476542, 3.0, 0.142862822}, {0.0, 0.6363195451, 0.6631249222, 2.0, 1.211148724}, {1.0, 0.1642774614, 1.1963615627, 3.0, 0.843113448}, {0.0, 0.0945515088, 1.8669327218, 1.0, 2.417198514}, {0.0, 0.2364508687, 1.4035215094, 2.0, 2.964026097}, {1.0, 0.7490112646, 0.1778408242, 4.0, 2.343119453}, {1.0, 0.5193473259, 0.3090019161, 3.0, 1.300277323}}; GradientBoosting gb = new GradientBoosting(trainingData, 4, VarType); int[] cvIterations = {10, 30, 50, 100, 500}; gb.setIterationsArray(cvIterations); gb.setShrinkageParameter(0.05); gb.setSampleSizeProportion(0.5); gb.setRandomObject(new Random(123457)); CrossValidation cv = new CrossValidation(gb); cv.setRandomObject(new Random(123457)); cv.crossValidate(); double cvError = cv.getCrossValidatedError(); double[] Rcv = cv.getRiskValues(); double[] SERcv = cv.getRiskStandardErrors(); System.out.println("\nModel \t Number Of Iterations\t CV Risk " + "SE of CV Risk "); for (int k = 0; k < Rcv.length; k++) { System.out.printf(" %d %3d %6.2f %7.2f\n", k, cvIterations[k], Rcv[k], SERcv[k]); } System.out.printf("Minimum CV Risk Values: %5.4f\n", cvError); System.out.printf("Minimum CV Risk + Standard error: %5.4f\n", (cvError + SERcv[0])); } }
Model Number Of Iterations CV Risk SE of CV Risk 0 10 106.89 926.69 1 30 102.32 850.83 2 50 98.43 785.00 3 100 111.95 1016.45 4 500 124.51 1257.52 Minimum CV Risk Values: 98.4270 Minimum CV Risk + Standard error: 1025.1126Link to Java source.