""" 数学题目查重助手 - 使用向量相似度比对题目是否重复 """ import json import os import pickle import re import threading import time import tempfile import numpy as np import faiss import pymysql import concurrent.futures from openai import OpenAI from typing import List, Dict, Tuple, Optional from config import DB_HOST, DB_PORT, DB_DATABASE, DB_USERNAME, DB_PASSWORD, OPENAI_API_KEY, OPENAI_BASE_URL class QuestionDuplicateChecker: """题目查重器""" def __init__(self, index_path="questions_tem.index", metadata_path="questions_tem_metadata.pkl"): """ 初始化查重器 Args: index_path: FAISS索引保存路径 metadata_path: 元数据保存路径 """ self.client = OpenAI( api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL ) base_dir = os.path.dirname(os.path.abspath(__file__)) # 统一使用绝对路径,避免不同工作目录导致索引文件混乱 self.index_path = index_path if os.path.isabs(index_path) else os.path.join(base_dir, index_path) self.metadata_path = metadata_path if os.path.isabs(metadata_path) else os.path.join(base_dir, metadata_path) self.index = None self.metadata = [] # 存储题目ID和文本,以便回显 self.sync_lock = threading.Lock() self.reload_lock = threading.Lock() self.index_mtime = None self.metadata_mtime = None self.sync_in_progress = False self.reload_failures = 0 self.next_reload_time = 0.0 # 权重设置 self.weights = { 'stem': 0.8, 'options': 0.05, 'answer': 0.05, 'solution': 0.1 } # 维度调整:4部分拼接后的总维度为 3072 * 4 = 12288 self.dimension = 3072 * 4 # 限制 FAISS/OpenMP 线程,避免小机器 CPU 过载 faiss_threads = max(1, min(2, os.cpu_count() or 1)) try: faiss.omp_set_num_threads(faiss_threads) except Exception: pass self._load_index() def clean_text(self, text: str) -> str: """清理文本:移除图片标签、归一化空白字符""" if not text: return "" # 移除 标签 text = re.sub(r']*/>', '', text) # 归一化空白:将多个空格、换行替换为单个空格 text = re.sub(r'\s+', ' ', text) return text.strip() def normalize_options(self, options) -> str: """标准化选项内容""" if not options: return "" data = None if isinstance(options, str): # 如果是字符串,尝试解析 JSON text = options.strip() if (text.startswith('{') and text.endswith('}')) or (text.startswith('[') and text.endswith(']')): try: data = json.loads(text) except: pass if data is None: return self.clean_text(options) else: data = options if isinstance(data, dict): # 排序并清理字典内容 sorted_data = {str(k).strip(): self.clean_text(str(v)) for k, v in sorted(data.items())} return self.clean_text(json.dumps(sorted_data, ensure_ascii=False)) elif isinstance(data, list): # 清理列表内容 cleaned_data = [self.clean_text(str(v)) for v in data] return self.clean_text(json.dumps(cleaned_data, ensure_ascii=False)) return self.clean_text(str(data)) def _load_index(self): """加载或初始化索引""" if os.path.exists(self.index_path) and os.path.exists(self.metadata_path): try: self.index = faiss.read_index(self.index_path) # 检查索引维度是否匹配,如果不匹配则重新初始化(针对权重升级) if self.index.d != self.dimension: print(f"⚠️ 索引维度不匹配 ({self.index.d} != {self.dimension}),正在重新初始化...") self._init_new_index() else: with open(self.metadata_path, 'rb') as f: self.metadata = pickle.load(f) print(f"✓ 已加载现有索引,包含 {len(self.metadata)} 道题目") self._update_index_mtime() except Exception as e: print(f"⚠️ 加载索引失败: {e},将初始化新索引") self._init_new_index() else: self._init_new_index() def _init_new_index(self): """初始化新的FAISS索引""" # 使用内积索引,配合权重拼接向量 self.index = faiss.IndexFlatIP(self.dimension) self.metadata = [] self.index_mtime = None self.metadata_mtime = None self.reload_failures = 0 self.next_reload_time = 0.0 print("✓ 已初始化新的FAISS索引") def _update_index_mtime(self): """更新索引文件的最后修改时间记录""" self.index_mtime = os.path.getmtime(self.index_path) if os.path.exists(self.index_path) else None self.metadata_mtime = os.path.getmtime(self.metadata_path) if os.path.exists(self.metadata_path) else None def ensure_index_loaded(self, force: bool = False) -> bool: """确保索引已从磁盘加载(多进程下避免空索引常驻)""" # 读取失败后做退避,避免频繁重载导致磁盘抖动 now = time.time() if not force and now < self.next_reload_time: return False index_exists = os.path.exists(self.index_path) and os.path.exists(self.metadata_path) if not index_exists: return False need_reload = force if not need_reload: # 空索引或元数据为空时,优先尝试从磁盘加载 if self.index is None or self.index.ntotal == 0 or not self.metadata: need_reload = True else: # 文件有更新则重载 current_index_mtime = os.path.getmtime(self.index_path) current_metadata_mtime = os.path.getmtime(self.metadata_path) if (self.index_mtime is None or self.metadata_mtime is None or current_index_mtime > self.index_mtime or current_metadata_mtime > self.metadata_mtime): need_reload = True if not need_reload: return False with self.reload_lock: # 同步进行中时,避免在索引尚未写完时反复重载 if self.sync_in_progress and not force: return False # 双重检查,避免重复加载 if not force and self.index is not None and self.index.ntotal > 0 and self.metadata: current_index_mtime = os.path.getmtime(self.index_path) current_metadata_mtime = os.path.getmtime(self.metadata_path) if (self.index_mtime is not None and self.metadata_mtime is not None and current_index_mtime <= self.index_mtime and current_metadata_mtime <= self.metadata_mtime): return False try: self.index = faiss.read_index(self.index_path) with open(self.metadata_path, 'rb') as f: self.metadata = pickle.load(f) self._update_index_mtime() self.reload_failures = 0 self.next_reload_time = 0.0 print(f"↻ 已重新加载索引,包含 {len(self.metadata)} 道题目") return True except Exception as e: print(f"⚠️ 重新加载索引失败: {e}") # 指数退避,最多 60 秒 self.reload_failures += 1 backoff = min(60.0, 2 ** self.reload_failures) self.next_reload_time = time.time() + backoff return False def save_index(self): """保存索引和元数据""" # 先写入临时文件,再原子替换,避免读写冲突导致索引损坏 index_dir = os.path.dirname(self.index_path) meta_dir = os.path.dirname(self.metadata_path) os.makedirs(index_dir, exist_ok=True) os.makedirs(meta_dir, exist_ok=True) tmp_index_fd, tmp_index_path = tempfile.mkstemp(prefix=".questions_tem.index.", dir=index_dir) tmp_meta_fd, tmp_meta_path = tempfile.mkstemp(prefix=".questions_tem_metadata.pkl.", dir=meta_dir) try: os.close(tmp_index_fd) os.close(tmp_meta_fd) faiss.write_index(self.index, tmp_index_path) with open(tmp_meta_path, 'wb') as f: pickle.dump(self.metadata, f) f.flush() os.fsync(f.fileno()) os.replace(tmp_index_path, self.index_path) os.replace(tmp_meta_path, self.metadata_path) self._update_index_mtime() print(f"✓ 索引和元数据已保存到 {self.index_path}") finally: # 清理临时文件(若 replace 成功则不存在) if os.path.exists(tmp_index_path): os.remove(tmp_index_path) if os.path.exists(tmp_meta_path): os.remove(tmp_meta_path) def get_weighted_embedding(self, question: Dict) -> np.ndarray: """ 获取加权拼接向量 计算 4 个部分的 embedding,并乘以各自权重的平方根后拼接 """ parts = { 'stem': self.clean_text(question.get('stem', '') or ''), 'options': self.normalize_options(question.get('options', '') or ''), 'answer': self.clean_text(question.get('answer', '') or ''), 'solution': self.clean_text(question.get('solution', '') or '') } embeddings = [] for key, text in parts.items(): # 获取单部分 embedding emb = self.get_embedding(text if text.strip() else "无") # 防止空文本 if emb is None: return None # 归一化该部分的向量 norm = np.linalg.norm(emb) if norm > 0: emb = emb / norm # 乘以权重平方根:这样点积结果就是权重加权和 weight_factor = np.sqrt(self.weights[key]) embeddings.append(emb * weight_factor) # 拼接成一个大向量 mega_vector = np.concatenate(embeddings).astype('float32') return mega_vector def get_embedding(self, text: str) -> np.ndarray: """获取文本的embedding向量""" try: # 统一清理文本 cleaned_text = self.clean_text(text) if not cleaned_text.strip(): cleaned_text = "无" response = self.client.embeddings.create( model="text-embedding-3-large", input=cleaned_text ) embedding = np.array(response.data[0].embedding).astype('float32') return embedding except Exception as e: print(f"❌ 获取embedding失败: {e}") return None def get_embeddings_batch(self, texts: List[str]) -> List[np.ndarray]: """批量获取向量化结果""" try: if not texts: return [] # 统一清理文本 cleaned_texts = [] for t in texts: c = self.clean_text(t) cleaned_texts.append(c if c.strip() else "无") response = self.client.embeddings.create( model="text-embedding-3-large", input=cleaned_texts ) return [np.array(data.embedding).astype('float32') for data in response.data] except Exception as e: print(f"❌ 批量获取 embedding 失败: {e}") return [] def get_weighted_embeddings_batch(self, questions: List[Dict]) -> List[np.ndarray]: """ 批量获取加权拼接向量 """ all_texts = [] for q in questions: # 获取 4 个部分并处理空值占位 stem = self.clean_text(q.get('stem', '') or '') opts = self.normalize_options(q.get('options', '') or '') ans = self.clean_text(q.get('answer', '') or '') sol = self.clean_text(q.get('solution', '') or '') # OpenAI API 不接受空字符串,强制填充 "无" all_texts.append(stem if stem.strip() else "无") all_texts.append(opts if opts.strip() else "无") all_texts.append(ans if ans.strip() else "无") all_texts.append(sol if sol.strip() else "无") # 批量请求 (每题 4 个部分) all_embeddings = self.get_embeddings_batch(all_texts) if not all_embeddings: return [] results = [] for i in range(len(questions)): q_embs = all_embeddings[i*4 : (i+1)*4] if len(q_embs) < 4: results.append(None) continue weighted_parts = [] keys = ['stem', 'options', 'answer', 'solution'] for j, emb in enumerate(q_embs): # 归一化 norm = np.linalg.norm(emb) if norm > 0: emb = emb / norm # 加权 weight_factor = np.sqrt(self.weights[keys[j]]) weighted_parts.append(emb * weight_factor) mega_vector = np.concatenate(weighted_parts).astype('float32') results.append(mega_vector) return results def combine_content(self, q: Dict) -> str: """组合题目内容用于预览显示""" stem = q.get('stem', '') or '' return self.clean_text(stem) # 预览以清理后的题干为主 def fetch_question_from_db(self, question_id: int) -> Optional[Dict]: """从数据库获取题目信息""" try: conn = pymysql.connect( host=DB_HOST, port=DB_PORT, user=DB_USERNAME, password=DB_PASSWORD, database=DB_DATABASE, charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor ) with conn.cursor() as cursor: sql = "SELECT id, stem, options, answer, solution, is_repeat FROM questions_tem WHERE id = %s" cursor.execute(sql, (question_id,)) result = cursor.fetchone() return result except Exception as e: print(f"❌ 数据库查询失败: {e}") return None finally: if 'conn' in locals() and conn: conn.close() def confirm_repeat(self, question_id: int, is_repeat: int) -> bool: """ 人工确认结果后更新数据库和向量库 is_repeat: 0 代表无相似题,1 代表有重复 """ # 1. 更新数据库状态 self.update_is_repeat(question_id, is_repeat) # 2. 如果标记为非重复,将其加入向量库 if is_repeat == 0: question = self.fetch_question_from_db(question_id) if question: mega_vec = self.get_weighted_embedding(question) if mega_vec is not None: # 转换维度以便 FAISS 接收 mega_vec = mega_vec.reshape(1, -1) self.add_to_index(question_id, mega_vec, self.clean_text(question.get('stem', ''))[:100]) return True return True def update_is_repeat(self, question_id: int, is_repeat: int): """更新数据库中的 is_repeat 字段""" try: conn = pymysql.connect( host=DB_HOST, port=DB_PORT, user=DB_USERNAME, password=DB_PASSWORD, database=DB_DATABASE, charset='utf8mb4' ) with conn.cursor() as cursor: sql = "UPDATE questions_tem SET is_repeat = %s WHERE id = %s" cursor.execute(sql, (is_repeat, question_id)) conn.commit() print(f"✓ 题目 ID {question_id} 的 is_repeat 已更新为 {is_repeat}") except Exception as e: print(f"❌ 更新 is_repeat 失败: {e}") finally: if 'conn' in locals() and conn: conn.close() def call_openai_for_duplicate_check(self, q1: Dict, q2: Dict) -> bool: """ 调用 GPT-4o 进行深度对比,判断两道题是否为重复题 """ try: prompt = f"""你是一个专业的数学题目查重专家。请对比以下两道题目,判断它们是否为“重复题”。 “重复题”的定义:题目背景、考查知识点、核心逻辑以及数值完全一致,或者仅有极其微小的表述差异但不影响题目本质。 题目 1: 题干: {q1.get('stem', '')} 选项: {q1.get('options', '')} 答案: {q1.get('answer', '')} 解析: {q1.get('solution', '')} 题目 2: 题干: {q2.get('stem', '')} 选项: {q2.get('options', '')} 答案: {q2.get('answer', '')} 解析: {q2.get('solution', '')} 请分析两道题的相似性,并给出最终结论。 结论必须以 JSON 格式输出,包含以下字段: - is_duplicate: 布尔值,True 表示是重复题,False 表示不是 - reason: 简短的分析理由 只输出 JSON 即可,不要有其他解释说明。""" response = self.client.chat.completions.create( model="gpt-4o", messages=[ {"role": "system", "content": "你是一个专业的数学题目查重助手。"}, {"role": "user", "content": prompt} ], response_format={"type": "json_object"}, temperature=0 ) result = json.loads(response.choices[0].message.content) is_dup = result.get('is_duplicate', False) reason = result.get('reason', '') print(f"🤖 GPT-4o 判定结果: {'重复' if is_dup else '不重复'} | 理由: {reason}") return is_dup except Exception as e: print(f"❌ GPT-4o 兜底计算失败: {e}") return False def check_duplicate(self, question_id: int, threshold: float = 0.85) -> Dict: """ 查重主逻辑 (未入库模式:不更新数据库,不自动入库) """ self.ensure_index_loaded() # 1. 获取题目信息 question = self.fetch_question_from_db(question_id) if not question: return {"status": "error", "message": f"未找到ID为 {question_id} 的题目"} # 2. 获取加权拼接向量 mega_vector = self.get_weighted_embedding(question) if mega_vector is None: return {"status": "error", "message": "无法生成题目向量"} # 3. 搜索索引 if self.index.ntotal == 0: return { "status": "success", "message": "向量库为空", "is_duplicate": False, "top_similar": [] } # 搜索最相似的题目 k = min(3, self.index.ntotal) query_vec = mega_vector.reshape(1, -1) similarities, indices = self.index.search(query_vec, k) best_similar = None is_duplicate = False gpt_checked = False for i in range(k): idx = indices[0][i] score = float(similarities[0][i]) if idx != -1: similar_info = self.metadata[idx] similar_id = similar_info['id'] if similar_id == question_id: continue best_similar = { "id": similar_id, "similarity": round(score, 4), "similar_point": similar_info.get('text_preview', '') } if score >= threshold: is_duplicate = True elif score > 0.5: # 相似度在 0.5 ~ 0.85 之间,调用 GPT-4o 兜底 similar_question = self.fetch_question_from_db(similar_id) if similar_question: is_duplicate = self.call_openai_for_duplicate_check(question, similar_question) gpt_checked = True break if is_duplicate: return { "status": "warning", "message": "该题目与已存在题目相似,请人工核验" + (" (经 GPT-4o 兜底确认)" if gpt_checked else ""), "is_duplicate": True, "gpt_checked": gpt_checked, "top_similar": [best_similar] } else: return { "status": "success", "message": "该题目通过查重", "is_duplicate": False, "top_similar": [best_similar] if best_similar else [] } def check_duplicate_by_content(self, question_data: Dict, threshold: float = 0.85) -> Dict: """ 基于原始文本内容进行查重 (预检模式) """ self.ensure_index_loaded() # 1. 获取加权拼接向量 mega_vector = self.get_weighted_embedding(question_data) if mega_vector is None: return {"status": "error", "message": "无法生成题目向量"} # 2. 检索 if self.index.ntotal == 0: return {"status": "success", "is_duplicate": False, "top_similar": []} k = 1 # 预检通常只需返回最相似的一个 query_vec = mega_vector.reshape(1, -1) similarities, indices = self.index.search(query_vec, k) if indices[0][0] != -1: score = float(similarities[0][0]) similar_info = self.metadata[indices[0][0]] similar_id = similar_info['id'] best_similar = { "id": similar_id, "similarity": round(score, 4), "similar_point": similar_info.get('text_preview', '') } is_duplicate = False gpt_checked = False if score >= threshold: is_duplicate = True elif score > 0.5: # 相似度在 0.5 ~ 0.85 之间,调用 GPT-4o 兜底 similar_question = self.fetch_question_from_db(similar_id) if similar_question: is_duplicate = self.call_openai_for_duplicate_check(question_data, similar_question) gpt_checked = True if is_duplicate: return { "status": "warning", "is_duplicate": True, "gpt_checked": gpt_checked, "top_similar": [best_similar] } return {"status": "success", "is_duplicate": False, "top_similar": [best_similar]} return {"status": "success", "is_duplicate": False, "top_similar": []} def get_question_data(self, question_id: int) -> Dict: """获取向量库中特定ID的数据及总数""" self.ensure_index_loaded() total_count = self.index.ntotal if self.index else 0 target_metadata = None for m in self.metadata: if m['id'] == question_id: target_metadata = m break return { "total_count": total_count, "question_data": target_metadata } def add_to_index(self, question_id: int, vector: np.ndarray, text: str): """将题目加入索引""" if any(m['id'] == question_id for m in self.metadata): return self.index.add(vector) self.metadata.append({ 'id': question_id, 'text_preview': self.clean_text(text)[:100] }) self.save_index() def sync_all_from_db(self, batch_size=50, max_workers=None): """同步数据库中所有题目到索引 (支持加权模式 + 批量 + 多线程)""" if not self.sync_lock.acquire(blocking=False): print("⏳ 同步正在进行中,已忽略重复请求") return False self.sync_in_progress = True print("🔄 开始全量同步 (优化版 - 加权模式)...") existing_ids = {m['id'] for m in self.metadata} try: conn = pymysql.connect( host=DB_HOST, port=DB_PORT, user=DB_USERNAME, password=DB_PASSWORD, database=DB_DATABASE, charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor ) with conn.cursor() as cursor: print("📡 正在从数据库读取所有题目数据...") sql = "SELECT id, stem, options, answer, solution FROM questions_tem" cursor.execute(sql) all_questions = cursor.fetchall() print(f"📦 数据库加载完成,共计 {len(all_questions)} 条记录") new_questions = [q for q in all_questions if q['id'] not in existing_ids] total_new = len(new_questions) if total_new == 0: print("✅ 已经是最新状态,无需同步。") return True print(f"📊 数据库总计: {len(all_questions)}, 需同步新增: {total_new}") # 分块处理 chunks = [new_questions[i:i + batch_size] for i in range(0, total_new, batch_size)] def process_chunk(chunk): mega_vectors = self.get_weighted_embeddings_batch(chunk) return chunk, mega_vectors # 使用线程池并发(小机器限制并发,避免 CPU 过载) worker_count = max_workers or max(1, min(2, os.cpu_count() or 1)) with concurrent.futures.ThreadPoolExecutor(max_workers=worker_count) as executor: future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks} count = 0 for future in concurrent.futures.as_completed(future_to_chunk): chunk, mega_vectors = future.result() if mega_vectors: for i, mega_vec in enumerate(mega_vectors): if mega_vec is not None: self.index.add(mega_vec.reshape(1, -1)) self.metadata.append({ 'id': chunk[i]['id'], 'text_preview': self.clean_text(chunk[i].get('stem', ''))[:100] }) count += len(chunk) print(f"✅ 已同步进度: {count}/{total_new}") # 每 500 条保存一次 if count % 500 == 0: self.save_index() self.save_index() print(f"🎉 同步完成!当前索引总数: {len(self.metadata)}") return True except Exception as e: print(f"❌ 同步失败: {e}") import traceback traceback.print_exc() return False finally: if 'conn' in locals() and conn: conn.close() self.sync_in_progress = False if self.sync_lock.locked(): self.sync_lock.release() if __name__ == "__main__": # 测试代码 checker = QuestionDuplicateChecker() # checker.sync_all_from_db() # 首次运行时同步 # result = checker.check_duplicate(10) # print(json.dumps(result, ensure_ascii=False, indent=2))