Classification problems with imbalanced inputs#
Oftentimes when we're doing real-world classification problems, we have the problem of "imbalanced classes".
Let's say we're analyzing a document dump, and trying to find the documents that are interesting to us. Maybe we're only interested in 10% of them! The fact that there's such a bias - 90% of them are uninteresting - will mess with our classifier. Let's take a look at imbalanced-learn library to help fix this problem!
We're going to go through this pretty quickly, so you should be familiar with vectorizing, classification, and confusion matrices going in.
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import LinearSVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
Our datasets#
We're going to be looking at two datasets today. They're both recipes and ingredient lists, and with both we're predicting whether we can accurate determine which recipes are Indian.
Let's read them both in.
df_balanced = pd.read_csv("data/recipes-indian.csv")
df_balanced['is_indian'] = (df_balanced.cuisine == "indian").astype(int)
df_balanced.head()
df_unbalanced = pd.read_csv("data/recipes.csv")
df_unbalanced['is_indian'] = (df_unbalanced.cuisine == "indian").astype(int)
df_unbalanced.head()
They both look similar enough, right? A list of ingredients and an is_indian
target column we'll be using as our label.
Finding the imbalance#
The real difference is how many of the recipes are Indian in each dataset. Let's take a look:
df_balanced.is_indian.value_counts()
df_unbalanced.is_indian.value_counts()
Ouch! That second dataset is really uneven - over ten times as many non-Indian recipes as there are Indian recipes!
The thing is: this is usually how data looks in the real world. You rarely have even numbers between your classes, and you often thing "more data is better data." We'll see how it plays out when we actually run our classifiers!
Testing our datasets#
We're going to use a TfidfVectorizer
to convert ingredient lists to numbers, run a test/train split, and then train (and test) a LinearSVC
classifier on the results. We'll start with the balanced dataset.
Balanced dataset#
# Create a vectorizer and train it
vectorizer = TfidfVectorizer()
matrix = vectorizer.fit_transform(df_balanced.ingredient_list)
# Features are our matrix of tf-idf values
# labels are whether each recipe is Indian or not
X = matrix
y = df_balanced.is_indian
# How many are Indian?
y.value_counts()
We still have an even split, 3000 non-Indian recipes and 3000 Indian recipes. Let's run a test/train split and see how the results look.
# Split into test and train data
X_train, X_test, y_train, y_test = train_test_split(X, y)
# Build a classifier and train it
clf = LinearSVC()
clf.fit(X_train, y_train)
# Test our classifier and build a confusion matrix
y_true = y_test
y_pred = clf.predict(X_test)
matrix = confusion_matrix(y_true, y_pred)
label_names = pd.Series(['not indian', 'indian'])
pd.DataFrame(matrix,
columns='Predicted ' + label_names,
index='Is ' + label_names).div(matrix.sum(axis=1), axis=0)
Our classifier looks pretty good! Around 96% accuracy for predicting non-Indian food, and around 95% correctly predicting Indian food. High quality and even.
Let's move on to see how it looks with our unbalanced dataset.
Unbalanced dataset#
# Create a vectorizer and train it
vectorizer = TfidfVectorizer()
matrix = vectorizer.fit_transform(df_unbalanced.ingredient_list)
# Features are our matrix of tf-idf values
# labels are whether each recipe is Indian or not
X = matrix
y = df_unbalanced.is_indian
# How many are Indian?
y.value_counts()
Again: around 36k non-Indian recipes really really outweighing the 3,003 Indian recipes. While we love the world of "more more more data," let's see what that imbalance does to our classifier.
# Split our dataset is train and test data
X_train, X_test, y_train, y_test = train_test_split(X, y)
# Train the classifier on the training data
clf = LinearSVC()
clf.fit(X_train, y_train)
# Test our classifier and build a confusion matrix
y_true = y_test
y_pred = clf.predict(X_test)
matrix = confusion_matrix(y_true, y_pred)
label_names = pd.Series(['not indian', 'indian'])
pd.DataFrame(matrix,
columns='Predicted ' + label_names,
index='Is ' + label_names).div(matrix.sum(axis=1), axis=0)
Ouch!!! While we're doing really well at predicting non-Indian dishes, our ability to predict Indian dishes has plummeted to just over 80%.
Why does this happen? An easy way to think about it is when it's a risky decision, it's always safest to guess "not Indian." In fact, if we always guessed non-Indian, no matter what, we'd be right...
36771/(36771+3003)
About 92% of the time! So how do we solve this problem?
Solving the problem#
Solving the problem of unbalanced (or biased) input classes is actually not too hard! There's a nice library that can give us a hand, imbalanced-learn.
imbalanced-learn will resample our dataset, either generating new datapoints or pruning out existing datapoints, until the classes are evened out.
What do were sample?#
An important thing to note is that the problem with bias happens when we train our model. If we show our model a skewed view of the world, it'll carry that bias when making judgments in the future. When we add or remove datapoints to even out the problem, we only need to do this for the training data.
We want to show the model an even view of the world, so we give it even data. The test data should still reflect the "real" world. Before we were looking at how imblanaced our overall dataset was, but now let's just look at how biased the training data is.
y_train.value_counts()
y_train.value_counts(normalize=True)
Looks like a little over 7% of our training data is Indian - we'd like to get that up to 50%, so let's see what the imbalanced-learn library can do for us!
Undersampling#
If we're feeling guilty that there are so many additional non-Indian recipes, we could always get rid of those extra non-Indian recipes! In fact, the balanced dataset was me manually creating a new CSV from an even split of Indian/non-Indian recipes..
Instead of manually digging through our dataset to even things out, though, we can rely on imbalanced-learn to do it automatically. We'll use the technique of under sampling to take those ~28k non-Indian recipes and randomly filter them down to around 2,000 to match the number of Indian recipes. (Remember we're only doing this with training data!)
from imblearn.under_sampling import RandomUnderSampler
resampler = RandomUnderSampler()
# Resample X and y so there are equal numbers of each y
X_train_resampled, y_train_resampled = resampler.fit_resample(X_train, y_train)
y_train_resampled.value_counts()
Okay, cool, equal numbers! Let's see how the classifier performs.
# We already split our data, so we don't need to do that again
# Train the classifier on the resampled training data
clf = LinearSVC()
clf.fit(X_train_resampled, y_train_resampled)
# Build a confusion matrix
y_true = y_test
y_pred = clf.predict(X_test)
matrix = confusion_matrix(y_true, y_pred)
label_names = pd.Series(['not indian', 'indian'])
pd.DataFrame(matrix,
columns='Predicted ' + label_names,
index='Is ' + label_names).div(matrix.sum(axis=1), axis=0)
Looking good! It performs as well as our other 3,000/3,000 split because, well, it's more or less the same thing (although the test data is "realistically" unbalanced).
Oversampling#
Cutting out those 27,000 "extra" non-Indian recipes seems like such a bummer, though. Wouldn't it be nice if we somehow found another 25,000 Indian recipes to even up our unbalanced training dataset to 27k non-Indian and 27k Indian? It's possible with oversampling!
Oversampling generates new datapoints based on your existing dataset. In this case we're going to use the RandomOverSampler
, which just fills our dataset with copies of the less-included class. We'll have 27k Indian recipes, but they'll be 25,0000 copies of the original ones. Can that possibly help?
from imblearn.over_sampling import RandomOverSampler
resampler = RandomOverSampler()
X_train_resampled, y_train_resampled = resampler.fit_resample(X_train, y_train)
y_train_resampled.value_counts()
Looking good, a nice even 27,599 apiece. Let's see how the classifier works out!
# We already split our dataset into train and test data
# Train the classifier on the resampled training data
clf = LinearSVC()
clf.fit(X_train_resampled, y_train_resampled)
# Build a confusion matrix with the result
y_true = y_test
y_pred = clf.predict(X_test)
matrix = confusion_matrix(y_true, y_pred)
label_names = pd.Series(['not indian', 'indian'])
pd.DataFrame(matrix,
columns='Predicted ' + label_names,
index='Is ' + label_names).div(matrix.sum(axis=1), axis=0)
Also looking pretty good! A little bit better at predicting non-Indian dishes and a little bit worse at predicting Indian dishes, but it more or less evens out with the under sampled example.
There are also other oversampling techniques that involve creating synthetic data, new datapoints that aren't copies of our data, but rather totally new ones. You can read more about them on the imbalanced-learn page.
Review#
In this section we talked about the problem of imbalanced classes, where an uneven split in your labels can cause suboptimal classifier performance. We used the imbalanced-learn library to talk about two methods of solving the issue - under sampling and oversampling - which both boosted performance as compared to the imbalanced dataset.
Discussion topics#
What is the difference between oversampling and under sampling? Why might have oversampling done a better job predicting non-Indian recipes?
Why did we only resample the training data, and not the test data?
While the idea of automatically-generated fake data might sound more attractive than just re-using existing data, what might be some issues with it?
Can we think of any times when we might not want a balanced dataset?