- Start a new project in IntelliJ or in an IDE of your choice. Make sure the necessary JAR files are included.
- The package statement for the recipe is as follows:
package spark.ml.cookbook.chapter12
- Import the necessary packages for Scala and Spark:
import edu.umd.cloud9.collection.wikipedia.WikipediaPage
import edu.umd.cloud9.collection.wikipedia.language.EnglishWikipediaPage
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.Text
import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.clustering.LDA
import org.apache.spark.ml.feature._
import org.apache.spark.sql.SparkSession
- We define a function to parse a Wikipedia page and return the title and content text of the page:
def parseWikiPage(rawPage: String): Option[(String, String)] = {
val wikiPage = new EnglishWikipediaPage()
WikipediaPage.readPage(wikiPage, rawPage)
if (wikiPage.isEmpty
|| wikiPage.isDisambiguation
|| wikiPage.isRedirect
|| !wikiPage.isArticle) {
None
} else {
Some(wikiPage.getTitle, wikiPage.getContent)
}
}
- Let us define the location of the Wikipedia data dump:
val input = "../data/sparkml2/chapter12/enwiki_dump.xml"
- We create a job configuration for Hadoop XML streaming:
val jobConf = new JobConf()
jobConf.set("stream.recordreader.class", "org.apache.hadoop.streaming.StreamXmlRecordReader")
jobConf.set("stream.recordreader.begin", "<page>")
jobConf.set("stream.recordreader.end", "</page>")
- We set up the data path for Hadoop XML streaming processing:
FileInputFormat.addInputPath(jobConf, new Path(input))
- Create a SparkSession with configurations using the factory builder pattern:
val spark = SparkSession
.builder
.master("local[*]")
.appName("ProcessLDA App")
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.sql.warehouse.dir", ".")
.getOrCreate()
- We should set the logging level to warning, otherwise output will be difficult to follow:
Logger.getRootLogger.setLevel(Level.WARN)
- We begin to process the huge Wikipedia data dump into article pages taking a sample of the file:
val wikiData = spark.sparkContext.hadoopRDD(
jobConf,
classOf[org.apache.hadoop.streaming.StreamInputFormat],
classOf[Text],
classOf[Text]).sample(false, .1)
- Next, we process our sample data into an RDD containing a tuple of title and page context text to finally generate a DataFrame:
val df = wiki.map(_._1.toString)
.flatMap(parseWikiPage)
.toDF("title", "text")
- We now transform the text column of the DataFrame into raw words using Spark's RegexTokenizer for each Wikipedia page:
val tokenizer = new RegexTokenizer()
.setPattern("\W+")
.setToLowercase(true)
.setMinTokenLength(4)
.setInputCol("text")
.setOutputCol("raw")
val rawWords = tokenizer.transform(df)
- The next step is to filter raw words by removing all stop words from the tokens:
val stopWords = new StopWordsRemover()
.setInputCol("raw")
.setOutputCol("words")
.setCaseSensitive(false)
val wordData = stopWords.transform(rawWords)
- We generate term counts for the filtered tokens by using Spark's CountVectorizer class, resulting in a new DataFrame containing the column features:
val cvModel = new CountVectorizer()
.setInputCol("words")
.setOutputCol("features")
.setMinDF(2)
.fit(wordData)
val cv = cvModel.transform(wordData)
cv.cache()
The "MinDF" specifies the minimum number of different document terms that must appear in order to be included in the vocabulary.
- We now invoke Spark's LDA class to generate topics and the distributions of tokens to topics:
val lda = new LDA()
.setK(5)
.setMaxIter(10)
.setFeaturesCol("features")
val model = lda.fit(tf)
val transformed = model.transform(tf)
The "K" refers to how many topics and "MaxIter" maximum iterations to execute.
- We finally describe the top five generated topics and display:
val topics = model.describeTopics(5)
topics.show(false)
- Now display, topics and terms associated with them:
val vocaList = cvModel.vocabulary
topics.collect().foreach { r => {
println(" Topic: " + r.get(r.fieldIndex("topic")))
val y = r.getSeq[Int](r.fieldIndex("termIndices")).map(vocaList(_))
.zip(r.getSeq[Double](r.fieldIndex("termWeights")))
y.foreach(println)
}
}
The console output will be as follows:
- We close the program by stopping the SparkContext:
spark.stop()