Created
November 17, 2023 20:23
-
-
Save jsundram/63138254a0a378e1706fd77c8bcdf8b4 to your computer and use it in GitHub Desktop.
Use word embeddings to generate better guesses at NYT Connections. TODO: come up with an algorithm to actually solve connections... https://www.nytimes.com/games/connections
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
from scipy import spatial | |
from sklearn.manifold import TSNE | |
from sklearn.cluster import SpectralClustering, KMeans | |
import functools | |
import heapq | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import os | |
import pickle | |
""" | |
Usage: eventually find some way to call this from the command line on a json | |
representation of the puzzle. For now: | |
1. open IPython | |
2. import solve | |
3. Read embeddings in: `d = solve.get_embeddings(glove_file)` | |
4. Supply an answer blob that looks like BLOB below. | |
5. guess(BLOB, d) | |
""" | |
# Sample puzzle answers | |
BLOB = """Brought to the Beach - Umbrella, Sunscreen, Towel, Flip-Flop | |
Types of French Fries - Curly, Shoestring, Waffle, Wedge | |
Equivocate - Hedge, See-Saw, Waver, Yo-Yo | |
Second Words of Vodka Cocktails - Russian, Breeze, Mary, Mule""" | |
def get_embeddings(glovef): | |
"""First need to download Wikipedia 2014 + Gigaword 5 vectors from: | |
http://nlp.stanford.edu/data/glove.6B.zip | |
and extract them. Because they take a while to load on each run, | |
I've pickled them for fast reloading. | |
""" | |
pkl = "./glove.pkl" | |
if os.path.exists(pkl): | |
with open(pkl, "rb") as f: | |
return pickle.load(f) | |
# Takes about 60s... | |
print("Reading embeddings from %s ..." % glovef) | |
d = {} | |
with open(glovef, "r", encoding="utf-8") as f: | |
for line in f: | |
values = line.split() | |
word = values[0] | |
vector = np.asarray(values[1:], "float32") | |
d[word] = vector | |
with open(pkl, "wb") as f: | |
pickle.dump(d) | |
return d | |
def find_closest(word, edict, n): | |
vec = edict[word] | |
distance = functools.partial(spatial.distance.euclidean, v=vec) | |
return heapq.nsmallest(n + 1, edict, key=lambda w: distance(edict[w]))[1:] | |
def get_words(blob=BLOB, delim=" - "): | |
groups = {} | |
for line in blob.split("\n"): | |
groupname, items = line.strip().split(delim) | |
groups[groupname] = items.split(", ") | |
words = [] | |
for vl in groups.values(): | |
words.extend(vl) | |
return groups, words | |
def plot(words, e, cats, postfix="truth"): | |
print("clustering") | |
p = 6 # NB this needs to be less then len(words). | |
tsne = TSNE(n_components=2, random_state=0, perplexity=p) | |
Y = tsne.fit_transform(np.array([e[w] for w in words])) | |
print("Plotting %s" % postfix) | |
clr = ["C%d" % cats[w] for w in words] | |
plt.scatter(Y[:, 0], Y[:, 1], c=clr, marker="None") | |
for label, x, y in zip(words, Y[:, 0], Y[:, 1]): | |
clr = "C%d" % cats[label] | |
plt.annotate( | |
label, | |
xy=(x, y), | |
xytext=(0, 0), | |
textcoords="offset points", | |
ha="center", | |
va="center", | |
c=clr, | |
fontsize=30, | |
) | |
plt.title( | |
"TSNE plot of %d words from GloVe 6B.300d vectors, <perplexity=%2.1f>" | |
% (len(words), p) | |
) | |
plt.savefig("./%s.png" % "-".join(words[:5] + [postfix])) | |
plt.close() | |
def form_groups(labels, words): | |
groups = defaultdict(list) | |
for l, w in zip(labels, words): | |
groups[l].append(w) | |
grouped = [ws for ws in groups.values()] | |
cats = {w: l for (l, w) in zip(labels, words)} | |
return grouped, cats | |
def cluster_kmeans(words, e): | |
X = np.array([e[w] for w in words]) | |
km = KMeans(n_clusters=4, random_state=0, n_init="auto").fit(X) | |
return form_groups(km.labels_, words) | |
def cluster_spectral(words, e): | |
X = np.array([e[w] for w in words]) | |
sc = SpectralClustering( | |
4, affinity="nearest_neighbors", n_neighbors=3, assign_labels="cluster_qr" | |
) | |
sc.fit_predict(X) | |
return form_groups(sc.labels_, words) | |
def guess(blob, d): | |
catdict, words = get_words(blob) | |
print("Tetting vectors for words: %s" % words) | |
e = {w: d[w.lower()] for w in words} | |
cats = {w: i for (i, (cat, ws)) in enumerate(catdict.items()) for w in ws} | |
plot(words, e, cats) | |
grouped, kcats = cluster_kmeans(words, e) | |
print("Guess: kmeans", kcats) | |
plot(words, e, kcats, "kmeans") | |
grouped, skats = cluster_spectral(words, e) | |
print("Guess: spectral", skats) | |
plot(words, e, skats, "spectral") | |
def main(): | |
glovef = "/Users/jsundram/Downloads/glove.6B/glove.6B.300d.txt" # 400k lines | |
d = get_embeddings(glovef) | |
guess(BLOB, d) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment