Share on facebook
Share on twitter
Share on linkedin
Share on pinterest

KNN from Scratch

The K-Nearest Neighbor(KNN) classifier is one of the easiest classification methods to understand and is one of the most basic classification models available.

KNN is a non-parametric method which classifies based on the distance to the training samples.

KNN is called a lazy algorithm. Technically, it does not build any model with training data; i.e., it does not really learn anything in the training phase. 

Actually, in the training phase, it just stores the training data in the memory and works in the testing phase.

In this tutorial, we’ll implement KNN from scratch using numpy.


Let’s take a look at K-nearest neighbors from a graphical perspective. 

Let’s suppose that we have a dataset with two classes circle and triangle. Visually, this looks like the following.

knnNow let’s say we have a mystery point whose class we need to predict.

To find out which class it belongs to we need to compare the distance(euclidean) of the mystery point to the training samples and selecting the K nearest neighbors.

The k indicates the number of close training samples to be regarded when predicting an unlabeled test point.

The class label of the new point is determined by a majority vote of its k nearest neighbors. The new point will be assigned to the class with the highest number of votes.

For example, if we choose the value of k to be 3 then the three closest neighbors of the new observation are two circles and one triangle.

k nearest neighborTherefore by majority vote, the mystery point will be classified as a circle.

The K-NN algorithm can be summarized as follows:

  1. Calculate the distances between the new input and all the training data.
  2. Find the nearest neighbors based on these pairwise distances.
  3. Classify the point based on a majority vote.

Now let’s create a simple KNN from scratch using Python.

First, let’s import the modules we’ll need and create the distance function which calculates the euclidean distance between two points.

Now let’s define the class

In the init function we’ll initialize K(number of nearest neighbors) to 3.

Since KNN does not build any model with training data, we’ll just store the values.

Now let’s define our predict method

Line 4 we have created an empty list to store all our predictions.

In Line 7 we are looping over all the points in the test set.

Line 10 we are calculating the distance between the test point and all other points in the training set

Then in Line 13 we sort the distances using argsort() and store the first K distances in a list. The argsort will return the indices of the K nearest points.

Line 16 we have created an empty dictionary that stores the neighbors and its count.

From Line 19 to 23 we loop through all the points in the list dist_sorted and add it to the dictionary.

If the neighbor’s class is already present in the dictionary we increase the count else we’ll set it equal to 1.

Line 25 we’ll sort the dictionary in the descending order based on the values. The values in the dictionary are the number of votes for that specific class.

The operator.itemgetter(1) in the key tells the sorted method to sort the dictionary based on the values in the dictionary.

Finally, we’ll append the class label to our list.

Let’s test out the KNN model. As usual, first, let’s import the necessary modules.

Now we can create a dataset with 150 points and 2 classes and split the data into train and test set

Now we can call the fit and predict method from our KNN class

Complete code for this tutorial can be found in this Github Repo.

Love What you Read. Subscribe to our Newsletter.

Stay up to date! We’ll send the content straight to your inbox, once a week. We promise not to spam you.

Subscribe Now! We'll keep you updated.