Gradient-weighted Class Activation Mapping (Grad-CAM) is a technique for making Convolutional Neural Network (CNN)-based models explainable by visualizing the importance of image regions for the predictions. In this post we walk through a Tensorflow implementation of Grad-CAM and apply it to the metastatic cancer detection CNN model created in the previous post.


non-tumor, p = 0.03


non-tumor map


tumor, p = 0.99


tumor map

Grad-CAM uses the class-specific gradient information input into the final convolutional layer of a CNN to produce a heatmap of region importance. We can use this heatmap to identify the parts of the image the CNN used to make the diagnosis providing a useful cross-check.

Despite the utilty of the method no standard implementation of Grad-CAM exists. However versions can be found in Github and in other blogs that get close to our requirements and are refrenced below.

Grad-CAM implementation using Keras and Tensorflow

The first thing we need to do is import the Tensorflow packages and use the GradCAM constructor to pass in the model and an optional layerName that can be used to specify a specific layer to visualize. If no layer is specified the final convolution layer will be used.

import tensorflow as tf
import numpy as np
import cv2 as cv

class GradCAM:

    def __init__(self, model, layerName=None):
        """
        model: pre-softmax layer (logit layer)
        """
        self.model = model
        self.layerName = layerName
            
        if self.layerName == None:
            self.layerName = self.find_target_layer()

Now we define a method that searches the model looking for a 4 dimensional output that could be a convolution and outputs the layer name.

    def find_target_layer(self):
        for layer in reversed(self.model.layers):
            if len(layer.output_shape) == 4:
                return layer.name
        raise ValueError("Could not find 4D layer. Cannot apply GradCAM")

To compute the heatmap we need to supply the following information

  1. the inputs of the pretrained model
  2. the outpur of the final 4D layer
  3. the output softmax activations
  4. the classIdx we are interested in.

The heatmap computation relies on the use of automatic differentiation to compute the gradient of a computation with respect to its input values. This is done in Tensorflow using the tf.GradientTape API which records all operations executed in the context of the tape. Tensorflow then uses that tape to compute the gradients associated with the recorded operations.

    def compute_heatmap(self, image, classIdx, upsample_size, eps=1e-5):
        gradModel = tf.keras.Model(
            inputs = [self.model.inputs],
            outputs = [self.model.get_layer(self.layerName).output, self.model.output]
        )

        # record operations for automatic differentiation       
        with tf.GradientTape() as tape:
            inputs = tf.cast(image, tf.float32)
            (convOuts, preds) = gradModel(inputs) # preds after softmax
            loss = preds[:,classIdx]
        
        # compute gradients with automatic differentiation
        grads = tape.gradient(loss, convOuts)
        
        # discard batch
        convOuts = convOuts[0]
        grads = grads[0]
        norm_grads = tf.divide(grads, tf.reduce_mean(tf.square(grads)) + tf.constant(eps))

The final stage is to compute the weights of the gradient values by computing the mean of the normalised grads and summing them into the Grad-CAM visualization.

        # compute weights
        weights = tf.reduce_mean(norm_grads, axis=(0,1))
        cam = tf.reduce_sum(tf.multiply(weights, convOuts), axis=-1)
        
        # Apply reLU
        cam = np.maximum(cam, 0)
        cam = cam/np.max(cam)
        cam = cv.resize(cam, upsample_size,interpolation=cv.INTER_LINEAR)
        
        # convert to 3D
        cam3 = np.expand_dims(cam, axis=2)
        cam3 = np.tile(cam3, [1,1,3])
        
        return cam3

Now we can transparently overlay the Grad-CAM heatmap onto the output image.

def overlay_gradCAM(img, cam3, prob):
    cam3 = np.uint8(255*cam3*prob)
    cam3 = cv.applyColorMap(cam3, cv.COLORMAP_JET)
    
    new_img = 0.3*cam3 + 0.5*img
    
    return (new_img*255.0/new_img.max()).astype("uint8")