Skip to content

Decoder

xai_decoder(data_gen, model, preds=None, method='gradcam', layerName=None, overlay=True, alpha=0.25, preprocess_overlay=True, out_path=None) ยค

XAI Decoder function for automatic computation of Explainable AI heatmaps.

This module allows to visualize which regions were crucial for the neural network model to compute a classification on the provided unknown images.

  • If out_path parameter is None, heatmaps are returned as NumPy array.
  • If a path is provided as out_path, then heatmaps are stored to disk as PNG files.
XAI Methods

The XAI Decoder can be run with different XAI methods as backbone.

A list of all implemented methods and their keys can be found here:
aucmedi.xai.methods

Parameter: preprocess_overlay

The XAI method computation is based on the fully preprocessed image. However, sometimes it is needed to map the resulting XAI map to the original image.

Subfunctions which drastically alter the image resolution like cropping lead to an incorrect mapping process which is why a slight preprocessing of the images, on which the XAI heatmap is overlayed, is recommended.

Example
# Create a DataGenerator for data I/O
datagen = DataGenerator(samples[:3], "images_xray/", labels=None, resize=(299, 299))

# Get a model
model = NeuralNetwork(n_labels=3, channels=3, architecture="Xception",
                       input_shape=(299,299))
model.load("model.xray.keras")

# Make some predictions
preds = model.predict(datagen)

# Compute XAI heatmaps via Grad-CAM (resulting heatmaps are stored in out_path)
xai_decoder(datagen, model, preds, method="gradcam", out_path="xai.xray_gradcam")

Parameters:

Name Type Description Default
data_gen DataGenerator

A data generator which will be used for inference.

required
model NeuralNetwork

Instance of a AUCMEDI neural network class.

required
preds numpy.ndarray

NumPy Array of classification prediction encoded as OHE (output of a AUCMEDI prediction).

None
method str

XAI method class instance or index. By default, GradCAM is used as XAI method.

'gradcam'
layerName str

Layer name of the convolutional layer for heatmap computation. If None, the last conv layer is used.

None
overlay bool

Switch deciding if XAI heatmap should be plotted as overlap on the original image. If False, only the XAI heatmap will be stroed.

True
alpha float

Transparency value for heatmap overlap plotting on input image (range: [0-1]).

0.25
preprocess_overlay bool

Switch for Subfunction application on visualization. Only relevant if heatmaps are saved to disk.

True
out_path str

Output path in which heatmaps are saved to disk as provided image_format (DataGenerator).

None

Returns:

Name Type Description
images numpy.ndarray

Combined array of images. Will be only returned if out_path parameter is None.

heatmaps numpy.ndarray

Combined array of XAI heatmaps. Will be only returned if out_path parameter is None.

Source code in aucmedi/xai/decoder.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def xai_decoder(data_gen, model, preds=None, method="gradcam", layerName=None,
                overlay=True, alpha=0.25, preprocess_overlay=True, out_path=None):
    """ XAI Decoder function for automatic computation of Explainable AI heatmaps.

    This module allows to visualize which regions were crucial for the neural network model
    to compute a classification on the provided unknown images.

    - If `out_path` parameter is None, heatmaps are returned as NumPy array.
    - If a path is provided as `out_path`, then heatmaps are stored to disk as PNG files.

    ???+ info "XAI Methods"
        The XAI Decoder can be run with different XAI methods as backbone.

        A list of all implemented methods and their keys can be found here: <br>
        [aucmedi.xai.methods][]

    ??? info "Parameter: preprocess_overlay"
        The XAI method computation is based on the fully preprocessed image.
        However, sometimes it is needed to map the resulting XAI map to the original image.

        Subfunctions which drastically alter the image resolution like cropping lead to 
        an incorrect mapping process which is why a slight preprocessing of the images, on
        which the XAI heatmap is overlayed, is recommended.

    ???+ example "Example"
        ```python
        # Create a DataGenerator for data I/O
        datagen = DataGenerator(samples[:3], "images_xray/", labels=None, resize=(299, 299))

        # Get a model
        model = NeuralNetwork(n_labels=3, channels=3, architecture="Xception",
                               input_shape=(299,299))
        model.load("model.xray.keras")

        # Make some predictions
        preds = model.predict(datagen)

        # Compute XAI heatmaps via Grad-CAM (resulting heatmaps are stored in out_path)
        xai_decoder(datagen, model, preds, method="gradcam", out_path="xai.xray_gradcam")
        ```

    Args:
        data_gen (DataGenerator):           A data generator which will be used for inference.
        model (NeuralNetwork):              Instance of a AUCMEDI neural network class.
        preds (numpy.ndarray):              NumPy Array of classification prediction encoded as OHE (output of a AUCMEDI prediction).
        method (str):                       XAI method class instance or index. By default, GradCAM is used as XAI method.
        layerName (str):                    Layer name of the convolutional layer for heatmap computation. If `None`, the last conv layer is used.
        overlay (bool):                     Switch deciding if XAI heatmap should be plotted as overlap on the original image.
                                            If `False`, only the XAI heatmap will be stroed.
        alpha (float):                      Transparency value for heatmap overlap plotting on input image (range: [0-1]).
        preprocess_overlay (bool):          Switch for Subfunction application on visualization. Only relevant if heatmaps are saved to disk.
        out_path (str):                     Output path in which heatmaps are saved to disk as provided `image_format` (DataGenerator).

    Returns:
        images (numpy.ndarray):             Combined array of images. Will be only returned if `out_path` parameter is `None`.
        heatmaps (numpy.ndarray):           Combined array of XAI heatmaps. Will be only returned if `out_path` parameter is `None`.
    """
    # Initialize & access some variables
    batch_size = data_gen.batch_size
    n_classes = model.n_labels
    sample_list = data_gen.samples
    # Prepare XAI output methods
    res_img = []
    res_xai = []
    if out_path is not None and not os.path.exists(out_path) : os.mkdir(out_path)
    # Initialize xai method
    if isinstance(method, str) and method in xai_dict:
        xai_method = xai_dict[method](model.model, layerName=layerName)
    else : xai_method = method

    # Iterate over all samples
    for i in range(0, len(sample_list)):
        # Load overlay image
        if preprocess_overlay:
            img_org = data_gen.preprocess_image(i, 
                                                run_resize=False, 
                                                run_aug=False, 
                                                run_standardize=False)
            shape_org = img_org.shape[0:-1]
        # Load original image
        else:
            img_org = data_gen.sample_loader(sample_list[i], 
                                    data_gen.path_imagedir,
                                    image_format=data_gen.image_format,
                                    grayscale=data_gen.grayscale,
                                    **data_gen.kwargs)
            shape_org = img_org.shape[0:-1]

        # Load processed image
        img_prc = data_gen.preprocess_image(i, run_aug=False)
        img_batch = np.expand_dims(img_prc, axis=0)
        # If preds given, compute heatmap only for argmax class
        if preds is not None:
            ci = np.argmax(preds[i])
            xai_map = xai_method.compute_heatmap(img_batch, class_index=ci)
            xai_map = Resize(shape=shape_org).transform(xai_map)
            postprocess_output(sample_list[i], img_org, xai_map, 
                               n_classes, data_gen, res_img, res_xai, 
                               overlay, out_path, alpha)
        # If no preds given, compute heatmap for all classes
        else:
            sample_maps = []
            for ci in range(0, n_classes):
                xai_map = xai_method.compute_heatmap(img_batch, class_index=ci)
                xai_map = Resize(shape=shape_org).transform(xai_map)
                sample_maps.append(xai_map)
            sample_maps = np.array(sample_maps)
            postprocess_output(sample_list[i], img_org, sample_maps, 
                               n_classes, data_gen, res_img, res_xai, 
                               overlay, out_path, alpha)
    # Return output directly if no output path is defined
    if out_path is None : return res_img, res_xai