{
    "cells": [
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "# Classification problems with imbalanced inputs\n",
                "\n",
                "Oftentimes when we're doing real-world classification problems, we have the problem of **\"imbalanced classes\"**.\n",
                "\n",
                "Let's say we're analyzing a document dump, and trying to find the documents that are interesting to us. Maybe we're only interested in 10% of them! The fact that there's such a bias - 90% of them are uninteresting - **will mess with our classifier.** Let's take a look at [imbalanced-learn](https://imbalanced-learn.readthedocs.io/en/stable/) library to help fix this problem!"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "<p class=\"reading-options\">\n  <a class=\"btn\" href=\"/classification/correcting-for-imbalanced-datasets\">\n    <i class=\"fa fa-sm fa-book\"></i>\n    Read online\n  </a>\n  <a class=\"btn\" href=\"/classification/notebooks/Correcting for imbalanced datasets.ipynb\">\n    <i class=\"fa fa-sm fa-download\"></i>\n    Download notebook\n  </a>\n  <a class=\"btn\" href=\"https://colab.research.google.com/github/littlecolumns/ds4j-notebooks/blob/master/classification/notebooks/Correcting for imbalanced datasets.ipynb\" target=\"_new\">\n    <i class=\"fa fa-sm fa-laptop\"></i>\n    Interactive version\n  </a>\n</p>"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### Prep work: Downloading necessary files\n",
                "Before we get started, we need to download all of the data we'll be using.\n",
                "* **recipes-indian.csv:** Indian classification recipes - a selection of recipe ingredient lists, with half of them being labeled as Indian cuisine\n",
                "* **recipes.csv:** recipes - a selection of recipe ingredient lists, with each labeled with the cuisine its from\n"
            ]
        },
        {
            "cell_type": "code",
            "metadata": {},
            "source": [
                "# Make data directory if it doesn't exist\n",
                "!mkdir -p data\n",
                "!wget -nc https://nyc3.digitaloceanspaces.com/ml-files-distro/v1/classification/data/recipes-indian.csv -P data\n",
                "!wget -nc https://nyc3.digitaloceanspaces.com/ml-files-distro/v1/classification/data/recipes.csv -P data"
            ],
            "outputs": [],
            "execution_count": null
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "We're going to go through this pretty quickly, so you should be familiar with vectorizing, classification, and confusion matrices going in."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 9,
            "metadata": {},
            "outputs": [],
            "source": [
                "import pandas as pd\n",
                "\n",
                "from sklearn.feature_extraction.text import TfidfVectorizer\n",
                "from sklearn.svm import LinearSVC\n",
                "from sklearn.model_selection import train_test_split\n",
                "from sklearn.metrics import confusion_matrix"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Our datasets\n",
                "\n",
                "We're going to be looking at two datasets today. They're both **recipes and ingredient lists**, and with both we're predicting whether we can **accurate determine which recipes are Indian**.\n",
                "\n",
                "Let's read them both in."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 10,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/html": [
                            "<div>\n",
                            "<style scoped>\n",
                            "    .dataframe tbody tr th:only-of-type {\n",
                            "        vertical-align: middle;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe tbody tr th {\n",
                            "        vertical-align: top;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe thead th {\n",
                            "        text-align: right;\n",
                            "    }\n",
                            "</style>\n",
                            "<table border=\"1\" class=\"dataframe\">\n",
                            "  <thead>\n",
                            "    <tr style=\"text-align: right;\">\n",
                            "      <th></th>\n",
                            "      <th>cuisine</th>\n",
                            "      <th>id</th>\n",
                            "      <th>ingredient_list</th>\n",
                            "      <th>is_indian</th>\n",
                            "    </tr>\n",
                            "  </thead>\n",
                            "  <tbody>\n",
                            "    <tr>\n",
                            "      <th>0</th>\n",
                            "      <td>indian</td>\n",
                            "      <td>23348</td>\n",
                            "      <td>minced ginger, garlic, oil, coriander powder, ...</td>\n",
                            "      <td>1</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>1</th>\n",
                            "      <td>indian</td>\n",
                            "      <td>18869</td>\n",
                            "      <td>chicken, chicken breasts</td>\n",
                            "      <td>1</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>2</th>\n",
                            "      <td>indian</td>\n",
                            "      <td>36405</td>\n",
                            "      <td>flour, rose essence, frying oil, powdered milk...</td>\n",
                            "      <td>1</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>3</th>\n",
                            "      <td>indian</td>\n",
                            "      <td>11494</td>\n",
                            "      <td>soda, ghee, sugar, khoa, maida flour, milk, oil</td>\n",
                            "      <td>1</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>4</th>\n",
                            "      <td>indian</td>\n",
                            "      <td>32675</td>\n",
                            "      <td>tumeric, garam masala, salt, chicken, curry le...</td>\n",
                            "      <td>1</td>\n",
                            "    </tr>\n",
                            "  </tbody>\n",
                            "</table>\n",
                            "</div>"
                        ],
                        "text/plain": [
                            "  cuisine     id                                    ingredient_list  is_indian\n",
                            "0  indian  23348  minced ginger, garlic, oil, coriander powder, ...          1\n",
                            "1  indian  18869                           chicken, chicken breasts          1\n",
                            "2  indian  36405  flour, rose essence, frying oil, powdered milk...          1\n",
                            "3  indian  11494    soda, ghee, sugar, khoa, maida flour, milk, oil          1\n",
                            "4  indian  32675  tumeric, garam masala, salt, chicken, curry le...          1"
                        ]
                    },
                    "execution_count": 10,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "df_balanced = pd.read_csv(\"data/recipes-indian.csv\")\n",
                "df_balanced['is_indian'] = (df_balanced.cuisine == \"indian\").astype(int)\n",
                "\n",
                "df_balanced.head()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 11,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/html": [
                            "<div>\n",
                            "<style scoped>\n",
                            "    .dataframe tbody tr th:only-of-type {\n",
                            "        vertical-align: middle;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe tbody tr th {\n",
                            "        vertical-align: top;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe thead th {\n",
                            "        text-align: right;\n",
                            "    }\n",
                            "</style>\n",
                            "<table border=\"1\" class=\"dataframe\">\n",
                            "  <thead>\n",
                            "    <tr style=\"text-align: right;\">\n",
                            "      <th></th>\n",
                            "      <th>cuisine</th>\n",
                            "      <th>id</th>\n",
                            "      <th>ingredient_list</th>\n",
                            "      <th>is_indian</th>\n",
                            "    </tr>\n",
                            "  </thead>\n",
                            "  <tbody>\n",
                            "    <tr>\n",
                            "      <th>0</th>\n",
                            "      <td>greek</td>\n",
                            "      <td>10259</td>\n",
                            "      <td>romaine lettuce, black olives, grape tomatoes,...</td>\n",
                            "      <td>0</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>1</th>\n",
                            "      <td>southern_us</td>\n",
                            "      <td>25693</td>\n",
                            "      <td>plain flour, ground pepper, salt, tomatoes, gr...</td>\n",
                            "      <td>0</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>2</th>\n",
                            "      <td>filipino</td>\n",
                            "      <td>20130</td>\n",
                            "      <td>eggs, pepper, salt, mayonaise, cooking oil, gr...</td>\n",
                            "      <td>0</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>3</th>\n",
                            "      <td>indian</td>\n",
                            "      <td>22213</td>\n",
                            "      <td>water, vegetable oil, wheat, salt</td>\n",
                            "      <td>1</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>4</th>\n",
                            "      <td>indian</td>\n",
                            "      <td>13162</td>\n",
                            "      <td>black pepper, shallots, cornflour, cayenne pep...</td>\n",
                            "      <td>1</td>\n",
                            "    </tr>\n",
                            "  </tbody>\n",
                            "</table>\n",
                            "</div>"
                        ],
                        "text/plain": [
                            "       cuisine     id                                    ingredient_list  \\\n",
                            "0        greek  10259  romaine lettuce, black olives, grape tomatoes,...   \n",
                            "1  southern_us  25693  plain flour, ground pepper, salt, tomatoes, gr...   \n",
                            "2     filipino  20130  eggs, pepper, salt, mayonaise, cooking oil, gr...   \n",
                            "3       indian  22213                  water, vegetable oil, wheat, salt   \n",
                            "4       indian  13162  black pepper, shallots, cornflour, cayenne pep...   \n",
                            "\n",
                            "   is_indian  \n",
                            "0          0  \n",
                            "1          0  \n",
                            "2          0  \n",
                            "3          1  \n",
                            "4          1  "
                        ]
                    },
                    "execution_count": 11,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "df_unbalanced = pd.read_csv(\"data/recipes.csv\")\n",
                "df_unbalanced['is_indian'] = (df_unbalanced.cuisine == \"indian\").astype(int)\n",
                "\n",
                "df_unbalanced.head()"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "They both look similar enough, right? A list of ingredients and an `is_indian` target column we'll be using as our label.\n",
                "\n",
                "### Finding the imbalance\n",
                "\n",
                "The real difference is how many of the recipes are Indian in each dataset. Let's take a look:"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 12,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "1    3000\n",
                            "0    3000\n",
                            "Name: is_indian, dtype: int64"
                        ]
                    },
                    "execution_count": 12,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "df_balanced.is_indian.value_counts()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 13,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "0    36771\n",
                            "1     3003\n",
                            "Name: is_indian, dtype: int64"
                        ]
                    },
                    "execution_count": 13,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "df_unbalanced.is_indian.value_counts()"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Ouch! That second dataset is really uneven - over ten times as many non-Indian recipes as there are Indian recipes!\n",
                "\n",
                "The thing is: **this is usually how data looks in the real world.** You rarely have even numbers between your classes, and you often thing \"more data is better data.\" We'll see how it plays out when we actually run our classifiers!\n",
                "\n",
                "## Testing our datasets\n",
                "\n",
                "We're going to use a `TfidfVectorizer` to convert ingredient lists to numbers, run a test/train split, and then train (and test) a `LinearSVC` classifier on the results. We'll start with the **balanced dataset**.\n",
                "\n",
                "### Balanced dataset"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 14,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "1    3000\n",
                            "0    3000\n",
                            "Name: is_indian, dtype: int64"
                        ]
                    },
                    "execution_count": 14,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "# Create a vectorizer and train it\n",
                "vectorizer = TfidfVectorizer()\n",
                "matrix = vectorizer.fit_transform(df_balanced.ingredient_list)\n",
                "\n",
                "# Features are our matrix of tf-idf values\n",
                "# labels are whether each recipe is Indian or not\n",
                "X = matrix\n",
                "y = df_balanced.is_indian\n",
                "\n",
                "# How many are Indian?\n",
                "y.value_counts()"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "We still have an even split, 3000 non-Indian recipes and 3000 Indian recipes. Let's run a test/train split and see how the results look."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 15,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/html": [
                            "<div>\n",
                            "<style scoped>\n",
                            "    .dataframe tbody tr th:only-of-type {\n",
                            "        vertical-align: middle;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe tbody tr th {\n",
                            "        vertical-align: top;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe thead th {\n",
                            "        text-align: right;\n",
                            "    }\n",
                            "</style>\n",
                            "<table border=\"1\" class=\"dataframe\">\n",
                            "  <thead>\n",
                            "    <tr style=\"text-align: right;\">\n",
                            "      <th></th>\n",
                            "      <th>Predicted not indian</th>\n",
                            "      <th>Predicted indian</th>\n",
                            "    </tr>\n",
                            "  </thead>\n",
                            "  <tbody>\n",
                            "    <tr>\n",
                            "      <th>Is not indian</th>\n",
                            "      <td>0.962815</td>\n",
                            "      <td>0.037185</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>Is indian</th>\n",
                            "      <td>0.048193</td>\n",
                            "      <td>0.951807</td>\n",
                            "    </tr>\n",
                            "  </tbody>\n",
                            "</table>\n",
                            "</div>"
                        ],
                        "text/plain": [
                            "               Predicted not indian  Predicted indian\n",
                            "Is not indian              0.962815          0.037185\n",
                            "Is indian                  0.048193          0.951807"
                        ]
                    },
                    "execution_count": 15,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "# Split into test and train data\n",
                "X_train, X_test, y_train, y_test = train_test_split(X, y)\n",
                "\n",
                "# Build a classifier and train it\n",
                "clf = LinearSVC()\n",
                "clf.fit(X_train, y_train)\n",
                "\n",
                "# Test our classifier and build a confusion matrix\n",
                "y_true = y_test\n",
                "y_pred = clf.predict(X_test)\n",
                "matrix = confusion_matrix(y_true, y_pred)\n",
                "\n",
                "label_names = pd.Series(['not indian', 'indian'])\n",
                "pd.DataFrame(matrix,\n",
                "     columns='Predicted ' + label_names,\n",
                "     index='Is ' + label_names).div(matrix.sum(axis=1), axis=0)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "**Our classifier looks pretty good!** Around 96% accuracy for predicting non-Indian food, and around 95% correctly predicting Indian food. High quality *and* even.\n",
                "\n",
                "Let's move on to see how it looks with our **unabalanced dataset**.\n",
                "\n",
                "### Unbalanced dataset "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 16,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "0    36771\n",
                            "1     3003\n",
                            "Name: is_indian, dtype: int64"
                        ]
                    },
                    "execution_count": 16,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "# Create a vectorizer and train it\n",
                "vectorizer = TfidfVectorizer()\n",
                "matrix = vectorizer.fit_transform(df_unbalanced.ingredient_list)\n",
                "\n",
                "# Features are our matrix of tf-idf values\n",
                "# labels are whether each recipe is Indian or not\n",
                "X = matrix\n",
                "y = df_unbalanced.is_indian\n",
                "\n",
                "# How many are Indian?\n",
                "y.value_counts()"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Again: around 36k non-Indian recipes really really outweighing the 3,003 Indian recipes. While we love the world of \"more more more data,\" let's see what that imbalance does to our classifier."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 17,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/html": [
                            "<div>\n",
                            "<style scoped>\n",
                            "    .dataframe tbody tr th:only-of-type {\n",
                            "        vertical-align: middle;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe tbody tr th {\n",
                            "        vertical-align: top;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe thead th {\n",
                            "        text-align: right;\n",
                            "    }\n",
                            "</style>\n",
                            "<table border=\"1\" class=\"dataframe\">\n",
                            "  <thead>\n",
                            "    <tr style=\"text-align: right;\">\n",
                            "      <th></th>\n",
                            "      <th>Predicted not indian</th>\n",
                            "      <th>Predicted indian</th>\n",
                            "    </tr>\n",
                            "  </thead>\n",
                            "  <tbody>\n",
                            "    <tr>\n",
                            "      <th>Is not indian</th>\n",
                            "      <td>0.992150</td>\n",
                            "      <td>0.007850</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>Is indian</th>\n",
                            "      <td>0.180052</td>\n",
                            "      <td>0.819948</td>\n",
                            "    </tr>\n",
                            "  </tbody>\n",
                            "</table>\n",
                            "</div>"
                        ],
                        "text/plain": [
                            "               Predicted not indian  Predicted indian\n",
                            "Is not indian              0.992150          0.007850\n",
                            "Is indian                  0.180052          0.819948"
                        ]
                    },
                    "execution_count": 17,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "# Split our dataset is train and test data\n",
                "X_train, X_test, y_train, y_test = train_test_split(X, y)\n",
                "\n",
                "# Train the classifier on the training data\n",
                "clf = LinearSVC()\n",
                "clf.fit(X_train, y_train)\n",
                "\n",
                "# Test our classifier and build a confusion matrix\n",
                "y_true = y_test\n",
                "y_pred = clf.predict(X_test)\n",
                "matrix = confusion_matrix(y_true, y_pred)\n",
                "\n",
                "label_names = pd.Series(['not indian', 'indian'])\n",
                "pd.DataFrame(matrix,\n",
                "     columns='Predicted ' + label_names,\n",
                "     index='Is ' + label_names).div(matrix.sum(axis=1), axis=0)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Ouch!!! While we're doing **really well** at predicting non-Indian dishes, our ability to predict Indian dishes has plummeted to just over 80%.\n",
                "\n",
                "Why does this happen? An easy way to think about it is **when it's a risky decision, it's always safest to guess \"not Indian.\"** In fact, if we *always guessed non-Indian*, no matter what, we'd be right..."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 18,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "0.9244984160506864"
                        ]
                    },
                    "execution_count": 18,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "36771/(36771+3003)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "About 92% of the time! So how do we solve this problem?\n",
                "\n",
                "## Solving the problem\n",
                "\n",
                "Solving the problem of unbalanced (or biased) input classes is actually not too hard! There's a nice library that can give us a hand, [imbalanced-learn](https://imbalanced-learn.readthedocs.io/en/stable/).\n",
                "\n",
                "imbalanced-learn will **resample** our dataset, either generating new datapoints or pruning out existing datapoints, until the classes are evened out.\n",
                "\n",
                "### What do we resample?\n",
                "\n",
                "An important thing to note is that **the problem with bias happens when we train our model.** If we show our model a skewed view of the world, it'll carry that bias when making judgments in the future. When we add or remove datapoints to even out the problem, **we only need to do this for the training data.**\n",
                "\n",
                "We want to show the model an even view of the world, so we give it even data. The test data should still reflect the \"real\" world. Before we were looking at how imblanaced our overall dataset was, but now let's **just look at how biased the training data is.**"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 25,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "0    27599\n",
                            "1     2231\n",
                            "Name: is_indian, dtype: int64"
                        ]
                    },
                    "execution_count": 25,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "y_train.value_counts()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 26,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "0    0.92521\n",
                            "1    0.07479\n",
                            "Name: is_indian, dtype: float64"
                        ]
                    },
                    "execution_count": 26,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "y_train.value_counts(normalize=True)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Looks like a little over 7% of our training data is Indian - we'd like to get that up to 50%, so let's see what the imbalanced-learn library can do for us!"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### Undersampling\n",
                "\n",
                "If we're feeling guilty that there are so many additional non-Indian recipes, *we could always get rid of those extra non-Indian recipes!* In fact, the balanced dataset was me manually creating a new CSV from an even split of Indian/non-Indian recipes..\n",
                "\n",
                "Instead of manually digging through our dataset to even things out, though, we can rely on imbalanced-learn to do it automatically. We'll use the technique of **undersampling** to take those ~28k non-Indian recipes and randomly filter them down to around 2,000 to match the number of Indian recipes. (Remember we're only doing this with training data!)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 27,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "1    2231\n",
                            "0    2231\n",
                            "Name: is_indian, dtype: int64"
                        ]
                    },
                    "execution_count": 27,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "from imblearn.under_sampling import RandomUnderSampler\n",
                "\n",
                "resampler = RandomUnderSampler()\n",
                "# Resample X and y so there are equal numbers of each y\n",
                "X_train_resampled, y_train_resampled = resampler.fit_resample(X_train, y_train)\n",
                "\n",
                "y_train_resampled.value_counts()"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Okay, cool, equal numbers! Let's see how the classifier performs."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 28,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/html": [
                            "<div>\n",
                            "<style scoped>\n",
                            "    .dataframe tbody tr th:only-of-type {\n",
                            "        vertical-align: middle;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe tbody tr th {\n",
                            "        vertical-align: top;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe thead th {\n",
                            "        text-align: right;\n",
                            "    }\n",
                            "</style>\n",
                            "<table border=\"1\" class=\"dataframe\">\n",
                            "  <thead>\n",
                            "    <tr style=\"text-align: right;\">\n",
                            "      <th></th>\n",
                            "      <th>Predicted not indian</th>\n",
                            "      <th>Predicted indian</th>\n",
                            "    </tr>\n",
                            "  </thead>\n",
                            "  <tbody>\n",
                            "    <tr>\n",
                            "      <th>Is not indian</th>\n",
                            "      <td>0.957479</td>\n",
                            "      <td>0.042521</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>Is indian</th>\n",
                            "      <td>0.051813</td>\n",
                            "      <td>0.948187</td>\n",
                            "    </tr>\n",
                            "  </tbody>\n",
                            "</table>\n",
                            "</div>"
                        ],
                        "text/plain": [
                            "               Predicted not indian  Predicted indian\n",
                            "Is not indian              0.957479          0.042521\n",
                            "Is indian                  0.051813          0.948187"
                        ]
                    },
                    "execution_count": 28,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "# We already split our data, so we don't need to do that again\n",
                "\n",
                "# Train the classifier on the resampled training data\n",
                "clf = LinearSVC()\n",
                "clf.fit(X_train_resampled, y_train_resampled)\n",
                "\n",
                "# Build a confusion matrix\n",
                "y_true = y_test\n",
                "y_pred = clf.predict(X_test)\n",
                "matrix = confusion_matrix(y_true, y_pred)\n",
                "\n",
                "label_names = pd.Series(['not indian', 'indian'])\n",
                "pd.DataFrame(matrix,\n",
                "     columns='Predicted ' + label_names,\n",
                "     index='Is ' + label_names).div(matrix.sum(axis=1), axis=0)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Looking good! It performs as well as our other 3,000/3,000 split because, well, it's more or less the same thing (although the test data is \"realistically\" unbalanced)."
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### Oversampling\n",
                "\n",
                "Cutting out those 27,000 \"extra\" non-Indian recipes seems like such a bummer, though. Wouldn't it be nice if we somehow found another 25,000 Indian recipes to even up our unbalanced training dataset to 27k non-Indian and 27k Indian? It's possible with **oversampling!**\n",
                "\n",
                "Oversampling generates **new datapoints** based on your existing dataset. In this case we're going to use the `RandomOverSampler`, which just fills our dataset with **copies of the less-included class**. We'll have 27k Indian recipes, *but they'll be 25,0000 copies of the original ones*. Can that possibly help?"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 32,
            "metadata": {},
            "outputs": [],
            "source": [
                "from imblearn.over_sampling import RandomOverSampler\n",
                "\n",
                "resampler = RandomOverSampler()\n",
                "X_train_resampled, y_train_resampled = resampler.fit_resample(X_train, y_train)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 33,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "1    27599\n",
                            "0    27599\n",
                            "Name: is_indian, dtype: int64"
                        ]
                    },
                    "execution_count": 33,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "y_train_resampled.value_counts()"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Looking good, a nice even 27,599 apiece. Let's see how the classifier works out!"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 34,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/html": [
                            "<div>\n",
                            "<style scoped>\n",
                            "    .dataframe tbody tr th:only-of-type {\n",
                            "        vertical-align: middle;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe tbody tr th {\n",
                            "        vertical-align: top;\n",
                            "    }\n",
                            "\n",
                            "    .dataframe thead th {\n",
                            "        text-align: right;\n",
                            "    }\n",
                            "</style>\n",
                            "<table border=\"1\" class=\"dataframe\">\n",
                            "  <thead>\n",
                            "    <tr style=\"text-align: right;\">\n",
                            "      <th></th>\n",
                            "      <th>Predicted not indian</th>\n",
                            "      <th>Predicted indian</th>\n",
                            "    </tr>\n",
                            "  </thead>\n",
                            "  <tbody>\n",
                            "    <tr>\n",
                            "      <th>Is not indian</th>\n",
                            "      <td>0.969363</td>\n",
                            "      <td>0.030637</td>\n",
                            "    </tr>\n",
                            "    <tr>\n",
                            "      <th>Is indian</th>\n",
                            "      <td>0.068653</td>\n",
                            "      <td>0.931347</td>\n",
                            "    </tr>\n",
                            "  </tbody>\n",
                            "</table>\n",
                            "</div>"
                        ],
                        "text/plain": [
                            "               Predicted not indian  Predicted indian\n",
                            "Is not indian              0.969363          0.030637\n",
                            "Is indian                  0.068653          0.931347"
                        ]
                    },
                    "execution_count": 34,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "# We already split our dataset into train and test data\n",
                "\n",
                "# Train the classifier on the resampled training data\n",
                "clf = LinearSVC()\n",
                "clf.fit(X_train_resampled, y_train_resampled)\n",
                "\n",
                "# Build a confusion matrix with the result\n",
                "y_true = y_test\n",
                "y_pred = clf.predict(X_test)\n",
                "matrix = confusion_matrix(y_true, y_pred)\n",
                "\n",
                "label_names = pd.Series(['not indian', 'indian'])\n",
                "pd.DataFrame(matrix,\n",
                "     columns='Predicted ' + label_names,\n",
                "     index='Is ' + label_names).div(matrix.sum(axis=1), axis=0)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Also looking pretty good! A little bit better at predicting non-Indian dishes and a little bit worse at predicting Indian dishes, but it more or less evens out with the undersampled example. \n",
                "\n",
                "There are also other oversampling techniques that involve **creating synthetic data,** new datapoints that aren't *copies* of our data, but rather totally new ones. You can read more about them [on the imbalanced-learn page](https://imbalanced-learn.readthedocs.io/en/stable/over_sampling.html).\n",
                "\n",
                "## Review\n",
                "\n",
                "In this section we talked about the problem of **imbalanced classes**, where an uneven split in your labels can cause suboptimal classifier performance. We used the imbalanced-learn library to talk about two methods of solving the issue - undersampling and oversampling - which both boosted performance as compared to the imbalanced dataset.\n",
                "\n",
                "## Discussion topics\n",
                "\n",
                "What is the difference between oversampling and undersampling? Why might have oversampling done a better job predicting non-Indian recipes?\n",
                "\n",
                "Why did we only resample the training data, and not the test data?\n",
                "\n",
                "While the idea of automatically-generated fake data might sound more attractive than just re-using existing data, [what might be some issues with it](https://imbalanced-learn.readthedocs.io/en/stable/over_sampling.html)?\n",
                "\n",
                "Can we think of any times when we might *not* want a balanced dataset?"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "Python 3",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.6.8"
        },
        "toc": {
            "base_numbering": 1,
            "nav_menu": {},
            "number_sections": true,
            "sideBar": true,
            "skip_h1_title": false,
            "title_cell": "Table of Contents",
            "title_sidebar": "Contents",
            "toc_cell": false,
            "toc_position": {},
            "toc_section_display": true,
            "toc_window_display": false
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}