public class RandomTrees extends PredictiveModel implements Serializable, Cloneable
A random forest is an ensemble of decision trees. Like bootstrap aggregation, a tree is fit to each of M bootstrap samples from the training data. Each tree is then used to generate predictions. For a regression problem (continuous response variable), the M predictions are combined into a single predicted value by averaging. For classification (categorical response variable), majority vote is used. A random forest also randomizes the predictors. That is, in every tree, the splitting variable at every node is selected from a random subset of the predictors. Randomization of the predictors reduces correlation among individual trees. The random forest was invented by Leo Breiman in 2001 (Breiman, 2001). Random ForestsTM is the trademark term for this approach. Also see Hastie, Tibshirani, and Friedman, 2008, for further discussion.
Modifier and Type | Class and Description |
---|---|
static class |
RandomTrees.ReflectiveOperationException
Class that wraps exceptions thrown by reflective operations in core
reflection.
|
PredictiveModel.CloneNotSupportedException, PredictiveModel.PredictiveModelException, PredictiveModel.StateChangeException, PredictiveModel.SumOfProbabilitiesNotOneException, PredictiveModel.VariableType
Constructor and Description |
---|
RandomTrees(DecisionTree dt)
Constructs a
RandomTrees random forest
of the input decision tree. |
RandomTrees(double[][] xy,
int responseColumnIndex,
PredictiveModel.VariableType[] varType)
Constructs a
RandomTrees random forest of
ALACART decision trees. |
RandomTrees(RandomTrees rtModel)
Constructs a copy of the input
RandomTrees predictive model. |
Modifier and Type | Method and Description |
---|---|
RandomTrees |
clone()
Clones a
RandomTrees predictive model. |
void |
fitModel()
Fits the random forest to the training data.
|
int |
getNumberOfRandomFeatures()
Returns the number of random features used in the splitting rules.
|
int |
getNumberOfTrees()
Returns the number of trees.
|
double |
getOutOfBagPredictionError()
Returns the out-of-bag prediction error.
|
double[] |
getOutOfBagPredictions()
Returns the out-of-bag predicted values for the examples in the
training data.
|
double[] |
getVariableImportance()
Returns the variable importance measure based on the out-of-bag
prediction error.
|
boolean |
isCalculateVariableImportance()
Returns the current setting of the boolean to calculate variable
importance.
|
double[] |
predict()
Returns the predicted values generated by the random forest on the
training data.
|
double[] |
predict(double[][] testData)
Returns the predicted values on the input test data.
|
double[] |
predict(double[][] testData,
double[] testDataWeights)
Returns the predicted values on the input test data and the test data
weights.
|
void |
setCalculateVariableImportance(boolean calculate)
Sets the boolean to calculate variable importance.
|
protected void |
setConfiguration(PredictiveModel pm)
Sets the configuration of
RandomTrees to that of the input
model. |
void |
setNumberOfRandomFeatures(int numberOfRandomFeatures)
Sets the number of random features used in the splitting rules.
|
void |
setNumberOfThreads(int numberOfThreads)
Sets the maximum number of
java.lang.Thread instances that
may be used for parallel processing. |
void |
setNumberOfTrees(int numberOfTrees)
Sets the number of trees to generate in the random forest.
|
getClassCounts, getClassErrors, getClassLabels, getClassProbabilities, getCostMatrix, getMaxNumberOfCategories, getMaxNumberOfIterations, getNumberOfClasses, getNumberOfColumns, getNumberOfMissing, getNumberOfPredictors, getNumberOfRows, getNumberOfUniquePredictorValues, getPredictorIndexes, getPredictorTypes, getPrintLevel, getPriorProbabilities, getRandomObject, getResponseColumnIndex, getResponseVariableAverage, getResponseVariableMostFrequentClass, getResponseVariableType, getTotalWeight, getVariableType, getWeights, getXY, isConstantSeries, isMustFitModel, isUserFixedNClasses, setClassCounts, setClassLabels, setClassProbabilities, setCostMatrix, setMaxNumberOfCategories, setMaxNumberOfIterations, setMustFitModel, setNumberOfClasses, setPredictorIndex, setPredictorTypes, setPrintLevel, setPriorProbabilities, setRandomObject, setResponseColumnIndex, setTrainingData, setVariableType, setWeights
public RandomTrees(double[][] xy, int responseColumnIndex, PredictiveModel.VariableType[] varType)
RandomTrees
random forest of
ALACART
decision trees.xy
- a double
matrix containing the training dataresponseColumnIndex
- an int
, the column index for the
response variablevarType
- a PredictiveModel.VariableType
array containing the type of each variablepublic RandomTrees(DecisionTree dt)
RandomTrees
random forest
of the input decision tree.dt
- a DecisionTree
objectpublic RandomTrees(RandomTrees rtModel)
RandomTrees
predictive model.rtModel
- a RandomTrees
predictive modelpublic RandomTrees clone()
RandomTrees
predictive model.clone
in class PredictiveModel
RandomTrees
predictive modelpublic void setNumberOfTrees(int numberOfTrees)
The number of trees is equivalent to the number of bootstrap samples.
numberOfTrees
- an int
, the number of trees to generate
Default: numberOfTrees=50
public void setNumberOfRandomFeatures(int numberOfRandomFeatures)
numberOfRandomFeatures
- an int
, the number of
predictors in the random subset
Default: numberOfRandomFeatures
=\(\sqrt{p}\) for
classification problems, \(\frac{p}{3}\)
for regression problems, where \(p\) is the
number of predictors in the training data.
public int getNumberOfRandomFeatures()
int
, the number of random featurespublic void setCalculateVariableImportance(boolean calculate)
When true
, a permutation type variable importance measure is
calculated during bootstrap aggregation.
calculate
- a boolean
indicating whether or not to
calculate variable importance
Default: calculate
= false
public boolean isCalculateVariableImportance()
boolean
, the current setting of the flagpublic int getNumberOfTrees()
int
, the number of treespublic void setNumberOfThreads(int numberOfThreads)
java.lang.Thread
instances that
may be used for parallel processing.numberOfThreads
- an int
specifying the maximum number
of java.lang.Thread
instances that may be used for parallel
processing.
The actual number of threads used in parallel processing will be the
lesser of numberOfThreads
and numberOfTrees
,
the number of trees in the random forest. This assessment is made to
optimize use of resources.
Default: numberOfThreads
= 1.
public void fitModel() throws PredictiveModel.PredictiveModelException
fitModel
in class PredictiveModel
PredictiveModel.PredictiveModelException
- is thrown
when an exception occurs in the com.imsl.datamining.PredictiveModel.
Superclass exceptions should be considered such as
com.imsl.datamining.PredictiveModel.StateChangeException and
com.imsl.datamining.PredictiveModel.SumOfProbabilitiesNotOneException.protected void setConfiguration(PredictiveModel pm) throws PredictiveModel.PredictiveModelException
RandomTrees
to that of the input
model.setConfiguration
in class PredictiveModel
pm
- a RandomTrees
objectPredictiveModel.PredictiveModelException
- is thrown
when an exception occurs in the com.imsl.datamining.PredictiveModel.
Superclass exceptions should be considered such as
com.imsl.datamining.PredictiveModel.StateChangeException and
com.imsl.datamining.PredictiveModel.SumOfProbabilitiesNotOneException.public double[] predict() throws PredictiveModel.PredictiveModelException
predict
in class PredictiveModel
double
array containing the fitted valuesPredictiveModel.PredictiveModelException
- is thrown
when an exception occurs in the com.imsl.datamining.PredictiveModel.
Superclass exceptions should be considered such as
com.imsl.datamining.PredictiveModel.StateChangeException and
com.imsl.datamining.PredictiveModel.SumOfProbabilitiesNotOneException.public double[] predict(double[][] testData) throws PredictiveModel.PredictiveModelException
predict
in class PredictiveModel
testData
- a double
matrix containing test data
Note: testData
must have the same number of columns
as xy
and the columns must be in the same arrangement as in
xy
.
double
array containing the predicted valuesPredictiveModel.PredictiveModelException
- is thrown
when an exception occurs in the com.imsl.datamining.PredictiveModel.
Superclass exceptions should be considered such as
com.imsl.datamining.PredictiveModel.StateChangeException and
com.imsl.datamining.PredictiveModel.SumOfProbabilitiesNotOneException.public double[] predict(double[][] testData, double[] testDataWeights) throws PredictiveModel.PredictiveModelException
predict
in class PredictiveModel
testData
- a double
matrix containing test datatestDataWeights
- a double
array containing weight
values for each row of testData
Note: testData
must have the same number of columns
as xy
and the columns must be in the same arrangement as in
xy
.
double
array containing the predicted valuesPredictiveModel.PredictiveModelException
- is thrown
when an exception occurs in the com.imsl.datamining.PredictiveModel.
Superclass exceptions should be considered such as
com.imsl.datamining.PredictiveModel.StateChangeException and
com.imsl.datamining.PredictiveModel.SumOfProbabilitiesNotOneException.public double[] getOutOfBagPredictions()
double
array containing the out-of-bag predictionspublic double getOutOfBagPredictionError()
double
, the out-of-bag prediction errorpublic double[] getVariableImportance()
Variable importance for a predictor is obtained by randomly permuting the out-of-bag values of the predictor and calculating the difference in predictive accuracy, before and after the permutation. The measure is averaged over all the trees.
double
array containing variable importance for
each predictorCopyright © 2020 Rogue Wave Software. All rights reserved.