GradCAM in PyTorch



Original Source Here

The image format and the library one uses for reading image may differ.

img = imread(‘/content/tiger.jfif’) #'bulbul.jpg'
img = resize(img, (224,224), preserve_range = True)
img = np.expand_dims(img.transpose((2,0,1)),0)
img /= 255.0
mean = np.array([0.485, 0.456, 0.406]).reshape((1,3,1,1))
std = np.array([0.229, 0.224, 0.225]).reshape((1,3,1,1))
img = (img — mean)/std
inpimg = torch.from_numpy(img).to(‘cuda:0’, torch.float32)

Compute Gradient Class Activation Maps

out, acts = gcmodel(inpimg)acts = acts.detach().cpu()loss = nn.CrossEntropyLoss()(out,torch.from_numpy(np.array([600])).to(‘cuda:0’))
loss.backward()
grads = gcmodel.get_act_grads().detach().cpu()pooled_grads = torch.mean(grads, dim=[0,2,3]).detach().cpu()for i in range(acts.shape[1]):
acts[:,i,:,:] += pooled_grads[i]
heatmap_j = torch.mean(acts, dim = 1).squeeze()heatmap_j_max = heatmap_j.max(axis = 0)[0]heatmap_j /= heatmap_j_max

Now, the heatmap needs to be resized and colour mapped.

Resize Heatmap

heatmap_j = resize(heatmap_j,(224,224),preserve_range=True)

Colour Mapping

cmap = mpl.cm.get_cmap(‘jet’,256)
heatmap_j2 = cmap(heatmap_j,alpha = 0.2)

Plotting

fig, axs = plt.subplots(1,1,figsize = (5,5))
axs.imshow((img*std+mean)[0].transpose(1,2,0))
axs.imshow(heatmap_j2)
plt.show()

Results

AI/ML

Trending AI/ML Article Identified & Digested via Granola by Ramsey Elbasheer; a Machine-Driven RSS Bot



via WordPress https://ramseyelbasheer.io/2021/07/31/gradcam-in-pytorch/

Popular posts from this blog

I’m Sorry! Evernote Has A New ‘Home’ Now

Jensen Huang: Racism is one flywheel we must stop

Streamlit — Deploy your app in just a few minutes