Skip to main content

🎨 Subplots and Figure Composition

Master the art of creating complex, multi-panel visualizations with perfect layouts

🏗️ Building Visual Stories with Multiple Plots

Think of subplots like panels in a comic book - each tells part of the story, but together they create a complete narrative. Whether you're comparing different datasets, showing multiple perspectives, or creating a comprehensive dashboard, mastering subplot layouts is essential!

In professional data science, single plots rarely tell the whole story. You need to show trends alongside distributions, comparisons next to correlations, and summaries with details. Subplots let you compose these elements into cohesive, publication-ready visualizations!

📐 Understanding Figure Architecture

graph TB A[Figure
The Canvas] --> B[Axes/Subplot 1] A --> C[Axes/Subplot 2] A --> D[Axes/Subplot 3] A --> E[Axes/Subplot N] B --> F[Title] B --> G[X-axis] B --> H[Y-axis] B --> I[Data] B --> J[Legend] style A fill:#f9f,stroke:#333,stroke-width:2px style B fill:#9ff,stroke:#333,stroke-width:2px style C fill:#9ff,stroke:#333,stroke-width:2px style D fill:#9ff,stroke:#333,stroke-width:2px style E fill:#9ff,stroke:#333,stroke-width:2px

🎮 Interactive Layout Builder

Click on different layouts to see them in action:

1
2
2×1 Vertical
1
2
1×2 Horizontal
1
2
3
4
2×2 Grid
1
2
3
4
5
6
3×2 Grid
1
2
3
Mixed Layout
1
2
3
4
Mosaic
0.2

🔧 Basic Subplot Creation

The fundamental methods for creating subplot layouts

# Method 1: plt.subplots()
import matplotlib.pyplot as plt
import numpy as np

# Create figure with subplots
fig, axes = plt.subplots(2, 2, figsize=(10, 8))

# Access individual subplots
axes[0, 0].plot([1, 2, 3], [1, 4, 9])
axes[0, 1].scatter([1, 2, 3], [1, 4, 9])
axes[1, 0].bar([1, 2, 3], [1, 4, 9])
axes[1, 1].hist(np.random.randn(100))

# Method 2: plt.subplot()
plt.figure(figsize=(10, 8))

plt.subplot(2, 2, 1)  # (rows, cols, index)
plt.plot([1, 2, 3], [1, 4, 9])

plt.subplot(2, 2, 2)
plt.scatter([1, 2, 3], [1, 4, 9])

plt.subplot(2, 2, 3)
plt.bar([1, 2, 3], [1, 4, 9])

plt.subplot(2, 2, 4)
plt.hist(np.random.randn(100))

# Method 3: add_subplot()
fig = plt.figure(figsize=(10, 8))

ax1 = fig.add_subplot(2, 2, 1)
ax2 = fig.add_subplot(2, 2, 2)
ax3 = fig.add_subplot(2, 2, 3)
ax4 = fig.add_subplot(2, 2, 4)

plt.tight_layout()
plt.show()

📏 GridSpec for Advanced Layouts

Create complex, asymmetric subplot arrangements

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

fig = plt.figure(figsize=(12, 8))

# Create GridSpec
gs = gridspec.GridSpec(3, 3, figure=fig)

# Spanning multiple cells
ax1 = fig.add_subplot(gs[0, :])  # Top row, all columns
ax2 = fig.add_subplot(gs[1, :-1])  # Middle row, first 2 columns
ax3 = fig.add_subplot(gs[1:, -1])  # Last 2 rows, last column
ax4 = fig.add_subplot(gs[-1, 0])  # Bottom left
ax5 = fig.add_subplot(gs[-1, 1])  # Bottom middle

# Add plots
ax1.plot(np.random.randn(100).cumsum())
ax1.set_title('Main Time Series')

ax2.scatter(np.random.randn(50), np.random.randn(50))
ax2.set_title('Correlation')

ax3.barh(['A', 'B', 'C', 'D'], [23, 45, 56, 78])
ax3.set_title('Categories')

ax4.pie([30, 25, 20, 25])
ax5.hist(np.random.randn(100))

# Custom spacing
gs.update(wspace=0.3, hspace=0.3)

# GridSpec with ratios
fig2 = plt.figure(figsize=(12, 6))
gs2 = gridspec.GridSpec(2, 3, 
                        width_ratios=[1, 2, 1],
                        height_ratios=[2, 1])

ax1 = fig2.add_subplot(gs2[0, :2])  # Top left, 2 columns
ax2 = fig2.add_subplot(gs2[0, 2])   # Top right
ax3 = fig2.add_subplot(gs2[1, :])   # Bottom, all columns

plt.tight_layout()
plt.show()

🎭 Subplot Mosaic

Visual ASCII art layout specification (matplotlib 3.3+)

# Using subplot_mosaic for intuitive layouts
import matplotlib.pyplot as plt
import numpy as np

# Define layout with ASCII art
mosaic = """
    AAB
    AAC
    DEC
"""

fig, axes = plt.subplot_mosaic(mosaic, figsize=(12, 8))

# Access subplots by letter
axes['A'].plot(np.random.randn(100).cumsum())
axes['A'].set_title('Large Main Plot')

axes['B'].scatter(np.random.randn(30), np.random.randn(30))
axes['B'].set_title('Scatter')

axes['C'].bar(['X', 'Y', 'Z'], [10, 20, 15])
axes['C'].set_title('Bar Chart')

axes['D'].pie([30, 70])
axes['E'].hist(np.random.randn(50))

# Complex mosaic with empty spaces
mosaic2 = """
    AB.
    ACD
    .CD
"""

fig2, axes2 = plt.subplot_mosaic(
    mosaic2,
    figsize=(10, 8),
    empty_sentinel=".",  # Use . for empty spaces
    gridspec_kw={'wspace': 0.3, 'hspace': 0.3}
)

# Nested lists alternative
mosaic_list = [
    ['top', 'top', 'right'],
    ['left', 'middle', 'right'],
    ['left', 'bottom', 'bottom']
]

fig3, axes3 = plt.subplot_mosaic(mosaic_list, figsize=(10, 8))

plt.tight_layout()
plt.show()

🔗 Sharing Axes

Link axes across subplots for better comparison

# Shared axes for synchronized views
import matplotlib.pyplot as plt
import numpy as np

# Share x-axis
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, 
                                     figsize=(10, 8),
                                     sharex=True)

x = np.linspace(0, 10, 100)
ax1.plot(x, np.sin(x))
ax1.set_title('Sin(x)')

ax2.plot(x, np.cos(x))
ax2.set_title('Cos(x)')

ax3.plot(x, np.sin(x) * np.cos(x))
ax3.set_title('Sin(x) × Cos(x)')
ax3.set_xlabel('X axis (shared)')

# Share y-axis
fig2, axes = plt.subplots(2, 2, 
                         figsize=(10, 8),
                         sharey='row')  # Share y within rows

# Share both axes
fig3, axes3 = plt.subplots(2, 2,
                          figsize=(10, 8),
                          sharex='all',
                          sharey='all')

# Zoom/pan in one subplot affects all shared axes

# Twin axes (secondary y-axis)
fig4, ax = plt.subplots(figsize=(10, 6))

x = np.arange(0, 10, 0.1)
y1 = np.sin(x)
y2 = np.exp(x/10)

color = 'tab:blue'
ax.set_xlabel('X data')
ax.set_ylabel('Sin(x)', color=color)
ax.plot(x, y1, color=color)
ax.tick_params(axis='y', labelcolor=color)

ax2 = ax.twinx()  # Secondary y-axis
color = 'tab:orange'
ax2.set_ylabel('Exp(x)', color=color)
ax2.plot(x, y2, color=color)
ax2.tick_params(axis='y', labelcolor=color)

fig4.tight_layout()
plt.show()

✨ Spacing and Alignment

Perfect your subplot spacing and alignment

Subplot 1
wspace
Subplot 2
# Control spacing between subplots
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(10, 8))

# Method 1: subplots_adjust
plt.subplots_adjust(
    left=0.1,    # Left margin
    right=0.9,   # Right margin
    top=0.9,     # Top margin
    bottom=0.1,  # Bottom margin
    wspace=0.3,  # Width space between subplots
    hspace=0.4   # Height space between subplots
)

# Method 2: tight_layout
plt.tight_layout()  # Automatic spacing
plt.tight_layout(pad=2.0)  # With padding

# Method 3: constrained_layout (newer, better)
fig, axes = plt.subplots(2, 2, 
                        figsize=(10, 8),
                        constrained_layout=True)

# Method 4: GridSpec spacing
gs = gridspec.GridSpec(2, 2,
                      wspace=0.4,
                      hspace=0.3,
                      left=0.1,
                      right=0.9,
                      top=0.95,
                      bottom=0.05)

# Colorbar handling with proper spacing
fig, axes = plt.subplots(2, 2, 
                        figsize=(10, 8),
                        constrained_layout=True)

for ax in axes.flat:
    im = ax.imshow(np.random.randn(10, 10))
    
# Add single colorbar for all subplots
fig.colorbar(im, ax=axes.ravel().tolist(), 
            shrink=0.6, location='right')

🎨 Inset Axes and Annotations

Add plot-in-plot and advanced annotations

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
import numpy as np

fig, ax = plt.subplots(figsize=(10, 8))

# Main plot
x = np.linspace(0, 10, 1000)
y = np.sin(x) * np.exp(-x/10)
ax.plot(x, y, 'b-', linewidth=2)
ax.set_title('Signal with Detailed Inset')

# Create inset axes
axins = inset_axes(ax, 
                   width="40%",  # Width of inset
                   height="30%",  # Height of inset
                   loc='upper right',  # Location
                   borderpad=3)  # Padding

# Plot in inset - zoomed region
x_zoom = np.linspace(2, 4, 100)
y_zoom = np.sin(x_zoom) * np.exp(-x_zoom/10)
axins.plot(x_zoom, y_zoom, 'r-', linewidth=2)
axins.set_xlim(2, 4)
axins.set_ylim(0.3, 0.5)

# Mark the zoomed region
mark_inset(ax, axins, loc1=2, loc2=4, fc="none", ec="red")

# Multiple insets
fig2, ax2 = plt.subplots(figsize=(10, 8))

# Main data
data = np.random.randn(1000)
ax2.hist(data, bins=30, alpha=0.7, color='blue')

# Inset 1: QQ plot
axins1 = inset_axes(ax2, width="30%", height="30%", 
                    loc='upper left', borderpad=3)
from scipy import stats
stats.probplot(data, dist="norm", plot=axins1)

# Inset 2: Box plot
axins2 = inset_axes(ax2, width="30%", height="20%", 
                    loc='lower right', borderpad=3)
axins2.boxplot(data, vert=False)

# Annotations between plots
from matplotlib.patches import ConnectionPatch

fig3, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

# Plot 1
ax1.scatter([1, 2, 3], [1, 4, 9])
point1 = [2, 4]

# Plot 2
ax2.bar(['A', 'B', 'C'], [4, 8, 12])
point2 = [0.5, 4]

# Connect points between subplots
con = ConnectionPatch(xyA=point1, xyB=point2, 
                      coordsA='data', coordsB='data',
                      axesA=ax1, axesB=ax2,
                      color='red', linewidth=2,
                      arrowstyle="-|>")
ax2.add_artist(con)

plt.tight_layout()
plt.show()

🌍 Real-World Example: Financial Dashboard

Create a comprehensive financial analysis dashboard using advanced subplot techniques:

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import pandas as pd
from datetime import datetime, timedelta

# Generate sample financial data
np.random.seed(42)
dates = pd.date_range(start='2023-01-01', end='2024-12-31', freq='D')
n_days = len(dates)

# Stock price simulation
initial_price = 100
returns = np.random.normal(0.0005, 0.02, n_days)
price = initial_price * np.exp(np.cumsum(returns))

# Volume data
volume = np.random.gamma(2, 2, n_days) * 1000000

# Moving averages
ma_20 = pd.Series(price).rolling(window=20).mean()
ma_50 = pd.Series(price).rolling(window=50).mean()
ma_200 = pd.Series(price).rolling(window=200).mean()

# Portfolio components
portfolio = {
    'Stocks': 45,
    'Bonds': 25,
    'Real Estate': 15,
    'Commodities': 10,
    'Cash': 5
}

# Sector performance
sectors = ['Tech', 'Finance', 'Healthcare', 'Energy', 'Consumer']
sector_returns = np.random.normal(0.08, 0.15, len(sectors))

# Risk metrics
daily_returns = np.diff(price) / price[:-1]
volatility = pd.Series(daily_returns).rolling(window=30).std()

# Create comprehensive dashboard
fig = plt.figure(figsize=(16, 12))
fig.suptitle('Financial Portfolio Dashboard - Q4 2024', 
             fontsize=18, fontweight='bold')

# Create GridSpec
gs = gridspec.GridSpec(4, 4, 
                      width_ratios=[1, 1, 1, 1],
                      height_ratios=[2, 1, 1, 1],
                      hspace=0.3, wspace=0.3)

# 1. Main price chart with volume (top, spanning 3 columns)
ax1 = fig.add_subplot(gs[0, :3])
ax1_vol = ax1.twinx()

# Price lines
ax1.plot(dates, price, 'b-', linewidth=1.5, label='Price', alpha=0.8)
ax1.plot(dates, ma_20, 'orange', linewidth=1, label='MA20', alpha=0.7)
ax1.plot(dates, ma_50, 'green', linewidth=1, label='MA50', alpha=0.7)
ax1.plot(dates, ma_200, 'red', linewidth=1, label='MA200', alpha=0.7)

# Volume bars
ax1_vol.bar(dates, volume, alpha=0.3, color='gray', width=1)

ax1.set_title('Stock Price & Volume', fontweight='bold')
ax1.set_ylabel('Price ($)')
ax1_vol.set_ylabel('Volume')
ax1.legend(loc='upper left')
ax1.grid(True, alpha=0.3)

# Highlight important regions
ax1.axhspan(price.min(), price.min() * 1.1, alpha=0.2, color='red')
ax1.axhspan(price.max() * 0.9, price.max(), alpha=0.2, color='green')

# 2. Portfolio allocation (top right)
ax2 = fig.add_subplot(gs[0, 3])
colors = plt.cm.Set3(np.linspace(0, 1, len(portfolio)))
wedges, texts, autotexts = ax2.pie(portfolio.values(), 
                                    labels=portfolio.keys(),
                                    autopct='%1.1f%%',
                                    colors=colors,
                                    startangle=90)

for autotext in autotexts:
    autotext.set_color('white')
    autotext.set_weight('bold')
    autotext.set_fontsize(9)

ax2.set_title('Portfolio Allocation', fontweight='bold')

# 3. Returns distribution (second row, left)
ax3 = fig.add_subplot(gs[1, :2])
ax3.hist(daily_returns * 100, bins=50, alpha=0.7, 
        color='skyblue', edgecolor='black')
ax3.axvline(np.mean(daily_returns) * 100, color='red', 
           linestyle='--', linewidth=2, label='Mean')
ax3.axvline(np.percentile(daily_returns * 100, 5), 
           color='orange', linestyle='--', linewidth=2, label='5% VaR')

ax3.set_title('Daily Returns Distribution', fontweight='bold')
ax3.set_xlabel('Return (%)')
ax3.set_ylabel('Frequency')
ax3.legend()
ax3.grid(True, alpha=0.3, axis='y')

# 4. Sector performance (second row, right)
ax4 = fig.add_subplot(gs[1, 2:])
bars = ax4.barh(sectors, sector_returns * 100, 
               color=['green' if r > 0 else 'red' for r in sector_returns])

for i, (sector, ret) in enumerate(zip(sectors, sector_returns)):
    ax4.text(ret * 100 + 1, i, f'{ret*100:.1f}%', 
            va='center', fontweight='bold')

ax4.set_title('Sector Performance (YTD)', fontweight='bold')
ax4.set_xlabel('Return (%)')
ax4.axvline(0, color='black', linewidth=1)
ax4.grid(True, alpha=0.3, axis='x')

# 5. Correlation heatmap (third row, left half)
ax5 = fig.add_subplot(gs[2, :2])

# Generate correlation matrix
n_assets = 6
asset_names = ['Stock A', 'Stock B', 'Bond X', 'Bond Y', 'Gold', 'REIT']
corr_matrix = np.random.rand(n_assets, n_assets)
corr_matrix = (corr_matrix + corr_matrix.T) / 2
np.fill_diagonal(corr_matrix, 1)

im = ax5.imshow(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1)
ax5.set_xticks(np.arange(n_assets))
ax5.set_yticks(np.arange(n_assets))
ax5.set_xticklabels(asset_names, rotation=45, ha='right')
ax5.set_yticklabels(asset_names)

# Add correlation values
for i in range(n_assets):
    for j in range(n_assets):
        text = ax5.text(j, i, f'{corr_matrix[i, j]:.2f}',
                       ha="center", va="center", 
                       color="white" if abs(corr_matrix[i, j]) > 0.5 else "black",
                       fontsize=8)

ax5.set_title('Asset Correlation Matrix', fontweight='bold')
plt.colorbar(im, ax=ax5, fraction=0.046, pad=0.04)

# 6. Risk-Return scatter (third row, right half)
ax6 = fig.add_subplot(gs[2, 2:])

# Generate risk-return data
n_investments = 20
risk = np.random.uniform(5, 25, n_investments)
expected_return = risk * 0.5 + np.random.normal(0, 3, n_investments)
investment_size = np.random.uniform(100, 1000, n_investments)

scatter = ax6.scatter(risk, expected_return, s=investment_size, 
                     c=expected_return/risk, cmap='viridis', 
                     alpha=0.6, edgecolors='black', linewidth=1)

# Efficient frontier
ef_risk = np.linspace(5, 25, 50)
ef_return = ef_risk * 0.6
ax6.plot(ef_risk, ef_return, 'r--', linewidth=2, 
        label='Efficient Frontier')

ax6.set_title('Risk-Return Profile', fontweight='bold')
ax6.set_xlabel('Risk (Volatility %)')
ax6.set_ylabel('Expected Return (%)')
ax6.legend()
ax6.grid(True, alpha=0.3)
plt.colorbar(scatter, ax=ax6, label='Sharpe Ratio', 
            fraction=0.046, pad=0.04)

# 7. Performance metrics table (bottom row, left)
ax7 = fig.add_subplot(gs[3, :2])
ax7.axis('tight')
ax7.axis('off')

metrics = {
    'Metric': ['Total Return', 'Sharpe Ratio', 'Max Drawdown', 
               'Volatility', 'Beta', 'Alpha'],
    'Value': ['24.5%', '1.82', '-8.3%', '16.2%', '0.95', '3.2%'],
    'Benchmark': ['18.2%', '1.45', '-12.1%', '14.8%', '1.00', '0.0%']
}

table_data = list(zip(metrics['Metric'], metrics['Value'], metrics['Benchmark']))
table = ax7.table(cellText=table_data,
                  colLabels=['Metric', 'Portfolio', 'Benchmark'],
                  cellLoc='center',
                  loc='center',
                  colColours=['lightgray']*3)

table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.2, 1.5)

for i in range(len(metrics['Metric']) + 1):
    for j in range(3):
        cell = table[(i, j)]
        if i == 0:
            cell.set_facecolor('#4472C4')
            cell.set_text_props(weight='bold', color='white')
        else:
            cell.set_facecolor('#F2F2F2' if i % 2 == 0 else 'white')

ax7.set_title('Performance Metrics', fontweight='bold', pad=20)

# 8. Monthly returns heatmap (bottom row, right)
ax8 = fig.add_subplot(gs[3, 2:])

# Generate monthly returns
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 
         'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
years = ['2023', '2024']
monthly_returns = np.random.normal(0.01, 0.05, (2, 12))

im2 = ax8.imshow(monthly_returns * 100, cmap='RdYlGn', 
                vmin=-10, vmax=10, aspect='auto')
ax8.set_xticks(np.arange(12))
ax8.set_yticks(np.arange(2))
ax8.set_xticklabels(months, fontsize=8)
ax8.set_yticklabels(years)

# Add return values
for i in range(2):
    for j in range(12):
        val = monthly_returns[i, j] * 100
        color = 'white' if abs(val) > 5 else 'black'
        text = ax8.text(j, i, f'{val:.1f}%',
                       ha="center", va="center", 
                       color=color, fontsize=8)

ax8.set_title('Monthly Returns Heatmap (%)', fontweight='bold')
plt.colorbar(im2, ax=ax8, fraction=0.046, pad=0.04)

# Add timestamp
fig.text(0.99, 0.01, f'Generated: {datetime.now().strftime("%Y-%m-%d %H:%M")}',
        fontsize=9, ha='right', style='italic', color='gray')

# Add performance summary
total_return = (price[-1] / price[0] - 1) * 100
fig.text(0.5, 0.01, 
        f'Portfolio Value: ${price[-1]*10000:,.0f} | ' +
        f'Total Return: {total_return:.1f}% | ' +
        f'Daily Avg: {np.mean(daily_returns)*100:.3f}%',
        fontsize=11, ha='center', 
        bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.5))

plt.tight_layout()
plt.show()

💡 Pro Tips for Perfect Subplots

⚠️ Common Pitfalls to Avoid