基于Redis实现一个朴素贝叶斯文本分类器

基于朴素贝叶斯进行文本分类因为实现简单被广泛应用,很多开源的机器学习框架都提供了相应的实现。使用场景如新闻类内容的分类,垃圾信息识别等等。

原理

贝叶斯的定义就不具体说了,概率统计都有讲,如果没学过也没关系,去wikipedia上看看。非常朴素,值得一看。

分类器最重要的是先要有“分类”,而且分类之间应该是相对来说重合度比较低的,或者说是正交的。然后对每个分类进行词的统计,然后就会生成一个关键词序列,或者叫模型。有点像这个分类的DNA序列一样,即可以理解为这个分类在用词维度上的特征(特征向量)。然后就是将目标内容基于这一系列的“DNA”序列进行计算likelihood,值最大的那个就是了。

实践

具体的实现我就不上代码了。主要几步:

0. 抽样

从对应的分类中获取足够多的样本,并且要避免有脏样本,否则生成的分类器会不准确。

1. 分词

这个很容易理解,因为是要用到词,中文就一定会用到分词工具。Java里有paoding, PHP里有Jieba,甚至新浪SAE的分词服务…这方面有很多开源的解决方案。

2. 统计、生成DNA(模型)

针对每个分类的样本集合,统计每个关键词的出现频率。最终得到一个“词-次数”的Map。这里就用到Redis了,其实不一定要用Redis,用Redis主要是考虑到后续的模型自动反馈改进。如果只是不怕麻烦而手动生成的话,这里保存模型到哪里其实无所谓。

3. 权重设置

针对每个词在所在分类里出现的频率来设置一个后续计算的权重。最简单是:这个词在所有词中的出现概率(该词的出现次数 / 所在分类中样本的总词数)。用Redis的SortedSet来保存。

4. 使用

对于一个目标样本,分词,然后针对每个分类计算相应的权重值和。

5. 模型自动化反馈

这个还是比较重要的思路,因为模型也是会变动的,你的分类可能也在不断进化,比如法制新闻的内容可能随着法制热点或者法治进程的变化,产生不同的特征(这个例子有点YY了)。所以模型的产生应该随着新的样本不断地变化,所以生成模型的过程应该是保持一定的频率进行,同时,每次的使用模型进行分类的目标样本如果区分度较高,也可以直接加入到模型里进行计算。

 

优化

避免overfit – 模型并不容易达到足够高的准确率,准确率过高的原因可能是因为分类区分度很高,也许根本就不需要用分类器类做这件事,而准确率过低有可能是你的抽样有问题,或者你的分类有很大的重叠,比如生活和社会这种。所以如果是目标样本非常fit某个分类,建议就不要放到新的抽样样本里,因为不会产生新的“基因”了,如果这种样本过多,最终就容易overfit。

数量取胜 – 也是很朴素的想法,一个分类器不行,那么十个、二十个呢?如果能够把整个抽样提取、模型生成、模型反馈机制都自动化了,那么用多个分类器进行交叉验证,或者是来个投票机制。准确率应该也可以有很好的提升。

通过使用mahout的核心api来实现中文文本的分类

mahout版本:0.7 (mahout的安装请参考:https://cwiki.apache.org/MAHOUT/buildingmahout.html)

hadoop版本:1.0.3

lucene版本:3.6.0(+paoding1.0无法通过maven直接导入依赖,需要单独加到classpath)

 

mahout标准的分类demo展示了如何对英文文本内容分类,并且使用的是mahout的命令行脚本。下面主要是介绍使用mahout的api来完成同样的事,并且是中文。

 

标准的mahout文本分类分为以下几步:

  1. sequencing:将训练样本从简单的文本转换为hadoop标准的sequence格式
  2. vectorize: 向量化,将sequence格式的样本转换为向量形式。如“钓鱼岛是中国的”,先转换为分词后的序列“钓鱼,钓鱼岛,岛,中国”,然后给每个词一个index,最终转换为”1212:1,232:1, 16:1,789:1″这样的向量
  3. split: 样本分为训练集和测试集
  4. train: 根据训练集训练出模型
  5. test: 使用测试集测试模型并获得模型的准确率和confusionmatrix
  6. classify: 使用模型对新样本进行分类

下面分别介绍每个步骤:

 

sequencing

 

java代码:

args = new String[] { “-i”, sampleDir, “-o”, sequenceDir };

SequenceFilesFromDirectory sequenceJob = new SequenceFilesFromDirectory();

sequenceJob.setConf(getConf());

sequenceJob.run(args);

 

sampleDir是样本文件,组织结构是每个类一个子文件夹,文件夹内是对应的一个个样本文件。

执行后会在sequenceDir目录下看到一个chunk-0的文件,可以通过mahout的SequenceFileDumper转换为文本,转换后可以看到内容是这样:

 

Input Path: file:/Users/derekzhangv/Develop/temptest/testTopic/testTopic-seq/chunk-0

Key class: class org.apache.hadoop.io.Text Value Class: class org.apache.hadoop.io.Text

 

Key: /good/51057: Value: 浪潮之巅

 

Key: /good/55107: Value: 电动车

 

Key: /bad/85364: Value: 婴儿用品

 

Count: 3

 

(技巧:如果需要在命令行里方便地查看可以设置一个alias:“alias seqdump=’/Users/derekzhangv/Develop/mahout-0.7/bin/mahout seqdumper -i `pwd`/$1 | more’ ”)

 

vectorize

 

java代码:

args = new String[] { “-i”, sequenceDir, “-o”, vectorDir, “-lnorm”,

“-nv”, “-wt”, “tfidf”, “-s”, “2”// minSuport, default 2 , “-a”, “net.paoding.analysis.analyzer.PaodingAnalyzer” };

SparseVectorsFromSequenceFiles vectorizeJob = new SparseVectorsFromSequenceFiles();

vectorizeJob.setConf(getConf());

vectorizeJob.run(args);

 

参数:-i/-o是指定输入/输出的目录,-lnorm指定使用对数来normalize向量,-nv指定向量需要命名,-wt指定权重计算方法tf或者tfidf,-s指定最小支持数量(即只考虑出现指定次数或者以上的词),-a指定analyzer分词方法,这里只需要指定一个中文分词analyzer既可,这里是PaodingAnalyzer。

 

执行后会在[vectorDir]下产生如下文件:

df-count

dictionary.file-0

frequency.file-0

tf-vectors

tfidf-vectors

tokeninzed-documents

wordcount

可以用vectorDumper看看里面的内容。这里就不赘述了。

 

split

 

java代码:

args = new String[] { “-i”, tfidfVectorDir, “–trainingOutput”,trainVectorDir, “–testOutput”, testVectorDir,”–randomSelectionPct”, “40”, “–overwrite”, “–sequenceFiles”,

     “-xm”, “sequential” };

SplitInput splitJob = new SplitInput();

splitJob.setConf(getConf());

splitJob.run(args);

 

参数:–trainingOutput和–testOutput分别指定训练和测试样本的输出路径,–randomSelectionPct指定比例,–sequenceFiles说明输入文件格式为sequence,-xm执行方式,默认是map reduce,这里指定sequential

执行后会产生对应的两个文件夹,里面是part-r-00000文件,可以通过sequencedumper查看。

 

train

 

java代码:

args = new String[] { “-i”, trainVectorDir, “-o”, modelDir, “-li”,labelIndexDir, “-el”, “-ow” };

TrainNaiveBayesJob trainJob = new TrainNaiveBayesJob();

trainJob.setConf(getConf());

trainJob.run(args);

 

参数:-o指定模型输出路径,-li指定labelindex文件的路径,这个文件是在vectorize时候产生的。-el指明label是从样本中提取,-ow指明可以覆盖文件。

执行后会在[modelDir]下看到naiveBayesModel.bin文件,就是训练出来的结果了。

 

test

 

java代码:

args = new String[] { “-i”, testVectorDir, “-m”, modelDir, “-l”,

labelIndexDir, “-ow”, “-o”, testingDir };

TestNaiveBayesDriver testJob = new TestNaiveBayesDriver();

testJob.setConf(getConf());

testJob.run(args);

 

参数:-i指定测试样本vector的路径,-m模型所在路径,-l是label索引路径,-ow指明覆盖原来的输出,-o输出路径

输出的文件可以通过sequencedumper查看。

 

classify

 

java代码:(有点dirty)

Configuration conf = new Configuration();

AbstractNaiveBayesClassifier classifier = loadClassifier(topicId, conf);

Vector instance = buildInstance(topicId,text);

Vector r = classifier.classifyFull(instance);

Path labelIndexPath = new Path(this.tempDir + “/” + topicId +“/” + topicId

+ “-labelIndex”);

Map<Integer, String> labelMap = BayesUtils.readLabelIndex(conf,

labelIndexPath);

 

int bestIdx = Integer.MIN_VALUE;

double bestScore = Long.MIN_VALUE;

HashMap<String, Double> resultMap = new HashMap<String, Double>();

for (int i = 0; i < labelMap.size(); i++) {

Vector.Element element = r.getElement(i);

resultMap.put(labelMap.get(element.index()), element.get());

if (element.get() > bestScore) {

bestScore = element.get();

bestIdx = element.index();

}

}

ClassifyResult result = new ClassifyResult();

if (bestIdx != Integer.MIN_VALUE) {

String label = labelMap.get(bestIdx);

double score = bestScore;

result.setLabel(label);

result.setScore(score);

}

return result;

 

–附上buildInstance和loadClassifier的代码:

private AbstractNaiveBayesClassifier loadClassifier(String topicId,Configuration conf) throws IOException {

Path modelPath = new Path(this.modelDir + “/” + topicId + “-model”);

NaiveBayesModel model = NaiveBayesModel.materialize(modelPath, conf);

AbstractNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(model);

  return classifier;

}

 

private Vector buildInstance(String topicId,String text){

try {

reBuildDictionary(topicId);

  }  catch (IOException e) {

e.printStackTrace();

}

Vector vector = new RandomAccessSparseVector(FEATURES);

FeatureExtractor fe = new FeatureExtractor();

HashSet<String> fs = fe.extract(text);

for (String s : fs) {

int index = dictionary.get(s);

vector.setQuick(index, frequency.get(index));

}

return vector;

}

(这里就是比较dirty的地方,需要根据dictionary.file-0和frequency.file-0来构造索引,然后用来vectorize目标样本(要分类的样本)

private boolean dictRebuilt = false;

privatevoid reBuildDictionary(String topicId) throws IOException{

if(dictRebuilt) return;

Configuration conf = getConf();

 

    Path dictionaryFile = new Path(tempDir+“/”+topicId+“/”+topicId+“-vectors/dictionary.file-0”);

    // key is feature, value is the document frequency

    for (Pair<Text,IntWritable> record 

        : new SequenceFileIterable<Text,IntWritable>(dictionaryFile, true, conf)) {

      dictionary.put(record.getFirst().toString(), record.getSecond().get());

    }

    Path freqFile = new Path(tempDir+“/”+topicId+“/”+topicId+“-vectors/frequency.file-0”);

    // key is feature, value is the document frequency

    for (Pair<IntWritable,LongWritable> record 

        : new SequenceFileIterable<IntWritable,LongWritable>(freqFile, true, conf)) {

    frequency.put(record.getFirst().get(), record.getSecond().get());

    }

    dictRebuilt = true;

}

 

结束。