Look inside the model using PyTorch Hooks
What are Hooks?
In computer science, hooks are programming mechanisms that allow developers to intercept and alter the behaviour of software systems at specific points in their execution. Hooks enable customization and extension of software functionality without modifying the original source code. In machine learning, we use hooks to look inside the model to debug and monitor various layers.
In PyTorch, hooks are functions or callable objects that allow us to look into the internal operations of the model during its forward and backward passes. Hooks provide a flexible way to access and manipulate the activations, gradients, and parameters of layers. They can be used for various purposes, including debugging, visualization, feature extraction, and customizing the training process.
PyTorch has two types of hooks.
- Forward hook: They are executed during the forward pass of the model, allowing us to access intermediate layer outputs.
- Backward hook: They are executed during the backward pass of the model, providing access to gradients
In this article, we will implement hooks from scratch and also how to use them for various use cases. I have embedded a few snippet of code you can check full code here.
Implement Hooks from scratch.
We will use forward hooks to see observe each layer after activation. For this, we will use the MNIST dataset and CNN layers with activation, and we will see how each layer's output mean and standard varies.
Download the dataset from Hugging Face and load it into data loaders.
dataset = load_dataset("mnist")
train_d =
train_v = dataset['test']
class Dataset():
def __init__(self, x, y):
self.x,self.y = x,y
def __len__(self):
return len(self.x)
def __getitem__(self, i):
return { "x" : torch.tensor(np.array(self.x[i])[None] , dtype = torch.float),"y" :torch.tensor(self.y[i])}
train_dataset = Dataset(dataset['train']['image'] , dataset['train']['label'])
val_dataset = Dataset(dataset['test']['image'] , dataset['test']['label'])
train_dataloader = DataLoader(train_dataset , batch_size=256)
val_dataloader = DataLoader(val_dataset , batch_size=256)
We will use convolution layers for our model, and in the forward function, we will store the mean of the output of each layer.
def conv(ni, nf, ks=3, act=True):
res = nn.Conv2d(ni, nf, stride=2, kernel_size=ks, padding=ks//2)
if act: res = nn.Sequential(res, nn.ReLU())
return res
def cnn_layers():
return [
conv(1 ,8, ks=5), #14x14
conv(8 ,16), #7x7
conv(16,32), #4x4
conv(32,64), #2x2
conv(64,10, act=False), #1x1
nn.Flatten()]
class Model(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList(cnn_layers())
self.layer_means = [ [] for _ in range(6)]
self.layer_std = [ [] for _ in range(6)]
def __call__(self , x):
for i, l in enumerate(self.layers):
x = l(x)
self.layer_means[i].append(x.detach().cpu().numpy().mean())
self.layer_std[i].append(x.detach().cpu().numpy().std())
return x
Fit function, we will use SGD optimizer and cross_entropy loss function to train our model.
import tqdm
def fit(model,train_dataloader , val_dataloader , loss_func ,optimizer,epochs,device):
for epoch in range(epochs):
with tqdm.tqdm(total=len(train_dataloader), desc=" Train :: Epoch {}/{}".format(epoch + 1, epochs)) as pbar:
model.train()
for batch in train_dataloader:
x = batch['x'].to(device)
y = batch['y'].to(device)
pred = model(x)
loss = loss_func(pred,y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
pbar.set_postfix({"loss": loss.item()})
pbar.update(1)
with tqdm.tqdm(total=len(train_dataloader), desc=" Val :: Epoch {}/{}".format(epoch + 1, epochs)) as pbar:
model.eval()
with torch.no_grad():
for batch in train_dataloader:
x = batch['x'].to(device)
y = batch['y'].to(device)
pred = model(x)
loss = loss_func(pred,y)
pbar.set_postfix({"loss": loss.item()})
pbar.update(1)
After training, we can analyse how the output of each layer varies.
for l in model.layer_means:
plt.plot(l)
plt.legend(range(5) , loc='right');
output
We can see output varies a lot. Ideally, we want the output of each layer to mean zero and variance one. For this, you can use input normalization and batch normalization and model weight initialization, and instead of vanilla ReLU, we can use LeakyReLU. I will cover this in some other post.
PyTorch Hooks
Hooks are PyTorch objects. We can add them to any nn.Module. A hook will be called when a layer is called; Hook is registered to a layer and is executed during the forward pass (forward hook) or the backward pass (backward hook). Hooks don’t require us to rewrite the model.
Hooks are attached to layers. Hooks must have a function that takes three arguments module (layer instance), input to that module, and output of that module. In this, we don’t need to modify the model class.
class Model(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList(cnn_layers())
def __call__(self , x):
for i, l in enumerate(self.layers):
x = l(x)
return x
we will define function append_ that will take the output from hooks and store the mean and std of each layer
layer_means = [[] for _ in model]
layer_stds = [[] for _ in model]
def append_(i, module, input, output):
act_means[i].append(output.detach().cpu().numpy().mean())
act_stds [i].append(outp.detach().cpu().numpy().std())
now we will register our hooks to each layer of the model
for i,m in enumerate(model):
m.register_forward_hook(partial(append_, i))
Use cases
PyTorch hooks are a powerful tool that can improve the performance, accuracy, and debugging of PyTorch models. They can be used to track the progress of model training, collect information about the model’s predictions, visualize its internal state, and debug the model. we can use gradient trimming, feature extraction and gradient clipping, runtime augmentation and many more ……
References