eumnq8 发表于 2021-1-13 15:45:00

敏感词过滤优化记录

问题背景:
之前做的聊天项目中涉及敏感词过滤,当时只是采用了简单的字符串匹配,敏感词量还不大,性能还能接受。
最近刚好有时间,于是就研究了trie tree来实现敏感词过滤。

资料
trie tree的原理网上很多。

几种Trie树性能比较 https://www.hankcs.com/nlp/performance-comparison-of-several-trie-tree.html

找到了darts的c/c++版 https://github.com/s-yata/darts-clone

使用
主要涉及几个接口的使用:

build : 传入敏感词数组,构建trie tree

exactMatchSearch : 查找完全匹配的敏感词

commonPrefixSearch : 前缀匹配。找到已字符串开头开始匹配的敏感词,并返回匹配的数组

性能比较
主要是和之前普通的字符串匹配进行性能比较,已10000个敏感词做测试,差距在10000倍以上。理论上性能差距量级 = 敏感词数量。

测试代码
// redistest.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。
//
#include <iostream>
#include <stdint.h>
#include <cassert>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <set>
#include <string>
#include <vector>
#include <algorithm>
#include "windows.h"
#include "darts.h"
using namespace Darts;
void generate_valid_keys(std::size_t num_keys,
    std::set<std::string>* valid_keys) {
    std::vector<char> key;
    while (valid_keys->size() < num_keys) {
      key.resize(1 + (std::rand() % 8));
      for (std::size_t i = 0; i < key.size(); ++i) {
            key = 'A' + (std::rand() % 26);
      }
      valid_keys->insert(std::string(&key, key.size()));
    }
}
void generate_invalid_keys(std::size_t num_keys,
    const std::set<std::string>& valid_keys,
    std::set<std::string>* invalid_keys) {
    std::vector<char> key;
    while (invalid_keys->size() < num_keys) {
      key.resize(1 + (std::rand() % 8));
      for (std::size_t i = 0; i < key.size(); ++i) {
            key = 'A' + (std::rand() % 26);
      }
      std::string generated_key(&key, key.size());
      if (valid_keys.find(generated_key) == valid_keys.end())
            invalid_keys->insert(std::string(&key, key.size()));
    }
}
typedef std::vector<bool> Mask;
typedef std::vector<std::string> FilterWords;
typedef std::vector<Mask> MaskArray;
inline std::string to_lower_str(const std::string& s) {
    std::string s2;
    s2.resize(s.size());
    std::transform(s.begin(), s.end(), s2.begin(), ::towlower);
    return s2;
}
std::set<std::string> g_valid_keys;
Darts::DoubleArray g_dic;
void ReplaceAndReport(std::string sText)
{
    std::string sCopy = to_lower_str(sText);
    int nMaskLength = sCopy.length();
    Mask textMask(nMaskLength);
    FilterWords matchedWords;
    MaskArray wordMaskArray;
    for (const auto& word : g_valid_keys)
    {
      size_t pos = 0;
      if ((pos = sCopy.find(word, 0)) == std::string::npos) continue;
      matchedWords.push_back(word);
      size_t nWordLength = word.length();
      Mask wordMask(nMaskLength);
      do {
            std::fill_n(wordMask.begin() + pos, nWordLength, true);
            std::fill_n(textMask.begin() + pos, nWordLength, true);
            pos += nWordLength;
      } while ((pos = sCopy.find(word, pos)) != std::string::npos);
      wordMaskArray.push_back(std::move(wordMask));
    }
    if (matchedWords.empty()) return;
    int i = 0;
    std::replace_if(sText.begin(), sText.end(), [&i, &textMask](char) { return textMask; }, '*');
}
void ReplaceAndReportV2(std::string sText)
{
    std::string sCopy = to_lower_str(sText);
    static const std::size_t MAX_NUM_RESULTS = 16;
    for (int i = 0; i < sText.size();)
    {
      typename Darts::DoubleArray::result_pair_type results = { 0 };
      std::size_t num_results = g_dic.commonPrefixSearch(&sText, results, MAX_NUM_RESULTS);
      if (num_results > 0)
      {
            int offset = results[(std::min)(num_results, MAX_NUM_RESULTS)-1].length;
            std::fill_n(sText.begin() + i, offset, '*');
            i += offset;
      }
      else
      {
            if (sText < 0)//表示中文
            {
                i += 2;
            }
            else
            {
                i++;
            }
      }
    }
}
int main()
{
    static const std::size_t NUM_VALID_KEYS = 1 << 17;
    static const std::size_t NUM_INVALID_KEYS = 1 << 17;
    generate_valid_keys(NUM_VALID_KEYS, &g_valid_keys);
    g_valid_keys.insert("傻逼BCD");
    g_valid_keys.insert("傻逼A");
    g_valid_keys.insert("傻逼");
    std::vector<const char*> keys(g_valid_keys.size());
    std::vector<std::size_t> lengths(g_valid_keys.size());
    std::vector<typename Darts::DoubleArray::value_type> values(g_valid_keys.size());
    std::size_t key_id = 0;
    for (std::set<std::string>::const_iterator it = g_valid_keys.begin(); it != g_valid_keys.end(); ++it, ++key_id) {
      keys = it->c_str();
    }
   // 文档里面没有说明,敏感词数组需要排序之后才能build,所以需要使用set容器
    g_dic.build(keys.size(), &keys);
    std::string strTest = "你是个傻逼BCD啊你是个玩";
    int testTimes = 1000;
    DWORD a1 = GetTickCount();
    for (int i = 0; i < testTimes; ++i)
    {
      ReplaceAndReport(strTest);
    }
    printf("find count %d \n",GetTickCount() - a1);
    /*a1 = GetTickCount();
    int result;
    for (int i = 0; i < testTimes; ++i)
    {
      g_dic.exactMatchSearch(strTest.c_str(), result);
    }
    printf("exactMatchSearch count %d", GetTickCount() - a1);
    */
    a1 = GetTickCount();
    for (int i = 0; i < testTimes; ++i)
    {
      ReplaceAndReportV2(strTest);
    }
    printf("commonPrefixSearch count %d", GetTickCount() - a1);
}

页: [1]
查看完整版本: 敏感词过滤优化记录