Computer scienceData scienceInstrumentsScikit-learnTraining ML models with scikit-learnClustering in scikit-learn

Agglomerative clustering in scikit-learn

8 minutes read

Agglomerative hierarchical clustering is an approach where every data point initially forms its individual cluster. In successive steps, the nearest clusters combine until only one or a designated number of clusters remain.

This topic will explore how hierarchical clustering can be performed using scikit-learn.

The preliminary steps

For the purpose of this topic, we will be working with the digits dataset, which consists of 10 labels (digits from 0 to 9) and 64 features. We will first use t-SNE to reduce the dataset's dimensionality for visualization purposes. Following that, we will apply agglomerative clustering on the dataset, now represented by two dimensions.

A visualization of the dataset

The dataset could be loaded from scikit-learn as

from sklearn.datasets import load_digits
from sklearn.preprocessing import StandardScaler

X, y = load_digits(return_X_y = True)

Since agglomerative clustering is a distance-based approach, feature standardization is required:

scaler = StandardScaler()
X = scaler.fit_transform(X)

As the final step in this section, we will perform dimensionality reduction with t-SNE to visualize in 2 dimensions:

from sklearn.manifold import TSNE

reduced_X = TSNE(n_components=2, random_state = 42).fit_transform(X)

Building the dendrogram

A dendrogram is a tree-like graph that visualizes the process of merging the clusters in agglomerative clustering. The organization of the branches in a dendrogram indicates the level of similarity between the clusters. The height of the branching indicates the degree of similarity or dissimilarity: a higher branching point suggests a greater difference between the clusters.

from scipy.cluster import hierarchy
import matplotlib.pyplot as plt

clusters = hierarchy.linkage(reduced_X, method="ward")

plt.figure(figsize=(8, 6))
dendrogram = hierarchy.dendrogram(clusters)

plt.xticks([])
plt.xlabel("Samples")
plt.ylabel("Cluster distance")

plt.grid()

The snippet above results in the following dendrogram:

A dendrogram of the dataset

The granularity, or the number of clusters, increases as we start from the top of the dendrogram and move down. A horizontal line can be used to set the distance threshold. The number of clusters corresponds to the number of intersections between the clades (the vertical lines) and the distance threshold (the horizontal line). For example, if we set the cluster distance at 600, there are 4 intersections with the clades, indicating that there are 4 clusters present for the specified threshold.

A brief note on how the cluster distance is calculated

The distance between the clusters is defined by the linkage criterion used. For example, single linkage defines the cluster distance as the minimum distance between two points belonging to different clusters, e.g., suppose we have clusters AA and BB, then

Single linkage distance between two clusters, A and B

It's important to note that trying to infer the most suitable distance threshold, and thus the optimal number of clusters, from the dendrogram is a heuristic approach. It does not guarantee an optimal solution, but it can be useful for understanding the data better, such as the order of merges or the distance/similarity between the clusters.

Fitting the model

Now, fitting and predicting is straightforward:

from sklearn.cluster import AgglomerativeClustering

clustering_model = AgglomerativeClustering(n_clusters=10, linkage="ward")
labels = clustering_model.fit_predict(reduced_X)

Below, we outline some of the most important AgglomerativeClustering() hyperparameters:

  • n_clusters (default: 2, int/ None) — the number of clusters to find. The algorithm will stop merging clusters when this number is reached. If we have a specific distance threshold, n_clusters should be set to None.
  • distance_threshold (default: None, float) — the linkage distance threshold above which the clusters will stop merging. When set to a float value, the n_clusters should be set to None and compute_full_tree should be set to True.
  • metric (default: None, str/ callable ) — metric used to compute the distances. None corresponds to the Euclidean metric. str accepts any valid metric from the sklearn.metrics.pairwise module, and callable allows to pass a custom metric function.
  • linkage (default: ward , ward/ complete/ average/ single) — determines which distance to use between sets of observations. The algorithm will merge the pairs of clusters that minimize this criterion.
    1. ward minimizes the within cluster variance of the clusters being merged (in short, ward measures the distance between two clusters, A and B, based on the increase in the sum of squares when they are merged. ward aims to minimize this increase);
    2. average uses the average of the distances of each observation of the two sets;
    3. complete uses the maximum distances between all observations of the two sets;
    4. single uses the minimum of the distances between all observations of the pairs of sets.
  • compute_full_tree (default: auto, bool/ auto) — early stopping of the tree contruction (at n_clusters value). Must be True (or the default auto) if the distance_threshold is not None.
  • connectivity (default: None, array-like/ callable) — impacts the cluster shapes and the sensitivity to outliers, as well as the general formation in the hierarchy.
    The effect of the connectivity constraints on the noisy circles dataset and various linkages

Visualizing the results

After the fitting, the clustering labels could be accessed via the .labels_ attribute and plotted (alternatively, labels from the previous code block could be used):

import seaborn as sns

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))

sns.scatterplot(
    ax=axes[0], data=reduced_X, x=reduced_X[:, 0], y=reduced_X[:, 1]
).set_title("Without clustering")
sns.scatterplot(
    ax=axes[1],
    data=reduced_X,
    x=reduced_X[:, 0],
    y=reduced_X[:, 1],
    hue=clustering_model.labels_,
).set_title("n_clusters = 10")

Two scatter plots: the original dataset and the clustering results

Since the ground truth labels are available, let's calculate the adjusted Rand score. This score lies in the range of [0,1][0, 1], with 11 indicating perfect alignment between the ground truth and the predicted labels. The adjusted Rand score measures the similarity between two clusterings:

from sklearn.metrics.cluster import adjusted_rand_score

print(round(adjusted_rand_score(y, labels), 3))

We obtained a score of 0.775.

The effects of various linkage methods

Linkage is very impactful in agglomerative clustering. Let's see how different linkage types affect the clustering results:

def plot_linkage():
    linkages = ("Ward", "Complete", "Average", "Single")

    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 10))
    axes = axes.flatten()

    for i, linkage in enumerate(linkages):
        clustering = AgglomerativeClustering(linkage=linkage.lower(), n_clusters=10)
        clustering.fit(reduced_X)
        ari_score = round(adjusted_rand_score(y, clustering.labels_), 3)
        
        sns.scatterplot(
            ax=axes[i],
            data=reduced_X,
            x=reduced_X[:, 0],
            y=reduced_X[:, 1],
            hue=clustering.labels_,
        ).set_title(f"{linkage} (Adjusted Rand = {ari_score})")

plot_linkage()

The effect of different linkages on the clustering results with the adjusted Rand scores

One can observe that Ward, average, and complete linkages tend to produce the most evenly distributed clusters. With the single linkage, one cluster encompasses a large portion of the digits.

Conclusion

As a result, you are now familiar with the usage of the AgglomerativeClustering() class, it's main parameters, and how to build the visualizations of the dendrogram and the clustering results.

How did you like the theory?
Report a typo