Skip to content

Augmenting

predict_augmenting(model, prediction_generator, n_cycles=10, aggregate='mean') ¤

Inference Augmenting function for automatically augmenting unknown images for prediction.

The predictions of the augmented images are aggregated via the provided Aggregate function.

Example
# Import libraries
from aucmedi.ensemble import predict_augmenting
from aucmedi import ImageAugmentation, DataGenerator

# Initialize testing DataGenerator with desired Data Augmentation
test_aug = ImageAugmentation(flip=True, rotate=True, brightness=False, contrast=False))
test_gen = DataGenerator(samples_test, "images_dir/",
                         data_aug=test_aug,
                         resize=model.meta_input,
                         standardize_mode=model.meta_standardize)

# Compute predictions via Augmenting
preds = predict_augmenting(model, test_gen, n_cycles=15, aggregate="majority_vote")

The inclusion of the Aggregate function can be achieved in multiple ways:

  • self-initialization with an AUCMEDI Aggregate function,
  • use a string key to call an AUCMEDI Aggregate function by name, or
  • implementing a custom Aggregate function by extending the AUCMEDI base class for Aggregate functions

Info

Description and list of implemented Aggregate functions can be found here: Aggregate

The Data Augmentation class instance from the DataGenerator will be used for inference augmenting. It can either be predefined or remain None. If the data_aug is None, a Data Augmentation class instance is automatically created which applies rotation and flipping augmentations.

Warning

The passed DataGenerator will be re-initialized! This can result in redundant image preparation if prepare_images=True.

Reference for Ensemble Learning Techniques

Dominik Müller, Iñaki Soto-Rey and Frank Kramer. (2022). An Analysis on Ensemble Learning optimized Medical Image Classification with Deep Convolutional Neural Networks. arXiv e-print: https://arxiv.org/abs/2201.11440

Parameters:

Name Type Description Default
model NeuralNetwork

Instance of a AUCMEDI neural network class.

required
prediction_generator DataGenerator

A data generator which will be used for Augmenting based inference.

required
n_cycles int

Number of image augmentations, which should be created per sample.

10
aggregate str or aggregate Function

Aggregate function class instance or a string for an AUCMEDI Aggregate function.

'mean'
Source code in aucmedi/ensemble/augmenting.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def predict_augmenting(model, prediction_generator, n_cycles=10, aggregate="mean"):
    """ Inference Augmenting function for automatically augmenting unknown images for prediction.

    The predictions of the augmented images are aggregated via the provided Aggregate function.

    ???+ example
        ```python
        # Import libraries
        from aucmedi.ensemble import predict_augmenting
        from aucmedi import ImageAugmentation, DataGenerator

        # Initialize testing DataGenerator with desired Data Augmentation
        test_aug = ImageAugmentation(flip=True, rotate=True, brightness=False, contrast=False))
        test_gen = DataGenerator(samples_test, "images_dir/",
                                 data_aug=test_aug,
                                 resize=model.meta_input,
                                 standardize_mode=model.meta_standardize)

        # Compute predictions via Augmenting
        preds = predict_augmenting(model, test_gen, n_cycles=15, aggregate="majority_vote")
        ```

    The inclusion of the Aggregate function can be achieved in multiple ways:

    - self-initialization with an AUCMEDI Aggregate function,
    - use a string key to call an AUCMEDI Aggregate function by name, or
    - implementing a custom Aggregate function by extending the [AUCMEDI base class for Aggregate functions][aucmedi.ensemble.aggregate.agg_base]

    !!! info
        Description and list of implemented Aggregate functions can be found here:
        [Aggregate][aucmedi.ensemble.aggregate]

    The Data Augmentation class instance from the DataGenerator will be used for inference augmenting.
    It can either be predefined or remain `None`. If the `data_aug` is `None`, a Data Augmentation class
    instance is automatically created which applies rotation and flipping augmentations.

    ???+ warning
        The passed DataGenerator will be re-initialized!
        This can result in redundant image preparation if `prepare_images=True`.

    ??? reference "Reference for Ensemble Learning Techniques"
        Dominik Müller, Iñaki Soto-Rey and Frank Kramer. (2022).
        An Analysis on Ensemble Learning optimized Medical Image Classification with Deep Convolutional Neural Networks.
        arXiv e-print: [https://arxiv.org/abs/2201.11440](https://arxiv.org/abs/2201.11440)

    Args:
        model (NeuralNetwork):                 Instance of a AUCMEDI neural network class.
        prediction_generator (DataGenerator):   A data generator which will be used for Augmenting based inference.
        n_cycles (int):                         Number of image augmentations, which should be created per sample.
        aggregate (str or aggregate Function):  Aggregate function class instance or a string for an AUCMEDI Aggregate function.
    """
    # Initialize aggregate function if required
    if isinstance(aggregate, str) and aggregate in aggregate_dict:
        agg_fun = aggregate_dict[aggregate]()
    else : agg_fun = aggregate

    # Initialize image augmentation if none provided (only flip, rotate)
    if prediction_generator.data_aug is None and len(model.input_shape) == 3:
        data_aug = ImageAugmentation(flip=True, rotate=True, scale=False,
                                     brightness=False, contrast=False,
                                     saturation=False, hue=False, crop=False,
                                     grid_distortion=False, compression=False,
                                     gamma=False, gaussian_noise=False,
                                     gaussian_blur=False, downscaling=False,
                                     elastic_transform=False)
    elif prediction_generator.data_aug is None and len(model.input_shape) == 4:
        data_aug = VolumeAugmentation(flip=True, rotate=True, scale=False,
                                      brightness=False, contrast=False,
                                      saturation=False, hue=False, crop=False,
                                      grid_distortion=False, compression=False,
                                      gamma=False, gaussian_noise=False,
                                      gaussian_blur=False, downscaling=False,
                                      elastic_transform=False)
    else : data_aug = prediction_generator.data_aug
    # Multiply sample list for prediction according to number of cycles
    samples_aug = np.repeat(prediction_generator.samples, n_cycles)

    # Re-initialize DataGenerator for inference
    aug_gen = DataGenerator(samples_aug,
                            path_imagedir=prediction_generator.path_imagedir,
                            labels=None,
                            metadata=prediction_generator.metadata,
                            batch_size=prediction_generator.batch_size,
                            data_aug=data_aug,
                            seed=prediction_generator.seed,
                            subfunctions=prediction_generator.subfunctions,
                            shuffle=False,
                            standardize_mode=prediction_generator.standardize_mode,
                            resize=prediction_generator.resize,
                            grayscale=prediction_generator.grayscale,
                            prepare_images=prediction_generator.prepare_images,
                            sample_weights=None,
                            image_format=prediction_generator.image_format,
                            loader=prediction_generator.sample_loader,
                            workers=prediction_generator.workers,
                            **prediction_generator.kwargs)

    # Compute predictions with provided model
    preds_all = model.predict(aug_gen)

    # Ensemble inferences via aggregate function
    preds_ensembled = []
    for i in range(0, len(prediction_generator.samples)):
        # Identify subset for a single sample
        j = i*n_cycles
        subset = preds_all[j:j+n_cycles]
        # Aggregate predictions
        pred_sample = agg_fun.aggregate(subset)
        # Add prediction to prediction list
        preds_ensembled.append(pred_sample)
    # Convert prediction list to NumPy
    preds_ensembled = np.asarray(preds_ensembled)

    # Return ensembled predictions
    return preds_ensembled