public abstract class DecisionTreeInfoGain extends DecisionTree implements Serializable, Cloneable
Abstract class that extends DecisionTree
for classes that use an
information gain criteria.
Modifier and Type | Class and Description |
---|---|
static class |
DecisionTreeInfoGain.GainCriteria
Specifies which information gain criteria to use in determining the best
split at each node.
|
DecisionTree.MaxTreeSizeExceededException, DecisionTree.PruningFailedToConvergeException, DecisionTree.PureNodeException
PredictiveModel.CloneNotSupportedException, PredictiveModel.PredictiveModelException, PredictiveModel.StateChangeException, PredictiveModel.SumOfProbabilitiesNotOneException, PredictiveModel.VariableType
Constructor and Description |
---|
DecisionTreeInfoGain(double[][] xy,
int responseColumnIndex,
PredictiveModel.VariableType[] varType)
Constructs a
DecisionTree object for a single response
variable and multiple predictor variables. |
Modifier and Type | Method and Description |
---|---|
protected double[][] |
getCountXY(double[][] xy,
int nRows,
int xIdx,
int yIdx,
int maxNumberOfCategories,
int[] uniqueX,
int[] uniqueY,
double[] frequencies)
Calculates a two-way frequency table with input frequencies.
|
protected double |
getCriteriaValueCategorical(double[][] tableXY,
double[] classCounts,
int nRows,
int maxNumCats)
Calculates and returns the value of the criterion on the node represented
by the data set S = xy.
|
protected abstract int |
selectSplitVariable(double[][] xy,
double[] classCounts,
double[] parentFreq,
double[] splitValue,
double[] splitCriterionValue,
int[] splitPartition)
Abstract method for selecting the next split variable and split
definition for the node.
|
void |
setGainCriteria(DecisionTreeInfoGain.GainCriteria gainCriteria)
Specifies which criteria to use in gain calculations in order to
determine the best split at each node.
|
void |
setUseRatio(boolean ratio)
Sets the flag to use or not use the gain ratio instead of the gain to
determine the best split.
|
boolean |
useGainRatio()
Returns whether or not the gain ratio is to be used instead of the gain
to determine the best split.
|
fitModel, getCostComplexityValues, getDecisionTree, getFittedMeanSquaredError, getMaxDepth, getMaxNodes, getMeanSquaredPredictionError, getMinCostComplexityValue, getMinObsPerChildNode, getMinObsPerNode, getNodeAssigments, getNumberOfComplexityValues, getNumberOfRandomFeatures, isAutoPruningFlag, isRandomFeatureSelection, predict, predict, predict, printDecisionTree, printDecisionTree, pruneTree, setAutoPruningFlag, setConfiguration, setCostComplexityValues, setMaxDepth, setMaxNodes, setMinCostComplexityValue, setMinObsPerChildNode, setMinObsPerNode, setNumberOfRandomFeatures, setRandomFeatureSelection
clone, getClassCounts, getClassErrors, getClassLabels, getClassProbabilities, getCostMatrix, getMaxNumberOfCategories, getMaxNumberOfIterations, getNumberOfClasses, getNumberOfColumns, getNumberOfMissing, getNumberOfPredictors, getNumberOfRows, getNumberOfUniquePredictorValues, getPredictorIndexes, getPredictorTypes, getPrintLevel, getPriorProbabilities, getRandomObject, getResponseColumnIndex, getResponseVariableAverage, getResponseVariableMostFrequentClass, getResponseVariableType, getTotalWeight, getVariableType, getWeights, getXY, isConstantSeries, isMustFitModel, isUserFixedNClasses, setClassCounts, setClassLabels, setClassProbabilities, setCostMatrix, setMaxNumberOfCategories, setMaxNumberOfIterations, setMustFitModel, setNumberOfClasses, setPredictorIndex, setPredictorTypes, setPrintLevel, setPriorProbabilities, setRandomObject, setResponseColumnIndex, setTrainingData, setVariableType, setWeights
public DecisionTreeInfoGain(double[][] xy, int responseColumnIndex, PredictiveModel.VariableType[] varType)
DecisionTree
object for a single response
variable and multiple predictor variables.xy
- a double
matrix with rows containing the
observations on the predictor variables and one response variableresponseColumnIndex
- an int
specifying the column
index of the response variablevarType
- a PredictiveModel.VariableType
array containing the type of each variableprotected abstract int selectSplitVariable(double[][] xy, double[] classCounts, double[] parentFreq, double[] splitValue, double[] splitCriterionValue, int[] splitPartition)
selectSplitVariable
in class DecisionTree
xy
- a double
matrix containing the dataclassCounts
- a double
array containing the counts for
each class of the response variable, when it is categoricalparentFreq
- a double
array used to indicate which
subset of the observations belong in the current nodesplitValue
- a double
array representing the resulting
split point if the selected variable is quantitativesplitCriterionValue
- a double
, the value of the
criterion used to determine the splitting variablesplitPartition
- an int
array indicating the resulting
split partition if the selected variable is categoricalint
specifying the column index of the split
variable in this.getPredictorIndexes
public void setGainCriteria(DecisionTreeInfoGain.GainCriteria gainCriteria)
gainCriteria
- a DecisionTreeInfoGain.GainCriteria
specifying which criteria to
use in gain calculations in order to determine the best split at each
node
Default: gainCriteria
= DecisionTreeInfoGain.GainCriteria.SHANNON_ENTROPY
public boolean useGainRatio()
boolean
indicating if the gain ratio is to be used
true
, uses the gain ratio; false
uses the gain.
public void setUseRatio(boolean ratio)
ratio
- a boolean
indicating if the gain ratio is to be
used
true
uses the gain ratio; false
uses the gain.
Default: useRatio=false
protected double getCriteriaValueCategorical(double[][] tableXY, double[] classCounts, int nRows, int maxNumCats)
tableXY
- classCounts
- an int
array containing the total counts
of response variable by categorynRows
- an int
, the number of rows in xy
maxNumCats
- an int
, the maximum number of categorical
values allowed in the problemdouble
, the value of the splitting criteriaprotected double[][] getCountXY(double[][] xy, int nRows, int xIdx, int yIdx, int maxNumberOfCategories, int[] uniqueX, int[] uniqueY, double[] frequencies)
double
matrix containing the data to be tabulatednRows
- an int
the number of rows in xy
xIdx
- an int
the column index of x
yIdx
- an int
the column index of y
maxNumberOfCategories
- an int
the maximum number of
categories in either x
or y
uniqueX
- an int
array containing indicators for the
categories of x
uniqueY
- an int
array containing indicators for the
categories of y
frequencies
- an double
array of length
nRows
containing the row frequencies for the datamaxNumberOfCategories
by
maxNumberOfCategories
double
matrix containing
the cross-tabulated frequencies for x
and y
.
The categories of x
vary along the rows and the categories
of y
vary along the columns.Copyright © 2020 Rogue Wave Software. All rights reserved.