public abstract class DecisionTree extends PredictiveModel implements Serializable, Cloneable
Abstract class for generating a decision tree for a single response variable and one or more predictor variables.
This package contains four of the most widely used algorithms for decision
trees (C45
, ALACART
, CHAID
, and QUEST
). The
user may also provide an alternate algorithm by extending the
DecisionTree
or DecisionTreeInfoGain
abstract class and
implementing the abstract method DecisionTree.selectSplitVariable(double[][], double[], double[], double[], double[], int[])
.
A strategy to address overfitting is to grow the tree as large as possible, and then use some logic to prune it back. Let T represent a decision tree generated by any of the methods above. The idea (from Breiman, et. al.) is to find the smallest sub-tree of T that minimizes the cost-complexity measure:
$$R_{\delta}(T)=R(T)+\delta|\tilde{T}|\mbox{,} $$\(\tilde{T}\) denotes the set of terminal nodes. \(|\tilde{T}|\) represents the number of terminal nodes, and \(\delta\geq0\) is a cost-complexity parameter. For a categorical target variable
$$R(T)=\sum_{t\in\tilde{T}}{R(t)}= \sum_{t\in\tilde{T}}{r(t)p(t)}$$ $$r(t)=\min_i{\sum_j{C(i|j)p(j|t)}}\) $$ $$p(t)=\mbox{Pr}[x\in t] \mbox{and}\;p(j|t)=\mbox{Pr}[y=j|x\in t]\mbox{,} $$and \(C(i|j)\) is the cost for misclassifying the actual class j as i. Note that \(C(j|j)=0\) and \(C(i|j)\gt0\), for \(i\neq j\).
When the target is continuous (and the problem is a regression problem), the metric is instead the mean squared error
$$R(T)=\sum_{t\in{\tilde{T}}}R(t)=\frac{1}{N} \sum_{t\in{\tilde{T}}}\sum_{y_n\in{t}}(y_n-\hat{y}(t))^2$$This class implements the optimal pruning algorithm 10.1, page 294 in Breiman, et. al (1984). The result of the algorithm is a sequence of sub-trees \(T_{\max}\succ{T_1}\succ{T_2}\succ{T_{M-1}} \succ{\{t_0\}}\) obtained by pruning the fully generated tree, \(T_{\max}\) , until the sub-tree consists of the single root node, \(\{t_0\}\). Corresponding to the sequence of sub-trees is the sequence of complexity values, \(0\leq \delta_{\min}\lt\delta_1\lt\delta_2\lt\ldots\lt\delta_{M-1}\lt\delta_M \) where M is the number of steps it takes in the algorithm to reach the root node. The sub-trees represent the optimally-pruned sub-trees for the sequence of complexity values. The minimum complexity \(\delta_{\min}\) can be set via an optional argument.
The CrossValidation class can be used for model validation.
The BootstrapAggregation class provides predictions through an ensemble of fitted trees, where the training is done on bootstrap samples.
The Apriori class provides predictions through an ensemble of trees trained on random subsets of the data and iteratively refined using the stochastic gradient boosting algorithm.
The RandomTrees
class provides predictions through an ensemble of
fitted trees, where the training is done on bootstrap samples and random
subsets of predictors.
Any observation or case with a missing response variable is eliminated from
the analysis. If a predictor has a missing value, each algorithm skips that
case when evaluating the given predictor. When making a prediction for a new
case, if the split variable is missing, the prediction function applies
surrogate split-variables and splitting rules in turn, if they are
estimated with the decision tree. Otherwise, the prediction function returns
the prediction from the most recent non-terminal node. In this
implementation, only ALACART
estimates surrogate split variables when
requested.
Modifier and Type | Class and Description |
---|---|
static class |
DecisionTree.MaxTreeSizeExceededException
Exception thrown when the maximum tree size has been exceeded.
|
static class |
DecisionTree.PruningFailedToConvergeException
Exception thrown when pruning fails to converge.
|
static class |
DecisionTree.PureNodeException
Exception thrown when attempting to split a node that is already pure
(response variable is constant).
|
PredictiveModel.CloneNotSupportedException, PredictiveModel.PredictiveModelException, PredictiveModel.StateChangeException, PredictiveModel.SumOfProbabilitiesNotOneException, PredictiveModel.VariableType
Constructor and Description |
---|
DecisionTree(double[][] xy,
int responseColumnIndex,
PredictiveModel.VariableType[] varType)
Constructs a
DecisionTree object for a single response
variable and multiple predictor variables. |
Modifier and Type | Method and Description |
---|---|
void |
fitModel()
Fits the decision tree.
|
double[] |
getCostComplexityValues()
Returns an array containing cost-complexity values.
|
Tree |
getDecisionTree()
Returns a
Tree object. |
double |
getFittedMeanSquaredError()
Returns the mean squared error on the training data.
|
int |
getMaxDepth()
Returns the maximum depth a tree is allowed to have.
|
int |
getMaxNodes()
Returns the maximum number of
TreeNode instances allowed in a
tree. |
double |
getMeanSquaredPredictionError()
Returns the mean squared error.
|
double |
getMinCostComplexityValue() |
int |
getMinObsPerChildNode()
Returns the minimum number of observations that are required for any
child node before performing a split.
|
int |
getMinObsPerNode()
Returns the minimum number of observations that are required in a node
before performing a split.
|
int[] |
getNodeAssigments(double[][] testData)
Returns the terminal node assignments for each row of the test data.
|
int |
getNumberOfComplexityValues()
Returns the number of cost complexity values determined by the pruning
algorithm.
|
int |
getNumberOfRandomFeatures()
Returns the number of random features used in the splitting rules when
randomFeatureSelection =true . |
boolean |
isAutoPruningFlag()
Returns the current setting of the boolean to automatically prune the
decision tree.
|
boolean |
isRandomFeatureSelection()
Returns the current setting of the boolean to perform random feature
selection.
|
double[] |
predict()
Predicts the training examples (in-sample predictions) using the most
recently grown tree.
|
double[] |
predict(double[][] testData)
Predicts new data using the most recently grown decision tree.
|
double[] |
predict(double[][] testData,
double[] testDataWeights)
Predicts new weighted data using the most recently grown decision tree.
|
void |
printDecisionTree(boolean printMaxTree)
Prints the contents of the decision tree using distinct but general
labels.
|
void |
printDecisionTree(String responseName,
String[] predictorNames,
String[] classNames,
String[] categoryNames,
boolean printMaxTree)
Prints the contents of the decision tree using labels.
|
void |
pruneTree(double gamma)
Finds the minimum cost-complexity decision tree for the cost-complexity
value, gamma.
|
protected abstract int |
selectSplitVariable(double[][] xy,
double[] classCounts,
double[] parentFreq,
double[] splitValue,
double[] splitCriterionValue,
int[] splitPartition)
Abstract method for selecting the next split variable and split
definition for the node.
|
void |
setAutoPruningFlag(boolean autoPruningFlag)
Sets the flag to automatically prune the tree during the fitting
procedure.
|
protected void |
setConfiguration(PredictiveModel pm)
Sets the configuration of
PredictiveModel to that of the
input model. |
void |
setCostComplexityValues(double[] gammas)
Sets the cost-complexity values.
|
void |
setMaxDepth(int nLevels)
Sets the maximum tree depth allowed.
|
void |
setMaxNodes(int maxNodes)
Sets the maximum number of nodes allowed in a tree.
|
void |
setMinCostComplexityValue(double minCostComplexity)
Sets the value of the minimum cost-complexity value.
|
void |
setMinObsPerChildNode(int nObs)
Sets the minimum number of observations that a child node must have in
order to split.
|
void |
setMinObsPerNode(int nObs)
Sets the minimum number of observations a node must have to allow a
split.
|
void |
setNumberOfRandomFeatures(int numberOfRandomFeatures)
Sets the number of predictors in the random subset to select from at each
node.
|
void |
setRandomFeatureSelection(boolean selectRandomFeatures)
Sets the flag to select split variables from a random subset of the
features.
|
clone, getClassCounts, getClassErrors, getClassLabels, getClassProbabilities, getCostMatrix, getMaxNumberOfCategories, getMaxNumberOfIterations, getNumberOfClasses, getNumberOfColumns, getNumberOfMissing, getNumberOfPredictors, getNumberOfRows, getNumberOfUniquePredictorValues, getPredictorIndexes, getPredictorTypes, getPrintLevel, getPriorProbabilities, getRandomObject, getResponseColumnIndex, getResponseVariableAverage, getResponseVariableMostFrequentClass, getResponseVariableType, getTotalWeight, getVariableType, getWeights, getXY, isConstantSeries, isMustFitModel, isUserFixedNClasses, setClassCounts, setClassLabels, setClassProbabilities, setCostMatrix, setMaxNumberOfCategories, setMaxNumberOfIterations, setMustFitModel, setNumberOfClasses, setPredictorIndex, setPredictorTypes, setPrintLevel, setPriorProbabilities, setRandomObject, setResponseColumnIndex, setTrainingData, setVariableType, setWeights
public DecisionTree(double[][] xy, int responseColumnIndex, PredictiveModel.VariableType[] varType)
DecisionTree
object for a single response
variable and multiple predictor variables.xy
- a double
matrix containing the training data and
associated response valuesresponseColumnIndex
- an int
specifying the column
index of the response variablevarType
- a PredictiveModel.VariableType
array containing the type of each variableprotected abstract int selectSplitVariable(double[][] xy, double[] classCounts, double[] parentFreq, double[] splitValue, double[] splitCriterionValue, int[] splitPartition)
xy
- a double
matrix containing the dataclassCounts
- a double
array containing the counts for
each class of the response variable, when it is categoricalparentFreq
- a double
array used to indicate which
subset of the observations belong in the current nodesplitValue
- a double
array representing the resulting
split point if the selected variable is quantitativesplitCriterionValue
- a double
, the value of the
criterion used to determine the splitting variablesplitPartition
- an int
array indicating the resulting
split partition if the selected variable is categoricalint
specifying the index of the split variable in
this.getPredictorIndexes()
public void fitModel() throws PredictiveModel.PredictiveModelException, DecisionTree.PruningFailedToConvergeException, PredictiveModel.StateChangeException, DecisionTree.PureNodeException, PredictiveModel.SumOfProbabilitiesNotOneException, DecisionTree.MaxTreeSizeExceededException
fitModel
in class PredictiveModel
PredictiveModel.PredictiveModelException
- is thrown when an
exception occurs in the com.imsl.datamining.PredictiveModel. Superclass
exceptions should be considered such as
com.imsl.datamining.PredictiveModel.StateChangeException and
com.imsl.datamining.PredictiveModel.SumOfProbabilitiesNotOneException.DecisionTree.PruningFailedToConvergeException
- is thrown when pruning fails to
converge.PredictiveModel.StateChangeException
- is thrown when an input
parameter changes that might affect the model estimates or predictions.DecisionTree.PureNodeException
- is thrown when attempting to split a node that
is already pure.PredictiveModel.SumOfProbabilitiesNotOneException
- is thrown when
the sum of probabilities is not approximately one.DecisionTree.MaxTreeSizeExceededException
- is thrown when the maximum tree size
has been exceeded.public int getNumberOfComplexityValues()
int
, the number of cost complexity valuespublic int[] getNodeAssigments(double[][] testData)
testData
- a double
matrix containing the test data
testData
must have the same column structure and type as the
training data.
int
array containing the (0-based) terminal node
id's for each observation in testData
public double[] getCostComplexityValues()
double
array containing the cost-complexity values
The cost-complexity values are found via the optimal pruning algorithm of Breiman, et. al.
public void setCostComplexityValues(double[] gammas)
DecisionTree.fitModel()
when
DecisionTree.isAutoPruningFlag()
returns true
.gammas
- double
array containing cost-complexity
values. This method is used when copying the configuration of one tree to
another.
Default: gammas
=
DecisionTree.setMinCostComplexityValue(double)
.
public Tree getDecisionTree() throws PredictiveModel.StateChangeException
Tree
object.Tree
object containing the tree structure
informationPredictiveModel.StateChangeException
- an input parameter has
changed that might affect the model estimates or predictions.public double getFittedMeanSquaredError() throws PredictiveModel.StateChangeException
double
equal to the mean squared error between the
fitted value and the actual value of the response variable in the
training dataPredictiveModel.StateChangeException
- an input parameter has
changed that might affect the model estimates or predictions.public int getMaxDepth()
int
, the maximum depth a tree is allowed to havepublic void setMaxDepth(int nLevels)
nLevels
- an int
specifying the maximum depth that the
DecisionTree
is allowed to have. nLevels
should
be strictly positive.
Default: nLevels
= 10.
public int getMaxNodes()
TreeNode
instances allowed in a
tree.int
, the maximum number of nodes allowed in a
treepublic void setMaxNodes(int maxNodes)
maxNodes
- an int
specifying the maximum number of
nodes allowed in a tree
Default: maxNodes
= 100.
public double getMeanSquaredPredictionError() throws PredictiveModel.StateChangeException
double
equal to the mean squared error between the
predicted value and the actual value of the response variable. The error
is the in-sample fitted error if predict
is first called
with no arguments. Otherwise, the error is relative to the test data
provided in the call to predict
.PredictiveModel.StateChangeException
- an input parameter has
changed that might affect the model estimates or predictions.public int getMinObsPerChildNode()
int
, the minimum number of observations that are
required for any child node before performing a splitpublic void setMinObsPerChildNode(int nObs)
nObs
- an int
specifying the minimum number of
observations that a child node must have in order to split the current
node. nObs
must be strictly positive. nObs
must
also be greater than the minimum number of observations required before a
node can split DecisionTree.setMinObsPerNode(int)
.
Default: nObs
= 7.
public int getMinObsPerNode()
int
indicating the minimum number of observations
that are required in a node before performing a split.public void setMinObsPerNode(int nObs)
nObs
- an int
specifying the number of observations the
current node must have before considering a split. nObs
should be greater than 1 but less than or equal to the number of
observations in xy
.
Default: nObs
= 21.
public int getNumberOfRandomFeatures()
randomFeatureSelection
=true
.int
, the number of random featurespublic void setNumberOfRandomFeatures(int numberOfRandomFeatures)
numberOfRandomFeatures
- an int
, the number of
predictors in the random subset
Default: numberOfFeatures=\(\sqrt(p) \) for classification problems, and p/3 for regression problems, where p is the number of predictors in the training data
public void setRandomFeatureSelection(boolean selectRandomFeatures)
selectRandomFeatures
- a boolean
, indicating whether or
not to select random features
Default: selectRandomFeaturesfalse
public boolean isRandomFeatureSelection()
boolean
, the value of the flag. If
true
, the set of variables considered at each node is
randomly selected. If the flag is false
, all variables are
considered at each node.public boolean isAutoPruningFlag()
boolean
, the value of the auto-pruning flag. If
true
, the model is configured to automatically prune the
decision tree. If the flag is false
, no pruning is
performed.public void setAutoPruningFlag(boolean autoPruningFlag)
The default value is false
. Set to true
before
calling DecisionTree.fitModel()
in order to prune the tree
automatically. The pruning will use the cost-complexity value equal to
minCostComplexityValue
. See also
DecisionTree.pruneTree(double)
which prunes the tree using a
given cost-complexity value.
autoPruningFlag
- a boolean
, specifying the value of
the flag. If true
, the maximally grown tree should be
automatically pruned in DecisionTree.fitModel()
Default: autoPruningFlag
=false
.
public double[] predict() throws PredictiveModel.StateChangeException
predict
in class PredictiveModel
double
array containing fitted values of the
response variable using the most recently grown decision tree. To
populate fitted values, use the predict
method without
arguments.PredictiveModel.StateChangeException
- is thrown when an input
parameter changes that might affect the model estimates or predictions.public double[] predict(double[][] testData) throws PredictiveModel.StateChangeException
predict
in class PredictiveModel
testData
- a double
matrix containing test data for
which predictions are to be made using the current tree.
testData
must have the same number of columns in the same
arrangement as xy
.double
array containing predicted valuesPredictiveModel.StateChangeException
- is thrown when an input
parameter changes that might affect the model estimates or predictions.public double[] predict(double[][] testData, double[] testDataWeights) throws PredictiveModel.StateChangeException
predict
in class PredictiveModel
testData
- a double
matrix containing test data for
which predictions are to be made using the current tree.
testData
must have the same number of columns in the same
arrangement as xy
.testDataWeights
- a double
array containing weights for
each row of testData
double
array containing predicted valuesPredictiveModel.StateChangeException
- is thrown when an input
parameter changes that might affect the model estimates or predictions.public void setMinCostComplexityValue(double minCostComplexity)
minCostComplexity
- a double
indicating the smallest
value to use in cost-complexity pruning. The value must be in [0.0, 1.0].
Default: minCostComplexity
= 0.
public double getMinCostComplexityValue()
public void printDecisionTree(boolean printMaxTree)
This method uses default values for the variable labels when printing
(see printDecisionTree
(String, String[], String[], String[], boolean) for these values.)
printMaxTree
- a boolean
indicating that the maximal
tree should be printed. When true
, the maximal tree is
printed. Otherwise, the pruned tree is printed.public void printDecisionTree(String responseName, String[] predictorNames, String[] classNames, String[] categoryNames, boolean printMaxTree)
responseName
- a String
specifying a name for the
response variable
If null
, the default value is used.
Default: responseName
= Y
predictorNames
- a String
array specifying names for
the predictor variables
If null
, the default value is used.
Default: predictorNames
= X0, X1, ...
classNames
- a String
array specifying names for the
class levels
If null
, the default value is used.
Default: classNames
= 0, 1, 2, ...
categoryNames
- a String
array specifying names for the
categories of the predictor variables
If null
, the default value is used.
Default: categoryNames
= 0, 1, 2, ...
printMaxTree
- a boolean
indicating that the maximal
tree should be printed. When true
, the maximal tree is
printed. Otherwise, the pruned tree is printed.protected void setConfiguration(PredictiveModel pm)
PredictiveModel
to that of the
input model.setConfiguration
in class PredictiveModel
pm
- a PredictiveModel
objectpublic void pruneTree(double gamma)
The method implements the optimal pruning algorithm 10.1, page 294 in Breiman, et. al (1984). The result of the algorithm is a sequence of sub-trees \(T_{\max}\succ{T_1}\succ{T_2}\succ{T_{M-1}} \succ{\{t_0\}}\) obtained by pruning the fully generated tree, \(T_{\max}\) , until the sub-tree consists of the single root node, \(\{t_0\}\). Corresponding to the sequence of sub-trees is the sequence of complexity values, \(0\leq \delta_{\min}\lt\delta_1\lt\delta_2\lt\ldots\lt\delta_{M-1}\lt\delta_M \) where M is the number of steps it takes in the algorithm to reach the root node. The sub-trees represent the optimally pruned sub-trees for the sequence of complexity values.
The effect of the pruning is stored in the tree's terminal node array.
That is, when the algorithm determines that the tree should be pruned at
a particular node, it sets that node to be a terminal node using the
method
Tree.setTerminalNode(int, boolean)
.
No other changes are made to the tree structure so that the maximal tree can still be printed and reviewed. However, once a tree is pruned, all the predictions will use the pruned tree.
gamma
- a double
giving the value of the
cost-complexity parameterCopyright © 2020 Rogue Wave Software. All rights reserved.