Class DecisionTree

java.lang.Object
com.imsl.datamining.PredictiveModel
com.imsl.datamining.decisionTree.DecisionTree
All Implemented Interfaces:
Serializable, Cloneable
Direct Known Subclasses:
CHAID, DecisionTreeInfoGain, QUEST

public abstract class DecisionTree extends PredictiveModel implements Serializable, Cloneable

Abstract class for generating a decision tree for a single response variable and one or more predictor variables.

Tree Generation Methods

This package contains four of the most widely used algorithms for decision trees (C45, ALACART, CHAID, and QUEST). The user may also provide an alternate algorithm by extending the DecisionTree or DecisionTreeInfoGain abstract class and implementing the abstract method selectSplitVariable(double[][], double[], double[], double[], double[], int[]).

Optimal Tree Size

Minimum Cost-complexity Pruning

A strategy to address overfitting is to grow the tree as large as possible, and then use some logic to prune it back. Let T represent a decision tree generated by any of the methods above. The idea (from Breiman, et. al.) is to find the smallest sub-tree of T that minimizes the cost-complexity measure:

$$R_{\delta}(T)=R(T)+\delta|\tilde{T}|\mbox{,} $$

\(\tilde{T}\) denotes the set of terminal nodes. \(|\tilde{T}|\) represents the number of terminal nodes, and \(\delta\geq0\) is a cost-complexity parameter. For a categorical target variable

$$R(T)=\sum_{t\in\tilde{T}}{R(t)}= \sum_{t\in\tilde{T}}{r(t)p(t)}$$ $$r(t)=\min_i{\sum_j{C(i|j)p(j|t)}}\) $$ $$p(t)=\mbox{Pr}[x\in t] \mbox{and}\;p(j|t)=\mbox{Pr}[y=j|x\in t]\mbox{,} $$

and \(C(i|j)\) is the cost for misclassifying the actual class j as i. Note that \(C(j|j)=0\) and \(C(i|j)\gt0\), for \(i\neq j\).

When the target is continuous (and the problem is a regression problem), the metric is instead the mean squared error

$$R(T)=\sum_{t\in{\tilde{T}}}R(t)=\frac{1}{N} \sum_{t\in{\tilde{T}}}\sum_{y_n\in{t}}(y_n-\hat{y}(t))^2$$

This class implements the optimal pruning algorithm 10.1, page 294 in Breiman, et. al (1984). The result of the algorithm is a sequence of sub-trees \(T_{\max}\succ{T_1}\succ{T_2}\succ{T_{M-1}} \succ{\{t_0\}}\) obtained by pruning the fully generated tree, \(T_{\max}\) , until the sub-tree consists of the single root node, \(\{t_0\}\). Corresponding to the sequence of sub-trees is the sequence of complexity values, \(0\leq \delta_{\min}\lt\delta_1\lt\delta_2\lt\ldots\lt\delta_{M-1}\lt\delta_M \) where M is the number of steps it takes in the algorithm to reach the root node. The sub-trees represent the optimally-pruned sub-trees for the sequence of complexity values. The minimum complexity \(\delta_{\min}\) can be set via an optional argument.

V-Fold Cross-Validation

The CrossValidation class can be used for model validation.

Prediction

Ensemble Methods

The BootstrapAggregation class provides predictions through an ensemble of fitted trees, where the training is done on bootstrap samples.

The Apriori class provides predictions through an ensemble of trees trained on random subsets of the data and iteratively refined using the stochastic gradient boosting algorithm.

The RandomTrees class provides predictions through an ensemble of fitted trees, where the training is done on bootstrap samples and random subsets of predictors.

Missing Values

Any observation or case with a missing response variable is eliminated from the analysis. If a predictor has a missing value, each algorithm skips that case when evaluating the given predictor. When making a prediction for a new case, if the split variable is missing, the prediction function applies surrogate split-variables and splitting rules in turn, if they are estimated with the decision tree. Otherwise, the prediction function returns the prediction from the most recent non-terminal node. In this implementation, only ALACART estimates surrogate split variables when requested.

See Also:
  • Constructor Details

    • DecisionTree

      public DecisionTree(double[][] xy, int responseColumnIndex, PredictiveModel.VariableType[] varType)
      Constructs a DecisionTree object for a single response variable and multiple predictor variables.
      Parameters:
      xy - a double matrix containing the training data and associated response values
      responseColumnIndex - an int specifying the column index of the response variable
      varType - a PredictiveModel.VariableType array containing the type of each variable
  • Method Details

    • selectSplitVariable

      protected abstract int selectSplitVariable(double[][] xy, double[] classCounts, double[] parentFreq, double[] splitValue, double[] splitCriterionValue, int[] splitPartition)
      Abstract method for selecting the next split variable and split definition for the node.
      Parameters:
      xy - a double matrix containing the data
      classCounts - a double array containing the counts for each class of the response variable, when it is categorical
      parentFreq - a double array used to indicate which subset of the observations belong in the current node
      splitValue - a double array representing the resulting split point if the selected variable is quantitative
      splitCriterionValue - a double, the value of the criterion used to determine the splitting variable
      splitPartition - an int array indicating the resulting split partition if the selected variable is categorical
      Returns:
      an int specifying the index of the split variable in this.getPredictorIndexes()
    • fitModel

      Fits the decision tree. Implements the abstract method.
      Overrides:
      fitModel in class PredictiveModel
      Throws:
      PredictiveModel.PredictiveModelException - is thrown when an exception occurs in the com.imsl.datamining.PredictiveModel. Superclass exceptions should be considered such as com.imsl.datamining.PredictiveModel.StateChangeException and com.imsl.datamining.PredictiveModel.SumOfProbabilitiesNotOneException.
      DecisionTree.PruningFailedToConvergeException - is thrown when pruning fails to converge.
      PredictiveModel.StateChangeException - is thrown when an input parameter changes that might affect the model estimates or predictions.
      DecisionTree.PureNodeException - is thrown when attempting to split a node that is already pure.
      PredictiveModel.SumOfProbabilitiesNotOneException - is thrown when the sum of probabilities is not approximately one.
      DecisionTree.MaxTreeSizeExceededException - is thrown when the maximum tree size has been exceeded.
    • getNumberOfComplexityValues

      public int getNumberOfComplexityValues()
      Returns the number of cost complexity values determined by the pruning algorithm.
      Returns:
      an int, the number of cost complexity values
    • getNodeAssigments

      public int[] getNodeAssigments(double[][] testData)
      Returns the terminal node assignments for each row of the test data.
      Parameters:
      testData - a double matrix containing the test data

      testData must have the same column structure and type as the training data.

      Returns:
      an int array containing the (0-based) terminal node id's for each observation in testData
    • getCostComplexityValues

      public double[] getCostComplexityValues()
      Returns an array containing cost-complexity values.
      Returns:
      a double array containing the cost-complexity values

      The cost-complexity values are found via the optimal pruning algorithm of Breiman, et. al.

    • setCostComplexityValues

      public void setCostComplexityValues(double[] gammas)
      Sets the cost-complexity values. For the original tree, the values are generated in fitModel() when isAutoPruningFlag() returns true.
      Parameters:
      gammas - double array containing cost-complexity values. This method is used when copying the configuration of one tree to another.

      Default: gammas = setMinCostComplexityValue(double).

    • getDecisionTree

      public Tree getDecisionTree() throws PredictiveModel.StateChangeException
      Returns a Tree object.
      Returns:
      a Tree object containing the tree structure information
      Throws:
      PredictiveModel.StateChangeException - an input parameter has changed that might affect the model estimates or predictions.
    • getFittedMeanSquaredError

      public double getFittedMeanSquaredError() throws PredictiveModel.StateChangeException
      Returns the mean squared error on the training data.
      Returns:
      a double equal to the mean squared error between the fitted value and the actual value of the response variable in the training data
      Throws:
      PredictiveModel.StateChangeException - an input parameter has changed that might affect the model estimates or predictions.
    • getMaxDepth

      public int getMaxDepth()
      Returns the maximum depth a tree is allowed to have.
      Returns:
      an int, the maximum depth a tree is allowed to have
    • setMaxDepth

      public void setMaxDepth(int nLevels)
      Sets the maximum tree depth allowed.
      Parameters:
      nLevels - an int specifying the maximum depth that the DecisionTree is allowed to have. nLevels should be strictly positive.

      Default: nLevels = 10.

    • getMaxNodes

      public int getMaxNodes()
      Returns the maximum number of TreeNode instances allowed in a tree.
      Returns:
      an int, the maximum number of nodes allowed in a tree
    • setMaxNodes

      public void setMaxNodes(int maxNodes)
      Sets the maximum number of nodes allowed in a tree.
      Parameters:
      maxNodes - an int specifying the maximum number of nodes allowed in a tree

      Default: maxNodes = 100.

    • getMeanSquaredPredictionError

      public double getMeanSquaredPredictionError() throws PredictiveModel.StateChangeException
      Returns the mean squared error.
      Returns:
      a double equal to the mean squared error between the predicted value and the actual value of the response variable. The error is the in-sample fitted error if predict is first called with no arguments. Otherwise, the error is relative to the test data provided in the call to predict.
      Throws:
      PredictiveModel.StateChangeException - an input parameter has changed that might affect the model estimates or predictions.
    • getMinObsPerChildNode

      public int getMinObsPerChildNode()
      Returns the minimum number of observations that are required for any child node before performing a split.
      Returns:
      an int, the minimum number of observations that are required for any child node before performing a split
    • setMinObsPerChildNode

      public void setMinObsPerChildNode(int nObs)
      Sets the minimum number of observations that a child node must have in order to split.
      Parameters:
      nObs - an int specifying the minimum number of observations that a child node must have in order to split the current node. nObs must be strictly positive. nObs must also be greater than the minimum number of observations required before a node can split setMinObsPerNode(int).

      Default: nObs = 7.

    • getMinObsPerNode

      public int getMinObsPerNode()
      Returns the minimum number of observations that are required in a node before performing a split.
      Returns:
      an int indicating the minimum number of observations that are required in a node before performing a split.
    • setMinObsPerNode

      public void setMinObsPerNode(int nObs)
      Sets the minimum number of observations a node must have to allow a split.
      Parameters:
      nObs - an int specifying the number of observations the current node must have before considering a split. nObs should be greater than 1 but less than or equal to the number of observations in xy.

      Default: nObs = 21.

    • getNumberOfRandomFeatures

      public int getNumberOfRandomFeatures()
      Returns the number of random features used in the splitting rules when randomFeatureSelection=true.
      Returns:
      an int, the number of random features
    • setNumberOfRandomFeatures

      public void setNumberOfRandomFeatures(int numberOfRandomFeatures)
      Sets the number of predictors in the random subset to select from at each node.
      Parameters:
      numberOfRandomFeatures - an int, the number of predictors in the random subset

      Default: numberOfFeatures=\(\sqrt(p) \) for classification problems, and p/3 for regression problems, where p is the number of predictors in the training data

    • setRandomFeatureSelection

      public void setRandomFeatureSelection(boolean selectRandomFeatures)
      Sets the flag to select split variables from a random subset of the features.
      Parameters:
      selectRandomFeatures - a boolean, indicating whether or not to select random features

      Default: selectRandomFeaturesfalse

    • isRandomFeatureSelection

      public boolean isRandomFeatureSelection()
      Returns the current setting of the boolean to perform random feature selection.
      Returns:
      a boolean, the value of the flag. If true, the set of variables considered at each node is randomly selected. If the flag is false, all variables are considered at each node.
    • isAutoPruningFlag

      public boolean isAutoPruningFlag()
      Returns the current setting of the boolean to automatically prune the decision tree.
      Returns:
      a boolean, the value of the auto-pruning flag. If true, the model is configured to automatically prune the decision tree. If the flag is false, no pruning is performed.
    • setAutoPruningFlag

      public void setAutoPruningFlag(boolean autoPruningFlag)
      Sets the flag to automatically prune the tree during the fitting procedure.

      The default value is false. Set to true before calling fitModel() in order to prune the tree automatically. The pruning will use the cost-complexity value equal to minCostComplexityValue. See also pruneTree(double) which prunes the tree using a given cost-complexity value.

      Parameters:
      autoPruningFlag - a boolean, specifying the value of the flag. If true, the maximally grown tree should be automatically pruned in fitModel()

      Default: autoPruningFlag=false.

    • predict

      public double[] predict() throws PredictiveModel.StateChangeException
      Predicts the training examples (in-sample predictions) using the most recently grown tree.
      Specified by:
      predict in class PredictiveModel
      Returns:
      a double array containing fitted values of the response variable using the most recently grown decision tree. To populate fitted values, use the predict method without arguments.
      Throws:
      PredictiveModel.StateChangeException - is thrown when an input parameter changes that might affect the model estimates or predictions.
    • predict

      public double[] predict(double[][] testData) throws PredictiveModel.StateChangeException
      Predicts new data using the most recently grown decision tree.
      Specified by:
      predict in class PredictiveModel
      Parameters:
      testData - a double matrix containing test data for which predictions are to be made using the current tree. testData must have the same number of columns in the same arrangement as xy.
      Returns:
      a double array containing predicted values
      Throws:
      PredictiveModel.StateChangeException - is thrown when an input parameter changes that might affect the model estimates or predictions.
    • predict

      public double[] predict(double[][] testData, double[] testDataWeights) throws PredictiveModel.StateChangeException
      Predicts new weighted data using the most recently grown decision tree.
      Specified by:
      predict in class PredictiveModel
      Parameters:
      testData - a double matrix containing test data for which predictions are to be made using the current tree. testData must have the same number of columns in the same arrangement as xy.
      testDataWeights - a double array containing weights for each row of testData
      Returns:
      a double array containing predicted values
      Throws:
      PredictiveModel.StateChangeException - is thrown when an input parameter changes that might affect the model estimates or predictions.
    • setMinCostComplexityValue

      public void setMinCostComplexityValue(double minCostComplexity)
      Sets the value of the minimum cost-complexity value.
      Parameters:
      minCostComplexity - a double indicating the smallest value to use in cost-complexity pruning. The value must be in [0.0, 1.0].

      Default: minCostComplexity = 0.

    • getMinCostComplexityValue

      public double getMinCostComplexityValue()
    • printDecisionTree

      public void printDecisionTree(boolean printMaxTree)
      Prints the contents of the decision tree using distinct but general labels.

      This method uses default values for the variable labels when printing (see printDecisionTree (String, String[], String[], String[], boolean) for these values.)

      Parameters:
      printMaxTree - a boolean indicating that the maximal tree should be printed. When true, the maximal tree is printed. Otherwise, the pruned tree is printed.
    • printDecisionTree

      public void printDecisionTree(String responseName, String[] predictorNames, String[] classNames, String[] categoryNames, boolean printMaxTree)
      Prints the contents of the decision tree using labels.
      Parameters:
      responseName - a String specifying a name for the response variable

      If null, the default value is used.

      Default: responseName = Y

      predictorNames - a String array specifying names for the predictor variables

      If null, the default value is used.

      Default: predictorNames = X0, X1, ...

      classNames - a String array specifying names for the class levels

      If null, the default value is used.

      Default: classNames = 0, 1, 2, ...

      categoryNames - a String array specifying names for the categories of the predictor variables

      If null, the default value is used.

      Default: categoryNames = 0, 1, 2, ...

      printMaxTree - a boolean indicating that the maximal tree should be printed. When true, the maximal tree is printed. Otherwise, the pruned tree is printed.
    • setConfiguration

      protected void setConfiguration(PredictiveModel pm)
      Sets the configuration of PredictiveModel to that of the input model.
      Specified by:
      setConfiguration in class PredictiveModel
      Parameters:
      pm - a PredictiveModel object
    • pruneTree

      public void pruneTree(double gamma)
      Finds the minimum cost-complexity decision tree for the cost-complexity value, gamma.

      The method implements the optimal pruning algorithm 10.1, page 294 in Breiman, et. al (1984). The result of the algorithm is a sequence of sub-trees \(T_{\max}\succ{T_1}\succ{T_2}\succ{T_{M-1}} \succ{\{t_0\}}\) obtained by pruning the fully generated tree, \(T_{\max}\) , until the sub-tree consists of the single root node, \(\{t_0\}\). Corresponding to the sequence of sub-trees is the sequence of complexity values, \(0\leq \delta_{\min}\lt\delta_1\lt\delta_2\lt\ldots\lt\delta_{M-1}\lt\delta_M \) where M is the number of steps it takes in the algorithm to reach the root node. The sub-trees represent the optimally pruned sub-trees for the sequence of complexity values.

      The effect of the pruning is stored in the tree's terminal node array. That is, when the algorithm determines that the tree should be pruned at a particular node, it sets that node to be a terminal node using the method Tree.setTerminalNode(int, boolean).

      No other changes are made to the tree structure so that the maximal tree can still be printed and reviewed. However, once a tree is pruned, all the predictions will use the pruned tree.

      Parameters:
      gamma - a double giving the value of the cost-complexity parameter