Computer scienceData scienceMachine learningClassification

Introduction to K-Nearest Neighbors

8 minutes read

K-Nearest Neighbors (KNN) is a simple and intuitive supervised learning algorithm typically used to classify data. In this topic, we will walk you through the basic concepts behind KNNs.

Let's start with a simple exercise. Look at the following set of circles and triangles.

Illustrating KNN with two classes: circles and triangles

We now introduce a new object that we want to add to one of the two classes we have. Can you guess which class the new object belongs to?

Classify a new object based on its nearest neigbors

You probably guessed it right: it's a triangle! But why? Our brains classified it correctly because there are two triangles right next to our unknown object. This is exactly how KNNs work.

KNN algorithm

KNNs follow four simple steps to classify an unknown object.

  1. Pick a K integer number, where K > 0.

  2. Get the distance between the new object and each object in the dataset.

  3. Sort the data points based on the distance in descending order.

  4. Assign a label to the new object, according to its K nearest neighbors.

In order to classify the object, we get the top K nearest objects and have them "vote" for the new class. The new object is then assigned a class with the most votes.

So far so good, but there is one problem: how do we choose K and how do we measure the distance between the neighboring objects? These are very important aspects of KNN: you can check the previous topics to learn more about choosing the suitable distance function. In the next section, we will be discussing the rationale behind choosing K.

Choosing K

Choosing the ideal K value can be tricky.

A very small K will be easily affected by anomalies. You can see it in the image below, where we choose K=1 and by coincidence, the nearest object is an anomaly. This will classify the object as a circle, while it clearly is a triangle. How do we avoid that? Well, we can try a larger K value.

Classify the new object with one nearest neighbor

Let's consider K=2.

Classify the new object with two nearest neighbors

This looks slightly better: we now have a triangle and a circle to vote for the new object's class. But this will result in a tie, and the algorithm will not be able to decide on a single label. You should account for such cases in your algorithm. In our specific example, since we have only two categories, choosing K to be odd will avoid a tie.

Now let's try K=3.

Classify the new object with three nearest neighbors


It worked! We now have three voters, 2 triangles, and 1 circle, so with the 3-NN we can successfully classify the new object as a triangle.

So, we can conclude that choosing a large K can help us generalize well. For instance, if we had 2 anomalies in our dataset, having K=5 will help us outvote the anomalies. Nevertheless, choosing a very large K is not a good idea. We must make sure that the algorithm won't take all the data points into consideration.

Back to our dataset! Let's say we have a total of 4 triangles and 7 circles in our dataset, and we chose K=9. This will make the algorithm always classify the new point as a circle, even if its top nearest neighbors are triangles.

9-Nearest Neighbors classification result

A rule of thumb is to have K<nK < \sqrt{n} where nn is the number of the data points we have.

There are fancier ways to choose K, for example, by experimenting with different K values on smaller sets of data and choosing K that predicts the correct labels of the existing data with the highest accuracy. This method is called cross-validation, and the data we experiment on is called the cross-validation dataset.

Example

Let's try to apply what we've learned so far in a real-life example.

We will use KNNs to predict whether or not a student will get admitted to an Engineering school, based on their high school grades in Maths and Physics.

Table of students grades and the corresponding admission decision

Our dataset, shown in the figure above contains the scores of 20 students in Physics and Maths, with their scores marked out of 100, and the corresponding admission decision.

Let's plot this dataset on a scatter plot graph:

A plot of twenty students' scores in Math and Physics

When taking a deeper look through the graph, we can see that admitted students are mainly clustered towards the upper right corner of the graph, and non-admitted students are scattered along the rest of the space. We can also spot some anomalies in the data such as the red point at the top right (82,93), and a green point at (67, 39).

Now we will apply KNNs to all possible points in the grid using the Euclidean distance as a distance function. The Euclidean distance is the length of the line that connects two points in Euclidean space, with nn being the number of dimensions, and is as follows:d(x,y)=i=1n(xiyi)2d(x,y) = \sqrt{\sum_{i=1}^{n}(x_i - y_i)^2}The charts below show the result of applying KNNs with different K values to all possible combinations of Physics and Maths scores.

Can you tell which is the best choice of K in our case?

Choose the best K value using visualization

When briefly looking at the graphs, we can see that with K=1, the boundary precisely wraps the 'Admitted' data points, and the result was affected by the anomalies. Thus, the model won't perform well on the unseen data. Small values of K make the model sensitive to outliers, which might result in overfitting, meaning that the model couldn't detect the underlying structure but instead captured the specifics of a particular dataset. As for even Ks (K=2, K=4, K=6), there are cases where they seem to provide good results, but as mentioned earlier, choosing an even K will result in ties. So, depending on the implementation, there might be a bias or randomness when choosing the results, or it might even break the code that wasn't designed to handle such cases.

We use the rule K<nK < \sqrt{n}, where in our dataset n\sqrt{n} = 20\sqrt{20} = 4.47, which supports our choice of picking K=3

Overall, determining machine learning model parameters (in our case K and the distance function) is not a straightforward task. It requires lots of experimenting and understanding of your data, and is often subject to change. For instance, our model might learn the data well and create a good decision boundary according to the data that was fed to it, but the data itself might not be a good representative of the problem. Such uncertainties leave room for your creativity and innovation!

Note that since KNN deals with distances, feature scaling (bringing all features to the same scale) will improve the quality of the predictions. For example, if some features lie in the [0,10][0,10] range (e.g., the house age, the number of rooms, etc), and others lie in the thousands or more (e.g., the house price in USD, the annual income, etc), the features in the smaller range will diminish when the distances are computed and won't be properly taken into consideration

Conclusion

KNN is an algorithm that is easy to implement and often yields accurate results. It is flexible and can be tweaked to solve different problems including multi-feature classification, regression, and text analysis.

It's worth mentioning that KNN is considered a "lazy algorithm" since it doesn't make any inferences from the data during the training time, unlike the way logistic regression or Neural Networks learn weight. Instead, it simply stores all the data points and determines the result during the prediction phase.

Although KNNs can perform well when used on small datasets, we need to calculate the distance between the new datapoint and each datapoint in the dataset, which makes it computationally expensive in large datasets, and saving the calculated distances consumes extra space as well.

12 learners liked this piece of theory. 0 didn't like it. What about you?
Report a typo