Master the art of creating complex, multi-panel visualizations with perfect layouts
Think of subplots like panels in a comic book - each tells part of the story, but together they create a complete narrative. Whether you're comparing different datasets, showing multiple perspectives, or creating a comprehensive dashboard, mastering subplot layouts is essential!
In professional data science, single plots rarely tell the whole story. You need to show trends alongside distributions, comparisons next to correlations, and summaries with details. Subplots let you compose these elements into cohesive, publication-ready visualizations!
Click on different layouts to see them in action:
The fundamental methods for creating subplot layouts
# Method 1: plt.subplots()
import matplotlib.pyplot as plt
import numpy as np
# Create figure with subplots
fig, axes = plt.subplots(2, 2, figsize=(10, 8))
# Access individual subplots
axes[0, 0].plot([1, 2, 3], [1, 4, 9])
axes[0, 1].scatter([1, 2, 3], [1, 4, 9])
axes[1, 0].bar([1, 2, 3], [1, 4, 9])
axes[1, 1].hist(np.random.randn(100))
# Method 2: plt.subplot()
plt.figure(figsize=(10, 8))
plt.subplot(2, 2, 1) # (rows, cols, index)
plt.plot([1, 2, 3], [1, 4, 9])
plt.subplot(2, 2, 2)
plt.scatter([1, 2, 3], [1, 4, 9])
plt.subplot(2, 2, 3)
plt.bar([1, 2, 3], [1, 4, 9])
plt.subplot(2, 2, 4)
plt.hist(np.random.randn(100))
# Method 3: add_subplot()
fig = plt.figure(figsize=(10, 8))
ax1 = fig.add_subplot(2, 2, 1)
ax2 = fig.add_subplot(2, 2, 2)
ax3 = fig.add_subplot(2, 2, 3)
ax4 = fig.add_subplot(2, 2, 4)
plt.tight_layout()
plt.show()
Create complex, asymmetric subplot arrangements
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
fig = plt.figure(figsize=(12, 8))
# Create GridSpec
gs = gridspec.GridSpec(3, 3, figure=fig)
# Spanning multiple cells
ax1 = fig.add_subplot(gs[0, :]) # Top row, all columns
ax2 = fig.add_subplot(gs[1, :-1]) # Middle row, first 2 columns
ax3 = fig.add_subplot(gs[1:, -1]) # Last 2 rows, last column
ax4 = fig.add_subplot(gs[-1, 0]) # Bottom left
ax5 = fig.add_subplot(gs[-1, 1]) # Bottom middle
# Add plots
ax1.plot(np.random.randn(100).cumsum())
ax1.set_title('Main Time Series')
ax2.scatter(np.random.randn(50), np.random.randn(50))
ax2.set_title('Correlation')
ax3.barh(['A', 'B', 'C', 'D'], [23, 45, 56, 78])
ax3.set_title('Categories')
ax4.pie([30, 25, 20, 25])
ax5.hist(np.random.randn(100))
# Custom spacing
gs.update(wspace=0.3, hspace=0.3)
# GridSpec with ratios
fig2 = plt.figure(figsize=(12, 6))
gs2 = gridspec.GridSpec(2, 3,
width_ratios=[1, 2, 1],
height_ratios=[2, 1])
ax1 = fig2.add_subplot(gs2[0, :2]) # Top left, 2 columns
ax2 = fig2.add_subplot(gs2[0, 2]) # Top right
ax3 = fig2.add_subplot(gs2[1, :]) # Bottom, all columns
plt.tight_layout()
plt.show()
Visual ASCII art layout specification (matplotlib 3.3+)
# Using subplot_mosaic for intuitive layouts
import matplotlib.pyplot as plt
import numpy as np
# Define layout with ASCII art
mosaic = """
AAB
AAC
DEC
"""
fig, axes = plt.subplot_mosaic(mosaic, figsize=(12, 8))
# Access subplots by letter
axes['A'].plot(np.random.randn(100).cumsum())
axes['A'].set_title('Large Main Plot')
axes['B'].scatter(np.random.randn(30), np.random.randn(30))
axes['B'].set_title('Scatter')
axes['C'].bar(['X', 'Y', 'Z'], [10, 20, 15])
axes['C'].set_title('Bar Chart')
axes['D'].pie([30, 70])
axes['E'].hist(np.random.randn(50))
# Complex mosaic with empty spaces
mosaic2 = """
AB.
ACD
.CD
"""
fig2, axes2 = plt.subplot_mosaic(
mosaic2,
figsize=(10, 8),
empty_sentinel=".", # Use . for empty spaces
gridspec_kw={'wspace': 0.3, 'hspace': 0.3}
)
# Nested lists alternative
mosaic_list = [
['top', 'top', 'right'],
['left', 'middle', 'right'],
['left', 'bottom', 'bottom']
]
fig3, axes3 = plt.subplot_mosaic(mosaic_list, figsize=(10, 8))
plt.tight_layout()
plt.show()
Link axes across subplots for better comparison
# Shared axes for synchronized views
import matplotlib.pyplot as plt
import numpy as np
# Share x-axis
fig, (ax1, ax2, ax3) = plt.subplots(3, 1,
figsize=(10, 8),
sharex=True)
x = np.linspace(0, 10, 100)
ax1.plot(x, np.sin(x))
ax1.set_title('Sin(x)')
ax2.plot(x, np.cos(x))
ax2.set_title('Cos(x)')
ax3.plot(x, np.sin(x) * np.cos(x))
ax3.set_title('Sin(x) × Cos(x)')
ax3.set_xlabel('X axis (shared)')
# Share y-axis
fig2, axes = plt.subplots(2, 2,
figsize=(10, 8),
sharey='row') # Share y within rows
# Share both axes
fig3, axes3 = plt.subplots(2, 2,
figsize=(10, 8),
sharex='all',
sharey='all')
# Zoom/pan in one subplot affects all shared axes
# Twin axes (secondary y-axis)
fig4, ax = plt.subplots(figsize=(10, 6))
x = np.arange(0, 10, 0.1)
y1 = np.sin(x)
y2 = np.exp(x/10)
color = 'tab:blue'
ax.set_xlabel('X data')
ax.set_ylabel('Sin(x)', color=color)
ax.plot(x, y1, color=color)
ax.tick_params(axis='y', labelcolor=color)
ax2 = ax.twinx() # Secondary y-axis
color = 'tab:orange'
ax2.set_ylabel('Exp(x)', color=color)
ax2.plot(x, y2, color=color)
ax2.tick_params(axis='y', labelcolor=color)
fig4.tight_layout()
plt.show()
Perfect your subplot spacing and alignment
# Control spacing between subplots
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 2, figsize=(10, 8))
# Method 1: subplots_adjust
plt.subplots_adjust(
left=0.1, # Left margin
right=0.9, # Right margin
top=0.9, # Top margin
bottom=0.1, # Bottom margin
wspace=0.3, # Width space between subplots
hspace=0.4 # Height space between subplots
)
# Method 2: tight_layout
plt.tight_layout() # Automatic spacing
plt.tight_layout(pad=2.0) # With padding
# Method 3: constrained_layout (newer, better)
fig, axes = plt.subplots(2, 2,
figsize=(10, 8),
constrained_layout=True)
# Method 4: GridSpec spacing
gs = gridspec.GridSpec(2, 2,
wspace=0.4,
hspace=0.3,
left=0.1,
right=0.9,
top=0.95,
bottom=0.05)
# Colorbar handling with proper spacing
fig, axes = plt.subplots(2, 2,
figsize=(10, 8),
constrained_layout=True)
for ax in axes.flat:
im = ax.imshow(np.random.randn(10, 10))
# Add single colorbar for all subplots
fig.colorbar(im, ax=axes.ravel().tolist(),
shrink=0.6, location='right')
Add plot-in-plot and advanced annotations
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
import numpy as np
fig, ax = plt.subplots(figsize=(10, 8))
# Main plot
x = np.linspace(0, 10, 1000)
y = np.sin(x) * np.exp(-x/10)
ax.plot(x, y, 'b-', linewidth=2)
ax.set_title('Signal with Detailed Inset')
# Create inset axes
axins = inset_axes(ax,
width="40%", # Width of inset
height="30%", # Height of inset
loc='upper right', # Location
borderpad=3) # Padding
# Plot in inset - zoomed region
x_zoom = np.linspace(2, 4, 100)
y_zoom = np.sin(x_zoom) * np.exp(-x_zoom/10)
axins.plot(x_zoom, y_zoom, 'r-', linewidth=2)
axins.set_xlim(2, 4)
axins.set_ylim(0.3, 0.5)
# Mark the zoomed region
mark_inset(ax, axins, loc1=2, loc2=4, fc="none", ec="red")
# Multiple insets
fig2, ax2 = plt.subplots(figsize=(10, 8))
# Main data
data = np.random.randn(1000)
ax2.hist(data, bins=30, alpha=0.7, color='blue')
# Inset 1: QQ plot
axins1 = inset_axes(ax2, width="30%", height="30%",
loc='upper left', borderpad=3)
from scipy import stats
stats.probplot(data, dist="norm", plot=axins1)
# Inset 2: Box plot
axins2 = inset_axes(ax2, width="30%", height="20%",
loc='lower right', borderpad=3)
axins2.boxplot(data, vert=False)
# Annotations between plots
from matplotlib.patches import ConnectionPatch
fig3, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
# Plot 1
ax1.scatter([1, 2, 3], [1, 4, 9])
point1 = [2, 4]
# Plot 2
ax2.bar(['A', 'B', 'C'], [4, 8, 12])
point2 = [0.5, 4]
# Connect points between subplots
con = ConnectionPatch(xyA=point1, xyB=point2,
coordsA='data', coordsB='data',
axesA=ax1, axesB=ax2,
color='red', linewidth=2,
arrowstyle="-|>")
ax2.add_artist(con)
plt.tight_layout()
plt.show()
Create a comprehensive financial analysis dashboard using advanced subplot techniques:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
# Generate sample financial data
np.random.seed(42)
dates = pd.date_range(start='2023-01-01', end='2024-12-31', freq='D')
n_days = len(dates)
# Stock price simulation
initial_price = 100
returns = np.random.normal(0.0005, 0.02, n_days)
price = initial_price * np.exp(np.cumsum(returns))
# Volume data
volume = np.random.gamma(2, 2, n_days) * 1000000
# Moving averages
ma_20 = pd.Series(price).rolling(window=20).mean()
ma_50 = pd.Series(price).rolling(window=50).mean()
ma_200 = pd.Series(price).rolling(window=200).mean()
# Portfolio components
portfolio = {
'Stocks': 45,
'Bonds': 25,
'Real Estate': 15,
'Commodities': 10,
'Cash': 5
}
# Sector performance
sectors = ['Tech', 'Finance', 'Healthcare', 'Energy', 'Consumer']
sector_returns = np.random.normal(0.08, 0.15, len(sectors))
# Risk metrics
daily_returns = np.diff(price) / price[:-1]
volatility = pd.Series(daily_returns).rolling(window=30).std()
# Create comprehensive dashboard
fig = plt.figure(figsize=(16, 12))
fig.suptitle('Financial Portfolio Dashboard - Q4 2024',
fontsize=18, fontweight='bold')
# Create GridSpec
gs = gridspec.GridSpec(4, 4,
width_ratios=[1, 1, 1, 1],
height_ratios=[2, 1, 1, 1],
hspace=0.3, wspace=0.3)
# 1. Main price chart with volume (top, spanning 3 columns)
ax1 = fig.add_subplot(gs[0, :3])
ax1_vol = ax1.twinx()
# Price lines
ax1.plot(dates, price, 'b-', linewidth=1.5, label='Price', alpha=0.8)
ax1.plot(dates, ma_20, 'orange', linewidth=1, label='MA20', alpha=0.7)
ax1.plot(dates, ma_50, 'green', linewidth=1, label='MA50', alpha=0.7)
ax1.plot(dates, ma_200, 'red', linewidth=1, label='MA200', alpha=0.7)
# Volume bars
ax1_vol.bar(dates, volume, alpha=0.3, color='gray', width=1)
ax1.set_title('Stock Price & Volume', fontweight='bold')
ax1.set_ylabel('Price ($)')
ax1_vol.set_ylabel('Volume')
ax1.legend(loc='upper left')
ax1.grid(True, alpha=0.3)
# Highlight important regions
ax1.axhspan(price.min(), price.min() * 1.1, alpha=0.2, color='red')
ax1.axhspan(price.max() * 0.9, price.max(), alpha=0.2, color='green')
# 2. Portfolio allocation (top right)
ax2 = fig.add_subplot(gs[0, 3])
colors = plt.cm.Set3(np.linspace(0, 1, len(portfolio)))
wedges, texts, autotexts = ax2.pie(portfolio.values(),
labels=portfolio.keys(),
autopct='%1.1f%%',
colors=colors,
startangle=90)
for autotext in autotexts:
autotext.set_color('white')
autotext.set_weight('bold')
autotext.set_fontsize(9)
ax2.set_title('Portfolio Allocation', fontweight='bold')
# 3. Returns distribution (second row, left)
ax3 = fig.add_subplot(gs[1, :2])
ax3.hist(daily_returns * 100, bins=50, alpha=0.7,
color='skyblue', edgecolor='black')
ax3.axvline(np.mean(daily_returns) * 100, color='red',
linestyle='--', linewidth=2, label='Mean')
ax3.axvline(np.percentile(daily_returns * 100, 5),
color='orange', linestyle='--', linewidth=2, label='5% VaR')
ax3.set_title('Daily Returns Distribution', fontweight='bold')
ax3.set_xlabel('Return (%)')
ax3.set_ylabel('Frequency')
ax3.legend()
ax3.grid(True, alpha=0.3, axis='y')
# 4. Sector performance (second row, right)
ax4 = fig.add_subplot(gs[1, 2:])
bars = ax4.barh(sectors, sector_returns * 100,
color=['green' if r > 0 else 'red' for r in sector_returns])
for i, (sector, ret) in enumerate(zip(sectors, sector_returns)):
ax4.text(ret * 100 + 1, i, f'{ret*100:.1f}%',
va='center', fontweight='bold')
ax4.set_title('Sector Performance (YTD)', fontweight='bold')
ax4.set_xlabel('Return (%)')
ax4.axvline(0, color='black', linewidth=1)
ax4.grid(True, alpha=0.3, axis='x')
# 5. Correlation heatmap (third row, left half)
ax5 = fig.add_subplot(gs[2, :2])
# Generate correlation matrix
n_assets = 6
asset_names = ['Stock A', 'Stock B', 'Bond X', 'Bond Y', 'Gold', 'REIT']
corr_matrix = np.random.rand(n_assets, n_assets)
corr_matrix = (corr_matrix + corr_matrix.T) / 2
np.fill_diagonal(corr_matrix, 1)
im = ax5.imshow(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1)
ax5.set_xticks(np.arange(n_assets))
ax5.set_yticks(np.arange(n_assets))
ax5.set_xticklabels(asset_names, rotation=45, ha='right')
ax5.set_yticklabels(asset_names)
# Add correlation values
for i in range(n_assets):
for j in range(n_assets):
text = ax5.text(j, i, f'{corr_matrix[i, j]:.2f}',
ha="center", va="center",
color="white" if abs(corr_matrix[i, j]) > 0.5 else "black",
fontsize=8)
ax5.set_title('Asset Correlation Matrix', fontweight='bold')
plt.colorbar(im, ax=ax5, fraction=0.046, pad=0.04)
# 6. Risk-Return scatter (third row, right half)
ax6 = fig.add_subplot(gs[2, 2:])
# Generate risk-return data
n_investments = 20
risk = np.random.uniform(5, 25, n_investments)
expected_return = risk * 0.5 + np.random.normal(0, 3, n_investments)
investment_size = np.random.uniform(100, 1000, n_investments)
scatter = ax6.scatter(risk, expected_return, s=investment_size,
c=expected_return/risk, cmap='viridis',
alpha=0.6, edgecolors='black', linewidth=1)
# Efficient frontier
ef_risk = np.linspace(5, 25, 50)
ef_return = ef_risk * 0.6
ax6.plot(ef_risk, ef_return, 'r--', linewidth=2,
label='Efficient Frontier')
ax6.set_title('Risk-Return Profile', fontweight='bold')
ax6.set_xlabel('Risk (Volatility %)')
ax6.set_ylabel('Expected Return (%)')
ax6.legend()
ax6.grid(True, alpha=0.3)
plt.colorbar(scatter, ax=ax6, label='Sharpe Ratio',
fraction=0.046, pad=0.04)
# 7. Performance metrics table (bottom row, left)
ax7 = fig.add_subplot(gs[3, :2])
ax7.axis('tight')
ax7.axis('off')
metrics = {
'Metric': ['Total Return', 'Sharpe Ratio', 'Max Drawdown',
'Volatility', 'Beta', 'Alpha'],
'Value': ['24.5%', '1.82', '-8.3%', '16.2%', '0.95', '3.2%'],
'Benchmark': ['18.2%', '1.45', '-12.1%', '14.8%', '1.00', '0.0%']
}
table_data = list(zip(metrics['Metric'], metrics['Value'], metrics['Benchmark']))
table = ax7.table(cellText=table_data,
colLabels=['Metric', 'Portfolio', 'Benchmark'],
cellLoc='center',
loc='center',
colColours=['lightgray']*3)
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.2, 1.5)
for i in range(len(metrics['Metric']) + 1):
for j in range(3):
cell = table[(i, j)]
if i == 0:
cell.set_facecolor('#4472C4')
cell.set_text_props(weight='bold', color='white')
else:
cell.set_facecolor('#F2F2F2' if i % 2 == 0 else 'white')
ax7.set_title('Performance Metrics', fontweight='bold', pad=20)
# 8. Monthly returns heatmap (bottom row, right)
ax8 = fig.add_subplot(gs[3, 2:])
# Generate monthly returns
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
years = ['2023', '2024']
monthly_returns = np.random.normal(0.01, 0.05, (2, 12))
im2 = ax8.imshow(monthly_returns * 100, cmap='RdYlGn',
vmin=-10, vmax=10, aspect='auto')
ax8.set_xticks(np.arange(12))
ax8.set_yticks(np.arange(2))
ax8.set_xticklabels(months, fontsize=8)
ax8.set_yticklabels(years)
# Add return values
for i in range(2):
for j in range(12):
val = monthly_returns[i, j] * 100
color = 'white' if abs(val) > 5 else 'black'
text = ax8.text(j, i, f'{val:.1f}%',
ha="center", va="center",
color=color, fontsize=8)
ax8.set_title('Monthly Returns Heatmap (%)', fontweight='bold')
plt.colorbar(im2, ax=ax8, fraction=0.046, pad=0.04)
# Add timestamp
fig.text(0.99, 0.01, f'Generated: {datetime.now().strftime("%Y-%m-%d %H:%M")}',
fontsize=9, ha='right', style='italic', color='gray')
# Add performance summary
total_return = (price[-1] / price[0] - 1) * 100
fig.text(0.5, 0.01,
f'Portfolio Value: ${price[-1]*10000:,.0f} | ' +
f'Total Return: {total_return:.1f}% | ' +
f'Daily Avg: {np.mean(daily_returns)*100:.3f}%',
fontsize=11, ha='center',
bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.5))
plt.tight_layout()
plt.show()