Skip to content

Decision tree

DecisionTree ยค

Bases: Metalearner_Base

A Decision Tree based Metalearner.

This class should be passed to an ensemble function/class like Stacking for combining predictions.

Info

Can be utilized for binary, multi-class and multi-label tasks.

Reference - Implementation

https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html

Scikit-learn: Machine Learning in Python, Pedregosa et al., JMLR 12, pp. 2825-2830, 2011. https://jmlr.csail.mit.edu/papers/v12/pedregosa11a.html

Source code in aucmedi/ensemble/metalearner/decision_tree.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class DecisionTree(Metalearner_Base):
    """ A Decision Tree based Metalearner.

    This class should be passed to an ensemble function/class like Stacking for combining predictions.

    !!! info
        Can be utilized for binary, multi-class and multi-label tasks.

    ???+ abstract "Reference - Implementation"
        https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html

        Scikit-learn: Machine Learning in Python, Pedregosa et al., JMLR 12, pp. 2825-2830, 2011.
        https://jmlr.csail.mit.edu/papers/v12/pedregosa11a.html
    """
    #---------------------------------------------#
    #                Initialization               #
    #---------------------------------------------#
    def __init__(self):
        self.model = DecisionTreeClassifier(random_state=0)

    #---------------------------------------------#
    #                  Training                   #
    #---------------------------------------------#
    def train(self, x, y):
        # Train model
        self.model = self.model.fit(x, y)

    #---------------------------------------------#
    #                  Prediction                 #
    #---------------------------------------------#
    def predict(self, data):
        # Compute prediction probabilities via fitted model
        pred = self.model.predict_proba(data)
        # Postprocess decision tree predictions
        pred = np.asarray(pred)
        pred = np.swapaxes(pred[:,:,1], 0, 1)
        # Return results as NumPy array
        return pred

    #---------------------------------------------#
    #              Dump Model to Disk             #
    #---------------------------------------------#
    def dump(self, path):
        # Dump model to disk via pickle
        with open(path, "wb") as pickle_writer:
            pickle.dump(self.model, pickle_writer)

    #---------------------------------------------#
    #             Load Model from Disk            #
    #---------------------------------------------#
    def load(self, path):
        # Load model from disk via pickle
        with open(path, "rb") as pickle_reader:
            self.model = pickle.load(pickle_reader)