K-Nearest Neighbors (KNN) is a simple and intuitive supervised learning algorithm typically used to classify data. In this topic, we will walk you through the basic concepts behind KNNs.
Let's start with a simple exercise. Look at the following set of circles and triangles.
We now introduce a new object that we want to add to one of the two classes we have. Can you guess which class the new object belongs to?
You probably guessed it right: it's a triangle! But why? Our brains classified it correctly because there are two triangles right next to our unknown object. This is exactly how KNNs work.
KNN algorithm
KNNs follow four simple steps to classify an unknown object.
-
Pick a K integer number, where K > 0.
-
Get the distance between the new object and each object in the dataset.
-
Sort the data points based on the distance in descending order.
-
Assign a label to the new object, according to its K nearest neighbors.
In order to classify the object, we get the top K nearest objects and have them "vote" for the new class. The new object is then assigned a class with the most votes.
So far so good, but there is one problem: how do we choose K and how do we measure the distance between the neighboring objects? These are very important aspects of KNN: you can check the previous topics to learn more about choosing the suitable distance function. In the next section, we will be discussing the rationale behind choosing K.
Choosing K
Choosing the ideal K value can be tricky.
A very small K will be easily affected by anomalies. You can see it in the image below, where we choose K=1 and by coincidence, the nearest object is an anomaly. This will classify the object as a circle, while it clearly is a triangle. How do we avoid that? Well, we can try a larger K value.
Let's consider K=2.
This looks slightly better: we now have a triangle and a circle to vote for the new object's class. But this will result in a tie, and the algorithm will not be able to decide on a single label. You should account for such cases in your algorithm. In our specific example, since we have only two categories, choosing K to be odd will avoid a tie.
Now let's try K=3.
It worked! We now have three voters, 2 triangles, and 1 circle, so with the 3-NN we can successfully classify the new object as a triangle.
So, we can conclude that choosing a large K can help us generalize well. For instance, if we had 2 anomalies in our dataset, having K=5 will help us outvote the anomalies. Nevertheless, choosing a very large K is not a good idea. We must make sure that the algorithm won't take all the data points into consideration.
Back to our dataset! Let's say we have a total of 4 triangles and 7 circles in our dataset, and we chose K=9. This will make the algorithm always classify the new point as a circle, even if its top nearest neighbors are triangles.
A rule of thumb is to have where is the number of the data points we have.
Example
Let's try to apply what we've learned so far in a real-life example.
We will use KNNs to predict whether or not a student will get admitted to an Engineering school, based on their high school grades in Maths and Physics.
Our dataset, shown in the figure above contains the scores of 20 students in Physics and Maths, with their scores marked out of 100, and the corresponding admission decision.
Let's plot this dataset on a scatter plot graph:
When taking a deeper look through the graph, we can see that admitted students are mainly clustered towards the upper right corner of the graph, and non-admitted students are scattered along the rest of the space. We can also spot some anomalies in the data such as the red point at the top right (82,93), and a green point at (67, 39).
Now we will apply KNNs to all possible points in the grid using the Euclidean distance as a distance function. The Euclidean distance is the length of the line that connects two points in Euclidean space, with being the number of dimensions, and is as follows:The charts below show the result of applying KNNs with different K values to all possible combinations of Physics and Maths scores.
Can you tell which is the best choice of K in our case?
When briefly looking at the graphs, we can see that with K=1, the boundary precisely wraps the 'Admitted' data points, and the result was affected by the anomalies. Thus, the model won't perform well on the unseen data. Small values of K make the model sensitive to outliers, which might result in overfitting, meaning that the model couldn't detect the underlying structure but instead captured the specifics of a particular dataset. As for even Ks (K=2, K=4, K=6), there are cases where they seem to provide good results, but as mentioned earlier, choosing an even K will result in ties. So, depending on the implementation, there might be a bias or randomness when choosing the results, or it might even break the code that wasn't designed to handle such cases.
We use the rule , where in our dataset = = 4.47, which supports our choice of picking K=3
Overall, determining machine learning model parameters (in our case K and the distance function) is not a straightforward task. It requires lots of experimenting and understanding of your data, and is often subject to change. For instance, our model might learn the data well and create a good decision boundary according to the data that was fed to it, but the data itself might not be a good representative of the problem. Such uncertainties leave room for your creativity and innovation!
Conclusion
KNN is an algorithm that is easy to implement and often yields accurate results. It is flexible and can be tweaked to solve different problems including multi-feature classification, regression, and text analysis.
It's worth mentioning that KNN is considered a "lazy algorithm" since it doesn't make any inferences from the data during the training time, unlike the way logistic regression or Neural Networks learn weight. Instead, it simply stores all the data points and determines the result during the prediction phase.
Although KNNs can perform well when used on small datasets, we need to calculate the distance between the new datapoint and each datapoint in the dataset, which makes it computationally expensive in large datasets, and saving the calculated distances consumes extra space as well.