This example applies cross-validation to the number of iterations parameter in stochastic gradient boosting. The number of iterations with the minimum cross-validated risk estimate is 50.
import com.imsl.datamining.CrossValidation;
import com.imsl.datamining.GradientBoosting;
import com.imsl.stat.Random;
public class GradientBoostingEx3 {
public static void main(String[] args) throws Exception {
GradientBoosting.VariableType[] VarType = {
GradientBoosting.VariableType.CATEGORICAL,
GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS,
GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS,
GradientBoosting.VariableType.CATEGORICAL,
GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS
};
double[][] trainingData = {
{0.0, 0.4223019897, 1.7540411302, 3.0, 0.763836258},
{0.0, 0.0907259332, 0.8722643796, 2.0, 1.859006285},
{0.0, 0.1384744535, 0.838324877, 1.0, 0.249729405},
{1.0, 0.5435024537, 1.2359190206, 4.0, 0.831992314},
{0.0, 0.8359154933, 1.8527500411, 1.0, 1.089201049},
{1.0, 0.3577950741, 0.3652825342, 3.0, 2.204364955},
{1.0, 0.6799094002, 0.6610595905, 3.0, 1.44730419},
{0.0, 0.5821297709, 1.6180879478, 1.0, 2.957565282},
{1.0, 0.8229457375, 1.0201675948, 3.0, 2.872570117},
{0.0, 0.0633462721, 0.4140600134, 1.0, 0.63906323},
{1.0, 0.1019134156, 0.0677204356, 3.0, 1.493447564},
{0.0, 0.1551713238, 1.541201456, 3.0, 1.90219884},
{1.0, 0.8273822817, 0.2114979578, 3.0, 2.855730173},
{0.0, 0.7955570114, 1.8757067556, 2.0, 2.930132627},
{0.0, 0.6537275917, 1.2139678737, 2.0, 1.535853243},
{1.0, 0.1243124125, 1.5130919744, 4.0, 2.733670775},
{0.0, 0.2163864174, 0.7051185896, 2.0, 2.755841087},
{0.0, 0.2522670308, 1.2821007571, 2.0, 0.342119491},
{0.0, 0.8677104027, 1.9003869346, 2.0, 2.454376481},
{1.0, 0.8670932774, 0.7993045617, 4.0, 2.732812615},
{0.0, 0.5384287981, 0.1856947718, 1.0, 1.838702635},
{0.0, 0.7236269342, 0.4993310347, 1.0, 1.030699128},
{0.0, 0.0789361731, 1.011216166, 1.0, 2.539607478},
{1.0, 0.7631686032, 0.0536725423, 2.0, 1.401761686},
{0.0, 0.1157020777, 0.0123261618, 1.0, 2.098372295},
{1.0, 0.1451248352, 1.9153951635, 3.0, 0.492650534},
{1.0, 0.8497178114, 1.80941298, 4.0, 2.653985489},
{0.0, 0.8027864883, 1.2631045617, 3.0, 2.716214291},
{0.0, 0.798560373, 0.6872106791, 2.0, 2.763023936},
{1.0, 0.1816879204, 0.4323868025, 4.0, 0.098090197},
{1.0, 0.6301239238, 0.3670980479, 3.0, 0.02313788},
{1.0, 0.0411311248, 0.0173408454, 3.0, 1.994786958},
{1.0, 0.0427366099, 0.8114635572, 3.0, 2.966069741},
{1.0, 0.4107826762, 0.1929467283, 4.0, 0.573832348},
{0.0, 0.9441903098, 0.0729898885, 1.0, 1.710992303},
{1.0, 0.3597549822, 0.2799857073, 2.0, 0.969428934},
{0.0, 0.3741368004, 1.6052779425, 2.0, 1.866030486},
{0.0, 0.3515911719, 0.3383029872, 1.0, 2.639469598},
{0.0, 0.9184092905, 1.7116801264, 1.0, 1.380178652},
{1.0, 0.77803064, 1.9830028405, 3.0, 1.834021992},
{0.0, 0.573786814, 0.0258851023, 1.0, 1.52130144},
{1.0, 0.3279244492, 0.6977945678, 4.0, 1.322451157},
{0.0, 0.7924819048, 0.3694838509, 1.0, 2.369654865},
{0.0, 0.9787846403, 1.1470323382, 2.0, 0.037156113},
{1.0, 0.6910662795, 0.1019420708, 2.0, 2.58588334},
{0.0, 0.1367050812, 0.6635301332, 2.0, 0.368273583},
{0.0, 0.2826360366, 1.4468787988, 1.0, 2.705811968},
{0.0, 0.4524727969, 0.7885378413, 2.0, 0.851228449},
{0.0, 0.5118664701, 1.061143666, 1.0, 0.249325278},
{0.0, 0.9965170731, 0.2068265025, 2.0, 0.9210639},
{1.0, 0.7801500652, 1.565742691, 4.0, 1.827419217},
{0.0, 0.2906187973, 1.7036567871, 2.0, 2.842997725},
{0.0, 0.1753704017, 0.7124397112, 2.0, 1.262811961},
{1.0, 0.7796778064, 0.3478030777, 3.0, 0.90719801},
{1.0, 0.3889356288, 1.1771452101, 4.0, 1.298438454},
{0.0, 0.9374473374, 1.1879778663, 1.0, 1.854424331},
{1.0, 0.1939157653, 0.093336341, 4.0, 0.166025681},
{1.0, 0.2023756928, 0.0623724433, 3.0, 0.536441906},
{0.0, 0.1691352043, 1.1587338657, 2.0, 2.15494096},
{1.0, 0.0921523357, 0.2247394961, 3.0, 2.006995301},
{0.0, 0.819186907, 0.0392292971, 1.0, 1.282159743},
{0.0, 0.9458126165, 1.5268264762, 1.0, 1.960050194},
{0.0, 0.1373939656, 1.8025095677, 2.0, 0.633624267},
{0.0, 0.0555424779, 0.5022063241, 2.0, 0.639495004},
{1.0, 0.3581428374, 1.4436954968, 3.0, 1.408938169},
{1.0, 0.1189418568, 0.8011626904, 4.0, 0.210266769},
{1.0, 0.5782070206, 1.58215921, 3.0, 2.648622607},
{0.0, 0.460689794, 0.0704823257, 1.0, 1.45671379},
{0.0, 0.6959878858, 0.2245675903, 2.0, 1.849515461},
{0.0, 0.1930288749, 0.6296302159, 2.0, 2.597390946},
{0.0, 0.4912149447, 0.0713489084, 1.0, 0.426487798},
{0.0, 0.3496920248, 1.0135462089, 1.0, 2.962295362},
{1.0, 0.7716284667, 0.5387295927, 4.0, 0.736709363},
{1.0, 0.3463061263, 0.7819578522, 4.0, 1.597238498},
{1.0, 0.6897138762, 1.2793166582, 4.0, 2.376281484},
{0.0, 0.2818824656, 1.4379718141, 3.0, 2.627468417},
{0.0, 0.5659798421, 1.6243568249, 1.0, 1.624809581},
{0.0, 0.7965560518, 0.3933029529, 2.0, 0.415849269},
{0.0, 0.9156922165, 1.0465683565, 1.0, 2.802914008},
{0.0, 0.8299879942, 1.2237155279, 1.0, 2.611676934},
{0.0, 0.0241912066, 1.9213823564, 1.0, 0.659596571},
{0.0, 0.0948590154, 0.3609640412, 1.0, 1.287687748},
{0.0, 0.230467916, 1.9421709292, 3.0, 2.290064565},
{0.0, 0.2209760561, 0.4812708795, 1.0, 1.862393057},
{0.0, 0.4704530933, 0.2644400774, 1.0, 1.960189529},
{1.0, 0.1986645423, 0.48924731, 2.0, 0.333790415},
{0.0, 0.9201823308, 1.4247304946, 1.0, 0.367654009},
{1.0, 0.8118424334, 0.1017034058, 2.0, 2.001390385},
{1.0, 0.1347265388, 0.1362061207, 3.0, 1.151431168},
{0.0, 0.9884603191, 1.5700038988, 2.0, 0.717332943},
{0.0, 0.1964012324, 0.4306495111, 1.0, 1.689056823},
{1.0, 0.4031848807, 1.1251849262, 4.0, 1.977734922},
{1.0, 0.0341882701, 0.3717348906, 4.0, 1.830587439},
{0.0, 0.5073120815, 1.7860476542, 3.0, 0.142862822},
{0.0, 0.6363195451, 0.6631249222, 2.0, 1.211148724},
{1.0, 0.1642774614, 1.1963615627, 3.0, 0.843113448},
{0.0, 0.0945515088, 1.8669327218, 1.0, 2.417198514},
{0.0, 0.2364508687, 1.4035215094, 2.0, 2.964026097},
{1.0, 0.7490112646, 0.1778408242, 4.0, 2.343119453},
{1.0, 0.5193473259, 0.3090019161, 3.0, 1.300277323}};
GradientBoosting gb = new GradientBoosting(trainingData, 4, VarType);
int[] cvIterations = {10, 30, 50, 100, 500};
gb.setIterationsArray(cvIterations);
gb.setShrinkageParameter(0.05);
gb.setSampleSizeProportion(0.5);
gb.setRandomObject(new Random(123457));
CrossValidation cv = new CrossValidation(gb);
cv.setRandomObject(new Random(123457));
cv.crossValidate();
double cvError = cv.getCrossValidatedError();
double[] Rcv = cv.getRiskValues();
double[] SERcv = cv.getRiskStandardErrors();
System.out.println("\nModel \t Number Of Iterations\t CV Risk "
+ "SE of CV Risk ");
for (int k = 0; k < Rcv.length; k++) {
System.out.printf(" %d %3d %6.2f %7.2f\n",
k, cvIterations[k], Rcv[k], SERcv[k]);
}
System.out.printf("Minimum CV Risk Values: %5.4f\n", cvError);
System.out.printf("Minimum CV Risk + Standard error: %5.4f\n",
(cvError + SERcv[0]));
}
}
Model Number Of Iterations CV Risk SE of CV Risk 0 10 106.89 926.69 1 30 102.32 850.83 2 50 98.43 785.00 3 100 111.95 1016.45 4 500 124.51 1257.52 Minimum CV Risk Values: 98.4270 Minimum CV Risk + Standard error: 1025.1126Link to Java source.