""" 数学题目查重助手 - 使用向量相似度比对题目是否重复 """ import json import os import pickle import re import numpy as np import faiss import pymysql import concurrent.futures from openai import OpenAI from typing import List, Dict, Tuple, Optional from config import DB_HOST, DB_PORT, DB_DATABASE, DB_USERNAME, DB_PASSWORD, OPENAI_API_KEY, OPENAI_BASE_URL class QuestionDuplicateChecker: """题目查重器""" def __init__(self, index_path="questions_tem.index", metadata_path="questions_tem_metadata.pkl"): """ 初始化查重器 Args: index_path: FAISS索引保存路径 metadata_path: 元数据保存路径 """ self.client = OpenAI( api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL ) self.index_path = index_path self.metadata_path = metadata_path self.index = None self.metadata = [] # 存储题目ID和文本,以便回显 # 权重设置 self.weights = { 'stem': 0.8, 'options': 0.05, 'answer': 0.05, 'solution': 0.1 } # 维度调整:4部分拼接后的总维度为 3072 * 4 = 12288 self.dimension = 3072 * 4 self._load_index() def clean_text(self, text: str) -> str: """清理文本:移除图片标签、归一化空白字符""" if not text: return "" # 移除 标签 text = re.sub(r']*/>', '', text) # 归一化空白:将多个空格、换行替换为单个空格 text = re.sub(r'\s+', ' ', text) return text.strip() def normalize_options(self, options) -> str: """标准化选项内容""" if not options: return "" data = None if isinstance(options, str): # 如果是字符串,尝试解析 JSON text = options.strip() if (text.startswith('{') and text.endswith('}')) or (text.startswith('[') and text.endswith(']')): try: data = json.loads(text) except: pass if data is None: return self.clean_text(options) else: data = options if isinstance(data, dict): # 排序并清理字典内容 sorted_data = {str(k).strip(): self.clean_text(str(v)) for k, v in sorted(data.items())} return self.clean_text(json.dumps(sorted_data, ensure_ascii=False)) elif isinstance(data, list): # 清理列表内容 cleaned_data = [self.clean_text(str(v)) for v in data] return self.clean_text(json.dumps(cleaned_data, ensure_ascii=False)) return self.clean_text(str(data)) def _load_index(self): """加载或初始化索引""" if os.path.exists(self.index_path) and os.path.exists(self.metadata_path): try: self.index = faiss.read_index(self.index_path) # 检查索引维度是否匹配,如果不匹配则重新初始化(针对权重升级) if self.index.d != self.dimension: print(f"⚠️ 索引维度不匹配 ({self.index.d} != {self.dimension}),正在重新初始化...") self._init_new_index() else: with open(self.metadata_path, 'rb') as f: self.metadata = pickle.load(f) print(f"✓ 已加载现有索引,包含 {len(self.metadata)} 道题目") except Exception as e: print(f"⚠️ 加载索引失败: {e},将初始化新索引") self._init_new_index() else: self._init_new_index() def _init_new_index(self): """初始化新的FAISS索引""" # 使用内积索引,配合权重拼接向量 self.index = faiss.IndexFlatIP(self.dimension) self.metadata = [] print("✓ 已初始化新的FAISS索引") def save_index(self): """保存索引和元数据""" faiss.write_index(self.index, self.index_path) with open(self.metadata_path, 'wb') as f: pickle.dump(self.metadata, f) print(f"✓ 索引和元数据已保存到 {self.index_path}") def get_weighted_embedding(self, question: Dict) -> np.ndarray: """ 获取加权拼接向量 计算 4 个部分的 embedding,并乘以各自权重的平方根后拼接 """ parts = { 'stem': self.clean_text(question.get('stem', '') or ''), 'options': self.normalize_options(question.get('options', '') or ''), 'answer': self.clean_text(question.get('answer', '') or ''), 'solution': self.clean_text(question.get('solution', '') or '') } embeddings = [] for key, text in parts.items(): # 获取单部分 embedding emb = self.get_embedding(text if text.strip() else "无") # 防止空文本 if emb is None: return None # 归一化该部分的向量 norm = np.linalg.norm(emb) if norm > 0: emb = emb / norm # 乘以权重平方根:这样点积结果就是权重加权和 weight_factor = np.sqrt(self.weights[key]) embeddings.append(emb * weight_factor) # 拼接成一个大向量 mega_vector = np.concatenate(embeddings).astype('float32') return mega_vector def get_embedding(self, text: str) -> np.ndarray: """获取文本的embedding向量""" try: # 统一清理文本 cleaned_text = self.clean_text(text) if not cleaned_text.strip(): cleaned_text = "无" response = self.client.embeddings.create( model="text-embedding-3-large", input=cleaned_text ) embedding = np.array(response.data[0].embedding).astype('float32') return embedding except Exception as e: print(f"❌ 获取embedding失败: {e}") return None def get_embeddings_batch(self, texts: List[str]) -> List[np.ndarray]: """批量获取向量化结果""" try: if not texts: return [] # 统一清理文本 cleaned_texts = [] for t in texts: c = self.clean_text(t) cleaned_texts.append(c if c.strip() else "无") response = self.client.embeddings.create( model="text-embedding-3-large", input=cleaned_texts ) return [np.array(data.embedding).astype('float32') for data in response.data] except Exception as e: print(f"❌ 批量获取 embedding 失败: {e}") return [] def get_weighted_embeddings_batch(self, questions: List[Dict]) -> List[np.ndarray]: """ 批量获取加权拼接向量 """ all_texts = [] for q in questions: # 获取 4 个部分并处理空值占位 stem = self.clean_text(q.get('stem', '') or '') opts = self.normalize_options(q.get('options', '') or '') ans = self.clean_text(q.get('answer', '') or '') sol = self.clean_text(q.get('solution', '') or '') # OpenAI API 不接受空字符串,强制填充 "无" all_texts.append(stem if stem.strip() else "无") all_texts.append(opts if opts.strip() else "无") all_texts.append(ans if ans.strip() else "无") all_texts.append(sol if sol.strip() else "无") # 批量请求 (每题 4 个部分) all_embeddings = self.get_embeddings_batch(all_texts) if not all_embeddings: return [] results = [] for i in range(len(questions)): q_embs = all_embeddings[i*4 : (i+1)*4] if len(q_embs) < 4: results.append(None) continue weighted_parts = [] keys = ['stem', 'options', 'answer', 'solution'] for j, emb in enumerate(q_embs): # 归一化 norm = np.linalg.norm(emb) if norm > 0: emb = emb / norm # 加权 weight_factor = np.sqrt(self.weights[keys[j]]) weighted_parts.append(emb * weight_factor) mega_vector = np.concatenate(weighted_parts).astype('float32') results.append(mega_vector) return results def combine_content(self, q: Dict) -> str: """组合题目内容用于预览显示""" stem = q.get('stem', '') or '' return self.clean_text(stem) # 预览以清理后的题干为主 def fetch_question_from_db(self, question_id: int) -> Optional[Dict]: """从数据库获取题目信息""" try: conn = pymysql.connect( host=DB_HOST, port=DB_PORT, user=DB_USERNAME, password=DB_PASSWORD, database=DB_DATABASE, charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor ) with conn.cursor() as cursor: sql = "SELECT id, stem, options, answer, solution, is_repeat FROM questions_tem WHERE id = %s" cursor.execute(sql, (question_id,)) result = cursor.fetchone() return result except Exception as e: print(f"❌ 数据库查询失败: {e}") return None finally: if 'conn' in locals() and conn: conn.close() def confirm_repeat(self, question_id: int, is_repeat: int) -> bool: """ 人工确认结果后更新数据库和向量库 is_repeat: 0 代表无相似题,1 代表有重复 """ # 1. 更新数据库状态 self.update_is_repeat(question_id, is_repeat) # 2. 如果标记为非重复,将其加入向量库 if is_repeat == 0: question = self.fetch_question_from_db(question_id) if question: mega_vec = self.get_weighted_embedding(question) if mega_vec is not None: # 转换维度以便 FAISS 接收 mega_vec = mega_vec.reshape(1, -1) self.add_to_index(question_id, mega_vec, self.clean_text(question.get('stem', ''))[:100]) return True return True def update_is_repeat(self, question_id: int, is_repeat: int): """更新数据库中的 is_repeat 字段""" try: conn = pymysql.connect( host=DB_HOST, port=DB_PORT, user=DB_USERNAME, password=DB_PASSWORD, database=DB_DATABASE, charset='utf8mb4' ) with conn.cursor() as cursor: sql = "UPDATE questions_tem SET is_repeat = %s WHERE id = %s" cursor.execute(sql, (is_repeat, question_id)) conn.commit() print(f"✓ 题目 ID {question_id} 的 is_repeat 已更新为 {is_repeat}") except Exception as e: print(f"❌ 更新 is_repeat 失败: {e}") finally: if 'conn' in locals() and conn: conn.close() def call_openai_for_duplicate_check(self, q1: Dict, q2: Dict) -> bool: """ 调用 GPT-4o 进行深度对比,判断两道题是否为重复题 """ try: prompt = f"""你是一个专业的数学题目查重专家。请对比以下两道题目,判断它们是否为“重复题”。 “重复题”的定义:题目背景、考查知识点、核心逻辑以及数值完全一致,或者仅有极其微小的表述差异但不影响题目本质。 题目 1: 题干: {q1.get('stem', '')} 选项: {q1.get('options', '')} 答案: {q1.get('answer', '')} 解析: {q1.get('solution', '')} 题目 2: 题干: {q2.get('stem', '')} 选项: {q2.get('options', '')} 答案: {q2.get('answer', '')} 解析: {q2.get('solution', '')} 请分析两道题的相似性,并给出最终结论。 结论必须以 JSON 格式输出,包含以下字段: - is_duplicate: 布尔值,True 表示是重复题,False 表示不是 - reason: 简短的分析理由 只输出 JSON 即可,不要有其他解释说明。""" response = self.client.chat.completions.create( model="gpt-4o", messages=[ {"role": "system", "content": "你是一个专业的数学题目查重助手。"}, {"role": "user", "content": prompt} ], response_format={"type": "json_object"} ) result = json.loads(response.choices[0].message.content) is_dup = result.get('is_duplicate', False) reason = result.get('reason', '') print(f"🤖 GPT-4o 判定结果: {'重复' if is_dup else '不重复'} | 理由: {reason}") return is_dup except Exception as e: print(f"❌ GPT-4o 兜底计算失败: {e}") return False def check_duplicate(self, question_id: int, threshold: float = 0.85) -> Dict: """ 查重主逻辑 (未入库模式:不更新数据库,不自动入库) """ # 1. 获取题目信息 question = self.fetch_question_from_db(question_id) if not question: return {"status": "error", "message": f"未找到ID为 {question_id} 的题目"} # 2. 获取加权拼接向量 mega_vector = self.get_weighted_embedding(question) if mega_vector is None: return {"status": "error", "message": "无法生成题目向量"} # 3. 搜索索引 if self.index.ntotal == 0: return { "status": "success", "message": "向量库为空", "is_duplicate": False, "top_similar": [] } # 搜索最相似的题目 k = min(3, self.index.ntotal) query_vec = mega_vector.reshape(1, -1) similarities, indices = self.index.search(query_vec, k) best_similar = None is_duplicate = False gpt_checked = False for i in range(k): idx = indices[0][i] score = float(similarities[0][i]) if idx != -1: similar_info = self.metadata[idx] similar_id = similar_info['id'] if similar_id == question_id: continue best_similar = { "id": similar_id, "similarity": round(score, 4), "similar_point": similar_info.get('text_preview', '') } if score >= threshold: is_duplicate = True elif score > 0.5: # 相似度在 0.5 ~ 0.85 之间,调用 GPT-4o 兜底 similar_question = self.fetch_question_from_db(similar_id) if similar_question: is_duplicate = self.call_openai_for_duplicate_check(question, similar_question) gpt_checked = True break if is_duplicate: return { "status": "warning", "message": "该题目与已存在题目相似,请人工核验" + (" (经 GPT-4o 兜底确认)" if gpt_checked else ""), "is_duplicate": True, "gpt_checked": gpt_checked, "top_similar": [best_similar] } else: return { "status": "success", "message": "该题目通过查重", "is_duplicate": False, "top_similar": [best_similar] if best_similar else [] } def check_duplicate_by_content(self, question_data: Dict, threshold: float = 0.85) -> Dict: """ 基于原始文本内容进行查重 (预检模式) """ # 1. 获取加权拼接向量 mega_vector = self.get_weighted_embedding(question_data) if mega_vector is None: return {"status": "error", "message": "无法生成题目向量"} # 2. 检索 if self.index.ntotal == 0: return {"status": "success", "is_duplicate": False, "top_similar": []} k = 1 # 预检通常只需返回最相似的一个 query_vec = mega_vector.reshape(1, -1) similarities, indices = self.index.search(query_vec, k) if indices[0][0] != -1: score = float(similarities[0][0]) similar_info = self.metadata[indices[0][0]] similar_id = similar_info['id'] best_similar = { "id": similar_id, "similarity": round(score, 4), "similar_point": similar_info.get('text_preview', '') } is_duplicate = False gpt_checked = False if score >= threshold: is_duplicate = True elif score > 0.5: # 相似度在 0.5 ~ 0.85 之间,调用 GPT-4o 兜底 similar_question = self.fetch_question_from_db(similar_id) if similar_question: is_duplicate = self.call_openai_for_duplicate_check(question_data, similar_question) gpt_checked = True if is_duplicate: return { "status": "warning", "is_duplicate": True, "gpt_checked": gpt_checked, "top_similar": [best_similar] } return {"status": "success", "is_duplicate": False, "top_similar": [best_similar]} return {"status": "success", "is_duplicate": False, "top_similar": []} def add_to_index(self, question_id: int, vector: np.ndarray, text: str): """将题目加入索引""" if any(m['id'] == question_id for m in self.metadata): return self.index.add(vector) self.metadata.append({ 'id': question_id, 'text_preview': self.clean_text(text)[:100] }) self.save_index() def sync_all_from_db(self, batch_size=50, max_workers=5): """同步数据库中所有题目到索引 (支持加权模式 + 批量 + 多线程)""" print("正在进行全量同步 (优化版 - 加权模式)...") existing_ids = {m['id'] for m in self.metadata} try: conn = pymysql.connect( host=DB_HOST, port=DB_PORT, user=DB_USERNAME, password=DB_PASSWORD, database=DB_DATABASE, charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor ) with conn.cursor() as cursor: sql = "SELECT id, stem, options, answer, solution FROM questions_tem" cursor.execute(sql) all_questions = cursor.fetchall() new_questions = [q for q in all_questions if q['id'] not in existing_ids] total_new = len(new_questions) if total_new == 0: print("✅ 已经是最新状态,无需同步。") return print(f"📊 数据库总计: {len(all_questions)}, 需同步新增: {total_new}") # 分块处理 chunks = [new_questions[i:i + batch_size] for i in range(0, total_new, batch_size)] def process_chunk(chunk): mega_vectors = self.get_weighted_embeddings_batch(chunk) return chunk, mega_vectors # 使用线程池并发 with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks} count = 0 for future in concurrent.futures.as_completed(future_to_chunk): chunk, mega_vectors = future.result() if mega_vectors: for i, mega_vec in enumerate(mega_vectors): if mega_vec is not None: self.index.add(mega_vec.reshape(1, -1)) self.metadata.append({ 'id': chunk[i]['id'], 'text_preview': self.clean_text(chunk[i].get('stem', ''))[:100] }) count += len(chunk) print(f"✅ 已同步进度: {count}/{total_new}") # 每 500 条保存一次 if count % 500 == 0: self.save_index() self.save_index() print(f"🎉 同步完成!当前索引总数: {len(self.metadata)}") except Exception as e: print(f"❌ 同步失败: {e}") import traceback traceback.print_exc() finally: if 'conn' in locals() and conn: conn.close() if __name__ == "__main__": # 测试代码 checker = QuestionDuplicateChecker() # checker.sync_all_from_db() # 首次运行时同步 # result = checker.check_duplicate(10) # print(json.dumps(result, ensure_ascii=False, indent=2))