Date: 15/07/2022

Author: @kavindu404

In this mini blog series, I am implementing multiclass classifier for MNIST digits from scratch. In this part, I will be classifying the digits using pixel similarity. I will try to improve the performance in each part. First, let's import FastAI

from fastai.vision.all import *
/usr/lib/python3/dist-packages/requests/__init__.py:89: RequestsDependencyWarning: urllib3 (1.26.12) or chardet (3.0.4) doesn't match a supported version!
  warnings.warn("urllib3 ({}) or chardet ({}) doesn't match a supported "

MNIST dataset can be downloaded and extracted using untar_data() method. With FastAI, we can easily list the elements in the extracted derectory.

path = untar_data(URLs.MNIST)
Path.BASE_PATH = path
path.ls()
(#2) [Path('testing'),Path('training')]

Let's first get training data into different objects. The ls() method returns an object of class L in FastAI.It has all the functionalities in python list() and some more.

zeros = (path/'training'/'0').ls().sorted()
ones = (path/'training'/'1').ls().sorted()
twos = (path/'training'/'2').ls().sorted()
threes = (path/'training'/'3').ls().sorted()
fours = (path/'training'/'4').ls().sorted()
fives = (path/'training'/'5').ls().sorted()
sixes = (path/'training'/'6').ls().sorted()
sevens = (path/'training'/'7').ls().sorted()
eights = (path/'training'/'8').ls().sorted()
nines = (path/'training'/'9').ls().sorted()
zeros,ones,twos,threes,fours,fives,sixes,sevens,eights,nines
((#5923) [Path('training/0/1.png'),Path('training/0/1000.png'),Path('training/0/10005.png'),Path('training/0/10010.png'),Path('training/0/10022.png'),Path('training/0/10025.png'),Path('training/0/10026.png'),Path('training/0/10045.png'),Path('training/0/10069.png'),Path('training/0/10071.png')...],
 (#6742) [Path('training/1/10006.png'),Path('training/1/10007.png'),Path('training/1/1002.png'),Path('training/1/10020.png'),Path('training/1/10027.png'),Path('training/1/1003.png'),Path('training/1/10040.png'),Path('training/1/10048.png'),Path('training/1/10058.png'),Path('training/1/10067.png')...],
 (#5958) [Path('training/2/10009.png'),Path('training/2/10016.png'),Path('training/2/10024.png'),Path('training/2/10029.png'),Path('training/2/10072.png'),Path('training/2/10073.png'),Path('training/2/10075.png'),Path('training/2/10078.png'),Path('training/2/10081.png'),Path('training/2/10082.png')...],
 (#6131) [Path('training/3/10.png'),Path('training/3/10000.png'),Path('training/3/10011.png'),Path('training/3/10031.png'),Path('training/3/10034.png'),Path('training/3/10042.png'),Path('training/3/10052.png'),Path('training/3/1007.png'),Path('training/3/10074.png'),Path('training/3/10091.png')...],
 (#5842) [Path('training/4/10013.png'),Path('training/4/10018.png'),Path('training/4/10033.png'),Path('training/4/1004.png'),Path('training/4/1006.png'),Path('training/4/10060.png'),Path('training/4/1008.png'),Path('training/4/10103.png'),Path('training/4/10104.png'),Path('training/4/10114.png')...],
 (#5421) [Path('training/5/0.png'),Path('training/5/100.png'),Path('training/5/10008.png'),Path('training/5/10015.png'),Path('training/5/10030.png'),Path('training/5/10035.png'),Path('training/5/10049.png'),Path('training/5/10051.png'),Path('training/5/10056.png'),Path('training/5/10062.png')...],
 (#5918) [Path('training/6/10017.png'),Path('training/6/10032.png'),Path('training/6/10036.png'),Path('training/6/10037.png'),Path('training/6/10044.png'),Path('training/6/10053.png'),Path('training/6/10076.png'),Path('training/6/10089.png'),Path('training/6/10101.png'),Path('training/6/10108.png')...],
 (#6265) [Path('training/7/10002.png'),Path('training/7/1001.png'),Path('training/7/10014.png'),Path('training/7/10019.png'),Path('training/7/10039.png'),Path('training/7/10046.png'),Path('training/7/10050.png'),Path('training/7/10063.png'),Path('training/7/10077.png'),Path('training/7/10086.png')...],
 (#5851) [Path('training/8/10001.png'),Path('training/8/10012.png'),Path('training/8/10021.png'),Path('training/8/10041.png'),Path('training/8/10054.png'),Path('training/8/10057.png'),Path('training/8/10061.png'),Path('training/8/10064.png'),Path('training/8/10066.png'),Path('training/8/10079.png')...],
 (#5949) [Path('training/9/10003.png'),Path('training/9/10004.png'),Path('training/9/10023.png'),Path('training/9/10028.png'),Path('training/9/10038.png'),Path('training/9/10043.png'),Path('training/9/10047.png'),Path('training/9/1005.png'),Path('training/9/10055.png'),Path('training/9/10059.png')...])

Now that we have all the data seperated into objects, let's stack them up.

stacked_zeros = torch.stack([tensor(Image.open(o)) for o in zeros]).float()/255
stacked_ones = torch.stack([tensor(Image.open(o)) for o in ones]).float()/255
stacked_twos = torch.stack([tensor(Image.open(o)) for o in twos]).float()/255
stacked_threes = torch.stack([tensor(Image.open(o)) for o in threes]).float()/255
stacked_fours = torch.stack([tensor(Image.open(o)) for o in fours]).float()/255
stacked_fives = torch.stack([tensor(Image.open(o)) for o in fives]).float()/255
stacked_sixes = torch.stack([tensor(Image.open(o)) for o in sixes]).float()/255
stacked_sevens = torch.stack([tensor(Image.open(o)) for o in sevens]).float()/255
stacked_eights = torch.stack([tensor(Image.open(o)) for o in eights]).float()/255
stacked_nines = torch.stack([tensor(Image.open(o)) for o in nines]).float()/255
stacked_zeros.shape, stacked_ones.shape, stacked_twos.shape, stacked_threes.shape, stacked_fours.shape, stacked_fives.shape, stacked_sixes.shape, stacked_sevens.shape, stacked_eights.shape, stacked_nines.shape, 
(torch.Size([5923, 28, 28]),
 torch.Size([6742, 28, 28]),
 torch.Size([5958, 28, 28]),
 torch.Size([6131, 28, 28]),
 torch.Size([5842, 28, 28]),
 torch.Size([5421, 28, 28]),
 torch.Size([5918, 28, 28]),
 torch.Size([6265, 28, 28]),
 torch.Size([5851, 28, 28]),
 torch.Size([5949, 28, 28]))

In our first attempt, we will use pixel similarity. So, first, let's calculate the mean for each digit.

mean0 = stacked_zeros.mean(0)
mean1 = stacked_ones.mean(0)
mean2 = stacked_twos.mean(0)
mean3 = stacked_threes.mean(0)
mean4 = stacked_fours.mean(0)
mean5 = stacked_fives.mean(0)
mean6 = stacked_sixes.mean(0)
mean7 = stacked_sevens.mean(0)
mean8 = stacked_eights.mean(0)
mean9 = stacked_nines.mean(0)

The mean for each digit represents the 'ideal' digit that is expected. Let's take a look at the 'ideal' 2.

df1 = pd.DataFrame(mean2[0:29,0:23])
df1.style.set_properties(**{'font-size':'4.5pt'}).background_gradient('Greys')
  0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000007 0.000142 0.000142 0.000006 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
1 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000011 0.000184 0.000234 0.000400 0.000471 0.000259 0.000369 0.000718 0.000733 0.001320 0.000560 0.000047 0.000000 0.000000 0.000000 0.000000
2 0.000000 0.000000 0.000000 0.000000 0.000000 0.000030 0.000301 0.001010 0.002387 0.004827 0.007155 0.010033 0.012604 0.014091 0.015811 0.015780 0.012331 0.008118 0.004130 0.001792 0.000594 0.000146 0.000008
3 0.000000 0.000000 0.000007 0.000010 0.000056 0.001060 0.004104 0.010762 0.024223 0.045967 0.070977 0.100678 0.127955 0.151479 0.162052 0.160046 0.140632 0.108345 0.072146 0.043444 0.021461 0.008759 0.002672
4 0.000000 0.000000 0.000042 0.000090 0.000819 0.006367 0.017791 0.042089 0.083421 0.142241 0.211767 0.292313 0.370494 0.431761 0.461901 0.456149 0.417914 0.336929 0.243804 0.160297 0.089220 0.040629 0.012857
5 0.000000 0.000000 0.000077 0.000273 0.003049 0.016349 0.045503 0.096712 0.171761 0.263224 0.362981 0.463160 0.547239 0.610263 0.644565 0.647456 0.608250 0.530123 0.407579 0.281021 0.173409 0.088314 0.033626
6 0.000000 0.000000 0.000053 0.000680 0.006591 0.028900 0.076292 0.151805 0.242472 0.341647 0.436100 0.516924 0.570205 0.603516 0.622732 0.634839 0.627733 0.584579 0.494698 0.369531 0.244626 0.136664 0.056396
7 0.000000 0.000000 0.000254 0.000965 0.010361 0.040237 0.099327 0.176339 0.259723 0.339625 0.407273 0.454401 0.472272 0.475716 0.483725 0.502393 0.532844 0.546947 0.510726 0.415592 0.291848 0.174315 0.079336
8 0.000000 0.000045 0.000288 0.001272 0.011794 0.044498 0.102871 0.166226 0.228230 0.278922 0.319372 0.333384 0.328697 0.318813 0.326290 0.363595 0.431744 0.499352 0.508834 0.437793 0.316499 0.190417 0.088839
9 0.000000 0.000000 0.000267 0.001539 0.011091 0.039959 0.085990 0.129770 0.165140 0.192895 0.213386 0.212083 0.195751 0.186895 0.201606 0.266368 0.371373 0.475721 0.510332 0.450774 0.323187 0.191269 0.087439
10 0.000000 0.000000 0.000130 0.001378 0.007421 0.029077 0.060358 0.081593 0.101843 0.116352 0.123059 0.116682 0.104919 0.104716 0.142872 0.234735 0.366604 0.481890 0.516487 0.450063 0.313075 0.176987 0.075751
11 0.000000 0.000000 0.000000 0.000766 0.005266 0.018284 0.033207 0.043408 0.053382 0.062493 0.064077 0.060981 0.061179 0.081220 0.144231 0.261926 0.398805 0.500654 0.514638 0.435455 0.288501 0.150212 0.058968
12 0.000000 0.000000 0.000000 0.000292 0.003241 0.010065 0.016726 0.023751 0.033803 0.041343 0.046381 0.055371 0.073395 0.119073 0.205897 0.329536 0.446791 0.518040 0.503421 0.401478 0.248775 0.118531 0.044476
13 0.000000 0.000000 0.000032 0.000380 0.002123 0.007130 0.016208 0.028661 0.045382 0.063764 0.087072 0.119171 0.161834 0.229378 0.320824 0.418851 0.502115 0.529520 0.473732 0.349567 0.202185 0.091573 0.035927
14 0.000000 0.000000 0.000000 0.000631 0.003544 0.013647 0.035809 0.069251 0.108837 0.153193 0.204482 0.261094 0.323508 0.397957 0.470263 0.534277 0.559740 0.528731 0.437916 0.300230 0.168766 0.080268 0.038659
15 0.000000 0.000000 0.000069 0.002358 0.011354 0.039603 0.090194 0.155878 0.224798 0.295254 0.366840 0.438308 0.511417 0.578410 0.619545 0.628696 0.593969 0.520419 0.403039 0.271179 0.162587 0.093511 0.061005
16 0.000000 0.000000 0.000350 0.006034 0.030797 0.089589 0.177901 0.269840 0.358496 0.437472 0.502642 0.569108 0.634937 0.679840 0.690350 0.666253 0.606259 0.510305 0.395925 0.281850 0.196662 0.140591 0.106600
17 0.000000 0.000000 0.000648 0.014614 0.062288 0.152536 0.268917 0.373017 0.453545 0.510708 0.556134 0.610725 0.656431 0.686115 0.683940 0.651334 0.590322 0.509432 0.420484 0.334785 0.265909 0.211504 0.168157
18 0.000000 0.000000 0.000675 0.025045 0.096128 0.211982 0.339039 0.442429 0.504724 0.550516 0.594093 0.627278 0.656336 0.670207 0.655604 0.622225 0.578423 0.524000 0.467761 0.405183 0.347683 0.294237 0.237217
19 0.000000 0.000000 0.001110 0.032870 0.117689 0.242600 0.378972 0.491348 0.566807 0.618971 0.651180 0.668537 0.669368 0.650098 0.613270 0.575419 0.541508 0.513792 0.488711 0.452865 0.406129 0.347430 0.274100
20 0.000000 0.000009 0.001803 0.034276 0.119279 0.234893 0.377356 0.504160 0.607640 0.673011 0.690975 0.682460 0.644020 0.586525 0.527870 0.482407 0.461993 0.455073 0.457337 0.437818 0.394623 0.330521 0.248283
21 0.000000 0.000122 0.001963 0.026058 0.091249 0.188354 0.313729 0.440901 0.550462 0.618395 0.625133 0.591058 0.529502 0.456810 0.395368 0.356391 0.344007 0.349045 0.354380 0.343302 0.304571 0.245597 0.176619
22 0.000000 0.000000 0.001066 0.012793 0.047181 0.108409 0.194192 0.289751 0.372962 0.419609 0.424284 0.390612 0.336123 0.277596 0.234563 0.209485 0.204407 0.207852 0.212406 0.202193 0.177242 0.140470 0.098799
23 0.000000 0.000000 0.000076 0.002722 0.012580 0.032933 0.063455 0.101813 0.135328 0.154938 0.158624 0.146665 0.126948 0.106317 0.091593 0.081452 0.079375 0.077296 0.076020 0.071620 0.063231 0.051806 0.035349
24 0.000000 0.000000 0.000000 0.000052 0.000743 0.002455 0.005081 0.008940 0.011869 0.014211 0.015518 0.016728 0.017879 0.017852 0.017593 0.016117 0.015332 0.013918 0.012382 0.010117 0.008858 0.007724 0.006063
25 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000181 0.000274 0.000274 0.000280 0.000440 0.000629 0.000767 0.001215 0.001755 0.001911 0.001793 0.001603 0.001268 0.000905 0.000796 0.000546 0.000224
26 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
27 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
im = stacked_ones[1]
show_image(im)
<matplotlib.axes._subplots.AxesSubplot at 0x7f588d51ed00>

Now, let's collect validation dataset and stack them up.

valid_zeros = (path/'testing'/'0').ls().sorted()
valid_ones = (path/'testing'/'1').ls().sorted()
valid_twos = (path/'testing'/'2').ls().sorted()
valid_threes = (path/'testing'/'3').ls().sorted()
valid_fours = (path/'testing'/'4').ls().sorted()
valid_fives = (path/'testing'/'5').ls().sorted()
valid_sixes = (path/'testing'/'6').ls().sorted()
valid_sevens = (path/'testing'/'7').ls().sorted()
valid_eights = (path/'testing'/'8').ls().sorted()
valid_nines = (path/'testing'/'9').ls().sorted()
valid_stacked_zeros = torch.stack([tensor(Image.open(o)) for o in valid_zeros]).float()/255
valid_stacked_ones = torch.stack([tensor(Image.open(o)) for o in valid_ones]).float()/255
valid_stacked_twos = torch.stack([tensor(Image.open(o)) for o in valid_twos]).float()/255
valid_stacked_threes = torch.stack([tensor(Image.open(o)) for o in valid_threes]).float()/255
valid_stacked_fours = torch.stack([tensor(Image.open(o)) for o in valid_fours]).float()/255
valid_stacked_fives = torch.stack([tensor(Image.open(o)) for o in valid_fives]).float()/255
valid_stacked_sixes = torch.stack([tensor(Image.open(o)) for o in valid_sixes]).float()/255
valid_stacked_sevens = torch.stack([tensor(Image.open(o)) for o in valid_sevens]).float()/255
valid_stacked_eights = torch.stack([tensor(Image.open(o)) for o in valid_eights]).float()/255
valid_stacked_nines = torch.stack([tensor(Image.open(o)) for o in valid_nines]).float()/255

In order to get the pixel similarity, we have to get the distance from the 'ideal' digit for each digit. First, we have to check the distance for each 'ideal' digit and then choose the closest one. In distance() method, we simply get the distance between two inputs. In min_distance() method, we find the closest 'ideal' digit for a given input. In is_correct() method, we can simply determine whether our prediction using pixel similarity is correct or not.

def distance(x,y): return (x-y).abs().mean((-1,-2))
mean_vec = [mean0, mean1, mean2, mean3, mean4, mean5, mean6, mean7, mean8, mean9]
def min_distance(x): 
    distances = [distance(x, o) for o in mean_vec]
    return distances.index(min(distances))
def is_correct(num, x): return num == min_distance(x)

Let's check with some inputs.

is_correct(4, valid_stacked_ones[140])
False

Now that we have guranteed it is working fine, let's calculate the accuracy of the model. In here, we will simply get the correct prediction per each class and then get the mean of it.

acc_zeros = tensor([is_correct(0,o) for o in valid_stacked_zeros]).float().mean()
acc_ones = tensor([is_correct(1,o) for o in valid_stacked_ones]).float().mean()
acc_twos = tensor([is_correct(2,o) for o in valid_stacked_twos]).float().mean()
acc_threes = tensor([is_correct(3,o) for o in valid_stacked_threes]).float().mean()
acc_fours = tensor([is_correct(4,o) for o in valid_stacked_fours]).float().mean()
acc_fives = tensor([is_correct(5,o) for o in valid_stacked_fives]).float().mean()
acc_sixes = tensor([is_correct(6,o) for o in valid_stacked_sixes]).float().mean()
acc_sevens = tensor([is_correct(7,o) for o in valid_stacked_sevens]).float().mean()
acc_eights = tensor([is_correct(8,o) for o in valid_stacked_eights]).float().mean()
acc_nines = tensor([is_correct(9,o) for o in valid_stacked_nines]).float().mean()

acc= tensor([acc_zeros, acc_ones, acc_twos, acc_threes, acc_fours, acc_fives, acc_sixes, acc_sevens, acc_eights, acc_nines]).mean()
acc
tensor(0.6610)

So, we have an accuracy of 66.1%. Given that we only considered pixel similarity, it is a good result. In next part, let's try to improve from here.