Example 1: Gradient Boosting

This example uses stochastic gradient boosting to obtain fitted values for a regression variable on a small data set with 6 predictor variables.

import com.imsl.datamining.GradientBoosting;
import com.imsl.stat.Random;

public class GradientBoostingEx1 {

    public static void main(String[] args) throws Exception {

        double[][] XY = {
            {4.45617685, 0.8587425048, 1.2705688183, 0.0, 0.0, 1.0, 0.836626959},
            {3.01895357, 0.8928761308, 1.3886538362, 2.0, 1.0, 2.0, 2.155131825},
            {5.16899757, 0.7385954093, 1.5773203815, 0.0, 4.0, 2.0, 0.075368922},
            {-0.23062048, 0.6227398487, 0.0228797458, 3.0, 4.0, 2.0, 0.070793233},
            {2.43144968, 0.8519553537, 1.2141886768, 2.0, 4.0, 2.0, 0.762200702},
            {2.28255119, 0.5578103897, 0.9185446175, 2.0, 4.0, 2.0, 0.085492814},
            {4.51650903, 0.4178302658, 1.3686663737, 0.0, 0.0, 0.0, 2.573941051},
            {5.42996967, 0.9829705667, 0.7817731784, 0.0, 5.0, 1.0, 0.865016054},
            {0.99551212, 0.3859238869, 0.2746516233, 3.0, 4.0, 0.0, 1.908151819},
            {1.23525017, 0.4165328839, 1.3154437956, 3.0, 4.0, 2.0, 2.752358041},
            {1.51599306, 0.2008399745, 0.9003028921, 3.0, 0.0, 2.0, 1.437127559},
            {2.72854297, 0.2072261081, 1.2282209327, 2.0, 5.0, 2.0, 0.68596562},
            {3.06956138, 0.9067490781, 0.8283077031, 2.0, 0.0, 2.0, 2.862403627},
            {1.81659279, 0.4506153886, 1.2822537781, 3.0, 4.0, 2.0, 1.710525684},
            {3.75978142, 0.2638894715, 0.4995447062, 0.0, 1.0, 1.0, 1.077172402},
            {5.72383445, 0.7682430062, 1.4758595745, 0.0, 3.0, 1.0, 2.365233736},
            {3.78155015, 0.6888140934, 0.4809393724, 0.0, 0.0, 1.0, 1.061246069},
            {3.60023233, 0.8470419827, 1.6149122352, 1.0, 1.0, 0.0, 0.01120048},
            {4.30238917, 0.9484412405, 1.6122899544, 1.0, 4.0, 2.0, 0.782038861},
            {-0.19206757, 0.7674867723, 0.01665624, 3.0, 5.0, 2.0, 2.924944949},
            {3.03246318, 0.8747456241, 1.6051767552, 2.0, 1.0, 0.0, 2.233971364},
            {1.56652306, 0.0947128241, 1.470864601, 3.0, 0.0, 1.0, 1.851705944},
            {2.77490671, 0.1347932827, 1.3693161067, 1.0, 2.0, 0.0, 0.795709459},
            {1.05042043, 0.258093959, 0.4679728113, 3.0, 5.0, 0.0, 2.897785557},
            {2.73366469, 0.152943752, 0.5244769375, 1.0, 4.0, 2.0, 2.712871963},
            {1.78996951, 0.7921472492, 0.4686144991, 2.0, 4.0, 1.0, 1.295327727},
            {1.10343272, 0.123231777, 0.563989053, 2.0, 4.0, 1.0, 0.510414582},
            {1.70883743, 0.1931027549, 1.8561577178, 3.0, 5.0, 1.0, 0.165721288},
            {2.17977731, 0.316932481, 1.3376214528, 2.0, 2.0, 0.0, 2.366607214},
            {2.46127675, 0.9601344266, 0.2090187217, 1.0, 3.0, 1.0, 0.846218965},
            {1.92249547, 0.1104206559, 1.739415036, 3.0, 0.0, 0.0, 0.652622544},
            {5.81907137, 0.7049566596, 1.6238740934, 0.0, 3.0, 0.0, 1.685337845},
            {2.04774497, 0.0480224835, 0.7510998738, 2.0, 5.0, 2.0, 1.400641323},
            {4.54023907, 0.0557708007, 1.0864350675, 0.0, 1.0, 1.0, 1.630408823},
            {3.66100874, 0.2939440177, 0.9709178614, 0.0, 1.0, 0.0, 0.06970193},
            {4.39253655, 0.0982369843, 1.2492676578, 0.0, 2.0, 2.0, 0.138188998},
            {3.23303353, 0.3775206071, 0.2937129182, 0.0, 0.0, 2.0, 1.070823081},
            {3.13800098, 0.7891691434, 1.90897633, 2.0, 3.0, 0.0, 1.240732062},
            {1.49034639, 0.2456938969, 0.9157859818, 3.0, 5.0, 0.0, 0.850803277},
            {0.09486277, 0.1240615626, 0.3891524528, 3.0, 5.0, 0.0, 2.532516038},
            {3.74460501, 0.0181218453, 1.4921644945, 1.0, 2.0, 1.0, 1.92839241},
            {3.24158796, 0.9203409508, 1.1644667462, 2.0, 3.0, 1.0, 1.956283022},
            {1.97796767, 0.5977597698, 0.5501609747, 2.0, 5.0, 2.0, 0.39384095},
            {4.15214037, 0.1433333508, 1.4292114358, 1.0, 0.0, 0.0, 1.114095218},
            {0.7799787, 0.8539819908, 0.7039108537, 3.0, 0.0, 1.0, 1.468978726},
            {2.01869009, 0.8919721926, 1.1436212659, 3.0, 4.0, 1.0, 2.09256257},
            {0.56311561, 0.0899261576, 0.7989077698, 3.0, 5.0, 0.0, 0.195650739},
            {4.74296429, 0.9625684835, 1.5732420743, 0.0, 3.0, 2.0, 2.685061853},
            {2.97981809, 0.5511086562, 1.6053283028, 2.0, 5.0, 2.0, 0.906810926},
            {2.82187135, 0.3869563073, 0.9321342241, 1.0, 5.0, 1.0, 0.756223386},
            {5.24390592, 0.3500950718, 1.7769328682, 0.0, 3.0, 2.0, 1.328165314},
            {3.17307157, 0.8798056154, 1.4647966106, 2.0, 5.0, 1.0, 0.561835038},
            {0.78246075, 0.1472158518, 0.4658273738, 2.0, 0.0, 0.0, 1.317240539},
            {1.57827027, 0.3415432149, 0.7513634153, 2.0, 2.0, 0.0, 1.502675544},
            {0.84104905, 0.1501226462, 0.9332020828, 3.0, 1.0, 2.0, 1.083374695},
            {2.63627352, 0.1707233109, 1.1676406977, 2.0, 3.0, 0.0, 2.236639737},
            {1.30863625, 0.2616807753, 0.8342161868, 3.0, 2.0, 2.0, 1.778402721},
            {2.7313073, 0.9616109401, 1.596915911, 3.0, 3.0, 1.0, 0.303127344},
            {3.56848173, 0.4072918599, 1.5345127448, 1.0, 2.0, 2.0, 1.47452504},
            {5.40152982, 0.7796053565, 1.3659530994, 0.0, 4.0, 1.0, 0.484531098},
            {3.94901823, 0.5052344366, 1.9319026601, 1.0, 2.0, 0.0, 2.504392843}};

      GradientBoosting.VariableType[] VarType = {
            GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS,
            GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS,
            GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS,
            GradientBoosting.VariableType.CATEGORICAL,
            GradientBoosting.VariableType.CATEGORICAL,
            GradientBoosting.VariableType.CATEGORICAL,
            GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS
        };

        GradientBoosting gb = new GradientBoosting(XY,0,VarType);
        gb.setShrinkageParameter(0.05);
        gb.setSampleSizeProportion(0.5);
        gb.setRandomObject(new Random(123457));
        gb.fitModel();
        double[] fittedValues = gb.predict();

        System.out.println("Fitted Values vs Actuals");
        for (int i = 0; i < fittedValues.length; i++) {
            System.out.printf("%5.3f, %5.3f\n", fittedValues[i], XY[i][0]);
        }
        System.out.printf("Loss Value: %5.5f\n", gb.getLossValue());
    }
}

Output

Fitted Values vs Actuals
4.341, 4.456
2.865, 3.019
4.485, 5.169
1.217, -0.231
2.757, 2.431
2.263, 2.283
4.212, 4.517
3.829, 5.430
1.177, 0.996
1.878, 1.235
1.473, 1.516
2.628, 2.729
2.352, 3.070
1.878, 1.817
3.511, 3.760
4.485, 5.724
3.551, 3.782
3.622, 3.600
3.622, 4.302
1.306, -0.192
2.902, 3.032
2.022, 1.567
3.349, 2.775
1.177, 1.050
2.648, 2.734
2.056, 1.790
1.927, 1.103
2.022, 1.709
2.628, 2.180
2.777, 2.461
2.022, 1.922
4.485, 5.819
2.039, 2.048
4.212, 4.540
3.929, 3.661
4.212, 4.393
3.511, 3.233
2.902, 3.138
1.473, 1.490
1.177, 0.095
3.493, 3.745
2.757, 3.242
1.968, 1.978
3.493, 4.152
1.306, 0.780
2.007, 2.019
1.366, 0.563
4.485, 4.743
2.813, 2.980
2.943, 2.822
4.356, 5.244
2.902, 3.173
1.927, 0.782
2.039, 1.578
1.473, 0.841
2.628, 2.636
1.473, 1.309
2.151, 2.731
3.493, 3.568
4.341, 5.402
3.533, 3.949
Loss Value: 0.35968
Link to Java source.