An Insight of K-Nearest Neighbour
K-Nearest Neighbour, or otherwise known as KNN is a machine learning algorithm that has a list of available cases and classifies data based on a similarity measure system. This model partakes in supervised learning and is most commonly used as a regression algorithm.
What does it mean for a model to be supervised?
I bet many of you have had to deal with little children, whether it’d be younger siblings, cousins, neighbours, and the list goes on. Young children are curious about their surroundings and curious to learn about the things around them in general.
We can think of our computer as a young child. It does not know everything but is curious to learn.
Let’s say we want to teach our computer (who is like a child), what an apple looks like. We’ll show our computer a picture of an apple, and tell the computer, yes it's an apple. We will also show a few pictures of other fruits, and tell the computer, no that is not an apple. Eventually, the computer will come to realize what an apple looks like.
Essentially this is how supervised machine learning works. The model is given labelled input data and has to produce an appropriate output when given unlabeled data.
What does it mean for a model to be a classification algorithm?
KNN is a classification algorithm, and essentially what this means is that we use the training dataset to get boundary conditions used to determine the target classes. Once we have these boundary conditions, we can use them to predict the target class. In all essence, a classification algorithm maps input data to a specific category.
Let’s look at the beautiful drawing I made. As you can see there are a group of purple and blue dots, that are situated separately. Then out of nowhere, there is a bright green dot and it does not belong anywhere… YET.
The purple dots represent the colour purple and the blue dots represent the colour blue, but green represents our mystery colour that could either be purple or blue.
So how do we find out if the green dot is either purple or blue?
To find out the true colour of the green dot, we’ll need to calculate the distance from the green dot to its neighbours. Let’s say we want to have 4 neighbours.
The big circle illustrated, shows the green dot’s 4 neighbours, within the range of the circle. To determine if the green dot belongs to the group of purple or group of blue dots, we will need to calculate the distance between the green dot and its neighbours. To do this we can use the Euclidean distance formula, which can be represented as d = sqrt((x-a)²+(y-b)²).
The KNN model will then run this formula and calculate the distances between the datapoint neighbours (blue and purple dots within the circle), and the test data (the green dot). Then the model will find the probability of these points being similar to the test data and will classify the green on the basis of the other neighbours that share the highest probabilities.
From this simple representation of a KNN-based problem, we can determine from a glance that the green dot most likely belongs to the class of blue dots.
- simple to implement
- no training time for classification
- KNN’s adjust to new data. Since there is no explicit training, we can keep adding new data, and the prediction is just adjusted without having to retrain the model
- A very flexible model in general, there are a wide variety of hyperparameters and distance metrics
- Not great for large datasets
- The more dimensions your data has, the more difficult predictions will be
- KNN assumes the same importance to all features
- Sensitive to outliers, meaning a single mislabeled example can change the class boundaries
KNN in Python
Now we have explored the background of the K-Nearest Neighbour’s algorithm, but let's also do a walkthrough of it using Python.
About the dataset
Info about this dataset:
- sepal length in cm
- sepal width in cm
- petal length in cm
- petal width in cm
To read more: https://en.wikipedia.org/wiki/Iris_flower_data_set
Here are our imports:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn.neighbors import KNeighborsClassifier
We will need numpy for making numpy arrays.
We will use patplotlib for data visualization.
We will import train_test_split from sklearn.model_selection to train our data.
Since the iris data set is built into sklearn, we’ll need to import it from sklearn.
Lastly, we will import the KNeighborsClassifier from sklearn.neighbors (our knn model).
iris = load_iris()
#print(iris.feature_names) --> shows features
#print(iris.target_names) --> prints target names, 0 = setosa, 1 = versicolor, 2 = virginica
We will store the iris data into the variable “iris” and use the load_iris() to call the data.
While working with data it is helpful to look at different parts of our dataset, to learn more about what we are working with.
If we print iris.data, you’ll see something see an array, but much longer:
[[5.1 3.5 1.4 0.2]
[4.9 3. 1.4 0.2]
[4.7 3.2 1.3 0.2]
[4.6 3.1 1.5 0.2]
[5. 3.6 1.4 0.2]
[5.4 3.9 1.7 0.4]]
Now what are these random decimal values?
We can print our feature names to find out, and if you do you’ll see a list of values:
[‘sepal length (cm)’, ‘sepal width (cm)’, ‘petal length (cm)’, ‘petal width (cm)’]
Essentially, this is a list of our feature names. As you can see, the features include sepal length, sepal width, petal length, and petal width. Our features represent each individual numerical value in our array (in the order of the features listed).
Now, we will want to know what our target names are, and these names will represent the types of flowers in our classification. To do this we can print our target names of the dataset.
If you do so you’ll see a list of the types of iris’.
[‘setosa’, ‘versicolor’, ‘virginica’]
The positions of the types of iris’ in the list are setosa , versicolor, and virginica.
If we want to access any other types of information with our dataset, you can do print(iris.keys), which will display the keys of the dataset, and from there you can select specific keys to access specific info.
print(iris.keys)OUTPUT: dict_keys([‘data’, ‘target’, ‘frame’, ‘target_names’, ‘DESCR’, ‘feature_names’, ‘filename’])
It’s always very helpful to do visualization when working with data. Matplotlib is an awesome library for doing so in python. Below is just some code that displays some current info regarding the dataset.
x_i = 0
y_i = 1
formatter = plt.FuncFormatter(lambda i, *args: iris.target_names[int(i)])
plt.scatter(iris.data[:, x_i], iris.data[:, y_i], c=iris.target)
plt.colorbar(ticks=[0, 1, 2], format=formatter)
plt.title('Iris Scatter Plot')
x = iris['data']
y = iris['target']
We can use sklearn’s train_test_split function to split our data into two subsets: training data and testing data. Essentially with this function, we do not need to divide the dataset manually.
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.4, random_state=42)
knn = KNeighborsClassifier(n_neighbors=3)
We’ve also defined our model, “knn”, which uses sklearn’s KNeighborsClassifier algorithm, and we have provided a parameter of 3 neighbors.
We can additionally fit our x_train and y_train data to the model, and then calculate the score of the x_test and y_test data, using the score() method. For this particular dataset, you’ll probably have an accuracy score of around 98.
Now, what if we want to predict what would happen with a value that may not be given in a dataset?
we can make a np.array that stores new values that we can our model can use as test data. The specific values in the array will output a 0, when the data is put through the predict function, meaning that the flower is setosa.
x_new_value = np.array([[5.0, 2.9, 1.0, 0.2]])
print(knn.predict(x_new_value)) # since the output of this is 0, the flower is setosa
Contact me for any inquiries 🚀
Please note that all code within this article is my own code. If you would like to use or reference this code, go to my Github, where the repository is public.
If you liked this article, you may want to check on my MNIST handwritten digit classification, using a neural network in PyTorch :)
Hi, I’m Ashley, a 16-year-old coding nerd and A.I. enthusiast!
I hope you enjoyed reading my article, and if you did, feel free to check out some of my other pieces on Medium :)
If you have any questions, would like to learn more about me, or want resources for anything A.I. or programming related, you can contact me by: