Created
September 6, 2020 03:02
-
-
Save kaniska/3ae7368c7566456df1e78ae72b2ed751 to your computer and use it in GitHub Desktop.
Semantic Similarity using Spark NLP and Spark ML
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Databricks notebook source | |
import com.johnsnowlabs.nlp.annotators.Tokenizer | |
import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector | |
import com.johnsnowlabs.nlp.embeddings.{BertEmbeddings, SentenceEmbeddings, WordEmbeddingsModel} | |
import com.johnsnowlabs.nlp.{DocumentAssembler, EmbeddingsFinisher, RecursivePipeline} | |
import org.apache.spark.ml.PipelineModel | |
import org.apache.spark.ml.feature.{BucketedRandomProjectionLSH, BucketedRandomProjectionLSHModel, LSH, Normalizer, SQLTransformer} | |
import org.apache.spark.ml.feature.{MinHashLSH, MinHashLSHModel} | |
import org.apache.spark.sql.SparkSession | |
import org.apache.spark.sql.functions._ | |
// COMMAND ---------- | |
var inputTable: String = "test_data" | |
var algoStage: String = "buildModel" | |
var embeddingType: String = "basic" | |
// COMMAND ---------- | |
val df = spark.read.table(inputTable) | |
val primaryCorpus = df.select("text2").withColumnRenamed("text2","text") | |
primaryCorpus.show(false) | |
/** | |
+-----------------------------------------+ | |
|text | | |
+-----------------------------------------+ | |
|Music Books | | |
|Books | | |
|Printer Cartridges Printers & All-in-Ones| | |
|Printer Refills | | |
|Books Manuals & Guides | | |
|Wallets Perfumes & Colognes | | |
+-----------------------------------------+ | |
**/ | |
val secondaryCorpus = df.select("text1").withColumnRenamed("text1","text") | |
secondaryCorpus.show(false) | |
/** | |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | |
|text | | |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | |
|Loudspeaker Cabinets Automotive Electrical Parts & Accessories Mouse Pads Cell Phone Cases Sheet Music Music Network Cables Computer Cables Tripods Books Cable Connectors Stringed Instrument Replacement Parts Power Cables Video Games| | |
|Books | | |
|Printer Cartridges | | |
|Printer Cartridges | | |
|Books | | |
|Wallets Money Clips | | |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | |
**/ | |
// COMMAND ---------- | |
def buildPipeline(): Unit = { | |
val documentAssembler = new DocumentAssembler().setInputCol("text").setOutputCol("document") | |
val sentence = new SentenceDetector() | |
.setInputCols("document") | |
.setOutputCol("sentence") | |
.setExplodeSentences(false) | |
val tokenizer = new Tokenizer() | |
.setInputCols(Array("sentence")) | |
.setOutputCol("token") | |
def embeddings = { | |
if("basic".equalsIgnoreCase(embeddingType)) { | |
WordEmbeddingsModel.pretrained("glove_100d", "en") | |
.setInputCols("sentence", "token") | |
.setOutputCol("embeddings") | |
.setCaseSensitive(false) | |
}else if("bert".equalsIgnoreCase(embeddingType)) { | |
BertEmbeddings | |
.pretrained("bert_base_cased", "en") | |
.setInputCols(Array("sentence", "token")) | |
.setOutputCol("embeddings") | |
.setCaseSensitive(false) | |
.setPoolingLayer(0) | |
}else{ | |
null | |
} | |
} | |
val embeddingsSentence = new SentenceEmbeddings() | |
.setInputCols(Array("sentence", "embeddings")) | |
.setOutputCol("sentence_embeddings") | |
.setPoolingStrategy("AVERAGE") | |
val embeddingsFinisher = new EmbeddingsFinisher() | |
.setInputCols("sentence_embeddings", "embeddings") | |
.setOutputCols("sentence_embeddings_vectors", "embeddings_vectors") | |
.setOutputAsVector(true) | |
.setCleanAnnotations(false) | |
val explodeVectors = new SQLTransformer().setStatement("SELECT EXPLODE(sentence_embeddings_vectors) AS features, * FROM __THIS__") | |
val vectorNormalizer = new Normalizer() | |
.setInputCol("features") | |
.setOutputCol("normFeatures") | |
.setP(1.0) | |
val similartyChecker = new BucketedRandomProjectionLSH().setInputCol("features").setOutputCol("hashes").setBucketLength(6.0).setNumHashTables(6) | |
val pipeline = new RecursivePipeline() | |
.setStages(Array(documentAssembler, | |
sentence, | |
tokenizer, | |
embeddings, | |
embeddingsSentence, | |
embeddingsFinisher, | |
explodeVectors, | |
vectorNormalizer, | |
similartyChecker)) | |
val pipelineModel = pipeline.fit(primaryCorpus) | |
pipelineModel.write.overwrite().save("/tmp/spark-nlp-model-v1") | |
} | |
// COMMAND ---------- | |
// COMMAND ---------- | |
import org.apache.spark.sql.functions.{col, udf} | |
val score = udf((s: Long) => (100-s)) | |
// COMMAND ---------- | |
def findSimilarity(): Unit = { | |
// load it back in during production | |
val similarityCheckingModel = PipelineModel.load("/tmp/spark-nlp-model-v1") | |
val primaryDF = similarityCheckingModel.transform(primaryCorpus) | |
val dfA = primaryDF.select("text","features","normFeatures").withColumn("rowkey",monotonically_increasing_id()) | |
//dfA.show() | |
val secondaryDF = similarityCheckingModel.transform(secondaryCorpus) | |
val dfB = secondaryDF.select("text","features","normFeatures").withColumn("rowkey",monotonically_increasing_id()) | |
//dfB.show() | |
//Feature Transformation | |
print("Approximately joining dfA and dfB :") | |
// BucketedRandomProjectionLSH | |
similarityCheckingModel.stages.last.asInstanceOf[BucketedRandomProjectionLSHModel].approxSimilarityJoin(dfA, dfB, 100) | |
.where(col("datasetA.rowkey") === col("datasetB.rowkey")) | |
.select(col("datasetA.text").alias("text1"), | |
col("datasetB.text").alias("text2"), | |
score(col("distCol")).alias("score")).show() | |
} | |
// COMMAND ---------- | |
buildPipeline() | |
// COMMAND ---------- | |
algoStage = "findSimilarity" | |
print(algoStage) | |
findSimilarity() | |
/** | |
findSimilarityApproximately joining dfA and dfB | |
:+--------------------+--------------------+-----+ | |
| text1| text2|score| | |
+--------------------+--------------------+-----+ | |
|Books Manuals & G...| Books| 97| | |
|Wallets Perfumes ...| Wallets Money Clips| 97| | |
| Printer Refills| Printer Cartridges| 98| | |
| Music Books|Loudspeaker Cabin...| 96| | |
| Books| Books| 100| | |
|Printer Cartridge...| Printer Cartridges| 98| | |
+--------------------+--------------------+-----+ | |
**/ | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment