9 March 2015

Naive Bayes on Apache Flink

In this blog post we are going to implement a Naive Bayes classifier in Apache Flink. We are going to use it for text classification by applying it to the 20 Newsgroup dataset. To understand what is going on, you should be familiar with Java and know what MapReduce is. If you have seen and understood a word count example in any system, you're good to go. If you haven't heard of MapReduce or haven't seen the word count, you may first have a look at our introductory post "Hadoop and MapReduce".

Here we will use:

  • Java 7 (or 8)
  • Apache Maven
  • Apache Flink
  • Stanford NLP for text preprocessing

Apache Flink is a system for large scale data processing, like Apache Hadoop or Apache Spark. In this post we will use the Java API because it's simple to understand, although more verbose than the Scala API. The mapping to the Scala API should be straightforward, and if you want, you can implement the same thing in Spark with only minor modifications.

You don't need to install Apache Flink or build it from the sources, but you do need to have Apache Maven installed. You can download Maven from its official website, but if you have latest Eclipse, IntelliJ Idea or NetBeans, they should already have it inside, so just make sure that you create a Maven project in your IDE. If you're not familiar with Maven, here's a getting started guide.

Theory

Text Classification

Let's quickly review the theory. In Machine Learning, classification is the problem of identifying to which class an observation belongs in. In text classification, the observations are documents, and the classes are document categories. So, based on some criteria, we need to assign a document to some category. A good example of text classification is deciding whether an email is spam or not. However, it doesn't have to be only two categories, and in this article we well try to assign documents to one out of 20 categories.

How do we represent a document? The usual way is to see a document as a collection of words: $$\text{doc} = ( w_1 , w_2 , \ ... \ , w_n )$$ For simplicity, we will assume that the order of words in the document and the grammatical relationships between the words are not important, and the only important thing is the words themselves and how many occurrences of each word is there. In Information Retrieval and NLP this representation is called the Bag-of-Words model. This representation should be good enough for classification purposes: we may assume that some words tend to occur more frequently in one specific category than in others.

Next, for text classification, the Bayes Rule says that the probability that a document $\text{doc}$ belongs to the category $\text{cat}$ is given by $$P(\text{cat} \mid \text{doc}) = \cfrac{P(\text{doc} \mid \text{cat}) \cdot P(\text{cat})}{P(\text{doc})}$$

Thomas Bayes

Naive Bayes

The Naive Bayes assumption is that all the words in $\text{doc}$ are independent. That is, if a document $\text{doc}$ is represented as $\text{doc} = ( w_1 , w_2 , \ ... \ , w_n )$, then the probability of seeing this document is $$P(\text{doc} \mid \text{cat}) = P( w_1 , w_2 , \ ... \ , w_n \mid \text{cat}) = \prod_i P(w_i \mid \text{cat})$$

To assign a document to some category, we select the category $\text{cat}^*$ with highest $P(\text{cat} \mid \text{doc})$:

$$\begin{align} \text{cat}^* & = \operatorname{argmax}_\text{cat} \cfrac{P(\text{doc} \mid \text{cat}) \cdot P(\text{cat})}{P(\text{doc})} \\ & = \operatorname{argmax}_\text{cat} \cfrac{\prod P(w_i \mid \text{cat}) P(\text{cat})}{P(\text{doc})} \end{align}$$

We can get rid of the denominator because it doesn't depend on $\text{cat}$:

$$\text{cat}^* = \operatorname{argmax}_\text{cat} \prod P(w_i \mid \text{cat}) P(\text{cat})$$

Note that we have to multiply lots of probability values, and many of them could be very small, and this may cause numerical underflow: when some numerical value is too small to be stored in memory. The typical solution to this problem is to take a log of this expression. Since logarithm is a monotonically increasing function, we get

$$\begin{align} \text{cat}^* & = \operatorname{argmax}_\text{cat} \prod_i P(w_i \mid \text{cat}) P(\text{cat}) \\ & = \operatorname{argmax}_\text{cat} \log \left[ \prod_i P(w_i \mid \text{cat}) P(\text{cat}) \right] \\ & = \operatorname{argmax}_\text{cat} \left[ \sum_i \log P(w_i \mid \text{cat}) + \log P(\text{cat}) \right] \end{align}$$


Estimating $P(\text{cat})$ and $P(w_i \mid \text{cat})$

So we have our formula:

$$\text{cat}^* = \operatorname{argmax}_\text{cat} \left[ \sum_i \log P(w_i \mid \text{cat}) + \log P(\text{cat}) \right]$$

We need two things here: $P(\text{cat})$ and $P(w_i \mid \text{cat})$. We estimate them using the data and obtain $\hat P(\text{cat})$ and $\hat P(w_i \mid \text{cat})$.

Estimation is very easy - it's just counting:

  • $\hat P(\text{cat}) = \cfrac{\text{count(cat)}}{\text{total count}}$: we just count how many documents of this category are there
  • $\hat P(w_i \mid \text{cat}) = \cfrac{ \text{count($w_i$ in cat)} }{\sum_j \text{count($w_j$ in cat)}}$: in each category, we count how many times times a word $w_i$ occurs and divide it by the total number of words in the category

So, with minor modification, we can reduce our calculations to the word count problem!


Smoothing

What happens if we never see a certain word $w_0$ during the training, but it occurs during classification? In this case $\hat P(w_0 \mid \text{cat}) = 0$, and thus $\hat P(\text{doc} \mid \text{cat})$ is also zero: $P(\text{doc} \mid \text{cat}) = \prod\limits_i \hat P(w_i \mid \text{cat}) = 0$. We also would have problems in the log domain: $\hat P(w_0 \mid \text{cat}) = \log 0 = -\infty$, so $\log P(\text{doc} \mid \text{cat}) = \sum\limits_i \log \hat P(w_i \mid \text{cat}) = -\infty$.

How we can avoid that? Laplace Smoothing (or Additive Smoothing) is a simple and popular technique for this. We add some number $\lambda$ to each count:

$$\hat P(w_i \mid \text{cat}) = \cfrac{ \text{count($w_i$ in cat)} + \lambda }{ \sum_j \left(\text{count($w_j$ in cat)} + \lambda \right)}$$

Thus, unseen words will have estimated probability $$\hat P(w_0 \mid \text{cat}) = \cfrac{ \lambda }{\sum_j \left( \text{count($w_j$ in cat)} + \lambda \right)} \ne 0$$.

We can further rewrite this as

$$\hat P(w_i \mid \text{cat}) = \cfrac{ \text{count($w_i$ in cat)} + \lambda }{\sum_j \text{count($w_j$ in cat)} + \sum_j \lambda }$$

So we can move the $\lambda$ out of this sum and then recognize that in $\sum_j \lambda$ the index $j$ goes over all seen words, so $\sum_j \lambda = \lambda \cdot \text{# distinct words}$:

$$\hat P(w_i \mid \text{cat}) = \cfrac{ \text{count($w_i$ in cat)} + \lambda }{\sum_j \text{count($w_j$ in cat)} + \lambda \cdot \text{# distinct words}}$$

Different values of the smoothing parameter $\lambda$ will lead to different probability estimates, and hence, by varying it, we can obtain different classification results. A very common value is $\lambda = 1$, and it's called "$+1$ smoothing".

This is enough to start coding.

Coding

Data set

There are quite a lot text of data sets for text classification on the Internet. A good collection of them could be found here: http://disi.unitn.it/moschitti/corpora.htm.

We are going to use one of them: the 20 Newsgroup dataset. Here's the description:

it contains 19997 articles for 20 categories taken from the Usenet newsgroups collection. We used the subject and the body of each message only. Some of the newsgroups are very closely related to each other (e.g., IBM computer system hardware / Macintosh computer system hardware), while others are highly unrelated (e.g. misc forsale / social religion and christian). This corpus is different from the previous corpora because it includes a larger vocabulary and words typically have more meanings. Moreover, the stylistic writing (e-mail dialogues) is very distant from the other more technical collections.

We are going to use the one with training/test split, so let's download it:

wget http://qwone.com/~jason/20Newsgroups/20news-bydate.tar.gz
tar -xzf 20news-bydate.tar.gz

If you want, you may skip over the preprocessing steps (discussed below) and download already processed data from here.

Initial Preprocessing

Let's have a look at how the files look like:

From: ho@cs.arizona.edu (Hilarie Orman)
Subject: Re: Licensing of public key implementations
Organization: U of Arizona, CS Dept, Tucson
Lines: 6

With regard to your speculations on NSA involvement in the creation 
of PKP, I find that it fails the test of Occam's butcher knife. Never
attribute to conspiracy what can be explained by forthright greed.

Hilarie Orman

This file above is from the "sci.crypt" category and has id=14989. So we see that it's indeed a big collections of email-like files, each of which is stored in a folder with category name. Such layout is very uncommon in Big Data, because there are lots of very small files, and it's not very effective to store them on a distributed file system. So we first need to convert this set of small files into a big one.

To do this we can create a small python script which will discard the header, remove all the linebreaks and then write the entire content as a tuple (category, content) in a bigger file. The script is very simple, but if you want, you can have a look at it here. The results of this step can be found here.


Further Preprocessing

Now we need to work with the text data more, and it will involve the following steps:

  • tokenization: split a text into separate tokens, e.g. "I love cookies!" to ["I", "love", "cookies", "!"]
  • lemmatization: reduce all forms of the same word to the common one, e.g. "walk", "walking", "walked" all get reduced to "walk"
  • stop words removal: discard all words that carry no meaning but occur very commonly, e.g. "a", "the", "and", etc.

To do this we will use Stanford NLP, a library for natural language processing for Java.

First, let's add the library to our pom:

<dependency>
    <groupId>edu.stanford.nlp</groupId>
    <artifactId>stanford-corenlp</artifactId>
    <version>3.5.1</version>
</dependency>
<dependency>
    <groupId>edu.stanford.nlp</groupId>
    <artifactId>stanford-corenlp</artifactId>
    <version>3.5.1</version>
    <classifier>models</classifier>
</dependency>

We also will need Apache Commons Lang 3:

<dependency>
    <groupId>org.apache.commons</groupId>
    <artifactId>commons-lang3</artifactId>
    <version>3.3.2</version>
</dependency>

And now we can use Stanford NLP to create a pipeline that will tokenize the text and return the lemmas:

// create
Properties props = new Properties();
props.put("annotators", "tokenize, ssplit, pos, lemma");
StanfordCoreNLP pipeline = new StanfordCoreNLP(props);

// use
Annotation document = new Annotation(body);
pipeline.annotate(document);

List<CoreLabel> tokenized = document.get(TokensAnnotation.class);

for (CoreLabel token : tokenized) {
    String lemma = token.get(LemmaAnnotation.class);
    // process lemma
}

(You may notice that we have ssplit and pos: they are necessary for lemmatization, and therefore they are also included to the pipeline.)

In our case, we also need to remove stop words, punctuation marks, emails and other things with no words in them. So we can use this simple filter before returning a word to the user:

private boolean valid(String lemma) {
    if (lemma.length() < 2) {
        return false;
    }

    if (stopwords.contains(lemma)) {
        return false;
    }

    return StringUtils.isAlpha(lemma);
}

The last line uses a method of the StringUtils class from Commons Lang that checks whether a string contains only letters or also something else (digits, special characters, etc).

Lastly, in the pipeline we use a Part-of-Speech tagger before the lemmatization, and "bad" symbols like "^", "~", "#" tend to confuse the tagger, so it also makes sense to remove them before parsing.

The whole class for NLP preprocessing is here.

Now it's time to use Flink: we can use it to apply the NLP preprocessing to each document.

Let's start with adding dependencies to Flink:

<dependency>
    <groupId>org.apache.flink</groupId>
    <artifactId>flink-clients</artifactId>
    <version>0.8.1</version>
    <scope>provided</scope>
</dependency>
<dependency>
    <groupId>org.apache.flink</groupId>
    <artifactId>flink-java</artifactId>
    <version>0.8.1</version>
    <scope>provided</scope>
</dependency>

At the moment of writing, the last version is 0.8.1, but you may want to check if there's a newer version and use it. You can check this in Flink's repository at maven central.

If you have a problem like "missing artifact jdk.tools", then there's a hack to make it go away: you can exclude the dependency on jdk.tools from the flink dependency. See the whole pom.xml for details.

Let's set up a Flink job. For that, create a Java class with the main method and write the following:

ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSource<String> input = env.readTextFile(inputPath);
DataSet<Tuple2<String, String>> output = input.flatMap(new NlpProcessingMapper());
output.writeAsCsv(outputPath, "\n", "\t", WriteMode.OVERWRITE);

ExecutionEnvironment is the class that keeps all the information about the execution of Flink's jobs. If you just run this job locally, it will be executed in the local environment, but if you submit it to the cluster, it will use a different environment, but the user doesn't need to change anything in the code.

The environment knows how to read files from a specified location. In our case, the location is stored in the inputPath variable: this is the path to the data source we want to process. Here we read a tab-separated text file, and then apply a flatMap function to each document (that is, to each line of the input file). FlatMap is similar to the map function, but for each input element it can produce 0 or more output elements. What we want to do here is to take a document, process it, and output a tuple (category, comma-separated-words) for valid documents, or no tuple at all if the input is invalid for some reason. The flatMap function takes an implementation of FlatMapFunction as its argument:

public static class NlpProcessingMapper extends 
        RichFlatMapFunction<String, Tuple2<String, String>> {

    private NlpPreprocessor processor;

    @Override
    public void open(Configuration parameters) throws Exception {
        super.open(parameters);
        processor = NlpPreprocessor.create();
    }

    @Override
    public void flatMap(String value, Collector<Tuple2<String, String>> out) 
                throws Exception {
        String[] split = value.split("\t");
        if (split.length < 3) {
            return;
        }

        String category = split[0];

        List<String> words = processor.processBody(split[2]);
        if (!words.isEmpty()) {
            out.collect(new Tuple2<>(category, StringUtils.join(words, ",")));
        }
    }
}

Note that NlpProcessingMapper doesn't implement the FlatMapFunction interface directly, but instead it extends RichFlatMapFunction, which is an abstract implementation of FlatMapFunction. The reason is that we want to initialize the NlpPreprocessor before processing the data, and this could be done by overriding the open method. This function will be called before the mapper is used on the data, so NlpPreprocessor will be already initialized when we use it inside the flatMap function.

We can run it locally (just press run in your IDE) or on a server. In this post we will run only locally, but it's quite easy to run it on a stand-alone Flink instance, because we don't need to make any changes to the code.

And a final note: if you want speed and don't care much about the form of the words you end up using, you can use stemming intead of lemmatization: it should increase the processing speed because there will be no need to do POS tagging and other expensive things. You can read more on the difference between lemmatization and stemming here. For example you can use a stemmer from Apache Lucene.

Training: Probability Estimation

We are finally ready to move on to the training phase.

Recall the formula:

$$\text{cat}^* = \operatorname{argmax}_\text{cat} \left[ \sum_i \log P(w_i \mid \text{cat}) + \log P(\text{cat}) \right]$$

We reduced the problem to counting words:

$$\hat P(w_i \mid \text{cat}) = \cfrac{ \text{count($w_i$ in cat)} + \lambda }{\sum_j \text{count($w_j$ in cat)} + \lambda \cdot \text{# distinct words}}$$

We have three things to count here:

  1. $\text{count(} w_i \text{ in cat)}$: how many times a word $w_i$ is seen in category $\text{cat}$,
  2. $\sum_j \text{count(} w_j \text{ in cat)}$: the total number word occurrences in category $\text{cat}$,
  3. $\text{# distinct words}$: the total number of distinct words in the corpus.

Now, when we know what we're going to calculate, let's read the training data:

ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSource<String> input = env.readTextFile(Config.TRAIN_DATA);

The first thing is $\text{count(} w_i \text{ in cat)}$: how many times a word $w_i$ is seen in category $\text{cat}$. This is calculated by counting the documents in each category and then by dividing each count by the total number of documents:

DataSet<Tuple2<String, Long>> labelFrequencies = 
        input.map(new LabelExtraction()).groupBy(0).sum(1);

public static class LabelExtraction implements 
            MapFunction<String, Tuple2<String, Long>> {
    @Override
    public Tuple2<String, Long> map(String value) throws Exception {
        return new Tuple2<>(value.split("\t")[0], 1L);
    }
}

It's very similar to the word count problem: for each document we output only its label along with 1, so we have tuples (category, 1), and then we group by the category and sum over all these ones to get the total number of elements in the category.

Next, we need to normalize the counts so they sum up to one, and for this we find the total sum. To do it, we just sum over labelFrequencies again:

DataSet<Tuple1<Long>> totalSum = labelFrequencies.sum(1).project(1);

The sum over (category, count) also produces a tuple with something on the first position and the total count on the second. We are interested only in the count, so we take a projection on the second element (it starts counting from 0).

Now we need to divide all elements of labelFrequencies by the total sum. It's a bit trickier: once we obtained the total count, we need to send this value to all the mappers across the cluster. This is done by "broadcasting" the variable:

DataSet<Tuple2<String, Double>> priors = 
        labelFrequencies.map(new NormalizationMapper())
            .withBroadcastSet(totalSum, "totalSum");

public static class NormalizationMapper extends
        RichMapFunction<Tuple2<String, Long>, Tuple2<String, Double>> {
    private long totalSum;

    @Override
    public void open(Configuration parameters) throws Exception {
        super.open(parameters);
        List<Tuple1<Long>> totalSumList = 
                getRuntimeContext().getBroadcastVariable("totalSum");
        this.totalSum = totalSumList.get(0).f0;
    }

    @Override
    public Tuple2<String, Double> map(Tuple2<String, Long> value) 
                throws Exception {
        return new Tuple2<>(value.f0, ((double) value.f1) / totalSum);
    }
}

Note withBroadcastSet(totalSum, "totalSum"): this is how the value is broadcasted to all the mappers. We again use a rich function RichMapFunction, and it allows us to access the runtime context and get the variable.

So we run this and get the following values:

alt.atheism 0.042
comp.sys.ibm.pc.hardware 0.051
comp.windows.x 0.052
misc.forsale 0.051
rec.autos 0.052
rec.motorcycles 0.052
rec.sport.baseball 0.052
rec.sport.hockey 0.053
... ...

We see that actually all the priors are almost the same for all the classes, so we might just as well skip this computation, but let's keep them anyway for a more general case.

Next, we need to compute the word count per each category. To do this, first, for each word of a document we omit a tuple (category, word, 1), and then after grouping by (category, word) we sum over the last column:

DataSet<Tuple3<String, String, Integer>> labelledWords = 
        input.flatMap(new TokenReaderMapper());
DataSet<Tuple3<String, String, Integer>> wordCount = 
        labelledWords.groupBy(1, 0).sum(2);

TokenReaderMapper here is very simple: it splits the input by \t (tabulation) and then for the second part of the split, it splits further by ,, and finally outputs a tuple (category, word, 1) for each word:

public static class TokenReaderMapper implements 
            FlatMapFunction<String, Tuple3<String, String, Integer>> {

    @Override
    public void flatMap(String inputValue, 
                    Collector<Tuple3<String, String, Integer>> out)
            throws Exception {
        String[] split = inputValue.split("\t");
        String category = split[0];

        for (String word : split[1].split(",")) {
            out.collect(new Tuple3<>(category, word, 1));
        }
    }
}

As the result it will produce tuples like these:

alt.atheismagnosticism12
comp.windows.xdump25
rec.autosbroken2

Finally, for we need to know how many words are there per category, so we count it this way:

DataSet<Tuple2<String, Integer>> countPerCategory = 
        labelledWords.groupBy(0).sum(2).project(0, 2);

Since we have tuples (category, word, 1), we just group by category, sum over the columns of ones and finally keep only (category, count) tuples.

The result is something like this:

alt.atheism 72270
comp.sys.ibm.pc.hardware 51321
comp.windows.x 82099
... ...

You can find the code for the entire class WordCounterJob here.

You may wonder why we created two separate jobs: one for NLP preprocessing and one for counting. There's no particular reason and in the real life you probably may want to do it in one job. But on the other hand splitting it into two parts allows us to store the intermediate results and it's also easier for debugging.

Now we're finally ready for classification.

Classification

Here we will assume that all the counts will fit into memory: all the heavy lifting (i.e. counting) has been done already at the previous step, so we just use the results. For very large datasets it will not be the case, so we would need do some joins before classification. Alternatively we could use approximation techniques like Count-min Sketch algorithm for getting counts of words for each category.

In the classification part for each element of the test set we try to predict the label using the Naive Bayes classifier using the probability estimates from the previous step.

First we need to read things we computed on the previous step:

DataSet<Tuple2<String, Double>> priors = 
        env.readTextFile(Config.OUT_PRIOR).map(new PriorsReaderMapper());
DataSet<Tuple2<String, Integer>> totalCountPerCategory = 
        env.readTextFile(Config.OUT_TOTAL_COUNT_PER_CAT)
           .map(new TotalCountPerCategoryMapper());
DataSet<Tuple3<String, String, Integer>> counts = 
        env.readTextFile(Config.OUT_COND_COUNT)
           .map(new CountsReaderMapper());

All mappers here just convert the text input to tuples.

Each test instance will be classified in a map function ClassifierMapper, and this function will output a tuple (actualClass, predictedClass) for each document in the testing set. But first we need to transmit all the calculated data to the mappers, and we do it by broadcasting the corresponding data sets:

int smoothing = 1;
testSet.map(new ClassifierMapper(smoothing))
       .withBroadcastSet(priors, "priors")
       .withBroadcastSet(counts, "counts")
       .withBroadcastSet(totalCount, "totalCount");

The ClassifierMapper is a RichMapFunction:

public static class ClassifierMapper 
                extends RichMapFunction<String, Tuple2<String, String>> {
    private int smooting;
    private NaiveBayesClassifier classifier;

    public ClassifierMapper(int smooting) {
        this.smooting = smooting;
    }

    @Override
    public void open(Configuration parameters) throws Exception {
        super.open(parameters);
        RuntimeContext ctx = getRuntimeContext();
        List<Tuple2<String, Integer>> countPerCategory = 
                ctx.getBroadcastVariable("countPerCategory");
        List<Tuple2<String, Double>> priors = 
                ctx.getBroadcastVariable("priors");
        List<Tuple3<String, String, Integer>> counts = 
                ctx.getBroadcastVariable("counts");

        classifier = new NaiveBayesClassifier(smooting);
        classifier.init(priors, counts, countPerCategory);
    }

    @Override
    public Tuple2<String, String> map(String value) throws Exception {
        String[] split = value.split("\t");
        String actualLabel = split[0];

        String[] words = split[1].split(",");
        String predictedLabel = classifier.predict(words);

        return new Tuple2<>(actualLabel, predictedLabel);
    }
}

As previously, we access the broadcasted variables in the open method via the runtime context, and then we use them to initialize the classifier.

So the only remaining part here is the classifier itself.

Inside, it will have the following:

  • prior probabilities: a Map<String, Double> from category to double
  • count per category: a Map<String, Integer> from category to count
  • conditional word counts: a Table<String, String, Integer> from (category, word) to count
  • list of all labels List<String>

All these fields are initialized when we call classifier.init(priors, counts, countPerCategory);

A Table here is a class from Google Guava, and in essence it's a map of a map: it's roughly equivalent to Map<String, Map<String, Integer>> but has some additional functionality. To use Guava, you need to add this dependency to your pom:

<dependency>
    <groupId>com.google.guava</groupId>
    <artifactId>guava</artifactId>
    <version>18.0</version>
</dependency>

Finding the class with highest probability is trivial: we just try all possible classes and take the one with highest (log) probability.

double maxLog = Double.NEGATIVE_INFINITY;

String predictionLabel = "";
for (String label : labels) {
    double logProb = calculateLogP(label, words);
    if (logProb > maxLog) {
        maxLog = logProb;
        predictionLabel = label;
    }
}

And finally, the log of the probability for each word is computed by

$$\text{cat}^* = \operatorname{argmax}_\text{cat} \left[ \sum_i \log P(w_i \mid \text{cat}) + \log P(\text{cat}) \right]$$

where $\hat P(w_i \mid \text{cat})$ is

$$\hat P(w_i \mid \text{cat}) = \cfrac{ \text{count($w_i$ in cat)} + \lambda }{\sum_j \text{count($w_j$ in cat)} + \lambda \cdot \text{# distinct words}}$$

Which in code looks like this:

double logProb = priors.get(label);
Map<String, Integer> countsPerLabel = conditionalWordCounts.row(label);

// numerator terms
for (String word : words) {
    Integer count = countsPerLabel.get(word);
    if (count != null) {
        logProb = logProb + Math.log(count + smoothing);
    } else {
        logProb = logProb + Math.log(smoothing);
    }
}

// denominator terms
double denom = countPerCategory.get(label) + smoothing * distinctWordCount;
logProb = logProb - words.length * Math.log(denom);

In the numerator, we sum over all word counts. Note that if a word is not present (it happens if we didn't see this word in category label during training), we add only the smoothing parameter $\lambda$. For each word the denominator is the same, so at the end we just subtract denom times number of words in the document.

Accuracy

There is one final thing: how do we check the accuracy of our model? From the previous step we have tuples (actual_label, predicted_label).

There's a simple way to calculate the accuracy from these tuples:

  • for each tuple omit (1, 1) if actual_label == predicted_label and (0, 1) otherwise
  • now reduce these tuples by summing element-wise: for two tuples (l1, r1) and (l2, r2) produce one tuple (l1 + l2, r1 + r2)
  • after reducing in the first position we have the number of correctly classified documents and the total number of documents in the second
  • finally, divide first by second to get the final accuracy
predictions.map(new MatchMapper())
           .reduce(new PerformaceEvaluatorReducer())
           .map(new PerformaceMapper())

So everything is ready and we can fire the execution. After completing it reports the accuracy 0.74. Not perfect, but also not very bad, especially considering the nature of the data.

The full code for this class is here.

Conclusions

In this post we implemented the Naive Bayes algorithm for text classification. We set an NLP pipeline and then did some simple counting in Apache Flink. Of course, the dataset that we used is quite small and it would be much faster to do the same computation in Python or Java with no tools for parallel computation. Also, the accuracy of 75% is not perfect and there is still a lot of room for improvement. For example, we could do additional preprocessing and be more careful with the way we do NLP: not just take the ready-to-use pipeline, but experiment with different settings and see which one produces the best results. We also could experiment more with the smoothing parameter $\lambda$: maybe $\lambda = 1$ is not the best parameter and we can do better.

However, the goal was to implement Naive Bayes in Flink and we successfully accomplished it. We also saw that computations for this classifier are easy to parallelize because it involves only counting, so this classifier potentially can handle very large data sets.

The entire project is accessible at github.

If you would like to learn more about Apache Flink, start with the official documentation: the docs for the current stable version (0.8) are available at https://flink.apache.org/docs/0.8/. You may also want to follow dataArtisans's blog - they are the company behind Apache Flink.

Acknowledgments

This post is based on materials presented in the "Scalable Data Mining" course by Sebastian Schelter, Christoph Boden and Juan Soto from TU Berlin, and I took this class as a part of the IT4BI curriculum in the winter semester of 2014-2015.

You can read more about classes we took at TU Berlin here, and if you are interested, you can check other posts tagged with IT4BI on this website.

Thanks for reading this post and stay tuned!

2 comments:

  1. Very nice work Alexey! A couple of suggestions from my side:

    I don't like much the idea of implementing you own classifier / scoring algorithm. Last year, I decided to store the output of the training algorithm in a PMML file and reuse the jpmml evaluator (https://github.com/jpmml/jpmml/tree/master/pmml-evaluator) in the classifier. I was just trying to avoid reinventing the wheel. That makes you code a bit more in the training algorithm, but saves you a lot of time in the scoring part and make it compatible with other software out there.

    Some thoughts on the scalability of the trainer would have been also nice to see. Because that's the point of using Apache Flink.

    ReplyDelete
  2. Really interesting blog post Alexey. It shows a really good use case of Flink for large scale machine learning and is written very comprehensible :-)

    ReplyDelete