Visualising Data with Matplotlib and Seaborn: A Comprehensive Guide
- Vusi Kubheka
- Nov 19, 2024
- 2 min read
Data visualisation is an essential skill for data scientists and analysts, enabling them to uncover patterns, trends, and relationships hidden within datasets. This blog post introduces two powerful Python libraries, Matplotlib and Seaborn, using a dataset related to heart attack risk factors.
Why Matplotlib and Seaborn?
Matplotlib: A versatile library suitable for creating static, interactive, and animated visualisations.
Seaborn: Built on top of Matplotlib, Seaborn simplifies the creation of aesthetically pleasing and informative statistical graphics.
Below is a walk-through of various visualisation techniques and their corresponding Python code:
Dataset Overview
We will use a dataset called Heart Attack Data Set, loaded into a Pandas DataFrame.
import pandas as pd
data = pd.read_excel('Heart Attack Data Set spreadsheet.xlsx')
1. Correlation Heatmap
Visualisation Goal: Highlight relationships between numerical features.
import matplotlib.pyplot as plt import seaborn as sns correlation_matrix = data.corr() plt.figure(figsize=(10, 10)) sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm') plt.xticks(rotation=45, ha='right') plt.title('Correlation Heatmap') plt.tight_layout() plt.savefig('correlation_heatmap.png') plt.show() | ![]() |
2. Scatter Plot: Age vs. Serum Cholesterol
Visualisation Goal: Observe the distribution and potential relationships between two variables.
sns.scatterplot(data = data, x = 'age', y = 'serum cholesterol', hue='age') plt.title('Scatter Plot Serum Cholesterol vs. Age') plt.xlabel('Age') plt.ylabel('Serum Cholesterol') plt.savefig('scatter_plot.png') plt.show() print("\n") | ![]() |
3. Count Plot: Heart Disease Distribution by Gender
Visualisation Goal: Compare the distribution of heart disease cases across genders.
sns.countplot(x='sex', hue='target', data=data) plt.title('Distribution of Heart Disease by Gender') plt.xlabel('Gender (0: Female; 1: Male)') plt.ylabel('Count') plt.legend(title='Heart Disease', labels=['No', 'Yes']) plt.savefig('heart_disease_by_gender.png') plt.show() | ![]() |
4. Box Plot
Visualisation Goal: Summarise the distribution of cholesterol levels across age groups.
plt.figure(figsize=(8, 6)) sns.boxplot(x='age', y='serum cholesterol', data=data, hue='age') plt.title('Box Plot of Serum Cholesterol by Age') plt.xlabel('Age') plt.ylabel('Serum Cholesterol') plt.savefig('box_plot.png') plt.show() | ![]() |
5. Histogram
Visualisation Goal: Show the frequency distribution of serum cholesterol and age.
sns.histplot(data=data, x='age', y='serum cholesterol', hue='age') plt.title('Histogram Serum Cholesterol vs. Age') plt.xlabel('Age') plt.ylabel('Serum Cholesterol') plt.savefig('histogram.png') plt.show() | ![]() |
6. Violin Plot
Visualisation Goal: Combine box plot insights with density estimates.
sns.violinplot(data=data, x='age', y='serum cholesterol', hue='age') plt.title('Violin Plot Serum Cholesterol vs. Age') plt.xlabel('Age') plt.ylabel('Serum Cholesterol') plt.savefig('violin_plot.png') plt.show() | ![]() |
7. Joint Plot
Visualisation Goal: Combine scatter and histogram views for deeper insights.
plt.figure(figsize=(8, 6)) sns.jointplot(data=data, x='age', y='serum cholesterol', hue='age') plt.title('Jointplot Serum Cholesterol vs. Age') plt.savefig('jointplot.png') plt.show() | ![]() |
8. Scatter Plot: Serum Cholesterol vs. Age
Visualisation Goal: Scatter plot to draw a scatter plot onto a FacetGrid.
sns.lmplot(data=data, x='age', y='serum cholesterol', hue='age') plt.title('Lmplot Serum Cholesterol vs. Age') plt.xlabel('Age') plt.ylabel('Serum Cholesterol') plt.savefig('lmplot.png') plt.show() | ![]() |
9. Regression Plot: Serum Cholesterol vs. Age
Visualisation Goal: Assess the linear relationship between age and serum cholesterol.
sns.regplot(data = data, x = 'age', y = 'serum cholesterol') plt.title('Regplot Serum Cholesterol vs. Age') plt.xlabel('Age') plt.ylabel('Serum Cholesterol') plt.savefig('regplot.png') plt.show() print("\n") | ![]() |
Resid Plott
Visualisation Goal: to plot the residuals of linear regression. This method will regress y on x and then draw a scatter plot of the residuals
sns.residplot(data = data, x = 'age', y = 'serum cholesterol') plt.title('Residplot Serum Cholesterol vs. Age') plt.xlabel('Age') plt.ylabel('Serum Cholesterol') plt.savefig('residplot.png') plt.show() print("\n") | ![]() |
Commentaires