duplicate_checker.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705
  1. """
  2. 数学题目查重助手 - 使用向量相似度比对题目是否重复
  3. """
  4. import json
  5. import os
  6. import pickle
  7. import re
  8. import threading
  9. import time
  10. import tempfile
  11. import numpy as np
  12. import faiss
  13. import pymysql
  14. import concurrent.futures
  15. from openai import OpenAI
  16. from typing import List, Dict, Tuple, Optional
  17. from config import DB_HOST, DB_PORT, DB_DATABASE, DB_USERNAME, DB_PASSWORD, OPENAI_API_KEY, OPENAI_BASE_URL
  18. class QuestionDuplicateChecker:
  19. """题目查重器"""
  20. def __init__(self, index_path="questions_tem.index", metadata_path="questions_tem_metadata.pkl"):
  21. """
  22. 初始化查重器
  23. Args:
  24. index_path: FAISS索引保存路径
  25. metadata_path: 元数据保存路径
  26. """
  27. self.client = OpenAI(
  28. api_key=OPENAI_API_KEY,
  29. base_url=OPENAI_BASE_URL
  30. )
  31. base_dir = os.path.dirname(os.path.abspath(__file__))
  32. # 统一使用绝对路径,避免不同工作目录导致索引文件混乱
  33. self.index_path = index_path if os.path.isabs(index_path) else os.path.join(base_dir, index_path)
  34. self.metadata_path = metadata_path if os.path.isabs(metadata_path) else os.path.join(base_dir, metadata_path)
  35. self.index = None
  36. self.metadata = [] # 存储题目ID和文本,以便回显
  37. self.sync_lock = threading.Lock()
  38. self.reload_lock = threading.Lock()
  39. self.index_mtime = None
  40. self.metadata_mtime = None
  41. self.sync_in_progress = False
  42. self.reload_failures = 0
  43. self.next_reload_time = 0.0
  44. # 权重设置
  45. self.weights = {
  46. 'stem': 0.8,
  47. 'options': 0.05,
  48. 'answer': 0.05,
  49. 'solution': 0.1
  50. }
  51. # 维度调整:4部分拼接后的总维度为 3072 * 4 = 12288
  52. self.dimension = 3072 * 4
  53. # 限制 FAISS/OpenMP 线程,避免小机器 CPU 过载
  54. faiss_threads = max(1, min(2, os.cpu_count() or 1))
  55. try:
  56. faiss.omp_set_num_threads(faiss_threads)
  57. except Exception:
  58. pass
  59. self._load_index()
  60. def clean_text(self, text: str) -> str:
  61. """清理文本:移除图片标签、归一化空白字符"""
  62. if not text:
  63. return ""
  64. # 移除 <image .../> 标签
  65. text = re.sub(r'<image\s+[^>]*/>', '', text)
  66. # 归一化空白:将多个空格、换行替换为单个空格
  67. text = re.sub(r'\s+', ' ', text)
  68. return text.strip()
  69. def normalize_options(self, options) -> str:
  70. """标准化选项内容"""
  71. if not options:
  72. return ""
  73. data = None
  74. if isinstance(options, str):
  75. # 如果是字符串,尝试解析 JSON
  76. text = options.strip()
  77. if (text.startswith('{') and text.endswith('}')) or (text.startswith('[') and text.endswith(']')):
  78. try:
  79. data = json.loads(text)
  80. except:
  81. pass
  82. if data is None:
  83. return self.clean_text(options)
  84. else:
  85. data = options
  86. if isinstance(data, dict):
  87. # 排序并清理字典内容
  88. sorted_data = {str(k).strip(): self.clean_text(str(v)) for k, v in sorted(data.items())}
  89. return self.clean_text(json.dumps(sorted_data, ensure_ascii=False))
  90. elif isinstance(data, list):
  91. # 清理列表内容
  92. cleaned_data = [self.clean_text(str(v)) for v in data]
  93. return self.clean_text(json.dumps(cleaned_data, ensure_ascii=False))
  94. return self.clean_text(str(data))
  95. def _load_index(self):
  96. """加载或初始化索引"""
  97. if os.path.exists(self.index_path) and os.path.exists(self.metadata_path):
  98. try:
  99. self.index = faiss.read_index(self.index_path)
  100. # 检查索引维度是否匹配,如果不匹配则重新初始化(针对权重升级)
  101. if self.index.d != self.dimension:
  102. print(f"⚠️ 索引维度不匹配 ({self.index.d} != {self.dimension}),正在重新初始化...")
  103. self._init_new_index()
  104. else:
  105. with open(self.metadata_path, 'rb') as f:
  106. self.metadata = pickle.load(f)
  107. print(f"✓ 已加载现有索引,包含 {len(self.metadata)} 道题目")
  108. self._update_index_mtime()
  109. except Exception as e:
  110. print(f"⚠️ 加载索引失败: {e},将初始化新索引")
  111. self._init_new_index()
  112. else:
  113. self._init_new_index()
  114. def _init_new_index(self):
  115. """初始化新的FAISS索引"""
  116. # 使用内积索引,配合权重拼接向量
  117. self.index = faiss.IndexFlatIP(self.dimension)
  118. self.metadata = []
  119. self.index_mtime = None
  120. self.metadata_mtime = None
  121. self.reload_failures = 0
  122. self.next_reload_time = 0.0
  123. print("✓ 已初始化新的FAISS索引")
  124. def _update_index_mtime(self):
  125. """更新索引文件的最后修改时间记录"""
  126. self.index_mtime = os.path.getmtime(self.index_path) if os.path.exists(self.index_path) else None
  127. self.metadata_mtime = os.path.getmtime(self.metadata_path) if os.path.exists(self.metadata_path) else None
  128. def ensure_index_loaded(self, force: bool = False) -> bool:
  129. """确保索引已从磁盘加载(多进程下避免空索引常驻)"""
  130. # 读取失败后做退避,避免频繁重载导致磁盘抖动
  131. now = time.time()
  132. if not force and now < self.next_reload_time:
  133. return False
  134. index_exists = os.path.exists(self.index_path) and os.path.exists(self.metadata_path)
  135. if not index_exists:
  136. return False
  137. need_reload = force
  138. if not need_reload:
  139. # 空索引或元数据为空时,优先尝试从磁盘加载
  140. if self.index is None or self.index.ntotal == 0 or not self.metadata:
  141. need_reload = True
  142. else:
  143. # 文件有更新则重载
  144. current_index_mtime = os.path.getmtime(self.index_path)
  145. current_metadata_mtime = os.path.getmtime(self.metadata_path)
  146. if (self.index_mtime is None or self.metadata_mtime is None or
  147. current_index_mtime > self.index_mtime or current_metadata_mtime > self.metadata_mtime):
  148. need_reload = True
  149. if not need_reload:
  150. return False
  151. with self.reload_lock:
  152. # 同步进行中时,避免在索引尚未写完时反复重载
  153. if self.sync_in_progress and not force:
  154. return False
  155. # 双重检查,避免重复加载
  156. if not force and self.index is not None and self.index.ntotal > 0 and self.metadata:
  157. current_index_mtime = os.path.getmtime(self.index_path)
  158. current_metadata_mtime = os.path.getmtime(self.metadata_path)
  159. if (self.index_mtime is not None and self.metadata_mtime is not None and
  160. current_index_mtime <= self.index_mtime and current_metadata_mtime <= self.metadata_mtime):
  161. return False
  162. try:
  163. self.index = faiss.read_index(self.index_path)
  164. with open(self.metadata_path, 'rb') as f:
  165. self.metadata = pickle.load(f)
  166. self._update_index_mtime()
  167. self.reload_failures = 0
  168. self.next_reload_time = 0.0
  169. print(f"↻ 已重新加载索引,包含 {len(self.metadata)} 道题目")
  170. return True
  171. except Exception as e:
  172. print(f"⚠️ 重新加载索引失败: {e}")
  173. # 指数退避,最多 60 秒
  174. self.reload_failures += 1
  175. backoff = min(60.0, 2 ** self.reload_failures)
  176. self.next_reload_time = time.time() + backoff
  177. return False
  178. def save_index(self):
  179. """保存索引和元数据"""
  180. # 先写入临时文件,再原子替换,避免读写冲突导致索引损坏
  181. index_dir = os.path.dirname(self.index_path)
  182. meta_dir = os.path.dirname(self.metadata_path)
  183. os.makedirs(index_dir, exist_ok=True)
  184. os.makedirs(meta_dir, exist_ok=True)
  185. tmp_index_fd, tmp_index_path = tempfile.mkstemp(prefix=".questions_tem.index.", dir=index_dir)
  186. tmp_meta_fd, tmp_meta_path = tempfile.mkstemp(prefix=".questions_tem_metadata.pkl.", dir=meta_dir)
  187. try:
  188. os.close(tmp_index_fd)
  189. os.close(tmp_meta_fd)
  190. faiss.write_index(self.index, tmp_index_path)
  191. with open(tmp_meta_path, 'wb') as f:
  192. pickle.dump(self.metadata, f)
  193. f.flush()
  194. os.fsync(f.fileno())
  195. os.replace(tmp_index_path, self.index_path)
  196. os.replace(tmp_meta_path, self.metadata_path)
  197. self._update_index_mtime()
  198. print(f"✓ 索引和元数据已保存到 {self.index_path}")
  199. finally:
  200. # 清理临时文件(若 replace 成功则不存在)
  201. if os.path.exists(tmp_index_path):
  202. os.remove(tmp_index_path)
  203. if os.path.exists(tmp_meta_path):
  204. os.remove(tmp_meta_path)
  205. def get_weighted_embedding(self, question: Dict) -> np.ndarray:
  206. """
  207. 获取加权拼接向量
  208. 计算 4 个部分的 embedding,并乘以各自权重的平方根后拼接
  209. """
  210. parts = {
  211. 'stem': self.clean_text(question.get('stem', '') or ''),
  212. 'options': self.normalize_options(question.get('options', '') or ''),
  213. 'answer': self.clean_text(question.get('answer', '') or ''),
  214. 'solution': self.clean_text(question.get('solution', '') or '')
  215. }
  216. embeddings = []
  217. for key, text in parts.items():
  218. # 获取单部分 embedding
  219. emb = self.get_embedding(text if text.strip() else "无") # 防止空文本
  220. if emb is None:
  221. return None
  222. # 归一化该部分的向量
  223. norm = np.linalg.norm(emb)
  224. if norm > 0:
  225. emb = emb / norm
  226. # 乘以权重平方根:这样点积结果就是权重加权和
  227. weight_factor = np.sqrt(self.weights[key])
  228. embeddings.append(emb * weight_factor)
  229. # 拼接成一个大向量
  230. mega_vector = np.concatenate(embeddings).astype('float32')
  231. return mega_vector
  232. def get_embedding(self, text: str) -> np.ndarray:
  233. """获取文本的embedding向量"""
  234. try:
  235. # 统一清理文本
  236. cleaned_text = self.clean_text(text)
  237. if not cleaned_text.strip():
  238. cleaned_text = "无"
  239. response = self.client.embeddings.create(
  240. model="text-embedding-3-large",
  241. input=cleaned_text
  242. )
  243. embedding = np.array(response.data[0].embedding).astype('float32')
  244. return embedding
  245. except Exception as e:
  246. print(f"❌ 获取embedding失败: {e}")
  247. return None
  248. def get_embeddings_batch(self, texts: List[str]) -> List[np.ndarray]:
  249. """批量获取向量化结果"""
  250. try:
  251. if not texts:
  252. return []
  253. # 统一清理文本
  254. cleaned_texts = []
  255. for t in texts:
  256. c = self.clean_text(t)
  257. cleaned_texts.append(c if c.strip() else "无")
  258. response = self.client.embeddings.create(
  259. model="text-embedding-3-large",
  260. input=cleaned_texts
  261. )
  262. return [np.array(data.embedding).astype('float32') for data in response.data]
  263. except Exception as e:
  264. print(f"❌ 批量获取 embedding 失败: {e}")
  265. return []
  266. def get_weighted_embeddings_batch(self, questions: List[Dict]) -> List[np.ndarray]:
  267. """
  268. 批量获取加权拼接向量
  269. """
  270. all_texts = []
  271. for q in questions:
  272. # 获取 4 个部分并处理空值占位
  273. stem = self.clean_text(q.get('stem', '') or '')
  274. opts = self.normalize_options(q.get('options', '') or '')
  275. ans = self.clean_text(q.get('answer', '') or '')
  276. sol = self.clean_text(q.get('solution', '') or '')
  277. # OpenAI API 不接受空字符串,强制填充 "无"
  278. all_texts.append(stem if stem.strip() else "无")
  279. all_texts.append(opts if opts.strip() else "无")
  280. all_texts.append(ans if ans.strip() else "无")
  281. all_texts.append(sol if sol.strip() else "无")
  282. # 批量请求 (每题 4 个部分)
  283. all_embeddings = self.get_embeddings_batch(all_texts)
  284. if not all_embeddings:
  285. return []
  286. results = []
  287. for i in range(len(questions)):
  288. q_embs = all_embeddings[i*4 : (i+1)*4]
  289. if len(q_embs) < 4:
  290. results.append(None)
  291. continue
  292. weighted_parts = []
  293. keys = ['stem', 'options', 'answer', 'solution']
  294. for j, emb in enumerate(q_embs):
  295. # 归一化
  296. norm = np.linalg.norm(emb)
  297. if norm > 0:
  298. emb = emb / norm
  299. # 加权
  300. weight_factor = np.sqrt(self.weights[keys[j]])
  301. weighted_parts.append(emb * weight_factor)
  302. mega_vector = np.concatenate(weighted_parts).astype('float32')
  303. results.append(mega_vector)
  304. return results
  305. def combine_content(self, q: Dict) -> str:
  306. """组合题目内容用于预览显示"""
  307. stem = q.get('stem', '') or ''
  308. return self.clean_text(stem) # 预览以清理后的题干为主
  309. def fetch_question_from_db(self, question_id: int) -> Optional[Dict]:
  310. """从数据库获取题目信息"""
  311. try:
  312. conn = pymysql.connect(
  313. host=DB_HOST,
  314. port=DB_PORT,
  315. user=DB_USERNAME,
  316. password=DB_PASSWORD,
  317. database=DB_DATABASE,
  318. charset='utf8mb4',
  319. cursorclass=pymysql.cursors.DictCursor
  320. )
  321. with conn.cursor() as cursor:
  322. sql = "SELECT id, stem, options, answer, solution, is_repeat FROM questions_tem WHERE id = %s"
  323. cursor.execute(sql, (question_id,))
  324. result = cursor.fetchone()
  325. return result
  326. except Exception as e:
  327. print(f"❌ 数据库查询失败: {e}")
  328. return None
  329. finally:
  330. if 'conn' in locals() and conn:
  331. conn.close()
  332. def confirm_repeat(self, question_id: int, is_repeat: int) -> bool:
  333. """
  334. 人工确认结果后更新数据库和向量库
  335. is_repeat: 0 代表无相似题,1 代表有重复
  336. """
  337. # 1. 更新数据库状态
  338. self.update_is_repeat(question_id, is_repeat)
  339. # 2. 如果标记为非重复,将其加入向量库
  340. if is_repeat == 0:
  341. question = self.fetch_question_from_db(question_id)
  342. if question:
  343. mega_vec = self.get_weighted_embedding(question)
  344. if mega_vec is not None:
  345. # 转换维度以便 FAISS 接收
  346. mega_vec = mega_vec.reshape(1, -1)
  347. self.add_to_index(question_id, mega_vec, self.clean_text(question.get('stem', ''))[:100])
  348. return True
  349. return True
  350. def update_is_repeat(self, question_id: int, is_repeat: int):
  351. """更新数据库中的 is_repeat 字段"""
  352. try:
  353. conn = pymysql.connect(
  354. host=DB_HOST,
  355. port=DB_PORT,
  356. user=DB_USERNAME,
  357. password=DB_PASSWORD,
  358. database=DB_DATABASE,
  359. charset='utf8mb4'
  360. )
  361. with conn.cursor() as cursor:
  362. sql = "UPDATE questions_tem SET is_repeat = %s WHERE id = %s"
  363. cursor.execute(sql, (is_repeat, question_id))
  364. conn.commit()
  365. print(f"✓ 题目 ID {question_id} 的 is_repeat 已更新为 {is_repeat}")
  366. except Exception as e:
  367. print(f"❌ 更新 is_repeat 失败: {e}")
  368. finally:
  369. if 'conn' in locals() and conn:
  370. conn.close()
  371. def call_openai_for_duplicate_check(self, q1: Dict, q2: Dict) -> bool:
  372. """
  373. 调用 GPT-4o 进行深度对比,判断两道题是否为重复题
  374. """
  375. try:
  376. prompt = f"""你是一个专业的数学题目查重专家。请对比以下两道题目,判断它们是否为“重复题”。
  377. “重复题”的定义:题目背景、考查知识点、核心逻辑以及数值完全一致,或者仅有极其微小的表述差异但不影响题目本质。
  378. 题目 1:
  379. 题干: {q1.get('stem', '')}
  380. 选项: {q1.get('options', '')}
  381. 答案: {q1.get('answer', '')}
  382. 解析: {q1.get('solution', '')}
  383. 题目 2:
  384. 题干: {q2.get('stem', '')}
  385. 选项: {q2.get('options', '')}
  386. 答案: {q2.get('answer', '')}
  387. 解析: {q2.get('solution', '')}
  388. 请分析两道题的相似性,并给出最终结论。
  389. 结论必须以 JSON 格式输出,包含以下字段:
  390. - is_duplicate: 布尔值,True 表示是重复题,False 表示不是
  391. - reason: 简短的分析理由
  392. 只输出 JSON 即可,不要有其他解释说明。"""
  393. response = self.client.chat.completions.create(
  394. model="gpt-4o",
  395. messages=[
  396. {"role": "system", "content": "你是一个专业的数学题目查重助手。"},
  397. {"role": "user", "content": prompt}
  398. ],
  399. response_format={"type": "json_object"},
  400. temperature=0
  401. )
  402. result = json.loads(response.choices[0].message.content)
  403. is_dup = result.get('is_duplicate', False)
  404. reason = result.get('reason', '')
  405. print(f"🤖 GPT-4o 判定结果: {'重复' if is_dup else '不重复'} | 理由: {reason}")
  406. return is_dup
  407. except Exception as e:
  408. print(f"❌ GPT-4o 兜底计算失败: {e}")
  409. return False
  410. def check_duplicate(self, question_id: int, threshold: float = 0.85) -> Dict:
  411. """
  412. 查重主逻辑 (未入库模式:不更新数据库,不自动入库)
  413. """
  414. self.ensure_index_loaded()
  415. # 1. 获取题目信息
  416. question = self.fetch_question_from_db(question_id)
  417. if not question:
  418. return {"status": "error", "message": f"未找到ID为 {question_id} 的题目"}
  419. # 2. 获取加权拼接向量
  420. mega_vector = self.get_weighted_embedding(question)
  421. if mega_vector is None:
  422. return {"status": "error", "message": "无法生成题目向量"}
  423. # 3. 搜索索引
  424. if self.index.ntotal == 0:
  425. return {
  426. "status": "success",
  427. "message": "向量库为空",
  428. "is_duplicate": False,
  429. "top_similar": []
  430. }
  431. # 搜索最相似的题目
  432. k = min(3, self.index.ntotal)
  433. query_vec = mega_vector.reshape(1, -1)
  434. similarities, indices = self.index.search(query_vec, k)
  435. best_similar = None
  436. is_duplicate = False
  437. gpt_checked = False
  438. for i in range(k):
  439. idx = indices[0][i]
  440. score = float(similarities[0][i])
  441. if idx != -1:
  442. similar_info = self.metadata[idx]
  443. similar_id = similar_info['id']
  444. if similar_id == question_id:
  445. continue
  446. best_similar = {
  447. "id": similar_id,
  448. "similarity": round(score, 4),
  449. "similar_point": similar_info.get('text_preview', '')
  450. }
  451. if score >= threshold:
  452. is_duplicate = True
  453. elif score > 0.5:
  454. # 相似度在 0.5 ~ 0.85 之间,调用 GPT-4o 兜底
  455. similar_question = self.fetch_question_from_db(similar_id)
  456. if similar_question:
  457. is_duplicate = self.call_openai_for_duplicate_check(question, similar_question)
  458. gpt_checked = True
  459. break
  460. if is_duplicate:
  461. return {
  462. "status": "warning",
  463. "message": "该题目与已存在题目相似,请人工核验" + (" (经 GPT-4o 兜底确认)" if gpt_checked else ""),
  464. "is_duplicate": True,
  465. "gpt_checked": gpt_checked,
  466. "top_similar": [best_similar]
  467. }
  468. else:
  469. return {
  470. "status": "success",
  471. "message": "该题目通过查重",
  472. "is_duplicate": False,
  473. "top_similar": [best_similar] if best_similar else []
  474. }
  475. def check_duplicate_by_content(self, question_data: Dict, threshold: float = 0.85) -> Dict:
  476. """
  477. 基于原始文本内容进行查重 (预检模式)
  478. """
  479. self.ensure_index_loaded()
  480. # 1. 获取加权拼接向量
  481. mega_vector = self.get_weighted_embedding(question_data)
  482. if mega_vector is None:
  483. return {"status": "error", "message": "无法生成题目向量"}
  484. # 2. 检索
  485. if self.index.ntotal == 0:
  486. return {"status": "success", "is_duplicate": False, "top_similar": []}
  487. k = 1 # 预检通常只需返回最相似的一个
  488. query_vec = mega_vector.reshape(1, -1)
  489. similarities, indices = self.index.search(query_vec, k)
  490. if indices[0][0] != -1:
  491. score = float(similarities[0][0])
  492. similar_info = self.metadata[indices[0][0]]
  493. similar_id = similar_info['id']
  494. best_similar = {
  495. "id": similar_id,
  496. "similarity": round(score, 4),
  497. "similar_point": similar_info.get('text_preview', '')
  498. }
  499. is_duplicate = False
  500. gpt_checked = False
  501. if score >= threshold:
  502. is_duplicate = True
  503. elif score > 0.5:
  504. # 相似度在 0.5 ~ 0.85 之间,调用 GPT-4o 兜底
  505. similar_question = self.fetch_question_from_db(similar_id)
  506. if similar_question:
  507. is_duplicate = self.call_openai_for_duplicate_check(question_data, similar_question)
  508. gpt_checked = True
  509. if is_duplicate:
  510. return {
  511. "status": "warning",
  512. "is_duplicate": True,
  513. "gpt_checked": gpt_checked,
  514. "top_similar": [best_similar]
  515. }
  516. return {"status": "success", "is_duplicate": False, "top_similar": [best_similar]}
  517. return {"status": "success", "is_duplicate": False, "top_similar": []}
  518. def get_question_data(self, question_id: int) -> Dict:
  519. """获取向量库中特定ID的数据及总数"""
  520. self.ensure_index_loaded()
  521. total_count = self.index.ntotal if self.index else 0
  522. target_metadata = None
  523. for m in self.metadata:
  524. if m['id'] == question_id:
  525. target_metadata = m
  526. break
  527. return {
  528. "total_count": total_count,
  529. "question_data": target_metadata
  530. }
  531. def add_to_index(self, question_id: int, vector: np.ndarray, text: str):
  532. """将题目加入索引"""
  533. if any(m['id'] == question_id for m in self.metadata):
  534. return
  535. self.index.add(vector)
  536. self.metadata.append({
  537. 'id': question_id,
  538. 'text_preview': self.clean_text(text)[:100]
  539. })
  540. self.save_index()
  541. def sync_all_from_db(self, batch_size=50, max_workers=None):
  542. """同步数据库中所有题目到索引 (支持加权模式 + 批量 + 多线程)"""
  543. if not self.sync_lock.acquire(blocking=False):
  544. print("⏳ 同步正在进行中,已忽略重复请求")
  545. return False
  546. self.sync_in_progress = True
  547. print("🔄 开始全量同步 (优化版 - 加权模式)...")
  548. existing_ids = {m['id'] for m in self.metadata}
  549. try:
  550. conn = pymysql.connect(
  551. host=DB_HOST,
  552. port=DB_PORT,
  553. user=DB_USERNAME,
  554. password=DB_PASSWORD,
  555. database=DB_DATABASE,
  556. charset='utf8mb4',
  557. cursorclass=pymysql.cursors.DictCursor
  558. )
  559. with conn.cursor() as cursor:
  560. print("📡 正在从数据库读取所有题目数据...")
  561. sql = "SELECT id, stem, options, answer, solution FROM questions_tem"
  562. cursor.execute(sql)
  563. all_questions = cursor.fetchall()
  564. print(f"📦 数据库加载完成,共计 {len(all_questions)} 条记录")
  565. new_questions = [q for q in all_questions if q['id'] not in existing_ids]
  566. total_new = len(new_questions)
  567. if total_new == 0:
  568. print("✅ 已经是最新状态,无需同步。")
  569. return True
  570. print(f"📊 数据库总计: {len(all_questions)}, 需同步新增: {total_new}")
  571. # 分块处理
  572. chunks = [new_questions[i:i + batch_size] for i in range(0, total_new, batch_size)]
  573. def process_chunk(chunk):
  574. mega_vectors = self.get_weighted_embeddings_batch(chunk)
  575. return chunk, mega_vectors
  576. # 使用线程池并发(小机器限制并发,避免 CPU 过载)
  577. worker_count = max_workers or max(1, min(2, os.cpu_count() or 1))
  578. with concurrent.futures.ThreadPoolExecutor(max_workers=worker_count) as executor:
  579. future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
  580. count = 0
  581. for future in concurrent.futures.as_completed(future_to_chunk):
  582. chunk, mega_vectors = future.result()
  583. if mega_vectors:
  584. for i, mega_vec in enumerate(mega_vectors):
  585. if mega_vec is not None:
  586. self.index.add(mega_vec.reshape(1, -1))
  587. self.metadata.append({
  588. 'id': chunk[i]['id'],
  589. 'text_preview': self.clean_text(chunk[i].get('stem', ''))[:100]
  590. })
  591. count += len(chunk)
  592. print(f"✅ 已同步进度: {count}/{total_new}")
  593. # 每 500 条保存一次
  594. if count % 500 == 0:
  595. self.save_index()
  596. self.save_index()
  597. print(f"🎉 同步完成!当前索引总数: {len(self.metadata)}")
  598. return True
  599. except Exception as e:
  600. print(f"❌ 同步失败: {e}")
  601. import traceback
  602. traceback.print_exc()
  603. return False
  604. finally:
  605. if 'conn' in locals() and conn:
  606. conn.close()
  607. self.sync_in_progress = False
  608. if self.sync_lock.locked():
  609. self.sync_lock.release()
  610. if __name__ == "__main__":
  611. # 测试代码
  612. checker = QuestionDuplicateChecker()
  613. # checker.sync_all_from_db() # 首次运行时同步
  614. # result = checker.check_duplicate(10)
  615. # print(json.dumps(result, ensure_ascii=False, indent=2))