GradCAM in PyTorch
Original Source Here
The image format and the library one uses for reading image may differ.
img = imread(‘/content/tiger.jfif’) #'bulbul.jpg'
img = resize(img, (224,224), preserve_range = True)
img = np.expand_dims(img.transpose((2,0,1)),0)
img /= 255.0
mean = np.array([0.485, 0.456, 0.406]).reshape((1,3,1,1))
std = np.array([0.229, 0.224, 0.225]).reshape((1,3,1,1))
img = (img — mean)/std
inpimg = torch.from_numpy(img).to(‘cuda:0’, torch.float32)
Compute Gradient Class Activation Maps
out, acts = gcmodel(inpimg)acts = acts.detach().cpu()loss = nn.CrossEntropyLoss()(out,torch.from_numpy(np.array([600])).to(‘cuda:0’))
loss.backward()grads = gcmodel.get_act_grads().detach().cpu()pooled_grads = torch.mean(grads, dim=[0,2,3]).detach().cpu()for i in range(acts.shape[1]):
acts[:,i,:,:] += pooled_grads[i]heatmap_j = torch.mean(acts, dim = 1).squeeze()heatmap_j_max = heatmap_j.max(axis = 0)[0]heatmap_j /= heatmap_j_max
Now, the heatmap needs to be resized and colour mapped.
Resize Heatmap
heatmap_j = resize(heatmap_j,(224,224),preserve_range=True)
Colour Mapping
cmap = mpl.cm.get_cmap(‘jet’,256)
heatmap_j2 = cmap(heatmap_j,alpha = 0.2)
Plotting
fig, axs = plt.subplots(1,1,figsize = (5,5))
axs.imshow((img*std+mean)[0].transpose(1,2,0))
axs.imshow(heatmap_j2)
plt.show()
Results
AI/ML
Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot
via WordPress https://ramseyelbasheer.io/2021/07/31/gradcam-in-pytorch/