Example 4: Gradient Boosting

This example illustrates using the alternative constructor for gradient boosting. This constructor accepts an instance of a regression tree to serve as the base learner and allows for more flexibility in controlling the base learner configuration.

import com.imsl.datamining.GradientBoosting;
import com.imsl.datamining.decisionTree.ALACART;
import com.imsl.stat.Random;

public class GradientBoostingEx4 {
     public static void main(String[] args) throws Exception {

        GradientBoosting.VariableType[] VarType = {
            GradientBoosting.VariableType.CATEGORICAL,
            GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS,
            GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS,
            GradientBoosting.VariableType.CATEGORICAL,
            GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS
        };
        double[][] trainingData = {
            {0.0, 0.4223019897, 1.7540411302, 3.0, 0.763836258},
            {0.0, 0.0907259332, 0.8722643796, 2.0, 1.859006285},
            {0.0, 0.1384744535, 0.838324877, 1.0, 0.249729405},
            {1.0, 0.5435024537, 1.2359190206, 4.0, 0.831992314},
            {0.0, 0.8359154933, 1.8527500411, 1.0, 1.089201049},
            {1.0, 0.3577950741, 0.3652825342, 3.0, 2.204364955},
            {1.0, 0.6799094002, 0.6610595905, 3.0, 1.44730419},
            {0.0, 0.5821297709, 1.6180879478, 1.0, 2.957565282},
            {1.0, 0.8229457375, 1.0201675948, 3.0, 2.872570117},
            {0.0, 0.0633462721, 0.4140600134, 1.0, 0.63906323},
            {1.0, 0.1019134156, 0.0677204356, 3.0, 1.493447564},
            {0.0, 0.1551713238, 1.541201456, 3.0, 1.90219884},
            {1.0, 0.8273822817, 0.2114979578, 3.0, 2.855730173},
            {0.0, 0.7955570114, 1.8757067556, 2.0, 2.930132627},
            {0.0, 0.6537275917, 1.2139678737, 2.0, 1.535853243},
            {1.0, 0.1243124125, 1.5130919744, 4.0, 2.733670775},
            {0.0, 0.2163864174, 0.7051185896, 2.0, 2.755841087},
            {0.0, 0.2522670308, 1.2821007571, 2.0, 0.342119491},
            {0.0, 0.8677104027, 1.9003869346, 2.0, 2.454376481},
            {1.0, 0.8670932774, 0.7993045617, 4.0, 2.732812615},
            {0.0, 0.5384287981, 0.1856947718, 1.0, 1.838702635},
            {0.0, 0.7236269342, 0.4993310347, 1.0, 1.030699128},
            {0.0, 0.0789361731, 1.011216166, 1.0, 2.539607478},
            {1.0, 0.7631686032, 0.0536725423, 2.0, 1.401761686},
            {0.0, 0.1157020777, 0.0123261618, 1.0, 2.098372295},
            {1.0, 0.1451248352, 1.9153951635, 3.0, 0.492650534},
            {1.0, 0.8497178114, 1.80941298, 4.0, 2.653985489},
            {0.0, 0.8027864883, 1.2631045617, 3.0, 2.716214291},
            {0.0, 0.798560373, 0.6872106791, 2.0, 2.763023936},
            {1.0, 0.1816879204, 0.4323868025, 4.0, 0.098090197},
            {1.0, 0.6301239238, 0.3670980479, 3.0, 0.02313788},
            {1.0, 0.0411311248, 0.0173408454, 3.0, 1.994786958},
            {1.0, 0.0427366099, 0.8114635572, 3.0, 2.966069741},
            {1.0, 0.4107826762, 0.1929467283, 4.0, 0.573832348},
            {0.0, 0.9441903098, 0.0729898885, 1.0, 1.710992303},
            {1.0, 0.3597549822, 0.2799857073, 2.0, 0.969428934},
            {0.0, 0.3741368004, 1.6052779425, 2.0, 1.866030486},
            {0.0, 0.3515911719, 0.3383029872, 1.0, 2.639469598},
            {0.0, 0.9184092905, 1.7116801264, 1.0, 1.380178652},
            {1.0, 0.77803064, 1.9830028405, 3.0, 1.834021992},
            {0.0, 0.573786814, 0.0258851023, 1.0, 1.52130144},
            {1.0, 0.3279244492, 0.6977945678, 4.0, 1.322451157},
            {0.0, 0.7924819048, 0.3694838509, 1.0, 2.369654865},
            {0.0, 0.9787846403, 1.1470323382, 2.0, 0.037156113},
            {1.0, 0.6910662795, 0.1019420708, 2.0, 2.58588334},
            {0.0, 0.1367050812, 0.6635301332, 2.0, 0.368273583},
            {0.0, 0.2826360366, 1.4468787988, 1.0, 2.705811968},
            {0.0, 0.4524727969, 0.7885378413, 2.0, 0.851228449},
            {0.0, 0.5118664701, 1.061143666, 1.0, 0.249325278},
            {0.0, 0.9965170731, 0.2068265025, 2.0, 0.9210639},
            {1.0, 0.7801500652, 1.565742691, 4.0, 1.827419217},
            {0.0, 0.2906187973, 1.7036567871, 2.0, 2.842997725},
            {0.0, 0.1753704017, 0.7124397112, 2.0, 1.262811961},
            {1.0, 0.7796778064, 0.3478030777, 3.0, 0.90719801},
            {1.0, 0.3889356288, 1.1771452101, 4.0, 1.298438454},
            {0.0, 0.9374473374, 1.1879778663, 1.0, 1.854424331},
            {1.0, 0.1939157653, 0.093336341, 4.0, 0.166025681},
            {1.0, 0.2023756928, 0.0623724433, 3.0, 0.536441906},
            {0.0, 0.1691352043, 1.1587338657, 2.0, 2.15494096},
            {1.0, 0.0921523357, 0.2247394961, 3.0, 2.006995301},
            {0.0, 0.819186907, 0.0392292971, 1.0, 1.282159743},
            {0.0, 0.9458126165, 1.5268264762, 1.0, 1.960050194},
            {0.0, 0.1373939656, 1.8025095677, 2.0, 0.633624267},
            {0.0, 0.0555424779, 0.5022063241, 2.0, 0.639495004},
            {1.0, 0.3581428374, 1.4436954968, 3.0, 1.408938169},
            {1.0, 0.1189418568, 0.8011626904, 4.0, 0.210266769},
            {1.0, 0.5782070206, 1.58215921, 3.0, 2.648622607},
            {0.0, 0.460689794, 0.0704823257, 1.0, 1.45671379},
            {0.0, 0.6959878858, 0.2245675903, 2.0, 1.849515461},
            {0.0, 0.1930288749, 0.6296302159, 2.0, 2.597390946},
            {0.0, 0.4912149447, 0.0713489084, 1.0, 0.426487798},
            {0.0, 0.3496920248, 1.0135462089, 1.0, 2.962295362},
            {1.0, 0.7716284667, 0.5387295927, 4.0, 0.736709363},
            {1.0, 0.3463061263, 0.7819578522, 4.0, 1.597238498},
            {1.0, 0.6897138762, 1.2793166582, 4.0, 2.376281484},
            {0.0, 0.2818824656, 1.4379718141, 3.0, 2.627468417},
            {0.0, 0.5659798421, 1.6243568249, 1.0, 1.624809581},
            {0.0, 0.7965560518, 0.3933029529, 2.0, 0.415849269},
            {0.0, 0.9156922165, 1.0465683565, 1.0, 2.802914008},
            {0.0, 0.8299879942, 1.2237155279, 1.0, 2.611676934},
            {0.0, 0.0241912066, 1.9213823564, 1.0, 0.659596571},
            {0.0, 0.0948590154, 0.3609640412, 1.0, 1.287687748},
            {0.0, 0.230467916, 1.9421709292, 3.0, 2.290064565},
            {0.0, 0.2209760561, 0.4812708795, 1.0, 1.862393057},
            {0.0, 0.4704530933, 0.2644400774, 1.0, 1.960189529},
            {1.0, 0.1986645423, 0.48924731, 2.0, 0.333790415},
            {0.0, 0.9201823308, 1.4247304946, 1.0, 0.367654009},
            {1.0, 0.8118424334, 0.1017034058, 2.0, 2.001390385},
            {1.0, 0.1347265388, 0.1362061207, 3.0, 1.151431168},
            {0.0, 0.9884603191, 1.5700038988, 2.0, 0.717332943},
            {0.0, 0.1964012324, 0.4306495111, 1.0, 1.689056823},
            {1.0, 0.4031848807, 1.1251849262, 4.0, 1.977734922},
            {1.0, 0.0341882701, 0.3717348906, 4.0, 1.830587439},
            {0.0, 0.5073120815, 1.7860476542, 3.0, 0.142862822},
            {0.0, 0.6363195451, 0.6631249222, 2.0, 1.211148724},
            {1.0, 0.1642774614, 1.1963615627, 3.0, 0.843113448},
            {0.0, 0.0945515088, 1.8669327218, 1.0, 2.417198514},
            {0.0, 0.2364508687, 1.4035215094, 2.0, 2.964026097},
            {1.0, 0.7490112646, 0.1778408242, 4.0, 2.343119453},
            {1.0, 0.5193473259, 0.3090019161, 3.0, 1.300277323}};

        double[][] testData = {
            {0.0, 0.0093314846, 0.0315045565, 1.0, 2.043737003},
            {0.0, 0.0663379349, 0.0822378928, 2.0, 1.202557951},
            {1.0, 0.9728333529, 0.8778284262, 4.0, 0.205940753},
            {1.0, 0.7655418115, 0.3292853828, 4.0, 2.940793653},
            {1.0, 0.1610695978, 0.3832762009, 4.0, 1.96753633},
            {0.0, 0.0849463812, 1.4988451041, 2.0, 2.307902221},
            {0.0, 0.7932621511, 1.2098399368, 1.0, 0.886761862},
            {0.0, 0.1336030525, 0.2794256401, 2.0, 2.672175208},
            {0.0, 0.4758480834, 0.0441179522, 1.0, 0.399722717},
            {1.0, 0.1137434335, 0.922533263, 3.0, 1.927635631}};

        ALACART dTree= new ALACART(trainingData,0,VarType);
                
        dTree.setMaxDepth(10);
        dTree.setMinObsPerNode(10);
        dTree.setMaxNodes(4);
        
        GradientBoosting gb = new GradientBoosting(dTree);
        gb.setShrinkageParameter(0.05);
        gb.setSampleSizeProportion(0.5);
        gb.setRandomObject(new Random(123457));
        gb.fitModel();
        
        /* Run gradient boosting, generating fitted values and predicted values 
         on the test data.
        */
        gb.predict(testData);
        /* Retrieve the fitted binomial probabilities on the training data. */
        double[][] probabilities = gb.getClassProbabilities();

        System.out.println("Training Data Probabilities vs Actuals");
        for (int i = 0; i < probabilities.length; i++) {
            System.out.printf("%5.3f, %5.3f\n", probabilities[i][0], 
                    trainingData[i][0]);
        }
        System.out.printf("Training Data Loss  Function Value: %5.5f\n", 
                gb.getLossValue());

         /* Retrieve the predicted binomial probabilities on the test data. */
        double[][] testProbabilities = gb.getTestClassProbabilities();
        System.out.println("Test Data Probabilities vs Actuals");
        for (int i = 0; i < testProbabilities.length; i++) {
            System.out.printf("%5.3f, %5.3f\n", testProbabilities[i][0],
                    testData[i][0]);
        }
        System.out.printf("Test Data Loss Function Value: %5.5f\n", 
                gb.getTestLossValue());
    }
}

Output

Training Data Probabilities vs Actuals
0.648, 0.000
0.173, 0.000
0.122, 0.000
0.784, 1.000
0.090, 0.000
0.766, 1.000
0.721, 1.000
0.090, 0.000
0.710, 1.000
0.129, 0.000
0.799, 1.000
0.660, 0.000
0.788, 1.000
0.136, 0.000
0.165, 0.000
0.763, 1.000
0.181, 0.000
0.149, 0.000
0.136, 0.000
0.802, 1.000
0.166, 0.000
0.122, 0.000
0.116, 0.000
0.253, 1.000
0.176, 0.000
0.648, 1.000
0.743, 1.000
0.673, 0.000
0.181, 0.000
0.812, 1.000
0.766, 1.000
0.799, 1.000
0.721, 1.000
0.853, 1.000
0.176, 0.000
0.240, 1.000
0.136, 0.000
0.150, 0.000
0.090, 0.000
0.648, 1.000
0.176, 0.000
0.802, 1.000
0.150, 0.000
0.173, 0.000
0.253, 1.000
0.181, 0.000
0.100, 0.000
0.181, 0.000
0.116, 0.000
0.240, 0.000
0.753, 1.000
0.136, 0.000
0.181, 0.000
0.766, 1.000
0.793, 1.000
0.116, 0.000
0.862, 1.000
0.799, 1.000
0.173, 0.000
0.788, 1.000
0.176, 0.000
0.095, 0.000
0.136, 0.000
0.181, 0.000
0.673, 1.000
0.802, 1.000
0.660, 1.000
0.176, 0.000
0.240, 0.000
0.181, 0.000
0.176, 0.000
0.116, 0.000
0.802, 1.000
0.802, 1.000
0.763, 1.000
0.673, 0.000
0.090, 0.000
0.199, 0.000
0.116, 0.000
0.111, 0.000
0.090, 0.000
0.150, 0.000
0.648, 0.000
0.129, 0.000
0.166, 0.000
0.190, 1.000
0.100, 0.000
0.253, 1.000
0.799, 1.000
0.142, 0.000
0.129, 0.000
0.793, 1.000
0.837, 1.000
0.648, 0.000
0.181, 0.000
0.710, 1.000
0.090, 0.000
0.149, 0.000
0.853, 1.000
0.788, 1.000
Training Data Loss  Function Value: 0.62967
Test Data Probabilities vs Actuals
0.176, 0.000
0.253, 0.000
0.793, 1.000
0.845, 1.000
0.821, 1.000
0.149, 0.000
0.111, 0.000
0.240, 0.000
0.176, 0.000
0.710, 1.000
Test Data Loss Function Value: 0.48412
Link to Java source.