What are word embeddings?#

Our relationship is troubled! We like words, but computers like math. Word embeddings are a way of bridging that gap (and saving our love!).

The problem#

You know how when we look at a crazy math formula, maybe our brain explodes a little?

\begin{equation*} \left( \sum_{k=1}^n a_k b_k \right)^2 \leq \left( \sum_{k=1}^n a_k^2 \right) \left( \sum_{k=1}^n b_k^2 \right) \end{equation*}

Yeah, that's exactly how computers feel when you use words. In the same way we might say "that weird angry capital E thing" to refer to Σ, computers look at the word "cat" and is like "uh, 0x63 0x61 0x74?"

While software might be able to understand that cat is three letters long, it's a c and an a and a t, and look up the definition in a dictionary for us, the computer doesn't really emotionally know what cats are. It can't feel what a cat is, know about its fur or how it meows, know about how it sleeps in the sun or tears apart our furniture or cruelly makes us take it to the vet on Christmas Day.

Word embeddings are a way of bridging that gap, a way of using math to describe all of those delightful/horrible things about cats (and everything else).

An axis of meaning#

Let's say we have the concept of a cat. Everything we know about a cat, thrown down on the screen, all of it sitting inside a little pink dot. We'll make it look computational so the computer doesn't get scared yet.

[<matplotlib.lines.Line2D at 0x1175b0f28>]

So far so good! Cats don't exist in the world by themselves, though, they exist in relation to other things. Like dogs, for example. Dogs are different than cats, so they should go... somewhere on the other side from cats, I guess?

Cool, great, amazing, wonderful.

It makes makes as much sense as something meaningless can, but it doesn't seem very much like math. Let's add an axis label to explain to the computer what's changing between "dog" on the left and "cat" on the right.

Text(1, -0.35, 'Less catlike')

If we count those little lines as points, we can see that cat is four points more catlike than dog. That's math! We can even put it into a pandas dataframe:

import pandas as pd

pd.DataFrame([
    { 'name': 'cat', 'cat_points': 4 },
    { 'name': 'dog', 'cat_points': 0 }
])
name cat_points
0 cat 4
1 dog 0

There are more animals than just cats and dogs, though, so let's add 'em! How about... a lion?

Lions are pretty catlike, but they're bigger and stronger and more powerful than most of the housecats that live with me (no offense). So we can give them a little fewer cat points than cats, but definitely not as far over as dogs.

Text(1, -0.35, 'Less catlike')

And again, because computers love spreadsheets and counting, we can make another dataframe.

pd.DataFrame([
    { 'name': 'cat', 'cat_points': 4 },
    { 'name': 'dog', 'cat_points': 0 },
    { 'name': 'lion', 'cat_points': 3.5 }
])
name cat_points
0 cat 4.0
1 dog 0.0
2 lion 3.5

I've heard rumors of even more animals, so let's keep going. How about wolves?

A wolf is definitely much closer to a dog than to a cat. Since a wolf is more intimidating than a dog, I think it's even further away than dog is.

Text(1, -0.35, 'Less catlike')

And just so the computer won't feel left out, we can put it into a dataframe to make it nice and math-y.

pd.DataFrame([
    { 'name': 'cat', 'cat_points': 4 },
    { 'name': 'dog', 'cat_points': 0 },
    { 'name': 'lion', 'cat_points': 3.5 },
    { 'name': 'wolf', 'cat_points': -0.5 }
])
name cat_points
0 cat 4.0
1 dog 0.0
2 lion 3.5
3 wolf -0.5

Another dimension#

We've all been to the zoo, we're all animal scientists, we've all watched Beastars, and we're all very very angry at this classification. Why are wolves and lions separated by dogs? How does that make any sense?

Sure, lions and cats are both felines, and wolves and dogs are both canines, but let's think about it:

  • Cats: totally domesticated
  • Dogs: totally domesticated
  • Wolves: totally wild
  • Lions: totally wild

If we're teaching our computer with just "hey this is like a cat" or "hey this is less like a cat" it isn't going to learn anything important. This is the nuance of our human experience of the world that computers are missing out on!

It's this nuance we're going to teach right now by giving our graph a brand new axis: wild or domesticated.

Text(-0.55, 1, 'More wild')

Look at that beauty!!! It's explaining everything I could ever want. And just so we don't leave out the computer:

pd.DataFrame([
    { 'name': 'cat', 'cat_points': 4, 'wildness': 0.5 },
    { 'name': 'dog', 'cat_points': 0, 'wildness': 0 },
    { 'name': 'lion', 'cat_points': 3.5, 'wildness': 4 },
    { 'name': 'wolf', 'cat_points': -0.5, 'wildness': 4 }
])
name cat_points wildness
0 cat 4.0 0.5
1 dog 0.0 0.0
2 lion 3.5 4.0
3 wolf -0.5 4.0

This is an excellent graph, and it's an excellent (if not perfect) way to describe all sorts of animals! We can describe a few just for fun:

  • Tigers (basically where lions are)
  • Killer whales (not catlike at all, pretty wild)
  • Worms (very very not catlike, a little wild)

We keep putting numbers in that chart, and the computer keeps having a better and better idea of what animals are similar to what other animals. Eventually it builds up a whole worldview of how catlike things are, and how wild they are, and then it can probably analyze something very complicated about zoology!

There's a problem lurking around the corner, though, and it's this: our computer is interested in things that aren't animals.

A third dimension#

We were feeling good for a hot second, but then we realized things other than animals existed. Like shoes, for example.

Text(-6.45, 1, 'More wild')

Shoes aren't like cats at all and are not very wild. But that doesn't do a good job describing them at all. It's like when we added wolves and lions and needed a new axis.

So what are we going to do? The exact same thing: add another piece of data to it! We'll call this axis something like "things you can wear."

df = pd.DataFrame([
    { 'name': 'cat', 'cat_points': 4, 'wildness': 0.5, 'wearability': 0.5 },
    { 'name': 'dog', 'cat_points': 0, 'wildness': 0, 'wearability': 0.25  },
    { 'name': 'lion', 'cat_points': 3.5, 'wildness': 4, 'wearability': -1  },
    { 'name': 'wolf', 'cat_points': -0.5, 'wildness': 4, 'wearability': -1  },
    { 'name': 'shoe', 'cat_points': -3.5, 'wildness': 0, 'wearability': 3  }
])
df
name cat_points wildness wearability
0 cat 4.0 0.5 0.50
1 dog 0.0 0.0 0.25
2 lion 3.5 4.0 -1.00
3 wolf -0.5 4.0 -1.00
4 shoe -3.5 0.0 3.00

And while two dimensions was all right, plotting this in three should really get our blood pumping! We're going to use a library called plotly to take care of this for us.

import plotly.graph_objects as go
import plotly.io as pio

pio.renderers.default = 'notebook'

fig = go.Figure(data=go.Scatter3d(
    x=df.cat_points,
    y=df.wildness,
    z=df.wearability,
    text=df.name,
    mode='markers + text',
    marker=dict(
        color = 'pink',
    )
))

fig.update_layout(scene = dict(
                    xaxis_title='cat points',
                    yaxis_title='wildness',
                    zaxis_title='wearability'
                 ))

pio.show(fig)