imsl.data_mining.QUESTDecisionTree

class QUESTDecisionTree(response_col_idx, var_type, alpha=0.05, min_n_node=7, min_split=21, max_x_cats=10, max_size=100, max_depth=10, priors=None, response_name='Y', var_names=None, class_names=None, categ_names=None)

Generate a decision tree using the QUEST method.

Generate a decision tree for a single response variable and two or more predictor variables using the QUEST method.

Parameters:
  • response_col_idx (int) – Column index of the response variable.
  • var_type ((N,) array_like) –

    Array indicating the type of each variable.

    var_type[i] Type
    0 Categorical
    1 Ordered Discrete (Low, Med., High)
    2 Quantitative or Continuous
    3 Ignore this variable
  • alpha (float, optional) –

    The significance level for split variable selection. Valid values are in the range 0 < alpha < 1.0.

    Default is 0.05.

  • min_n_node (int, optional) –

    Do not split a node if one of its child nodes will have fewer than min_n_node observations.

    Default is 7.

  • min_split (int, optional) –

    Do not split a node if the node has fewer than min_split observations.

    Default is 21.

  • max_x_cats (int, optional) –

    Allow for up to max_x_cats for categorical predictor variables.

    Default is 10.

  • max_size (int, optional) –

    Stop growing the tree once it has reached max_size number of nodes.

    Default is 100.

  • max_depth (int, optional) –

    Stop growing the tree once it has reached max_depth number of levels.

    Default is 10.

  • priors ((N,) array_like, optional) – An array containing prior probabilities for class membership. The argument is ignored for continuous response variables. By default, the prior probabilities are estimated from the data.
  • response_name (string, optional) –

    A string representing the name of the response variable.

    Default is “Y”.

  • var_names (tuple, optional) –

    A tuple containing strings representing the names of predictors.

    Default is “X0”, “X1”, etc.

  • class_names (tuple, optional) –

    A tuple containing strings representing the names of the different classes in Y, assuming Y is of categorical type.

    Default is “0”, “1”, etc.

  • categ_names (tuple, optional) –

    A tuple containing strings representing the names of the different category levels for each predictor of categorical type.

    Default is “0”, “1”, etc.

Notes

The QUEST algorithm ([1]) is appropriate for a categorical response variable and predictors of either categorical or quantitative type. For each categorical predictor, imsl.data_mining.QUESTDecisionTree() performs a multi-way chi-square test of association between the predictor and Y. For every continuous predictor, imsl.data_mining.QUESTDecisionTree() performs an ANOVA test to see if the means of the predictor vary among the groups of Y. Among these tests, the variable with the most significant result is selected as a potential splitting variable, say, \(X_j\). If the p-value (adjusted for multiple tests) is less than the specified splitting threshold, then \(X_j\) is the splitting variable for the current node. If not, imsl.data_mining.QUESTDecisionTree() performs for each continuous variable X a Levene’s test of homogeneity to see if the variance of X varies within the different groups of Y. Among these tests, we again find the predictor with the most significant result, say \(X_i\). If its p-value (adjusted for multiple tests) is less than the splitting threshold, \(X_i\) is the splitting variable. Otherwise, the node is not split.

Assuming a splitting variable is found, the next step is to determine how the variable should be split. If the selected variable \(X_j\) is continuous, a split point d is determined by quadratic discriminant analysis (QDA) of \(X_j\) into two populations determined by a binary partition of the response Y. The goal of this step is to group the classes of Y into two subsets or super classes, A and B. If there are only two classes in the response Y, the super classes are obvious. Otherwise, calculate the means and variances of \(X_j\) in each of the classes of Y. If the means are all equal, put the largest-sized class into group A and combine the rest to form group B. If they are not all equal, use a k-means clustering method (k = 2) on the class means to determine A and B.

\(X_j\) in A and in B is assumed to be normally distributed with estimated means \(\bar{x}_{j|A}\), \(\bar{x}_{j|B}\), and variances \(S^2_j|A\), \(S^2_j|B\), respectively. The quadratic discriminant is the partition \(X_j\le d\) and \(X_j\gt d\) such that \(\mbox{Pr}\left(X_j,A\right)=\mbox{Pr}\left(X_j,B\right)\). The discriminant rule assigns an observation to A if \(x_{ij}\le d\) and to B if \(x_{ij}\gt d\). For d to maximally discriminate, the probabilities must be equal.

If the selected variable \(X_j\) is categorical, it is first transformed using the method outlined in Loh and Shih ([1]), and then QDA is performed as above. The transformation is related to the discriminant coordinate (CRIMCOORD) approach due to Gnanadesikan ([2]).

References

[1](1, 2) Loh, W.-Y. and Shih, Y.-S. (1997). Split Selection Methods for Classification Trees, Statistica Sinica, 7, 815-840. For information on the latest version of QUEST see: http://www.stat.wisc.edu/~loh/quest.html.
[2]Gnanadesikan, R. (1977). Methods for Statistical Data Analysis of Multivariate Observations. Wiley. New York.

Examples

This example applies the QUEST method to a simulated dataset with 50 cases and three predictors of mixed-type. A maximally grown tree under the default controls and the optimally pruned sub-tree obtained from cross-validation and minimal cost complexity pruning are produced. Notice that the optimally pruned tree consts of just the root node, whereas the maximal tree has five nodes and three levels.

>>> import numpy as np
>>> import imsl.data_mining as dm
>>> xy = np.array([[2.0, 25.928690, 0.0, 0.0],
...                [1.0, 51.632450, 1.0, 1.0],
...                [1.0, 25.784321, 0.0, 2.0],
...                [0.0, 39.379478, 0.0, 3.0],
...                [2.0, 24.650579, 0.0, 2.0],
...                [2.0, 45.200840, 0.0, 2.0],
...                [2.0, 52.679600, 1.0, 3.0],
...                [1.0, 44.283421, 1.0, 3.0],
...                [2.0, 40.635231, 1.0, 3.0],
...                [2.0, 51.760941, 0.0, 3.0],
...                [2.0, 26.303680, 0.0, 1.0],
...                [2.0, 20.702299, 1.0, 0.0],
...                [2.0, 38.742729, 1.0, 3.0],
...                [2.0, 19.473330, 0.0, 0.0],
...                [1.0, 26.422110, 0.0, 0.0],
...                [2.0, 37.059860, 1.0, 0.0],
...                [1.0, 51.670429, 1.0, 3.0],
...                [0.0, 42.401562, 0.0, 3.0],
...                [2.0, 33.900269, 1.0, 2.0],
...                [1.0, 35.432819, 0.0, 0.0],
...                [1.0, 44.303692, 0.0, 1.0],
...                [0.0, 46.723869, 0.0, 2.0],
...                [1.0, 46.992619, 0.0, 2.0],
...                [0.0, 36.059231, 0.0, 3.0],
...                [2.0, 36.831970, 1.0, 1.0],
...                [1.0, 61.662571, 1.0, 2.0],
...                [0.0, 25.677139, 0.0, 3.0],
...                [1.0, 39.085670, 1.0, 0.0],
...                [0.0, 48.843410, 1.0, 1.0],
...                [1.0, 39.343910, 0.0, 3.0],
...                [2.0, 24.735220, 0.0, 2.0],
...                [1.0, 50.552509, 1.0, 3.0],
...                [0.0, 31.342630, 1.0, 3.0],
...                [1.0, 27.157949, 1.0, 0.0],
...                [0.0, 31.726851, 0.0, 2.0],
...                [0.0, 25.004080, 0.0, 3.0],
...                [1.0, 26.354570, 1.0, 3.0],
...                [2.0, 38.123428, 0.0, 1.0],
...                [0.0, 49.940300, 0.0, 2.0],
...                [1.0, 42.457790, 1.0, 3.0],
...                [0.0, 38.809479, 1.0, 1.0],
...                [0.0, 43.227989, 1.0, 1.0],
...                [0.0, 41.876240, 0.0, 3.0],
...                [2.0, 48.078201, 0.0, 2.0],
...                [0.0, 43.236729, 1.0, 0.0],
...                [2.0, 39.412941, 0.0, 3.0],
...                [1.0, 23.933460, 0.0, 2.0],
...                [2.0, 42.841301, 1.0, 3.0],
...                [2.0, 30.406691, 0.0, 1.0],
...                [0.0, 37.773891, 0.0, 2.0]])
>>> response_column_index = 3
>>> var_type = np.array([0, 2, 0, 0], dtype=int)
>>> with dm.QUESTDecisionTree(response_column_index,
...                           var_type) as decision_tree:
...     decision_tree.train(xy)
...     print(decision_tree)
... 
Decision Tree:
Node 0: Cost = 0.620, N = 50.0, Level = 0, Child Nodes:  1  2
P(Y=0) = 0.180
P(Y=1) = 0.180
P(Y=2) = 0.260
P(Y=3) = 0.380
Predicted Y: 3
   Node 1: Cost = 0.220, N = 17.0, Level = 1
   Rule: X1 <= 35.031
   P(Y=0) = 0.294
   P(Y=1) = 0.118
   P(Y=2) = 0.353
   P(Y=3) = 0.235
   Predicted Y: 2
   Node 2: Cost = 0.360, N = 33.0, Level = 1, Child Nodes:  3  4
   Rule: X1 > 35.031
   P(Y=0) = 0.121
   P(Y=1) = 0.212
   P(Y=2) = 0.212
   P(Y=3) = 0.455
   Predicted Y: 3
      Node 3: Cost = 0.180, N = 19.0, Level = 2
      Rule: X1 <= 43.265
      P(Y=0) = 0.211
      P(Y=1) = 0.211
      P(Y=2) = 0.053
      P(Y=3) = 0.526
      Predicted Y: 3
      Node 4: Cost = 0.160, N = 14.0, Level = 2
      Rule: X1 > 43.265
      P(Y=0) = 0.000
      P(Y=1) = 0.214
      P(Y=2) = 0.429
      P(Y=3) = 0.357
      Predicted Y: 2

This example uses the dataset Kyphosis. The 81 cases represent 81 children who have undergone surgery to correct a type of spinal deformity known as Kyphosis. The response variable is the presence or absence of Kyphosis after the surgery. Three predictors are: Age of the patient in months; Start, the vertebra number where the surgery started; and Number, the number of vertebra involved in the surgery. This example uses the method QUEST to produce a maximal tree. It also requests predictions for a test dataset consisting of 10 “new” cases.

>>> import numpy as np
>>> import imsl.data_mining as dm
>>> xy = np.array([[0.0, 71.0, 3.0, 5.0],
...                [0.0, 158.0, 3.0, 14.0],
...                [1.0, 128, 4.0, 5.0],
...                [0.0, 2.0, 5.0, 1.0],
...                [0.0, 1.0, 4.0, 15.0],
...                [0.0, 1.0, 2.0, 16.0],
...                [0.0, 61.0, 2.0, 17.0],
...                [0.0, 37.0, 3.0, 16.0],
...                [0.0, 113.0, 2.0, 16.0],
...                [1.0, 59.0, 6.0, 12.0],
...                [1.0, 82.0, 5.0, 14.0],
...                [0.0, 148.0, 3.0, 16.0],
...                [0.0, 18.0, 5.0, 2.0],
...                [0.0, 1.0, 4.0, 12.0],
...                [0.0, 168.0, 3.0, 18.0],
...                [0.0, 1.0, 3.0, 16.0],
...                [0.0, 78.0, 6.0, 15.0],
...                [0.0, 175.0, 5.0, 13.0],
...                [0.0, 80.0, 5.0, 16.0],
...                [0.0, 27.0, 4.0, 9.0],
...                [0.0, 22.0, 2.0, 16.0],
...                [1.0, 105.0, 6.0, 5.0],
...                [1.0, 96.0, 3.0, 12.0],
...                [0.0, 131.0, 2.0, 3.0],
...                [1.0, 15.0, 7.0, 2.0],
...                [0.0, 9.0, 5.0, 13.0],
...                [0.0, 8.0, 3.0, 6.0],
...                [0.0, 100.0, 3.0, 14.0],
...                [0.0, 4.0, 3.0, 16.0],
...                [0.0, 151.0, 2.0, 16.0],
...                [0.0, 31.0, 3.0, 16.0],
...                [0.0, 125.0, 2.0, 11.0],
...                [0.0, 130.0, 5.0, 13.0],
...                [0.0, 112.0, 3.0, 16.0],
...                [0.0, 140.0, 5.0, 11.0],
...                [0.0, 93.0, 3.0, 16.0],
...                [0.0, 1.0, 3.0, 9.0],
...                [1.0, 52.0, 5.0, 6.0],
...                [0.0, 20.0, 6.0, 9.0],
...                [1.0, 91.0, 5.0, 12.0],
...                [1.0, 73.0, 5.0, 1.0],
...                [0.0, 35.0, 3.0, 13.0],
...                [0.0, 143.0, 9.0, 3.0],
...                [0.0, 61.0, 4.0, 1.0],
...                [0.0, 97.0, 3.0, 16.0],
...                [1.0, 139.0, 3.0, 10.0],
...                [0.0, 136.0, 4.0, 15.0],
...                [0.0, 131.0, 5.0, 13.0],
...                [1.0, 121.0, 3.0, 3.0],
...                [0.0, 177.0, 2.0, 14.0],
...                [0.0, 68.0, 5.0, 10.0],
...                [0.0, 9.0, 2.0, 17.0],
...                [1.0, 139.0, 10.0, 6.0],
...                [0.0, 2.0, 2.0, 17.0],
...                [0.0, 140.0, 4.0, 15.0],
...                [0.0, 72.0, 5.0, 15.0],
...                [0.0, 2.0, 3.0, 13.0],
...                [1.0, 120.0, 5.0, 8.0],
...                [0.0, 51.0, 7.0, 9.0],
...                [0.0, 102.0, 3.0, 13.0],
...                [1.0, 130.0, 4.0, 1.0],
...                [1.0, 114.0, 7.0, 8.0],
...                [0.0, 81.0, 4.0, 1.0],
...                [0.0, 118.0, 3.0, 16.0],
...                [0.0, 118.0, 4.0, 16.0],
...                [0.0, 17.0, 4.0, 10.0],
...                [0.0, 195.0, 2.0, 17.0],
...                [0.0, 159.0, 4.0, 13.0],
...                [0.0, 18.0, 4.0, 11.0],
...                [0.0, 15.0, 5.0, 16.0],
...                [0.0, 158.0, 5.0, 14.0],
...                [0.0, 127.0, 4.0, 12.0],
...                [0.0, 87.0, 4.0, 16.0],
...                [0.0, 206.0, 4.0, 10.0],
...                [0.0, 11.0, 3.0, 15.0],
...                [0.0, 178.0, 4.0, 15.0],
...                [1.0, 157.0, 3.0, 13.0],
...                [0.0, 26.0, 7.0, 13.0],
...                [0.0, 120.0, 2.0, 13.0],
...                [1.0, 42.0, 7.0, 6.0],
...                [0.0, 36.0, 4.0, 13.0]])
>>> xy_test = np.array([[0.0, 71.0, 3.0, 5.0],
...                     [1.0, 128.0, 4.0, 5.0],
...                     [0.0, 1.0, 4.0, 15.0],
...                     [0.0, 61.0, 6.0, 10.0],
...                     [0.0, 113.0, 2.0, 16.0],
...                     [1.0, 82.0, 5.0, 14.0],
...                     [0.0, 148.0, 3.0, 16.0],
...                     [0.0, 1.0, 4.0, 12.0],
...                     [0.0, 1.0, 3.0, 16.0],
...                     [0.0, 175.0, 5.0, 13.0]])
>>> response_column_index = 0
>>> var_type = np.array([0, 2, 2, 2], dtype=int)
>>> names = ["Age", "Number", "Start"]
>>> class_names = ["Absent", "Present"]
>>> response_name = "Kyphosis"
>>> with dm.QUESTDecisionTree(response_column_index, var_type,
...                           min_n_node=5, min_split=10, max_x_cats=10,
...                           max_size=50, max_depth=10,
...                           response_name=response_name, var_names=names,
...                           class_names=class_names) as decision_tree:
...    decision_tree.train(xy)
...    predictions = decision_tree.predict(xy_test)
...    print(decision_tree)
... 
Decision Tree:
Node 0: Cost = 0.210, N = 81.0, Level = 0, Child Nodes:  1  4
P(Y=0) = 0.790
P(Y=1) = 0.210
Predicted Kyphosis: Absent
   Node 1: Cost = 0.074, N = 13.0, Level = 1, Child Nodes:  2  3
   Rule: Start <= 5.155
   P(Y=0) = 0.538
   P(Y=1) = 0.462
   Predicted Kyphosis: Absent
      Node 2: Cost = 0.025, N = 7.0, Level = 2
      Rule: Age <= 84.030
      P(Y=0) = 0.714
      P(Y=1) = 0.286
      Predicted Kyphosis: Absent
      Node 3: Cost = 0.025, N = 6.0, Level = 2
      Rule: Age > 84.030
      P(Y=0) = 0.333
      P(Y=1) = 0.667
      Predicted Kyphosis: Present
   Node 4: Cost = 0.136, N = 68.0, Level = 1, Child Nodes:  5  6
   Rule: Start > 5.155
   P(Y=0) = 0.838
   P(Y=1) = 0.162
   Predicted Kyphosis: Absent
      Node 5: Cost = 0.012, N = 6.0, Level = 2
      Rule: Start <= 8.862
      P(Y=0) = 0.167
      P(Y=1) = 0.833
      Predicted Kyphosis: Present
      Node 6: Cost = 0.074, N = 62.0, Level = 2, Child Nodes:  7  12
      Rule: Start > 8.862
      P(Y=0) = 0.903
      P(Y=1) = 0.097
      Predicted Kyphosis: Absent
         Node 7: Cost = 0.062, N = 28.0, Level = 3, Child Nodes:  8  9
         Rule: Start <= 13.092
         P(Y=0) = 0.821
         P(Y=1) = 0.179
         Predicted Kyphosis: Absent
            Node 8: Cost = 0.025, N = 15.0, Level = 4
            Rule: Age <= 91.722
            P(Y=0) = 0.867
            P(Y=1) = 0.133
            Predicted Kyphosis: Absent
            Node 9: Cost = 0.037, N = 13.0, Level = 4, Child Nodes:  10  11
            Rule: Age > 91.722
            P(Y=0) = 0.769
            P(Y=1) = 0.231
            Predicted Kyphosis: Absent
               Node 10: Cost = 0.037, N = 6.0, Level = 5
               Rule: Number <= 3.450
               P(Y=0) = 0.500
               P(Y=1) = 0.500
               Predicted Kyphosis: Absent
               Node 11: Cost = 0.000, N = 7.0, Level = 5
               Rule: Number > 3.450
               P(Y=0) = 1.000
               P(Y=1) = 0.000
               Predicted Kyphosis: Absent
         Node 12: Cost = 0.012, N = 34.0, Level = 3, Child Nodes:  13  14
         Rule: Start > 13.092
         P(Y=0) = 0.971
         P(Y=1) = 0.029
         Predicted Kyphosis: Absent
            Node 13: Cost = 0.012, N = 5.0, Level = 4
            Rule: Start <= 14.864
            P(Y=0) = 0.800
            P(Y=1) = 0.200
            Predicted Kyphosis: Absent
            Node 14: Cost = 0.000, N = 29.0, Level = 4
            Rule: Start > 14.864
            P(Y=0) = 1.000
            P(Y=1) = 0.000
            Predicted Kyphosis: Absent
>>> print("\nPredictions for test data:\n")
... 
Predictions for test data:
>>> print("  {:5s} {:8s} {:7s} {:8s}".format(names[0], names[1], names[2],
...                                           response_name))
  Age   Number   Start   Kyphosis
>>> n_rows = xy_test.shape[0]
>>> for i in range(n_rows):
...     idx = int(predictions.predictions[i])
...     print("{:5.0f} {:8.0f} {:7.0f}   {}".format(xy_test[i, 1],
...                                                 xy_test[i, 2],
...                                                 xy_test[i, 3],
...                                                 class_names[idx]))
   71        3       5   Absent
  128        4       5   Present
    1        4      15   Absent
   61        6      10   Absent
  113        2      16   Absent
   82        5      14   Absent
  148        3      16   Absent
    1        4      12   Absent
    1        3      16   Absent
  175        5      13   Absent
>>> print("\nMean squared prediction error: {}".format(
...     predictions.pred_err_ss))
... 
Mean squared prediction error: 0.1

Methods

predict(data[, weights]) Compute predicted values using a decision tree.
train(training_data[, weights]) Train a decision tree using training data and weights.

Attributes

categ_names Return names of category levels for each categorical predictor.
class_names Return names of different classes in Y.
n_classes Return number of classes assumed by response variable.
n_levels Return number of levels or depth of tree.
n_nodes Return number of nodes or size of tree.
n_preds Return number of predictors used in the model.
pred_n_values Return number of values of predictor variables.
pred_type Return types of predictor variables.
response_name Return name of the response variable.
response_type Return type of the response variable.
var_names Return names of the predictors.