Augmenting
predict_augmenting(model, prediction_generator, n_cycles=10, aggregate='mean')
¤
Inference Augmenting function for automatically augmenting unknown images for prediction.
The predictions of the augmented images are aggregated via the provided Aggregate function.
Example
# Import libraries
from aucmedi.ensemble import predict_augmenting
from aucmedi import ImageAugmentation, DataGenerator
# Initialize testing DataGenerator with desired Data Augmentation
test_aug = ImageAugmentation(flip=True, rotate=True, brightness=False, contrast=False))
test_gen = DataGenerator(samples_test, "images_dir/",
data_aug=test_aug,
resize=model.meta_input,
standardize_mode=model.meta_standardize)
# Compute predictions via Augmenting
preds = predict_augmenting(model, test_gen, n_cycles=15, aggregate="majority_vote")
The inclusion of the Aggregate function can be achieved in multiple ways:
- self-initialization with an AUCMEDI Aggregate function,
- use a string key to call an AUCMEDI Aggregate function by name, or
- implementing a custom Aggregate function by extending the AUCMEDI base class for Aggregate functions
Info
Description and list of implemented Aggregate functions can be found here: Aggregate
The Data Augmentation class instance from the DataGenerator will be used for inference augmenting.
It can either be predefined or remain None
. If the data_aug
is None
, a Data Augmentation class
instance is automatically created which applies rotation and flipping augmentations.
Warning
The passed DataGenerator will be re-initialized!
This can result in redundant image preparation if prepare_images=True
.
Reference for Ensemble Learning Techniques
Dominik Müller, Iñaki Soto-Rey and Frank Kramer. (2022). An Analysis on Ensemble Learning optimized Medical Image Classification with Deep Convolutional Neural Networks. arXiv e-print: https://arxiv.org/abs/2201.11440
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
NeuralNetwork
|
Instance of a AUCMEDI neural network class. |
required |
prediction_generator |
DataGenerator
|
A data generator which will be used for Augmenting based inference. |
required |
n_cycles |
int
|
Number of image augmentations, which should be created per sample. |
10
|
aggregate |
str or aggregate Function
|
Aggregate function class instance or a string for an AUCMEDI Aggregate function. |
'mean'
|
Source code in aucmedi/ensemble/augmenting.py
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
|