top of page

Visualising Data with Matplotlib and Seaborn: A Comprehensive Guide

  • Writer: Vusi Kubheka
    Vusi Kubheka
  • Nov 19, 2024
  • 2 min read

Data visualisation is an essential skill for data scientists and analysts, enabling them to uncover patterns, trends, and relationships hidden within datasets. This blog post introduces two powerful Python libraries, Matplotlib and Seaborn, using a dataset related to heart attack risk factors.


Why Matplotlib and Seaborn?


  • Matplotlib: A versatile library suitable for creating static, interactive, and animated visualisations.


  • Seaborn: Built on top of Matplotlib, Seaborn simplifies the creation of aesthetically pleasing and informative statistical graphics.


Below is a walk-through of various visualisation techniques and their corresponding Python code:


Dataset Overview


We will use a dataset called Heart Attack Data Set, loaded into a Pandas DataFrame.


import pandas as pd
data = pd.read_excel('Heart Attack Data Set spreadsheet.xlsx')

1. Correlation Heatmap


Visualisation Goal: Highlight relationships between numerical features.



import matplotlib.pyplot as plt
import seaborn as sns
correlation_matrix = data.corr()
plt.figure(figsize=(10, 10))
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm')
plt.xticks(rotation=45, ha='right')
plt.title('Correlation Heatmap')
plt.tight_layout()
plt.savefig('correlation_heatmap.png')
plt.show()



2. Scatter Plot: Age vs. Serum Cholesterol


Visualisation Goal: Observe the distribution and potential relationships between two variables.

sns.scatterplot(data = data, x = 'age', y = 'serum cholesterol', hue='age')
plt.title('Scatter Plot Serum Cholesterol vs. Age')
plt.xlabel('Age')
plt.ylabel('Serum Cholesterol')
plt.savefig('scatter_plot.png')
plt.show()
print("\n")



3. Count Plot: Heart Disease Distribution by Gender


Visualisation Goal: Compare the distribution of heart disease cases across genders.


sns.countplot(x='sex', hue='target', data=data)
plt.title('Distribution of Heart Disease by Gender')
plt.xlabel('Gender (0: Female; 1: Male)')
plt.ylabel('Count')
plt.legend(title='Heart Disease', labels=['No', 'Yes'])
plt.savefig('heart_disease_by_gender.png')
plt.show()



4. Box Plot


Visualisation Goal: Summarise the distribution of cholesterol levels across age groups.


plt.figure(figsize=(8, 6))
sns.boxplot(x='age', y='serum cholesterol', data=data, hue='age')
plt.title('Box Plot of Serum Cholesterol by Age')
plt.xlabel('Age')
plt.ylabel('Serum Cholesterol')
plt.savefig('box_plot.png')
plt.show()



5. Histogram


Visualisation Goal: Show the frequency distribution of serum cholesterol and age.


sns.histplot(data=data, x='age', y='serum cholesterol', hue='age')
plt.title('Histogram Serum Cholesterol vs. Age')
plt.xlabel('Age')
plt.ylabel('Serum Cholesterol')
plt.savefig('histogram.png')
plt.show()



6. Violin Plot


Visualisation Goal: Combine box plot insights with density estimates.


sns.violinplot(data=data, x='age', y='serum cholesterol', hue='age')
plt.title('Violin Plot Serum Cholesterol vs. Age')
plt.xlabel('Age')
plt.ylabel('Serum Cholesterol')
plt.savefig('violin_plot.png')
plt.show()



7. Joint Plot


Visualisation Goal: Combine scatter and histogram views for deeper insights.


plt.figure(figsize=(8, 6))
sns.jointplot(data=data, x='age', y='serum cholesterol', hue='age')
plt.title('Jointplot Serum Cholesterol vs. Age')
plt.savefig('jointplot.png')
plt.show()



8. Scatter Plot: Serum Cholesterol vs. Age


Visualisation Goal: Scatter plot to draw a scatter plot onto a FacetGrid.

sns.lmplot(data=data, x='age', y='serum cholesterol', hue='age')
plt.title('Lmplot Serum Cholesterol vs. Age')
plt.xlabel('Age')
plt.ylabel('Serum Cholesterol')
plt.savefig('lmplot.png')
plt.show()



9. Regression Plot: Serum Cholesterol vs. Age


Visualisation Goal: Assess the linear relationship between age and serum cholesterol.


sns.regplot(data = data, x = 'age', y = 'serum cholesterol')
plt.title('Regplot Serum Cholesterol vs. Age')
plt.xlabel('Age')
plt.ylabel('Serum Cholesterol')
plt.savefig('regplot.png')
plt.show()
print("\n")


  1. Resid Plott

Visualisation Goal: to plot the residuals of linear regression. This method will regress y on x and then draw a scatter plot of the residuals


sns.residplot(data = data, x = 'age', y = 'serum cholesterol')
plt.title('Residplot Serum Cholesterol vs. Age')
plt.xlabel('Age')
plt.ylabel('Serum Cholesterol')
plt.savefig('residplot.png')
plt.show()
print("\n")











Commentaires


  • Linkedin
  • Kaggle_logo_edited
  • Twitter
bottom of page