Example 2: Gradient Boosting

This example uses stochastic gradient boosting to obtain fitted probability estimates for a binary response variable and 4 predictor variables. The estimated probabilities are obtained for the training data and a small test data set. Probabilities less than or equal to 0.5 are associated with Y=0, while probabilities greater than 0.5 associate with Y=1 and would lead to these predictions on the test data.


import com.imsl.datamining.GradientBoosting;
import com.imsl.stat.Random;

public class GradientBoostingEx2 {

    public static void main(String[] args) throws Exception {

        GradientBoosting.VariableType[] VarType = {
            GradientBoosting.VariableType.CATEGORICAL,
            GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS,
            GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS,
            GradientBoosting.VariableType.CATEGORICAL,
            GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS
        };
        double[][] trainingData = {
            {0.0, 0.4223019897, 1.7540411302, 3.0, 0.763836258},
            {0.0, 0.0907259332, 0.8722643796, 2.0, 1.859006285},
            {0.0, 0.1384744535, 0.838324877, 1.0, 0.249729405},
            {1.0, 0.5435024537, 1.2359190206, 4.0, 0.831992314},
            {0.0, 0.8359154933, 1.8527500411, 1.0, 1.089201049},
            {1.0, 0.3577950741, 0.3652825342, 3.0, 2.204364955},
            {1.0, 0.6799094002, 0.6610595905, 3.0, 1.44730419},
            {0.0, 0.5821297709, 1.6180879478, 1.0, 2.957565282},
            {1.0, 0.8229457375, 1.0201675948, 3.0, 2.872570117},
            {0.0, 0.0633462721, 0.4140600134, 1.0, 0.63906323},
            {1.0, 0.1019134156, 0.0677204356, 3.0, 1.493447564},
            {0.0, 0.1551713238, 1.541201456, 3.0, 1.90219884},
            {1.0, 0.8273822817, 0.2114979578, 3.0, 2.855730173},
            {0.0, 0.7955570114, 1.8757067556, 2.0, 2.930132627},
            {0.0, 0.6537275917, 1.2139678737, 2.0, 1.535853243},
            {1.0, 0.1243124125, 1.5130919744, 4.0, 2.733670775},
            {0.0, 0.2163864174, 0.7051185896, 2.0, 2.755841087},
            {0.0, 0.2522670308, 1.2821007571, 2.0, 0.342119491},
            {0.0, 0.8677104027, 1.9003869346, 2.0, 2.454376481},
            {1.0, 0.8670932774, 0.7993045617, 4.0, 2.732812615},
            {0.0, 0.5384287981, 0.1856947718, 1.0, 1.838702635},
            {0.0, 0.7236269342, 0.4993310347, 1.0, 1.030699128},
            {0.0, 0.0789361731, 1.011216166, 1.0, 2.539607478},
            {1.0, 0.7631686032, 0.0536725423, 2.0, 1.401761686},
            {0.0, 0.1157020777, 0.0123261618, 1.0, 2.098372295},
            {1.0, 0.1451248352, 1.9153951635, 3.0, 0.492650534},
            {1.0, 0.8497178114, 1.80941298, 4.0, 2.653985489},
            {0.0, 0.8027864883, 1.2631045617, 3.0, 2.716214291},
            {0.0, 0.798560373, 0.6872106791, 2.0, 2.763023936},
            {1.0, 0.1816879204, 0.4323868025, 4.0, 0.098090197},
            {1.0, 0.6301239238, 0.3670980479, 3.0, 0.02313788},
            {1.0, 0.0411311248, 0.0173408454, 3.0, 1.994786958},
            {1.0, 0.0427366099, 0.8114635572, 3.0, 2.966069741},
            {1.0, 0.4107826762, 0.1929467283, 4.0, 0.573832348},
            {0.0, 0.9441903098, 0.0729898885, 1.0, 1.710992303},
            {1.0, 0.3597549822, 0.2799857073, 2.0, 0.969428934},
            {0.0, 0.3741368004, 1.6052779425, 2.0, 1.866030486},
            {0.0, 0.3515911719, 0.3383029872, 1.0, 2.639469598},
            {0.0, 0.9184092905, 1.7116801264, 1.0, 1.380178652},
            {1.0, 0.77803064, 1.9830028405, 3.0, 1.834021992},
            {0.0, 0.573786814, 0.0258851023, 1.0, 1.52130144},
            {1.0, 0.3279244492, 0.6977945678, 4.0, 1.322451157},
            {0.0, 0.7924819048, 0.3694838509, 1.0, 2.369654865},
            {0.0, 0.9787846403, 1.1470323382, 2.0, 0.037156113},
            {1.0, 0.6910662795, 0.1019420708, 2.0, 2.58588334},
            {0.0, 0.1367050812, 0.6635301332, 2.0, 0.368273583},
            {0.0, 0.2826360366, 1.4468787988, 1.0, 2.705811968},
            {0.0, 0.4524727969, 0.7885378413, 2.0, 0.851228449},
            {0.0, 0.5118664701, 1.061143666, 1.0, 0.249325278},
            {0.0, 0.9965170731, 0.2068265025, 2.0, 0.9210639},
            {1.0, 0.7801500652, 1.565742691, 4.0, 1.827419217},
            {0.0, 0.2906187973, 1.7036567871, 2.0, 2.842997725},
            {0.0, 0.1753704017, 0.7124397112, 2.0, 1.262811961},
            {1.0, 0.7796778064, 0.3478030777, 3.0, 0.90719801},
            {1.0, 0.3889356288, 1.1771452101, 4.0, 1.298438454},
            {0.0, 0.9374473374, 1.1879778663, 1.0, 1.854424331},
            {1.0, 0.1939157653, 0.093336341, 4.0, 0.166025681},
            {1.0, 0.2023756928, 0.0623724433, 3.0, 0.536441906},
            {0.0, 0.1691352043, 1.1587338657, 2.0, 2.15494096},
            {1.0, 0.0921523357, 0.2247394961, 3.0, 2.006995301},
            {0.0, 0.819186907, 0.0392292971, 1.0, 1.282159743},
            {0.0, 0.9458126165, 1.5268264762, 1.0, 1.960050194},
            {0.0, 0.1373939656, 1.8025095677, 2.0, 0.633624267},
            {0.0, 0.0555424779, 0.5022063241, 2.0, 0.639495004},
            {1.0, 0.3581428374, 1.4436954968, 3.0, 1.408938169},
            {1.0, 0.1189418568, 0.8011626904, 4.0, 0.210266769},
            {1.0, 0.5782070206, 1.58215921, 3.0, 2.648622607},
            {0.0, 0.460689794, 0.0704823257, 1.0, 1.45671379},
            {0.0, 0.6959878858, 0.2245675903, 2.0, 1.849515461},
            {0.0, 0.1930288749, 0.6296302159, 2.0, 2.597390946},
            {0.0, 0.4912149447, 0.0713489084, 1.0, 0.426487798},
            {0.0, 0.3496920248, 1.0135462089, 1.0, 2.962295362},
            {1.0, 0.7716284667, 0.5387295927, 4.0, 0.736709363},
            {1.0, 0.3463061263, 0.7819578522, 4.0, 1.597238498},
            {1.0, 0.6897138762, 1.2793166582, 4.0, 2.376281484},
            {0.0, 0.2818824656, 1.4379718141, 3.0, 2.627468417},
            {0.0, 0.5659798421, 1.6243568249, 1.0, 1.624809581},
            {0.0, 0.7965560518, 0.3933029529, 2.0, 0.415849269},
            {0.0, 0.9156922165, 1.0465683565, 1.0, 2.802914008},
            {0.0, 0.8299879942, 1.2237155279, 1.0, 2.611676934},
            {0.0, 0.0241912066, 1.9213823564, 1.0, 0.659596571},
            {0.0, 0.0948590154, 0.3609640412, 1.0, 1.287687748},
            {0.0, 0.230467916, 1.9421709292, 3.0, 2.290064565},
            {0.0, 0.2209760561, 0.4812708795, 1.0, 1.862393057},
            {0.0, 0.4704530933, 0.2644400774, 1.0, 1.960189529},
            {1.0, 0.1986645423, 0.48924731, 2.0, 0.333790415},
            {0.0, 0.9201823308, 1.4247304946, 1.0, 0.367654009},
            {1.0, 0.8118424334, 0.1017034058, 2.0, 2.001390385},
            {1.0, 0.1347265388, 0.1362061207, 3.0, 1.151431168},
            {0.0, 0.9884603191, 1.5700038988, 2.0, 0.717332943},
            {0.0, 0.1964012324, 0.4306495111, 1.0, 1.689056823},
            {1.0, 0.4031848807, 1.1251849262, 4.0, 1.977734922},
            {1.0, 0.0341882701, 0.3717348906, 4.0, 1.830587439},
            {0.0, 0.5073120815, 1.7860476542, 3.0, 0.142862822},
            {0.0, 0.6363195451, 0.6631249222, 2.0, 1.211148724},
            {1.0, 0.1642774614, 1.1963615627, 3.0, 0.843113448},
            {0.0, 0.0945515088, 1.8669327218, 1.0, 2.417198514},
            {0.0, 0.2364508687, 1.4035215094, 2.0, 2.964026097},
            {1.0, 0.7490112646, 0.1778408242, 4.0, 2.343119453},
            {1.0, 0.5193473259, 0.3090019161, 3.0, 1.300277323}};

        double[][] testData = {
            {0.0, 0.0093314846, 0.0315045565, 1.0, 2.043737003},
            {0.0, 0.0663379349, 0.0822378928, 2.0, 1.202557951},
            {1.0, 0.9728333529, 0.8778284262, 4.0, 0.205940753},
            {1.0, 0.7655418115, 0.3292853828, 4.0, 2.940793653},
            {1.0, 0.1610695978, 0.3832762009, 4.0, 1.96753633},
            {0.0, 0.0849463812, 1.4988451041, 2.0, 2.307902221},
            {0.0, 0.7932621511, 1.2098399368, 1.0, 0.886761862},
            {0.0, 0.1336030525, 0.2794256401, 2.0, 2.672175208},
            {0.0, 0.4758480834, 0.0441179522, 1.0, 0.399722717},
            {1.0, 0.1137434335, 0.922533263, 3.0, 1.927635631}};

        GradientBoosting gb = new GradientBoosting(trainingData, 0, VarType);
        gb.setShrinkageParameter(0.05);
        gb.setSampleSizeProportion(0.5);
        gb.setRandomObject(new Random(123457));
        gb.fitModel();

        /* Run gradient boosting, generating fitted values and predicted values 
         on the test data.
         */
        gb.predict(testData);
        /* Retrieve the fitted class probabilities on the training data.
         For binomial data, there will be only 1 column in the matrix. 
         */
        double[][] probabilities = gb.getClassProbabilities();

        System.out.println("Training Data Probabilities vs Actuals");
        for (int i = 0; i < probabilities.length; i++) {
            System.out.printf("%5.3f, %5.3f\n", probabilities[i][0],
                    trainingData[i][0]);
        }
        System.out.printf("Training Data Loss  Function Value: %5.5f\n",
                gb.getLossValue());

        /* Retrieve the predicted binomial probabilities on the test data. */
        double[][] testProbabilities = gb.getTestClassProbabilities();
        System.out.println("Test Data Probabilities vs Actuals");
        for (int i = 0; i < testProbabilities.length; i++) {
            System.out.printf("%5.3f, %5.3f\n", testProbabilities[i][0],
                    testData[i][0]);
        }
        System.out.printf("Test Data Loss Function Value: %5.5f\n",
                gb.getTestLossValue());
    }
}

Output

Training Data Probabilities vs Actuals
0.648, 0.000
0.173, 0.000
0.122, 0.000
0.784, 1.000
0.090, 0.000
0.766, 1.000
0.721, 1.000
0.090, 0.000
0.710, 1.000
0.129, 0.000
0.799, 1.000
0.660, 0.000
0.788, 1.000
0.136, 0.000
0.165, 0.000
0.763, 1.000
0.181, 0.000
0.149, 0.000
0.136, 0.000
0.802, 1.000
0.166, 0.000
0.122, 0.000
0.116, 0.000
0.253, 1.000
0.176, 0.000
0.648, 1.000
0.743, 1.000
0.673, 0.000
0.181, 0.000
0.812, 1.000
0.766, 1.000
0.799, 1.000
0.721, 1.000
0.853, 1.000
0.176, 0.000
0.240, 1.000
0.136, 0.000
0.150, 0.000
0.090, 0.000
0.648, 1.000
0.176, 0.000
0.802, 1.000
0.150, 0.000
0.173, 0.000
0.253, 1.000
0.181, 0.000
0.100, 0.000
0.181, 0.000
0.116, 0.000
0.240, 0.000
0.753, 1.000
0.136, 0.000
0.181, 0.000
0.766, 1.000
0.793, 1.000
0.116, 0.000
0.862, 1.000
0.799, 1.000
0.173, 0.000
0.788, 1.000
0.176, 0.000
0.095, 0.000
0.136, 0.000
0.181, 0.000
0.673, 1.000
0.802, 1.000
0.660, 1.000
0.176, 0.000
0.240, 0.000
0.181, 0.000
0.176, 0.000
0.116, 0.000
0.802, 1.000
0.802, 1.000
0.763, 1.000
0.673, 0.000
0.090, 0.000
0.199, 0.000
0.116, 0.000
0.111, 0.000
0.090, 0.000
0.150, 0.000
0.648, 0.000
0.129, 0.000
0.166, 0.000
0.190, 1.000
0.100, 0.000
0.253, 1.000
0.799, 1.000
0.142, 0.000
0.129, 0.000
0.793, 1.000
0.837, 1.000
0.648, 0.000
0.181, 0.000
0.710, 1.000
0.090, 0.000
0.149, 0.000
0.853, 1.000
0.788, 1.000
Training Data Loss  Function Value: 0.62967
Test Data Probabilities vs Actuals
0.176, 0.000
0.253, 0.000
0.793, 1.000
0.845, 1.000
0.821, 1.000
0.149, 0.000
0.111, 0.000
0.240, 0.000
0.176, 0.000
0.710, 1.000
Test Data Loss Function Value: 0.48412
Link to Java source.