Skip to content

Global argmax

GlobalArgmax ยค

Bases: Aggregate_Base

Aggregate function based on Global Argmax.

This class should be passed to an ensemble function/class for combining predictions.

Source code in aucmedi/ensemble/aggregate/global_argmax.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class GlobalArgmax(Aggregate_Base):
    """ Aggregate function based on Global Argmax.

    This class should be passed to an ensemble function/class for combining predictions.
    """
    #---------------------------------------------#
    #                Initialization               #
    #---------------------------------------------#
    def __init__(self):
        # No hyperparameter adjustment required for this method, therefore skip
        pass

    #---------------------------------------------#
    #                  Aggregate                  #
    #---------------------------------------------#
    def aggregate(self, preds):
        # Identify global argmax
        max = np.amax(preds)
        argmax_flatten = np.argmax(preds)
        argmax = np.unravel_index(argmax_flatten, preds.shape)[-1]

        # Compute prediction by global argmax and equally distributed remaining
        # probability for other classes
        prob_remaining = np.divide(1-max, preds.shape[1]-1)
        pred = np.full((preds.shape[1],), fill_value=prob_remaining)
        pred[argmax] = max

        # Return prediction
        return pred