Computer scienceData scienceInstrumentsScikit-learnTraining ML models with scikit-learnClustering in scikit-learn

K-Means in scikit-learn

4 minutes read

In your day-to-day life, you may have come across problems that require clustering or grouping of entities. For instance, if you work at a supermarket, you may have customers with different spending habits. You might group them into high-value customers and regular customers. Similarly, you may have items from multiple brands that can be grouped into categories like noodles, oil, chocolates, and so on. This grouping of entities is called clustering. Clustering can either be done manually or using machine learning algorithms. In this topic, we will use the k-means algorithm in scikit-learn to cluster iris flowers into their respective species.

Dataset

As mentioned earlier, we will be clustering iris flowers using the Iris dataset — a popular dataset for learning statistical concepts. The Iris dataset has 150 instances of iris flowers representing its species: Iris setosa, Iris versicolor, and Iris virginica. The dataset contains 50 samples of each species. For each sample, there are 4 features present: sepal length, sepal width, petal length, and petal width (all measurements are in cm). The target is a categorical feature that can be 0, 1, or 2, corresponding to the species classification. For our convenience, we will only use two features: petal width and petal length. This makes it easier to visualize the data in 2D plots.

Let's load the dataset and separate the features from the labels.

from sklearn import datasets

iris = datasets.load_iris()
X = iris.data[:, 2:] # Using the petal width and petal length
y = iris.target

Note that since k-means is an unsupervised algorithm, we won't need the labels to perform clustering.

Splitting the data into train and test set

In a machine learning project, it is common to have two different sets: the train set to train the model and the test set to test the efficacy of the trained model in unseen data. Similarly, we will split our dataset into a train and test set. It's important to note that, in real projects, there is a third set of data, the validation set, which is used to tune the hyperparameters of the model. For now, we will only use the train and the test set.

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
   X, y, test_size=0.33, random_state=42)

In general, the random_state parameter is responsible for reproducibility across multiple re-runs and can be set to any arbitrary value. Here, we set it to 42, and, in this case, random_state controls the dataset shuffling.

Visualizing the train set

Now, let's visualize the train set using matplotlib. We perform the visualization to pick the number of clusters, which will be used in initializing the k-means model later on.

import matplotlib.pyplot as plt
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=plt.cm.tab20b)
plt.xlabel("Petal length")
plt.ylabel("Petal width")
plt.show()

Here is the output:

A plot with three species clusters

From the manual inspection of the train set, we can see that there are 3 different clusters representing each of setosa, versicolor, and virginica. As mentioned earlier, we will use the observed number of clusters for instantiating the k-means model.

Training the k-means model

Now, we will use the scikit-learn library to create a k-means model with 3 clusters. This algorithm will then be fitted using the train set. K-means is an unsupervised machine learning algorithm, so we won't need to pass the actual output or 'y' values to train the model.

from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters=3, random_state=0)
kmeans.fit(X_train)

Once again, we set random_state to 0 for reproducibility. Some other important parameters that can be passed to k-means are:

algorithm: K-means algorithm to use. The default one is "lloyd".

tol: Relative tolerance with regards to Frobenius norm of the difference in the cluster centers of two consecutive iterations to declare convergence. 0.0001 is the default.

n_init: The number of times the algorithm will run with different seeds.

Similarly, once you fit the model, you can access the following attributes using the trained model:

cluster_centers_ : Coordinates of cluster centers.
labels_ : Labels of each point.
n_features_in_ : Number of features seen during the training/fitting process.

These attributes can be accessed just like any attributes of the python object:  
kmeans.n_features_in_, kmeans.labels_, etc.

Please refer to the documentation to learn more about the attributes and parameters of KMeans() model.

We can now use the model to classify the test data into respective clusters.

test_labels = kmeans.predict(X_test)

Finally, let's plot the output predictions of our model on the test set.

plt.scatter(X_test[:, 0], X_test[:, 1], c=test_labels, cmap=plt.cm.tab20b)
plt.xlabel("Petal length")
plt.ylabel("Petal width")
plt.show()

Here is the output:

A plot with predictions for the three clusters

If you carefully look at the output above, you will find that the orange and pink clusters have been interchanged. In the train set, the cluster in the middle is colored orange, but, in the predicted output, the cluster in the top right is orange. Since k-means is an unsupervised algorithm, it may not label the clusters exactly as specified in the train set. Nevertheless, it does a good job of separating the data items into appropriate clusters. The label names can always be reassigned.

Conclusion

Here is what you have learned in this topic:

  • K-means is an unsupervised machine learning algorithm that can be used to cluster data into 'k' different groups.

  • K-means doesn't require explicit labels to perform clustering.

  • We can perform data visualization with matplotlib to identify the number of clusters required.

  • The KMeans class in sklearn.cluster can be used to instantiate a k-means algorithm with the required number of clusters.

  • Just like any other algorithm in sklearn, we can use the fit() method to train the model and the predict() method to generate predictions using the model.

How did you like the theory?
Report a typo