Class DecisionTree
- All Implemented Interfaces:
Serializable,Cloneable
- Direct Known Subclasses:
CHAID,DecisionTreeInfoGain,QUEST
Abstract class for generating a decision tree for a single response variable and one or more predictor variables.
Tree Generation Methods
This package contains four of the most widely used algorithms for decision
trees (C45, ALACART, CHAID, and QUEST). The
user may also provide an alternate algorithm by extending the
DecisionTree or DecisionTreeInfoGain abstract class and
implementing the abstract method selectSplitVariable(double[][], double[], double[], double[], double[], int[]).
Optimal Tree Size
Minimum Cost-complexity Pruning
A strategy to address overfitting is to grow the tree as large as possible, and then use some logic to prune it back. Let T represent a decision tree generated by any of the methods above. The idea (from Breiman, et. al.) is to find the smallest sub-tree of T that minimizes the cost-complexity measure:
$$R_{\delta}(T)=R(T)+\delta|\tilde{T}|\mbox{,} $$\(\tilde{T}\) denotes the set of terminal nodes. \(|\tilde{T}|\) represents the number of terminal nodes, and \(\delta\geq0\) is a cost-complexity parameter. For a categorical target variable
$$R(T)=\sum_{t\in\tilde{T}}{R(t)}= \sum_{t\in\tilde{T}}{r(t)p(t)}$$ $$r(t)=\min_i{\sum_j{C(i|j)p(j|t)}}\) $$ $$p(t)=\mbox{Pr}[x\in t] \mbox{and}\;p(j|t)=\mbox{Pr}[y=j|x\in t]\mbox{,} $$and \(C(i|j)\) is the cost for misclassifying the actual class j as i. Note that \(C(j|j)=0\) and \(C(i|j)\gt0\), for \(i\neq j\).
When the target is continuous (and the problem is a regression problem), the metric is instead the mean squared error
$$R(T)=\sum_{t\in{\tilde{T}}}R(t)=\frac{1}{N} \sum_{t\in{\tilde{T}}}\sum_{y_n\in{t}}(y_n-\hat{y}(t))^2$$This class implements the optimal pruning algorithm 10.1, page 294 in Breiman, et. al (1984). The result of the algorithm is a sequence of sub-trees \(T_{\max}\succ{T_1}\succ{T_2}\succ{T_{M-1}} \succ{\{t_0\}}\) obtained by pruning the fully generated tree, \(T_{\max}\) , until the sub-tree consists of the single root node, \(\{t_0\}\). Corresponding to the sequence of sub-trees is the sequence of complexity values, \(0\leq \delta_{\min}\lt\delta_1\lt\delta_2\lt\ldots\lt\delta_{M-1}\lt\delta_M \) where M is the number of steps it takes in the algorithm to reach the root node. The sub-trees represent the optimally-pruned sub-trees for the sequence of complexity values. The minimum complexity \(\delta_{\min}\) can be set via an optional argument.
V-Fold Cross-Validation
The CrossValidation class can be used for model validation.
Prediction
Ensemble Methods
The BootstrapAggregation class provides predictions through an ensemble of fitted trees, where the training is done on bootstrap samples.
The Apriori class provides predictions through an ensemble of trees trained on random subsets of the data and iteratively refined using the stochastic gradient boosting algorithm.
The RandomTrees class provides predictions through an ensemble of
fitted trees, where the training is done on bootstrap samples and random
subsets of predictors.
Missing Values
Any observation or case with a missing response variable is eliminated from
the analysis. If a predictor has a missing value, each algorithm skips that
case when evaluating the given predictor. When making a prediction for a new
case, if the split variable is missing, the prediction function applies
surrogate split-variables and splitting rules in turn, if they are
estimated with the decision tree. Otherwise, the prediction function returns
the prediction from the most recent non-terminal node. In this
implementation, only ALACART estimates surrogate split variables when
requested.
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic classException thrown when the maximum tree size has been exceeded.static classException thrown when pruning fails to converge.static classException thrown when attempting to split a node that is already pure (response variable is constant).Nested classes/interfaces inherited from class com.imsl.datamining.PredictiveModel
PredictiveModel.CloneNotSupportedException, PredictiveModel.PredictiveModelException, PredictiveModel.StateChangeException, PredictiveModel.SumOfProbabilitiesNotOneException, PredictiveModel.VariableType -
Constructor Summary
ConstructorsConstructorDescriptionDecisionTree(double[][] xy, int responseColumnIndex, PredictiveModel.VariableType[] varType) Constructs aDecisionTreeobject for a single response variable and multiple predictor variables. -
Method Summary
Modifier and TypeMethodDescriptionvoidfitModel()Fits the decision tree.double[]Returns an array containing cost-complexity values.Returns aTreeobject.doubleReturns the mean squared error on the training data.intReturns the maximum depth a tree is allowed to have.intReturns the maximum number ofTreeNodeinstances allowed in a tree.doubleReturns the mean squared error.doubleintReturns the minimum number of observations that are required for any child node before performing a split.intReturns the minimum number of observations that are required in a node before performing a split.int[]getNodeAssigments(double[][] testData) Returns the terminal node assignments for each row of the test data.intReturns the number of cost complexity values determined by the pruning algorithm.intReturns the number of random features used in the splitting rules whenrandomFeatureSelection=true.booleanReturns the current setting of the boolean to automatically prune the decision tree.booleanReturns the current setting of the boolean to perform random feature selection.double[]predict()Predicts the training examples (in-sample predictions) using the most recently grown tree.double[]predict(double[][] testData) Predicts new data using the most recently grown decision tree.double[]predict(double[][] testData, double[] testDataWeights) Predicts new weighted data using the most recently grown decision tree.voidprintDecisionTree(boolean printMaxTree) Prints the contents of the decision tree using distinct but general labels.voidprintDecisionTree(String responseName, String[] predictorNames, String[] classNames, String[] categoryNames, boolean printMaxTree) Prints the contents of the decision tree using labels.voidpruneTree(double gamma) Finds the minimum cost-complexity decision tree for the cost-complexity value, gamma.protected abstract intselectSplitVariable(double[][] xy, double[] classCounts, double[] parentFreq, double[] splitValue, double[] splitCriterionValue, int[] splitPartition) Abstract method for selecting the next split variable and split definition for the node.voidsetAutoPruningFlag(boolean autoPruningFlag) Sets the flag to automatically prune the tree during the fitting procedure.protected voidSets the configuration ofPredictiveModelto that of the input model.voidsetCostComplexityValues(double[] gammas) Sets the cost-complexity values.voidsetMaxDepth(int nLevels) Sets the maximum tree depth allowed.voidsetMaxNodes(int maxNodes) Sets the maximum number of nodes allowed in a tree.voidsetMinCostComplexityValue(double minCostComplexity) Sets the value of the minimum cost-complexity value.voidsetMinObsPerChildNode(int nObs) Sets the minimum number of observations that a child node must have in order to split.voidsetMinObsPerNode(int nObs) Sets the minimum number of observations a node must have to allow a split.voidsetNumberOfRandomFeatures(int numberOfRandomFeatures) Sets the number of predictors in the random subset to select from at each node.voidsetRandomFeatureSelection(boolean selectRandomFeatures) Sets the flag to select split variables from a random subset of the features.Methods inherited from class com.imsl.datamining.PredictiveModel
clone, getClassCounts, getClassErrors, getClassErrors, getClassLabels, getClassProbabilities, getCostMatrix, getMaxNumberOfCategories, getMaxNumberOfIterations, getNumberOfClasses, getNumberOfColumns, getNumberOfMissing, getNumberOfPredictors, getNumberOfRows, getNumberOfUniquePredictorValues, getPredictorIndexes, getPredictorTypes, getPrintLevel, getPriorProbabilities, getRandomObject, getResponseColumnIndex, getResponseVariableAverage, getResponseVariableMostFrequentClass, getResponseVariableType, getTotalWeight, getVariableType, getWeights, getXY, isConstantSeries, isMustFitModel, isUserFixedNClasses, setClassCounts, setClassLabels, setClassProbabilities, setCostMatrix, setMaxNumberOfCategories, setMaxNumberOfIterations, setMustFitModel, setNumberOfClasses, setPredictorIndex, setPredictorTypes, setPrintLevel, setPriorProbabilities, setRandomObject, setResponseColumnIndex, setTrainingData, setVariableType, setWeights
-
Constructor Details
-
DecisionTree
Constructs aDecisionTreeobject for a single response variable and multiple predictor variables.- Parameters:
xy- adoublematrix containing the training data and associated response valuesresponseColumnIndex- anintspecifying the column index of the response variablevarType- aPredictiveModel.VariableTypearray containing the type of each variable
-
-
Method Details
-
selectSplitVariable
protected abstract int selectSplitVariable(double[][] xy, double[] classCounts, double[] parentFreq, double[] splitValue, double[] splitCriterionValue, int[] splitPartition) Abstract method for selecting the next split variable and split definition for the node.- Parameters:
xy- adoublematrix containing the dataclassCounts- adoublearray containing the counts for each class of the response variable, when it is categoricalparentFreq- adoublearray used to indicate which subset of the observations belong in the current nodesplitValue- adoublearray representing the resulting split point if the selected variable is quantitativesplitCriterionValue- adouble, the value of the criterion used to determine the splitting variablesplitPartition- anintarray indicating the resulting split partition if the selected variable is categorical- Returns:
- an
intspecifying the index of the split variable inthis.getPredictorIndexes()
-
fitModel
public void fitModel() throws PredictiveModel.PredictiveModelException, DecisionTree.PruningFailedToConvergeException, PredictiveModel.StateChangeException, DecisionTree.PureNodeException, PredictiveModel.SumOfProbabilitiesNotOneException, DecisionTree.MaxTreeSizeExceededExceptionFits the decision tree. Implements the abstract method.- Overrides:
fitModelin classPredictiveModel- Throws:
PredictiveModel.PredictiveModelException- is thrown when an exception occurs in the com.imsl.datamining.PredictiveModel. Superclass exceptions should be considered such as com.imsl.datamining.PredictiveModel.StateChangeException and com.imsl.datamining.PredictiveModel.SumOfProbabilitiesNotOneException.DecisionTree.PruningFailedToConvergeException- is thrown when pruning fails to converge.PredictiveModel.StateChangeException- is thrown when an input parameter changes that might affect the model estimates or predictions.DecisionTree.PureNodeException- is thrown when attempting to split a node that is already pure.PredictiveModel.SumOfProbabilitiesNotOneException- is thrown when the sum of probabilities is not approximately one.DecisionTree.MaxTreeSizeExceededException- is thrown when the maximum tree size has been exceeded.
-
getNumberOfComplexityValues
public int getNumberOfComplexityValues()Returns the number of cost complexity values determined by the pruning algorithm.- Returns:
- an
int, the number of cost complexity values
-
getNodeAssigments
public int[] getNodeAssigments(double[][] testData) Returns the terminal node assignments for each row of the test data.- Parameters:
testData- adoublematrix containing the test datatestDatamust have the same column structure and type as the training data.- Returns:
- an
intarray containing the (0-based) terminal node id's for each observation intestData
-
getCostComplexityValues
public double[] getCostComplexityValues()Returns an array containing cost-complexity values.- Returns:
- a
doublearray containing the cost-complexity valuesThe cost-complexity values are found via the optimal pruning algorithm of Breiman, et. al.
-
setCostComplexityValues
public void setCostComplexityValues(double[] gammas) Sets the cost-complexity values. For the original tree, the values are generated infitModel()whenisAutoPruningFlag()returnstrue.- Parameters:
gammas-doublearray containing cost-complexity values. This method is used when copying the configuration of one tree to another.Default:
gammas=setMinCostComplexityValue(double).
-
getDecisionTree
Returns aTreeobject.- Returns:
- a
Treeobject containing the tree structure information - Throws:
PredictiveModel.StateChangeException- an input parameter has changed that might affect the model estimates or predictions.
-
getFittedMeanSquaredError
Returns the mean squared error on the training data.- Returns:
- a
doubleequal to the mean squared error between the fitted value and the actual value of the response variable in the training data - Throws:
PredictiveModel.StateChangeException- an input parameter has changed that might affect the model estimates or predictions.
-
getMaxDepth
public int getMaxDepth()Returns the maximum depth a tree is allowed to have.- Returns:
- an
int, the maximum depth a tree is allowed to have
-
setMaxDepth
public void setMaxDepth(int nLevels) Sets the maximum tree depth allowed.- Parameters:
nLevels- anintspecifying the maximum depth that theDecisionTreeis allowed to have.nLevelsshould be strictly positive.Default:
nLevels= 10.
-
getMaxNodes
public int getMaxNodes()Returns the maximum number ofTreeNodeinstances allowed in a tree.- Returns:
- an
int, the maximum number of nodes allowed in a tree
-
setMaxNodes
public void setMaxNodes(int maxNodes) Sets the maximum number of nodes allowed in a tree.- Parameters:
maxNodes- anintspecifying the maximum number of nodes allowed in a treeDefault:
maxNodes= 100.
-
getMeanSquaredPredictionError
Returns the mean squared error.- Returns:
- a
doubleequal to the mean squared error between the predicted value and the actual value of the response variable. The error is the in-sample fitted error ifpredictis first called with no arguments. Otherwise, the error is relative to the test data provided in the call topredict. - Throws:
PredictiveModel.StateChangeException- an input parameter has changed that might affect the model estimates or predictions.
-
getMinObsPerChildNode
public int getMinObsPerChildNode()Returns the minimum number of observations that are required for any child node before performing a split.- Returns:
- an
int, the minimum number of observations that are required for any child node before performing a split
-
setMinObsPerChildNode
public void setMinObsPerChildNode(int nObs) Sets the minimum number of observations that a child node must have in order to split.- Parameters:
nObs- anintspecifying the minimum number of observations that a child node must have in order to split the current node.nObsmust be strictly positive.nObsmust also be greater than the minimum number of observations required before a node can splitsetMinObsPerNode(int).Default:
nObs= 7.
-
getMinObsPerNode
public int getMinObsPerNode()Returns the minimum number of observations that are required in a node before performing a split.- Returns:
- an
intindicating the minimum number of observations that are required in a node before performing a split.
-
setMinObsPerNode
public void setMinObsPerNode(int nObs) Sets the minimum number of observations a node must have to allow a split.- Parameters:
nObs- anintspecifying the number of observations the current node must have before considering a split.nObsshould be greater than 1 but less than or equal to the number of observations inxy.Default:
nObs= 21.
-
getNumberOfRandomFeatures
public int getNumberOfRandomFeatures()Returns the number of random features used in the splitting rules whenrandomFeatureSelection=true.- Returns:
- an
int, the number of random features
-
setNumberOfRandomFeatures
public void setNumberOfRandomFeatures(int numberOfRandomFeatures) Sets the number of predictors in the random subset to select from at each node.- Parameters:
numberOfRandomFeatures- anint, the number of predictors in the random subsetDefault: numberOfFeatures=\(\sqrt(p) \) for classification problems, and p/3 for regression problems, where p is the number of predictors in the training data
-
setRandomFeatureSelection
public void setRandomFeatureSelection(boolean selectRandomFeatures) Sets the flag to select split variables from a random subset of the features.- Parameters:
selectRandomFeatures- aboolean, indicating whether or not to select random featuresDefault: selectRandomFeatures
false
-
isRandomFeatureSelection
public boolean isRandomFeatureSelection()Returns the current setting of the boolean to perform random feature selection.- Returns:
- a
boolean, the value of the flag. Iftrue, the set of variables considered at each node is randomly selected. If the flag isfalse, all variables are considered at each node.
-
isAutoPruningFlag
public boolean isAutoPruningFlag()Returns the current setting of the boolean to automatically prune the decision tree.- Returns:
- a
boolean, the value of the auto-pruning flag. Iftrue, the model is configured to automatically prune the decision tree. If the flag isfalse, no pruning is performed.
-
setAutoPruningFlag
public void setAutoPruningFlag(boolean autoPruningFlag) Sets the flag to automatically prune the tree during the fitting procedure.The default value is
false. Set totruebefore callingfitModel()in order to prune the tree automatically. The pruning will use the cost-complexity value equal tominCostComplexityValue. See alsopruneTree(double)which prunes the tree using a given cost-complexity value.- Parameters:
autoPruningFlag- aboolean, specifying the value of the flag. Iftrue, the maximally grown tree should be automatically pruned infitModel()Default:
autoPruningFlag=false.
-
predict
Predicts the training examples (in-sample predictions) using the most recently grown tree.- Specified by:
predictin classPredictiveModel- Returns:
- a
doublearray containing fitted values of the response variable using the most recently grown decision tree. To populate fitted values, use thepredictmethod without arguments. - Throws:
PredictiveModel.StateChangeException- is thrown when an input parameter changes that might affect the model estimates or predictions.
-
predict
Predicts new data using the most recently grown decision tree.- Specified by:
predictin classPredictiveModel- Parameters:
testData- adoublematrix containing test data for which predictions are to be made using the current tree.testDatamust have the same number of columns in the same arrangement asxy.- Returns:
- a
doublearray containing predicted values - Throws:
PredictiveModel.StateChangeException- is thrown when an input parameter changes that might affect the model estimates or predictions.
-
predict
public double[] predict(double[][] testData, double[] testDataWeights) throws PredictiveModel.StateChangeException Predicts new weighted data using the most recently grown decision tree.- Specified by:
predictin classPredictiveModel- Parameters:
testData- adoublematrix containing test data for which predictions are to be made using the current tree.testDatamust have the same number of columns in the same arrangement asxy.testDataWeights- adoublearray containing weights for each row oftestData- Returns:
- a
doublearray containing predicted values - Throws:
PredictiveModel.StateChangeException- is thrown when an input parameter changes that might affect the model estimates or predictions.
-
setMinCostComplexityValue
public void setMinCostComplexityValue(double minCostComplexity) Sets the value of the minimum cost-complexity value.- Parameters:
minCostComplexity- adoubleindicating the smallest value to use in cost-complexity pruning. The value must be in [0.0, 1.0].Default:
minCostComplexity= 0.
-
getMinCostComplexityValue
public double getMinCostComplexityValue() -
printDecisionTree
public void printDecisionTree(boolean printMaxTree) Prints the contents of the decision tree using distinct but general labels.This method uses default values for the variable labels when printing (see
printDecisionTree(String, String[], String[], String[], boolean) for these values.)- Parameters:
printMaxTree- abooleanindicating that the maximal tree should be printed. Whentrue, the maximal tree is printed. Otherwise, the pruned tree is printed.
-
printDecisionTree
public void printDecisionTree(String responseName, String[] predictorNames, String[] classNames, String[] categoryNames, boolean printMaxTree) Prints the contents of the decision tree using labels.- Parameters:
responseName- aStringspecifying a name for the response variableIf
null, the default value is used.Default:
responseName= YpredictorNames- aStringarray specifying names for the predictor variablesIf
null, the default value is used.Default:
predictorNames= X0, X1, ...classNames- aStringarray specifying names for the class levelsIf
null, the default value is used.Default:
classNames= 0, 1, 2, ...categoryNames- aStringarray specifying names for the categories of the predictor variablesIf
null, the default value is used.Default:
categoryNames= 0, 1, 2, ...printMaxTree- abooleanindicating that the maximal tree should be printed. Whentrue, the maximal tree is printed. Otherwise, the pruned tree is printed.
-
setConfiguration
Sets the configuration ofPredictiveModelto that of the input model.- Specified by:
setConfigurationin classPredictiveModel- Parameters:
pm- aPredictiveModelobject
-
pruneTree
public void pruneTree(double gamma) Finds the minimum cost-complexity decision tree for the cost-complexity value, gamma.The method implements the optimal pruning algorithm 10.1, page 294 in Breiman, et. al (1984). The result of the algorithm is a sequence of sub-trees \(T_{\max}\succ{T_1}\succ{T_2}\succ{T_{M-1}} \succ{\{t_0\}}\) obtained by pruning the fully generated tree, \(T_{\max}\) , until the sub-tree consists of the single root node, \(\{t_0\}\). Corresponding to the sequence of sub-trees is the sequence of complexity values, \(0\leq \delta_{\min}\lt\delta_1\lt\delta_2\lt\ldots\lt\delta_{M-1}\lt\delta_M \) where M is the number of steps it takes in the algorithm to reach the root node. The sub-trees represent the optimally pruned sub-trees for the sequence of complexity values.
The effect of the pruning is stored in the tree's terminal node array. That is, when the algorithm determines that the tree should be pruned at a particular node, it sets that node to be a terminal node using the method
Tree.setTerminalNode(int, boolean).No other changes are made to the tree structure so that the maximal tree can still be printed and reviewed. However, once a tree is pruned, all the predictions will use the pruned tree.
- Parameters:
gamma- adoublegiving the value of the cost-complexity parameter
-