public class BootstrapAggregation extends Object implements Serializable, Cloneable
Bootstrap aggregation, also known as bagging, generates predictions using predictive models. In the procedure, M bootstrap samples of size N are drawn with replacement from an original training set of size N. Sampling with replacement means that when an example is randomly selected, it is replaced back into the training set before the next draw. Thus a bootstrap sample can have repeated examples or observations. Using each bootstrap sample as a separate training data set, the procedure fits a predictive model and then generates predictions. The M predictions are combined into a single predicted value by averaging if the problem is regression (continuous response variable) or by majority vote if the problem is classification (categorical response variable).
Originally proposed for decision trees, bagging leads to "improvements for unstable procedures," such as neural networks, classification and regression trees, and subset selection in linear regression. On the other hand, it can mildly degrade the performance of stable methods such as K-nearest neighbors (Breiman, 1996).
Constructor and Description |
---|
BootstrapAggregation(PredictiveModel pm)
Constructs a
BootstrapAggregation class in order to generate
predictions of a PredictiveModel using bootstrap aggregation. |
Modifier and Type | Method and Description |
---|---|
void |
aggregate()
Performs the bootstrap aggregation.
|
double |
getMeanSquaredPredictionError()
Returns the mean squared prediction error.
|
int |
getNumberOfThreads()
Returns the maximum number of
java.lang.Thread instances
that may be used for parallel processing. |
double[] |
getPredictions()
Returns the predicted values.
|
int |
getPrintLevel()
Returns the current print level.
|
void |
setNumberBootstrapSamples(int nSamples)
Sets the number of bootstrap samples.
|
void |
setNumberOfThreads(int numberOfThreads)
Sets the maximum number of
java.lang.Thread instances that
may be used for parallel processing. |
void |
setPrintLevel(int printLevel)
Sets the print level for the predictive model.
|
void |
setRandomObject(Random r)
Sets a random object for the bootstrap random sampling scheme.
|
void |
setTestData(double[][] testData)
Sets the test data to be predicted.
|
void |
setTestData(double[][] testData,
double[] testDataWeights)
Sets the test data to be predicted along with weights for each row in the
test data.
|
public BootstrapAggregation(PredictiveModel pm)
BootstrapAggregation
class in order to generate
predictions of a PredictiveModel
using bootstrap aggregation.pm
- a PredictiveModel
for which the predictions are to
be generated.public void aggregate() throws PredictiveModel.PredictiveModelException, NoSuchMethodException, InstantiationException, IllegalAccessException, InvocationTargetException
PredictiveModel.PredictiveModelException
- an exception has occurred in the common
PredictiveModel
programming interface methods or an
exception class that has extended the
PredictiveModelException
class.NoSuchMethodException
- the PredictiveModel subclass is missing a
constructor with the expected signature (see PredictiveModel(double[][], int,
com.imsl.datamining.PredictiveModel.VariableType[])
).InstantiationException
- the object has failed to instantiate. This
maybe have occurred if your PredictiveModel}
subclass is not
concrete.IllegalAccessException
- the currently executing method does not
have access to the definition of the specified class, field, method or
constructor.InvocationTargetException
- an exception has occurred within one of
the methods. Use Throwable.getCause()
to extract the
initiating exception.public double getMeanSquaredPredictionError()
double
equal to the mean squared error between the
predicted value and the actual value of the response variable.
Note: the error is the in-sample fitted error unless the user specifies
the test data using setTestData()
.
public int getNumberOfThreads()
java.lang.Thread
instances
that may be used for parallel processing.int
containing the maximum number of
java.lang.Thread
instances that may be used for parallel
processing.
The actual number of threads used in parallel processing will be the
lesser of numberOfThreads
and nSamples
, the
number of bootstrap samples set for bootstrap aggregation. This
assessment is made to optimize use of resources.
public double[] getPredictions()
double
array of predicted values of the response
variable for the examples in the test data.
To generate the predicted values, use the method aggregate
.
If testData
is not specified in-sample predictions are
produced.
public int getPrintLevel()
int
indicating the current print level.
printLevel | Action |
0 | No printing. |
1 | Prints final results only. |
2 | Prints intermediate and final results. |
public void setNumberBootstrapSamples(int nSamples)
nSamples
- an int
specifying the number of bootstrap
samples.
Default: nSamples = 50.
public void setNumberOfThreads(int numberOfThreads)
java.lang.Thread
instances that
may be used for parallel processing.numberOfThreads
- an int
specifying the maximum number
of java.lang.Thread
instances that may be used for parallel
processing.
The actual number of threads used in parallel processing will be the
lesser of numberOfThreads
and nSamples
, the
number of bootstrap samples set for bootstrap aggregation. This
assessment is made to optimize use of resources.
Default: numberOfThreads
= 1.
public void setPrintLevel(int printLevel)
printLevel
- An int
specifying the level of printing to
perform.
printLevel | Action |
0 | No printing. |
1 | Prints final results only. |
2 | Prints intermediate and final results. |
Default: printLevel
= 0.
public void setRandomObject(Random r)
r
- a Random
object.
Default: r
is created inside the code and the seed is set by
the computer clock.
To obtain repeatable results, set the seed of the input r
before calling this method. See Random
for other
options.
public void setTestData(double[][] testData)
testData
- a double
matrix containing test data for
which predictions are to be made using bagging.
testData
must have the same number of columns and in the
same arrangement as xy
. Missing response variable values
should be indicated with Double.NaN()
.
Default: If testData is not specified, in-sample predictions are produced (i.e., the original training set serves as the test data).
public void setTestData(double[][] testData, double[] testDataWeights)
testData
- a double
matrix containing test data for
which predictions are to be made using bagging.
testData
must have the same number of columns and in the
same arrangement as xy
. Missing response variable values
should be indicated with Double.NaN()
.
testDataWeights
- a double
array containing observation
weights for the test data.
Default: If testData is not specified, in-sample predictions are produced (i.e., the original training set serves as the test data).
Copyright © 1970-2015 Rogue Wave Software
Built October 13 2015.