Bases: Aggregate_Base
  
      Aggregate function based on Global Argmax.
This class should be passed to an ensemble function/class for combining predictions.
        
          Source code in aucmedi/ensemble/aggregate/global_argmax.py
          | 30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58 | class GlobalArgmax(Aggregate_Base):
    """ Aggregate function based on Global Argmax.
    This class should be passed to an ensemble function/class for combining predictions.
    """
    #---------------------------------------------#
    #                Initialization               #
    #---------------------------------------------#
    def __init__(self):
        # No hyperparameter adjustment required for this method, therefore skip
        pass
    #---------------------------------------------#
    #                  Aggregate                  #
    #---------------------------------------------#
    def aggregate(self, preds):
        # Identify global argmax
        max = np.amax(preds)
        argmax_flatten = np.argmax(preds)
        argmax = np.unravel_index(argmax_flatten, preds.shape)[-1]
        # Compute prediction by global argmax and equally distributed remaining
        # probability for other classes
        prob_remaining = np.divide(1-max, preds.shape[1]-1)
        pred = np.full((preds.shape[1],), fill_value=prob_remaining)
        pred[argmax] = max
        # Return prediction
        return pred
 |