Skip to content

Instantly share code, notes, and snippets.

@jsundram
Created November 17, 2023 20:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jsundram/63138254a0a378e1706fd77c8bcdf8b4 to your computer and use it in GitHub Desktop.
Save jsundram/63138254a0a378e1706fd77c8bcdf8b4 to your computer and use it in GitHub Desktop.
Use word embeddings to generate better guesses at NYT Connections. TODO: come up with an algorithm to actually solve connections... https://www.nytimes.com/games/connections
from collections import defaultdict
from scipy import spatial
from sklearn.manifold import TSNE
from sklearn.cluster import SpectralClustering, KMeans
import functools
import heapq
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle
"""
Usage: eventually find some way to call this from the command line on a json
representation of the puzzle. For now:
1. open IPython
2. import solve
3. Read embeddings in: `d = solve.get_embeddings(glove_file)`
4. Supply an answer blob that looks like BLOB below.
5. guess(BLOB, d)
"""
# Sample puzzle answers
BLOB = """Brought to the Beach - Umbrella, Sunscreen, Towel, Flip-Flop
Types of French Fries - Curly, Shoestring, Waffle, Wedge
Equivocate - Hedge, See-Saw, Waver, Yo-Yo
Second Words of Vodka Cocktails - Russian, Breeze, Mary, Mule"""
def get_embeddings(glovef):
"""First need to download Wikipedia 2014 + Gigaword 5 vectors from:
http://nlp.stanford.edu/data/glove.6B.zip
and extract them. Because they take a while to load on each run,
I've pickled them for fast reloading.
"""
pkl = "./glove.pkl"
if os.path.exists(pkl):
with open(pkl, "rb") as f:
return pickle.load(f)
# Takes about 60s...
print("Reading embeddings from %s ..." % glovef)
d = {}
with open(glovef, "r", encoding="utf-8") as f:
for line in f:
values = line.split()
word = values[0]
vector = np.asarray(values[1:], "float32")
d[word] = vector
with open(pkl, "wb") as f:
pickle.dump(d)
return d
def find_closest(word, edict, n):
vec = edict[word]
distance = functools.partial(spatial.distance.euclidean, v=vec)
return heapq.nsmallest(n + 1, edict, key=lambda w: distance(edict[w]))[1:]
def get_words(blob=BLOB, delim=" - "):
groups = {}
for line in blob.split("\n"):
groupname, items = line.strip().split(delim)
groups[groupname] = items.split(", ")
words = []
for vl in groups.values():
words.extend(vl)
return groups, words
def plot(words, e, cats, postfix="truth"):
print("clustering")
p = 6 # NB this needs to be less then len(words).
tsne = TSNE(n_components=2, random_state=0, perplexity=p)
Y = tsne.fit_transform(np.array([e[w] for w in words]))
print("Plotting %s" % postfix)
clr = ["C%d" % cats[w] for w in words]
plt.scatter(Y[:, 0], Y[:, 1], c=clr, marker="None")
for label, x, y in zip(words, Y[:, 0], Y[:, 1]):
clr = "C%d" % cats[label]
plt.annotate(
label,
xy=(x, y),
xytext=(0, 0),
textcoords="offset points",
ha="center",
va="center",
c=clr,
fontsize=30,
)
plt.title(
"TSNE plot of %d words from GloVe 6B.300d vectors, <perplexity=%2.1f>"
% (len(words), p)
)
plt.savefig("./%s.png" % "-".join(words[:5] + [postfix]))
plt.close()
def form_groups(labels, words):
groups = defaultdict(list)
for l, w in zip(labels, words):
groups[l].append(w)
grouped = [ws for ws in groups.values()]
cats = {w: l for (l, w) in zip(labels, words)}
return grouped, cats
def cluster_kmeans(words, e):
X = np.array([e[w] for w in words])
km = KMeans(n_clusters=4, random_state=0, n_init="auto").fit(X)
return form_groups(km.labels_, words)
def cluster_spectral(words, e):
X = np.array([e[w] for w in words])
sc = SpectralClustering(
4, affinity="nearest_neighbors", n_neighbors=3, assign_labels="cluster_qr"
)
sc.fit_predict(X)
return form_groups(sc.labels_, words)
def guess(blob, d):
catdict, words = get_words(blob)
print("Tetting vectors for words: %s" % words)
e = {w: d[w.lower()] for w in words}
cats = {w: i for (i, (cat, ws)) in enumerate(catdict.items()) for w in ws}
plot(words, e, cats)
grouped, kcats = cluster_kmeans(words, e)
print("Guess: kmeans", kcats)
plot(words, e, kcats, "kmeans")
grouped, skats = cluster_spectral(words, e)
print("Guess: spectral", skats)
plot(words, e, skats, "spectral")
def main():
glovef = "/Users/jsundram/Downloads/glove.6B/glove.6B.300d.txt" # 400k lines
d = get_embeddings(glovef)
guess(BLOB, d)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment