瀏覽代碼

Initial commit: Math question duplicate checker with GPT-4o fallback

Hai Lin 3 周之前
當前提交
b916f15a7f
共有 7 個文件被更改,包括 894 次插入0 次删除
  1. 17 0
      .gitignore
  2. 108 0
      API_DOCUMENT.md
  3. 86 0
      app.py
  4. 42 0
      check_duplicate_trigger.py
  5. 29 0
      config.py
  6. 563 0
      duplicate_checker.py
  7. 49 0
      test_similarity.py

+ 17 - 0
.gitignore

@@ -0,0 +1,17 @@
+# FAISS index files
+*.index
+*.pkl
+
+# Python
+__pycache__/
+*.pyc
+*.pyo
+*.pyd
+.Python
+env/
+venv/
+.env
+
+# IDEs
+.vscode/
+.idea/

+ 108 - 0
API_DOCUMENT.md

@@ -0,0 +1,108 @@
+# 数学题目查重系统 API 接口文档
+
+本文档定义了数学题目查重系统与录题系统对接的 API 接口。
+
+## 基础信息
+- **服务地址**: `http://<server-ip>:8888`
+- **数据格式**: `application/json`
+
+---
+
+## 1. 题目查重预检 (基于内容)
+
+录题系统在录入题目时,调用此接口判断当前输入的内容是否与库中题目重复。
+
+- **接口地址**: `/api/check_duplicate`
+- **请求方法**: `POST`
+- **请求参数**:
+
+| 参数名 | 类型 | 必填 | 说明 |
+| :--- | :--- | :--- | :--- |
+| stem | String | 是 | 题干内容 |
+| options | String/Object | 否 | 选项内容 (JSON 字符串或对象) |
+| answer | String | 否 | 答案 |
+| solution | String | 否 | 解析 |
+
+- **响应示例 (发现重复)**:
+```json
+{
+  "code": -1,
+  "result": {
+    "repeatIdList": [
+      {
+        "questionsId": 12345,
+        "repeatMsg": "相似度: 0.9256。相似点: 某某数学题干..."
+      }
+    ]
+  }
+}
+```
+
+- **响应示例 (经过 GPT-4o 兜底确认重复)**:
+```json
+{
+  "code": -1,
+  "result": {
+    "repeatIdList": [
+      {
+        "questionsId": 12345,
+        "repeatMsg": "相似度: 0.7582 (经 GPT-4o 深度核验)。相似点: 某某数学题干..."
+      }
+    ]
+  }
+}
+```
+
+- **响应示例 (无重复)**:
+```json
+{
+  "code": 0,
+  "result": "ok"
+}
+```
+
+---
+
+## 2. 人工确认查重结果
+
+录题系统完成录入并分配 ID 后,调用此接口通知查重系统该题的最终状态。查重系统会根据状态决定是否将该题加入向量检索库。
+
+- **接口地址**: `/api/confirm_repeat`
+- **请求方法**: `POST`
+- **请求参数**:
+
+| 参数名 | 类型 | 必填 | 说明 |
+| :--- | :--- | :--- | :--- |
+| questionId | Integer | 是 | 数据库中的题目 ID |
+| isRepeat | Integer | 是 | 0: 确认不重复(将入库); 1: 确认重复(不入库) |
+
+- **响应示例**:
+```json
+{
+  "code": 0,
+  "result": "ok"
+}
+```
+
+---
+
+## 3. 手动触发全量同步
+
+将数据库 `questions_tem` 表中的存量题目同步到向量库中。
+
+- **接口地址**: `/api/sync`
+- **请求方法**: `POST`
+
+- **响应示例**:
+```json
+{
+  "code": 0,
+  "result": "Sync completed"
+}
+```
+
+---
+
+## 状态码说明
+- `code: 0`: 操作成功或通过查重。
+- `code: -1`: 操作失败、参数错误或发现重复题目。

+ 86 - 0
app.py

@@ -0,0 +1,86 @@
+from flask import Flask, request, jsonify
+from duplicate_checker import QuestionDuplicateChecker
+
+app = Flask(__name__)
+# 初始化查重器(全局单例,避免重复加载索引)
+checker = QuestionDuplicateChecker()
+
+@app.route('/api/check_duplicate', methods=['POST'])
+def check_duplicate():
+    """
+    题目查重 API 接口 (提交前预检模式)
+    参数: stem, options, answer, solution
+    """
+    data = request.get_json()
+    if not data:
+        return jsonify({"code": -1, "message": "Missing content"}), 400
+
+    # 提取内容字段
+    question_data = {
+        "stem": data.get('stem', ''),
+        "options": data.get('options', ''),
+        "answer": data.get('answer', ''),
+        "solution": data.get('solution', '')
+    }
+
+    if not question_data["stem"]:
+        return jsonify({"code": -1, "message": "stem is required"}), 400
+
+    # 执行基于内容的查重
+    result = checker.check_duplicate_by_content(question_data)
+
+    if result.get("status") == "error":
+        return jsonify({"code": -1, "message": result.get("message")}), 500
+
+    if result.get("is_duplicate"):
+        item = result["top_similar"][0]
+        gpt_info = " (经 GPT-4o 深度核验)" if result.get("gpt_checked") else ""
+        return jsonify({
+            "code": -1,
+            "result": {
+                "repeatIdList": [{
+                    "questionsId": item["id"],
+                    "repeatMsg": f"相似度: {item['similarity']}{gpt_info}。相似点: {item['similar_point']}"
+                }]
+            }
+        })
+    else:
+        return jsonify({"code": 0, "result": "ok"})
+
+@app.route('/api/sync', methods=['POST'])
+def sync_index():
+    """手动触发全量同步接口"""
+    try:
+        checker.sync_all_from_db()
+        return jsonify({"code": 0, "result": "Sync completed"})
+    except Exception as e:
+        return jsonify({"code": -1, "message": str(e)}), 500
+
+@app.route('/api/confirm_repeat', methods=['POST'])
+def confirm_repeat():
+    """
+    人工确认查重结果接口
+    参数: questionId, isRepeat (0: 无相似, 1: 有重复)
+    """
+    data = request.get_json()
+    if not data:
+        return jsonify({"code": -1, "message": "Missing JSON body"}), 400
+    
+    question_id = data.get('questionId')
+    is_repeat = data.get('isRepeat')
+
+    if question_id is None or is_repeat is None:
+        return jsonify({"code": -1, "message": "Missing questionId or isRepeat"}), 400
+
+    try:
+        success = checker.confirm_repeat(int(question_id), int(is_repeat))
+        if success:
+            return jsonify({"code": 0, "result": "ok"})
+        else:
+            return jsonify({"code": -1, "message": "Failed to update"}), 500
+    except Exception as e:
+        return jsonify({"code": -1, "message": str(e)}), 500
+
+if __name__ == '__main__':
+    # 启动服务,默认 5000 端口
+    app.run(host='0.0.0.0', port=8888, debug=False)

+ 42 - 0
check_duplicate_trigger.py

@@ -0,0 +1,42 @@
+import sys
+import json
+from duplicate_checker import QuestionDuplicateChecker
+
+def main():
+    if len(sys.argv) < 2:
+        print("用法: python check_duplicate_trigger.py <question_id>")
+        return
+
+    try:
+        question_id = int(sys.argv[1])
+    except ValueError:
+        print("错误: question_id 必须是整数")
+        return
+
+    # 初始化查重器
+    checker = QuestionDuplicateChecker()
+    
+    # 执行查重
+    print(f"正在对题目 ID: {question_id} 进行查重分析...")
+    result = checker.check_duplicate(question_id)
+    
+    # 格式化输出结果,匹配用户要求的返回格式
+    if result.get("status") == "success":
+        print("\n" + "="*30)
+        print(result["message"])
+        print("="*30)
+    elif result.get("status") == "warning":
+        print("\n" + "!"*30)
+        print(result["message"])
+        print("\n相似题目 Top 3:")
+        for item in result["top_similar"]:
+            print(f"- 题目 ID: {item['id']}")
+            print(f"  相似度: {item['similarity']}")
+            print(f"  相似点预览: {item.get('similar_point', '无')}")
+            print("-" * 20)
+        print("!"*30)
+    else:
+        print(f"\n❌ 出错: {result.get('message')}")
+
+if __name__ == "__main__":
+    main()

+ 29 - 0
config.py

@@ -0,0 +1,29 @@
+"""
+配置文件
+"""
+import os
+from dotenv import load_dotenv
+
+load_dotenv()
+
+# OpenAI API配置
+OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "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")
+OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
+OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-5.1")
+
+# 知识点匹配配置
+CONFIDENCE_THRESHOLD = float(os.getenv("CONFIDENCE_THRESHOLD", "0.7"))  # 置信度阈值,低于此值允许生成参考知识点
+
+# 向量数据库配置
+# 向量数据库文件路径(可以是本地路径或网络共享路径)
+# 例如:本地路径 "knowledge_points.index" 或网络路径 "\\\\server\\share\\knowledge_points.index"
+FAISS_INDEX_PATH = os.getenv("FAISS_INDEX_PATH", "knowledge_points.index")
+FAISS_METADATA_PATH = os.getenv("FAISS_METADATA_PATH", "knowledge_points_metadata.pkl")
+
+# 数据库配置
+DB_HOST = os.getenv("DB_HOST", "rm-f8ze60yirdj8786u2wo.mysql.rds.aliyuncs.com")
+DB_PORT = int(os.getenv("DB_PORT", "3306"))
+DB_DATABASE = os.getenv("DB_DATABASE", "math-conten-online2")
+DB_USERNAME = os.getenv("DB_USERNAME", "root")
+DB_PASSWORD = os.getenv("DB_PASSWORD", "csqz@20255")
+

+ 563 - 0
duplicate_checker.py

@@ -0,0 +1,563 @@
+"""
+数学题目查重助手 - 使用向量相似度比对题目是否重复
+"""
+import json
+import os
+import pickle
+import re
+import numpy as np
+import faiss
+import pymysql
+import concurrent.futures
+from openai import OpenAI
+from typing import List, Dict, Tuple, Optional
+from config import DB_HOST, DB_PORT, DB_DATABASE, DB_USERNAME, DB_PASSWORD, OPENAI_API_KEY, OPENAI_BASE_URL
+
+class QuestionDuplicateChecker:
+    """题目查重器"""
+    
+    def __init__(self, index_path="questions_tem.index", metadata_path="questions_tem_metadata.pkl"):
+        """
+        初始化查重器
+        
+        Args:
+            index_path: FAISS索引保存路径
+            metadata_path: 元数据保存路径
+        """
+        self.client = OpenAI(
+            api_key=OPENAI_API_KEY,
+            base_url=OPENAI_BASE_URL
+        )
+        self.index_path = index_path
+        self.metadata_path = metadata_path
+        self.index = None
+        self.metadata = [] # 存储题目ID和文本,以便回显
+        
+        # 权重设置
+        self.weights = {
+            'stem': 0.8,
+            'options': 0.05,
+            'answer': 0.05,
+            'solution': 0.1
+        }
+        # 维度调整:4部分拼接后的总维度为 3072 * 4 = 12288
+        self.dimension = 3072 * 4
+        
+        self._load_index()
+
+    def clean_text(self, text: str) -> str:
+        """清理文本:移除图片标签、归一化空白字符"""
+        if not text:
+            return ""
+        # 移除 <image .../> 标签
+        text = re.sub(r'<image\s+[^>]*/>', '', text)
+        # 归一化空白:将多个空格、换行替换为单个空格
+        text = re.sub(r'\s+', ' ', text)
+        return text.strip()
+
+    def normalize_options(self, options) -> str:
+        """标准化选项内容"""
+        if not options:
+            return ""
+        
+        data = None
+        if isinstance(options, str):
+            # 如果是字符串,尝试解析 JSON
+            text = options.strip()
+            if (text.startswith('{') and text.endswith('}')) or (text.startswith('[') and text.endswith(']')):
+                try:
+                    data = json.loads(text)
+                except:
+                    pass
+            if data is None:
+                return self.clean_text(options)
+        else:
+            data = options
+
+        if isinstance(data, dict):
+            # 排序并清理字典内容
+            sorted_data = {str(k).strip(): self.clean_text(str(v)) for k, v in sorted(data.items())}
+            return self.clean_text(json.dumps(sorted_data, ensure_ascii=False))
+        elif isinstance(data, list):
+            # 清理列表内容
+            cleaned_data = [self.clean_text(str(v)) for v in data]
+            return self.clean_text(json.dumps(cleaned_data, ensure_ascii=False))
+        
+        return self.clean_text(str(data))
+
+    def _load_index(self):
+        """加载或初始化索引"""
+        if os.path.exists(self.index_path) and os.path.exists(self.metadata_path):
+            try:
+                self.index = faiss.read_index(self.index_path)
+                # 检查索引维度是否匹配,如果不匹配则重新初始化(针对权重升级)
+                if self.index.d != self.dimension:
+                    print(f"⚠️ 索引维度不匹配 ({self.index.d} != {self.dimension}),正在重新初始化...")
+                    self._init_new_index()
+                else:
+                    with open(self.metadata_path, 'rb') as f:
+                        self.metadata = pickle.load(f)
+                    print(f"✓ 已加载现有索引,包含 {len(self.metadata)} 道题目")
+            except Exception as e:
+                print(f"⚠️ 加载索引失败: {e},将初始化新索引")
+                self._init_new_index()
+        else:
+            self._init_new_index()
+
+    def _init_new_index(self):
+        """初始化新的FAISS索引"""
+        # 使用内积索引,配合权重拼接向量
+        self.index = faiss.IndexFlatIP(self.dimension)
+        self.metadata = []
+        print("✓ 已初始化新的FAISS索引")
+
+    def save_index(self):
+        """保存索引和元数据"""
+        faiss.write_index(self.index, self.index_path)
+        with open(self.metadata_path, 'wb') as f:
+            pickle.dump(self.metadata, f)
+        print(f"✓ 索引和元数据已保存到 {self.index_path}")
+
+    def get_weighted_embedding(self, question: Dict) -> np.ndarray:
+        """
+        获取加权拼接向量
+        计算 4 个部分的 embedding,并乘以各自权重的平方根后拼接
+        """
+        parts = {
+            'stem': self.clean_text(question.get('stem', '') or ''),
+            'options': self.normalize_options(question.get('options', '') or ''),
+            'answer': self.clean_text(question.get('answer', '') or ''),
+            'solution': self.clean_text(question.get('solution', '') or '')
+        }
+        
+        embeddings = []
+        for key, text in parts.items():
+            # 获取单部分 embedding
+            emb = self.get_embedding(text if text.strip() else "无") # 防止空文本
+            if emb is None:
+                return None
+            
+            # 归一化该部分的向量
+            norm = np.linalg.norm(emb)
+            if norm > 0:
+                emb = emb / norm
+                
+            # 乘以权重平方根:这样点积结果就是权重加权和
+            weight_factor = np.sqrt(self.weights[key])
+            embeddings.append(emb * weight_factor)
+            
+        # 拼接成一个大向量
+        mega_vector = np.concatenate(embeddings).astype('float32')
+        return mega_vector
+
+    def get_embedding(self, text: str) -> np.ndarray:
+        """获取文本的embedding向量"""
+        try:
+            # 统一清理文本
+            cleaned_text = self.clean_text(text)
+            if not cleaned_text.strip():
+                cleaned_text = "无"
+                
+            response = self.client.embeddings.create(
+                model="text-embedding-3-large",
+                input=cleaned_text
+            )
+            embedding = np.array(response.data[0].embedding).astype('float32')
+            return embedding
+        except Exception as e:
+            print(f"❌ 获取embedding失败: {e}")
+            return None
+
+    def get_embeddings_batch(self, texts: List[str]) -> List[np.ndarray]:
+        """批量获取向量化结果"""
+        try:
+            if not texts:
+                return []
+            
+            # 统一清理文本
+            cleaned_texts = []
+            for t in texts:
+                c = self.clean_text(t)
+                cleaned_texts.append(c if c.strip() else "无")
+                
+            response = self.client.embeddings.create(
+                model="text-embedding-3-large",
+                input=cleaned_texts
+            )
+            return [np.array(data.embedding).astype('float32') for data in response.data]
+        except Exception as e:
+            print(f"❌ 批量获取 embedding 失败: {e}")
+            return []
+
+    def get_weighted_embeddings_batch(self, questions: List[Dict]) -> List[np.ndarray]:
+        """
+        批量获取加权拼接向量
+        """
+        all_texts = []
+        for q in questions:
+            # 获取 4 个部分并处理空值占位
+            stem = self.clean_text(q.get('stem', '') or '')
+            opts = self.normalize_options(q.get('options', '') or '')
+            ans = self.clean_text(q.get('answer', '') or '')
+            sol = self.clean_text(q.get('solution', '') or '')
+            
+            # OpenAI API 不接受空字符串,强制填充 "无"
+            all_texts.append(stem if stem.strip() else "无")
+            all_texts.append(opts if opts.strip() else "无")
+            all_texts.append(ans if ans.strip() else "无")
+            all_texts.append(sol if sol.strip() else "无")
+        
+        # 批量请求 (每题 4 个部分)
+        all_embeddings = self.get_embeddings_batch(all_texts)
+        if not all_embeddings:
+            return []
+            
+        results = []
+        for i in range(len(questions)):
+            q_embs = all_embeddings[i*4 : (i+1)*4]
+            if len(q_embs) < 4:
+                results.append(None)
+                continue
+                
+            weighted_parts = []
+            keys = ['stem', 'options', 'answer', 'solution']
+            for j, emb in enumerate(q_embs):
+                # 归一化
+                norm = np.linalg.norm(emb)
+                if norm > 0:
+                    emb = emb / norm
+                # 加权
+                weight_factor = np.sqrt(self.weights[keys[j]])
+                weighted_parts.append(emb * weight_factor)
+            
+            mega_vector = np.concatenate(weighted_parts).astype('float32')
+            results.append(mega_vector)
+            
+        return results
+
+    def combine_content(self, q: Dict) -> str:
+        """组合题目内容用于预览显示"""
+        stem = q.get('stem', '') or ''
+        return self.clean_text(stem) # 预览以清理后的题干为主
+
+    def fetch_question_from_db(self, question_id: int) -> Optional[Dict]:
+        """从数据库获取题目信息"""
+        try:
+            conn = pymysql.connect(
+                host=DB_HOST,
+                port=DB_PORT,
+                user=DB_USERNAME,
+                password=DB_PASSWORD,
+                database=DB_DATABASE,
+                charset='utf8mb4',
+                cursorclass=pymysql.cursors.DictCursor
+            )
+            with conn.cursor() as cursor:
+                sql = "SELECT id, stem, options, answer, solution, is_repeat FROM questions_tem WHERE id = %s"
+                cursor.execute(sql, (question_id,))
+                result = cursor.fetchone()
+                return result
+        except Exception as e:
+            print(f"❌ 数据库查询失败: {e}")
+            return None
+        finally:
+            if 'conn' in locals() and conn:
+                conn.close()
+
+    def confirm_repeat(self, question_id: int, is_repeat: int) -> bool:
+        """
+        人工确认结果后更新数据库和向量库
+        is_repeat: 0 代表无相似题,1 代表有重复
+        """
+        # 1. 更新数据库状态
+        self.update_is_repeat(question_id, is_repeat)
+        
+        # 2. 如果标记为非重复,将其加入向量库
+        if is_repeat == 0:
+            question = self.fetch_question_from_db(question_id)
+            if question:
+                mega_vec = self.get_weighted_embedding(question)
+                if mega_vec is not None:
+                    # 转换维度以便 FAISS 接收
+                    mega_vec = mega_vec.reshape(1, -1)
+                    self.add_to_index(question_id, mega_vec, self.clean_text(question.get('stem', ''))[:100])
+                    return True
+        return True
+
+    def update_is_repeat(self, question_id: int, is_repeat: int):
+        """更新数据库中的 is_repeat 字段"""
+        try:
+            conn = pymysql.connect(
+                host=DB_HOST,
+                port=DB_PORT,
+                user=DB_USERNAME,
+                password=DB_PASSWORD,
+                database=DB_DATABASE,
+                charset='utf8mb4'
+            )
+            with conn.cursor() as cursor:
+                sql = "UPDATE questions_tem SET is_repeat = %s WHERE id = %s"
+                cursor.execute(sql, (is_repeat, question_id))
+                conn.commit()
+                print(f"✓ 题目 ID {question_id} 的 is_repeat 已更新为 {is_repeat}")
+        except Exception as e:
+            print(f"❌ 更新 is_repeat 失败: {e}")
+        finally:
+            if 'conn' in locals() and conn:
+                conn.close()
+
+    def call_openai_for_duplicate_check(self, q1: Dict, q2: Dict) -> bool:
+        """
+        调用 GPT-4o 进行深度对比,判断两道题是否为重复题
+        """
+        try:
+            prompt = f"""你是一个专业的数学题目查重专家。请对比以下两道题目,判断它们是否为“重复题”。
+“重复题”的定义:题目背景、考查知识点、核心逻辑以及数值完全一致,或者仅有极其微小的表述差异但不影响题目本质。
+
+题目 1:
+题干: {q1.get('stem', '')}
+选项: {q1.get('options', '')}
+答案: {q1.get('answer', '')}
+解析: {q1.get('solution', '')}
+
+题目 2:
+题干: {q2.get('stem', '')}
+选项: {q2.get('options', '')}
+答案: {q2.get('answer', '')}
+解析: {q2.get('solution', '')}
+
+请分析两道题的相似性,并给出最终结论。
+结论必须以 JSON 格式输出,包含以下字段:
+- is_duplicate: 布尔值,True 表示是重复题,False 表示不是
+- reason: 简短的分析理由
+
+只输出 JSON 即可,不要有其他解释说明。"""
+
+            response = self.client.chat.completions.create(
+                model="gpt-4o",
+                messages=[
+                    {"role": "system", "content": "你是一个专业的数学题目查重助手。"},
+                    {"role": "user", "content": prompt}
+                ],
+                response_format={"type": "json_object"}
+            )
+            
+            result = json.loads(response.choices[0].message.content)
+            is_dup = result.get('is_duplicate', False)
+            reason = result.get('reason', '')
+            print(f"🤖 GPT-4o 判定结果: {'重复' if is_dup else '不重复'} | 理由: {reason}")
+            return is_dup
+        except Exception as e:
+            print(f"❌ GPT-4o 兜底计算失败: {e}")
+            return False
+
+    def check_duplicate(self, question_id: int, threshold: float = 0.85) -> Dict:
+        """
+        查重主逻辑 (未入库模式:不更新数据库,不自动入库)
+        """
+        # 1. 获取题目信息
+        question = self.fetch_question_from_db(question_id)
+        if not question:
+            return {"status": "error", "message": f"未找到ID为 {question_id} 的题目"}
+
+        # 2. 获取加权拼接向量
+        mega_vector = self.get_weighted_embedding(question)
+        if mega_vector is None:
+            return {"status": "error", "message": "无法生成题目向量"}
+
+        # 3. 搜索索引
+        if self.index.ntotal == 0:
+            return {
+                "status": "success",
+                "message": "向量库为空",
+                "is_duplicate": False,
+                "top_similar": []
+            }
+
+        # 搜索最相似的题目
+        k = min(3, self.index.ntotal)
+        query_vec = mega_vector.reshape(1, -1)
+        similarities, indices = self.index.search(query_vec, k)
+
+        best_similar = None
+        is_duplicate = False
+        gpt_checked = False
+        
+        for i in range(k):
+            idx = indices[0][i]
+            score = float(similarities[0][i])
+            if idx != -1:
+                similar_info = self.metadata[idx]
+                similar_id = similar_info['id']
+                
+                if similar_id == question_id:
+                    continue
+                
+                best_similar = {
+                    "id": similar_id,
+                    "similarity": round(score, 4),
+                    "similar_point": similar_info.get('text_preview', '')
+                }
+                
+                if score >= threshold:
+                    is_duplicate = True
+                elif score > 0.5:
+                    # 相似度在 0.5 ~ 0.85 之间,调用 GPT-4o 兜底
+                    similar_question = self.fetch_question_from_db(similar_id)
+                    if similar_question:
+                        is_duplicate = self.call_openai_for_duplicate_check(question, similar_question)
+                        gpt_checked = True
+                break
+
+        if is_duplicate:
+            return {
+                "status": "warning",
+                "message": "该题目与已存在题目相似,请人工核验" + (" (经 GPT-4o 兜底确认)" if gpt_checked else ""),
+                "is_duplicate": True,
+                "gpt_checked": gpt_checked,
+                "top_similar": [best_similar]
+            }
+        else:
+            return {
+                "status": "success",
+                "message": "该题目通过查重",
+                "is_duplicate": False,
+                "top_similar": [best_similar] if best_similar else []
+            }
+
+    def check_duplicate_by_content(self, question_data: Dict, threshold: float = 0.85) -> Dict:
+        """
+        基于原始文本内容进行查重 (预检模式)
+        """
+        # 1. 获取加权拼接向量
+        mega_vector = self.get_weighted_embedding(question_data)
+        if mega_vector is None:
+            return {"status": "error", "message": "无法生成题目向量"}
+
+        # 2. 检索
+        if self.index.ntotal == 0:
+            return {"status": "success", "is_duplicate": False, "top_similar": []}
+
+        k = 1  # 预检通常只需返回最相似的一个
+        query_vec = mega_vector.reshape(1, -1)
+        similarities, indices = self.index.search(query_vec, k)
+
+        if indices[0][0] != -1:
+            score = float(similarities[0][0])
+            similar_info = self.metadata[indices[0][0]]
+            similar_id = similar_info['id']
+            best_similar = {
+                "id": similar_id,
+                "similarity": round(score, 4),
+                "similar_point": similar_info.get('text_preview', '')
+            }
+            
+            is_duplicate = False
+            gpt_checked = False
+            
+            if score >= threshold:
+                is_duplicate = True
+            elif score > 0.5:
+                # 相似度在 0.5 ~ 0.85 之间,调用 GPT-4o 兜底
+                similar_question = self.fetch_question_from_db(similar_id)
+                if similar_question:
+                    is_duplicate = self.call_openai_for_duplicate_check(question_data, similar_question)
+                    gpt_checked = True
+            
+            if is_duplicate:
+                return {
+                    "status": "warning", 
+                    "is_duplicate": True, 
+                    "gpt_checked": gpt_checked,
+                    "top_similar": [best_similar]
+                }
+            return {"status": "success", "is_duplicate": False, "top_similar": [best_similar]}
+        
+        return {"status": "success", "is_duplicate": False, "top_similar": []}
+
+    def add_to_index(self, question_id: int, vector: np.ndarray, text: str):
+        """将题目加入索引"""
+        if any(m['id'] == question_id for m in self.metadata):
+            return
+            
+        self.index.add(vector)
+        self.metadata.append({
+            'id': question_id,
+            'text_preview': self.clean_text(text)[:100]
+        })
+        self.save_index()
+
+    def sync_all_from_db(self, batch_size=50, max_workers=5):
+        """同步数据库中所有题目到索引 (支持加权模式 + 批量 + 多线程)"""
+        print("正在进行全量同步 (优化版 - 加权模式)...")
+        existing_ids = {m['id'] for m in self.metadata}
+        try:
+            conn = pymysql.connect(
+                host=DB_HOST,
+                port=DB_PORT,
+                user=DB_USERNAME,
+                password=DB_PASSWORD,
+                database=DB_DATABASE,
+                charset='utf8mb4',
+                cursorclass=pymysql.cursors.DictCursor
+            )
+            with conn.cursor() as cursor:
+                sql = "SELECT id, stem, options, answer, solution FROM questions_tem"
+                cursor.execute(sql)
+                all_questions = cursor.fetchall()
+            
+            new_questions = [q for q in all_questions if q['id'] not in existing_ids]
+            total_new = len(new_questions)
+            
+            if total_new == 0:
+                print("✅ 已经是最新状态,无需同步。")
+                return
+
+            print(f"📊 数据库总计: {len(all_questions)}, 需同步新增: {total_new}")
+
+            # 分块处理
+            chunks = [new_questions[i:i + batch_size] for i in range(0, total_new, batch_size)]
+            
+            def process_chunk(chunk):
+                mega_vectors = self.get_weighted_embeddings_batch(chunk)
+                return chunk, mega_vectors
+
+            # 使用线程池并发
+            with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
+                future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
+                
+                count = 0
+                for future in concurrent.futures.as_completed(future_to_chunk):
+                    chunk, mega_vectors = future.result()
+                    if mega_vectors:
+                        for i, mega_vec in enumerate(mega_vectors):
+                            if mega_vec is not None:
+                                self.index.add(mega_vec.reshape(1, -1))
+                                self.metadata.append({
+                                    'id': chunk[i]['id'],
+                                    'text_preview': self.clean_text(chunk[i].get('stem', ''))[:100]
+                                })
+                        
+                        count += len(chunk)
+                        print(f"✅ 已同步进度: {count}/{total_new}")
+                        
+                        # 每 500 条保存一次
+                        if count % 500 == 0:
+                            self.save_index()
+            
+            self.save_index()
+            print(f"🎉 同步完成!当前索引总数: {len(self.metadata)}")
+        except Exception as e:
+            print(f"❌ 同步失败: {e}")
+            import traceback
+            traceback.print_exc()
+        finally:
+            if 'conn' in locals() and conn:
+                conn.close()
+
+if __name__ == "__main__":
+    # 测试代码
+    checker = QuestionDuplicateChecker()
+    # checker.sync_all_from_db() # 首次运行时同步
+    # result = checker.check_duplicate(10)
+    # print(json.dumps(result, ensure_ascii=False, indent=2))

+ 49 - 0
test_similarity.py

@@ -0,0 +1,49 @@
+
+import json
+import numpy as np
+from duplicate_checker import QuestionDuplicateChecker
+
+checker = QuestionDuplicateChecker()
+
+user_input = {
+    'stem': '如图,两个同心圆中,大圆的半径为$5$,小圆的半径为$3$,若大圆的弦$AB$与小圆有公共点,则$AB$的取值范围是(  )',
+    'options': '{"A": "$8\\le AB\\le10$", "B": "$8<AB\\le10$", "C": "$4\\le AB\\le5$", "D": "$4<AB\\le5$"}',
+    'answer': 'A',
+    'solution': '如图,过$O$点作$OC\\perp AB$于$C$,连接$OA$,则$AC=BC$。当$AB$与小圆相切时,大圆的弦$AB$与小圆有唯一公共点,$OC$取最大值$3$,此时$AC$取最小值,为$\\sqrt{5^2-3^2}=4$,∴弦$AB$的最小值为$2\\times4=8$;当点$C$与$O$重合时,$AB$的值最大,$AB$为大圆的直径,即$AB$的最大值为$10$,∴$AB$的取值范围是$8\\le AB\\le10$。<image src="https://file.chunsunqiuzhu.com/data/2026/01/17/20260117180759A644.png"/>'
+}
+
+# Fetch ID 3094
+q3094 = checker.fetch_question_from_db(3094)
+
+if not q3094:
+    print("Question 3094 not found")
+    exit()
+
+print(f"Comparing user input with ID 3094\n")
+
+# Calculate embeddings for each part to see where it drops
+parts = ['stem', 'options', 'answer', 'solution']
+for p in parts:
+    u_text = user_input.get(p, '')
+    db_text = str(q3094.get(p, '') or '')
+    
+    u_emb = checker.get_embedding(u_text)
+    db_emb = checker.get_embedding(db_text)
+    
+    if u_emb is not None and db_emb is not None:
+        # Normalize
+        u_emb = u_emb / np.linalg.norm(u_emb)
+        db_emb = db_emb / np.linalg.norm(db_emb)
+        
+        sim = np.dot(u_emb, db_emb)
+        print(f"Part {p} similarity: {sim:.4f}")
+    else:
+        print(f"Part {p} failed to get embedding")
+
+mega_u = checker.get_weighted_embedding(user_input)
+mega_db = checker.get_weighted_embedding(q3094)
+if mega_u is not None and mega_db is not None:
+    overall_sim = np.dot(mega_u, mega_db)
+    print(f"\nOverall weighted similarity: {overall_sim:.4f}")
+else:
+    print("\nMega vector calculation failed")