Bases: Aggregate_Base
Aggregate function based on Global Argmax.
This class should be passed to an ensemble function/class for combining predictions.
Source code in aucmedi/ensemble/aggregate/global_argmax.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58 | class GlobalArgmax(Aggregate_Base):
""" Aggregate function based on Global Argmax.
This class should be passed to an ensemble function/class for combining predictions.
"""
#---------------------------------------------#
# Initialization #
#---------------------------------------------#
def __init__(self):
# No hyperparameter adjustment required for this method, therefore skip
pass
#---------------------------------------------#
# Aggregate #
#---------------------------------------------#
def aggregate(self, preds):
# Identify global argmax
max = np.amax(preds)
argmax_flatten = np.argmax(preds)
argmax = np.unravel_index(argmax_flatten, preds.shape)[-1]
# Compute prediction by global argmax and equally distributed remaining
# probability for other classes
prob_remaining = np.divide(1-max, preds.shape[1]-1)
pred = np.full((preds.shape[1],), fill_value=prob_remaining)
pred[argmax] = max
# Return prediction
return pred
|