什么是BM25

BM25算法是一种常见用来做相关度打分的公式,思路比较简单,主要就是计算一个query里面所有词和文档的相关度,然后在把分数做累加操作,而每个词的相关度分数主要还是受到tf/idf的影响。

BM25算法是常见的用来计算query和文章相关度的相似度的。其实这个算法的原理很简单,就是将需要计算的query分词成w1,w2,…,wn,然后求出每一个词和文章的相关度,最后将这些相关度进行累加,最终就可以的得到文本相似度计算结果。

Score(Q,d)=inWiR(qi,d)Score(Q,d) = \sum_{i}^{n}W_{i}\cdot R(q_{i},d)

首先Wi表示第i个词的权重,这里一般会使用TF-IDF算法来计算词语的权重这个公式第二项R(qi,d)表示我们查询query中的每一个词和文章d的相关度。一般来说Wi的计算用逆项文本频率IDF的计算公式:

IDF(qi)=logN+0.5n(qi)+0.5IDF(q_i) = \log \frac{N+0.5}{n(q_i)+0.5}

在这个公式中,N表示文档的总数,n(qi)表示包含这个词的文章数,为了避免对数里面分母项等于0,给分子分母同时加上0.5,这个0.5被称作调教系数,所以当n(qi)越小的时候IDF值就越大,表示词的权重就越大。

接着来看公式中的第二项R(qi,d),第二项的计算公式:

R(qi,d)=fi(k1+1)fi+Kqfi(k2+1)qfi+k2R(q_i,d) = \frac{f_i(k_1+1)}{f_i+K} \cdot \frac{qf_i(k_2+1)}{qf_i+k_2}

在这个公式中,一般来说,k1k_1k2k_2和b都是调节因子,k1k_1=1、k2k_2=1、b = 0.75,qfiqf_i表示qiq_i在查询query中出现的频率,fif_i表示qiq_i在文档d中出现的频率,因为在一般的情况下,qiq_i在查询query中只会出现一次,因此把qfiqf_i=1和k2k_2=1代入上述公式中,后面一项就等于1,最终可以得到:

R(qi,d)=fi(k1+1)fi+KR(q_i,d) = \frac{f_i(k_1+1)}{f_i+K}

在这里其实K的值也是一个公式的缩写,把K展开:

K=k1(1b+bdlavg(dl))K = k_1 \cdot (1-b+b \cdot \frac{dl}{avg(dl)})

K的展开式中dl表示文档的长度,avg(dl)表示文档的平均长度,b是前面提到的调节因子,从公式中可以看出在文章长度比平均文章长度固定的情况下,调节因子b越大,文章长度占有的影响权重就越大,反之则越小。在调节因子b固定的时候,当文章的长度比文章的平均长度越大,则K越大,R(qi,d)R(q_i,d)就越小。

代码实现

一个简单的实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import math
import os
import jieba
import pickle
import logging

jieba.setLogLevel(log_level=logging.INFO)


class BM25Param(object):
def __init__(self, f, df, idf, length, avg_length, docs_list, line_length_list,k1=1.5, k2=1.0,b=0.75):
"""

:param f:
:param df:
:param idf:
:param length:
:param avg_length:
:param docs_list:
:param line_length_list:
:param k1: 可调整参数,[1.2, 2.0]
:param k2: 可调整参数,[1.2, 2.0]
:param b:
"""
self.f = f
self.df = df
self.k1 = k1
self.k2 = k2
self.b = b
self.idf = idf
self.length = length
self.avg_length = avg_length
self.docs_list = docs_list
self.line_length_list = line_length_list

def __str__(self):
return f"k1:{self.k1}, k2:{self.k2}, b:{self.b}"


class BM25(object):
_param_pkl = "data/param.pkl"
_docs_path = "data/data.txt"
_stop_words_path = "data/stop_words.txt"
_stop_words = []

def __init__(self, docs=""):
self.docs = docs
self.param: BM25Param = self._load_param()

def _load_stop_words(self):
if not os.path.exists(self._stop_words_path):
raise Exception(f"system stop words: {self._stop_words_path} not found")
stop_words = []
with open(self._stop_words_path, 'r', encoding='utf8') as reader:
for line in reader:
line = line.strip()
stop_words.append(line)
return stop_words

def _build_param(self):

def _cal_param(reader_obj):
f = [] # 列表的每一个元素是一个dict,dict存储着一个文档中每个词的出现次数
df = {} # 存储每个词及出现了该词的文档数量
idf = {} # 存储每个词的idf值
lines = reader_obj.readlines()
length = len(lines)
words_count = 0
docs_list = []
line_length_list =[]
for line in lines:
line = line.strip()
if not line:
continue
words = [word for word in jieba.lcut(line) if word and word not in self._stop_words]
line_length_list.append(len(words))
docs_list.append(line)
words_count += len(words)
tmp_dict = {}
for word in words:
tmp_dict[word] = tmp_dict.get(word, 0) + 1
f.append(tmp_dict)
for word in tmp_dict.keys():
df[word] = df.get(word, 0) + 1
for word, num in df.items():
idf[word] = math.log(length - num + 0.5) - math.log(num + 0.5)
param = BM25Param(f, df, idf, length, words_count / length, docs_list, line_length_list)
return param

# cal
if self.docs:
if not os.path.exists(self.docs):
raise Exception(f"input docs {self.docs} not found")
with open(self.docs, 'r', encoding='utf8') as reader:
param = _cal_param(reader)

else:
if not os.path.exists(self._docs_path):
raise Exception(f"system docs {self._docs_path} not found")
with open(self._docs_path, 'r', encoding='utf8') as reader:
param = _cal_param(reader)

with open(self._param_pkl, 'wb') as writer:
pickle.dump(param, writer)
return param

def _load_param(self):
self._stop_words = self._load_stop_words()
if self.docs:
param = self._build_param()
else:
if not os.path.exists(self._param_pkl):
param = self._build_param()
else:
with open(self._param_pkl, 'rb') as reader:
param = pickle.load(reader)
return param

def _cal_similarity(self, words, index):
score = 0
for word in words:
if word not in self.param.f[index]:
continue
molecular = self.param.idf[word] * self.param.f[index][word] * (self.param.k1 + 1)
denominator = self.param.f[index][word] + self.param.k1 * (1 - self.param.b +
self.param.b * self.param.line_length_list[index] /
self.param.avg_length)
score += molecular / denominator
return score

def cal_similarity(self, query: str):
"""
相似度计算,无排序结果
:param query: 待查询结果
:return: [(doc, score), ..]
"""
words = [word for word in jieba.lcut(query) if word and word not in self._stop_words]
score_list = []
for index in range(self.param.length):
score = self._cal_similarity(words, index)
score_list.append((self.param.docs_list[index], score))
return score_list

def cal_similarity_rank(self, query: str):
"""
相似度计算,排序
:param query: 待查询结果
:return: [(doc, score), ..]
"""
result = self.cal_similarity(query)
result.sort(key=lambda x: -x[1])
return result