In this topic, we'll continue to explore the matplotlib library. We'll take a look at another plot type — heatmap. A heatmap is a graph that extensively uses color for data visualization. The colors depend on several independent variables. There are two different kinds of heatmap: a cluster heatmap (a cell-matrix of different colors) and a spatial heatmap (it has no cells; the variable variation is considered continuous). We'll focus on cluster heatmaps, as they're more popular for data visualization tasks.
Dealing with data
First, we need to prepare the data we want to plot. In this case, we've chosen the Seattle weather dataset. It contains information on the precipitation, temperature minimums and maximums, and wind. It's provided by the Vega datasets, so we need to import a special module for that — use pip install vega_datasets.
Tip: For more information on these datasets, check out the Github repository of the project: you'll find the datasets and their descriptions there.
We'll work with the pandas library, so don't forget to import it into your code along with the matplotlib. Here are the import statements we need:
from vega_datasets import data
import matplotlib.pyplot as plt
import pandas as pd
Now, we're ready to go! Let's have a look at our dataset:
weather = data.seattle_weather()
weather.head()
Output:
+----+---------------------+-----------------+------------+------------+--------+-----------+
| | date | precipitation | temp_max | temp_min | wind | weather |
|----+---------------------+-----------------+------------+------------+--------+-----------|
| 0 | 2012-01-01 00:00:00 | 0 | 12.8 | 5 | 4.7 | drizzle |
| 1 | 2012-01-02 00:00:00 | 10.9 | 10.6 | 2.8 | 4.5 | rain |
| 2 | 2012-01-03 00:00:00 | 0.8 | 11.7 | 7.2 | 2.3 | rain |
| 3 | 2012-01-04 00:00:00 | 20.3 | 12.2 | 5.6 | 4.7 | rain |
| 4 | 2012-01-05 00:00:00 | 1.3 | 8.9 | 2.8 | 6.1 | rain |
+----+---------------------+-----------------+------------+------------+--------+-----------+
Heatmaps are popular as correlation plots. We'll see whether there's a correlation between the precipitation, min-max parameters, and wind. The weather column is of no use to us since it contains non-numeric values.
Tip: If you're not quite sure what correlation means – check out the Correlation article by the Statistics Knowledge Portal.
Let's get rid of the columns we don't need and calculate the correlation:
weather = weather[['precipitation', 'temp_max', 'temp_min', 'wind']]
weather.corr()
Here's the dataframe we get:
+---------------+-----------------+------------+------------+------------+
| | precipitation | temp_max | temp_min | wind |
|---------------+-----------------+------------+------------+------------|
| precipitation | 1 | -0.228555 | -0.072684 | 0.328045 |
| temp_max | -0.228555 | 1 | 0.875687 | -0.164857 |
| temp_min | -0.072684 | 0.875687 | 1 | -0.0741852 |
| wind | 0.328045 | -0.164857 | -0.0741852 | 1 |
+---------------+-----------------+------------+------------+------------+
As you can see, the correlation values range from 1 to -1. Values close to 1, like 0.87 for temp_min and temp_max, construct a high positive correlation — if one variable is maxing, so is the other one. The higher is the maximum temperature, the higher is the minimum. Taking that into account, it's understandable that all variables have a correlation value of 1 compared with themselves.
Values close to -1 form a negative correlation — when there's an increase in one variable, there's a decrease in the other one. There's no example of a strong negative correlation in our dataframe. However, we have values close to 0, for example, -0.22 between precipitation and temp_max. It means that there's almost no correlation between the maximum temperature and precipitation. This makes sense — precipitation can occur at any temperature.
We have everything we need for our first heatmap, so let's start!
A basic heatmap
There are several functions in matplotlib that you can use to create a heatmap. In this topic, we'll focus on plt.imshow(). The only argument it requires is X – a dataset to a plot:
plt.imshow(weather.corr())
A dataset can be either a positional or a keyword argument, so you can type plt.imshow(X=weather.corr()) or plt.imshow(weather.corr()) — the result would be the same.
There's another important argument this function can take — interpolation. It takes a str that specifies an interpolation we want on our plot. The default value is None; and it works well for our example, as it produces a cluster heatmap. However, if you need the spatial one, you may want to change the interpolation value. Take a look at the official documentation page that provides a list of possible interpolation methods and a graph to illustrate each of them.
Now, let's pass our dataset to the function:
As you can see, with just one line, we've already got a valid heatmap. Yet it's quite ambiguous: no title, labels, or clarifications as to what these colors mean. Let's fix that!
Changing color and size
First of all, let's add a color bar, a scale that specifies the meaning of colors in our plot. All we need for that is the plt.colorbar() function without any arguments.
Tip: By the way, if you don't like the default colors, matplotlib allows you to change them. For that, use the cmap argument inside the plt.imshow() function. There are many colormaps available, just check the official documentation.
By default, a color bar is a vertical line on the right side of the plot. If you want to put it beneath the heatmap, you can use the optional argument orientation and set it as orientation='horizontal'. For our example, we prefer a vertical color bar, so we don't need this argument. Here's the code:
plt.imshow(weather.corr(), cmap="Spectral")
plt.colorbar()
And here's the plot we get:
Thanks to the color bar, we now see that the blue cells stand for maximum correlation (1.0) and the red ones — minimum correlation (-0.2 in our case).
If we want to change the size of our plot, we need the plt.gcf().set_size_inches() function; it takes two integers as its arguments. Try running the following code or use different values to see how the plot changes:
plt.gcf().set_size_inches(7, 7)Adding labels
We can also add labels to our plot to make it more comprehensive. To add them to the plot, use functions plt.xticks() and plt.yticks() for X- and Y-axes.
plt.xticks(range(len(weather.corr().columns)), weather.corr().columns)
plt.yticks(range(len(weather.corr().columns)), weather.corr().columns)
The first argument specifies the number of ticks, while the second one – their names. You can also use the optional argument rotation to change the orientation of the ticks. For example, rotation=90 will place the ticks' names vertically, which might be useful if the plot is small. In our example, we've used plt.gcf().set_size_inches(7, 7) to make the plot larger, so there's no need for rotation.
Here's the result:
That's much better, isn't it? Now we can see clearly that there's a high correlation between the maximum and minimum temperatures and a very low correlation between, for example, temperature and precipitation. Still, we lack information on the exact values. We can see that the correlation between the wind and maximum temperature is somewhere around -0.2, but what's the exact value? Unfortunately, no function can add values to cells, so we need to write a for loop to do that:
labels = weather.corr().values
for a in range(labels.shape[0]):
for b in range(labels.shape[1]):
plt.text(a, b, '{:.2f}'.format(labels[b, a]), ha='center', va='center', color='black')
Tip: In case you've forgotten what {:.2f} means, it's simply a placeholder for a floating-point number.
Before we take a look at the result, it'd be nice to add a title to our plot. We can do it with the plt.title() function that takes a str as an argument. We can also add the optional parameter fontsize:
plt.title('Weather in Seattle \n', fontsize=14)
If you want to set labels to X- and Y-axes, you can do that with separate functions: plt.xlabel() and plt.ylabel(). They take a str as an argument. We won't use them in the example since, in our case the plot is already clear enough.
Here comes the final version of our correlation heatmap:
Conclusion
In this topic, we've covered the basics of heatmap creation with matplotlib. Let's quickly go through the main points:
plt.imshow()creates a heatmap and has only one required argument — a dataset to plot;plt.colorbar()adds a color bar that shows what values the colors represent;plt.gcf().set_size_inches()changes the size of the plot;plt.xticks()andplt.yticks()set names to the plot ticks;plt.xlabel()andplt.ylabel()set labels to the axes;plt.title()sets a title to the plot.
Now, you know how to create an annotated cluster heatmap, but of course, there's much more to learn. If you need more details, make sure to check out the official documentation. Now, let's practice!