public abstract class DecisionTree extends PredictiveModel implements Serializable, Cloneable
Abstract class for generating a decision tree for a single response variable and one or more predictor variables.
This package contains four of the most widely used algorithms for decision
trees (C45, ALACART, CHAID, and QUEST). The
user may also provide an alternate algorithm by extending the
DecisionTree or DecisionTreeInfoGain abstract class and
implementing the abstract method selectSplitVariable(
double[][], double[], double[], double[], int[]).
A strategy to address overfitting is to grow the tree as large as possible, and then use some logic to prune it back. Let T represent a decision tree generated by any of the methods above. The idea (from Breiman, et. al.) is to find the smallest sub-tree of T that minimizes the cost-complexity measure:
![]()
denotes the set of terminal nodes.
represents the number of terminal nodes, and
is a cost-complexity parameter. For a
categorical target variable
![]()
![]()
![]()
![]()
and
is the cost for misclassifying the actual
class j as i. Note that
and
, for
.
When the target is continuous (and the problem is a regression problem), the metric is instead the mean squared error
![]()
This class implements the optimal pruning algorithm 10.1, page 294 in
Breiman, et. al (1984). The result of the algorithm is a sequence of
sub-trees
obtained by pruning the fully generated tree,
, until the sub-tree consists of the single
root node,
. Corresponding to the sequence of
sub-trees is the sequence of complexity values,
where M is the number of steps it takes in the algorithm
to reach the root node. The sub-trees represent the optimally-pruned
sub-trees for the sequence of complexity values. The minimum complexity
can be set via an optional argument.
The CrossValidation class can be used for model validation.
The BootstrapAggregation class provides predictions through a
resampling scheme.
Any observation or case with a missing response variable is eliminated from
the analysis. If a predictor has a missing value, each algorithm will skip
that case when evaluating the given predictor. When making a prediction for a
new case, if the split variable is missing, the prediction function applies
surrogate split-variables and splitting rules in turn, if they are
estimated with the decision tree. Otherwise, the prediction function returns
the prediction from the most recent non-terminal node. In this
implementation, only ALACART estimates surrogate split variables when
requested.
| Modifier and Type | Class and Description |
|---|---|
static class |
DecisionTree.MaxTreeSizeExceededException
Exception thrown when the maximum tree size has been exceeded.
|
static class |
DecisionTree.PruningFailedToConvergeException
Exception thrown when pruning fails to converge.
|
static class |
DecisionTree.PureNodeException
Exception thrown when attempting to split a node that is already pure
(response variable is constant).
|
PredictiveModel.PredictiveModelException, PredictiveModel.StateChangeException, PredictiveModel.SumOfProbabilitiesNotOneException, PredictiveModel.VariableType| Constructor and Description |
|---|
DecisionTree(double[][] xy,
int responseColumnIndex,
PredictiveModel.VariableType[] varType)
Constructs a
DecisionTree object for a single response
variable and multiple predictor variables. |
| Modifier and Type | Method and Description |
|---|---|
void |
fitModel()
Fits the decision tree.
|
double[] |
getCostComplexityValues()
Returns an array containing cost-complexity values.
|
Tree |
getDecisionTree()
Returns a
Tree object. |
double |
getFittedMeanSquaredError()
Returns the mean squared error on the training data.
|
int |
getMaxDepth()
Returns the maximum depth a tree is allowed to have.
|
int |
getMaxNodes()
Returns the maximum number of
TreeNode instances allowed in a
tree. |
double |
getMeanSquaredPredictionError()
Returns the mean squared error.
|
int |
getMinObsPerChildNode()
Returns the minimum number of observations that are required for any
child node before performing a split.
|
int |
getMinObsPerNode()
Returns the minimum number of observations that are required in a node
before performing a split.
|
int |
getNumberOfComplexityValues()
Return the number of cost complexity values determined.
|
protected int |
getNumberOfSets(double[] parentFreqs,
int[] splita)
Returns the number of sets for a split.
|
boolean |
isAutoPruningFlag()
Returns the auto-pruning flag.
|
double[] |
predict()
Predicts the training examples (in-sample predictions) using the most
recently grown tree.
|
double[] |
predict(double[][] testData)
Predicts new data using the most recently grown decision tree.
|
double[] |
predict(double[][] testData,
double[] testDataWeights)
Predicts new weighted data using the most recently grown decision tree.
|
void |
printDecisionTree(boolean printMaxTree)
Prints the contents of the Decision Tree using distinct but general
labels.
|
void |
printDecisionTree(String responseName,
String[] predictorNames,
String[] classNames,
String[] categoryNames,
boolean printMaxTree)
Prints the contents of the Decision Tree.
|
void |
pruneTree(double gamma)
Finds the minimum cost-complexity decision tree for the cost-complexity
value, gamma.
|
protected abstract int |
selectSplitVariable(double[][] xy,
double[] classCounts,
double[] parentFreq,
double[] splitValue,
int[] splitPartition)
Abstract method for selecting the next split variable and split
definition for the node.
|
void |
setAutoPruningFlag(boolean autoPruningFlag)
Sets the flag to automatically prune the tree during the fitting
procedure.
|
protected void |
setConfiguration(PredictiveModel pm)
Sets the configuration of
PredictiveModel to that of the
input model. |
void |
setCostComplexityValues(double[] gammas)
Sets the cost-complexity values.
|
void |
setMaxDepth(int nLevels)
Specifies the maximum tree depth allowed.
|
void |
setMaxNodes(int maxNodes)
Sets the maximum number of nodes allowed in a tree.
|
void |
setMinCostComplexityValue(double minCostComplexity)
Sets the value of the minimum cost-complexity value.
|
void |
setMinObsPerChildNode(int nObs)
Specifies the minimum number of observations that a child node must have
in order to split, one of several tree size and splitting control
parameters.
|
void |
setMinObsPerNode(int nObs)
Specifies the minimum number of observations a node must have to allow a
split, one of several tree size and splitting control parameters.
|
getClassCounts, getCostMatrix, getMaxNumberOfCategories, getNumberOfClasses, getNumberOfColumns, getNumberOfMissing, getNumberOfPredictors, getNumberOfRows, getNumberOfUniquePredictorValues, getPredictorIndexes, getPredictorTypes, getPrintLevel, getPriorProbabilities, getResponseColumnIndex, getResponseVariableAverage, getResponseVariableMostFrequentClass, getResponseVariableType, getTotalWeight, getVariableType, getWeights, getXY, isMustFitModelFlag, isUserFixedNClasses, setClassCounts, setCostMatrix, setMaxNumberOfCategories, setNumberOfClasses, setPredictorIndex, setPredictorTypes, setPrintLevel, setPriorProbabilities, setWeightspublic DecisionTree(double[][] xy,
int responseColumnIndex,
PredictiveModel.VariableType[] varType)
DecisionTree object for a single response
variable and multiple predictor variables.xy - a double matrix with rows containing the
observations on the predictor variables and one response variable.responseColumnIndex - an int specifying the column
index of the response variable.varType - a PredictiveModel.VariableType
array containing the type of each variable.public void fitModel()
throws PredictiveModel.PredictiveModelException,
DecisionTree.PruningFailedToConvergeException,
PredictiveModel.StateChangeException,
DecisionTree.PureNodeException,
PredictiveModel.SumOfProbabilitiesNotOneException,
DecisionTree.MaxTreeSizeExceededException
fitModel in class PredictiveModelPredictiveModel.PredictiveModelException - an exception has
occurred in the com.imsl.datamining.PredictiveModel. Superclass
exceptions should be considered such as
com.imsl.datamining.PredictiveModel.StateChangeException and
com.imsl.datamining.PredictiveModel.SumOfProbabilitiesNotOneException}.DecisionTree.PruningFailedToConvergeException - pruning has failed to converge.PredictiveModel.StateChangeException - an input parameter has
changed that might affect the model estimates or predictions.DecisionTree.PureNodeException - attempting to split a node that is already
pure.PredictiveModel.SumOfProbabilitiesNotOneException - the sum of
probabilities is not approximately one.DecisionTree.MaxTreeSizeExceededException - the maximum tree size has been
exceeded.public double[] getCostComplexityValues()
throws DecisionTree.PruningFailedToConvergeException,
PredictiveModel.StateChangeException
double array containing the cost-complexity
values.
The cost-complexity values are found via the optimal pruning algorithm of Breiman, et. al.
DecisionTree.PruningFailedToConvergeException - pruning has failed to converge.PredictiveModel.StateChangeException - an input parameter has
changed that might affect the model estimates or predictions.public Tree getDecisionTree() throws PredictiveModel.StateChangeException
Tree object.Tree object containing the tree structure
information.PredictiveModel.StateChangeException - an input parameter has
changed that might affect the model estimates or predictions.public double getFittedMeanSquaredError()
throws PredictiveModel.StateChangeException
double equal to the mean squared error between the
fitted value and the actual value of the response variable in the
training data.PredictiveModel.StateChangeException - an input parameter has
changed that might affect the model estimates or predictions.public int getMaxDepth()
int indicating the maximum depth a tree is
allowed to have.public int getMaxNodes()
TreeNode instances allowed in a
tree.int indicating the maximum number of nodes
allowed in a tree.public double getMeanSquaredPredictionError()
throws PredictiveModel.StateChangeException
double equal to the mean squared error between the
predicted value and the actual value of the response variable. The error
is the in-sample fitted error if predict is first called
with no arguments. Otherwise, the error is relative to the test data
provided in the call to predict.PredictiveModel.StateChangeException - an input parameter has
changed that might affect the model estimates or predictions.public int getMinObsPerChildNode()
int indicating the minimum number of observations
that are required for any child node before performing a split.public int getMinObsPerNode()
int indicating the minimum number of observations
that are required in a node before performing a split.public int getNumberOfComplexityValues()
int indicating the number of cost complexity
values determined.protected int getNumberOfSets(double[] parentFreqs,
int[] splita)
parentFreqs - a double array containing frequencies of
the response variable in the data subset of the parent node.splita - an int array that contains the split partition
determined in the selectSplitVariable(double[][], double[], double[], double[], int[]) method.int that is the number of sets.public boolean isAutoPruningFlag()
setAutoPruningFlag(boolean) for
details.boolean which if true means that the
model is configured to perform automatic pruning.public double[] predict()
throws PredictiveModel.StateChangeException
predict in class PredictiveModeldouble array of fitted values of the response
variable using the most recently grown decision tree. To populate fitted
values, use the predict method without arguments.PredictiveModel.StateChangeException - an input parameter has
changed that might affect the model estimates or predictions.public double[] predict(double[][] testData)
throws PredictiveModel.StateChangeException
predict in class PredictiveModeltestData - a double matrix containing test data for
which predictions are to be made using the current tree.
testData must have the same number of columns and the
columns must be in the same arrangement as xy.double array of predicted values of the response
variable using the most recently grown decision tree.PredictiveModel.StateChangeException - an input parameter has
changed that might affect the model estimates or predictions.public double[] predict(double[][] testData,
double[] testDataWeights)
throws PredictiveModel.StateChangeException
predict in class PredictiveModeltestData - a double matrix containing test data for
which predictions are to be made using the current tree.
testData must have the same number of columns and the
columns must be in the same arrangement as xy.testDataWeights - a double array containing weights for
each row of testData.double array of predicted values of the response
variable using the most recently grown decision tree.PredictiveModel.StateChangeException - an input parameter has
changed that might affect the model estimates or predictions.public void printDecisionTree(boolean printMaxTree)
This method uses default values for the variable labels when printing
(see printDecisionTree
(String, String[], String[], String[], boolean) for
these values.)
printMaxTree - a boolean indicating that the maximal
tree should be printed.
Otherwise the pruned tree is printed.
public void printDecisionTree(String responseName, String[] predictorNames, String[] classNames, String[] categoryNames, boolean printMaxTree)
responseName - a String specifying a name for the
response variable.
If null, the default value is used.
Default: responseName = Y
predictorNames - a String array specifying names for
the response variables.
If null, the default value is used.
Default: predictorNames = X0, X1, ...
classNames - a String array specifying names for the
classes.
If null, the default value is used.
Default: classNames = 0, 1, 2, ...
categoryNames - a String array specifying names for the
categories.
If null, the default value is used.
Default: categoryNames = 0, 1, 2, ...
printMaxTree - a boolean indicating that the maximal
tree should be printed.
Otherwise the pruned tree is printed.
public void pruneTree(double gamma)
The method implements the optimal pruning algorithm 10.1, page 294 in
Breiman, et. al (1984). The result of the algorithm is a sequence of
sub-trees
obtained by pruning the fully generated tree,
, until the sub-tree consists of the single
root node,
. Corresponding to the sequence of
sub-trees is the sequence of complexity values,
where M is the number of steps it takes in the
algorithm to reach the root node. The sub-trees represent the optimally
pruned sub-trees for the sequence of complexity values.
Tree.setTerminalNode(int, boolean).
No other changes are made to the tree structure so that the maximal tree
can still be printed and reviewed. However, once a tree is pruned, all
the predictions will use the pruned tree.gamma - a double equal to the cost-complexity
parameter.protected abstract int selectSplitVariable(double[][] xy,
double[] classCounts,
double[] parentFreq,
double[] splitValue,
int[] splitPartition)
xy - a double matrix containing the data.classCounts - a double array containing the counts for
each class of the response variable, when it is categorical.parentFreq - a double array used to indicate which
subset of the observations belong in the current node.splitValue - a double array representing the resulting
split point if the selected variable is quantitative.splitPartition - an int array indicating the resulting
split partition if the selected variable is categorical.int specifying the column index of the split
variable in xy.public void setAutoPruningFlag(boolean autoPruningFlag)
The default value is false. Set to true before
calling fitModel() in order to prune the tree
automatically. The pruning will use the cost-complexity value equal to
minCostComplexityValue. See also
pruneTree(double) which prunes the tree using a
given cost-complexity value.
autoPruningFlag - a boolean value that when
true indicates that the maximally grown tree should be
automatically pruned in fitModel().
Default: autoPruningFlag=false.
protected void setConfiguration(PredictiveModel pm) throws DecisionTree.PruningFailedToConvergeException, PredictiveModel.StateChangeException, PredictiveModel.SumOfProbabilitiesNotOneException
PredictiveModel to that of the
input model.setConfiguration in class PredictiveModelpm - a PredictiveModel object which is to have its
attributes duplicated in this instance.DecisionTree.PruningFailedToConvergeException - pruning has failed to converge.PredictiveModel.StateChangeException - an input parameter has
changed that might affect the model estimates or predictions.PredictiveModel.SumOfProbabilitiesNotOneException - the sum of the
probabilities does not equal 1.public void setCostComplexityValues(double[] gammas)
fitModel() when
isAutoPruningFlag() returns true.gammas - double array containing cost-complexity
values. This method is used when copying the configuration of one tree to
another.
Default: gammas =
setMinCostComplexityValue(double).
public void setMaxDepth(int nLevels)
nLevels - an int specifying the maximum depth that the
DecisionTree is allowed to have. nLevels should
be strictly positive.
Default: nLevels = 10.
public void setMaxNodes(int maxNodes)
maxNodes - an int specifying the maximum number of
nodes allowed in a tree.
Default: maxNodes = 100.
public void setMinCostComplexityValue(double minCostComplexity)
minCostComplexity - a double indicating the smallest
value to use in cost-complexity pruning. The value must be in [0.0, 1.0].
Default: minCostComplexity = 0.
public void setMinObsPerChildNode(int nObs)
nObs - an int specifying the minimum number of
observations that a child node must have in order to split the current
node. nObs must be strictly positive. nObs must
also be greater than the minimum number of observations required before a
node can split setMinObsPerNode(int).
Default: nObs = 7.
public void setMinObsPerNode(int nObs)
nObs - an int specifying the number of observations the
current node must have before considering a split. nObs
should be greater than 1 but less than or equal to the number of
observations in xy.
Default: nObs = 21.
Copyright © 1970-2015 Rogue Wave Software
Built June 18 2015.