This example uses stochastic gradient boosting to obtain fitted values for a regression variable on a small data set with 6 predictor variables.
import com.imsl.datamining.GradientBoosting; import com.imsl.stat.Random; public class GradientBoostingEx1 { public static void main(String[] args) throws Exception { double[][] XY = { {4.45617685, 0.8587425048, 1.2705688183, 0.0, 0.0, 1.0, 0.836626959}, {3.01895357, 0.8928761308, 1.3886538362, 2.0, 1.0, 2.0, 2.155131825}, {5.16899757, 0.7385954093, 1.5773203815, 0.0, 4.0, 2.0, 0.075368922}, {-0.23062048, 0.6227398487, 0.0228797458, 3.0, 4.0, 2.0, 0.070793233}, {2.43144968, 0.8519553537, 1.2141886768, 2.0, 4.0, 2.0, 0.762200702}, {2.28255119, 0.5578103897, 0.9185446175, 2.0, 4.0, 2.0, 0.085492814}, {4.51650903, 0.4178302658, 1.3686663737, 0.0, 0.0, 0.0, 2.573941051}, {5.42996967, 0.9829705667, 0.7817731784, 0.0, 5.0, 1.0, 0.865016054}, {0.99551212, 0.3859238869, 0.2746516233, 3.0, 4.0, 0.0, 1.908151819}, {1.23525017, 0.4165328839, 1.3154437956, 3.0, 4.0, 2.0, 2.752358041}, {1.51599306, 0.2008399745, 0.9003028921, 3.0, 0.0, 2.0, 1.437127559}, {2.72854297, 0.2072261081, 1.2282209327, 2.0, 5.0, 2.0, 0.68596562}, {3.06956138, 0.9067490781, 0.8283077031, 2.0, 0.0, 2.0, 2.862403627}, {1.81659279, 0.4506153886, 1.2822537781, 3.0, 4.0, 2.0, 1.710525684}, {3.75978142, 0.2638894715, 0.4995447062, 0.0, 1.0, 1.0, 1.077172402}, {5.72383445, 0.7682430062, 1.4758595745, 0.0, 3.0, 1.0, 2.365233736}, {3.78155015, 0.6888140934, 0.4809393724, 0.0, 0.0, 1.0, 1.061246069}, {3.60023233, 0.8470419827, 1.6149122352, 1.0, 1.0, 0.0, 0.01120048}, {4.30238917, 0.9484412405, 1.6122899544, 1.0, 4.0, 2.0, 0.782038861}, {-0.19206757, 0.7674867723, 0.01665624, 3.0, 5.0, 2.0, 2.924944949}, {3.03246318, 0.8747456241, 1.6051767552, 2.0, 1.0, 0.0, 2.233971364}, {1.56652306, 0.0947128241, 1.470864601, 3.0, 0.0, 1.0, 1.851705944}, {2.77490671, 0.1347932827, 1.3693161067, 1.0, 2.0, 0.0, 0.795709459}, {1.05042043, 0.258093959, 0.4679728113, 3.0, 5.0, 0.0, 2.897785557}, {2.73366469, 0.152943752, 0.5244769375, 1.0, 4.0, 2.0, 2.712871963}, {1.78996951, 0.7921472492, 0.4686144991, 2.0, 4.0, 1.0, 1.295327727}, {1.10343272, 0.123231777, 0.563989053, 2.0, 4.0, 1.0, 0.510414582}, {1.70883743, 0.1931027549, 1.8561577178, 3.0, 5.0, 1.0, 0.165721288}, {2.17977731, 0.316932481, 1.3376214528, 2.0, 2.0, 0.0, 2.366607214}, {2.46127675, 0.9601344266, 0.2090187217, 1.0, 3.0, 1.0, 0.846218965}, {1.92249547, 0.1104206559, 1.739415036, 3.0, 0.0, 0.0, 0.652622544}, {5.81907137, 0.7049566596, 1.6238740934, 0.0, 3.0, 0.0, 1.685337845}, {2.04774497, 0.0480224835, 0.7510998738, 2.0, 5.0, 2.0, 1.400641323}, {4.54023907, 0.0557708007, 1.0864350675, 0.0, 1.0, 1.0, 1.630408823}, {3.66100874, 0.2939440177, 0.9709178614, 0.0, 1.0, 0.0, 0.06970193}, {4.39253655, 0.0982369843, 1.2492676578, 0.0, 2.0, 2.0, 0.138188998}, {3.23303353, 0.3775206071, 0.2937129182, 0.0, 0.0, 2.0, 1.070823081}, {3.13800098, 0.7891691434, 1.90897633, 2.0, 3.0, 0.0, 1.240732062}, {1.49034639, 0.2456938969, 0.9157859818, 3.0, 5.0, 0.0, 0.850803277}, {0.09486277, 0.1240615626, 0.3891524528, 3.0, 5.0, 0.0, 2.532516038}, {3.74460501, 0.0181218453, 1.4921644945, 1.0, 2.0, 1.0, 1.92839241}, {3.24158796, 0.9203409508, 1.1644667462, 2.0, 3.0, 1.0, 1.956283022}, {1.97796767, 0.5977597698, 0.5501609747, 2.0, 5.0, 2.0, 0.39384095}, {4.15214037, 0.1433333508, 1.4292114358, 1.0, 0.0, 0.0, 1.114095218}, {0.7799787, 0.8539819908, 0.7039108537, 3.0, 0.0, 1.0, 1.468978726}, {2.01869009, 0.8919721926, 1.1436212659, 3.0, 4.0, 1.0, 2.09256257}, {0.56311561, 0.0899261576, 0.7989077698, 3.0, 5.0, 0.0, 0.195650739}, {4.74296429, 0.9625684835, 1.5732420743, 0.0, 3.0, 2.0, 2.685061853}, {2.97981809, 0.5511086562, 1.6053283028, 2.0, 5.0, 2.0, 0.906810926}, {2.82187135, 0.3869563073, 0.9321342241, 1.0, 5.0, 1.0, 0.756223386}, {5.24390592, 0.3500950718, 1.7769328682, 0.0, 3.0, 2.0, 1.328165314}, {3.17307157, 0.8798056154, 1.4647966106, 2.0, 5.0, 1.0, 0.561835038}, {0.78246075, 0.1472158518, 0.4658273738, 2.0, 0.0, 0.0, 1.317240539}, {1.57827027, 0.3415432149, 0.7513634153, 2.0, 2.0, 0.0, 1.502675544}, {0.84104905, 0.1501226462, 0.9332020828, 3.0, 1.0, 2.0, 1.083374695}, {2.63627352, 0.1707233109, 1.1676406977, 2.0, 3.0, 0.0, 2.236639737}, {1.30863625, 0.2616807753, 0.8342161868, 3.0, 2.0, 2.0, 1.778402721}, {2.7313073, 0.9616109401, 1.596915911, 3.0, 3.0, 1.0, 0.303127344}, {3.56848173, 0.4072918599, 1.5345127448, 1.0, 2.0, 2.0, 1.47452504}, {5.40152982, 0.7796053565, 1.3659530994, 0.0, 4.0, 1.0, 0.484531098}, {3.94901823, 0.5052344366, 1.9319026601, 1.0, 2.0, 0.0, 2.504392843}}; GradientBoosting.VariableType[] VarType = { GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS, GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS, GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS, GradientBoosting.VariableType.CATEGORICAL, GradientBoosting.VariableType.CATEGORICAL, GradientBoosting.VariableType.CATEGORICAL, GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS }; GradientBoosting gb = new GradientBoosting(XY,0,VarType); gb.setShrinkageParameter(0.05); gb.setSampleSizeProportion(0.5); gb.setRandomObject(new Random(123457)); gb.fitModel(); double[] fittedValues = gb.predict(); System.out.println("Fitted Values vs Actuals"); for (int i = 0; i < fittedValues.length; i++) { System.out.printf("%5.3f, %5.3f\n", fittedValues[i], XY[i][0]); } System.out.printf("Loss Value: %5.5f\n", gb.getLossValue()); } }
Fitted Values vs Actuals 4.341, 4.456 2.865, 3.019 4.485, 5.169 1.217, -0.231 2.757, 2.431 2.263, 2.283 4.212, 4.517 3.829, 5.430 1.177, 0.996 1.878, 1.235 1.473, 1.516 2.628, 2.729 2.352, 3.070 1.878, 1.817 3.511, 3.760 4.485, 5.724 3.551, 3.782 3.622, 3.600 3.622, 4.302 1.306, -0.192 2.902, 3.032 2.022, 1.567 3.349, 2.775 1.177, 1.050 2.648, 2.734 2.056, 1.790 1.927, 1.103 2.022, 1.709 2.628, 2.180 2.777, 2.461 2.022, 1.922 4.485, 5.819 2.039, 2.048 4.212, 4.540 3.929, 3.661 4.212, 4.393 3.511, 3.233 2.902, 3.138 1.473, 1.490 1.177, 0.095 3.493, 3.745 2.757, 3.242 1.968, 1.978 3.493, 4.152 1.306, 0.780 2.007, 2.019 1.366, 0.563 4.485, 4.743 2.813, 2.980 2.943, 2.822 4.356, 5.244 2.902, 3.173 1.927, 0.782 2.039, 1.578 1.473, 0.841 2.628, 2.636 1.473, 1.309 2.151, 2.731 3.493, 3.568 4.341, 5.402 3.533, 3.949 Loss Value: 0.35968Link to Java source.