Seaborn for Statistical Plots
Master the Art of Statistical Visualization! 🎨
Seaborn is a Python data visualization library based on matplotlib that provides a high-level interface for drawing attractive and informative statistical graphics. With Seaborn, you can create beautiful, publication-ready visualizations with just a few lines of code!
Why Seaborn?
Seaborn builds on top of matplotlib and integrates closely with pandas data structures. Here's what makes it special:
- 🎯 Statistical Focus: Built-in statistical estimation and aggregation
- 🎨 Beautiful Defaults: Professional-looking plots out of the box
- 🔧 Pandas Integration: Works seamlessly with DataFrame structures
- 📊 Complex Plots Made Easy: Create sophisticated visualizations with minimal code
Setting Up Seaborn
# Import libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# Set the style
sns.set_theme(style="whitegrid")
# Load sample datasets
tips = sns.load_dataset("tips")
iris = sns.load_dataset("iris")
titanic = sns.load_dataset("titanic")
print("Seaborn version:", sns.__version__)
print("Available built-in datasets:", sns.get_dataset_names())
Seaborn Styles and Themes
Seaborn offers several built-in themes to quickly change the look of your plots:
# Set different styles
sns.set_theme(style="whitegrid") # Clean with grid
sns.set_theme(style="darkgrid") # Dark background with grid
sns.set_theme(style="white") # Plain white background
sns.set_theme(style="dark") # Plain dark background
sns.set_theme(style="ticks") # White with tick marks
# Set context for different use cases
sns.set_context("notebook") # Default
sns.set_context("paper") # Smaller elements
sns.set_context("talk") # Larger elements for presentations
sns.set_context("poster") # Even larger for posters
# Custom context scaling
sns.set_context("notebook", font_scale=1.5, rc={"lines.linewidth": 2.5})
Color Palettes
Color palettes are one of Seaborn's strongest features. You can use built-in palettes or create custom ones:
# Built-in color palettes
sns.color_palette("deep") # Default palette
sns.color_palette("pastel") # Soft colors
sns.color_palette("bright") # Vivid colors
sns.color_palette("dark") # Dark colors
sns.color_palette("colorblind") # Colorblind-friendly
# Sequential palettes (good for ordered data)
sns.color_palette("Blues") # Single hue gradient
sns.color_palette("rocket") # Multi-hue gradient
# Diverging palettes (good for data with a meaningful center)
sns.color_palette("coolwarm") # Cool to warm
sns.color_palette("RdBu") # Red to blue
# Custom palettes
custom_palette = ["#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FFEAA7"]
sns.set_palette(custom_palette)
# Using palettes in plots
plt.figure(figsize=(10, 6))
sns.barplot(data=tips, x="day", y="total_bill", palette="Set2")
plt.title("Average Bill by Day with Custom Palette")
plt.show()
Distribution Plots
Histogram with KDE
# Basic histogram
sns.histplot(data=tips, x="total_bill")
# With KDE overlay
sns.histplot(data=tips, x="total_bill", kde=True, stat="density")
# Multiple variables with hue
sns.histplot(data=tips, x="total_bill", hue="time", multiple="stack")
# 2D histogram for bivariate data
sns.histplot(data=tips, x="total_bill", y="tip", bins=30)
KDE Plots
# Basic KDE
sns.kdeplot(data=tips, x="total_bill")
# Multiple distributions
sns.kdeplot(data=tips, x="total_bill", hue="time", fill=True)
# 2D KDE for bivariate relationships
sns.kdeplot(data=tips, x="total_bill", y="tip", levels=5)
# Cumulative distribution
sns.kdeplot(data=tips, x="total_bill", cumulative=True)
Categorical Plots
Seaborn excels at visualizing categorical data with various plot types:
# Strip plot - shows all points
sns.stripplot(data=tips, x="day", y="total_bill", hue="sex", dodge=True)
# Swarm plot - non-overlapping points
sns.swarmplot(data=tips, x="day", y="total_bill", hue="sex", dodge=True)
# Box plot - shows quartiles and outliers
sns.boxplot(data=tips, x="day", y="total_bill", hue="smoker")
# Violin plot - combines KDE and box plot
sns.violinplot(data=tips, x="day", y="total_bill", hue="sex", split=True)
# Bar plot - shows mean with error bars
sns.barplot(data=tips, x="day", y="total_bill", hue="sex", ci="sd")
# Count plot - shows counts of observations
sns.countplot(data=tips, x="day", hue="sex", palette="Set2")
Regression and Relationship Plots
# Linear regression plot
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
sns.regplot(data=tips, x="total_bill", y="tip")
plt.title("Basic Linear Regression")
plt.subplot(1, 3, 2)
sns.regplot(data=tips, x="total_bill", y="tip", order=2, ci=95)
plt.title("Polynomial Regression (Order 2)")
plt.subplot(1, 3, 3)
sns.regplot(data=tips, x="total_bill", y="tip", lowess=True)
plt.title("LOWESS Smoothing")
plt.tight_layout()
plt.show()
# Residual plot
plt.figure(figsize=(10, 5))
sns.residplot(data=tips, x="total_bill", y="tip", lowess=True, color="g")
plt.title("Residual Plot")
plt.show()
Matrix Plots - Heatmaps and Clustering
# Correlation heatmap
plt.figure(figsize=(10, 8))
correlation_matrix = tips.select_dtypes(include=[np.number]).corr()
sns.heatmap(correlation_matrix,
annot=True, # Show values
cmap="coolwarm", # Color scheme
vmin=-1, vmax=1, # Scale limits
center=0, # Center point
square=True, # Square cells
linewidths=1, # Grid lines
cbar_kws={"shrink": 0.8})
plt.title("Correlation Heatmap")
plt.show()
# Clustermap (hierarchical clustering)
plt.figure(figsize=(10, 8))
sns.clustermap(correlation_matrix,
cmap="RdBu_r",
annot=True,
fmt=".2f",
linewidths=0.5)
plt.title("Hierarchical Clustering of Features")
plt.show()
Pair Plots and Joint Plots
# Pair plot - explore relationships between all numeric variables
sns.pairplot(iris, hue="species",
diag_kind="kde", # Diagonal plot type
markers=["o", "s", "D"], # Different markers
palette="husl")
# Advanced pair plot with regression
sns.pairplot(tips[["total_bill", "tip", "size"]],
kind="reg", # Regression plots
diag_kind="kde",
plot_kws={'line_kws':{'color':'red'}})
# Joint plot - detailed view of two variables
g = sns.jointplot(data=tips, x="total_bill", y="tip",
kind="hex", # Hexbin plot
marginal_kws=dict(bins=25))
# Different joint plot types
kinds = ["scatter", "kde", "hist", "hex", "reg", "resid"]
for kind in kinds:
sns.jointplot(data=tips, x="total_bill", y="tip", kind=kind)
plt.suptitle(f"Joint Plot - {kind.title()}")
FacetGrid - Multi-plot Grids
FacetGrid allows you to create a grid of plots based on categorical variables:
# Basic FacetGrid
g = sns.FacetGrid(tips, col="time", row="smoker",
height=4, aspect=1.2)
g.map(sns.scatterplot, "total_bill", "tip")
g.add_legend()
# With histogram
g = sns.FacetGrid(tips, col="day", col_wrap=2, height=3)
g.map(sns.histplot, "total_bill", kde=True)
# Complex mapping with multiple layers
g = sns.FacetGrid(tips, col="time", row="smoker",
hue="sex", palette="Set1",
height=4, aspect=1.2)
g.map(sns.scatterplot, "total_bill", "tip", alpha=0.7)
g.map(sns.regplot, "total_bill", "tip", scatter=False)
g.add_legend()
# Categorical plots in grid
g = sns.FacetGrid(tips, col="day", height=4, aspect=0.5)
g.map(sns.boxplot, "time", "total_bill", order=["Lunch", "Dinner"])
Real-world Example: Sales Analysis Dashboard
# Create sample sales data
np.random.seed(42)
dates = pd.date_range('2024-01-01', periods=365)
sales_data = pd.DataFrame({
'date': dates,
'sales': np.random.normal(1000, 200, 365) + np.sin(np.arange(365) * 2 * np.pi / 365) * 300,
'region': np.random.choice(['North', 'South', 'East', 'West'], 365),
'product': np.random.choice(['A', 'B', 'C'], 365),
'customer_type': np.random.choice(['New', 'Returning'], 365)
})
sales_data['month'] = sales_data['date'].dt.month
sales_data['quarter'] = sales_data['date'].dt.quarter
# Create comprehensive dashboard
fig = plt.figure(figsize=(16, 12))
# 1. Time series with trend
ax1 = plt.subplot(3, 3, 1)
sns.lineplot(data=sales_data, x='date', y='sales', ax=ax1)
ax1.set_title('Sales Trend Over Time')
ax1.tick_params(axis='x', rotation=45)
# 2. Sales by region
ax2 = plt.subplot(3, 3, 2)
sns.boxplot(data=sales_data, x='region', y='sales', ax=ax2)
ax2.set_title('Sales Distribution by Region')
# 3. Product performance
ax3 = plt.subplot(3, 3, 3)
product_sales = sales_data.groupby('product')['sales'].mean().reset_index()
sns.barplot(data=product_sales, x='product', y='sales', palette='viridis', ax=ax3)
ax3.set_title('Average Sales by Product')
# 4. Customer type comparison
ax4 = plt.subplot(3, 3, 4)
sns.violinplot(data=sales_data, x='customer_type', y='sales', split=True, ax=ax4)
ax4.set_title('Sales by Customer Type')
# 5. Monthly patterns
ax5 = plt.subplot(3, 3, 5)
monthly_sales = sales_data.groupby('month')['sales'].mean().reset_index()
sns.lineplot(data=monthly_sales, x='month', y='sales', marker='o', ax=ax5)
ax5.set_title('Monthly Sales Pattern')
ax5.set_xticks(range(1, 13))
# 6. Region-Product heatmap
ax6 = plt.subplot(3, 3, 6)
pivot_table = sales_data.pivot_table(values='sales', index='region', columns='product', aggfunc='mean')
sns.heatmap(pivot_table, annot=True, fmt='.0f', cmap='YlOrRd', ax=ax6)
ax6.set_title('Sales Heatmap: Region vs Product')
# 7. Quarterly comparison
ax7 = plt.subplot(3, 3, 7)
sns.countplot(data=sales_data, x='quarter', hue='region', ax=ax7)
ax7.set_title('Transaction Count by Quarter')
# 8. Sales distribution
ax8 = plt.subplot(3, 3, 8)
sns.histplot(data=sales_data, x='sales', kde=True, bins=30, ax=ax8)
ax8.set_title('Overall Sales Distribution')
# 9. Correlation matrix
ax9 = plt.subplot(3, 3, 9)
numeric_cols = sales_data.select_dtypes(include=[np.number]).columns
corr_matrix = sales_data[numeric_cols].corr()
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', center=0, ax=ax9)
ax9.set_title('Feature Correlation')
plt.suptitle('Sales Analysis Dashboard', fontsize=16, y=1.02)
plt.tight_layout()
plt.show()
Best Practices and Tips
🎨 Choose the Right Plot
- Distributions: Use histplot, kdeplot, or ecdfplot
- Relationships: Use scatterplot, lineplot, or regplot
- Comparisons: Use barplot, boxplot, or violinplot
- Proportions: Use countplot or catplot
🎯 Optimize for Your Audience
- Use
set_context()to adjust plot elements for different mediums - Choose appropriate color palettes (sequential for ordered data, diverging for data with a center)
- Add clear titles and labels with
plt.title()andplt.xlabel() - Consider colorblind-friendly palettes for accessibility
⚡ Performance Tips
- Use
sample()to plot a subset of large datasets - Pre-aggregate data when possible before plotting
- Set
ci=Nonein statistical plots to skip confidence interval calculation - Use matplotlib directly for simple plots that don't need statistical features
Common Patterns and Use Cases
Practice Exercises
Exercise 1: Create a Statistical Dashboard
Using the tips dataset, create a 4-panel dashboard showing:
- Distribution of tips by day (use violinplot)
- Relationship between bill and tip with regression line
- Average bill by time and day (use heatmap)
- Correlation heatmap of numeric variables
Exercise 2: Custom Color Palette
Create a custom color palette matching your brand colors and apply it to a multi-category plot. Use at least 5 different colors and demonstrate it with a barplot.
Exercise 3: Complex FacetGrid
Use FacetGrid to create a grid of plots showing how tips vary by:
- Columns: Day of the week
- Rows: Time (Lunch/Dinner)
- Hue: Smoker status
Key Takeaways
- 📊 Seaborn makes statistical visualization easy with high-level functions
- 🎨 Built-in themes and color palettes create professional-looking plots
- 🔗 Tight integration with pandas makes data manipulation seamless
- 📈 Statistical features like regression and confidence intervals are built-in
- 🔲 FacetGrid enables powerful multi-dimensional visualizations