Python Machine Learning – Scatter Plot
In machine learning, scatter plots are a useful way to visualize the relationship between two features (or variables) in your dataset. Scatter plots plot individual data points on a two-dimensional plane, with one feature on the x-axis and the other on the y-axis. They are commonly used for exploratory data analysis (EDA) to observe potential relationships, correlations, patterns, and outliers in the data.
1. Scatter Plot Basics
A scatter plot displays data points as dots on a graph. Each dot represents one observation from the dataset, with its position determined by the values of two features: one for the x-axis and one for the y-axis.
Example: Simple Scatter Plot in Python
import matplotlib.pyplot as plt
# Example data: two features
x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
y = [2, 3, 4, 6, 7, 8, 9, 10, 12]
# Create a scatter plot
plt.scatter(x, y)
# Add labels and title
plt.title('Simple Scatter Plot')
plt.xlabel('Feature X')
plt.ylabel('Feature Y')
# Show the plot
plt.show()
In this example:
plt.scatter()creates the scatter plot.- The
xlist represents the x-coordinates of the points (feature X), and theylist represents the y-coordinates (feature Y).
2. Scatter Plot in Machine Learning Context
In machine learning, scatter plots help to understand how two variables relate to each other. This is especially important for supervised learning tasks like regression or classification:
- Positive Correlation: When one variable increases, the other tends to increase.
- Negative Correlation: When one variable increases, the other tends to decrease.
- No Correlation: No visible pattern between the two variables.
3. Scatter Plot with Labels/Colors
In classification problems, scatter plots can be enhanced by coloring the data points according to their class labels. This helps visualize the class separation in the feature space.
Example: Scatter Plot with Class Labels
import matplotlib.pyplot as plt
# Example data with class labels
x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
y = [2, 3, 4, 6, 7, 8, 9, 10, 12]
labels = [0, 0, 1, 1, 1, 0, 1, 0, 0] # Class labels (0 or 1)
# Create a scatter plot with color based on labels
plt.scatter(x, y, c=labels, cmap='coolwarm')
# Add labels and title
plt.title('Scatter Plot with Class Labels')
plt.xlabel('Feature X')
plt.ylabel('Feature Y')
# Show the plot
plt.show()
In this example:
- The
c=labelsparameter colors the data points based on their class labels. cmap='coolwarm'specifies the color map used for the different classes.
4. Scatter Plot with Size and Color
Scatter plots can also use the size and color of the data points to represent additional variables. For example, you can make the data points larger or smaller based on the magnitude of a third feature, or color the points based on a fourth feature.
Example: Scatter Plot with Size and Color
import matplotlib.pyplot as plt
# Example data with third and fourth features
x = [1, 2, 3, 4, 5, 6, 7, 8, 9]
y = [2, 3, 4, 6, 7, 8, 9, 10, 12]
sizes = [50, 80, 90, 200, 100, 150, 180, 300, 400] # Size of points
colors = [5, 20, 35, 50, 65, 80, 95, 110, 125] # Color based on third feature
# Create a scatter plot with size and color variation
plt.scatter(x, y, s=sizes, c=colors, cmap='viridis', alpha=0.6, edgecolors='black')
# Add labels and title
plt.title('Scatter Plot with Size and Color')
plt.xlabel('Feature X')
plt.ylabel('Feature Y')
# Show the plot
plt.show()
In this example:
s=sizesadjusts the size of each point based on thesizeslist.c=colorscolors the data points based on thecolorslist.alpha=0.6sets the transparency of the points, andedgecolors='black'adds an outline to each point for better visibility.
5. Scatter Plot for Regression
Scatter plots are useful in regression analysis to visualize the relationship between the independent (input) and dependent (output) variables. You can also plot a regression line to observe how well the data fits a linear model.
Example: Scatter Plot with Regression Line
import numpy as np
import matplotlib.pyplot as plt
# Example data
x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
y = np.array([2, 3, 4, 6, 7, 8, 9, 10, 12])
# Create a scatter plot
plt.scatter(x, y)
# Calculate the regression line
slope, intercept = np.polyfit(x, y, 1)
regression_line = slope * x + intercept
# Plot the regression line
plt.plot(x, regression_line, color='red')
# Add labels and title
plt.title('Scatter Plot with Regression Line')
plt.xlabel('Feature X')
plt.ylabel('Feature Y')
# Show the plot
plt.show()
In this example:
np.polyfit(x, y, 1)calculates the slope and intercept of the regression line.- The regression line is plotted over the scatter plot in red.
6. Scatter Plot for Multivariate Data
If you have more than two features in your dataset, you can use 3D scatter plots to visualize three variables simultaneously. For more than three dimensions, you’ll need advanced visualization techniques like pair plots or parallel coordinates.
Example: 3D Scatter Plot
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
# Example data with three features
x = np.random.rand(100)
y = np.random.rand(100)
z = np.random.rand(100)
# Create a 3D scatter plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x, y, z)
# Add labels and title
ax.set_title('3D Scatter Plot')
ax.set_xlabel('Feature X')
ax.set_ylabel('Feature Y')
ax.set_zlabel('Feature Z')
# Show the plot
plt.show()
In this example, the third feature (z) is represented along the z-axis in a 3D plot, helping you visualize the relationship among three features.
7. Scatter Matrix
For datasets with more than two features, you can use a scatter matrix (also called pair plots). A scatter matrix shows scatter plots of every pair of features, allowing you to see correlations or relationships between each pair.
Example: Scatter Matrix Using Pandas
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
# Create a DataFrame with multiple features
data = {
'Feature X': [1, 2, 3, 4, 5, 6, 7, 8, 9],
'Feature Y': [2, 3, 4, 6, 7, 8, 9, 10, 12],
'Feature Z': [5, 3, 6, 8, 7, 9, 11, 15, 18]
}
df = pd.DataFrame(data)
# Create a scatter matrix
sns.pairplot(df)
# Show the plot
plt.show()
This scatter matrix shows pairwise scatter plots between Feature X, Feature Y, and Feature Z. It provides an overview of the relationships among all the features.
Conclusion
Scatter plots are a versatile tool in machine learning for understanding relationships between features, detecting patterns, and identifying potential outliers. They are essential in the exploratory data analysis phase, helping to inform feature selection and data preprocessing steps. You can create scatter plots using libraries like Matplotlib, Seaborn, and Pandas, and further enhance them by adding regression lines, coloring points based on class labels, or representing additional features with size and color.