|
|
@@ -0,0 +1,563 @@
|
|
|
+"""
|
|
|
+数学题目查重助手 - 使用向量相似度比对题目是否重复
|
|
|
+"""
|
|
|
+import json
|
|
|
+import os
|
|
|
+import pickle
|
|
|
+import re
|
|
|
+import numpy as np
|
|
|
+import faiss
|
|
|
+import pymysql
|
|
|
+import concurrent.futures
|
|
|
+from openai import OpenAI
|
|
|
+from typing import List, Dict, Tuple, Optional
|
|
|
+from config import DB_HOST, DB_PORT, DB_DATABASE, DB_USERNAME, DB_PASSWORD, OPENAI_API_KEY, OPENAI_BASE_URL
|
|
|
+
|
|
|
+class QuestionDuplicateChecker:
|
|
|
+ """题目查重器"""
|
|
|
+
|
|
|
+ def __init__(self, index_path="questions_tem.index", metadata_path="questions_tem_metadata.pkl"):
|
|
|
+ """
|
|
|
+ 初始化查重器
|
|
|
+
|
|
|
+ Args:
|
|
|
+ index_path: FAISS索引保存路径
|
|
|
+ metadata_path: 元数据保存路径
|
|
|
+ """
|
|
|
+ self.client = OpenAI(
|
|
|
+ api_key=OPENAI_API_KEY,
|
|
|
+ base_url=OPENAI_BASE_URL
|
|
|
+ )
|
|
|
+ self.index_path = index_path
|
|
|
+ self.metadata_path = metadata_path
|
|
|
+ self.index = None
|
|
|
+ self.metadata = [] # 存储题目ID和文本,以便回显
|
|
|
+
|
|
|
+ # 权重设置
|
|
|
+ self.weights = {
|
|
|
+ 'stem': 0.8,
|
|
|
+ 'options': 0.05,
|
|
|
+ 'answer': 0.05,
|
|
|
+ 'solution': 0.1
|
|
|
+ }
|
|
|
+ # 维度调整:4部分拼接后的总维度为 3072 * 4 = 12288
|
|
|
+ self.dimension = 3072 * 4
|
|
|
+
|
|
|
+ self._load_index()
|
|
|
+
|
|
|
+ def clean_text(self, text: str) -> str:
|
|
|
+ """清理文本:移除图片标签、归一化空白字符"""
|
|
|
+ if not text:
|
|
|
+ return ""
|
|
|
+ # 移除 <image .../> 标签
|
|
|
+ text = re.sub(r'<image\s+[^>]*/>', '', text)
|
|
|
+ # 归一化空白:将多个空格、换行替换为单个空格
|
|
|
+ text = re.sub(r'\s+', ' ', text)
|
|
|
+ return text.strip()
|
|
|
+
|
|
|
+ def normalize_options(self, options) -> str:
|
|
|
+ """标准化选项内容"""
|
|
|
+ if not options:
|
|
|
+ return ""
|
|
|
+
|
|
|
+ data = None
|
|
|
+ if isinstance(options, str):
|
|
|
+ # 如果是字符串,尝试解析 JSON
|
|
|
+ text = options.strip()
|
|
|
+ if (text.startswith('{') and text.endswith('}')) or (text.startswith('[') and text.endswith(']')):
|
|
|
+ try:
|
|
|
+ data = json.loads(text)
|
|
|
+ except:
|
|
|
+ pass
|
|
|
+ if data is None:
|
|
|
+ return self.clean_text(options)
|
|
|
+ else:
|
|
|
+ data = options
|
|
|
+
|
|
|
+ if isinstance(data, dict):
|
|
|
+ # 排序并清理字典内容
|
|
|
+ sorted_data = {str(k).strip(): self.clean_text(str(v)) for k, v in sorted(data.items())}
|
|
|
+ return self.clean_text(json.dumps(sorted_data, ensure_ascii=False))
|
|
|
+ elif isinstance(data, list):
|
|
|
+ # 清理列表内容
|
|
|
+ cleaned_data = [self.clean_text(str(v)) for v in data]
|
|
|
+ return self.clean_text(json.dumps(cleaned_data, ensure_ascii=False))
|
|
|
+
|
|
|
+ return self.clean_text(str(data))
|
|
|
+
|
|
|
+ def _load_index(self):
|
|
|
+ """加载或初始化索引"""
|
|
|
+ if os.path.exists(self.index_path) and os.path.exists(self.metadata_path):
|
|
|
+ try:
|
|
|
+ self.index = faiss.read_index(self.index_path)
|
|
|
+ # 检查索引维度是否匹配,如果不匹配则重新初始化(针对权重升级)
|
|
|
+ if self.index.d != self.dimension:
|
|
|
+ print(f"⚠️ 索引维度不匹配 ({self.index.d} != {self.dimension}),正在重新初始化...")
|
|
|
+ self._init_new_index()
|
|
|
+ else:
|
|
|
+ with open(self.metadata_path, 'rb') as f:
|
|
|
+ self.metadata = pickle.load(f)
|
|
|
+ print(f"✓ 已加载现有索引,包含 {len(self.metadata)} 道题目")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"⚠️ 加载索引失败: {e},将初始化新索引")
|
|
|
+ self._init_new_index()
|
|
|
+ else:
|
|
|
+ self._init_new_index()
|
|
|
+
|
|
|
+ def _init_new_index(self):
|
|
|
+ """初始化新的FAISS索引"""
|
|
|
+ # 使用内积索引,配合权重拼接向量
|
|
|
+ self.index = faiss.IndexFlatIP(self.dimension)
|
|
|
+ self.metadata = []
|
|
|
+ print("✓ 已初始化新的FAISS索引")
|
|
|
+
|
|
|
+ def save_index(self):
|
|
|
+ """保存索引和元数据"""
|
|
|
+ faiss.write_index(self.index, self.index_path)
|
|
|
+ with open(self.metadata_path, 'wb') as f:
|
|
|
+ pickle.dump(self.metadata, f)
|
|
|
+ print(f"✓ 索引和元数据已保存到 {self.index_path}")
|
|
|
+
|
|
|
+ def get_weighted_embedding(self, question: Dict) -> np.ndarray:
|
|
|
+ """
|
|
|
+ 获取加权拼接向量
|
|
|
+ 计算 4 个部分的 embedding,并乘以各自权重的平方根后拼接
|
|
|
+ """
|
|
|
+ parts = {
|
|
|
+ 'stem': self.clean_text(question.get('stem', '') or ''),
|
|
|
+ 'options': self.normalize_options(question.get('options', '') or ''),
|
|
|
+ 'answer': self.clean_text(question.get('answer', '') or ''),
|
|
|
+ 'solution': self.clean_text(question.get('solution', '') or '')
|
|
|
+ }
|
|
|
+
|
|
|
+ embeddings = []
|
|
|
+ for key, text in parts.items():
|
|
|
+ # 获取单部分 embedding
|
|
|
+ emb = self.get_embedding(text if text.strip() else "无") # 防止空文本
|
|
|
+ if emb is None:
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 归一化该部分的向量
|
|
|
+ norm = np.linalg.norm(emb)
|
|
|
+ if norm > 0:
|
|
|
+ emb = emb / norm
|
|
|
+
|
|
|
+ # 乘以权重平方根:这样点积结果就是权重加权和
|
|
|
+ weight_factor = np.sqrt(self.weights[key])
|
|
|
+ embeddings.append(emb * weight_factor)
|
|
|
+
|
|
|
+ # 拼接成一个大向量
|
|
|
+ mega_vector = np.concatenate(embeddings).astype('float32')
|
|
|
+ return mega_vector
|
|
|
+
|
|
|
+ def get_embedding(self, text: str) -> np.ndarray:
|
|
|
+ """获取文本的embedding向量"""
|
|
|
+ try:
|
|
|
+ # 统一清理文本
|
|
|
+ cleaned_text = self.clean_text(text)
|
|
|
+ if not cleaned_text.strip():
|
|
|
+ cleaned_text = "无"
|
|
|
+
|
|
|
+ response = self.client.embeddings.create(
|
|
|
+ model="text-embedding-3-large",
|
|
|
+ input=cleaned_text
|
|
|
+ )
|
|
|
+ embedding = np.array(response.data[0].embedding).astype('float32')
|
|
|
+ return embedding
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 获取embedding失败: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ def get_embeddings_batch(self, texts: List[str]) -> List[np.ndarray]:
|
|
|
+ """批量获取向量化结果"""
|
|
|
+ try:
|
|
|
+ if not texts:
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 统一清理文本
|
|
|
+ cleaned_texts = []
|
|
|
+ for t in texts:
|
|
|
+ c = self.clean_text(t)
|
|
|
+ cleaned_texts.append(c if c.strip() else "无")
|
|
|
+
|
|
|
+ response = self.client.embeddings.create(
|
|
|
+ model="text-embedding-3-large",
|
|
|
+ input=cleaned_texts
|
|
|
+ )
|
|
|
+ return [np.array(data.embedding).astype('float32') for data in response.data]
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 批量获取 embedding 失败: {e}")
|
|
|
+ return []
|
|
|
+
|
|
|
+ def get_weighted_embeddings_batch(self, questions: List[Dict]) -> List[np.ndarray]:
|
|
|
+ """
|
|
|
+ 批量获取加权拼接向量
|
|
|
+ """
|
|
|
+ all_texts = []
|
|
|
+ for q in questions:
|
|
|
+ # 获取 4 个部分并处理空值占位
|
|
|
+ stem = self.clean_text(q.get('stem', '') or '')
|
|
|
+ opts = self.normalize_options(q.get('options', '') or '')
|
|
|
+ ans = self.clean_text(q.get('answer', '') or '')
|
|
|
+ sol = self.clean_text(q.get('solution', '') or '')
|
|
|
+
|
|
|
+ # OpenAI API 不接受空字符串,强制填充 "无"
|
|
|
+ all_texts.append(stem if stem.strip() else "无")
|
|
|
+ all_texts.append(opts if opts.strip() else "无")
|
|
|
+ all_texts.append(ans if ans.strip() else "无")
|
|
|
+ all_texts.append(sol if sol.strip() else "无")
|
|
|
+
|
|
|
+ # 批量请求 (每题 4 个部分)
|
|
|
+ all_embeddings = self.get_embeddings_batch(all_texts)
|
|
|
+ if not all_embeddings:
|
|
|
+ return []
|
|
|
+
|
|
|
+ results = []
|
|
|
+ for i in range(len(questions)):
|
|
|
+ q_embs = all_embeddings[i*4 : (i+1)*4]
|
|
|
+ if len(q_embs) < 4:
|
|
|
+ results.append(None)
|
|
|
+ continue
|
|
|
+
|
|
|
+ weighted_parts = []
|
|
|
+ keys = ['stem', 'options', 'answer', 'solution']
|
|
|
+ for j, emb in enumerate(q_embs):
|
|
|
+ # 归一化
|
|
|
+ norm = np.linalg.norm(emb)
|
|
|
+ if norm > 0:
|
|
|
+ emb = emb / norm
|
|
|
+ # 加权
|
|
|
+ weight_factor = np.sqrt(self.weights[keys[j]])
|
|
|
+ weighted_parts.append(emb * weight_factor)
|
|
|
+
|
|
|
+ mega_vector = np.concatenate(weighted_parts).astype('float32')
|
|
|
+ results.append(mega_vector)
|
|
|
+
|
|
|
+ return results
|
|
|
+
|
|
|
+ def combine_content(self, q: Dict) -> str:
|
|
|
+ """组合题目内容用于预览显示"""
|
|
|
+ stem = q.get('stem', '') or ''
|
|
|
+ return self.clean_text(stem) # 预览以清理后的题干为主
|
|
|
+
|
|
|
+ def fetch_question_from_db(self, question_id: int) -> Optional[Dict]:
|
|
|
+ """从数据库获取题目信息"""
|
|
|
+ try:
|
|
|
+ conn = pymysql.connect(
|
|
|
+ host=DB_HOST,
|
|
|
+ port=DB_PORT,
|
|
|
+ user=DB_USERNAME,
|
|
|
+ password=DB_PASSWORD,
|
|
|
+ database=DB_DATABASE,
|
|
|
+ charset='utf8mb4',
|
|
|
+ cursorclass=pymysql.cursors.DictCursor
|
|
|
+ )
|
|
|
+ with conn.cursor() as cursor:
|
|
|
+ sql = "SELECT id, stem, options, answer, solution, is_repeat FROM questions_tem WHERE id = %s"
|
|
|
+ cursor.execute(sql, (question_id,))
|
|
|
+ result = cursor.fetchone()
|
|
|
+ return result
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 数据库查询失败: {e}")
|
|
|
+ return None
|
|
|
+ finally:
|
|
|
+ if 'conn' in locals() and conn:
|
|
|
+ conn.close()
|
|
|
+
|
|
|
+ def confirm_repeat(self, question_id: int, is_repeat: int) -> bool:
|
|
|
+ """
|
|
|
+ 人工确认结果后更新数据库和向量库
|
|
|
+ is_repeat: 0 代表无相似题,1 代表有重复
|
|
|
+ """
|
|
|
+ # 1. 更新数据库状态
|
|
|
+ self.update_is_repeat(question_id, is_repeat)
|
|
|
+
|
|
|
+ # 2. 如果标记为非重复,将其加入向量库
|
|
|
+ if is_repeat == 0:
|
|
|
+ question = self.fetch_question_from_db(question_id)
|
|
|
+ if question:
|
|
|
+ mega_vec = self.get_weighted_embedding(question)
|
|
|
+ if mega_vec is not None:
|
|
|
+ # 转换维度以便 FAISS 接收
|
|
|
+ mega_vec = mega_vec.reshape(1, -1)
|
|
|
+ self.add_to_index(question_id, mega_vec, self.clean_text(question.get('stem', ''))[:100])
|
|
|
+ return True
|
|
|
+ return True
|
|
|
+
|
|
|
+ def update_is_repeat(self, question_id: int, is_repeat: int):
|
|
|
+ """更新数据库中的 is_repeat 字段"""
|
|
|
+ try:
|
|
|
+ conn = pymysql.connect(
|
|
|
+ host=DB_HOST,
|
|
|
+ port=DB_PORT,
|
|
|
+ user=DB_USERNAME,
|
|
|
+ password=DB_PASSWORD,
|
|
|
+ database=DB_DATABASE,
|
|
|
+ charset='utf8mb4'
|
|
|
+ )
|
|
|
+ with conn.cursor() as cursor:
|
|
|
+ sql = "UPDATE questions_tem SET is_repeat = %s WHERE id = %s"
|
|
|
+ cursor.execute(sql, (is_repeat, question_id))
|
|
|
+ conn.commit()
|
|
|
+ print(f"✓ 题目 ID {question_id} 的 is_repeat 已更新为 {is_repeat}")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 更新 is_repeat 失败: {e}")
|
|
|
+ finally:
|
|
|
+ if 'conn' in locals() and conn:
|
|
|
+ conn.close()
|
|
|
+
|
|
|
+ def call_openai_for_duplicate_check(self, q1: Dict, q2: Dict) -> bool:
|
|
|
+ """
|
|
|
+ 调用 GPT-4o 进行深度对比,判断两道题是否为重复题
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ prompt = f"""你是一个专业的数学题目查重专家。请对比以下两道题目,判断它们是否为“重复题”。
|
|
|
+“重复题”的定义:题目背景、考查知识点、核心逻辑以及数值完全一致,或者仅有极其微小的表述差异但不影响题目本质。
|
|
|
+
|
|
|
+题目 1:
|
|
|
+题干: {q1.get('stem', '')}
|
|
|
+选项: {q1.get('options', '')}
|
|
|
+答案: {q1.get('answer', '')}
|
|
|
+解析: {q1.get('solution', '')}
|
|
|
+
|
|
|
+题目 2:
|
|
|
+题干: {q2.get('stem', '')}
|
|
|
+选项: {q2.get('options', '')}
|
|
|
+答案: {q2.get('answer', '')}
|
|
|
+解析: {q2.get('solution', '')}
|
|
|
+
|
|
|
+请分析两道题的相似性,并给出最终结论。
|
|
|
+结论必须以 JSON 格式输出,包含以下字段:
|
|
|
+- is_duplicate: 布尔值,True 表示是重复题,False 表示不是
|
|
|
+- reason: 简短的分析理由
|
|
|
+
|
|
|
+只输出 JSON 即可,不要有其他解释说明。"""
|
|
|
+
|
|
|
+ response = self.client.chat.completions.create(
|
|
|
+ model="gpt-4o",
|
|
|
+ messages=[
|
|
|
+ {"role": "system", "content": "你是一个专业的数学题目查重助手。"},
|
|
|
+ {"role": "user", "content": prompt}
|
|
|
+ ],
|
|
|
+ response_format={"type": "json_object"}
|
|
|
+ )
|
|
|
+
|
|
|
+ result = json.loads(response.choices[0].message.content)
|
|
|
+ is_dup = result.get('is_duplicate', False)
|
|
|
+ reason = result.get('reason', '')
|
|
|
+ print(f"🤖 GPT-4o 判定结果: {'重复' if is_dup else '不重复'} | 理由: {reason}")
|
|
|
+ return is_dup
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ GPT-4o 兜底计算失败: {e}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ def check_duplicate(self, question_id: int, threshold: float = 0.85) -> Dict:
|
|
|
+ """
|
|
|
+ 查重主逻辑 (未入库模式:不更新数据库,不自动入库)
|
|
|
+ """
|
|
|
+ # 1. 获取题目信息
|
|
|
+ question = self.fetch_question_from_db(question_id)
|
|
|
+ if not question:
|
|
|
+ return {"status": "error", "message": f"未找到ID为 {question_id} 的题目"}
|
|
|
+
|
|
|
+ # 2. 获取加权拼接向量
|
|
|
+ mega_vector = self.get_weighted_embedding(question)
|
|
|
+ if mega_vector is None:
|
|
|
+ return {"status": "error", "message": "无法生成题目向量"}
|
|
|
+
|
|
|
+ # 3. 搜索索引
|
|
|
+ if self.index.ntotal == 0:
|
|
|
+ return {
|
|
|
+ "status": "success",
|
|
|
+ "message": "向量库为空",
|
|
|
+ "is_duplicate": False,
|
|
|
+ "top_similar": []
|
|
|
+ }
|
|
|
+
|
|
|
+ # 搜索最相似的题目
|
|
|
+ k = min(3, self.index.ntotal)
|
|
|
+ query_vec = mega_vector.reshape(1, -1)
|
|
|
+ similarities, indices = self.index.search(query_vec, k)
|
|
|
+
|
|
|
+ best_similar = None
|
|
|
+ is_duplicate = False
|
|
|
+ gpt_checked = False
|
|
|
+
|
|
|
+ for i in range(k):
|
|
|
+ idx = indices[0][i]
|
|
|
+ score = float(similarities[0][i])
|
|
|
+ if idx != -1:
|
|
|
+ similar_info = self.metadata[idx]
|
|
|
+ similar_id = similar_info['id']
|
|
|
+
|
|
|
+ if similar_id == question_id:
|
|
|
+ continue
|
|
|
+
|
|
|
+ best_similar = {
|
|
|
+ "id": similar_id,
|
|
|
+ "similarity": round(score, 4),
|
|
|
+ "similar_point": similar_info.get('text_preview', '')
|
|
|
+ }
|
|
|
+
|
|
|
+ if score >= threshold:
|
|
|
+ is_duplicate = True
|
|
|
+ elif score > 0.5:
|
|
|
+ # 相似度在 0.5 ~ 0.85 之间,调用 GPT-4o 兜底
|
|
|
+ similar_question = self.fetch_question_from_db(similar_id)
|
|
|
+ if similar_question:
|
|
|
+ is_duplicate = self.call_openai_for_duplicate_check(question, similar_question)
|
|
|
+ gpt_checked = True
|
|
|
+ break
|
|
|
+
|
|
|
+ if is_duplicate:
|
|
|
+ return {
|
|
|
+ "status": "warning",
|
|
|
+ "message": "该题目与已存在题目相似,请人工核验" + (" (经 GPT-4o 兜底确认)" if gpt_checked else ""),
|
|
|
+ "is_duplicate": True,
|
|
|
+ "gpt_checked": gpt_checked,
|
|
|
+ "top_similar": [best_similar]
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ return {
|
|
|
+ "status": "success",
|
|
|
+ "message": "该题目通过查重",
|
|
|
+ "is_duplicate": False,
|
|
|
+ "top_similar": [best_similar] if best_similar else []
|
|
|
+ }
|
|
|
+
|
|
|
+ def check_duplicate_by_content(self, question_data: Dict, threshold: float = 0.85) -> Dict:
|
|
|
+ """
|
|
|
+ 基于原始文本内容进行查重 (预检模式)
|
|
|
+ """
|
|
|
+ # 1. 获取加权拼接向量
|
|
|
+ mega_vector = self.get_weighted_embedding(question_data)
|
|
|
+ if mega_vector is None:
|
|
|
+ return {"status": "error", "message": "无法生成题目向量"}
|
|
|
+
|
|
|
+ # 2. 检索
|
|
|
+ if self.index.ntotal == 0:
|
|
|
+ return {"status": "success", "is_duplicate": False, "top_similar": []}
|
|
|
+
|
|
|
+ k = 1 # 预检通常只需返回最相似的一个
|
|
|
+ query_vec = mega_vector.reshape(1, -1)
|
|
|
+ similarities, indices = self.index.search(query_vec, k)
|
|
|
+
|
|
|
+ if indices[0][0] != -1:
|
|
|
+ score = float(similarities[0][0])
|
|
|
+ similar_info = self.metadata[indices[0][0]]
|
|
|
+ similar_id = similar_info['id']
|
|
|
+ best_similar = {
|
|
|
+ "id": similar_id,
|
|
|
+ "similarity": round(score, 4),
|
|
|
+ "similar_point": similar_info.get('text_preview', '')
|
|
|
+ }
|
|
|
+
|
|
|
+ is_duplicate = False
|
|
|
+ gpt_checked = False
|
|
|
+
|
|
|
+ if score >= threshold:
|
|
|
+ is_duplicate = True
|
|
|
+ elif score > 0.5:
|
|
|
+ # 相似度在 0.5 ~ 0.85 之间,调用 GPT-4o 兜底
|
|
|
+ similar_question = self.fetch_question_from_db(similar_id)
|
|
|
+ if similar_question:
|
|
|
+ is_duplicate = self.call_openai_for_duplicate_check(question_data, similar_question)
|
|
|
+ gpt_checked = True
|
|
|
+
|
|
|
+ if is_duplicate:
|
|
|
+ return {
|
|
|
+ "status": "warning",
|
|
|
+ "is_duplicate": True,
|
|
|
+ "gpt_checked": gpt_checked,
|
|
|
+ "top_similar": [best_similar]
|
|
|
+ }
|
|
|
+ return {"status": "success", "is_duplicate": False, "top_similar": [best_similar]}
|
|
|
+
|
|
|
+ return {"status": "success", "is_duplicate": False, "top_similar": []}
|
|
|
+
|
|
|
+ def add_to_index(self, question_id: int, vector: np.ndarray, text: str):
|
|
|
+ """将题目加入索引"""
|
|
|
+ if any(m['id'] == question_id for m in self.metadata):
|
|
|
+ return
|
|
|
+
|
|
|
+ self.index.add(vector)
|
|
|
+ self.metadata.append({
|
|
|
+ 'id': question_id,
|
|
|
+ 'text_preview': self.clean_text(text)[:100]
|
|
|
+ })
|
|
|
+ self.save_index()
|
|
|
+
|
|
|
+ def sync_all_from_db(self, batch_size=50, max_workers=5):
|
|
|
+ """同步数据库中所有题目到索引 (支持加权模式 + 批量 + 多线程)"""
|
|
|
+ print("正在进行全量同步 (优化版 - 加权模式)...")
|
|
|
+ existing_ids = {m['id'] for m in self.metadata}
|
|
|
+ try:
|
|
|
+ conn = pymysql.connect(
|
|
|
+ host=DB_HOST,
|
|
|
+ port=DB_PORT,
|
|
|
+ user=DB_USERNAME,
|
|
|
+ password=DB_PASSWORD,
|
|
|
+ database=DB_DATABASE,
|
|
|
+ charset='utf8mb4',
|
|
|
+ cursorclass=pymysql.cursors.DictCursor
|
|
|
+ )
|
|
|
+ with conn.cursor() as cursor:
|
|
|
+ sql = "SELECT id, stem, options, answer, solution FROM questions_tem"
|
|
|
+ cursor.execute(sql)
|
|
|
+ all_questions = cursor.fetchall()
|
|
|
+
|
|
|
+ new_questions = [q for q in all_questions if q['id'] not in existing_ids]
|
|
|
+ total_new = len(new_questions)
|
|
|
+
|
|
|
+ if total_new == 0:
|
|
|
+ print("✅ 已经是最新状态,无需同步。")
|
|
|
+ return
|
|
|
+
|
|
|
+ print(f"📊 数据库总计: {len(all_questions)}, 需同步新增: {total_new}")
|
|
|
+
|
|
|
+ # 分块处理
|
|
|
+ chunks = [new_questions[i:i + batch_size] for i in range(0, total_new, batch_size)]
|
|
|
+
|
|
|
+ def process_chunk(chunk):
|
|
|
+ mega_vectors = self.get_weighted_embeddings_batch(chunk)
|
|
|
+ return chunk, mega_vectors
|
|
|
+
|
|
|
+ # 使用线程池并发
|
|
|
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
+ future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
|
|
|
+
|
|
|
+ count = 0
|
|
|
+ for future in concurrent.futures.as_completed(future_to_chunk):
|
|
|
+ chunk, mega_vectors = future.result()
|
|
|
+ if mega_vectors:
|
|
|
+ for i, mega_vec in enumerate(mega_vectors):
|
|
|
+ if mega_vec is not None:
|
|
|
+ self.index.add(mega_vec.reshape(1, -1))
|
|
|
+ self.metadata.append({
|
|
|
+ 'id': chunk[i]['id'],
|
|
|
+ 'text_preview': self.clean_text(chunk[i].get('stem', ''))[:100]
|
|
|
+ })
|
|
|
+
|
|
|
+ count += len(chunk)
|
|
|
+ print(f"✅ 已同步进度: {count}/{total_new}")
|
|
|
+
|
|
|
+ # 每 500 条保存一次
|
|
|
+ if count % 500 == 0:
|
|
|
+ self.save_index()
|
|
|
+
|
|
|
+ self.save_index()
|
|
|
+ print(f"🎉 同步完成!当前索引总数: {len(self.metadata)}")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ 同步失败: {e}")
|
|
|
+ import traceback
|
|
|
+ traceback.print_exc()
|
|
|
+ finally:
|
|
|
+ if 'conn' in locals() and conn:
|
|
|
+ conn.close()
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ # 测试代码
|
|
|
+ checker = QuestionDuplicateChecker()
|
|
|
+ # checker.sync_all_from_db() # 首次运行时同步
|
|
|
+ # result = checker.check_duplicate(10)
|
|
|
+ # print(json.dumps(result, ensure_ascii=False, indent=2))
|