Skip to content

Instantly share code, notes, and snippets.

@softwaredoug
Created January 1, 2023 20:51
Show Gist options
  • Save softwaredoug/804eb9cb960f722f0c46d355a21936ba to your computer and use it in GitHub Desktop.
Save softwaredoug/804eb9cb960f722f0c46d355a21936ba to your computer and use it in GitHub Desktop.
Encoding wikipedia sentences with sentence encoder
import numpy as np
import os
from time import perf_counter
from sentence_transformers import SentenceTransformer, LoggingHandler
import logging
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.DEBUG,
handlers=[LoggingHandler()])
def encode(sentences, chunk_size=20000):
print("Loaded sentences")
model_mini = SentenceTransformer('all-MiniLM-L6-v2')
model_mpnet = SentenceTransformer('all-mpnet-base-v2')
pool = model_mpnet.start_multi_process_pool()
start = perf_counter()
# pool = model.start_multi_process_pool()
for chunk in range(0, len(sentences), chunk_size):
mini_fname = f"data/wikisent2_{chunk}.npz"
mpnet_fname = f"data/wikisent2-mpnet_{chunk}.npz"
begin = chunk
end = chunk + chunk_size
if not os.path.exists(mini_fname):
print(f"Processing mini {chunk}")
embeddings = model_mini.encode(sentences[begin:end],
show_progress_bar=True)
print(f"Encoded sentences chunk {chunk} ({begin}-{end}) - {perf_counter() - start}")
np.savez(mini_fname, embeddings)
print("Saved sentences")
else:
print(f"Skipping mini {chunk}")
if not os.path.exists(mpnet_fname):
print(f"Processing mpnet {chunk}")
embeddings = model_mpnet.encode_multi_process(sentences[begin:end], pool)
print(f"Encoded sentences chunk {chunk} ({begin}-{end}) - {perf_counter() - start}")
np.savez(mpnet_fname, embeddings)
print("Saved sentences")
else:
print(f"Skipping mpnet {chunk}")
def append(encoding="mini"):
# Iterate all files in data/
# Load them and append to a single file
# This is to make it easier to load the data
# in the future
if encoding == "mini":
encoding = ""
files = []
# Get all wikisent2_*.npz files in a list
for fname in os.listdir("data"):
if encoding != "":
if fname.startswith(f"wikisent2-{encoding}") and fname.endswith(".npz"):
files.append(fname)
# Sort by chunk number
files.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
# Load and append into one numpy array
arrs = []
for fname in files:
print(f"Loading {fname}")
arrs.append(np.load(f"data/{fname}").get("arr_0"))
print("Concatenating")
arr = np.concatenate(arrs)
print(arr.shape)
np.savez("data/wikisent2_{encoding}_all.npz", arr)
if __name__ == "__main__":
sentences = []
with open('wikisent2.txt') as f:
sentences = [line for line in f]
encode(sentences)
# append("mini")
append("mpnet")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment