Topic models with Gensim#

Gensim is a popular library for topic modeling. Here we'll see how it stacks up to scikit-learn.

Gensim vs. Scikit-learn#

Gensim is a very very popular piece of software to do topic modeling with (as is Mallet, if you're making a list). Since we're using scikit-learn for everything else, though, we use scikit-learn instead of Gensim when we get to topic modeling.

Since someone might show up one day offering us tens of thousands of dollars to demonstrate proficiency in Gensim, though, we might as well see how it works as compared to scikit-learn.

Our data#

We'll be using the same dataset as we did with scikit-learn: State of the Union addresses from 1790 to 2012, where America's president addresses the Congress about the coming year.

import pandas as pd

df = pd.read_csv("data/state-of-the-union.csv")

# Clean it up a little bit, removing non-word characters (numbers and ___ etc)
df.content = df.content.str.replace("[^A-Za-z ]", " ")

df.head()
year content
0 1790 George Washington January Fellow Citi...
1 1790 State of the Union Address George Washington ...
2 1791 State of the Union Address George Washington ...
3 1792 State of the Union Address George Washington ...
4 1793 State of the Union Address George Washington ...

Using Gensim#

#!pip install --upgrade gensim
from gensim.utils import simple_preprocess

texts = df.content.apply(simple_preprocess)
from gensim import corpora

dictionary = corpora.Dictionary(texts)
dictionary.filter_extremes(no_below=5, no_above=0.5)

corpus = [dictionary.doc2bow(text) for text in texts]
from gensim import models

tfidf = models.TfidfModel(corpus)
corpus_tfidf = tfidf[corpus]
n_topics = 15

# Build an LSI model
lsi_model = models.LsiModel(corpus_tfidf,
                            id2word=dictionary,
                            num_topics=n_topics)
lsi_model.print_topics()
[(0,
  '0.084*"tonight" + 0.073*"program" + 0.066*"ve" + 0.065*"help" + 0.065*"budget" + 0.065*"mexico" + 0.065*"americans" + 0.061*"programs" + 0.059*"jobs" + 0.058*"re"'),
 (1,
  '-0.206*"tonight" + -0.169*"ve" + -0.137*"re" + -0.136*"jobs" + -0.130*"americans" + -0.124*"budget" + -0.123*"help" + -0.116*"programs" + -0.112*"program" + -0.106*"billion"'),
 (2,
  '-0.199*"tonight" + -0.177*"ve" + -0.164*"re" + 0.137*"program" + -0.094*"jobs" + 0.092*"farm" + -0.092*"ll" + -0.091*"iraq" + 0.087*"veterans" + 0.081*"interstate"'),
 (3,
  '0.141*"program" + -0.121*"silver" + -0.114*"re" + -0.110*"cent" + 0.106*"communist" + -0.104*"ve" + -0.098*"tonight" + 0.097*"soviet" + 0.096*"programs" + -0.089*"gold"'),
 (4,
  '0.196*"iraq" + 0.171*"terrorists" + -0.143*"silver" + -0.133*"gold" + 0.124*"interstate" + 0.117*"iraqi" + -0.111*"programs" + 0.109*"al" + -0.099*"notes" + -0.099*"soviet"'),
 (5,
  '0.274*"iraq" + 0.237*"terrorists" + -0.187*"re" + -0.182*"ve" + 0.169*"iraqi" + 0.154*"al" + 0.138*"terror" + 0.128*"terrorist" + 0.101*"afghanistan" + -0.097*"ll"'),
 (6,
  '-0.221*"mexico" + -0.205*"texas" + -0.092*"kansas" + -0.092*"oregon" + -0.091*"paper" + 0.089*"silver" + -0.088*"mexican" + 0.086*"gentlemen" + -0.084*"california" + -0.077*"slavery"'),
 (7,
  '-0.159*"banks" + -0.147*"iraq" + -0.121*"veterans" + 0.113*"japanese" + 0.111*"vietnam" + -0.109*"terrorists" + -0.102*"bank" + 0.099*"soviet" + -0.097*"notes" + 0.097*"fighting"'),
 (8,
  '0.189*"silver" + -0.176*"mexico" + 0.150*"gold" + 0.138*"notes" + -0.135*"texas" + 0.111*"gentlemen" + 0.108*"circulation" + 0.105*"currency" + 0.102*"coinage" + 0.099*"paper"'),
 (9,
  '-0.341*"vietnam" + 0.205*"ve" + -0.168*"tonight" + 0.159*"re" + -0.112*"billion" + 0.107*"soviet" + 0.101*"ll" + 0.095*"planes" + -0.088*"programs" + -0.087*"interstate"'),
 (10,
  '0.195*"enemy" + -0.195*"soviet" + -0.148*"gentlemen" + -0.129*"ve" + 0.106*"savages" + 0.099*"vietnam" + 0.096*"militia" + 0.093*"whilst" + -0.084*"chambers" + -0.080*"oil"'),
 (11,
  '0.242*"soviet" + 0.201*"ve" + 0.174*"oil" + -0.143*"spain" + -0.100*"colonies" + -0.099*"democracy" + 0.098*"enemy" + 0.094*"militia" + 0.088*"afghanistan" + 0.085*"salt"'),
 (12,
  '-0.248*"gentlemen" + 0.156*"soviet" + 0.126*"spain" + 0.100*"colonies" + -0.099*"slavery" + -0.093*"kansas" + 0.090*"gold" + 0.089*"notes" + 0.085*"oil" + -0.082*"emancipation"'),
 (13,
  '0.238*"texas" + 0.226*"mexico" + -0.157*"slavery" + -0.143*"kansas" + 0.137*"mexican" + -0.120*"emancipation" + -0.118*"rebellion" + 0.117*"vietnam" + 0.107*"gentlemen" + 0.103*"annexation"'),
 (14,
  '-0.319*"vietnam" + -0.199*"tonight" + 0.092*"program" + -0.085*"planes" + 0.080*"housing" + 0.080*"forest" + -0.079*"fighting" + 0.076*"corporations" + 0.069*"militia" + -0.068*"colonies"')]

Gensim is all about how important each word is to the category. Why not visualize it? First we'll make a dataframe that shows each topic, its top five words, and its values.

n_words = 10

topic_words = pd.DataFrame({})

for i, topic in enumerate(lsi_model.get_topics()):
    top_feature_ids = topic.argsort()[-n_words:][::-1]
    feature_values = topic[top_feature_ids]
    words = [dictionary[id] for id in top_feature_ids]
    topic_df = pd.DataFrame({'value': feature_values, 'word': words, 'topic': i})
    topic_words = pd.concat([topic_words, topic_df], ignore_index=True)

topic_words.head()
value word topic
0 0.083982 tonight 0
1 0.073466 program 0
2 0.065711 ve 0
3 0.065221 help 0
4 0.065030 budget 0

Then we'll use seaborn to visualize it.

import seaborn as sns

g = sns.FacetGrid(topic_words, col="topic", col_wrap=3, sharey=False)
g.map(plt.barh, "word", "value")
<seaborn.axisgrid.FacetGrid at 0x139663668>

Using LDA with Gensim#

Now we'll use LDA.

from gensim.utils import simple_preprocess

texts = df.content.apply(simple_preprocess)
from gensim import corpora

dictionary = corpora.Dictionary(texts)
dictionary.filter_extremes(no_below=5, no_above=0.5, keep_n=2000)
corpus = [dictionary.doc2bow(text) for text in texts]
from gensim import models

n_topics = 15

lda_model = models.LdaModel(corpus=corpus, num_topics=n_topics)
lda_model.print_topics()
[(0,
  '0.003*"1260" + 0.003*"1930" + 0.003*"1971" + 0.003*"1559" + 0.003*"1327" + 0.002*"151" + 0.002*"1986" + 0.002*"1446" + 0.002*"951" + 0.002*"266"'),
 (1,
  '0.003*"1626" + 0.003*"1986" + 0.003*"1559" + 0.003*"1784" + 0.002*"976" + 0.002*"440" + 0.002*"1257" + 0.002*"1060" + 0.002*"951" + 0.002*"151"'),
 (2,
  '0.004*"1986" + 0.003*"1242" + 0.003*"1971" + 0.003*"1260" + 0.003*"1626" + 0.003*"1989" + 0.002*"62" + 0.002*"151" + 0.002*"1974" + 0.002*"1545"'),
 (3,
  '0.005*"1559" + 0.002*"1626" + 0.002*"951" + 0.002*"1446" + 0.002*"578" + 0.002*"1327" + 0.002*"1459" + 0.002*"973" + 0.002*"976" + 0.002*"1865"'),
 (4,
  '0.005*"1260" + 0.005*"1930" + 0.004*"1999" + 0.003*"1971" + 0.003*"1242" + 0.003*"1559" + 0.003*"1974" + 0.003*"1986" + 0.002*"1651" + 0.002*"1644"'),
 (5,
  '0.004*"1559" + 0.004*"1986" + 0.004*"1242" + 0.003*"1930" + 0.003*"1260" + 0.003*"1974" + 0.003*"1989" + 0.003*"1971" + 0.002*"440" + 0.002*"1964"'),
 (6,
  '0.006*"1559" + 0.003*"151" + 0.003*"1327" + 0.003*"976" + 0.003*"951" + 0.002*"214" + 0.002*"1986" + 0.002*"116" + 0.002*"578" + 0.002*"1619"'),
 (7,
  '0.004*"1986" + 0.003*"1260" + 0.003*"1930" + 0.003*"1922" + 0.003*"1242" + 0.003*"1964" + 0.003*"1971" + 0.003*"1989" + 0.002*"1995" + 0.002*"1626"'),
 (8,
  '0.004*"1260" + 0.004*"1930" + 0.004*"1559" + 0.003*"1626" + 0.002*"440" + 0.002*"1971" + 0.002*"151" + 0.002*"1651" + 0.002*"1697" + 0.002*"1802"'),
 (9,
  '0.004*"1930" + 0.003*"1260" + 0.003*"1989" + 0.003*"1242" + 0.002*"1974" + 0.002*"1802" + 0.002*"951" + 0.002*"1559" + 0.002*"151" + 0.002*"1327"'),
 (10,
  '0.005*"1986" + 0.003*"1930" + 0.003*"1626" + 0.003*"1260" + 0.002*"1559" + 0.002*"1964" + 0.002*"1545" + 0.002*"1242" + 0.002*"1896" + 0.002*"1446"'),
 (11,
  '0.003*"1930" + 0.003*"976" + 0.003*"1626" + 0.002*"1986" + 0.002*"1446" + 0.002*"1974" + 0.002*"151" + 0.002*"1784" + 0.002*"887" + 0.002*"468"'),
 (12,
  '0.006*"1559" + 0.003*"151" + 0.003*"1784" + 0.003*"976" + 0.002*"1626" + 0.002*"1865" + 0.002*"1060" + 0.002*"19" + 0.002*"1446" + 0.002*"951"'),
 (13,
  '0.003*"1930" + 0.003*"1986" + 0.003*"1999" + 0.003*"1260" + 0.002*"1971" + 0.002*"1784" + 0.002*"1626" + 0.002*"1327" + 0.002*"1559" + 0.002*"1792"'),
 (14,
  '0.005*"1260" + 0.004*"1999" + 0.004*"1930" + 0.003*"1986" + 0.003*"1971" + 0.003*"1242" + 0.003*"1559" + 0.003*"1995" + 0.003*"1988" + 0.003*"1644"')]
import pyLDAvis
import pyLDAvis.gensim

pyLDAvis.enable_notebook()
vis = pyLDAvis.gensim.prepare(lda_model, corpus, dictionary)
vis