1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
| import math import os import jieba import pickle import logging
jieba.setLogLevel(log_level=logging.INFO)
class BM25Param(object): def __init__(self, f, df, idf, length, avg_length, docs_list, line_length_list,k1=1.5, k2=1.0,b=0.75): """
:param f: :param df: :param idf: :param length: :param avg_length: :param docs_list: :param line_length_list: :param k1: 可调整参数,[1.2, 2.0] :param k2: 可调整参数,[1.2, 2.0] :param b: """ self.f = f self.df = df self.k1 = k1 self.k2 = k2 self.b = b self.idf = idf self.length = length self.avg_length = avg_length self.docs_list = docs_list self.line_length_list = line_length_list
def __str__(self): return f"k1:{self.k1}, k2:{self.k2}, b:{self.b}"
class BM25(object): _param_pkl = "data/param.pkl" _docs_path = "data/data.txt" _stop_words_path = "data/stop_words.txt" _stop_words = []
def __init__(self, docs=""): self.docs = docs self.param: BM25Param = self._load_param()
def _load_stop_words(self): if not os.path.exists(self._stop_words_path): raise Exception(f"system stop words: {self._stop_words_path} not found") stop_words = [] with open(self._stop_words_path, 'r', encoding='utf8') as reader: for line in reader: line = line.strip() stop_words.append(line) return stop_words
def _build_param(self):
def _cal_param(reader_obj): f = [] df = {} idf = {} lines = reader_obj.readlines() length = len(lines) words_count = 0 docs_list = [] line_length_list =[] for line in lines: line = line.strip() if not line: continue words = [word for word in jieba.lcut(line) if word and word not in self._stop_words] line_length_list.append(len(words)) docs_list.append(line) words_count += len(words) tmp_dict = {} for word in words: tmp_dict[word] = tmp_dict.get(word, 0) + 1 f.append(tmp_dict) for word in tmp_dict.keys(): df[word] = df.get(word, 0) + 1 for word, num in df.items(): idf[word] = math.log(length - num + 0.5) - math.log(num + 0.5) param = BM25Param(f, df, idf, length, words_count / length, docs_list, line_length_list) return param
if self.docs: if not os.path.exists(self.docs): raise Exception(f"input docs {self.docs} not found") with open(self.docs, 'r', encoding='utf8') as reader: param = _cal_param(reader)
else: if not os.path.exists(self._docs_path): raise Exception(f"system docs {self._docs_path} not found") with open(self._docs_path, 'r', encoding='utf8') as reader: param = _cal_param(reader)
with open(self._param_pkl, 'wb') as writer: pickle.dump(param, writer) return param
def _load_param(self): self._stop_words = self._load_stop_words() if self.docs: param = self._build_param() else: if not os.path.exists(self._param_pkl): param = self._build_param() else: with open(self._param_pkl, 'rb') as reader: param = pickle.load(reader) return param
def _cal_similarity(self, words, index): score = 0 for word in words: if word not in self.param.f[index]: continue molecular = self.param.idf[word] * self.param.f[index][word] * (self.param.k1 + 1) denominator = self.param.f[index][word] + self.param.k1 * (1 - self.param.b + self.param.b * self.param.line_length_list[index] / self.param.avg_length) score += molecular / denominator return score
def cal_similarity(self, query: str): """ 相似度计算,无排序结果 :param query: 待查询结果 :return: [(doc, score), ..] """ words = [word for word in jieba.lcut(query) if word and word not in self._stop_words] score_list = [] for index in range(self.param.length): score = self._cal_similarity(words, index) score_list.append((self.param.docs_list[index], score)) return score_list
def cal_similarity_rank(self, query: str): """ 相似度计算,排序 :param query: 待查询结果 :return: [(doc, score), ..] """ result = self.cal_similarity(query) result.sort(key=lambda x: -x[1]) return result
|