Created
November 7, 2017 15:45
-
-
Save danemacaulay/217619291c36e3561ba24f30a7bd78de to your computer and use it in GitHub Desktop.
show important features of a text classifier pipeline
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from operator import itemgetter | |
def show_most_informative_features(model, text=None, n=50): | |
""" | |
Accepts a Pipeline with a classifer and a TfidfVectorizer and computes | |
the n most informative features of the model. If text is given, then will | |
compute the most informative features for classifying that text. | |
Note that this function will only work on linear models with coefs_ | |
""" | |
# Extract the vectorizer and the classifier from the pipeline | |
vectorizer = model.named_steps['vectorizer'] | |
classifier = model.named_steps['classifier'] | |
# Check to make sure that we can perform this computation | |
if not hasattr(classifier, 'coef_'): | |
raise TypeError( | |
"Cannot compute most informative features on {} model.".format( | |
classifier.__class__.__name__ | |
) | |
) | |
if text is not None: | |
# Compute the coefficients for the text | |
tvec = model.transform([text]).toarray() | |
else: | |
# Otherwise simply use the coefficients | |
tvec = classifier.coef_ | |
# Zip the feature names with the coefs and sort | |
coefs = sorted( | |
zip(tvec[0], vectorizer.get_feature_names()), | |
key=itemgetter(0), reverse=True, | |
) | |
topn = zip(coefs[:n], coefs[:-(n + 1):-1]) | |
# Create the output string to return | |
output = [] | |
# If text, add the predicted value to the output. | |
if text is not None: | |
output.append("\"{}\"".format(text)) | |
output.append("Classified as: {}".format(model.predict([text]))) | |
output.append("") | |
# Create two columns with most negative and most positive features. | |
for (cp, fnp), (cn, fnn) in topn: | |
output.append( | |
"{:0.4f}{: >15} {:0.4f}{: >15}".format(cp, fnp, cn, fnn) | |
) | |
return "\n".join(output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment