Join us and get access to thousands of tutorials and a community of expert Pythonistas.

Unlock This Lesson

This lesson is for members only. Join us and get access to thousands of tutorials and a community of expert Pythonistas.

Unlock This Lesson

Hint: You can adjust the default video playback speed in your account settings.
Hint: You can set your subtitle preferences in your account settings.
Sorry! Looks like there’s an issue with video playback 🙁 This might be due to a temporary outage or because of a configuration issue with your browser. Please refer to our video player troubleshooting guide for assistance.

Working With Multiple Subplots

A figure can have more than one subplot. Earlier, you learned that you can obtain your figure and axes objects by calling plt.subplots() and passing in a figure size. This function can take two additional arguments:

  1. The number of rows
  2. The numbers of columns

These arguments determine how many axes objects will belong to the figure, and by extension, how many axes objects will be returned to you. In the following code, nrows is set to 1, and ncols is set to 2:

Python
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2)

This function returns two axes objects, which you store in a tuple.

00:00 A figure can have more than one subplot. Earlier, we learned how we can obtain our Figure and Axes objects with the plt.subplots() function, passing in a figure size.

00:15 This function can also take two additional arguments, the number of rows and the number of columns.

00:23 These arguments dictate how many Axes objects will belong to the Figure, and by extension, how many Axes objects will be returned to us.

00:34 In this example, I’ve set nrows=1 and ncols=2, and so this function returns two Axes objects, which I store in this tuple.

00:48 If I set nrows=2 and ncols=2, the function would return four Axes.

00:58 Let’s see how we can modify these two Axes independently, creating two new visualizations in the process.

01:06 I’m here in a new file called plot2.py and to save on time, I have already imported pyplot and numpy at the top. Just like before, we’re going to get our randomized data using numpy. We’ll create a new variable called x, and that will store the one-dimensional ndarray obtained by calling randint() with a lower limit of 1, an upper limit of 11, and a size of 50, just like before. Now, I’ll create a new variable called y and that will store x plus a one-dimensional array of 50 random numbers, from 1 to 4 inclusive.

01:50 This means that a random number from 1 to 4 will be added to each element in our x ndarray.

01:59 I’m also going to create one more variable called data, and that will store the two-dimensional ndarray obtained by calling column_stack() with our x and y arrays.

02:13 Now we’re done obtaining our data points, so we can use pyplot to obtain our Figure and Axes objects. This time around, we’re going to have one Figure and two Axes and so I’ll write fig, and then a tuple of (ax1, ax2).

02:35 Now we’ll call the subplots() function with an nrows value of 1, an ncols value of 2, and a figure size of (8, 4). The first Axes is going to be a scatter plot. I’ll write ax1.scatter(), passing in x for the x data and y for the y data.

03:03 Now the method needs to know how to style the scatter plot. For marker, I’ll give it a value of 'o'. The marker parameter sets the style of the dot, like circles versus x’s versus crosses.

03:19 You can learn about all the different options by viewing the documentation for the .scatter() function. Now I’ll give these circles a color of red and an edge color of blue.

03:33 Just like before, I’m going to set the axes’s title, x label, and y label. Matplotlib uses LaTeX to render the text, and so placing text in between dollar signs ('$') will italicize it.

03:50 It respects basic LaTeX formatting options. Let’s give this Axes some grid lines so it’s easier to match each point to an x and y value. To do this, I’m first going to call the .set_accessbelow() method with a value of True, which will tell this Axes to display ticks and grids behind the points.

04:16 Then, we can actually enable the grid by calling the .grid() method with a linestyle of two dashes ('--'). Once again, you can check out the documentation to see all of the different line styles.

04:31 Great. Our first Axes is done. The second Axes will be a histogram, and so I’ll call the .hist() method on our ax2 object.

04:43 I’ll pass in our data, which is the two-dimensional ndarray that resulted from stacking our x and y arrays on top of each other.

04:53 The method also needs a bins argument, which will set the axes’s bins along the x-axis. I’ll use np.arange() with a lower bound of data.min() and an upper bound of data.max(), which will generate an ndarray counting from 1 to 13 inclusive.

05:19 Finally, the .hist() method needs a label parameter, so I’ll give it ('x', 'y').

05:26 This histogram will have two colored bars, so it’s important we distinguish between which one represents x and which one represents y. To do this, we can add a legend to our plot.

05:41 That’s as easy as ax2.legend(),

05:46 and I’ll give it a location of 0, which will tell matplotlib to render the legend in the best place to avoid overlapping with the bars drawn on the screen.

05:58 It should be noted that this can cause delays in rendering your figure when you have very large data sets to work with, so in that case, you can define the position manually by passing in two floats.

06:12 I almost always use 0, though. Just like before, I will give this Axes a title with the variable names italicized. Because we’ll be plotting two Axes side by side, it would look nicer if the y-axis ticks for this second Axes was on the right side of the plot, instead of on the left.

06:36 We can change this with ax2.yaxis.tick_right(). To finish up, I’ll add the dashed lines to our grid and I’ll call plt.show() so we actually get some output on our screen. And when I run this, you see both axes show up within our figure.

07:00 The scatter plot simply checks the elements in each position of both the x and y arrays, and then plots a point there. As for the histogram, that shows the frequency of x and y, meaning how often a specific number along the x-axis occurs in either the x array or the y array. Notice how the tick marks and the labels for the y-axis are on the right side, just as we specified. This code shows that it’s important we use a stateless approach with matplotlib.

07:38 If we relied on pyplot alone, it would be very difficult to customize each axes independently, because we wouldn’t have a direct object reference to each Axes.

07:50 We’d have to dig down deep into pyplot and find each Axes ourself, which is even more difficult considering that pyplot keeps track of the current Axes, and we have two of them to deal with.

08:05 Before I close off this video, I want to show you one other way we could use pyplot to get one Figure and multiple Axes objects.

08:16 I’m here in a blank Python shell, and I’m going to start by importing matplotlib.pyplot as plt, as we always do. I want to create a figure with four axes in a 2 by 2 grid, so I’ll write fig, ax = plt.subplots(nrows=2, ncols=2) and I’ll give it a figure size of, let’s say, (7, 7).

08:52 The figure size doesn’t really matter here since we won’t be displaying the figure on the screen. Notice here that I wrote ax instead of a tuple containing (ax1, ax2, ax3, ax4).

09:09 We could do that, but it might get difficult to manage if we have a figure with a lot of axes in it. But then if we’re trying to create four axes, what is ax? To answer that, I’ll use the Python type() function, passing in ax.

09:28 As you can see, ax is not an Axes, but a numpy.ndarray. If I just write ax to see its value, we can see that it’s a two-dimensional ndarray containing all four of our Axes objects. To get the first object, we can use square bracket notation, ax[0][0].

09:53 I put two zeros because it’s a two-dimensional array. And there’s our object. I mentioned in the NumPy video that each ndarray has a shape, which can help us better visualize the multi-dimensional array. To see the shape,

10:10 we can write ax.shape, and we see it’s a 2 by 2, because we have 2 columns and 2 rows.

10:20 If I ever wanted to store each Axes object in the array within its own separate variable, I can write ax1, ax2, ax3, ax4 = ax.flatten().

10:38 Now, we’ve got variables that reference each Axes object individually. I’ll run ax1, and you see there’s our first Axes.

10:51 This is all just another way to create a Figure with multiple Axes. It’s especially helpful if you have a Figure with many Axes, as you can manage them within an ndarray, instead of having variables for each one.

alistairmclaren1 on March 21, 2020

Why when setting lower and upper limits for x the argument is “low=” and “high=” however for variable y only the integers 1 and 5 are used?

Ranit Pradhan on April 5, 2020

File “<ipython-input-14-4e84e83ed077>”, line 19 ax2.legend(loc=(0)) ^ SyntaxError: invalid syntax

–what’s the solution?

ycc on Nov. 22, 2020

Regarding invalid syntax, I faced similar problem as well. After correcting line 17 on ax2.hist(data, bins... whereby I missed out one “)” which is not obvious at first, managed to run it without error.

Become a Member to join the conversation.