MNIST Digit Classification In Pytorch

Image for post
Image for post

Heads Up

Just a heads up, I programmed this neural network in Python using PyTorch. I also wrote my model in Pycharm but I would advise that if you choose to write this code (or really any deep learning model), use Google Colaboratory or Jupyter Notebooks (unless you can train models on your GPU). I recommend this because it will take a bit of time for your model to perform (took 5 minutes each time for this particular project).

Dataset Information

The MNIST dataset contains 28 by 28 grayscale images of single handwritten digits between 0 and 9. The set consists of a total of 70,000 images, the training set having 60,000 and the test set has 10,000. This means that there are 10 classes of digits, which includes the labels for the numbers 0 to 9.

Image for post
Image for post
import torch as t
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
  1. We’ll get our dataset from torchvision.datasets, and we’ll import it as “datasets”.
  2. Then we will import torchvision.transforms so we can transform our image to fit our model.
  3. Now we’ll import torch.nn as nn, and we will use this to build our actual neural network.
  4. Lastly, we’ll import matplotlib to visualize our results at the end.

Data


transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),])

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = t.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = t.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),])
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = t.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = t.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

The Model

class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.linear1 = nn.Linear(28*28, 100)
self.linear2 = nn.Linear(100, 50)
self.final = nn.Linear(50, 10)
self.relu = nn.ReLU()

def forward(self, img): #convert + flatten
x = img.view(-1, 28*28)
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.final(x)
return x
net = Net()
def forward(self, img): #convert + flatten
x = img.view(-1, 28*28)
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.final(x)
return x
net = Net()

x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.final(x)
return x
net = Net()

Loss Function

cross_el = nn.CrossEntropyLoss()
optimizer = t.optim.Adam(net.parameters(), lr=0.001) #e-1
epoch = 10

for epoch in range(epoch):
net.train()

for data in train_loader:
x, y = data
optimizer.zero_grad()
output = net(x.view(-1, 28*28))
loss = cross_el(output, y)
loss.backward()
optimizer.step()
cross_el = nn.CrossEntropyLoss()
optimizer = t.optim.Adam(net.parameters(), lr=0.001) #e-1
epoch = 10
for epoch in range(epoch):
net.train()

for data in train_loader:
x, y = data
optimizer.zero_grad()
output = net(x.view(-1, 28*28))
loss = cross_el(output, y)
loss.backward()
optimizer.step()
  1. Zero the gradients.
  2. Pass the dataset through the network.
  3. Perform loss calculation.
  4. Adjust weights within the network to decrease loss.

Evaluating Our Dataset

with t.no_grad():
for data in test_loader:
x, y = data
output = net(x.view(-1, 784))
for idx, i in enumerate(output):
if t.argmax(i) == y[idx]:
correct +=1
total +=1
print(f'accuracy: {round(correct/total, 3)}')

Visualization

We’ve now completed our network and the functions necessary, although, it is nice to add some visualization to show the data it is working with and if the results match the handwritten digits.

plt.imshow(x[3].view(28, 28))
plt.show()
print(t.argmax(net(x[3].view(-1, 784))[0]))
Image for post
Image for post
Image for post
Image for post

Contact me for any inquiries 🚀

Please note that all code within this article is my own code. If you would like to use or reference this code, go to my Github, where the repository is public.

Innovator and AI enthusiast

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store