When we need to create a larger number of subplots of the same size, it can be quite inefficient to generate them one by one with the plt.subplot() or fig.add_subplot() function. In this case, we can call plt.subplots() to generate an array of subplots at once.
plt.subplots() takes in the number of rows and columns as input parameters, and returns a Figure together with a grid of subplots stored in a NumPy array. When there is no input parameter, plt.subplots() is equivalent to plt.figure() plus plt.subplot() by default.
Here is a code snippet for demonstration:
import matplotlib.pyplot as plt
fig, axarr = plt.subplots(1,1)
print(type(fig))
print(type(axarr))
plt.show()
From the resultant screenshot , we can observe that plt.subplots() also returns the Figure and AxesSubplot objects:
The next example illustrates a more useful case of plt.subplots().
This time, we will create a figure of 3x4 subplots and label each in a nested for loop:
import matplotlib.pyplot as plt
fig, axarr = plt.subplots(3,4)
for i in range(3):
for j in range(4):
axarr[i][j].text(0.3,0.5,str(i)+','+str(j),fontsize=18)
plt.show()
Again, we can observe from this figure that the subplots are ordered in rows and then columns, as seen in the preceding examples:
It is also possible to supply only one input parameter to plt.subplots(), which will be interpreted as the specified number of plots vertically stacked in rows. As the plt.subplots() function has essentially incorporated the plt.figure() function, we can also specify the figure dimensions by providing input to the figsize argument:
import numpy as np
import matplotlib.pyplot as plt
x = np.arange(0.0, 1.0, 0.01)
y1 = np.sin(8*np.pi*x)
y2 = np.cos(8*np.pi*x)
# Draw 1x2 subplots
fig, axarr = plt.subplots(2,figsize=(8,6))
axarr[0].plot(x,y1)
axarr[1].plot(x,y2,'red')
plt.show()
Note that the type of axarr is <class 'numpy.ndarray'>.
The preceding code results in the following figure with two rows of subplots: