Introduction

I have hidden all the code from the previous post so we can focus on the experiments but you can find it if you download the notebook or open it in colab.

Collecting the Dataset

Visualize the dataset

Remember the model class definition.

class CNN(nn.Module):
    def __init__(self, dropout_rate=0.0, norm_w=True):
        super(CNN, self).__init__()

        self.norm_w = norm_w

        # Create the layers normalizing or not the weigths
        if self.norm_w:
            self.conv_layer1 = nn.utils.weight_norm(nn.Conv2d(in_channels=1, out_channels=32, 
                                     kernel_size=3), name='weight')
            self.conv_layer2 = nn.utils.weight_norm(nn.Conv2d(32, 64, 3), name='weight')
            self.fc = nn.utils.weight_norm(nn.Linear(in_features= 64*5*5, out_features=10), name='weight')
          
        else:
            self.conv_layer1 = nn.Conv2d(in_channels=1, out_channels=32, 
                                     kernel_size=3)
            self.conv_layer2 = nn.Conv2d(32, 64, 3)
            self.fc = nn.Linear(in_features= 64*5*5, out_features=10)

        
        # input dimensions are Bx1x28x28 (BxCxHxW)
        self.batch_norm1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2,stride=2)

        self.batch_norm2 = nn.BatchNorm2d(64)

        self.flat = nn.Flatten()
        self.drop = nn.Dropout(p=dropout_rate)


    def forward(self, x):
        # Block 1
        out = self.conv_layer1(x)
        out = self.batch_norm1(out)
        out = self.relu(out)
        out = self.pool(out)

        # Block 2
        out = self.conv_layer2(out)
        out = self.batch_norm2(out)
        out = self.relu(out)
        out = self.pool(out)

        # Flatten the output using BxC*H*W 
        out = self.flat(out)
        out = self.drop(out)
        out = self.fc(out)
        
        return out

    def add_quant(self):
        '''
        Returns a new model with added quantization layers
        '''
        return nn.Sequential(torch.quantization.QuantStub(), self,
                          torch.quantization.DeQuantStub())

Training without PyTorch Model Quantization

Helper functions for Pytorch Quantizantion and evaluation

Chain everything together in a single train function

Experiments

Let's simulate that it is our first time facing this problem, so we haven't thought about normalizing the weights.

model_15, _ = train_model_and_quantize(epochs=15, norm_w=False)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 26, 26]             320
       BatchNorm2d-2           [-1, 32, 26, 26]              64
              ReLU-3           [-1, 32, 26, 26]               0
         MaxPool2d-4           [-1, 32, 13, 13]               0
            Conv2d-5           [-1, 64, 11, 11]          18,496
       BatchNorm2d-6           [-1, 64, 11, 11]             128
              ReLU-7           [-1, 64, 11, 11]               0
         MaxPool2d-8             [-1, 64, 5, 5]               0
           Flatten-9                 [-1, 1600]               0
          Dropout-10                 [-1, 1600]               0
           Linear-11                   [-1, 10]          16,010
================================================================
Total params: 35,018
Trainable params: 35,018
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.75
Params size (MB): 0.13
Estimated Total Size (MB): 0.89
----------------------------------------------------------------
Epoch 1
Time: 2.53 s  -  Loss: 0.2702  -  Categorical_Accuracy: 0.9447  -  Val_Loss: 0.0466  -  Categorical_Val_Accuracy: 0.9848
Epoch 2
Time: 2.21 s  -  Loss: 0.0523  -  Categorical_Accuracy: 0.9837  -  Val_Loss: 0.0416  -  Categorical_Val_Accuracy: 0.9869
Epoch 3
Time: 2.21 s  -  Loss: 0.0358  -  Categorical_Accuracy: 0.9885  -  Val_Loss: 0.0286  -  Categorical_Val_Accuracy: 0.9904
Epoch 4
Time: 2.20 s  -  Loss: 0.0291  -  Categorical_Accuracy: 0.9907  -  Val_Loss: 0.0337  -  Categorical_Val_Accuracy: 0.9888
Epoch 5
Time: 2.20 s  -  Loss: 0.0250  -  Categorical_Accuracy: 0.9920  -  Val_Loss: 0.0335  -  Categorical_Val_Accuracy: 0.9909
Epoch 6
Time: 2.22 s  -  Loss: 0.0204  -  Categorical_Accuracy: 0.9937  -  Val_Loss: 0.0306  -  Categorical_Val_Accuracy: 0.9898
Epoch 7
Time: 2.26 s  -  Loss: 0.0183  -  Categorical_Accuracy: 0.9945  -  Val_Loss: 0.0419  -  Categorical_Val_Accuracy: 0.9863
Epoch 8
Time: 2.23 s  -  Loss: 0.0168  -  Categorical_Accuracy: 0.9946  -  Val_Loss: 0.0379  -  Categorical_Val_Accuracy: 0.9900
Epoch 9
Time: 2.23 s  -  Loss: 0.0176  -  Categorical_Accuracy: 0.9941  -  Val_Loss: 0.0436  -  Categorical_Val_Accuracy: 0.9882
Epoch 10
Time: 2.23 s  -  Loss: 0.0159  -  Categorical_Accuracy: 0.9945  -  Val_Loss: 0.0376  -  Categorical_Val_Accuracy: 0.9889
Epoch 11
Time: 2.21 s  -  Loss: 0.0159  -  Categorical_Accuracy: 0.9949  -  Val_Loss: 0.0433  -  Categorical_Val_Accuracy: 0.9881
Epoch 12
Time: 2.21 s  -  Loss: 0.0132  -  Categorical_Accuracy: 0.9954  -  Val_Loss: 0.0428  -  Categorical_Val_Accuracy: 0.9894
Epoch 13
Time: 2.21 s  -  Loss: 0.0143  -  Categorical_Accuracy: 0.9956  -  Val_Loss: 0.0443  -  Categorical_Val_Accuracy: 0.9884
Epoch 14
Time: 2.22 s  -  Loss: 0.0113  -  Categorical_Accuracy: 0.9962  -  Val_Loss: 0.0397  -  Categorical_Val_Accuracy: 0.9896
Epoch 15
Time: 2.23 s  -  Loss: 0.0132  -  Categorical_Accuracy: 0.9961  -  Val_Loss: 0.0371  -  Categorical_Val_Accuracy: 0.9895
Finished Training
Running Final Evaluation
Model Name: mnist_model_float.pt, Quantized: False
Model Size: 142.11 KB
Accuracy: 0.9895
Eval time: 5.13s
/usr/local/lib/python3.7/dist-packages/torch/ao/quantization/observer.py:179: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  reduce_range will be deprecated in a future release of PyTorch."
Model Name: mnist_model_quantized.pt, Quantized: True
Model Size: 41.67 KB
Accuracy: 0.6951
Eval time: 6.34s

As we can see, the drop in the performance of the quantized model is remarkable. What might have happened? Why the results are way better when using Keras?

We could try training for just 1 epoch and see what happend

model_1, _ = train_model_and_quantize(epochs=1, norm_w=False)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 26, 26]             320
       BatchNorm2d-2           [-1, 32, 26, 26]              64
              ReLU-3           [-1, 32, 26, 26]               0
         MaxPool2d-4           [-1, 32, 13, 13]               0
            Conv2d-5           [-1, 64, 11, 11]          18,496
       BatchNorm2d-6           [-1, 64, 11, 11]             128
              ReLU-7           [-1, 64, 11, 11]               0
         MaxPool2d-8             [-1, 64, 5, 5]               0
           Flatten-9                 [-1, 1600]               0
          Dropout-10                 [-1, 1600]               0
           Linear-11                   [-1, 10]          16,010
================================================================
Total params: 35,018
Trainable params: 35,018
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.75
Params size (MB): 0.13
Estimated Total Size (MB): 0.89
----------------------------------------------------------------
Epoch 1
Time: 2.34 s  -  Loss: 0.2260  -  Categorical_Accuracy: 0.9505  -  Val_Loss: 0.0466  -  Categorical_Val_Accuracy: 0.9838
Finished Training
Running Final Evaluation
Model Name: mnist_model_float.pt, Quantized: False
Model Size: 142.11 KB
Accuracy: 0.9838
Eval time: 5.25s
/usr/local/lib/python3.7/dist-packages/torch/ao/quantization/observer.py:179: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  reduce_range will be deprecated in a future release of PyTorch."
Model Name: mnist_model_quantized.pt, Quantized: True
Model Size: 41.67 KB
Accuracy: 0.9424
Eval time: 6.03s

The result of the quantization is better now. What is happening?

The first thing that comes to my mind is taking a look at the weights🔍

import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

def vis_model(model):
  w = []
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        w.extend(np.array(m.weight.data.cpu()).flatten())
  plt.hist(w, density=True, bins=128)  
  plt.ylabel('Ocurrences')
  plt.xlabel('Weight')
vis_model(model_15)
vis_model(model_1)

So the weights of the model trained for 15 epochs are larger, may that be what is causing the problem?

We can see how weights are distributed in the Keras model and we will see the wights are smaller (more similar to the model trained for one epoch).

The explanation may be that with larger weights the minimum and the maximum are further apart so the "resolution" of the quantization is worse. One quick fix could be clamping the weights, so they can not be larger or smaller than a certain threshold. But this will make that our model just learn in the few first iterations. It will produce a histogram like this one:

A better option would be normalizing the weights so let's try it.

model_15_norm, _ = train_model_and_quantize(epochs=15, norm_w=True)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 26, 26]             320
       BatchNorm2d-2           [-1, 32, 26, 26]              64
              ReLU-3           [-1, 32, 26, 26]               0
         MaxPool2d-4           [-1, 32, 13, 13]               0
            Conv2d-5           [-1, 64, 11, 11]          18,496
       BatchNorm2d-6           [-1, 64, 11, 11]             128
              ReLU-7           [-1, 64, 11, 11]               0
         MaxPool2d-8             [-1, 64, 5, 5]               0
           Flatten-9                 [-1, 1600]               0
          Dropout-10                 [-1, 1600]               0
           Linear-11                   [-1, 10]          16,010
================================================================
Total params: 35,018
Trainable params: 35,018
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.75
Params size (MB): 0.13
Estimated Total Size (MB): 0.89
----------------------------------------------------------------
Epoch 1
Time: 2.44 s  -  Loss: 0.1988  -  Categorical_Accuracy: 0.9498  -  Val_Loss: 0.0593  -  Categorical_Val_Accuracy: 0.9806
Epoch 2
Time: 2.36 s  -  Loss: 0.0509  -  Categorical_Accuracy: 0.9848  -  Val_Loss: 0.0391  -  Categorical_Val_Accuracy: 0.9870
Epoch 3
Time: 2.51 s  -  Loss: 0.0383  -  Categorical_Accuracy: 0.9880  -  Val_Loss: 0.0392  -  Categorical_Val_Accuracy: 0.9871
Epoch 4
Time: 2.34 s  -  Loss: 0.0325  -  Categorical_Accuracy: 0.9901  -  Val_Loss: 0.0304  -  Categorical_Val_Accuracy: 0.9897
Epoch 5
Time: 2.37 s  -  Loss: 0.0279  -  Categorical_Accuracy: 0.9909  -  Val_Loss: 0.0394  -  Categorical_Val_Accuracy: 0.9877
Epoch 6
Time: 2.39 s  -  Loss: 0.0237  -  Categorical_Accuracy: 0.9928  -  Val_Loss: 0.0339  -  Categorical_Val_Accuracy: 0.9896
Epoch 7
Time: 2.33 s  -  Loss: 0.0208  -  Categorical_Accuracy: 0.9933  -  Val_Loss: 0.0284  -  Categorical_Val_Accuracy: 0.9899
Epoch 8
Time: 2.34 s  -  Loss: 0.0171  -  Categorical_Accuracy: 0.9948  -  Val_Loss: 0.0386  -  Categorical_Val_Accuracy: 0.9869
Epoch 9
Time: 2.35 s  -  Loss: 0.0154  -  Categorical_Accuracy: 0.9950  -  Val_Loss: 0.0476  -  Categorical_Val_Accuracy: 0.9849
Epoch 10
Time: 2.34 s  -  Loss: 0.0147  -  Categorical_Accuracy: 0.9952  -  Val_Loss: 0.0313  -  Categorical_Val_Accuracy: 0.9903
Epoch 11
Time: 2.34 s  -  Loss: 0.0116  -  Categorical_Accuracy: 0.9961  -  Val_Loss: 0.0375  -  Categorical_Val_Accuracy: 0.9885
Epoch 12
Time: 2.35 s  -  Loss: 0.0124  -  Categorical_Accuracy: 0.9960  -  Val_Loss: 0.0399  -  Categorical_Val_Accuracy: 0.9896
Epoch 13
Time: 2.34 s  -  Loss: 0.0099  -  Categorical_Accuracy: 0.9969  -  Val_Loss: 0.0283  -  Categorical_Val_Accuracy: 0.9918
Epoch 14
Time: 2.36 s  -  Loss: 0.0083  -  Categorical_Accuracy: 0.9973  -  Val_Loss: 0.0290  -  Categorical_Val_Accuracy: 0.9913
Epoch 15
Time: 2.43 s  -  Loss: 0.0067  -  Categorical_Accuracy: 0.9977  -  Val_Loss: 0.0337  -  Categorical_Val_Accuracy: 0.9907
Finished Training
Running Final Evaluation
Model Name: mnist_model_float.pt, Quantized: False
Model Size: 142.11 KB
Accuracy: 0.9907
Eval time: 5.51s
/usr/local/lib/python3.7/dist-packages/torch/ao/quantization/observer.py:179: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  reduce_range will be deprecated in a future release of PyTorch."
Model Name: mnist_model_quantized.pt, Quantized: True
Model Size: 41.67 KB
Accuracy: 0.9838
Eval time: 6.13s

We get better results! Lower loss in performance while being able to reduce significantly the size of the model.

And as we can see the weights are now normalized.

vis_model(model_15_norm)

Conclusion

After these experiments, the main takeaways would be the following:

  • We can look at the network's weights and see what is happening. It is a good practice and provides some understanding of how the model is being/has being trained.

  • The values that our weight take are really important in quantization.

  • We have to be careful with the weights not getting too large in our network. As we reward the outputs with higher values is feasible that the model tries to make its weights as big as possible.

  • Normalizing the weight can lead to a better quantized models. Nevertheless, it can increase the training time with the same number of epochs. But it often converges faster, so fewer epochs are needed. You will have to study your problem carefully. No one said it was easy!