duplicate_checker.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. """
  2. 数学题目查重助手 - 使用向量相似度比对题目是否重复
  3. """
  4. import json
  5. import os
  6. import pickle
  7. import re
  8. import numpy as np
  9. import faiss
  10. import pymysql
  11. import concurrent.futures
  12. from openai import OpenAI
  13. from typing import List, Dict, Tuple, Optional
  14. from config import DB_HOST, DB_PORT, DB_DATABASE, DB_USERNAME, DB_PASSWORD, OPENAI_API_KEY, OPENAI_BASE_URL
  15. class QuestionDuplicateChecker:
  16. """题目查重器"""
  17. def __init__(self, index_path="questions_tem.index", metadata_path="questions_tem_metadata.pkl"):
  18. """
  19. 初始化查重器
  20. Args:
  21. index_path: FAISS索引保存路径
  22. metadata_path: 元数据保存路径
  23. """
  24. self.client = OpenAI(
  25. api_key=OPENAI_API_KEY,
  26. base_url=OPENAI_BASE_URL
  27. )
  28. self.index_path = index_path
  29. self.metadata_path = metadata_path
  30. self.index = None
  31. self.metadata = [] # 存储题目ID和文本,以便回显
  32. # 权重设置
  33. self.weights = {
  34. 'stem': 0.8,
  35. 'options': 0.05,
  36. 'answer': 0.05,
  37. 'solution': 0.1
  38. }
  39. # 维度调整:4部分拼接后的总维度为 3072 * 4 = 12288
  40. self.dimension = 3072 * 4
  41. self._load_index()
  42. def clean_text(self, text: str) -> str:
  43. """清理文本:移除图片标签、归一化空白字符"""
  44. if not text:
  45. return ""
  46. # 移除 <image .../> 标签
  47. text = re.sub(r'<image\s+[^>]*/>', '', text)
  48. # 归一化空白:将多个空格、换行替换为单个空格
  49. text = re.sub(r'\s+', ' ', text)
  50. return text.strip()
  51. def normalize_options(self, options) -> str:
  52. """标准化选项内容"""
  53. if not options:
  54. return ""
  55. data = None
  56. if isinstance(options, str):
  57. # 如果是字符串,尝试解析 JSON
  58. text = options.strip()
  59. if (text.startswith('{') and text.endswith('}')) or (text.startswith('[') and text.endswith(']')):
  60. try:
  61. data = json.loads(text)
  62. except:
  63. pass
  64. if data is None:
  65. return self.clean_text(options)
  66. else:
  67. data = options
  68. if isinstance(data, dict):
  69. # 排序并清理字典内容
  70. sorted_data = {str(k).strip(): self.clean_text(str(v)) for k, v in sorted(data.items())}
  71. return self.clean_text(json.dumps(sorted_data, ensure_ascii=False))
  72. elif isinstance(data, list):
  73. # 清理列表内容
  74. cleaned_data = [self.clean_text(str(v)) for v in data]
  75. return self.clean_text(json.dumps(cleaned_data, ensure_ascii=False))
  76. return self.clean_text(str(data))
  77. def _load_index(self):
  78. """加载或初始化索引"""
  79. if os.path.exists(self.index_path) and os.path.exists(self.metadata_path):
  80. try:
  81. self.index = faiss.read_index(self.index_path)
  82. # 检查索引维度是否匹配,如果不匹配则重新初始化(针对权重升级)
  83. if self.index.d != self.dimension:
  84. print(f"⚠️ 索引维度不匹配 ({self.index.d} != {self.dimension}),正在重新初始化...")
  85. self._init_new_index()
  86. else:
  87. with open(self.metadata_path, 'rb') as f:
  88. self.metadata = pickle.load(f)
  89. print(f"✓ 已加载现有索引,包含 {len(self.metadata)} 道题目")
  90. except Exception as e:
  91. print(f"⚠️ 加载索引失败: {e},将初始化新索引")
  92. self._init_new_index()
  93. else:
  94. self._init_new_index()
  95. def _init_new_index(self):
  96. """初始化新的FAISS索引"""
  97. # 使用内积索引,配合权重拼接向量
  98. self.index = faiss.IndexFlatIP(self.dimension)
  99. self.metadata = []
  100. print("✓ 已初始化新的FAISS索引")
  101. def save_index(self):
  102. """保存索引和元数据"""
  103. faiss.write_index(self.index, self.index_path)
  104. with open(self.metadata_path, 'wb') as f:
  105. pickle.dump(self.metadata, f)
  106. print(f"✓ 索引和元数据已保存到 {self.index_path}")
  107. def get_weighted_embedding(self, question: Dict) -> np.ndarray:
  108. """
  109. 获取加权拼接向量
  110. 计算 4 个部分的 embedding,并乘以各自权重的平方根后拼接
  111. """
  112. parts = {
  113. 'stem': self.clean_text(question.get('stem', '') or ''),
  114. 'options': self.normalize_options(question.get('options', '') or ''),
  115. 'answer': self.clean_text(question.get('answer', '') or ''),
  116. 'solution': self.clean_text(question.get('solution', '') or '')
  117. }
  118. embeddings = []
  119. for key, text in parts.items():
  120. # 获取单部分 embedding
  121. emb = self.get_embedding(text if text.strip() else "无") # 防止空文本
  122. if emb is None:
  123. return None
  124. # 归一化该部分的向量
  125. norm = np.linalg.norm(emb)
  126. if norm > 0:
  127. emb = emb / norm
  128. # 乘以权重平方根:这样点积结果就是权重加权和
  129. weight_factor = np.sqrt(self.weights[key])
  130. embeddings.append(emb * weight_factor)
  131. # 拼接成一个大向量
  132. mega_vector = np.concatenate(embeddings).astype('float32')
  133. return mega_vector
  134. def get_embedding(self, text: str) -> np.ndarray:
  135. """获取文本的embedding向量"""
  136. try:
  137. # 统一清理文本
  138. cleaned_text = self.clean_text(text)
  139. if not cleaned_text.strip():
  140. cleaned_text = "无"
  141. response = self.client.embeddings.create(
  142. model="text-embedding-3-large",
  143. input=cleaned_text
  144. )
  145. embedding = np.array(response.data[0].embedding).astype('float32')
  146. return embedding
  147. except Exception as e:
  148. print(f"❌ 获取embedding失败: {e}")
  149. return None
  150. def get_embeddings_batch(self, texts: List[str]) -> List[np.ndarray]:
  151. """批量获取向量化结果"""
  152. try:
  153. if not texts:
  154. return []
  155. # 统一清理文本
  156. cleaned_texts = []
  157. for t in texts:
  158. c = self.clean_text(t)
  159. cleaned_texts.append(c if c.strip() else "无")
  160. response = self.client.embeddings.create(
  161. model="text-embedding-3-large",
  162. input=cleaned_texts
  163. )
  164. return [np.array(data.embedding).astype('float32') for data in response.data]
  165. except Exception as e:
  166. print(f"❌ 批量获取 embedding 失败: {e}")
  167. return []
  168. def get_weighted_embeddings_batch(self, questions: List[Dict]) -> List[np.ndarray]:
  169. """
  170. 批量获取加权拼接向量
  171. """
  172. all_texts = []
  173. for q in questions:
  174. # 获取 4 个部分并处理空值占位
  175. stem = self.clean_text(q.get('stem', '') or '')
  176. opts = self.normalize_options(q.get('options', '') or '')
  177. ans = self.clean_text(q.get('answer', '') or '')
  178. sol = self.clean_text(q.get('solution', '') or '')
  179. # OpenAI API 不接受空字符串,强制填充 "无"
  180. all_texts.append(stem if stem.strip() else "无")
  181. all_texts.append(opts if opts.strip() else "无")
  182. all_texts.append(ans if ans.strip() else "无")
  183. all_texts.append(sol if sol.strip() else "无")
  184. # 批量请求 (每题 4 个部分)
  185. all_embeddings = self.get_embeddings_batch(all_texts)
  186. if not all_embeddings:
  187. return []
  188. results = []
  189. for i in range(len(questions)):
  190. q_embs = all_embeddings[i*4 : (i+1)*4]
  191. if len(q_embs) < 4:
  192. results.append(None)
  193. continue
  194. weighted_parts = []
  195. keys = ['stem', 'options', 'answer', 'solution']
  196. for j, emb in enumerate(q_embs):
  197. # 归一化
  198. norm = np.linalg.norm(emb)
  199. if norm > 0:
  200. emb = emb / norm
  201. # 加权
  202. weight_factor = np.sqrt(self.weights[keys[j]])
  203. weighted_parts.append(emb * weight_factor)
  204. mega_vector = np.concatenate(weighted_parts).astype('float32')
  205. results.append(mega_vector)
  206. return results
  207. def combine_content(self, q: Dict) -> str:
  208. """组合题目内容用于预览显示"""
  209. stem = q.get('stem', '') or ''
  210. return self.clean_text(stem) # 预览以清理后的题干为主
  211. def fetch_question_from_db(self, question_id: int) -> Optional[Dict]:
  212. """从数据库获取题目信息"""
  213. try:
  214. conn = pymysql.connect(
  215. host=DB_HOST,
  216. port=DB_PORT,
  217. user=DB_USERNAME,
  218. password=DB_PASSWORD,
  219. database=DB_DATABASE,
  220. charset='utf8mb4',
  221. cursorclass=pymysql.cursors.DictCursor
  222. )
  223. with conn.cursor() as cursor:
  224. sql = "SELECT id, stem, options, answer, solution, is_repeat FROM questions_tem WHERE id = %s"
  225. cursor.execute(sql, (question_id,))
  226. result = cursor.fetchone()
  227. return result
  228. except Exception as e:
  229. print(f"❌ 数据库查询失败: {e}")
  230. return None
  231. finally:
  232. if 'conn' in locals() and conn:
  233. conn.close()
  234. def confirm_repeat(self, question_id: int, is_repeat: int) -> bool:
  235. """
  236. 人工确认结果后更新数据库和向量库
  237. is_repeat: 0 代表无相似题,1 代表有重复
  238. """
  239. # 1. 更新数据库状态
  240. self.update_is_repeat(question_id, is_repeat)
  241. # 2. 如果标记为非重复,将其加入向量库
  242. if is_repeat == 0:
  243. question = self.fetch_question_from_db(question_id)
  244. if question:
  245. mega_vec = self.get_weighted_embedding(question)
  246. if mega_vec is not None:
  247. # 转换维度以便 FAISS 接收
  248. mega_vec = mega_vec.reshape(1, -1)
  249. self.add_to_index(question_id, mega_vec, self.clean_text(question.get('stem', ''))[:100])
  250. return True
  251. return True
  252. def update_is_repeat(self, question_id: int, is_repeat: int):
  253. """更新数据库中的 is_repeat 字段"""
  254. try:
  255. conn = pymysql.connect(
  256. host=DB_HOST,
  257. port=DB_PORT,
  258. user=DB_USERNAME,
  259. password=DB_PASSWORD,
  260. database=DB_DATABASE,
  261. charset='utf8mb4'
  262. )
  263. with conn.cursor() as cursor:
  264. sql = "UPDATE questions_tem SET is_repeat = %s WHERE id = %s"
  265. cursor.execute(sql, (is_repeat, question_id))
  266. conn.commit()
  267. print(f"✓ 题目 ID {question_id} 的 is_repeat 已更新为 {is_repeat}")
  268. except Exception as e:
  269. print(f"❌ 更新 is_repeat 失败: {e}")
  270. finally:
  271. if 'conn' in locals() and conn:
  272. conn.close()
  273. def call_openai_for_duplicate_check(self, q1: Dict, q2: Dict) -> bool:
  274. """
  275. 调用 GPT-4o 进行深度对比,判断两道题是否为重复题
  276. """
  277. try:
  278. prompt = f"""你是一个专业的数学题目查重专家。请对比以下两道题目,判断它们是否为“重复题”。
  279. “重复题”的定义:题目背景、考查知识点、核心逻辑以及数值完全一致,或者仅有极其微小的表述差异但不影响题目本质。
  280. 题目 1:
  281. 题干: {q1.get('stem', '')}
  282. 选项: {q1.get('options', '')}
  283. 答案: {q1.get('answer', '')}
  284. 解析: {q1.get('solution', '')}
  285. 题目 2:
  286. 题干: {q2.get('stem', '')}
  287. 选项: {q2.get('options', '')}
  288. 答案: {q2.get('answer', '')}
  289. 解析: {q2.get('solution', '')}
  290. 请分析两道题的相似性,并给出最终结论。
  291. 结论必须以 JSON 格式输出,包含以下字段:
  292. - is_duplicate: 布尔值,True 表示是重复题,False 表示不是
  293. - reason: 简短的分析理由
  294. 只输出 JSON 即可,不要有其他解释说明。"""
  295. response = self.client.chat.completions.create(
  296. model="gpt-4o",
  297. messages=[
  298. {"role": "system", "content": "你是一个专业的数学题目查重助手。"},
  299. {"role": "user", "content": prompt}
  300. ],
  301. response_format={"type": "json_object"}
  302. )
  303. result = json.loads(response.choices[0].message.content)
  304. is_dup = result.get('is_duplicate', False)
  305. reason = result.get('reason', '')
  306. print(f"🤖 GPT-4o 判定结果: {'重复' if is_dup else '不重复'} | 理由: {reason}")
  307. return is_dup
  308. except Exception as e:
  309. print(f"❌ GPT-4o 兜底计算失败: {e}")
  310. return False
  311. def check_duplicate(self, question_id: int, threshold: float = 0.85) -> Dict:
  312. """
  313. 查重主逻辑 (未入库模式:不更新数据库,不自动入库)
  314. """
  315. # 1. 获取题目信息
  316. question = self.fetch_question_from_db(question_id)
  317. if not question:
  318. return {"status": "error", "message": f"未找到ID为 {question_id} 的题目"}
  319. # 2. 获取加权拼接向量
  320. mega_vector = self.get_weighted_embedding(question)
  321. if mega_vector is None:
  322. return {"status": "error", "message": "无法生成题目向量"}
  323. # 3. 搜索索引
  324. if self.index.ntotal == 0:
  325. return {
  326. "status": "success",
  327. "message": "向量库为空",
  328. "is_duplicate": False,
  329. "top_similar": []
  330. }
  331. # 搜索最相似的题目
  332. k = min(3, self.index.ntotal)
  333. query_vec = mega_vector.reshape(1, -1)
  334. similarities, indices = self.index.search(query_vec, k)
  335. best_similar = None
  336. is_duplicate = False
  337. gpt_checked = False
  338. for i in range(k):
  339. idx = indices[0][i]
  340. score = float(similarities[0][i])
  341. if idx != -1:
  342. similar_info = self.metadata[idx]
  343. similar_id = similar_info['id']
  344. if similar_id == question_id:
  345. continue
  346. best_similar = {
  347. "id": similar_id,
  348. "similarity": round(score, 4),
  349. "similar_point": similar_info.get('text_preview', '')
  350. }
  351. if score >= threshold:
  352. is_duplicate = True
  353. elif score > 0.5:
  354. # 相似度在 0.5 ~ 0.85 之间,调用 GPT-4o 兜底
  355. similar_question = self.fetch_question_from_db(similar_id)
  356. if similar_question:
  357. is_duplicate = self.call_openai_for_duplicate_check(question, similar_question)
  358. gpt_checked = True
  359. break
  360. if is_duplicate:
  361. return {
  362. "status": "warning",
  363. "message": "该题目与已存在题目相似,请人工核验" + (" (经 GPT-4o 兜底确认)" if gpt_checked else ""),
  364. "is_duplicate": True,
  365. "gpt_checked": gpt_checked,
  366. "top_similar": [best_similar]
  367. }
  368. else:
  369. return {
  370. "status": "success",
  371. "message": "该题目通过查重",
  372. "is_duplicate": False,
  373. "top_similar": [best_similar] if best_similar else []
  374. }
  375. def check_duplicate_by_content(self, question_data: Dict, threshold: float = 0.85) -> Dict:
  376. """
  377. 基于原始文本内容进行查重 (预检模式)
  378. """
  379. # 1. 获取加权拼接向量
  380. mega_vector = self.get_weighted_embedding(question_data)
  381. if mega_vector is None:
  382. return {"status": "error", "message": "无法生成题目向量"}
  383. # 2. 检索
  384. if self.index.ntotal == 0:
  385. return {"status": "success", "is_duplicate": False, "top_similar": []}
  386. k = 1 # 预检通常只需返回最相似的一个
  387. query_vec = mega_vector.reshape(1, -1)
  388. similarities, indices = self.index.search(query_vec, k)
  389. if indices[0][0] != -1:
  390. score = float(similarities[0][0])
  391. similar_info = self.metadata[indices[0][0]]
  392. similar_id = similar_info['id']
  393. best_similar = {
  394. "id": similar_id,
  395. "similarity": round(score, 4),
  396. "similar_point": similar_info.get('text_preview', '')
  397. }
  398. is_duplicate = False
  399. gpt_checked = False
  400. if score >= threshold:
  401. is_duplicate = True
  402. elif score > 0.5:
  403. # 相似度在 0.5 ~ 0.85 之间,调用 GPT-4o 兜底
  404. similar_question = self.fetch_question_from_db(similar_id)
  405. if similar_question:
  406. is_duplicate = self.call_openai_for_duplicate_check(question_data, similar_question)
  407. gpt_checked = True
  408. if is_duplicate:
  409. return {
  410. "status": "warning",
  411. "is_duplicate": True,
  412. "gpt_checked": gpt_checked,
  413. "top_similar": [best_similar]
  414. }
  415. return {"status": "success", "is_duplicate": False, "top_similar": [best_similar]}
  416. return {"status": "success", "is_duplicate": False, "top_similar": []}
  417. def add_to_index(self, question_id: int, vector: np.ndarray, text: str):
  418. """将题目加入索引"""
  419. if any(m['id'] == question_id for m in self.metadata):
  420. return
  421. self.index.add(vector)
  422. self.metadata.append({
  423. 'id': question_id,
  424. 'text_preview': self.clean_text(text)[:100]
  425. })
  426. self.save_index()
  427. def sync_all_from_db(self, batch_size=50, max_workers=5):
  428. """同步数据库中所有题目到索引 (支持加权模式 + 批量 + 多线程)"""
  429. print("正在进行全量同步 (优化版 - 加权模式)...")
  430. existing_ids = {m['id'] for m in self.metadata}
  431. try:
  432. conn = pymysql.connect(
  433. host=DB_HOST,
  434. port=DB_PORT,
  435. user=DB_USERNAME,
  436. password=DB_PASSWORD,
  437. database=DB_DATABASE,
  438. charset='utf8mb4',
  439. cursorclass=pymysql.cursors.DictCursor
  440. )
  441. with conn.cursor() as cursor:
  442. sql = "SELECT id, stem, options, answer, solution FROM questions_tem"
  443. cursor.execute(sql)
  444. all_questions = cursor.fetchall()
  445. new_questions = [q for q in all_questions if q['id'] not in existing_ids]
  446. total_new = len(new_questions)
  447. if total_new == 0:
  448. print("✅ 已经是最新状态,无需同步。")
  449. return
  450. print(f"📊 数据库总计: {len(all_questions)}, 需同步新增: {total_new}")
  451. # 分块处理
  452. chunks = [new_questions[i:i + batch_size] for i in range(0, total_new, batch_size)]
  453. def process_chunk(chunk):
  454. mega_vectors = self.get_weighted_embeddings_batch(chunk)
  455. return chunk, mega_vectors
  456. # 使用线程池并发
  457. with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
  458. future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
  459. count = 0
  460. for future in concurrent.futures.as_completed(future_to_chunk):
  461. chunk, mega_vectors = future.result()
  462. if mega_vectors:
  463. for i, mega_vec in enumerate(mega_vectors):
  464. if mega_vec is not None:
  465. self.index.add(mega_vec.reshape(1, -1))
  466. self.metadata.append({
  467. 'id': chunk[i]['id'],
  468. 'text_preview': self.clean_text(chunk[i].get('stem', ''))[:100]
  469. })
  470. count += len(chunk)
  471. print(f"✅ 已同步进度: {count}/{total_new}")
  472. # 每 500 条保存一次
  473. if count % 500 == 0:
  474. self.save_index()
  475. self.save_index()
  476. print(f"🎉 同步完成!当前索引总数: {len(self.metadata)}")
  477. except Exception as e:
  478. print(f"❌ 同步失败: {e}")
  479. import traceback
  480. traceback.print_exc()
  481. finally:
  482. if 'conn' in locals() and conn:
  483. conn.close()
  484. if __name__ == "__main__":
  485. # 测试代码
  486. checker = QuestionDuplicateChecker()
  487. # checker.sync_all_from_db() # 首次运行时同步
  488. # result = checker.check_duplicate(10)
  489. # print(json.dumps(result, ensure_ascii=False, indent=2))