Last active
April 13, 2020 17:57
-
-
Save apoorvnandan/6f91ea3f5d1734485a37065a4202b2cd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def prefix_beam_search(ctc, | |
alphabet, | |
blank_token, | |
end_token, | |
space_token, | |
lm, | |
k=25, | |
alpha=0.30, | |
beta=5, | |
prune=0.001): | |
''' | |
function to perform prefix beam search on output ctc matrix and return the best string | |
:param ctc: output matrix | |
:param alphabet: list of strings in the order their probabilties are present in ctc output | |
:param blank_token: string representing blank token | |
:param end_token: string representing end token | |
:param space_token: string representing space token | |
:param lm: function to calculate language model probability of given string | |
:param k: threshold for selecting the k best prefixes at each timestep | |
:param alpha: language model weight (b/w 0 and 1) | |
:param beta: language model compensation (should be proportional to alpha) | |
:param prune: threshold on the output matrix probability of a character. | |
If the probability of a character is less than this threshold, we do not extend the prefix with it | |
:return: best string | |
''' | |
zero_pad = np.zeros((ctc.shape[0]+1,ctc.shape[1])) | |
zero_pad[1:,:] = ctc | |
ctc = zero_pad | |
total_timesteps = ctc.shape[0] | |
# #### Initialization #### | |
null_token = '' | |
Pb, Pnb = Cache(), Cache() | |
Pb.add(0,null_token,1) | |
Pnb.add(0,null_token,0) | |
prefix_list = [null_token] | |
# #### Iterations #### | |
for timestep in range(1, total_timesteps): | |
pruned_alphabet = [alphabet[i] for i in np.where(ctc[timestep] > prune)[0]] | |
for prefix in prefix_list: | |
if len(prefix) > 0 and prefix[-1] == end_token: | |
Pb.add(timestep,prefix,Pb.get(timestep - 1,prefix)) | |
Pnb.add(timestep,prefix,Pnb.get(timestep - 1,prefix)) | |
continue | |
for character in pruned_alphabet: | |
character_index = alphabet.index(character) | |
# #### Iterations : Case A #### | |
if character == blank_token: | |
value = Pb.get(timestep,prefix) + ctc[timestep][character_index] * (Pb.get(timestep - 1,prefix) + Pnb.get(timestep - 1,prefix)) | |
Pb.add(timestep,prefix,value) | |
else: | |
prefix_extended = prefix + character | |
# #### Iterations : Case C #### | |
if len(prefix) > 0 and character == prefix[-1]: | |
value = Pnb.get(timestep,prefix_extended) + ctc[timestep][character_index] * Pb.get(timestep-1,prefix) | |
Pnb.add(timestep,prefix_extended,value) | |
value = Pnb.get(timestep,prefix) + ctc[timestep][character_index] * Pnb.get(timestep-1,prefix) | |
Pnb.add(timestep,prefix,value) | |
# #### Iterations : Case B #### | |
elif len(prefix.replace(space_token, '')) > 0 and character in (space_token, end_token): | |
lm_prob = lm(prefix_extended.strip(space_token + end_token)) ** alpha | |
value = Pnb.get(timestep,prefix_extended) + lm_prob * ctc[timestep][character_index] * (Pb.get(timestep-1,prefix) + Pnb.get(timestep-1,prefix)) | |
Pnb.add(timestep,prefix_extended,value) | |
else: | |
value = Pnb.get(timestep,prefix_extended) + ctc[timestep][character_index] * (Pb.get(timestep-1,prefix) + Pnb.get(timestep-1,prefix)) | |
Pnb.add(timestep,prefix_extended,value) | |
if prefix_extended not in prefix_list: | |
value = Pb.get(timestep,prefix_extended) + ctc[timestep][-1] * (Pb.get(timestep-1,prefix_extended) + Pnb.get(timestep-1,prefix_extended)) | |
Pb.add(timestep,prefix_extended,value) | |
value = Pnb.get(timestep,prefix_extended) + ctc[timestep][character_index] * Pnb.get(timestep-1,prefix_extended) | |
Pnb.add(timestep,prefix_extended,value) | |
prefix_list = get_k_most_probable_prefixes(Pb,Pnb,timestep,k,beta) | |
# #### Output #### | |
return prefix_list[0].strip(end_token) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment