10 minutes read

Now that you have some experience working with various plots in the matplotlib library, it's time to consolidate all that information. In this topic, we'll sum up what you already know and give you some tips and life hacks on how to visualize correlations. Let's start!

Correlation basics

If you've reached this topic, the chances are you know what a correlation is. A correlation describes the relationship between two variables. If an increase in one variable produces an increase in the other one, that's a positive correlation. If an increase in one variable results in a decrease in the other, that's a negative correlation.

There are several different correlation coefficients, but the most popular one is Pearson's correlation (a.k.a Pearson's R). If someone mentions a correlation without specifying which coefficient they use, then most probably they use the Pearson's R. We'll use it in our topic too. One important thing — Pearson's correlation is for numeric data only. Techniques for locating associations in categorical data are more advanced.

Correlation values always range from 1-1 to +1+1, where 1-1 means a very strong negative correlation, +1+1 — a very strong positive correlation. 00 means no correlation, in other words, no relationship between the variables. Here's a table that can help you interpret the correlation values:

Range

Meaning

0.70.7 to 1.01.0

a strong positive correlation

0.30.3 to 0.70.7

a weak positive correlation

0.3-0.3 to 0.30.3

a negligible correlation

0.7-0.7 to 0.3-0.3

a weak negative correlation

1.0-1.0 to 0.7-0.7

a strong negative correlation

Scatterplot

The simplest way to visualize a correlation is to use a scatterplot. You don't even need to calculate a coefficient! A scatterplot is a plot that uses dots to show values for two numeric variables. It's a good way to see if there's any association between the variables. However, keep in mind that it's not very precise.

In the next sections, we'll learn to calculate Pearson's coefficient, but for now, let's give scatterplot a try.

As an example, we'll use The Boston Housing Dataset. It's one of the built-in datasets in the R programming language. To access it in Python, we've installed the pydataset library (pip install pydataset), imported it into our code, and saved the dataset in a variable named boston. Don't forget to import pandas as well; we need it to deal with the dataset, and, of course, matplotlib. Here's how you can do this:

import matplotlib.pyplot as plt
import pandas as pd
from pydataset import data

# loading the dataset
boston = data('Boston')

This dataset contains lots of information, but for our example, we are only interested in four variables: the concentration of nitrogen oxides in the air (NOx), the distance to the major employment centers, the average number of rooms per dwelling, and the median value of houses.

Let's have a look at an example of a positive correlation. In our case, it's the relationship between the number of rooms (rm in the dataset) and the median value of houses (medv). Here's the code:

# create a scatterplot
plt.scatter(boston['rm'], boston['medv'])

# add a title and labels to the axes
plt.title('The correlation between average number of rooms and median value of homes', pad=20)
plt.xlabel('Number of rooms')
plt.ylabel('Value of a home in 1000$')

The pad argument elevates the title above the graph.

Let's look at the plot:

A scatterplot of positively correlated features

Most of the data points form something like a line going from the lower-left corner to the upper-right. That's an indicator of a strong positive correlation. Do you remember what positive correlation means? Right, an increase in the number of rooms forms an increase in the value of a home. In this case, the correlation value is 0.7, which is quite high.

Now let's have a look at a negative correlation. For this, we've picked the relationship between the concentration of nitrogen oxides in the air (the nox variable) and the distance to the employment centers (the dis variable). Here's the code:

# create a scatterplot
plt.scatter(boston['nox'], boston['dis'])

# add a title and labels to the axes
plt.title('The correlation between NOx concentration and distance to the employment centers', pad=20)
plt.xlabel('The NOx in the air')
plt.ylabel('Distance to the employment centers')

Our scatterplot will look like this:

A scatterplot of negatively correlated features

A good number of dots form a line going from the upper-left corner to the lower-right. This indicates a strong negative correlation: for an increase in one variable, there's a decrease in another. In our case, a higher NOx concentration shortens the distance to the employment centers. The correlation value is -0.75. As you can see, scatterplots are best for finding strong correlations. Weaker ones might not be as visible. If the dots are placed chaotically, or if they form a horizontal line parallel to one of the axes, it means that there's little to no correlation.

Finally, to gain a better understanding of what correlation looks like on a scatterplot, check out an interactive visualization by Kristoffer Magnusson. Try dragging the data points to see how each of them affects the correlation. You can also change the correlation value in the field to see how different values look on a plot.

Now that you know how to use scatterplots — let's move on to the other methods!

Correlation matrix with pandas

From now on, we'll be more precise and start calculating the correlation. The pandas library has a special function for that; don't worry, you don't need to do the math.

What's more interesting, you can create a pretty neat visualization with pandas only! We know it's a topic on matplotlib, but in case you're in a hurry, you can try it.

Let's get back to our house prices dataset. For the sake of simplicity, we'd leave out most of the columns and only work with five variables. Our variables are crim – crime rate, nox – concentration of nitrogen oxides in the air, rm – average number of rooms, dis – distance to the employment centers, and medv – median house value.

To calculate the correlation between them all, we only need the corr() method:

import pandas as pd
from pydataset import data

# load the dataset
boston = data('Boston')
# leave only the columns we need
boston = boston[['crim', 'nox', 'rm', 'dis', 'medv']]

# calculate the correlation
boston.corr()

As a result, you get a correlation matrix, a matrix with the correlation values for all variable pairs:

+------+-----------+-----------+-----------+-----------+-----------+
|      |      crim |       nox |        rm |       dis |      medv |
|------+-----------+-----------+-----------+-----------+-----------|
| crim |  1        |  0.420972 | -0.219247 | -0.37967  | -0.388305 |
| nox  |  0.420972 |  1        | -0.302188 | -0.76923  | -0.427321 |
| rm   | -0.219247 | -0.302188 |  1        |  0.205246 |  0.69536  |
| dis  | -0.37967  | -0.76923  |  0.205246 |  1        |  0.249929 |
| medv | -0.388305 | -0.427321 |  0.69536  |  0.249929 |  1        |
+------+-----------+-----------+-----------+-----------+-----------+

A perfect way to take a quick look at the data, right? The nice thing is that pandas allows you to add colors to it! Save your matrix into a variable (called corr in our example) and use the style method. Here's how you can do that:

corr = boston.corr()
corr.style.background_gradient(cmap='Spectral', axis=None)

The list of available colormaps can be found in the official docs. Axis=None means that the choice of colors is based on the entire matrix, not just one column or row. This way, the colors are more informative.

There's a lot more you can do with pandas styles. However, other parameters are not very straightforward, and if you need advanced styling, we advise you to use matplotlib. You can read about the pandas styles in the official documentation.

Now let's finally look at our matrix:

A colored correlation matrix that assigns different colors to lower and higher correlation coefficients

Does it remind you of anything? Right, it's just like a heatmap! It's time to get back to the matplotlib and move on to our last correlation visualization technique.

Correlation heatmap

Now let's learn the most popular way of plotting correlation data – a heatmap. It's a plot that uses color to convey information. This is just what we need for correlation!

As you remember, the main function we need here is plt.imshow(). We can pass it our correlation dataframe, and that's it! Here's the code:

import matplotlib.pyplot as plt
import pandas as pd
from pydataset import data

# load the dataset
boston = data('Boston')
# leave only the columns we need
boston = boston[['crim', 'nox', 'rm', 'dis', 'medv']]

# calculate the correlation
corr = boston.corr()
# create a heatmap
plt.imshow(corr)

The resulting plot is a valid heatmap, but it's somewhat hard to interpret:

A basic correlation heatmap of the Boston dataset

How can we make it better? Right, by adding labels to the x-ticks, y-ticks, and cells. A color bar on the side and a title can come in handy, too. Here's the code:

# create a heatmap and choose a colormap
plt.imshow(corr, cmap='coolwarm')

# add a colorbar
plt.colorbar()
# change the graph size
plt.gcf().set_size_inches(7, 7)

# add x-ticks, y-ticks and a label
ticks = ['crime rate', 'nitrogene oxids', 'number of rooms', 'distance to the center', 'value of a house']
plt.xticks(range(len(corr.columns)), ticks, fontsize=12, rotation=90)
plt.yticks(range(len(corr.columns)), ticks, fontsize=12)
plt.title('House prices in Boston', fontsize=16, pad=20)

# add labels to the cells
labels = corr.values
for a in range(labels.shape[0]):
    for b in range(labels.shape[1]):
        plt.text(a, b, '{:.3f}'.format(labels[b, a]), ha='center', va='center', color='black')

To change how many figures you want to see in the cells, change the number in the placeholder inside plt.text(): {:.2f} for two digits after the dot, {:.3f} for three and so on.

Now, our heatmap looks like this:

Customized correlation heatmap: color bar, ticks labels, and text in clusters

It's the most informative plot we've seen so far, right? By looking at it, you can tell that there's a positive correlation between the NOx in the air and the crime rate, and very little in terms of correlation between the distance to the city center and the number of rooms.

As you see, it's quite similar to the visualization we created with pandas in the previous section. The main difference is that here you have far more opportunities to change the look of the plot. You can alter its size, the placement of elements, x-ticks and y-ticks, and so on.

Conclusion

In this topic, we've discussed what correlation is, how to calculate, interpret, and plot it on a graph. Now let's quickly go through the main points. What can you do to visualize correlation data?

  • Create a scatterplot with plt.scatter() to see if there's any strong correlation,

  • Use the corr() method in pandas to create a correlation matrix and corr.style.background_gradient() to add colors to it,

  • use plt.imshow() to create a correlation heatmap

Now that you know all that, it's time for some practice!

17 learners liked this piece of theory. 1 didn't like it. What about you?
Report a typo