This example illustrates using the alternative constructor for gradient boosting. This constructor accepts an instance of a regression tree to serve as the base learner and allows for more flexibility in controlling the base learner configuration.
import com.imsl.datamining.GradientBoosting; import com.imsl.datamining.decisionTree.ALACART; import com.imsl.stat.Random; public class GradientBoostingEx4 { public static void main(String[] args) throws Exception { GradientBoosting.VariableType[] VarType = { GradientBoosting.VariableType.CATEGORICAL, GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS, GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS, GradientBoosting.VariableType.CATEGORICAL, GradientBoosting.VariableType.QUANTITATIVE_CONTINUOUS }; double[][] trainingData = { {0.0, 0.4223019897, 1.7540411302, 3.0, 0.763836258}, {0.0, 0.0907259332, 0.8722643796, 2.0, 1.859006285}, {0.0, 0.1384744535, 0.838324877, 1.0, 0.249729405}, {1.0, 0.5435024537, 1.2359190206, 4.0, 0.831992314}, {0.0, 0.8359154933, 1.8527500411, 1.0, 1.089201049}, {1.0, 0.3577950741, 0.3652825342, 3.0, 2.204364955}, {1.0, 0.6799094002, 0.6610595905, 3.0, 1.44730419}, {0.0, 0.5821297709, 1.6180879478, 1.0, 2.957565282}, {1.0, 0.8229457375, 1.0201675948, 3.0, 2.872570117}, {0.0, 0.0633462721, 0.4140600134, 1.0, 0.63906323}, {1.0, 0.1019134156, 0.0677204356, 3.0, 1.493447564}, {0.0, 0.1551713238, 1.541201456, 3.0, 1.90219884}, {1.0, 0.8273822817, 0.2114979578, 3.0, 2.855730173}, {0.0, 0.7955570114, 1.8757067556, 2.0, 2.930132627}, {0.0, 0.6537275917, 1.2139678737, 2.0, 1.535853243}, {1.0, 0.1243124125, 1.5130919744, 4.0, 2.733670775}, {0.0, 0.2163864174, 0.7051185896, 2.0, 2.755841087}, {0.0, 0.2522670308, 1.2821007571, 2.0, 0.342119491}, {0.0, 0.8677104027, 1.9003869346, 2.0, 2.454376481}, {1.0, 0.8670932774, 0.7993045617, 4.0, 2.732812615}, {0.0, 0.5384287981, 0.1856947718, 1.0, 1.838702635}, {0.0, 0.7236269342, 0.4993310347, 1.0, 1.030699128}, {0.0, 0.0789361731, 1.011216166, 1.0, 2.539607478}, {1.0, 0.7631686032, 0.0536725423, 2.0, 1.401761686}, {0.0, 0.1157020777, 0.0123261618, 1.0, 2.098372295}, {1.0, 0.1451248352, 1.9153951635, 3.0, 0.492650534}, {1.0, 0.8497178114, 1.80941298, 4.0, 2.653985489}, {0.0, 0.8027864883, 1.2631045617, 3.0, 2.716214291}, {0.0, 0.798560373, 0.6872106791, 2.0, 2.763023936}, {1.0, 0.1816879204, 0.4323868025, 4.0, 0.098090197}, {1.0, 0.6301239238, 0.3670980479, 3.0, 0.02313788}, {1.0, 0.0411311248, 0.0173408454, 3.0, 1.994786958}, {1.0, 0.0427366099, 0.8114635572, 3.0, 2.966069741}, {1.0, 0.4107826762, 0.1929467283, 4.0, 0.573832348}, {0.0, 0.9441903098, 0.0729898885, 1.0, 1.710992303}, {1.0, 0.3597549822, 0.2799857073, 2.0, 0.969428934}, {0.0, 0.3741368004, 1.6052779425, 2.0, 1.866030486}, {0.0, 0.3515911719, 0.3383029872, 1.0, 2.639469598}, {0.0, 0.9184092905, 1.7116801264, 1.0, 1.380178652}, {1.0, 0.77803064, 1.9830028405, 3.0, 1.834021992}, {0.0, 0.573786814, 0.0258851023, 1.0, 1.52130144}, {1.0, 0.3279244492, 0.6977945678, 4.0, 1.322451157}, {0.0, 0.7924819048, 0.3694838509, 1.0, 2.369654865}, {0.0, 0.9787846403, 1.1470323382, 2.0, 0.037156113}, {1.0, 0.6910662795, 0.1019420708, 2.0, 2.58588334}, {0.0, 0.1367050812, 0.6635301332, 2.0, 0.368273583}, {0.0, 0.2826360366, 1.4468787988, 1.0, 2.705811968}, {0.0, 0.4524727969, 0.7885378413, 2.0, 0.851228449}, {0.0, 0.5118664701, 1.061143666, 1.0, 0.249325278}, {0.0, 0.9965170731, 0.2068265025, 2.0, 0.9210639}, {1.0, 0.7801500652, 1.565742691, 4.0, 1.827419217}, {0.0, 0.2906187973, 1.7036567871, 2.0, 2.842997725}, {0.0, 0.1753704017, 0.7124397112, 2.0, 1.262811961}, {1.0, 0.7796778064, 0.3478030777, 3.0, 0.90719801}, {1.0, 0.3889356288, 1.1771452101, 4.0, 1.298438454}, {0.0, 0.9374473374, 1.1879778663, 1.0, 1.854424331}, {1.0, 0.1939157653, 0.093336341, 4.0, 0.166025681}, {1.0, 0.2023756928, 0.0623724433, 3.0, 0.536441906}, {0.0, 0.1691352043, 1.1587338657, 2.0, 2.15494096}, {1.0, 0.0921523357, 0.2247394961, 3.0, 2.006995301}, {0.0, 0.819186907, 0.0392292971, 1.0, 1.282159743}, {0.0, 0.9458126165, 1.5268264762, 1.0, 1.960050194}, {0.0, 0.1373939656, 1.8025095677, 2.0, 0.633624267}, {0.0, 0.0555424779, 0.5022063241, 2.0, 0.639495004}, {1.0, 0.3581428374, 1.4436954968, 3.0, 1.408938169}, {1.0, 0.1189418568, 0.8011626904, 4.0, 0.210266769}, {1.0, 0.5782070206, 1.58215921, 3.0, 2.648622607}, {0.0, 0.460689794, 0.0704823257, 1.0, 1.45671379}, {0.0, 0.6959878858, 0.2245675903, 2.0, 1.849515461}, {0.0, 0.1930288749, 0.6296302159, 2.0, 2.597390946}, {0.0, 0.4912149447, 0.0713489084, 1.0, 0.426487798}, {0.0, 0.3496920248, 1.0135462089, 1.0, 2.962295362}, {1.0, 0.7716284667, 0.5387295927, 4.0, 0.736709363}, {1.0, 0.3463061263, 0.7819578522, 4.0, 1.597238498}, {1.0, 0.6897138762, 1.2793166582, 4.0, 2.376281484}, {0.0, 0.2818824656, 1.4379718141, 3.0, 2.627468417}, {0.0, 0.5659798421, 1.6243568249, 1.0, 1.624809581}, {0.0, 0.7965560518, 0.3933029529, 2.0, 0.415849269}, {0.0, 0.9156922165, 1.0465683565, 1.0, 2.802914008}, {0.0, 0.8299879942, 1.2237155279, 1.0, 2.611676934}, {0.0, 0.0241912066, 1.9213823564, 1.0, 0.659596571}, {0.0, 0.0948590154, 0.3609640412, 1.0, 1.287687748}, {0.0, 0.230467916, 1.9421709292, 3.0, 2.290064565}, {0.0, 0.2209760561, 0.4812708795, 1.0, 1.862393057}, {0.0, 0.4704530933, 0.2644400774, 1.0, 1.960189529}, {1.0, 0.1986645423, 0.48924731, 2.0, 0.333790415}, {0.0, 0.9201823308, 1.4247304946, 1.0, 0.367654009}, {1.0, 0.8118424334, 0.1017034058, 2.0, 2.001390385}, {1.0, 0.1347265388, 0.1362061207, 3.0, 1.151431168}, {0.0, 0.9884603191, 1.5700038988, 2.0, 0.717332943}, {0.0, 0.1964012324, 0.4306495111, 1.0, 1.689056823}, {1.0, 0.4031848807, 1.1251849262, 4.0, 1.977734922}, {1.0, 0.0341882701, 0.3717348906, 4.0, 1.830587439}, {0.0, 0.5073120815, 1.7860476542, 3.0, 0.142862822}, {0.0, 0.6363195451, 0.6631249222, 2.0, 1.211148724}, {1.0, 0.1642774614, 1.1963615627, 3.0, 0.843113448}, {0.0, 0.0945515088, 1.8669327218, 1.0, 2.417198514}, {0.0, 0.2364508687, 1.4035215094, 2.0, 2.964026097}, {1.0, 0.7490112646, 0.1778408242, 4.0, 2.343119453}, {1.0, 0.5193473259, 0.3090019161, 3.0, 1.300277323}}; double[][] testData = { {0.0, 0.0093314846, 0.0315045565, 1.0, 2.043737003}, {0.0, 0.0663379349, 0.0822378928, 2.0, 1.202557951}, {1.0, 0.9728333529, 0.8778284262, 4.0, 0.205940753}, {1.0, 0.7655418115, 0.3292853828, 4.0, 2.940793653}, {1.0, 0.1610695978, 0.3832762009, 4.0, 1.96753633}, {0.0, 0.0849463812, 1.4988451041, 2.0, 2.307902221}, {0.0, 0.7932621511, 1.2098399368, 1.0, 0.886761862}, {0.0, 0.1336030525, 0.2794256401, 2.0, 2.672175208}, {0.0, 0.4758480834, 0.0441179522, 1.0, 0.399722717}, {1.0, 0.1137434335, 0.922533263, 3.0, 1.927635631}}; ALACART dTree= new ALACART(trainingData,0,VarType); dTree.setMaxDepth(10); dTree.setMinObsPerNode(10); dTree.setMaxNodes(4); GradientBoosting gb = new GradientBoosting(dTree); gb.setShrinkageParameter(0.05); gb.setSampleSizeProportion(0.5); gb.setRandomObject(new Random(123457)); gb.fitModel(); /* Run gradient boosting, generating fitted values and predicted values on the test data. */ gb.predict(testData); /* Retrieve the fitted binomial probabilities on the training data. */ double[][] probabilities = gb.getClassProbabilities(); System.out.println("Training Data Probabilities vs Actuals"); for (int i = 0; i < probabilities.length; i++) { System.out.printf("%5.3f, %5.3f\n", probabilities[i][0], trainingData[i][0]); } System.out.printf("Training Data Loss Function Value: %5.5f\n", gb.getLossValue()); /* Retrieve the predicted binomial probabilities on the test data. */ double[][] testProbabilities = gb.getTestClassProbabilities(); System.out.println("Test Data Probabilities vs Actuals"); for (int i = 0; i < testProbabilities.length; i++) { System.out.printf("%5.3f, %5.3f\n", testProbabilities[i][0], testData[i][0]); } System.out.printf("Test Data Loss Function Value: %5.5f\n", gb.getTestLossValue()); } }
Training Data Probabilities vs Actuals 0.648, 0.000 0.173, 0.000 0.122, 0.000 0.784, 1.000 0.090, 0.000 0.766, 1.000 0.721, 1.000 0.090, 0.000 0.710, 1.000 0.129, 0.000 0.799, 1.000 0.660, 0.000 0.788, 1.000 0.136, 0.000 0.165, 0.000 0.763, 1.000 0.181, 0.000 0.149, 0.000 0.136, 0.000 0.802, 1.000 0.166, 0.000 0.122, 0.000 0.116, 0.000 0.253, 1.000 0.176, 0.000 0.648, 1.000 0.743, 1.000 0.673, 0.000 0.181, 0.000 0.812, 1.000 0.766, 1.000 0.799, 1.000 0.721, 1.000 0.853, 1.000 0.176, 0.000 0.240, 1.000 0.136, 0.000 0.150, 0.000 0.090, 0.000 0.648, 1.000 0.176, 0.000 0.802, 1.000 0.150, 0.000 0.173, 0.000 0.253, 1.000 0.181, 0.000 0.100, 0.000 0.181, 0.000 0.116, 0.000 0.240, 0.000 0.753, 1.000 0.136, 0.000 0.181, 0.000 0.766, 1.000 0.793, 1.000 0.116, 0.000 0.862, 1.000 0.799, 1.000 0.173, 0.000 0.788, 1.000 0.176, 0.000 0.095, 0.000 0.136, 0.000 0.181, 0.000 0.673, 1.000 0.802, 1.000 0.660, 1.000 0.176, 0.000 0.240, 0.000 0.181, 0.000 0.176, 0.000 0.116, 0.000 0.802, 1.000 0.802, 1.000 0.763, 1.000 0.673, 0.000 0.090, 0.000 0.199, 0.000 0.116, 0.000 0.111, 0.000 0.090, 0.000 0.150, 0.000 0.648, 0.000 0.129, 0.000 0.166, 0.000 0.190, 1.000 0.100, 0.000 0.253, 1.000 0.799, 1.000 0.142, 0.000 0.129, 0.000 0.793, 1.000 0.837, 1.000 0.648, 0.000 0.181, 0.000 0.710, 1.000 0.090, 0.000 0.149, 0.000 0.853, 1.000 0.788, 1.000 Training Data Loss Function Value: 0.62967 Test Data Probabilities vs Actuals 0.176, 0.000 0.253, 0.000 0.793, 1.000 0.845, 1.000 0.821, 1.000 0.149, 0.000 0.111, 0.000 0.240, 0.000 0.176, 0.000 0.710, 1.000 Test Data Loss Function Value: 0.48412Link to Java source.